use std::ffi::CString;
use std::ptr::NonNull;
use crate::error::{check, ErrorCode, Result, ZvecError};
use crate::ffi_util::{cstr_to_string, cstring, slice_as_bytes};
use crate::query_params::{FlatQueryParams, HnswQueryParams, IvfQueryParams};
use crate::sys;
fn field_name_vec_to_c(fields: &[&str]) -> Result<(Vec<CString>, Vec<*const core::ffi::c_char>)> {
let mut c_strings = Vec::with_capacity(fields.len());
for f in fields {
c_strings.push(cstring(f)?);
}
let ptrs = c_strings.iter().map(|s| s.as_ptr()).collect();
Ok((c_strings, ptrs))
}
pub struct VectorQuery {
ptr: NonNull<sys::zvec_vector_query_t>,
}
impl VectorQuery {
pub fn new() -> Result<Self> {
let ptr = unsafe { sys::zvec_vector_query_create() };
NonNull::new(ptr).map(|ptr| Self { ptr }).ok_or_else(|| {
ZvecError::with_message(
ErrorCode::ResourceExhausted,
"zvec_vector_query_create returned NULL",
)
})
}
pub(crate) fn as_ptr(&self) -> *const sys::zvec_vector_query_t {
self.ptr.as_ptr() as *const _
}
pub fn topk(&self) -> i32 {
unsafe { sys::zvec_vector_query_get_topk(self.as_ptr()) }
}
pub fn set_topk(&mut self, topk: i32) -> Result<()> {
check(unsafe { sys::zvec_vector_query_set_topk(self.ptr.as_ptr(), topk) })
}
pub fn field_name(&self) -> Option<String> {
unsafe { cstr_to_string(sys::zvec_vector_query_get_field_name(self.as_ptr())) }
}
pub fn set_field_name(&mut self, name: &str) -> Result<()> {
let c = cstring(name)?;
check(unsafe { sys::zvec_vector_query_set_field_name(self.ptr.as_ptr(), c.as_ptr()) })
}
pub fn set_query_vector_raw(&mut self, bytes: &[u8]) -> Result<()> {
check(unsafe {
sys::zvec_vector_query_set_query_vector(
self.ptr.as_ptr(),
bytes.as_ptr() as *const core::ffi::c_void,
bytes.len(),
)
})
}
pub fn set_query_vector_fp32(&mut self, vec: &[f32]) -> Result<()> {
self.set_query_vector_raw(slice_as_bytes(vec))
}
pub fn set_query_vector_fp64(&mut self, vec: &[f64]) -> Result<()> {
self.set_query_vector_raw(slice_as_bytes(vec))
}
#[cfg(feature = "half")]
#[cfg_attr(docsrs, doc(cfg(feature = "half")))]
pub fn set_query_vector_fp16(&mut self, vec: &[half::f16]) -> Result<()> {
let bits: &[u16] =
unsafe { core::slice::from_raw_parts(vec.as_ptr() as *const u16, vec.len()) };
self.set_query_vector_raw(slice_as_bytes(bits))
}
pub fn filter(&self) -> Option<String> {
unsafe { cstr_to_string(sys::zvec_vector_query_get_filter(self.as_ptr())) }
}
pub fn set_filter(&mut self, filter: &str) -> Result<()> {
let c = cstring(filter)?;
check(unsafe { sys::zvec_vector_query_set_filter(self.ptr.as_ptr(), c.as_ptr()) })
}
pub fn include_vector(&self) -> bool {
unsafe { sys::zvec_vector_query_get_include_vector(self.as_ptr()) }
}
pub fn set_include_vector(&mut self, b: bool) -> Result<()> {
check(unsafe { sys::zvec_vector_query_set_include_vector(self.ptr.as_ptr(), b) })
}
pub fn include_doc_id(&self) -> bool {
unsafe { sys::zvec_vector_query_get_include_doc_id(self.as_ptr()) }
}
pub fn set_include_doc_id(&mut self, b: bool) -> Result<()> {
check(unsafe { sys::zvec_vector_query_set_include_doc_id(self.ptr.as_ptr(), b) })
}
pub fn set_output_fields(&mut self, fields: &[&str]) -> Result<()> {
let (keep, ptrs) = field_name_vec_to_c(fields)?;
let rc = unsafe {
sys::zvec_vector_query_set_output_fields(
self.ptr.as_ptr(),
ptrs.as_ptr() as *mut *const core::ffi::c_char,
ptrs.len(),
)
};
drop(keep);
check(rc)
}
pub fn output_fields(&self) -> Result<Vec<String>> {
let mut arr: *mut *const core::ffi::c_char = core::ptr::null_mut();
let mut count: usize = 0;
check(unsafe {
sys::zvec_vector_query_get_output_fields(self.as_ptr(), &mut arr, &mut count)
})?;
let mut out = Vec::with_capacity(count);
for i in 0..count {
let p = unsafe { *arr.add(i) };
if let Some(s) = unsafe { cstr_to_string(p) } {
out.push(s);
}
}
if !arr.is_null() {
unsafe { sys::zvec_free(arr as *mut _) };
}
Ok(out)
}
pub fn set_hnsw_params(&mut self, params: HnswQueryParams) -> Result<()> {
let raw = params.into_raw();
check(unsafe { sys::zvec_vector_query_set_hnsw_params(self.ptr.as_ptr(), raw) })
}
pub fn set_ivf_params(&mut self, params: IvfQueryParams) -> Result<()> {
let raw = params.into_raw();
check(unsafe { sys::zvec_vector_query_set_ivf_params(self.ptr.as_ptr(), raw) })
}
pub fn set_flat_params(&mut self, params: FlatQueryParams) -> Result<()> {
let raw = params.into_raw();
check(unsafe { sys::zvec_vector_query_set_flat_params(self.ptr.as_ptr(), raw) })
}
}
impl Drop for VectorQuery {
fn drop(&mut self) {
unsafe { sys::zvec_vector_query_destroy(self.ptr.as_ptr()) };
}
}
unsafe impl Send for VectorQuery {}
pub struct GroupByVectorQuery {
ptr: NonNull<sys::zvec_group_by_vector_query_t>,
}
impl GroupByVectorQuery {
pub fn new() -> Result<Self> {
let ptr = unsafe { sys::zvec_group_by_vector_query_create() };
NonNull::new(ptr).map(|ptr| Self { ptr }).ok_or_else(|| {
ZvecError::with_message(
ErrorCode::ResourceExhausted,
"zvec_group_by_vector_query_create returned NULL",
)
})
}
pub(crate) fn as_ptr(&self) -> *const sys::zvec_group_by_vector_query_t {
self.ptr.as_ptr() as *const _
}
pub fn field_name(&self) -> Option<String> {
unsafe {
cstr_to_string(sys::zvec_group_by_vector_query_get_field_name(
self.as_ptr(),
))
}
}
pub fn set_field_name(&mut self, name: &str) -> Result<()> {
let c = cstring(name)?;
check(unsafe {
sys::zvec_group_by_vector_query_set_field_name(self.ptr.as_ptr(), c.as_ptr())
})
}
pub fn group_by_field_name(&self) -> Option<String> {
unsafe {
cstr_to_string(sys::zvec_group_by_vector_query_get_group_by_field_name(
self.as_ptr(),
))
}
}
pub fn set_group_by_field_name(&mut self, name: &str) -> Result<()> {
let c = cstring(name)?;
check(unsafe {
sys::zvec_group_by_vector_query_set_group_by_field_name(self.ptr.as_ptr(), c.as_ptr())
})
}
pub fn group_count(&self) -> u32 {
unsafe { sys::zvec_group_by_vector_query_get_group_count(self.as_ptr()) }
}
pub fn set_group_count(&mut self, n: u32) -> Result<()> {
check(unsafe { sys::zvec_group_by_vector_query_set_group_count(self.ptr.as_ptr(), n) })
}
pub fn group_topk(&self) -> u32 {
unsafe { sys::zvec_group_by_vector_query_get_group_topk(self.as_ptr()) }
}
pub fn set_group_topk(&mut self, n: u32) -> Result<()> {
check(unsafe { sys::zvec_group_by_vector_query_set_group_topk(self.ptr.as_ptr(), n) })
}
pub fn set_query_vector_raw(&mut self, bytes: &[u8]) -> Result<()> {
check(unsafe {
sys::zvec_group_by_vector_query_set_query_vector(
self.ptr.as_ptr(),
bytes.as_ptr() as *const core::ffi::c_void,
bytes.len(),
)
})
}
pub fn set_query_vector_fp32(&mut self, vec: &[f32]) -> Result<()> {
self.set_query_vector_raw(slice_as_bytes(vec))
}
pub fn filter(&self) -> Option<String> {
unsafe { cstr_to_string(sys::zvec_group_by_vector_query_get_filter(self.as_ptr())) }
}
pub fn set_filter(&mut self, filter: &str) -> Result<()> {
let c = cstring(filter)?;
check(unsafe { sys::zvec_group_by_vector_query_set_filter(self.ptr.as_ptr(), c.as_ptr()) })
}
pub fn include_vector(&self) -> bool {
unsafe { sys::zvec_group_by_vector_query_get_include_vector(self.as_ptr()) }
}
pub fn set_include_vector(&mut self, b: bool) -> Result<()> {
check(unsafe { sys::zvec_group_by_vector_query_set_include_vector(self.ptr.as_ptr(), b) })
}
pub fn set_output_fields(&mut self, fields: &[&str]) -> Result<()> {
let (keep, ptrs) = field_name_vec_to_c(fields)?;
let rc = unsafe {
sys::zvec_group_by_vector_query_set_output_fields(
self.ptr.as_ptr(),
ptrs.as_ptr() as *mut *const core::ffi::c_char,
ptrs.len(),
)
};
drop(keep);
check(rc)
}
pub fn output_fields(&self) -> Result<Vec<String>> {
let mut arr: *mut *const core::ffi::c_char = core::ptr::null_mut();
let mut count: usize = 0;
check(unsafe {
sys::zvec_group_by_vector_query_get_output_fields(
self.ptr.as_ptr(),
&mut arr,
&mut count,
)
})?;
let mut out = Vec::with_capacity(count);
for i in 0..count {
let p = unsafe { *arr.add(i) };
if let Some(s) = unsafe { cstr_to_string(p) } {
out.push(s);
}
}
if !arr.is_null() {
unsafe { sys::zvec_free(arr as *mut _) };
}
Ok(out)
}
pub fn set_hnsw_params(&mut self, params: HnswQueryParams) -> Result<()> {
let raw = params.into_raw();
check(unsafe { sys::zvec_group_by_vector_query_set_hnsw_params(self.ptr.as_ptr(), raw) })
}
pub fn set_ivf_params(&mut self, params: IvfQueryParams) -> Result<()> {
let raw = params.into_raw();
check(unsafe { sys::zvec_group_by_vector_query_set_ivf_params(self.ptr.as_ptr(), raw) })
}
pub fn set_flat_params(&mut self, params: FlatQueryParams) -> Result<()> {
let raw = params.into_raw();
check(unsafe { sys::zvec_group_by_vector_query_set_flat_params(self.ptr.as_ptr(), raw) })
}
}
impl Drop for GroupByVectorQuery {
fn drop(&mut self) {
unsafe { sys::zvec_group_by_vector_query_destroy(self.ptr.as_ptr()) };
}
}
unsafe impl Send for GroupByVectorQuery {}