use ndarray::{Array2, ArrayView2, Axis};
use rand::seq::SliceRandom;
use rand::SeedableRng;
use rand_chacha::ChaCha8Rng;
use crate::error::{Error, Result};
use crate::maxsim;
pub use fastkmeans_rs::{kmeans_double_chunked, FastKMeans, KMeansConfig, KMeansError};
#[cfg(feature = "cuda")]
pub use fastkmeans_rs::FastKMeansCuda;
#[cfg(feature = "metal_gpu")]
pub use fastkmeans_rs::FastKMeansMetal;
#[derive(Debug, Clone)]
pub struct ComputeKmeansConfig {
pub kmeans_niters: usize,
pub max_points_per_centroid: usize,
pub seed: u64,
pub n_samples_kmeans: Option<usize>,
pub num_partitions: Option<usize>,
pub force_cpu: bool,
}
impl Default for ComputeKmeansConfig {
fn default() -> Self {
Self {
kmeans_niters: 4,
max_points_per_centroid: 256,
seed: 42,
n_samples_kmeans: None,
num_partitions: None,
force_cpu: false,
}
}
}
pub fn default_config(num_centroids: usize) -> KMeansConfig {
KMeansConfig {
k: num_centroids,
max_iters: 4,
tol: 1e-8,
seed: 42,
max_points_per_centroid: Some(256),
chunk_size_data: 51_200,
chunk_size_centroids: 10_240,
verbose: false,
}
}
fn compute_centroids_cpu(
embeddings: &ArrayView2<f32>,
config: KMeansConfig,
) -> Result<Array2<f32>> {
let result = kmeans_double_chunked(embeddings, &config)
.map_err(|e| Error::IndexCreation(format!("K-means training failed: {}", e)))?;
Ok(result.centroids)
}
#[cfg(not(any(feature = "cuda", feature = "metal_gpu")))]
pub fn compute_centroids(
embeddings: &ArrayView2<f32>,
num_centroids: usize,
config: Option<KMeansConfig>,
_force_cpu: bool,
) -> Result<Array2<f32>> {
let config = config.unwrap_or_else(|| default_config(num_centroids));
compute_centroids_cpu(embeddings, config)
}
#[cfg(feature = "cuda")]
pub fn compute_centroids(
embeddings: &ArrayView2<f32>,
num_centroids: usize,
config: Option<KMeansConfig>,
force_cpu: bool,
) -> Result<Array2<f32>> {
let config = config.unwrap_or_else(|| default_config(num_centroids));
if force_cpu || crate::cuda::is_cuda_broken() {
return compute_centroids_cpu(embeddings, config.clone());
}
let cuda_result = crate::cuda::catch_cuda_panic(std::panic::AssertUnwindSafe(|| {
match FastKMeansCuda::with_config(config.clone()) {
Ok(mut kmeans) => match kmeans.train(embeddings) {
Ok(()) => kmeans
.centroids()
.map(|c| c.to_owned())
.ok_or_else(|| "CUDA K-means did not produce centroids".to_string()),
Err(e) => Err(format!("CUDA K-means training failed: {}", e)),
},
Err(e) => Err(format!("CUDA K-means init failed: {}", e)),
}
}));
match cuda_result {
Ok(Ok(centroids)) => Ok(centroids),
Ok(Err(e)) => {
crate::cuda::mark_cuda_broken();
eprintln!(
"[next-plaid] CUDA K-means error: {}. Falling back to CPU.",
e
);
compute_centroids_cpu(embeddings, config)
}
Err(_) => {
crate::cuda::mark_cuda_broken();
eprintln!(
"[next-plaid] CUDA library found but missing required symbols (stub or incompatible driver). \
K-means will use CPU instead."
);
compute_centroids_cpu(embeddings, config)
}
}
}
#[cfg(all(feature = "metal_gpu", not(feature = "cuda")))]
pub fn compute_centroids(
embeddings: &ArrayView2<f32>,
num_centroids: usize,
config: Option<KMeansConfig>,
force_cpu: bool,
) -> Result<Array2<f32>> {
let config = config.unwrap_or_else(|| default_config(num_centroids));
if force_cpu {
return compute_centroids_cpu(embeddings, config);
}
let mut kmeans = FastKMeansMetal::with_config(config)
.map_err(|e| Error::IndexCreation(format!("Metal K-means initialization failed: {}", e)))?;
kmeans
.train(embeddings)
.map_err(|e| Error::IndexCreation(format!("Metal K-means training failed: {}", e)))?;
kmeans
.centroids()
.ok_or_else(|| Error::IndexCreation("Metal K-means did not produce centroids".into()))
.map(|c| c.to_owned())
}
pub fn compute_centroids_from_documents(
documents: &[Array2<f32>],
num_centroids: usize,
config: Option<KMeansConfig>,
force_cpu: bool,
) -> Result<Array2<f32>> {
if documents.is_empty() {
return Err(Error::IndexCreation("No documents provided".into()));
}
let dim = documents[0].ncols();
let total_tokens: usize = documents.iter().map(|d| d.nrows()).sum();
let mut flat = Array2::<f32>::zeros((total_tokens, dim));
let mut offset = 0;
for doc in documents {
let n = doc.nrows();
flat.slice_mut(ndarray::s![offset..offset + n, ..])
.assign(doc);
offset += n;
}
compute_centroids(&flat.view(), num_centroids, config, force_cpu)
}
pub fn assign_to_centroids(embeddings: &ArrayView2<f32>, centroids: &Array2<f32>) -> Vec<usize> {
maxsim::assign_to_centroids(embeddings, ¢roids.view())
}
pub fn compute_kmeans(
documents_embeddings: &[Array2<f32>],
config: &ComputeKmeansConfig,
) -> Result<Array2<f32>> {
if documents_embeddings.is_empty() {
return Err(Error::IndexCreation("No documents provided".into()));
}
let num_documents = documents_embeddings.len();
let dim = documents_embeddings[0].ncols();
let n_samples_kmeans = config.n_samples_kmeans.unwrap_or_else(|| {
(1.0 + 16.0 * (120.0 * num_documents as f64).sqrt()).min(num_documents as f64) as usize
});
let n_samples_kmeans = n_samples_kmeans.min(num_documents);
let mut rng = ChaCha8Rng::seed_from_u64(config.seed);
let mut indices: Vec<usize> = (0..num_documents).collect();
indices.shuffle(&mut rng);
indices.truncate(n_samples_kmeans);
let sampled_indices = indices;
let total_sample_tokens: usize = sampled_indices
.iter()
.map(|&i| documents_embeddings[i].nrows())
.sum();
let mut samples_tensor = Array2::<f32>::zeros((total_sample_tokens, dim));
let mut current_offset = 0;
for &i in &sampled_indices {
let tensor_slice = &documents_embeddings[i];
let length = tensor_slice.nrows();
samples_tensor
.slice_mut(ndarray::s![current_offset..current_offset + length, ..])
.assign(tensor_slice);
current_offset += length;
}
let num_partitions = config.num_partitions.unwrap_or_else(|| {
let avg_tokens_per_doc = total_sample_tokens as f64 / n_samples_kmeans as f64;
let estimated_total_tokens = avg_tokens_per_doc * num_documents as f64;
2usize.pow((16.0 * estimated_total_tokens.sqrt()).log2().floor() as u32)
});
let actual_k = num_partitions.min(total_sample_tokens);
if actual_k == 0 {
return Err(Error::IndexCreation("Cannot compute 0 centroids".into()));
}
let kmeans_config = KMeansConfig {
k: actual_k,
max_iters: config.kmeans_niters,
tol: 1e-8,
seed: config.seed,
max_points_per_centroid: Some(config.max_points_per_centroid),
chunk_size_data: 51_200,
chunk_size_centroids: 10_240,
verbose: false,
};
#[cfg(not(any(feature = "cuda", feature = "metal_gpu")))]
let centroids = {
let mut kmeans = FastKMeans::with_config(kmeans_config);
kmeans
.train(&samples_tensor.view())
.map_err(|e| Error::IndexCreation(format!("K-means training failed: {}", e)))?;
kmeans
.centroids()
.ok_or_else(|| Error::IndexCreation("K-means did not produce centroids".into()))?
.to_owned()
};
#[cfg(feature = "cuda")]
let centroids = if config.force_cpu || crate::cuda::is_cuda_broken() {
let result = kmeans_double_chunked(&samples_tensor.view(), &kmeans_config)
.map_err(|e| Error::IndexCreation(format!("K-means training failed: {}", e)))?;
result.centroids
} else {
let samples_view = samples_tensor.view();
let cuda_result = crate::cuda::catch_cuda_panic(std::panic::AssertUnwindSafe(|| {
match FastKMeansCuda::with_config(kmeans_config.clone()) {
Ok(mut kmeans) => match kmeans.train(&samples_view) {
Ok(()) => kmeans.centroids().map(|c| c.to_owned()),
Err(_) => None,
},
Err(_) => None,
}
}));
match cuda_result {
Ok(Some(c)) => c,
Ok(None) | Err(_) => {
crate::cuda::mark_cuda_broken();
if cuda_result.is_err() {
eprintln!("[next-plaid] CUDA library found but missing required symbols (stub or incompatible driver). \
K-means will use CPU instead.");
} else {
eprintln!(
"[next-plaid] CUDA K-means did not produce centroids. Falling back to CPU."
);
}
let result = kmeans_double_chunked(&samples_tensor.view(), &kmeans_config)
.map_err(|e| Error::IndexCreation(format!("K-means training failed: {}", e)))?;
result.centroids
}
}
};
#[cfg(all(feature = "metal_gpu", not(feature = "cuda")))]
let centroids = if config.force_cpu {
let mut kmeans = FastKMeans::with_config(kmeans_config);
kmeans
.train(&samples_tensor.view())
.map_err(|e| Error::IndexCreation(format!("K-means training failed: {}", e)))?;
kmeans
.centroids()
.ok_or_else(|| Error::IndexCreation("K-means did not produce centroids".into()))?
.to_owned()
} else {
let mut kmeans = FastKMeansMetal::with_config(kmeans_config).map_err(|e| {
Error::IndexCreation(format!("Metal K-means initialization failed: {}", e))
})?;
kmeans
.train(&samples_tensor.view())
.map_err(|e| Error::IndexCreation(format!("Metal K-means training failed: {}", e)))?;
kmeans
.centroids()
.ok_or_else(|| Error::IndexCreation("Metal K-means did not produce centroids".into()))?
.to_owned()
};
let mut normalized = centroids.clone();
for mut row in normalized.axis_iter_mut(Axis(0)) {
let norm = row.dot(&row).sqrt().max(1e-12);
row /= norm;
}
Ok(normalized)
}
pub fn estimate_num_partitions(documents_embeddings: &[Array2<f32>]) -> usize {
if documents_embeddings.is_empty() {
return 0;
}
let num_documents = documents_embeddings.len();
let n_samples_kmeans =
(1.0 + 16.0 * (120.0 * num_documents as f64).sqrt()).min(num_documents as f64) as usize;
let total_tokens: usize = documents_embeddings.iter().map(|d| d.nrows()).sum();
let avg_tokens_per_doc = total_tokens as f64 / num_documents as f64;
let sampled_count = n_samples_kmeans.min(num_documents);
let estimated_total_tokens = avg_tokens_per_doc * num_documents as f64;
let k = 2usize.pow((16.0 * estimated_total_tokens.sqrt()).log2().floor() as u32);
let sample_tokens = (avg_tokens_per_doc * sampled_count as f64) as usize;
k.min(sample_tokens)
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray_rand::rand_distr::Uniform;
use ndarray_rand::RandomExt;
#[test]
fn test_compute_centroids() {
let data: Array2<f32> = Array2::random((500, 32), Uniform::new(-1.0f32, 1.0));
let centroids = compute_centroids(&data.view(), 10, None, false).unwrap();
assert_eq!(centroids.nrows(), 10);
assert_eq!(centroids.ncols(), 32);
}
#[test]
fn test_compute_centroids_from_documents() {
let docs: Vec<Array2<f32>> = (0..10)
.map(|_| Array2::random((50, 16), Uniform::new(-1.0f32, 1.0)))
.collect();
let centroids = compute_centroids_from_documents(&docs, 8, None, false).unwrap();
assert_eq!(centroids.nrows(), 8);
assert_eq!(centroids.ncols(), 16);
}
#[test]
fn test_assign_to_centroids() {
let data: Array2<f32> = Array2::random((100, 16), Uniform::new(-1.0f32, 1.0));
let centroids = compute_centroids(&data.view(), 5, None, false).unwrap();
let assignments = assign_to_centroids(&data.view(), ¢roids);
assert_eq!(assignments.len(), 100);
for &label in &assignments {
assert!(label < 5);
}
}
#[test]
fn test_compute_kmeans() {
let docs: Vec<Array2<f32>> = (0..100)
.map(|_| Array2::random((50, 32), Uniform::new(-1.0f32, 1.0)))
.collect();
let config = ComputeKmeansConfig::default();
let centroids = compute_kmeans(&docs, &config).unwrap();
assert!(centroids.nrows() > 0);
assert_eq!(centroids.ncols(), 32);
for row in centroids.axis_iter(Axis(0)) {
let norm = row.dot(&row).sqrt();
assert!(
(norm - 1.0).abs() < 1e-5,
"Centroid not normalized: norm={}",
norm
);
}
}
#[test]
fn test_compute_kmeans_with_explicit_k() {
let docs: Vec<Array2<f32>> = (0..50)
.map(|_| Array2::random((30, 16), Uniform::new(-1.0f32, 1.0)))
.collect();
let config = ComputeKmeansConfig {
num_partitions: Some(16),
..Default::default()
};
let centroids = compute_kmeans(&docs, &config).unwrap();
assert_eq!(centroids.nrows(), 16);
assert_eq!(centroids.ncols(), 16);
}
#[test]
fn test_estimate_num_partitions() {
let small_docs: Vec<Array2<f32>> = (0..10)
.map(|_| Array2::random((20, 16), Uniform::new(-1.0f32, 1.0)))
.collect();
let k_small = estimate_num_partitions(&small_docs);
assert!(k_small > 0);
let large_docs: Vec<Array2<f32>> = (0..1000)
.map(|_| Array2::random((50, 16), Uniform::new(-1.0f32, 1.0)))
.collect();
let k_large = estimate_num_partitions(&large_docs);
assert!(
k_large > k_small,
"Larger dataset should have more partitions"
);
}
}