use crate::error::{check_cuvs, Result};
use std::fmt;
use std::io::{stderr, Write};
pub use ffi::cudaDataType_t;
pub struct SearchParams(pub ffi::cuvsIvfPqSearchParams_t);
impl SearchParams {
pub fn new() -> Result<SearchParams> {
unsafe {
let mut params = std::mem::MaybeUninit::<ffi::cuvsIvfPqSearchParams_t>::uninit();
check_cuvs(ffi::cuvsIvfPqSearchParamsCreate(params.as_mut_ptr()))?;
Ok(SearchParams(params.assume_init()))
}
}
pub fn set_n_probes(self, n_probes: u32) -> SearchParams {
unsafe {
(*self.0).n_probes = n_probes;
}
self
}
pub fn set_lut_dtype(self, lut_dtype: cudaDataType_t) -> SearchParams {
unsafe {
(*self.0).lut_dtype = lut_dtype;
}
self
}
pub fn set_internal_distance_dtype(
self,
internal_distance_dtype: cudaDataType_t,
) -> SearchParams {
unsafe {
(*self.0).internal_distance_dtype = internal_distance_dtype;
}
self
}
}
impl fmt::Debug for SearchParams {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "SearchParams {{ params: {:?} }}", unsafe { *self.0 })
}
}
impl Drop for SearchParams {
fn drop(&mut self) {
if let Err(e) = check_cuvs(unsafe { ffi::cuvsIvfPqSearchParamsDestroy(self.0) }) {
write!(
stderr(),
"failed to call cuvsIvfPqSearchParamsDestroy {:?}",
e
)
.expect("failed to write to stderr");
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_search_params() {
let params = SearchParams::new().unwrap().set_n_probes(128);
unsafe {
assert_eq!((*params.0).n_probes, 128);
}
}
}