use crate::DType;
use crate::cluster::traits::kmeans::{KMeansOptions, KMeansResult};
use crate::cluster::traits::mini_batch_kmeans::MiniBatchKMeansOptions;
use crate::cluster::validation::{validate_cluster_dtype, validate_data_2d, validate_n_clusters};
use numr::error::Result;
use numr::ops::{
CompareOps, ConditionalOps, CumulativeOps, DistanceMetric, DistanceOps, IndexingOps, RandomOps,
ReduceOps, ScalarOps, ShapeOps, SortingOps, TensorOps, TypeConversionOps, UnaryOps, UtilityOps,
};
use numr::runtime::{Runtime, RuntimeClient};
use numr::tensor::Tensor;
pub fn mini_batch_kmeans_impl<R, C>(
client: &C,
data: &Tensor<R>,
options: &MiniBatchKMeansOptions<R>,
) -> Result<KMeansResult<R>>
where
R: Runtime<DType = DType>,
C: DistanceOps<R>
+ IndexingOps<R>
+ ReduceOps<R>
+ ScalarOps<R>
+ TensorOps<R>
+ TypeConversionOps<R>
+ UnaryOps<R>
+ CumulativeOps<R>
+ ConditionalOps<R>
+ CompareOps<R>
+ RandomOps<R>
+ SortingOps<R>
+ ShapeOps<R>
+ UtilityOps<R>
+ RuntimeClient<R>,
{
validate_cluster_dtype(data.dtype(), "mini_batch_kmeans")?;
validate_data_2d(data.shape(), "mini_batch_kmeans")?;
validate_n_clusters(options.n_clusters, data.shape()[0], "mini_batch_kmeans")?;
let n = data.shape()[0];
let d = data.shape()[1];
let k = options.n_clusters;
let dtype = data.dtype();
let device = data.device();
let batch_size = options.batch_size.min(n);
let kmeans_opts = KMeansOptions {
n_clusters: k,
max_iter: 1,
tol: 0.0,
n_init: 1,
init: options.init.clone(),
..Default::default()
};
let init_result = super::kmeans::kmeans_impl(client, data, &kmeans_opts)?;
let mut centroids = init_result.centroids;
let mut counts = Tensor::<R>::ones(&[k], dtype, device);
let mut best_inertia = f64::INFINITY;
let mut no_improvement = 0usize;
let mut n_iter = 0;
for iter in 0..options.max_iter {
n_iter = iter + 1;
let perm = client.randperm(n)?;
let batch_idx = perm.narrow(0, 0, batch_size)?;
let batch = client.index_select(data, 0, &batch_idx)?;
let dists = client.cdist(&batch, ¢roids, DistanceMetric::SquaredEuclidean)?;
let labels = client.argmin(&dists, 1, false)?;
let labels_exp = labels.unsqueeze(1)?.broadcast_to(&[batch_size, d])?;
let dst = Tensor::<R>::zeros(&[k, d], dtype, device);
let batch_sums = client.scatter_reduce(
&dst,
0,
&labels_exp,
&batch,
numr::ops::ScatterReduceOp::Sum,
false,
)?;
let batch_counts = client.bincount(&labels, None, k)?;
let batch_counts_f = client.cast(&batch_counts, dtype)?;
counts = client.add(&counts, &batch_counts_f)?;
let eta = client.div(&batch_counts_f, &counts)?; let eta_exp = eta.unsqueeze(1)?.broadcast_to(&[k, d])?;
let bc_safe = client.maximum(&batch_counts_f, &Tensor::<R>::ones(&[k], dtype, device))?;
let bc_exp = bc_safe.unsqueeze(1)?.broadcast_to(&[k, d])?;
let batch_centroids = client.div(&batch_sums, &bc_exp)?;
let has_points = client.gt(&batch_counts_f, &Tensor::<R>::zeros(&[k], dtype, device))?;
let has_points_exp = has_points.unsqueeze(1)?.broadcast_to(&[k, d])?;
let has_points_f = client.cast(&has_points_exp, dtype)?;
let diff = client.sub(&batch_centroids, ¢roids)?;
let update = client.mul(&eta_exp, &diff)?;
let update = client.mul(&update, &has_points_f)?;
centroids = client.add(¢roids, &update)?;
if options.tol > 0.0 || options.max_no_improvement < options.max_iter {
let min_dists = client.min(&dists, &[1], false)?;
let inertia: f64 = client.mean(&min_dists, &[0], false)?.item()?;
if inertia < best_inertia - options.tol {
best_inertia = inertia;
no_improvement = 0;
} else {
no_improvement += 1;
if no_improvement >= options.max_no_improvement {
break;
}
}
}
}
let final_dists = client.cdist(data, ¢roids, DistanceMetric::SquaredEuclidean)?;
let labels = client.argmin(&final_dists, 1, false)?;
let min_dists = client.min(&final_dists, &[1], false)?;
let inertia = client.sum(&min_dists, &[0], false)?;
Ok(KMeansResult {
centroids,
labels,
inertia,
n_iter,
})
}