use std::ffi::CString;
use std::os::raw::c_int;
use crate::error::{check_error, Result};
use crate::ffi;
use crate::query::{FlatQueryParam, HnswQueryParam, IVFQueryParam};
pub struct SubQuery {
pub(crate) ptr: *mut ffi::zvec_sub_query_t,
}
impl SubQuery {
pub fn new() -> Result<Self> {
let ptr = unsafe { ffi::zvec_sub_query_create() };
if ptr.is_null() {
return Err(crate::error::Error::InternalError(
"zvec_sub_query_create returned null".into(),
));
}
Ok(Self { ptr })
}
pub fn field_name(self, name: &str) -> Result<Self> {
let cstr =
CString::new(name).map_err(|e| crate::error::Error::InvalidArgument(e.to_string()))?;
let code = unsafe { ffi::zvec_sub_query_set_field_name(self.ptr, cstr.as_ptr()) };
check_error(code as c_int)?;
Ok(self)
}
pub fn num_candidates(self, n: u32) -> Self {
let code = unsafe { ffi::zvec_sub_query_set_num_candidates(self.ptr, n as c_int) };
let _ = check_error(code as c_int);
self
}
pub fn vector(self, data: &[f32]) -> Result<Self> {
let code = unsafe {
ffi::zvec_sub_query_set_query_vector(
self.ptr,
data.as_ptr() as *const std::os::raw::c_void,
std::mem::size_of_val(data),
)
};
check_error(code as c_int)?;
Ok(self)
}
pub fn sparse_vector(self, indices: &[u32], values: &[f32]) -> Result<Self> {
if indices.len() != values.len() {
return Err(crate::error::Error::InvalidArgument(
"indices and values must have same length".into(),
));
}
let n = indices.len();
let code = unsafe { ffi::zvec_sub_query_set_sparse_indices(self.ptr, indices.as_ptr(), n) };
check_error(code as c_int)?;
let code = unsafe { ffi::zvec_sub_query_set_sparse_values(self.ptr, values.as_ptr(), n) };
check_error(code as c_int)?;
Ok(self)
}
pub fn hnsw_params(self, params: HnswQueryParam) -> Self {
let ptr = params.ptr;
std::mem::forget(params);
let code = unsafe { ffi::zvec_sub_query_set_hnsw_params(self.ptr, ptr) };
let _ = check_error(code as c_int);
self
}
pub fn ivf_params(self, params: IVFQueryParam) -> Self {
let ptr = params.ptr;
std::mem::forget(params);
let code = unsafe { ffi::zvec_sub_query_set_ivf_params(self.ptr, ptr) };
let _ = check_error(code as c_int);
self
}
pub fn flat_params(self, params: FlatQueryParam) -> Self {
let ptr = params.ptr;
std::mem::forget(params);
let code = unsafe { ffi::zvec_sub_query_set_flat_params(self.ptr, ptr) };
let _ = check_error(code as c_int);
self
}
}
impl Default for SubQuery {
fn default() -> Self {
Self::new().expect("zvec_sub_query_create failed in Default")
}
}
impl Drop for SubQuery {
fn drop(&mut self) {
if !self.ptr.is_null() {
unsafe { ffi::zvec_sub_query_destroy(self.ptr) };
}
}
}
unsafe impl Send for SubQuery {}
pub struct MultiQuery {
pub(crate) ptr: *mut ffi::zvec_multi_query_t,
}
impl MultiQuery {
pub fn new() -> Result<Self> {
let ptr = unsafe { ffi::zvec_multi_query_create() };
if ptr.is_null() {
return Err(crate::error::Error::InternalError(
"zvec_multi_query_create returned null".into(),
));
}
Ok(Self { ptr })
}
pub fn add_sub_query(&mut self, sub: SubQuery) -> Result<&mut Self> {
let code = unsafe { ffi::zvec_multi_query_add_sub_query(self.ptr, sub.ptr) };
drop(sub);
check_error(code as c_int)?;
Ok(self)
}
pub fn topk(self, k: i32) -> Self {
let code = unsafe { ffi::zvec_multi_query_set_topk(self.ptr, k) };
let _ = check_error(code as c_int);
self
}
pub fn filter(self, filter: &str) -> Result<Self> {
let cstr = CString::new(filter)
.map_err(|e| crate::error::Error::InvalidArgument(e.to_string()))?;
let code = unsafe { ffi::zvec_multi_query_set_filter(self.ptr, cstr.as_ptr()) };
check_error(code as c_int)?;
Ok(self)
}
pub fn output_fields(self, fields: &[&str]) -> Result<Self> {
let fields_c: Vec<CString> = fields
.iter()
.map(|f| {
CString::new(*f).map_err(|e| crate::error::Error::InvalidArgument(e.to_string()))
})
.collect::<Result<_>>()?;
let mut fields_ptr: Vec<*const std::os::raw::c_char> =
fields_c.iter().map(|f| f.as_ptr()).collect();
let code = unsafe {
ffi::zvec_multi_query_set_output_fields(
self.ptr,
fields_ptr.as_mut_ptr(),
fields_ptr.len(),
)
};
check_error(code as c_int)?;
Ok(self)
}
pub fn rerank_rrf(self, rank_constant: i32) -> Self {
let code = unsafe { ffi::zvec_multi_query_set_rerank_rrf(self.ptr, rank_constant) };
let _ = check_error(code as c_int);
self
}
pub fn rerank_weighted(self, weights: &[f64]) -> Result<Self> {
let code = unsafe {
ffi::zvec_multi_query_set_rerank_weighted(self.ptr, weights.as_ptr(), weights.len())
};
check_error(code as c_int)?;
Ok(self)
}
pub(crate) fn as_ptr(&self) -> *mut ffi::zvec_multi_query_t {
self.ptr
}
}
impl Default for MultiQuery {
fn default() -> Self {
Self::new().expect("zvec_multi_query_create failed in Default")
}
}
impl Drop for MultiQuery {
fn drop(&mut self) {
if !self.ptr.is_null() {
unsafe { ffi::zvec_multi_query_destroy(self.ptr) };
}
}
}
unsafe impl Send for MultiQuery {}