mod params;
pub use params::Params;
use crate::dlpack::ManagedTensor;
use crate::error::{check_cuvs, Result};
use crate::resources::Resources;
pub fn fit(
res: &Resources,
params: &Params,
x: &ManagedTensor,
sample_weight: &Option<ManagedTensor>,
centroids: &mut ManagedTensor,
) -> Result<(f64, i32)> {
let mut inertia: f64 = 0.0;
let mut niter: i32 = 0;
unsafe {
let sample_weight_dlpack = match sample_weight {
Some(tensor) => tensor.as_ptr(),
None => std::ptr::null_mut(),
};
check_cuvs(ffi::cuvsKMeansFit(
res.0,
params.0,
x.as_ptr(),
sample_weight_dlpack,
centroids.as_ptr(),
&mut inertia as *mut f64,
&mut niter as *mut i32,
))?;
}
Ok((inertia, niter))
}
pub fn predict(
res: &Resources,
params: &Params,
x: &ManagedTensor,
sample_weight: &Option<ManagedTensor>,
centroids: &ManagedTensor,
labels: &mut ManagedTensor,
normalize_weight: bool,
) -> Result<f64> {
let mut inertia: f64 = 0.0;
unsafe {
let sample_weight_dlpack = match sample_weight {
Some(tensor) => tensor.as_ptr(),
None => std::ptr::null_mut(),
};
check_cuvs(ffi::cuvsKMeansPredict(
res.0,
params.0,
x.as_ptr(),
sample_weight_dlpack,
centroids.as_ptr(),
labels.as_ptr(),
normalize_weight,
&mut inertia as *mut f64,
))?;
}
Ok(inertia)
}
pub fn cluster_cost(res: &Resources, x: &ManagedTensor, centroids: &ManagedTensor) -> Result<f64> {
let mut inertia: f64 = 0.0;
unsafe {
check_cuvs(ffi::cuvsKMeansClusterCost(
res.0,
x.as_ptr(),
centroids.as_ptr(),
&mut inertia as *mut f64,
))?;
}
Ok(inertia)
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray_rand::rand_distr::Uniform;
use ndarray_rand::RandomExt;
#[test]
fn test_kmeans() {
let res = Resources::new().unwrap();
let n_clusters = 4;
let n_datapoints = 256;
let n_features = 16;
let dataset =
ndarray::Array::<f32, _>::random((n_datapoints, n_features), Uniform::new(0., 1.0));
let dataset = ManagedTensor::from(&dataset).to_device(&res).unwrap();
let centroids_host = ndarray::Array::<f32, _>::zeros((n_clusters, n_features));
let mut centroids = ManagedTensor::from(¢roids_host)
.to_device(&res)
.unwrap();
let params = Params::new().unwrap().set_n_clusters(n_clusters as i32);
let original_inertia = cluster_cost(&res, &dataset, ¢roids).unwrap();
let (inertia, n_iter) = fit(&res, ¶ms, &dataset, &None, &mut centroids).unwrap();
assert!(inertia < original_inertia);
assert!(n_iter >= 1);
let mut labels_host = ndarray::Array::<i32, _>::zeros((n_clusters,));
let mut labels = ManagedTensor::from(&labels_host).to_device(&res).unwrap();
predict(
&res,
¶ms,
¢roids,
&None,
¢roids,
&mut labels,
false,
)
.unwrap();
labels.to_host(&res, &mut labels_host).unwrap();
assert_eq!(labels_host[[0,]], 0);
assert_eq!(labels_host[[1,]], 1);
assert_eq!(labels_host[[2,]], 2);
assert_eq!(labels_host[[3,]], 3);
}
}