use crate::clump_compat::{map_clump_error, soa_to_aos};
use crate::RetrieveError;
pub struct KMeans {
dimension: usize,
k: usize,
seed: Option<u64>,
fit: Option<clump::KmeansFit<clump::CosineDistance>>,
}
impl KMeans {
pub fn new(dimension: usize, k: usize) -> Result<Self, RetrieveError> {
if dimension == 0 || k == 0 {
return Err(RetrieveError::InvalidParameter(
"Dimension and k must be greater than 0".to_string(),
));
}
Ok(Self {
dimension,
k,
seed: None,
fit: None,
})
}
#[must_use]
pub fn with_seed(mut self, seed: u64) -> Self {
self.seed = Some(seed);
self
}
pub fn fit(&mut self, vectors: &[f32], num_vectors: usize) -> Result<(), RetrieveError> {
if vectors.len() < num_vectors * self.dimension {
return Err(RetrieveError::InvalidParameter(
"Insufficient vectors".to_string(),
));
}
let data = soa_to_aos(vectors, num_vectors, self.dimension);
let effective_k = self.k.min(num_vectors);
let mut builder =
clump::Kmeans::with_metric(effective_k, clump::CosineDistance).with_tol(1e-6);
if let Some(s) = self.seed {
builder = builder.with_seed(s);
}
let result = builder.fit(&data).map_err(map_clump_error)?;
self.fit = Some(result);
Ok(())
}
pub fn assign_clusters(&self, vectors: &[f32], num_vectors: usize) -> Vec<usize> {
let Some(ref fit) = self.fit else {
return Vec::new();
};
let data = soa_to_aos(vectors, num_vectors, self.dimension);
fit.predict(&data).unwrap_or_default()
}
pub fn centroids(&self) -> &[Vec<f32>] {
match self.fit {
Some(ref f) => &f.centroids,
None => &[],
}
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
use proptest::prelude::*;
fn l2_normalize_in_place(vecs: &mut [f32], num_vectors: usize, dimension: usize) {
for i in 0..num_vectors {
let start = i * dimension;
let end = start + dimension;
let v = &mut vecs[start..end];
let norm2: f32 = v.iter().map(|&x| x * x).sum();
let norm = norm2.sqrt();
if norm > 0.0 {
for x in v {
*x /= norm;
}
} else if !v.is_empty() {
v[0] = 1.0;
}
}
}
proptest! {
#[test]
fn prop_kmeans_fit_is_deterministic_given_seed(
seed in any::<u64>(),
dimension in 1usize..16,
num_vectors in 2usize..64,
k in 1usize..16,
raw in proptest::collection::vec(-1.0f32..1.0f32, 2usize..(64*16)),
) {
prop_assume!(k <= num_vectors);
let needed = num_vectors * dimension;
prop_assume!(raw.len() >= needed);
let mut vectors = raw[..needed].to_vec();
l2_normalize_in_place(&mut vectors, num_vectors, dimension);
let mut km1 = KMeans::new(dimension, k).unwrap().with_seed(seed);
let mut km2 = KMeans::new(dimension, k).unwrap().with_seed(seed);
km1.fit(&vectors, num_vectors).unwrap();
km2.fit(&vectors, num_vectors).unwrap();
let a1 = km1.assign_clusters(&vectors, num_vectors);
let a2 = km2.assign_clusters(&vectors, num_vectors);
prop_assert_eq!(a1, a2);
}
}
}