use crate::distance_type::DistanceType;
use crate::error::{check_cuvs, Result};
use std::fmt;
use std::io::{stderr, Write};
pub struct Params(pub ffi::cuvsKMeansParams_t);
impl Params {
pub fn new() -> Result<Params> {
unsafe {
let mut params = std::mem::MaybeUninit::<ffi::cuvsKMeansParams_t>::uninit();
check_cuvs(ffi::cuvsKMeansParamsCreate(params.as_mut_ptr()))?;
Ok(Params(params.assume_init()))
}
}
pub fn set_metric(self, metric: DistanceType) -> Params {
unsafe {
(*self.0).metric = metric;
}
self
}
pub fn set_n_clusters(self, n_clusters: i32) -> Params {
unsafe {
(*self.0).n_clusters = n_clusters;
}
self
}
pub fn set_max_iter(self, max_iter: i32) -> Params {
unsafe {
(*self.0).max_iter = max_iter;
}
self
}
pub fn set_tol(self, tol: f64) -> Params {
unsafe {
(*self.0).tol = tol;
}
self
}
pub fn set_n_init(self, n_init: i32) -> Params {
unsafe {
(*self.0).n_init = n_init;
}
self
}
pub fn set_oversampling_factor(self, oversampling_factor: f64) -> Params {
unsafe {
(*self.0).oversampling_factor = oversampling_factor;
}
self
}
pub fn set_batch_samples(self, batch_samples: i32) -> Params {
unsafe {
(*self.0).batch_samples = batch_samples;
}
self
}
pub fn set_batch_centroids(self, batch_centroids: i32) -> Params {
unsafe {
(*self.0).batch_centroids = batch_centroids;
}
self
}
pub fn set_hierarchical(self, hierarchical: bool) -> Params {
unsafe {
(*self.0).hierarchical = hierarchical;
}
self
}
pub fn set_hierarchical_n_iters(self, hierarchical_n_iters: i32) -> Params {
unsafe {
(*self.0).hierarchical_n_iters = hierarchical_n_iters;
}
self
}
}
impl fmt::Debug for Params {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "Params({:?})", unsafe { *self.0 })
}
}
impl Drop for Params {
fn drop(&mut self) {
if let Err(e) = check_cuvs(unsafe { ffi::cuvsKMeansParamsDestroy(self.0) }) {
write!(stderr(), "failed to call cuvsKMeansParamsDestroy {:?}", e)
.expect("failed to write to stderr");
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_params() {
let params = Params::new()
.unwrap()
.set_n_clusters(128)
.set_hierarchical(true);
unsafe {
assert_eq!((*params.0).n_clusters, 128);
assert_eq!((*params.0).hierarchical, true);
}
}
}