lance_index/vector/
kmeans.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4//! KMeans implementation for Apache Arrow Arrays.
5//!
6//! Support ``l2``, ``cosine`` and ``dot`` distances, see [DistanceType].
7//!
8//! ``Cosine`` distance are calculated by normalizing the vectors to unit length,
9//! and run ``l2`` distance on the unit vectors.
10//!
11
12use core::f32;
13use std::cmp::Ordering;
14use std::collections::BinaryHeap;
15use std::ops::{AddAssign, DivAssign};
16use std::sync::Arc;
17use std::vec;
18use std::{collections::HashMap, ops::MulAssign};
19
20use arrow_array::{
21    cast::AsArray,
22    types::{ArrowPrimitiveType, Float16Type, Float32Type, Float64Type, UInt8Type},
23    Array, ArrayRef, FixedSizeListArray, Float32Array, PrimitiveArray, UInt32Array,
24};
25use arrow_array::{ArrowNumericType, UInt8Array};
26use arrow_ord::sort::sort_to_indices;
27use arrow_schema::{ArrowError, DataType};
28use bitvec::prelude::*;
29use lance_arrow::FixedSizeListArrayExt;
30use lance_core::utils::tokio::get_num_compute_intensive_cpus;
31use lance_linalg::distance::hamming::{hamming, hamming_distance_batch};
32use lance_linalg::distance::{dot_distance_batch, DistanceType, Normalize};
33use lance_linalg::kernels::{argmin_value_float, argmin_value_float_with_bias};
34use log::{info, warn};
35use num_traits::One;
36use num_traits::{AsPrimitive, Float, FromPrimitive, Num, Zero};
37use rand::prelude::*;
38use rayon::prelude::*;
39use snafu::location;
40use {
41    lance_linalg::distance::{
42        l2::{l2_distance_batch, L2},
43        Dot,
44    },
45    lance_linalg::kernels::argmin_value,
46};
47
48use crate::vector::utils::SimpleIndex;
49use crate::{Error, Result};
50
51/// KMean initialization method.
52#[derive(Debug, PartialEq)]
53pub enum KMeanInit {
54    Random,
55    Incremental(Arc<FixedSizeListArray>),
56}
57
58/// KMean Training Parameters
59#[derive(Debug)]
60pub struct KMeansParams {
61    /// Max number of iterations.
62    pub max_iters: u32,
63
64    /// When the difference of mean distance to the centroids is less than this `tolerance`
65    /// threshold, stop the training.
66    pub tolerance: f64,
67
68    /// Run kmeans multiple times and pick the best (balanced) one.
69    pub redos: usize,
70
71    /// Init methods.
72    pub init: KMeanInit,
73
74    /// The metric to calculate distance.
75    pub distance_type: DistanceType,
76
77    /// Balance factor for the kmeans clustering.
78    /// Higher value means more balanced clustering.
79    ///
80    /// Setting this value to 0 means no balance factor,
81    /// which is the same as normal kmeans clustering.
82    pub balance_factor: f32,
83
84    /// The number of clusters to train in each hierarchical level.
85    ///
86    /// Default is 16, which performs the best performance in our experiments.
87    /// Higher would split the clusters more aggressively, which would be more accurate but slower.
88    /// hierarchical kmeans is enabled only if hierarchical_k > 1 and k > 256.
89    pub hierarchical_k: usize,
90}
91
92impl Default for KMeansParams {
93    fn default() -> Self {
94        Self {
95            max_iters: 50,
96            tolerance: 1e-4,
97            redos: 1,
98            init: KMeanInit::Random,
99            distance_type: DistanceType::L2,
100            balance_factor: 0.0,
101            hierarchical_k: 16,
102        }
103    }
104}
105
106impl KMeansParams {
107    pub fn new(
108        centroids: Option<Arc<FixedSizeListArray>>,
109        max_iters: u32,
110        redos: usize,
111        distance_type: DistanceType,
112    ) -> Self {
113        let init = match centroids {
114            Some(centroids) => KMeanInit::Incremental(centroids),
115            None => KMeanInit::Random,
116        };
117        Self {
118            max_iters,
119            redos,
120            distance_type,
121            init,
122            ..Default::default()
123        }
124    }
125
126    /// Set the balance factor for the kmeans clustering.
127    ///
128    /// Higher value means more balanced clustering.
129    /// Setting this value to 0 means no balance factor,
130    /// which is the same as normal kmeans clustering.
131    pub fn with_balance_factor(mut self, balance_factor: f32) -> Self {
132        self.balance_factor = balance_factor;
133        self
134    }
135
136    /// Set the number of clusters to train in each hierarchical level.
137    ///
138    /// Higher would split the clusters more aggressively, which would be more accurate but slower.
139    /// hierarchical kmeans is enabled only if hierarchical_k > 1 and k > 256.
140    pub fn with_hierarchical_k(mut self, hierarchical_k: usize) -> Self {
141        self.hierarchical_k = hierarchical_k;
142        self
143    }
144}
145
146/// Randomly initialize kmeans centroids.
147///
148///
149fn kmeans_random_init<T: ArrowPrimitiveType>(
150    data: &[T::Native],
151    dimension: usize,
152    k: usize,
153    mut rng: impl Rng,
154    distance_type: DistanceType,
155) -> KMeans {
156    assert!(data.len() >= k * dimension);
157    let chosen = (0..data.len() / dimension).choose_multiple(&mut rng, k);
158    let centroids = PrimitiveArray::<T>::from_iter_values(
159        chosen
160            .iter()
161            .flat_map(|&i| data[i * dimension..(i + 1) * dimension].iter())
162            .copied(),
163    );
164    KMeans {
165        centroids: Arc::new(centroids),
166        dimension,
167        distance_type,
168        loss: f64::MAX,
169    }
170}
171
172/// Split one big cluster into two smaller clusters. After split, each
173/// cluster has approximately half of the vectors.
174fn split_clusters<T: Float + MulAssign>(
175    n: usize,
176    cnts: &mut [usize],
177    centroids: &mut [T],
178    dim: usize,
179) {
180    let eps = T::from(1.0 / 1024.0).unwrap();
181    let mut rng = SmallRng::from_os_rng();
182    for i in 0..cnts.len() {
183        if cnts[i] == 0 {
184            let mut j = 0;
185            loop {
186                let p = (cnts[j] as f32 - 1.0) / (n - cnts.len()) as f32;
187                if rng.random::<f32>() < p {
188                    break;
189                }
190                j += 1;
191                j %= cnts.len();
192            }
193
194            cnts[i] = cnts[j] / 2;
195            cnts[j] -= cnts[i];
196            for k in 0..dim {
197                if k % 2 == 0 {
198                    centroids[i * dim + k] = centroids[j * dim + k] * (T::one() + eps);
199                    centroids[j * dim + k] *= T::one() - eps;
200                } else {
201                    centroids[i * dim + k] = centroids[j * dim + k] * (T::one() - eps);
202                    centroids[j * dim + k] *= T::one() + eps;
203                }
204            }
205        }
206    }
207}
208
209// compute the cluster sizes and return adjusted balance factor
210fn compute_cluster_sizes(
211    membership: &[Option<u32>],
212    radius: &[f32],
213    losses: &[f64],
214    cluster_sizes: &mut [usize],
215) -> f32 {
216    cluster_sizes.fill(0);
217    let mut max_cluster_id = 0;
218    let mut max_cluster_size = 0;
219    membership.iter().for_each(|cluster_id| {
220        if let Some(cluster_id) = cluster_id {
221            let cluster_id = *cluster_id as usize;
222            cluster_sizes[cluster_id] += 1;
223            if cluster_sizes[cluster_id] > max_cluster_size {
224                max_cluster_size = cluster_sizes[cluster_id];
225                max_cluster_id = cluster_id;
226            }
227        }
228    });
229
230    (radius[max_cluster_id] - losses[max_cluster_id] as f32 / cluster_sizes[max_cluster_id] as f32)
231        / membership.len() as f32
232}
233
234fn compute_balance_loss(cluster_sizes: &[usize], n: usize, balance_factor: f32) -> f32 {
235    let size_loss = cluster_sizes.iter().map(|size| size.pow(2)).sum::<usize>() as f32;
236    balance_factor * (size_loss - n.pow(2) as f32 / cluster_sizes.len() as f32)
237}
238
239pub trait KMeansAlgo<T: Num> {
240    /// Recompute the membership of each vector.
241    ///
242    /// Parameters:
243    ///
244    /// - *data*: a `N * dimension` floating array. Not necessarily normalized.
245    ///
246    /// Returns:
247    /// - *membership*: the membership of each vector.
248    /// - *cluster_radius*: the radius of each cluster.
249    /// - *losses*: the losses of each cluster.
250    fn compute_membership_and_loss(
251        centroids: &[T],
252        data: &[T],
253        dimension: usize,
254        distance_type: DistanceType,
255        balance_factor: f32,
256        cluster_sizes: Option<&[usize]>,
257        index: Option<&SimpleIndex>,
258    ) -> (Vec<Option<u32>>, Vec<f32>, Vec<f64>) {
259        let (membership, dists) = Self::compute_membership_and_dist(
260            centroids,
261            data,
262            dimension,
263            distance_type,
264            balance_factor,
265            cluster_sizes,
266            index,
267        );
268
269        let k = centroids.len() / dimension;
270        let mut cluster_radius = vec![0.0; k];
271        let mut losses = vec![0.0; k];
272        for (cluster_id, dist) in membership.iter().zip(dists.iter()) {
273            if let (Some(cluster_id), Some(dist)) = (cluster_id, dist) {
274                let cluster_id = *cluster_id as usize;
275                cluster_radius[cluster_id] = cluster_radius[cluster_id].max(*dist);
276                losses[cluster_id] += *dist as f64;
277            }
278        }
279
280        (membership, cluster_radius, losses)
281    }
282
283    fn compute_membership_and_dist(
284        centroids: &[T],
285        data: &[T],
286        dimension: usize,
287        distance_type: DistanceType,
288        balance_factor: f32,
289        cluster_sizes: Option<&[usize]>,
290        index: Option<&SimpleIndex>,
291    ) -> (Vec<Option<u32>>, Vec<Option<f32>>);
292
293    /// Construct a new KMeans model.
294    fn to_kmeans(
295        data: &[T],
296        dimension: usize,
297        k: usize,
298        membership: &[Option<u32>],
299        cluster_sizes: &mut [usize],
300        distance_type: DistanceType,
301        loss: f64,
302    ) -> KMeans;
303}
304
305pub struct KMeansAlgoFloat<T: ArrowNumericType>
306where
307    T::Native: Float + Num,
308{
309    phantom_data: std::marker::PhantomData<T>,
310}
311
312impl<T: ArrowNumericType> KMeansAlgo<T::Native> for KMeansAlgoFloat<T>
313where
314    T::Native: Float + Dot + L2 + MulAssign + DivAssign + AddAssign + FromPrimitive + Sync,
315    PrimitiveArray<T>: From<Vec<T::Native>>,
316{
317    fn compute_membership_and_dist(
318        centroids: &[T::Native],
319        data: &[T::Native],
320        dimension: usize,
321        distance_type: DistanceType,
322        balance_factor: f32,
323        cluster_sizes: Option<&[usize]>,
324        index: Option<&SimpleIndex>,
325    ) -> (Vec<Option<u32>>, Vec<Option<f32>>) {
326        let cluster_and_dists = match index {
327            Some(index) => data
328                .par_chunks(dimension)
329                .map(|vec| {
330                    let query = PrimitiveArray::<T>::from_iter_values(vec.iter().copied());
331                    // unable to use balance_factor here because index.search returns the closest centroid
332                    index
333                        .search(Arc::new(query))
334                        .map(|(id, dist)| Some((id, dist)))
335                        .unwrap()
336                })
337                .collect::<Vec<_>>(),
338            None => match distance_type {
339                DistanceType::L2 => data
340                    .par_chunks(dimension)
341                    .map(|vec| {
342                        argmin_value_float_with_bias(
343                            l2_distance_batch(vec, centroids, dimension),
344                            cluster_sizes
345                                .map(|size| size.iter().map(|size| balance_factor * *size as f32)),
346                        )
347                    })
348                    .collect::<Vec<_>>(),
349                DistanceType::Dot => data
350                    .par_chunks(dimension)
351                    .map(|vec| {
352                        argmin_value_float_with_bias(
353                            dot_distance_batch(vec, centroids, dimension),
354                            cluster_sizes
355                                .map(|size| size.iter().map(|size| balance_factor * *size as f32)),
356                        )
357                    })
358                    .collect::<Vec<_>>(),
359                _ => {
360                    panic!(
361                        "KMeans::find_partitions: {} is not supported",
362                        distance_type
363                    );
364                }
365            },
366        };
367
368        cluster_and_dists.into_iter().map(Option::unzip).unzip()
369    }
370
371    fn to_kmeans(
372        data: &[T::Native],
373        dimension: usize,
374        k: usize,
375        membership: &[Option<u32>],
376        cluster_sizes: &mut [usize],
377        distance_type: DistanceType,
378        loss: f64,
379    ) -> KMeans {
380        let mut centroids = vec![T::Native::zero(); k * dimension];
381
382        let mut num_cpus = get_num_compute_intensive_cpus();
383        if k < num_cpus || k < 16 {
384            num_cpus = 1;
385        }
386        let chunk_size = k / num_cpus;
387
388        centroids
389            .par_chunks_mut(dimension * chunk_size)
390            .enumerate()
391            .with_max_len(1)
392            .for_each(|(i, centroids)| {
393                let start = i * chunk_size;
394                let end = ((i + 1) * chunk_size).min(k);
395                data.chunks(dimension)
396                    .zip(membership.iter())
397                    .filter_map(|(vector, cluster_id)| {
398                        cluster_id.map(|cluster_id| (vector, cluster_id as usize))
399                    })
400                    .for_each(|(vector, cluster_id)| {
401                        if start <= cluster_id && cluster_id < end {
402                            let local_id = cluster_id - start;
403                            let centroid =
404                                &mut centroids[local_id * dimension..(local_id + 1) * dimension];
405                            centroid.iter_mut().zip(vector).for_each(|(c, v)| *c += *v);
406                        }
407                    });
408            });
409
410        centroids
411            .par_chunks_mut(dimension)
412            .zip(cluster_sizes.par_iter())
413            .for_each(|(centroid, &cnt)| {
414                if cnt > 0 {
415                    let norm = T::Native::one() / T::Native::from_usize(cnt).unwrap();
416                    centroid.iter_mut().for_each(|v| *v *= norm);
417                }
418            });
419
420        let empty_clusters = cluster_sizes.iter().filter(|&cnt| *cnt == 0).count();
421        if empty_clusters as f32 / k as f32 > 0.1 {
422            if data.len() / dimension < k * 256 {
423                warn!("KMeans: more than 10% of clusters are empty: {} of {}.\nHelp: this could mean your dataset \
424                is too small to have a meaningful index ({} < {}) or has many duplicate vectors.",
425                empty_clusters, k, data.len() / dimension, k * 256);
426            } else {
427                warn!("KMeans: more than 10% of clusters are empty: {} of {}.\nHelp: this could mean your dataset \
428                has many duplicate vectors.",
429                empty_clusters, k);
430            }
431        }
432
433        split_clusters(
434            data.len() / dimension,
435            cluster_sizes,
436            &mut centroids,
437            dimension,
438        );
439
440        KMeans {
441            centroids: Arc::new(PrimitiveArray::<T>::from(centroids)),
442            dimension,
443            distance_type,
444            loss,
445        }
446    }
447}
448
449struct KModeAlgo {}
450
451impl KMeansAlgo<u8> for KModeAlgo {
452    fn compute_membership_and_dist(
453        centroids: &[u8],
454        data: &[u8],
455        dimension: usize,
456        distance_type: DistanceType,
457        balance_factor: f32,
458        cluster_sizes: Option<&[usize]>,
459        _: Option<&SimpleIndex>,
460    ) -> (Vec<Option<u32>>, Vec<Option<f32>>) {
461        assert_eq!(distance_type, DistanceType::Hamming);
462        let cluster_and_dists = data
463            .par_chunks(dimension)
464            .map(|vec| {
465                argmin_value(
466                    centroids
467                        .chunks_exact(dimension)
468                        .enumerate()
469                        .map(|(id, c)| {
470                            hamming(vec, c)
471                                + balance_factor
472                                    * cluster_sizes.map(|sizes| sizes[id] as f32).unwrap_or(0.0)
473                        }),
474                )
475            })
476            .collect::<Vec<_>>();
477        cluster_and_dists.into_iter().map(Option::unzip).unzip()
478    }
479
480    fn to_kmeans(
481        data: &[u8],
482        dimension: usize,
483        k: usize,
484        membership: &[Option<u32>],
485        _cluster_sizes: &mut [usize],
486        distance_type: DistanceType,
487        loss: f64,
488    ) -> KMeans {
489        assert_eq!(distance_type, DistanceType::Hamming);
490
491        let mut clusters = HashMap::<u32, Vec<usize>>::new();
492        membership.iter().enumerate().for_each(|(i, part_id)| {
493            if let Some(part_id) = part_id {
494                clusters.entry(*part_id).or_default().push(i);
495            }
496        });
497        let centroids = (0..k as u32)
498            .into_par_iter()
499            .flat_map(|part_id| {
500                if let Some(vecs) = clusters.get(&part_id) {
501                    let mut ones = vec![0_u32; dimension * 8];
502                    let cnt = vecs.len() as u32;
503                    vecs.iter().for_each(|&i| {
504                        let vec = &data[i * dimension..(i + 1) * dimension];
505                        ones.iter_mut()
506                            .zip(vec.view_bits::<Lsb0>())
507                            .for_each(|(c, v)| {
508                                if *v.as_ref() {
509                                    *c += 1;
510                                }
511                            });
512                    });
513
514                    let bits = ones.iter().map(|&c| c * 2 > cnt).collect::<BitVec<u8>>();
515                    bits.as_raw_slice()
516                        .iter()
517                        .copied()
518                        .map(Some)
519                        .collect::<Vec<_>>()
520                } else {
521                    vec![None; dimension]
522                }
523            })
524            .collect::<Vec<_>>();
525
526        KMeans {
527            centroids: Arc::new(UInt8Array::from(centroids)),
528            dimension,
529            distance_type,
530            loss,
531        }
532    }
533}
534
535/// KMeans implementation for Apache Arrow Arrays.
536#[derive(Debug, Clone)]
537pub struct KMeans {
538    /// Flattened array of centroids.
539    ///
540    /// dimension * k of floating number.
541    pub centroids: ArrayRef,
542
543    /// The dimension of each vector.
544    pub dimension: usize,
545
546    /// How to calculate distance between two vectors.
547    pub distance_type: DistanceType,
548
549    /// The loss of the last training.
550    pub loss: f64,
551}
552
553impl KMeans {
554    fn empty(dimension: usize, distance_type: DistanceType) -> Self {
555        Self {
556            centroids: arrow_array::array::new_empty_array(&DataType::Float32),
557            dimension,
558            distance_type,
559            loss: f64::MAX,
560        }
561    }
562
563    /// Create a [`KMeans`] with existing centroids.
564    /// It is useful for continuing training.
565    pub fn with_centroids(
566        centroids: ArrayRef,
567        dimension: usize,
568        distance_type: DistanceType,
569        loss: f64,
570    ) -> Self {
571        assert!(matches!(
572            centroids.data_type(),
573            DataType::Float16 | DataType::Float32 | DataType::Float64 | DataType::UInt8
574        ));
575        Self {
576            centroids,
577            dimension,
578            distance_type,
579            loss,
580        }
581    }
582
583    /// Initialize a [`KMeans`] with random centroids.
584    ///
585    /// Parameters
586    /// - *data*: training data. provided to do samplings.
587    /// - *k*: the number of clusters.
588    /// - *distance_type*: the distance type to calculate distance.
589    /// - *rng*: random generator.
590    fn init_random<T: ArrowPrimitiveType>(
591        data: &[T::Native],
592        dimension: usize,
593        k: usize,
594        rng: impl Rng,
595        distance_type: DistanceType,
596    ) -> Self {
597        kmeans_random_init::<T>(data, dimension, k, rng, distance_type)
598    }
599
600    /// Train a KMeans model on data with `k` clusters.
601    pub fn new(data: &FixedSizeListArray, k: usize, max_iters: u32) -> arrow::error::Result<Self> {
602        let params = KMeansParams {
603            max_iters,
604            distance_type: DistanceType::L2,
605            ..Default::default()
606        };
607        Self::new_with_params(data, k, &params)
608    }
609
610    fn train_kmeans<T: ArrowNumericType, Algo: KMeansAlgo<T::Native>>(
611        data: &FixedSizeListArray,
612        k: usize,
613        params: &KMeansParams,
614    ) -> arrow::error::Result<Self>
615    where
616        T::Native: Num,
617    {
618        // the data is `num_partitions * sample_rate` vectors,
619        // but here `k` may be not `num_partitions` in the case of hierarchical kmeans,
620        // so we need to sample the sampled data again here.
621        // we have to limit the number of data to avoid division underflow,
622        // the threshold 512 is chosen because the minimal normal f16 value will be 0 if divided by 1024.
623        let data = if data.len() >= k * 512 {
624            data.slice(0, k * 512)
625        } else {
626            data.clone()
627        };
628
629        let n = data.len();
630        let dimension = data.value_length() as usize;
631
632        let data =
633            data.values()
634                .as_primitive_opt::<T>()
635                .ok_or(ArrowError::InvalidArgumentError(format!(
636                    "KMeans: data must be {}, got: {}",
637                    T::DATA_TYPE,
638                    data.value_type()
639                )))?;
640
641        let mut best_kmeans = Self::empty(dimension, params.distance_type);
642        let mut cluster_sizes = vec![0; k];
643        let mut adjusted_balance_factor = f32::MAX;
644
645        // TODO: use seed for Rng.
646        let rng = SmallRng::from_os_rng();
647        for redo in 1..=params.redos {
648            let mut kmeans: Self = match &params.init {
649                KMeanInit::Random => Self::init_random::<T>(
650                    data.values(),
651                    dimension,
652                    k,
653                    rng.clone(),
654                    params.distance_type,
655                ),
656                KMeanInit::Incremental(centroids) => Self::with_centroids(
657                    centroids.values().clone(),
658                    dimension,
659                    params.distance_type,
660                    f64::MAX,
661                ),
662            };
663
664            let mut loss = f64::MAX;
665            for i in 1..=params.max_iters {
666                if i % 10 == 0 {
667                    info!(
668                        "KMeans training: iteration {} / {}, redo={}",
669                        i, params.max_iters, redo
670                    );
671                };
672
673                let index = SimpleIndex::may_train_index(
674                    kmeans.centroids.clone(),
675                    kmeans.dimension,
676                    kmeans.distance_type,
677                )?;
678
679                let balance_factor = adjusted_balance_factor.min(params.balance_factor);
680                let (membership, radius, losses) = Algo::compute_membership_and_loss(
681                    kmeans.centroids.as_primitive::<T>().values(),
682                    data.values(),
683                    dimension,
684                    params.distance_type,
685                    balance_factor,
686                    Some(&cluster_sizes),
687                    index.as_ref(),
688                );
689
690                adjusted_balance_factor =
691                    compute_cluster_sizes(&membership, &radius, &losses, &mut cluster_sizes);
692                let balance_loss = compute_balance_loss(&cluster_sizes, n, balance_factor);
693                let last_loss = losses.iter().sum::<f64>() + balance_loss as f64;
694
695                kmeans = Algo::to_kmeans(
696                    data.values(),
697                    dimension,
698                    k,
699                    &membership,
700                    &mut cluster_sizes,
701                    params.distance_type,
702                    last_loss,
703                );
704                if (loss - last_loss).abs() < params.tolerance * last_loss {
705                    info!(
706                        "KMeans training: converged at iteration {} / {}, redo={}, loss={}, last_loss={}, loss_diff={}",
707                        i, params.max_iters, redo, loss, last_loss, (loss - last_loss).abs() / last_loss
708                    );
709                    break;
710                }
711                loss = last_loss;
712            }
713            if kmeans.loss < best_kmeans.loss {
714                best_kmeans = kmeans;
715            }
716        }
717
718        Ok(best_kmeans)
719    }
720
721    /// Helper function to create a FixedSizeListArray from indices
722    fn create_array_from_indices<T: ArrowNumericType>(
723        indices: &[usize],
724        data_values: &[T::Native],
725        dimension: usize,
726    ) -> arrow::error::Result<FixedSizeListArray>
727    where
728        T::Native: Clone,
729        PrimitiveArray<T>: From<Vec<T::Native>>,
730    {
731        let mut subset_data = Vec::with_capacity(indices.len() * dimension);
732        for &idx in indices {
733            let start = idx * dimension;
734            let end = start + dimension;
735            subset_data.extend_from_slice(&data_values[start..end]);
736        }
737        let array = PrimitiveArray::<T>::from(subset_data);
738        FixedSizeListArray::try_new_from_values(array, dimension as i32)
739    }
740
741    /// Train a hierarchical KMeans model when k > 256
742    ///
743    /// This function implements a hierarchical clustering approach:
744    /// 1. Start with k'=256 initial clusters
745    /// 2. Iteratively split the largest cluster until we have k clusters
746    fn train_hierarchical_kmeans<T: ArrowNumericType, Algo: KMeansAlgo<T::Native>>(
747        data: &FixedSizeListArray,
748        target_k: usize,
749        params: &KMeansParams,
750    ) -> arrow::error::Result<Self>
751    where
752        T::Native: Num,
753        PrimitiveArray<T>: From<Vec<T::Native>>,
754    {
755        // Cluster structure for the heap
756        #[derive(Clone, Debug)]
757        struct Cluster<N> {
758            id: usize,
759            indices: Vec<usize>,
760            centroid: Vec<N>,
761            finalized: bool,
762        }
763
764        impl<N> Eq for Cluster<N> {}
765
766        impl<N> PartialEq for Cluster<N> {
767            fn eq(&self, other: &Self) -> bool {
768                self.indices.len() == other.indices.len()
769            }
770        }
771
772        impl<N> Ord for Cluster<N> {
773            fn cmp(&self, other: &Self) -> Ordering {
774                // Non-finalized clusters should always have higher priority than finalized ones
775                match (self.finalized, other.finalized) {
776                    (false, true) => Ordering::Greater,
777                    (true, false) => Ordering::Less,
778                    _ => {
779                        // Max heap: larger clusters first
780                        self.indices.len().cmp(&other.indices.len())
781                    }
782                }
783            }
784        }
785
786        impl<N> PartialOrd for Cluster<N> {
787            fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
788                Some(self.cmp(other))
789            }
790        }
791
792        let n = data.len();
793        let dimension = data.value_length() as usize;
794
795        let data_values = data
796            .values()
797            .as_primitive_opt::<T>()
798            .ok_or(ArrowError::InvalidArgumentError(format!(
799                "KMeans: data must be {}, got: {}",
800                T::DATA_TYPE,
801                data.value_type()
802            )))?
803            .values();
804
805        // Initial clustering with k'=16
806        let initial_k = params.hierarchical_k.min(target_k).min(n);
807        info!(
808            "Hierarchical clustering: initial k={}, target k={}",
809            initial_k, target_k
810        );
811
812        let initial_kmeans = Self::train_kmeans::<T, Algo>(data, initial_k, params)?;
813
814        // Get membership for all data points
815        let (membership, _, _) = Algo::compute_membership_and_loss(
816            initial_kmeans.centroids.as_primitive::<T>().values(),
817            data_values,
818            dimension,
819            params.distance_type,
820            0.0, // No balance factor for membership computation
821            None,
822            None,
823        );
824
825        // Build initial clusters and add to heap
826        let mut heap: BinaryHeap<Cluster<T::Native>> = BinaryHeap::new();
827        let mut next_cluster_id = 0;
828        let initial_centroids = initial_kmeans.centroids.as_primitive::<T>().values();
829
830        for i in 0..initial_k {
831            let mut cluster_indices = Vec::new();
832            for (idx, &cluster_id) in membership.iter().enumerate() {
833                if let Some(cid) = cluster_id {
834                    if cid as usize == i {
835                        cluster_indices.push(idx);
836                    }
837                }
838            }
839
840            if !cluster_indices.is_empty() {
841                let centroid_start = i * dimension;
842                let centroid_end = centroid_start + dimension;
843                let centroid = initial_centroids[centroid_start..centroid_end].to_vec();
844
845                heap.push(Cluster {
846                    id: next_cluster_id,
847                    indices: cluster_indices,
848                    centroid,
849                    finalized: false,
850                });
851                next_cluster_id += 1;
852            }
853        }
854
855        // Iteratively split largest clusters until we have target_k clusters
856        while heap.len() < target_k {
857            // Get the largest cluster
858            let mut largest_cluster = heap.pop().ok_or(ArrowError::InvalidArgumentError(
859                "No cluster can be further split".to_string(),
860            ))?;
861
862            // If this cluster is already finalized, no further split is possible; stop splitting
863            if largest_cluster.finalized {
864                log::warn!("Cluster {} is already finalized, no further split is possible, finish with {} clusters", largest_cluster.id, heap.len()+ 1);
865                heap.push(largest_cluster);
866                break;
867            }
868
869            // Because the clusters are sorted by size, if the cluster has only 1 point, no further split is possible; stop splitting
870            if largest_cluster.indices.len() <= 1 {
871                log::warn!("Cluster {} has only 1 point, no further split is possible, finish with {} clusters", largest_cluster.id, heap.len()+ 1);
872                heap.push(largest_cluster);
873                break;
874            }
875
876            let cluster_size = largest_cluster.indices.len();
877            log::debug!(
878                "Splitting cluster {} with {} points (current total clusters: {})",
879                largest_cluster.id,
880                cluster_size,
881                heap.len() + 1 // +1 for the cluster we just popped
882            );
883
884            // Determine k' for this cluster based on its size
885            let remaining_k = target_k - heap.len(); // Spaces left to fill
886            let cluster_k = if cluster_size <= params.hierarchical_k {
887                2.min(remaining_k).min(cluster_size)
888            } else {
889                // For larger clusters, split more aggressively
890                let suggested_k = cluster_size / params.hierarchical_k;
891                suggested_k
892                    .min(remaining_k)
893                    .min(params.hierarchical_k)
894                    .max(2)
895            };
896
897            // Create sub-dataset for this cluster using indices
898            let sub_data = Self::create_array_from_indices::<T>(
899                &largest_cluster.indices,
900                data_values,
901                dimension,
902            )?;
903
904            // Run kmeans on this cluster
905            let sub_kmeans = Self::train_kmeans::<T, Algo>(&sub_data, cluster_k, params)?;
906
907            // Get membership for points in the sub-cluster
908            let sub_data = sub_data.values().as_primitive::<T>().values();
909            let (sub_membership, _, _) = Algo::compute_membership_and_loss(
910                sub_kmeans.centroids.as_primitive::<T>().values(),
911                sub_data,
912                dimension,
913                params.distance_type,
914                0.0,
915                None,
916                None,
917            );
918
919            // Build per-cluster membership while checking whether the split is effective
920            let approx_cluster_capacity = if cluster_k > 0 {
921                largest_cluster.indices.len().div_ceil(cluster_k)
922            } else {
923                0
924            };
925            let mut cluster_assignments: Vec<Vec<usize>> = (0..cluster_k)
926                .map(|_| Vec::with_capacity(approx_cluster_capacity))
927                .collect();
928
929            let mut first_sid: Option<u32> = None;
930            let mut all_same = true;
931            for (local_idx, &membership) in sub_membership.iter().enumerate() {
932                let Some(sub_cluster_id) = membership else {
933                    continue;
934                };
935
936                if let Some(first) = first_sid {
937                    if sub_cluster_id != first {
938                        all_same = false;
939                    }
940                } else {
941                    first_sid = Some(sub_cluster_id);
942                }
943
944                let sub_cluster_id = sub_cluster_id as usize;
945                if let Some(indices) = cluster_assignments.get_mut(sub_cluster_id) {
946                    indices.push(largest_cluster.indices[local_idx]);
947                } else {
948                    // Unexpected assignment outside [0, cluster_k); treat as ineffective split.
949                    all_same = false;
950                }
951            }
952
953            // If all memberships are identical, the split is ineffective; finalize the original cluster
954            if all_same {
955                largest_cluster.finalized = true;
956                heap.push(largest_cluster);
957                continue;
958            }
959
960            // Create new sub-clusters and add to heap
961            let sub_centroids = sub_kmeans.centroids.as_primitive::<T>().values();
962            for (i, new_cluster_indices) in cluster_assignments.into_iter().enumerate() {
963                if new_cluster_indices.is_empty() {
964                    continue;
965                }
966
967                let centroid_start = i * dimension;
968                let centroid_end = centroid_start + dimension;
969                let centroid = sub_centroids[centroid_start..centroid_end].to_vec();
970
971                heap.push(Cluster {
972                    id: next_cluster_id,
973                    indices: new_cluster_indices,
974                    centroid,
975                    finalized: false,
976                });
977                next_cluster_id += 1;
978            }
979
980            log::debug!(
981                "Split complete: now have {} clusters (target: {})",
982                heap.len(),
983                target_k
984            );
985        }
986        debug_assert_eq!(heap.len(), target_k);
987
988        // Construct final KMeans model with all centroids
989        let mut all_clusters: Vec<Cluster<T::Native>> = heap.into_vec();
990        // Sort by ID to ensure consistent ordering
991        all_clusters.sort_by_key(|c| c.id);
992
993        let flat_centroids: Vec<T::Native> =
994            all_clusters.into_iter().flat_map(|c| c.centroid).collect();
995        let centroids_array = PrimitiveArray::<T>::from(flat_centroids);
996
997        Ok(Self {
998            centroids: Arc::new(centroids_array),
999            dimension,
1000            distance_type: params.distance_type,
1001            loss: 0.0, // Loss is not meaningful for hierarchical clustering
1002        })
1003    }
1004
1005    /// Train a [`KMeans`] model with full parameters.
1006    ///
1007    /// If the DistanceType is `Cosine`, the input vectors will be normalized with each iteration.
1008    pub fn new_with_params(
1009        data: &FixedSizeListArray,
1010        k: usize,
1011        params: &KMeansParams,
1012    ) -> arrow::error::Result<Self> {
1013        let n = data.len();
1014        if n < k {
1015            return Err(ArrowError::InvalidArgumentError(
1016                format!(
1017                    "KMeans: training does not have sufficient data points: n({}) is smaller than k({})",
1018                    n, k
1019                )
1020            ));
1021        }
1022
1023        // use hierarchical clustering if k > 256 and hierarchical_k > 1
1024        // we set 256 as the threshold because:
1025        // 1. PQ would run kmeans with k=256, in that case we don't want to use hierarchical clustering for accuracy
1026        // 2. kmeans with k=256 is small enough that we don't need to use hierarchical clustering for efficiency
1027        if k > 256 && params.hierarchical_k > 1 {
1028            log::debug!("Using hierarchical clustering for k={}", k);
1029            return match (data.value_type(), params.distance_type) {
1030                (DataType::Float16, _) => Self::train_hierarchical_kmeans::<
1031                    Float16Type,
1032                    KMeansAlgoFloat<Float16Type>,
1033                >(data, k, params),
1034                (DataType::Float32, _) => Self::train_hierarchical_kmeans::<
1035                    Float32Type,
1036                    KMeansAlgoFloat<Float32Type>,
1037                >(data, k, params),
1038                (DataType::Float64, _) => Self::train_hierarchical_kmeans::<
1039                    Float64Type,
1040                    KMeansAlgoFloat<Float64Type>,
1041                >(data, k, params),
1042                (DataType::UInt8, DistanceType::Hamming) => {
1043                    Self::train_hierarchical_kmeans::<UInt8Type, KModeAlgo>(data, k, params)
1044                }
1045                _ => Err(ArrowError::InvalidArgumentError(format!(
1046                    "KMeans: can not train data type {} with distance type: {}",
1047                    data.value_type(),
1048                    params.distance_type
1049                ))),
1050            };
1051        }
1052
1053        match (data.value_type(), params.distance_type) {
1054            (DataType::Float16, _) => {
1055                Self::train_kmeans::<Float16Type, KMeansAlgoFloat<Float16Type>>(data, k, params)
1056            }
1057
1058            (DataType::Float32, _) => {
1059                Self::train_kmeans::<Float32Type, KMeansAlgoFloat<Float32Type>>(data, k, params)
1060            }
1061            (DataType::Float64, _) => {
1062                Self::train_kmeans::<Float64Type, KMeansAlgoFloat<Float64Type>>(data, k, params)
1063            }
1064            (DataType::UInt8, DistanceType::Hamming) => {
1065                Self::train_kmeans::<UInt8Type, KModeAlgo>(data, k, params)
1066            }
1067            _ => Err(ArrowError::InvalidArgumentError(format!(
1068                "KMeans: can not train data type {} with distance type: {}",
1069                data.value_type(),
1070                params.distance_type
1071            ))),
1072        }
1073    }
1074}
1075
1076pub fn kmeans_find_partitions_arrow_array(
1077    centroids: &FixedSizeListArray,
1078    query: &dyn Array,
1079    nprobes: usize,
1080    distance_type: DistanceType,
1081) -> arrow::error::Result<(UInt32Array, Float32Array)> {
1082    if centroids.value_length() as usize != query.len() {
1083        return Err(ArrowError::InvalidArgumentError(format!(
1084            "Centroids and vectors have different dimensions: {} != {}",
1085            centroids.value_length(),
1086            query.len()
1087        )));
1088    }
1089
1090    match (centroids.value_type(), query.data_type()) {
1091        (DataType::Float16, DataType::Float16) => Ok(kmeans_find_partitions(
1092            centroids.values().as_primitive::<Float16Type>().values(),
1093            query.as_primitive::<Float16Type>().values(),
1094            nprobes,
1095            distance_type,
1096        )?),
1097        (DataType::Float32, DataType::Float32) => Ok(kmeans_find_partitions(
1098            centroids.values().as_primitive::<Float32Type>().values(),
1099            query.as_primitive::<Float32Type>().values(),
1100            nprobes,
1101            distance_type,
1102        )?),
1103        (DataType::Float64, DataType::Float64) => Ok(kmeans_find_partitions(
1104            centroids.values().as_primitive::<Float64Type>().values(),
1105            query.as_primitive::<Float64Type>().values(),
1106            nprobes,
1107            distance_type,
1108        )?),
1109        (DataType::UInt8, DataType::UInt8) => Ok(kmeans_find_partitions_binary(
1110            centroids.values().as_primitive::<UInt8Type>().values(),
1111            query.as_primitive::<UInt8Type>().values(),
1112            nprobes,
1113            distance_type,
1114        )?),
1115        _ => Err(ArrowError::InvalidArgumentError(format!(
1116            "Centroids and vectors have different types: {} != {}",
1117            centroids.value_type(),
1118            query.data_type()
1119        ))),
1120    }
1121}
1122
1123/// KMeans finds N nearest partitions.
1124///
1125/// Parameters:
1126/// - *centroids*: a `k * dimension` floating array.
1127/// - *query*: a `dimension` floating array.
1128/// - *nprobes*: the number of partitions to find.
1129/// - *distance_type*: the distance type to calculate distance.
1130///
1131/// This function allows to conduct kmeans search without constructing
1132/// `Arrow Array` or `Vec<Float>` types.
1133///
1134pub fn kmeans_find_partitions<T: Float + L2 + Dot>(
1135    centroids: &[T],
1136    query: &[T],
1137    nprobes: usize,
1138    distance_type: DistanceType,
1139) -> arrow::error::Result<(UInt32Array, Float32Array)> {
1140    let dists: Vec<f32> = match distance_type {
1141        DistanceType::L2 => l2_distance_batch(query, centroids, query.len()).collect(),
1142        DistanceType::Dot => dot_distance_batch(query, centroids, query.len()).collect(),
1143        _ => {
1144            panic!(
1145                "KMeans::find_partitions: {} is not supported",
1146                distance_type
1147            );
1148        }
1149    };
1150
1151    // TODO: use heap to just keep nprobes smallest values.
1152    let dists_arr = Float32Array::from(dists);
1153    let indices = sort_to_indices(&dists_arr, None, Some(nprobes))?;
1154    let dists = arrow::compute::take(&dists_arr, &indices, None)?
1155        .as_primitive::<Float32Type>()
1156        .clone();
1157    Ok((indices, dists))
1158}
1159
1160pub fn kmeans_find_partitions_binary(
1161    centroids: &[u8],
1162    query: &[u8],
1163    nprobes: usize,
1164    distance_type: DistanceType,
1165) -> arrow::error::Result<(UInt32Array, Float32Array)> {
1166    let dists: Vec<f32> = match distance_type {
1167        DistanceType::Hamming => hamming_distance_batch(query, centroids, query.len()).collect(),
1168        _ => {
1169            panic!(
1170                "KMeans::find_partitions: {} is not supported",
1171                distance_type
1172            );
1173        }
1174    };
1175
1176    // TODO: use heap to just keep nprobes smallest values.
1177    let dists_arr = Float32Array::from(dists);
1178    let indices = sort_to_indices(&dists_arr, None, Some(nprobes))?;
1179    let dists = arrow::compute::take(&dists_arr, &indices, None)?
1180        .as_primitive::<Float32Type>()
1181        .clone();
1182    Ok((indices, dists))
1183}
1184
1185/// Compute partitions from Arrow FixedSizeListArray.
1186#[allow(clippy::type_complexity)]
1187pub fn compute_partitions_arrow_array(
1188    centroids: &FixedSizeListArray,
1189    vectors: &FixedSizeListArray,
1190    distance_type: DistanceType,
1191) -> arrow::error::Result<(Vec<Option<u32>>, Vec<Option<f32>>)> {
1192    if centroids.value_length() != vectors.value_length() {
1193        return Err(ArrowError::InvalidArgumentError(
1194            "Centroids and vectors have different dimensions".to_string(),
1195        ));
1196    }
1197    match (centroids.value_type(), vectors.value_type()) {
1198        (DataType::Float16, DataType::Float16) => Ok(compute_partitions_with_dists::<
1199            Float16Type,
1200            KMeansAlgoFloat<Float16Type>,
1201        >(
1202            centroids.values().as_primitive(),
1203            vectors.values().as_primitive(),
1204            centroids.value_length(),
1205            distance_type,
1206        )),
1207        (DataType::Float32, DataType::Float32) => Ok(compute_partitions_with_dists::<
1208            Float32Type,
1209            KMeansAlgoFloat<Float32Type>,
1210        >(
1211            centroids.values().as_primitive(),
1212            vectors.values().as_primitive(),
1213            centroids.value_length(),
1214            distance_type,
1215        )),
1216        (DataType::Float32, DataType::Int8) => Ok(compute_partitions_with_dists::<
1217            Float32Type,
1218            KMeansAlgoFloat<Float32Type>,
1219        >(
1220            centroids.values().as_primitive(),
1221            vectors.convert_to_floating_point()?.values().as_primitive(),
1222            centroids.value_length(),
1223            distance_type,
1224        )),
1225        (DataType::Float64, DataType::Float64) => Ok(compute_partitions_with_dists::<
1226            Float64Type,
1227            KMeansAlgoFloat<Float64Type>,
1228        >(
1229            centroids.values().as_primitive(),
1230            vectors.values().as_primitive(),
1231            centroids.value_length(),
1232            distance_type,
1233        )),
1234        (DataType::UInt8, DataType::UInt8) => {
1235            Ok(compute_partitions_with_dists::<UInt8Type, KModeAlgo>(
1236                centroids.values().as_primitive(),
1237                vectors.values().as_primitive(),
1238                centroids.value_length(),
1239                distance_type,
1240            ))
1241        }
1242        _ => Err(ArrowError::InvalidArgumentError(
1243            "Centroids and vectors have incompatible types".to_string(),
1244        )),
1245    }
1246}
1247
1248/// Compute partition ID of each vector in the KMeans.
1249///
1250/// If returns `None`, means the vector is not valid, i.e., all `NaN`.
1251pub fn compute_partitions<T: ArrowNumericType, K: KMeansAlgo<T::Native>>(
1252    centroids: &PrimitiveArray<T>,
1253    vectors: &PrimitiveArray<T>,
1254    dimension: impl AsPrimitive<usize>,
1255    distance_type: DistanceType,
1256) -> (Vec<Option<u32>>, f64)
1257where
1258    T::Native: Num,
1259{
1260    let dimension = dimension.as_();
1261    let (membership, _, losses) = K::compute_membership_and_loss(
1262        centroids.values(),
1263        vectors.values(),
1264        dimension,
1265        distance_type,
1266        0.0,
1267        None,
1268        None,
1269    );
1270    (membership, losses.iter().sum::<f64>())
1271}
1272
1273/// compute the partition id and the distance to the centroid for each vector,
1274/// NOTE the distance is squared distance for L2
1275pub fn compute_partitions_with_dists<T: ArrowNumericType, K: KMeansAlgo<T::Native>>(
1276    centroids: &PrimitiveArray<T>,
1277    vectors: &PrimitiveArray<T>,
1278    dimension: impl AsPrimitive<usize>,
1279    distance_type: DistanceType,
1280) -> (Vec<Option<u32>>, Vec<Option<f32>>)
1281where
1282    T::Native: Num,
1283{
1284    let dimension = dimension.as_();
1285    K::compute_membership_and_dist(
1286        centroids.values(),
1287        vectors.values(),
1288        dimension,
1289        distance_type,
1290        0.0,
1291        None,
1292        None,
1293    )
1294}
1295
1296/// Train KMeans model and returns the centroids of each cluster.
1297///
1298/// Parameters
1299/// ----------
1300/// - *centroids*: initial centroids, use the random initialization if None
1301/// - *array*: a flatten floating number array of vectors
1302/// - *dimension*: dimension of the vector
1303/// - *k*: number of clusters
1304/// - *max_iterations*: maximum number of iterations
1305/// - *redos*: number of times to redo the k-means clustering
1306/// - *distance_type*: distance type to compute pair-wise vector distance
1307/// - *sample_rate*: sample rate to select the data for training
1308#[allow(clippy::too_many_arguments)]
1309pub fn train_kmeans<T: ArrowPrimitiveType>(
1310    array: &PrimitiveArray<T>,
1311    mut params: KMeansParams,
1312    dimension: usize,
1313    k: usize,
1314    sample_rate: usize,
1315) -> Result<KMeans>
1316where
1317    T::Native: Dot + L2 + Normalize,
1318    PrimitiveArray<T>: From<Vec<T::Native>>,
1319{
1320    let num_rows = array.len() / dimension;
1321    if num_rows < k {
1322        return Err(Error::Unprocessable {
1323            message: format!(
1324                "KMeans cannot train {k} centroids with {num_rows} vectors; choose a smaller K (< {num_rows})"
1325            ),
1326            location: location!(),
1327        });
1328    }
1329
1330    // Only sample sample_rate * num_clusters. See Faiss
1331    let data = if num_rows > sample_rate * k {
1332        log::info!(
1333            "Sample {} out of {} to train kmeans of {} dim, {} clusters",
1334            sample_rate * k,
1335            array.len() / dimension,
1336            dimension,
1337            k,
1338        );
1339        let sample_size = sample_rate * k;
1340        array.slice(0, sample_size * dimension)
1341    } else {
1342        array.clone()
1343    };
1344
1345    let data = FixedSizeListArray::try_new_from_values(data, dimension as i32)?;
1346
1347    params.balance_factor /= data.len() as f32;
1348    let model = KMeans::new_with_params(&data, k, &params)?;
1349    Ok(model)
1350}
1351
1352#[inline]
1353pub fn compute_partition<T: Float + L2 + Dot>(
1354    centroids: &[T],
1355    vector: &[T],
1356    distance_type: DistanceType,
1357) -> Option<u32> {
1358    match distance_type {
1359        DistanceType::L2 => {
1360            argmin_value_float(l2_distance_batch(vector, centroids, vector.len())).map(|(c, _)| c)
1361        }
1362        DistanceType::Dot => {
1363            argmin_value_float(dot_distance_batch(vector, centroids, vector.len())).map(|(c, _)| c)
1364        }
1365        _ => {
1366            panic!(
1367                "KMeans::compute_partition: distance type {} is not supported",
1368                distance_type
1369            );
1370        }
1371    }
1372}
1373
1374#[cfg(test)]
1375mod tests {
1376    use std::iter::repeat_n;
1377
1378    use arrow_array::types::Float16Type;
1379    use arrow_array::Float16Array;
1380    use half::f16;
1381    use lance_arrow::*;
1382    use lance_testing::datagen::generate_random_array;
1383
1384    use super::*;
1385    use lance_linalg::distance::l2;
1386    use lance_linalg::kernels::argmin;
1387
1388    #[test]
1389    fn test_train_with_small_dataset() {
1390        let data = Float32Array::from(vec![1.0, 2.0, 3.0, 4.0]);
1391        let data = FixedSizeListArray::try_new_from_values(data, 2).unwrap();
1392        match KMeans::new(&data, 128, 5) {
1393            Ok(_) => panic!("Should fail to train KMeans"),
1394            Err(e) => {
1395                assert!(e.to_string().contains("smaller than"));
1396            }
1397        }
1398    }
1399
1400    #[test]
1401    fn test_compute_partitions() {
1402        const DIM: usize = 256;
1403        let centroids = generate_random_array(DIM * 18);
1404        let data = generate_random_array(DIM * 20);
1405
1406        let expected = data
1407            .values()
1408            .chunks(DIM)
1409            .map(|row| {
1410                argmin(
1411                    centroids
1412                        .values()
1413                        .chunks(DIM)
1414                        .map(|centroid| l2(row, centroid)),
1415                )
1416            })
1417            .collect::<Vec<_>>();
1418        let (actual, _) = compute_partitions::<Float32Type, KMeansAlgoFloat<Float32Type>>(
1419            &centroids,
1420            &data,
1421            DIM,
1422            DistanceType::L2,
1423        );
1424        assert_eq!(expected, actual);
1425    }
1426
1427    #[tokio::test]
1428    async fn test_compute_membership_and_loss() {
1429        const DIM: usize = 256;
1430        let centroids = generate_random_array(DIM * 18);
1431        let data = generate_random_array(DIM * 20);
1432
1433        let (membership, _, losses) = KMeansAlgoFloat::<Float32Type>::compute_membership_and_loss(
1434            centroids.as_slice(),
1435            data.values(),
1436            DIM,
1437            DistanceType::L2,
1438            0.0,
1439            None,
1440            None,
1441        );
1442        let loss = losses.iter().sum::<f64>();
1443        assert!(loss > 0.0, "loss is not zero: {}", loss);
1444        membership.iter().for_each(|cd| {
1445            assert!(cd.is_some());
1446        });
1447    }
1448
1449    #[tokio::test]
1450    async fn test_l2_with_nans() {
1451        const DIM: usize = 8;
1452        const K: usize = 32;
1453        const NUM_CENTROIDS: usize = 16 * 2048;
1454        let centroids = generate_random_array(DIM * NUM_CENTROIDS);
1455        let values = Float32Array::from_iter_values(repeat_n(f32::NAN, DIM * K));
1456
1457        compute_partitions::<Float32Type, KMeansAlgoFloat<Float32Type>>(
1458            &centroids,
1459            &values,
1460            DIM,
1461            DistanceType::L2,
1462        )
1463        .0
1464        .iter()
1465        .for_each(|cd| {
1466            assert!(cd.is_none());
1467        });
1468    }
1469
1470    #[tokio::test]
1471    async fn test_train_l2_kmeans_with_nans() {
1472        const DIM: usize = 8;
1473        const K: usize = 32;
1474        const NUM_CENTROIDS: usize = 16 * 2048;
1475        let centroids = generate_random_array(DIM * NUM_CENTROIDS);
1476        let values = repeat_n(f32::NAN, DIM * K).collect::<Vec<_>>();
1477
1478        let (membership, _, _) = KMeansAlgoFloat::<Float32Type>::compute_membership_and_loss(
1479            centroids.as_slice(),
1480            &values,
1481            DIM,
1482            DistanceType::L2,
1483            0.0,
1484            None,
1485            None,
1486        );
1487
1488        membership.iter().for_each(|cd| assert!(cd.is_none()));
1489    }
1490
1491    #[tokio::test]
1492    async fn test_train_kmode() {
1493        const DIM: usize = 16;
1494        const K: usize = 32;
1495        const NUM_VALUES: usize = 256 * K;
1496
1497        let mut rng = SmallRng::from_os_rng();
1498        let values =
1499            UInt8Array::from_iter_values((0..NUM_VALUES * DIM).map(|_| rng.random_range(0..255)));
1500
1501        let fsl = FixedSizeListArray::try_new_from_values(values, DIM as i32).unwrap();
1502
1503        let params = KMeansParams {
1504            distance_type: DistanceType::Hamming,
1505            ..Default::default()
1506        };
1507        let kmeans = KMeans::new_with_params(&fsl, K, &params).unwrap();
1508        assert_eq!(kmeans.centroids.len(), K * DIM);
1509        assert_eq!(kmeans.dimension, DIM);
1510        assert_eq!(kmeans.centroids.data_type(), &DataType::UInt8);
1511    }
1512
1513    #[tokio::test]
1514    async fn test_hierarchical_kmeans() {
1515        const DIM: usize = 64;
1516        const K: usize = 257; // Greater than 256 to trigger hierarchical clustering
1517        const NUM_VALUES: usize = 1024 * K;
1518
1519        let values = generate_random_array(NUM_VALUES * DIM);
1520        let fsl = FixedSizeListArray::try_new_from_values(values, DIM as i32).unwrap();
1521
1522        let params = KMeansParams {
1523            max_iters: 10,
1524            hierarchical_k: 16,
1525            ..Default::default()
1526        };
1527
1528        let kmeans = KMeans::new_with_params(&fsl, K, &params).unwrap();
1529
1530        // Verify that we have the correct number of clusters
1531        assert_eq!(kmeans.centroids.len(), K * DIM);
1532        assert_eq!(kmeans.dimension, DIM);
1533        assert_eq!(kmeans.centroids.data_type(), &DataType::Float32);
1534
1535        // Verify that all centroids are valid (not NaN)
1536        let centroids = kmeans.centroids.as_primitive::<Float32Type>().values();
1537        for val in centroids {
1538            assert!(!val.is_nan(), "Centroid should not contain NaN values");
1539        }
1540    }
1541
1542    #[tokio::test]
1543    async fn test_float16_underflow_fix() {
1544        // This test verifies the fix for float16 division underflow
1545        // When training k-means on many float16 vectors with small k,
1546        // without limiting the data size, dividing centroids by count
1547        // can underflow to 0,
1548        // The fix limits data to k * 512 to prevent this
1549        const DIM: usize = 2;
1550        const K: usize = 2;
1551        const NUM_VALUES: usize = K * 65536; // Many vectors to trigger the issue
1552
1553        let f32_values = generate_random_array(NUM_VALUES * DIM);
1554        let f16_values = Float16Array::from_iter_values(
1555            f32_values.values().iter().map(|&v| half::f16::from_f32(v)),
1556        );
1557        let fsl = FixedSizeListArray::try_new_from_values(f16_values, DIM as i32).unwrap();
1558
1559        let params = KMeansParams {
1560            max_iters: 10,
1561            ..Default::default()
1562        };
1563
1564        let kmeans = KMeans::new_with_params(&fsl, K, &params).unwrap();
1565
1566        // Verify that we have the correct number of clusters
1567        assert_eq!(kmeans.centroids.len(), K * DIM);
1568        assert_eq!(kmeans.dimension, DIM);
1569        assert_eq!(kmeans.centroids.data_type(), &DataType::Float16);
1570
1571        // Verify that all centroids are valid (not zero or NaN)
1572        // Without the fix, they would all be zero due to underflow
1573        let centroids = kmeans.centroids.as_primitive::<Float16Type>().values();
1574        for &val in centroids {
1575            assert!(!val.is_nan(), "Centroid should not contain NaN values");
1576            assert!(val != f16::ZERO);
1577        }
1578    }
1579}