1use core::f32;
13use std::cmp::Ordering;
14use std::collections::BinaryHeap;
15use std::ops::{AddAssign, DivAssign};
16use std::sync::Arc;
17use std::vec;
18use std::{collections::HashMap, ops::MulAssign};
19
20use arrow_array::{
21 cast::AsArray,
22 types::{ArrowPrimitiveType, Float16Type, Float32Type, Float64Type, UInt8Type},
23 Array, ArrayRef, FixedSizeListArray, Float32Array, PrimitiveArray, UInt32Array,
24};
25use arrow_array::{ArrowNumericType, UInt8Array};
26use arrow_ord::sort::sort_to_indices;
27use arrow_schema::{ArrowError, DataType};
28use bitvec::prelude::*;
29use lance_arrow::FixedSizeListArrayExt;
30use lance_core::utils::tokio::get_num_compute_intensive_cpus;
31use lance_linalg::distance::hamming::{hamming, hamming_distance_batch};
32use lance_linalg::distance::{dot_distance_batch, DistanceType, Normalize};
33use lance_linalg::kernels::{argmin_value_float, argmin_value_float_with_bias};
34use log::{info, warn};
35use num_traits::One;
36use num_traits::{AsPrimitive, Float, FromPrimitive, Num, Zero};
37use rand::prelude::*;
38use rayon::prelude::*;
39use snafu::location;
40use {
41 lance_linalg::distance::{
42 l2::{l2_distance_batch, L2},
43 Dot,
44 },
45 lance_linalg::kernels::argmin_value,
46};
47
48use crate::vector::utils::SimpleIndex;
49use crate::{Error, Result};
50
51#[derive(Debug, PartialEq)]
53pub enum KMeanInit {
54 Random,
55 Incremental(Arc<FixedSizeListArray>),
56}
57
58#[derive(Debug)]
60pub struct KMeansParams {
61 pub max_iters: u32,
63
64 pub tolerance: f64,
67
68 pub redos: usize,
70
71 pub init: KMeanInit,
73
74 pub distance_type: DistanceType,
76
77 pub balance_factor: f32,
83
84 pub hierarchical_k: usize,
90}
91
92impl Default for KMeansParams {
93 fn default() -> Self {
94 Self {
95 max_iters: 50,
96 tolerance: 1e-4,
97 redos: 1,
98 init: KMeanInit::Random,
99 distance_type: DistanceType::L2,
100 balance_factor: 0.0,
101 hierarchical_k: 16,
102 }
103 }
104}
105
106impl KMeansParams {
107 pub fn new(
108 centroids: Option<Arc<FixedSizeListArray>>,
109 max_iters: u32,
110 redos: usize,
111 distance_type: DistanceType,
112 ) -> Self {
113 let init = match centroids {
114 Some(centroids) => KMeanInit::Incremental(centroids),
115 None => KMeanInit::Random,
116 };
117 Self {
118 max_iters,
119 redos,
120 distance_type,
121 init,
122 ..Default::default()
123 }
124 }
125
126 pub fn with_balance_factor(mut self, balance_factor: f32) -> Self {
132 self.balance_factor = balance_factor;
133 self
134 }
135
136 pub fn with_hierarchical_k(mut self, hierarchical_k: usize) -> Self {
141 self.hierarchical_k = hierarchical_k;
142 self
143 }
144}
145
146fn kmeans_random_init<T: ArrowPrimitiveType>(
150 data: &[T::Native],
151 dimension: usize,
152 k: usize,
153 mut rng: impl Rng,
154 distance_type: DistanceType,
155) -> KMeans {
156 assert!(data.len() >= k * dimension);
157 let chosen = (0..data.len() / dimension).choose_multiple(&mut rng, k);
158 let centroids = PrimitiveArray::<T>::from_iter_values(
159 chosen
160 .iter()
161 .flat_map(|&i| data[i * dimension..(i + 1) * dimension].iter())
162 .copied(),
163 );
164 KMeans {
165 centroids: Arc::new(centroids),
166 dimension,
167 distance_type,
168 loss: f64::MAX,
169 }
170}
171
172fn split_clusters<T: Float + MulAssign>(
175 n: usize,
176 cnts: &mut [usize],
177 centroids: &mut [T],
178 dim: usize,
179) {
180 let eps = T::from(1.0 / 1024.0).unwrap();
181 let mut rng = SmallRng::from_os_rng();
182 for i in 0..cnts.len() {
183 if cnts[i] == 0 {
184 let mut j = 0;
185 loop {
186 let p = (cnts[j] as f32 - 1.0) / (n - cnts.len()) as f32;
187 if rng.random::<f32>() < p {
188 break;
189 }
190 j += 1;
191 j %= cnts.len();
192 }
193
194 cnts[i] = cnts[j] / 2;
195 cnts[j] -= cnts[i];
196 for k in 0..dim {
197 if k % 2 == 0 {
198 centroids[i * dim + k] = centroids[j * dim + k] * (T::one() + eps);
199 centroids[j * dim + k] *= T::one() - eps;
200 } else {
201 centroids[i * dim + k] = centroids[j * dim + k] * (T::one() - eps);
202 centroids[j * dim + k] *= T::one() + eps;
203 }
204 }
205 }
206 }
207}
208
209fn compute_cluster_sizes(
211 membership: &[Option<u32>],
212 radius: &[f32],
213 losses: &[f64],
214 cluster_sizes: &mut [usize],
215) -> f32 {
216 cluster_sizes.fill(0);
217 let mut max_cluster_id = 0;
218 let mut max_cluster_size = 0;
219 membership.iter().for_each(|cluster_id| {
220 if let Some(cluster_id) = cluster_id {
221 let cluster_id = *cluster_id as usize;
222 cluster_sizes[cluster_id] += 1;
223 if cluster_sizes[cluster_id] > max_cluster_size {
224 max_cluster_size = cluster_sizes[cluster_id];
225 max_cluster_id = cluster_id;
226 }
227 }
228 });
229
230 (radius[max_cluster_id] - losses[max_cluster_id] as f32 / cluster_sizes[max_cluster_id] as f32)
231 / membership.len() as f32
232}
233
234fn compute_balance_loss(cluster_sizes: &[usize], n: usize, balance_factor: f32) -> f32 {
235 let size_loss = cluster_sizes.iter().map(|size| size.pow(2)).sum::<usize>() as f32;
236 balance_factor * (size_loss - n.pow(2) as f32 / cluster_sizes.len() as f32)
237}
238
239pub trait KMeansAlgo<T: Num> {
240 fn compute_membership_and_loss(
251 centroids: &[T],
252 data: &[T],
253 dimension: usize,
254 distance_type: DistanceType,
255 balance_factor: f32,
256 cluster_sizes: Option<&[usize]>,
257 index: Option<&SimpleIndex>,
258 ) -> (Vec<Option<u32>>, Vec<f32>, Vec<f64>) {
259 let (membership, dists) = Self::compute_membership_and_dist(
260 centroids,
261 data,
262 dimension,
263 distance_type,
264 balance_factor,
265 cluster_sizes,
266 index,
267 );
268
269 let k = centroids.len() / dimension;
270 let mut cluster_radius = vec![0.0; k];
271 let mut losses = vec![0.0; k];
272 for (cluster_id, dist) in membership.iter().zip(dists.iter()) {
273 if let (Some(cluster_id), Some(dist)) = (cluster_id, dist) {
274 let cluster_id = *cluster_id as usize;
275 cluster_radius[cluster_id] = cluster_radius[cluster_id].max(*dist);
276 losses[cluster_id] += *dist as f64;
277 }
278 }
279
280 (membership, cluster_radius, losses)
281 }
282
283 fn compute_membership_and_dist(
284 centroids: &[T],
285 data: &[T],
286 dimension: usize,
287 distance_type: DistanceType,
288 balance_factor: f32,
289 cluster_sizes: Option<&[usize]>,
290 index: Option<&SimpleIndex>,
291 ) -> (Vec<Option<u32>>, Vec<Option<f32>>);
292
293 fn to_kmeans(
295 data: &[T],
296 dimension: usize,
297 k: usize,
298 membership: &[Option<u32>],
299 cluster_sizes: &mut [usize],
300 distance_type: DistanceType,
301 loss: f64,
302 ) -> KMeans;
303}
304
305pub struct KMeansAlgoFloat<T: ArrowNumericType>
306where
307 T::Native: Float + Num,
308{
309 phantom_data: std::marker::PhantomData<T>,
310}
311
312impl<T: ArrowNumericType> KMeansAlgo<T::Native> for KMeansAlgoFloat<T>
313where
314 T::Native: Float + Dot + L2 + MulAssign + DivAssign + AddAssign + FromPrimitive + Sync,
315 PrimitiveArray<T>: From<Vec<T::Native>>,
316{
317 fn compute_membership_and_dist(
318 centroids: &[T::Native],
319 data: &[T::Native],
320 dimension: usize,
321 distance_type: DistanceType,
322 balance_factor: f32,
323 cluster_sizes: Option<&[usize]>,
324 index: Option<&SimpleIndex>,
325 ) -> (Vec<Option<u32>>, Vec<Option<f32>>) {
326 let cluster_and_dists = match index {
327 Some(index) => data
328 .par_chunks(dimension)
329 .map(|vec| {
330 let query = PrimitiveArray::<T>::from_iter_values(vec.iter().copied());
331 index
333 .search(Arc::new(query))
334 .map(|(id, dist)| Some((id, dist)))
335 .unwrap()
336 })
337 .collect::<Vec<_>>(),
338 None => match distance_type {
339 DistanceType::L2 => data
340 .par_chunks(dimension)
341 .map(|vec| {
342 argmin_value_float_with_bias(
343 l2_distance_batch(vec, centroids, dimension),
344 cluster_sizes
345 .map(|size| size.iter().map(|size| balance_factor * *size as f32)),
346 )
347 })
348 .collect::<Vec<_>>(),
349 DistanceType::Dot => data
350 .par_chunks(dimension)
351 .map(|vec| {
352 argmin_value_float_with_bias(
353 dot_distance_batch(vec, centroids, dimension),
354 cluster_sizes
355 .map(|size| size.iter().map(|size| balance_factor * *size as f32)),
356 )
357 })
358 .collect::<Vec<_>>(),
359 _ => {
360 panic!(
361 "KMeans::find_partitions: {} is not supported",
362 distance_type
363 );
364 }
365 },
366 };
367
368 cluster_and_dists.into_iter().map(Option::unzip).unzip()
369 }
370
371 fn to_kmeans(
372 data: &[T::Native],
373 dimension: usize,
374 k: usize,
375 membership: &[Option<u32>],
376 cluster_sizes: &mut [usize],
377 distance_type: DistanceType,
378 loss: f64,
379 ) -> KMeans {
380 let mut centroids = vec![T::Native::zero(); k * dimension];
381
382 let mut num_cpus = get_num_compute_intensive_cpus();
383 if k < num_cpus || k < 16 {
384 num_cpus = 1;
385 }
386 let chunk_size = k / num_cpus;
387
388 centroids
389 .par_chunks_mut(dimension * chunk_size)
390 .enumerate()
391 .with_max_len(1)
392 .for_each(|(i, centroids)| {
393 let start = i * chunk_size;
394 let end = ((i + 1) * chunk_size).min(k);
395 data.chunks(dimension)
396 .zip(membership.iter())
397 .filter_map(|(vector, cluster_id)| {
398 cluster_id.map(|cluster_id| (vector, cluster_id as usize))
399 })
400 .for_each(|(vector, cluster_id)| {
401 if start <= cluster_id && cluster_id < end {
402 let local_id = cluster_id - start;
403 let centroid =
404 &mut centroids[local_id * dimension..(local_id + 1) * dimension];
405 centroid.iter_mut().zip(vector).for_each(|(c, v)| *c += *v);
406 }
407 });
408 });
409
410 centroids
411 .par_chunks_mut(dimension)
412 .zip(cluster_sizes.par_iter())
413 .for_each(|(centroid, &cnt)| {
414 if cnt > 0 {
415 let norm = T::Native::one() / T::Native::from_usize(cnt).unwrap();
416 centroid.iter_mut().for_each(|v| *v *= norm);
417 }
418 });
419
420 let empty_clusters = cluster_sizes.iter().filter(|&cnt| *cnt == 0).count();
421 if empty_clusters as f32 / k as f32 > 0.1 {
422 if data.len() / dimension < k * 256 {
423 warn!("KMeans: more than 10% of clusters are empty: {} of {}.\nHelp: this could mean your dataset \
424 is too small to have a meaningful index ({} < {}) or has many duplicate vectors.",
425 empty_clusters, k, data.len() / dimension, k * 256);
426 } else {
427 warn!("KMeans: more than 10% of clusters are empty: {} of {}.\nHelp: this could mean your dataset \
428 has many duplicate vectors.",
429 empty_clusters, k);
430 }
431 }
432
433 split_clusters(
434 data.len() / dimension,
435 cluster_sizes,
436 &mut centroids,
437 dimension,
438 );
439
440 KMeans {
441 centroids: Arc::new(PrimitiveArray::<T>::from(centroids)),
442 dimension,
443 distance_type,
444 loss,
445 }
446 }
447}
448
449struct KModeAlgo {}
450
451impl KMeansAlgo<u8> for KModeAlgo {
452 fn compute_membership_and_dist(
453 centroids: &[u8],
454 data: &[u8],
455 dimension: usize,
456 distance_type: DistanceType,
457 balance_factor: f32,
458 cluster_sizes: Option<&[usize]>,
459 _: Option<&SimpleIndex>,
460 ) -> (Vec<Option<u32>>, Vec<Option<f32>>) {
461 assert_eq!(distance_type, DistanceType::Hamming);
462 let cluster_and_dists = data
463 .par_chunks(dimension)
464 .map(|vec| {
465 argmin_value(
466 centroids
467 .chunks_exact(dimension)
468 .enumerate()
469 .map(|(id, c)| {
470 hamming(vec, c)
471 + balance_factor
472 * cluster_sizes.map(|sizes| sizes[id] as f32).unwrap_or(0.0)
473 }),
474 )
475 })
476 .collect::<Vec<_>>();
477 cluster_and_dists.into_iter().map(Option::unzip).unzip()
478 }
479
480 fn to_kmeans(
481 data: &[u8],
482 dimension: usize,
483 k: usize,
484 membership: &[Option<u32>],
485 _cluster_sizes: &mut [usize],
486 distance_type: DistanceType,
487 loss: f64,
488 ) -> KMeans {
489 assert_eq!(distance_type, DistanceType::Hamming);
490
491 let mut clusters = HashMap::<u32, Vec<usize>>::new();
492 membership.iter().enumerate().for_each(|(i, part_id)| {
493 if let Some(part_id) = part_id {
494 clusters.entry(*part_id).or_default().push(i);
495 }
496 });
497 let centroids = (0..k as u32)
498 .into_par_iter()
499 .flat_map(|part_id| {
500 if let Some(vecs) = clusters.get(&part_id) {
501 let mut ones = vec![0_u32; dimension * 8];
502 let cnt = vecs.len() as u32;
503 vecs.iter().for_each(|&i| {
504 let vec = &data[i * dimension..(i + 1) * dimension];
505 ones.iter_mut()
506 .zip(vec.view_bits::<Lsb0>())
507 .for_each(|(c, v)| {
508 if *v.as_ref() {
509 *c += 1;
510 }
511 });
512 });
513
514 let bits = ones.iter().map(|&c| c * 2 > cnt).collect::<BitVec<u8>>();
515 bits.as_raw_slice()
516 .iter()
517 .copied()
518 .map(Some)
519 .collect::<Vec<_>>()
520 } else {
521 vec![None; dimension]
522 }
523 })
524 .collect::<Vec<_>>();
525
526 KMeans {
527 centroids: Arc::new(UInt8Array::from(centroids)),
528 dimension,
529 distance_type,
530 loss,
531 }
532 }
533}
534
535#[derive(Debug, Clone)]
537pub struct KMeans {
538 pub centroids: ArrayRef,
542
543 pub dimension: usize,
545
546 pub distance_type: DistanceType,
548
549 pub loss: f64,
551}
552
553impl KMeans {
554 fn empty(dimension: usize, distance_type: DistanceType) -> Self {
555 Self {
556 centroids: arrow_array::array::new_empty_array(&DataType::Float32),
557 dimension,
558 distance_type,
559 loss: f64::MAX,
560 }
561 }
562
563 pub fn with_centroids(
566 centroids: ArrayRef,
567 dimension: usize,
568 distance_type: DistanceType,
569 loss: f64,
570 ) -> Self {
571 assert!(matches!(
572 centroids.data_type(),
573 DataType::Float16 | DataType::Float32 | DataType::Float64 | DataType::UInt8
574 ));
575 Self {
576 centroids,
577 dimension,
578 distance_type,
579 loss,
580 }
581 }
582
583 fn init_random<T: ArrowPrimitiveType>(
591 data: &[T::Native],
592 dimension: usize,
593 k: usize,
594 rng: impl Rng,
595 distance_type: DistanceType,
596 ) -> Self {
597 kmeans_random_init::<T>(data, dimension, k, rng, distance_type)
598 }
599
600 pub fn new(data: &FixedSizeListArray, k: usize, max_iters: u32) -> arrow::error::Result<Self> {
602 let params = KMeansParams {
603 max_iters,
604 distance_type: DistanceType::L2,
605 ..Default::default()
606 };
607 Self::new_with_params(data, k, ¶ms)
608 }
609
610 fn train_kmeans<T: ArrowNumericType, Algo: KMeansAlgo<T::Native>>(
611 data: &FixedSizeListArray,
612 k: usize,
613 params: &KMeansParams,
614 ) -> arrow::error::Result<Self>
615 where
616 T::Native: Num,
617 {
618 let data = if data.len() >= k * 512 {
624 data.slice(0, k * 512)
625 } else {
626 data.clone()
627 };
628
629 let n = data.len();
630 let dimension = data.value_length() as usize;
631
632 let data =
633 data.values()
634 .as_primitive_opt::<T>()
635 .ok_or(ArrowError::InvalidArgumentError(format!(
636 "KMeans: data must be {}, got: {}",
637 T::DATA_TYPE,
638 data.value_type()
639 )))?;
640
641 let mut best_kmeans = Self::empty(dimension, params.distance_type);
642 let mut cluster_sizes = vec![0; k];
643 let mut adjusted_balance_factor = f32::MAX;
644
645 let rng = SmallRng::from_os_rng();
647 for redo in 1..=params.redos {
648 let mut kmeans: Self = match ¶ms.init {
649 KMeanInit::Random => Self::init_random::<T>(
650 data.values(),
651 dimension,
652 k,
653 rng.clone(),
654 params.distance_type,
655 ),
656 KMeanInit::Incremental(centroids) => Self::with_centroids(
657 centroids.values().clone(),
658 dimension,
659 params.distance_type,
660 f64::MAX,
661 ),
662 };
663
664 let mut loss = f64::MAX;
665 for i in 1..=params.max_iters {
666 if i % 10 == 0 {
667 info!(
668 "KMeans training: iteration {} / {}, redo={}",
669 i, params.max_iters, redo
670 );
671 };
672
673 let index = SimpleIndex::may_train_index(
674 kmeans.centroids.clone(),
675 kmeans.dimension,
676 kmeans.distance_type,
677 )?;
678
679 let balance_factor = adjusted_balance_factor.min(params.balance_factor);
680 let (membership, radius, losses) = Algo::compute_membership_and_loss(
681 kmeans.centroids.as_primitive::<T>().values(),
682 data.values(),
683 dimension,
684 params.distance_type,
685 balance_factor,
686 Some(&cluster_sizes),
687 index.as_ref(),
688 );
689
690 adjusted_balance_factor =
691 compute_cluster_sizes(&membership, &radius, &losses, &mut cluster_sizes);
692 let balance_loss = compute_balance_loss(&cluster_sizes, n, balance_factor);
693 let last_loss = losses.iter().sum::<f64>() + balance_loss as f64;
694
695 kmeans = Algo::to_kmeans(
696 data.values(),
697 dimension,
698 k,
699 &membership,
700 &mut cluster_sizes,
701 params.distance_type,
702 last_loss,
703 );
704 if (loss - last_loss).abs() < params.tolerance * last_loss {
705 info!(
706 "KMeans training: converged at iteration {} / {}, redo={}, loss={}, last_loss={}, loss_diff={}",
707 i, params.max_iters, redo, loss, last_loss, (loss - last_loss).abs() / last_loss
708 );
709 break;
710 }
711 loss = last_loss;
712 }
713 if kmeans.loss < best_kmeans.loss {
714 best_kmeans = kmeans;
715 }
716 }
717
718 Ok(best_kmeans)
719 }
720
721 fn create_array_from_indices<T: ArrowNumericType>(
723 indices: &[usize],
724 data_values: &[T::Native],
725 dimension: usize,
726 ) -> arrow::error::Result<FixedSizeListArray>
727 where
728 T::Native: Clone,
729 PrimitiveArray<T>: From<Vec<T::Native>>,
730 {
731 let mut subset_data = Vec::with_capacity(indices.len() * dimension);
732 for &idx in indices {
733 let start = idx * dimension;
734 let end = start + dimension;
735 subset_data.extend_from_slice(&data_values[start..end]);
736 }
737 let array = PrimitiveArray::<T>::from(subset_data);
738 FixedSizeListArray::try_new_from_values(array, dimension as i32)
739 }
740
741 fn train_hierarchical_kmeans<T: ArrowNumericType, Algo: KMeansAlgo<T::Native>>(
747 data: &FixedSizeListArray,
748 target_k: usize,
749 params: &KMeansParams,
750 ) -> arrow::error::Result<Self>
751 where
752 T::Native: Num,
753 PrimitiveArray<T>: From<Vec<T::Native>>,
754 {
755 #[derive(Clone, Debug)]
757 struct Cluster<N> {
758 id: usize,
759 indices: Vec<usize>,
760 centroid: Vec<N>,
761 finalized: bool,
762 }
763
764 impl<N> Eq for Cluster<N> {}
765
766 impl<N> PartialEq for Cluster<N> {
767 fn eq(&self, other: &Self) -> bool {
768 self.indices.len() == other.indices.len()
769 }
770 }
771
772 impl<N> Ord for Cluster<N> {
773 fn cmp(&self, other: &Self) -> Ordering {
774 match (self.finalized, other.finalized) {
776 (false, true) => Ordering::Greater,
777 (true, false) => Ordering::Less,
778 _ => {
779 self.indices.len().cmp(&other.indices.len())
781 }
782 }
783 }
784 }
785
786 impl<N> PartialOrd for Cluster<N> {
787 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
788 Some(self.cmp(other))
789 }
790 }
791
792 let n = data.len();
793 let dimension = data.value_length() as usize;
794
795 let data_values = data
796 .values()
797 .as_primitive_opt::<T>()
798 .ok_or(ArrowError::InvalidArgumentError(format!(
799 "KMeans: data must be {}, got: {}",
800 T::DATA_TYPE,
801 data.value_type()
802 )))?
803 .values();
804
805 let initial_k = params.hierarchical_k.min(target_k).min(n);
807 info!(
808 "Hierarchical clustering: initial k={}, target k={}",
809 initial_k, target_k
810 );
811
812 let initial_kmeans = Self::train_kmeans::<T, Algo>(data, initial_k, params)?;
813
814 let (membership, _, _) = Algo::compute_membership_and_loss(
816 initial_kmeans.centroids.as_primitive::<T>().values(),
817 data_values,
818 dimension,
819 params.distance_type,
820 0.0, None,
822 None,
823 );
824
825 let mut heap: BinaryHeap<Cluster<T::Native>> = BinaryHeap::new();
827 let mut next_cluster_id = 0;
828 let initial_centroids = initial_kmeans.centroids.as_primitive::<T>().values();
829
830 for i in 0..initial_k {
831 let mut cluster_indices = Vec::new();
832 for (idx, &cluster_id) in membership.iter().enumerate() {
833 if let Some(cid) = cluster_id {
834 if cid as usize == i {
835 cluster_indices.push(idx);
836 }
837 }
838 }
839
840 if !cluster_indices.is_empty() {
841 let centroid_start = i * dimension;
842 let centroid_end = centroid_start + dimension;
843 let centroid = initial_centroids[centroid_start..centroid_end].to_vec();
844
845 heap.push(Cluster {
846 id: next_cluster_id,
847 indices: cluster_indices,
848 centroid,
849 finalized: false,
850 });
851 next_cluster_id += 1;
852 }
853 }
854
855 while heap.len() < target_k {
857 let mut largest_cluster = heap.pop().ok_or(ArrowError::InvalidArgumentError(
859 "No cluster can be further split".to_string(),
860 ))?;
861
862 if largest_cluster.finalized {
864 log::warn!("Cluster {} is already finalized, no further split is possible, finish with {} clusters", largest_cluster.id, heap.len()+ 1);
865 heap.push(largest_cluster);
866 break;
867 }
868
869 if largest_cluster.indices.len() <= 1 {
871 log::warn!("Cluster {} has only 1 point, no further split is possible, finish with {} clusters", largest_cluster.id, heap.len()+ 1);
872 heap.push(largest_cluster);
873 break;
874 }
875
876 let cluster_size = largest_cluster.indices.len();
877 log::debug!(
878 "Splitting cluster {} with {} points (current total clusters: {})",
879 largest_cluster.id,
880 cluster_size,
881 heap.len() + 1 );
883
884 let remaining_k = target_k - heap.len(); let cluster_k = if cluster_size <= params.hierarchical_k {
887 2.min(remaining_k).min(cluster_size)
888 } else {
889 let suggested_k = cluster_size / params.hierarchical_k;
891 suggested_k
892 .min(remaining_k)
893 .min(params.hierarchical_k)
894 .max(2)
895 };
896
897 let sub_data = Self::create_array_from_indices::<T>(
899 &largest_cluster.indices,
900 data_values,
901 dimension,
902 )?;
903
904 let sub_kmeans = Self::train_kmeans::<T, Algo>(&sub_data, cluster_k, params)?;
906
907 let sub_data = sub_data.values().as_primitive::<T>().values();
909 let (sub_membership, _, _) = Algo::compute_membership_and_loss(
910 sub_kmeans.centroids.as_primitive::<T>().values(),
911 sub_data,
912 dimension,
913 params.distance_type,
914 0.0,
915 None,
916 None,
917 );
918
919 let approx_cluster_capacity = if cluster_k > 0 {
921 largest_cluster.indices.len().div_ceil(cluster_k)
922 } else {
923 0
924 };
925 let mut cluster_assignments: Vec<Vec<usize>> = (0..cluster_k)
926 .map(|_| Vec::with_capacity(approx_cluster_capacity))
927 .collect();
928
929 let mut first_sid: Option<u32> = None;
930 let mut all_same = true;
931 for (local_idx, &membership) in sub_membership.iter().enumerate() {
932 let Some(sub_cluster_id) = membership else {
933 continue;
934 };
935
936 if let Some(first) = first_sid {
937 if sub_cluster_id != first {
938 all_same = false;
939 }
940 } else {
941 first_sid = Some(sub_cluster_id);
942 }
943
944 let sub_cluster_id = sub_cluster_id as usize;
945 if let Some(indices) = cluster_assignments.get_mut(sub_cluster_id) {
946 indices.push(largest_cluster.indices[local_idx]);
947 } else {
948 all_same = false;
950 }
951 }
952
953 if all_same {
955 largest_cluster.finalized = true;
956 heap.push(largest_cluster);
957 continue;
958 }
959
960 let sub_centroids = sub_kmeans.centroids.as_primitive::<T>().values();
962 for (i, new_cluster_indices) in cluster_assignments.into_iter().enumerate() {
963 if new_cluster_indices.is_empty() {
964 continue;
965 }
966
967 let centroid_start = i * dimension;
968 let centroid_end = centroid_start + dimension;
969 let centroid = sub_centroids[centroid_start..centroid_end].to_vec();
970
971 heap.push(Cluster {
972 id: next_cluster_id,
973 indices: new_cluster_indices,
974 centroid,
975 finalized: false,
976 });
977 next_cluster_id += 1;
978 }
979
980 log::debug!(
981 "Split complete: now have {} clusters (target: {})",
982 heap.len(),
983 target_k
984 );
985 }
986 debug_assert_eq!(heap.len(), target_k);
987
988 let mut all_clusters: Vec<Cluster<T::Native>> = heap.into_vec();
990 all_clusters.sort_by_key(|c| c.id);
992
993 let flat_centroids: Vec<T::Native> =
994 all_clusters.into_iter().flat_map(|c| c.centroid).collect();
995 let centroids_array = PrimitiveArray::<T>::from(flat_centroids);
996
997 Ok(Self {
998 centroids: Arc::new(centroids_array),
999 dimension,
1000 distance_type: params.distance_type,
1001 loss: 0.0, })
1003 }
1004
1005 pub fn new_with_params(
1009 data: &FixedSizeListArray,
1010 k: usize,
1011 params: &KMeansParams,
1012 ) -> arrow::error::Result<Self> {
1013 let n = data.len();
1014 if n < k {
1015 return Err(ArrowError::InvalidArgumentError(
1016 format!(
1017 "KMeans: training does not have sufficient data points: n({}) is smaller than k({})",
1018 n, k
1019 )
1020 ));
1021 }
1022
1023 if k > 256 && params.hierarchical_k > 1 {
1028 log::debug!("Using hierarchical clustering for k={}", k);
1029 return match (data.value_type(), params.distance_type) {
1030 (DataType::Float16, _) => Self::train_hierarchical_kmeans::<
1031 Float16Type,
1032 KMeansAlgoFloat<Float16Type>,
1033 >(data, k, params),
1034 (DataType::Float32, _) => Self::train_hierarchical_kmeans::<
1035 Float32Type,
1036 KMeansAlgoFloat<Float32Type>,
1037 >(data, k, params),
1038 (DataType::Float64, _) => Self::train_hierarchical_kmeans::<
1039 Float64Type,
1040 KMeansAlgoFloat<Float64Type>,
1041 >(data, k, params),
1042 (DataType::UInt8, DistanceType::Hamming) => {
1043 Self::train_hierarchical_kmeans::<UInt8Type, KModeAlgo>(data, k, params)
1044 }
1045 _ => Err(ArrowError::InvalidArgumentError(format!(
1046 "KMeans: can not train data type {} with distance type: {}",
1047 data.value_type(),
1048 params.distance_type
1049 ))),
1050 };
1051 }
1052
1053 match (data.value_type(), params.distance_type) {
1054 (DataType::Float16, _) => {
1055 Self::train_kmeans::<Float16Type, KMeansAlgoFloat<Float16Type>>(data, k, params)
1056 }
1057
1058 (DataType::Float32, _) => {
1059 Self::train_kmeans::<Float32Type, KMeansAlgoFloat<Float32Type>>(data, k, params)
1060 }
1061 (DataType::Float64, _) => {
1062 Self::train_kmeans::<Float64Type, KMeansAlgoFloat<Float64Type>>(data, k, params)
1063 }
1064 (DataType::UInt8, DistanceType::Hamming) => {
1065 Self::train_kmeans::<UInt8Type, KModeAlgo>(data, k, params)
1066 }
1067 _ => Err(ArrowError::InvalidArgumentError(format!(
1068 "KMeans: can not train data type {} with distance type: {}",
1069 data.value_type(),
1070 params.distance_type
1071 ))),
1072 }
1073 }
1074}
1075
1076pub fn kmeans_find_partitions_arrow_array(
1077 centroids: &FixedSizeListArray,
1078 query: &dyn Array,
1079 nprobes: usize,
1080 distance_type: DistanceType,
1081) -> arrow::error::Result<(UInt32Array, Float32Array)> {
1082 if centroids.value_length() as usize != query.len() {
1083 return Err(ArrowError::InvalidArgumentError(format!(
1084 "Centroids and vectors have different dimensions: {} != {}",
1085 centroids.value_length(),
1086 query.len()
1087 )));
1088 }
1089
1090 match (centroids.value_type(), query.data_type()) {
1091 (DataType::Float16, DataType::Float16) => Ok(kmeans_find_partitions(
1092 centroids.values().as_primitive::<Float16Type>().values(),
1093 query.as_primitive::<Float16Type>().values(),
1094 nprobes,
1095 distance_type,
1096 )?),
1097 (DataType::Float32, DataType::Float32) => Ok(kmeans_find_partitions(
1098 centroids.values().as_primitive::<Float32Type>().values(),
1099 query.as_primitive::<Float32Type>().values(),
1100 nprobes,
1101 distance_type,
1102 )?),
1103 (DataType::Float64, DataType::Float64) => Ok(kmeans_find_partitions(
1104 centroids.values().as_primitive::<Float64Type>().values(),
1105 query.as_primitive::<Float64Type>().values(),
1106 nprobes,
1107 distance_type,
1108 )?),
1109 (DataType::UInt8, DataType::UInt8) => Ok(kmeans_find_partitions_binary(
1110 centroids.values().as_primitive::<UInt8Type>().values(),
1111 query.as_primitive::<UInt8Type>().values(),
1112 nprobes,
1113 distance_type,
1114 )?),
1115 _ => Err(ArrowError::InvalidArgumentError(format!(
1116 "Centroids and vectors have different types: {} != {}",
1117 centroids.value_type(),
1118 query.data_type()
1119 ))),
1120 }
1121}
1122
1123pub fn kmeans_find_partitions<T: Float + L2 + Dot>(
1135 centroids: &[T],
1136 query: &[T],
1137 nprobes: usize,
1138 distance_type: DistanceType,
1139) -> arrow::error::Result<(UInt32Array, Float32Array)> {
1140 let dists: Vec<f32> = match distance_type {
1141 DistanceType::L2 => l2_distance_batch(query, centroids, query.len()).collect(),
1142 DistanceType::Dot => dot_distance_batch(query, centroids, query.len()).collect(),
1143 _ => {
1144 panic!(
1145 "KMeans::find_partitions: {} is not supported",
1146 distance_type
1147 );
1148 }
1149 };
1150
1151 let dists_arr = Float32Array::from(dists);
1153 let indices = sort_to_indices(&dists_arr, None, Some(nprobes))?;
1154 let dists = arrow::compute::take(&dists_arr, &indices, None)?
1155 .as_primitive::<Float32Type>()
1156 .clone();
1157 Ok((indices, dists))
1158}
1159
1160pub fn kmeans_find_partitions_binary(
1161 centroids: &[u8],
1162 query: &[u8],
1163 nprobes: usize,
1164 distance_type: DistanceType,
1165) -> arrow::error::Result<(UInt32Array, Float32Array)> {
1166 let dists: Vec<f32> = match distance_type {
1167 DistanceType::Hamming => hamming_distance_batch(query, centroids, query.len()).collect(),
1168 _ => {
1169 panic!(
1170 "KMeans::find_partitions: {} is not supported",
1171 distance_type
1172 );
1173 }
1174 };
1175
1176 let dists_arr = Float32Array::from(dists);
1178 let indices = sort_to_indices(&dists_arr, None, Some(nprobes))?;
1179 let dists = arrow::compute::take(&dists_arr, &indices, None)?
1180 .as_primitive::<Float32Type>()
1181 .clone();
1182 Ok((indices, dists))
1183}
1184
1185#[allow(clippy::type_complexity)]
1187pub fn compute_partitions_arrow_array(
1188 centroids: &FixedSizeListArray,
1189 vectors: &FixedSizeListArray,
1190 distance_type: DistanceType,
1191) -> arrow::error::Result<(Vec<Option<u32>>, Vec<Option<f32>>)> {
1192 if centroids.value_length() != vectors.value_length() {
1193 return Err(ArrowError::InvalidArgumentError(
1194 "Centroids and vectors have different dimensions".to_string(),
1195 ));
1196 }
1197 match (centroids.value_type(), vectors.value_type()) {
1198 (DataType::Float16, DataType::Float16) => Ok(compute_partitions_with_dists::<
1199 Float16Type,
1200 KMeansAlgoFloat<Float16Type>,
1201 >(
1202 centroids.values().as_primitive(),
1203 vectors.values().as_primitive(),
1204 centroids.value_length(),
1205 distance_type,
1206 )),
1207 (DataType::Float32, DataType::Float32) => Ok(compute_partitions_with_dists::<
1208 Float32Type,
1209 KMeansAlgoFloat<Float32Type>,
1210 >(
1211 centroids.values().as_primitive(),
1212 vectors.values().as_primitive(),
1213 centroids.value_length(),
1214 distance_type,
1215 )),
1216 (DataType::Float32, DataType::Int8) => Ok(compute_partitions_with_dists::<
1217 Float32Type,
1218 KMeansAlgoFloat<Float32Type>,
1219 >(
1220 centroids.values().as_primitive(),
1221 vectors.convert_to_floating_point()?.values().as_primitive(),
1222 centroids.value_length(),
1223 distance_type,
1224 )),
1225 (DataType::Float64, DataType::Float64) => Ok(compute_partitions_with_dists::<
1226 Float64Type,
1227 KMeansAlgoFloat<Float64Type>,
1228 >(
1229 centroids.values().as_primitive(),
1230 vectors.values().as_primitive(),
1231 centroids.value_length(),
1232 distance_type,
1233 )),
1234 (DataType::UInt8, DataType::UInt8) => {
1235 Ok(compute_partitions_with_dists::<UInt8Type, KModeAlgo>(
1236 centroids.values().as_primitive(),
1237 vectors.values().as_primitive(),
1238 centroids.value_length(),
1239 distance_type,
1240 ))
1241 }
1242 _ => Err(ArrowError::InvalidArgumentError(
1243 "Centroids and vectors have incompatible types".to_string(),
1244 )),
1245 }
1246}
1247
1248pub fn compute_partitions<T: ArrowNumericType, K: KMeansAlgo<T::Native>>(
1252 centroids: &PrimitiveArray<T>,
1253 vectors: &PrimitiveArray<T>,
1254 dimension: impl AsPrimitive<usize>,
1255 distance_type: DistanceType,
1256) -> (Vec<Option<u32>>, f64)
1257where
1258 T::Native: Num,
1259{
1260 let dimension = dimension.as_();
1261 let (membership, _, losses) = K::compute_membership_and_loss(
1262 centroids.values(),
1263 vectors.values(),
1264 dimension,
1265 distance_type,
1266 0.0,
1267 None,
1268 None,
1269 );
1270 (membership, losses.iter().sum::<f64>())
1271}
1272
1273pub fn compute_partitions_with_dists<T: ArrowNumericType, K: KMeansAlgo<T::Native>>(
1276 centroids: &PrimitiveArray<T>,
1277 vectors: &PrimitiveArray<T>,
1278 dimension: impl AsPrimitive<usize>,
1279 distance_type: DistanceType,
1280) -> (Vec<Option<u32>>, Vec<Option<f32>>)
1281where
1282 T::Native: Num,
1283{
1284 let dimension = dimension.as_();
1285 K::compute_membership_and_dist(
1286 centroids.values(),
1287 vectors.values(),
1288 dimension,
1289 distance_type,
1290 0.0,
1291 None,
1292 None,
1293 )
1294}
1295
1296#[allow(clippy::too_many_arguments)]
1309pub fn train_kmeans<T: ArrowPrimitiveType>(
1310 array: &PrimitiveArray<T>,
1311 mut params: KMeansParams,
1312 dimension: usize,
1313 k: usize,
1314 sample_rate: usize,
1315) -> Result<KMeans>
1316where
1317 T::Native: Dot + L2 + Normalize,
1318 PrimitiveArray<T>: From<Vec<T::Native>>,
1319{
1320 let num_rows = array.len() / dimension;
1321 if num_rows < k {
1322 return Err(Error::Unprocessable {
1323 message: format!(
1324 "KMeans cannot train {k} centroids with {num_rows} vectors; choose a smaller K (< {num_rows})"
1325 ),
1326 location: location!(),
1327 });
1328 }
1329
1330 let data = if num_rows > sample_rate * k {
1332 log::info!(
1333 "Sample {} out of {} to train kmeans of {} dim, {} clusters",
1334 sample_rate * k,
1335 array.len() / dimension,
1336 dimension,
1337 k,
1338 );
1339 let sample_size = sample_rate * k;
1340 array.slice(0, sample_size * dimension)
1341 } else {
1342 array.clone()
1343 };
1344
1345 let data = FixedSizeListArray::try_new_from_values(data, dimension as i32)?;
1346
1347 params.balance_factor /= data.len() as f32;
1348 let model = KMeans::new_with_params(&data, k, ¶ms)?;
1349 Ok(model)
1350}
1351
1352#[inline]
1353pub fn compute_partition<T: Float + L2 + Dot>(
1354 centroids: &[T],
1355 vector: &[T],
1356 distance_type: DistanceType,
1357) -> Option<u32> {
1358 match distance_type {
1359 DistanceType::L2 => {
1360 argmin_value_float(l2_distance_batch(vector, centroids, vector.len())).map(|(c, _)| c)
1361 }
1362 DistanceType::Dot => {
1363 argmin_value_float(dot_distance_batch(vector, centroids, vector.len())).map(|(c, _)| c)
1364 }
1365 _ => {
1366 panic!(
1367 "KMeans::compute_partition: distance type {} is not supported",
1368 distance_type
1369 );
1370 }
1371 }
1372}
1373
1374#[cfg(test)]
1375mod tests {
1376 use std::iter::repeat_n;
1377
1378 use arrow_array::types::Float16Type;
1379 use arrow_array::Float16Array;
1380 use half::f16;
1381 use lance_arrow::*;
1382 use lance_testing::datagen::generate_random_array;
1383
1384 use super::*;
1385 use lance_linalg::distance::l2;
1386 use lance_linalg::kernels::argmin;
1387
1388 #[test]
1389 fn test_train_with_small_dataset() {
1390 let data = Float32Array::from(vec![1.0, 2.0, 3.0, 4.0]);
1391 let data = FixedSizeListArray::try_new_from_values(data, 2).unwrap();
1392 match KMeans::new(&data, 128, 5) {
1393 Ok(_) => panic!("Should fail to train KMeans"),
1394 Err(e) => {
1395 assert!(e.to_string().contains("smaller than"));
1396 }
1397 }
1398 }
1399
1400 #[test]
1401 fn test_compute_partitions() {
1402 const DIM: usize = 256;
1403 let centroids = generate_random_array(DIM * 18);
1404 let data = generate_random_array(DIM * 20);
1405
1406 let expected = data
1407 .values()
1408 .chunks(DIM)
1409 .map(|row| {
1410 argmin(
1411 centroids
1412 .values()
1413 .chunks(DIM)
1414 .map(|centroid| l2(row, centroid)),
1415 )
1416 })
1417 .collect::<Vec<_>>();
1418 let (actual, _) = compute_partitions::<Float32Type, KMeansAlgoFloat<Float32Type>>(
1419 ¢roids,
1420 &data,
1421 DIM,
1422 DistanceType::L2,
1423 );
1424 assert_eq!(expected, actual);
1425 }
1426
1427 #[tokio::test]
1428 async fn test_compute_membership_and_loss() {
1429 const DIM: usize = 256;
1430 let centroids = generate_random_array(DIM * 18);
1431 let data = generate_random_array(DIM * 20);
1432
1433 let (membership, _, losses) = KMeansAlgoFloat::<Float32Type>::compute_membership_and_loss(
1434 centroids.as_slice(),
1435 data.values(),
1436 DIM,
1437 DistanceType::L2,
1438 0.0,
1439 None,
1440 None,
1441 );
1442 let loss = losses.iter().sum::<f64>();
1443 assert!(loss > 0.0, "loss is not zero: {}", loss);
1444 membership.iter().for_each(|cd| {
1445 assert!(cd.is_some());
1446 });
1447 }
1448
1449 #[tokio::test]
1450 async fn test_l2_with_nans() {
1451 const DIM: usize = 8;
1452 const K: usize = 32;
1453 const NUM_CENTROIDS: usize = 16 * 2048;
1454 let centroids = generate_random_array(DIM * NUM_CENTROIDS);
1455 let values = Float32Array::from_iter_values(repeat_n(f32::NAN, DIM * K));
1456
1457 compute_partitions::<Float32Type, KMeansAlgoFloat<Float32Type>>(
1458 ¢roids,
1459 &values,
1460 DIM,
1461 DistanceType::L2,
1462 )
1463 .0
1464 .iter()
1465 .for_each(|cd| {
1466 assert!(cd.is_none());
1467 });
1468 }
1469
1470 #[tokio::test]
1471 async fn test_train_l2_kmeans_with_nans() {
1472 const DIM: usize = 8;
1473 const K: usize = 32;
1474 const NUM_CENTROIDS: usize = 16 * 2048;
1475 let centroids = generate_random_array(DIM * NUM_CENTROIDS);
1476 let values = repeat_n(f32::NAN, DIM * K).collect::<Vec<_>>();
1477
1478 let (membership, _, _) = KMeansAlgoFloat::<Float32Type>::compute_membership_and_loss(
1479 centroids.as_slice(),
1480 &values,
1481 DIM,
1482 DistanceType::L2,
1483 0.0,
1484 None,
1485 None,
1486 );
1487
1488 membership.iter().for_each(|cd| assert!(cd.is_none()));
1489 }
1490
1491 #[tokio::test]
1492 async fn test_train_kmode() {
1493 const DIM: usize = 16;
1494 const K: usize = 32;
1495 const NUM_VALUES: usize = 256 * K;
1496
1497 let mut rng = SmallRng::from_os_rng();
1498 let values =
1499 UInt8Array::from_iter_values((0..NUM_VALUES * DIM).map(|_| rng.random_range(0..255)));
1500
1501 let fsl = FixedSizeListArray::try_new_from_values(values, DIM as i32).unwrap();
1502
1503 let params = KMeansParams {
1504 distance_type: DistanceType::Hamming,
1505 ..Default::default()
1506 };
1507 let kmeans = KMeans::new_with_params(&fsl, K, ¶ms).unwrap();
1508 assert_eq!(kmeans.centroids.len(), K * DIM);
1509 assert_eq!(kmeans.dimension, DIM);
1510 assert_eq!(kmeans.centroids.data_type(), &DataType::UInt8);
1511 }
1512
1513 #[tokio::test]
1514 async fn test_hierarchical_kmeans() {
1515 const DIM: usize = 64;
1516 const K: usize = 257; const NUM_VALUES: usize = 1024 * K;
1518
1519 let values = generate_random_array(NUM_VALUES * DIM);
1520 let fsl = FixedSizeListArray::try_new_from_values(values, DIM as i32).unwrap();
1521
1522 let params = KMeansParams {
1523 max_iters: 10,
1524 hierarchical_k: 16,
1525 ..Default::default()
1526 };
1527
1528 let kmeans = KMeans::new_with_params(&fsl, K, ¶ms).unwrap();
1529
1530 assert_eq!(kmeans.centroids.len(), K * DIM);
1532 assert_eq!(kmeans.dimension, DIM);
1533 assert_eq!(kmeans.centroids.data_type(), &DataType::Float32);
1534
1535 let centroids = kmeans.centroids.as_primitive::<Float32Type>().values();
1537 for val in centroids {
1538 assert!(!val.is_nan(), "Centroid should not contain NaN values");
1539 }
1540 }
1541
1542 #[tokio::test]
1543 async fn test_float16_underflow_fix() {
1544 const DIM: usize = 2;
1550 const K: usize = 2;
1551 const NUM_VALUES: usize = K * 65536; let f32_values = generate_random_array(NUM_VALUES * DIM);
1554 let f16_values = Float16Array::from_iter_values(
1555 f32_values.values().iter().map(|&v| half::f16::from_f32(v)),
1556 );
1557 let fsl = FixedSizeListArray::try_new_from_values(f16_values, DIM as i32).unwrap();
1558
1559 let params = KMeansParams {
1560 max_iters: 10,
1561 ..Default::default()
1562 };
1563
1564 let kmeans = KMeans::new_with_params(&fsl, K, ¶ms).unwrap();
1565
1566 assert_eq!(kmeans.centroids.len(), K * DIM);
1568 assert_eq!(kmeans.dimension, DIM);
1569 assert_eq!(kmeans.centroids.data_type(), &DataType::Float16);
1570
1571 let centroids = kmeans.centroids.as_primitive::<Float16Type>().values();
1574 for &val in centroids {
1575 assert!(!val.is_nan(), "Centroid should not contain NaN values");
1576 assert!(val != f16::ZERO);
1577 }
1578 }
1579}