use core::f32;
use std::cmp::Ordering;
use std::collections::BinaryHeap;
use std::ops::{AddAssign, DivAssign};
use std::sync::Arc;
use std::vec;
use std::{collections::HashMap, ops::MulAssign};
use arrow_array::{
Array, ArrayRef, FixedSizeListArray, Float32Array, PrimitiveArray, UInt32Array,
cast::AsArray,
types::{ArrowPrimitiveType, Float16Type, Float32Type, Float64Type, UInt8Type},
};
use arrow_array::{ArrowNumericType, UInt8Array};
use arrow_ord::sort::sort_to_indices;
use arrow_schema::{ArrowError, DataType};
use bitvec::prelude::*;
use lance_arrow::FixedSizeListArrayExt;
use lance_core::utils::tokio::get_num_compute_intensive_cpus;
use lance_linalg::distance::hamming::{hamming, hamming_distance_batch};
use lance_linalg::distance::{DistanceType, Normalize, dot_distance_batch};
use lance_linalg::kernels::{argmin_value_float, argmin_value_float_with_bias};
use log::{info, warn};
use num_traits::One;
use num_traits::{AsPrimitive, Float, FromPrimitive, Num, Zero};
use rand::prelude::*;
use rayon::prelude::*;
use {
lance_linalg::distance::{
Dot,
l2::{L2, l2_distance_batch},
},
lance_linalg::kernels::argmin_value,
};
use crate::vector::utils::SimpleIndex;
use crate::{Error, Result};
#[derive(Debug, PartialEq)]
pub enum KMeanInit {
Random,
Incremental(Arc<FixedSizeListArray>),
}
pub struct KMeansParams {
pub max_iters: u32,
pub tolerance: f64,
pub redos: usize,
pub init: KMeanInit,
pub distance_type: DistanceType,
pub balance_factor: f32,
pub hierarchical_k: usize,
pub on_progress: Option<Arc<dyn Fn(u32, u32) + Send + Sync>>,
}
impl std::fmt::Debug for KMeansParams {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("KMeansParams")
.field("max_iters", &self.max_iters)
.field("tolerance", &self.tolerance)
.field("redos", &self.redos)
.field("init", &self.init)
.field("distance_type", &self.distance_type)
.field("balance_factor", &self.balance_factor)
.field("hierarchical_k", &self.hierarchical_k)
.field("on_progress", &self.on_progress.as_ref().map(|_| "..."))
.finish()
}
}
impl Default for KMeansParams {
fn default() -> Self {
Self {
max_iters: 50,
tolerance: 1e-4,
redos: 1,
init: KMeanInit::Random,
distance_type: DistanceType::L2,
balance_factor: 0.0,
hierarchical_k: 16,
on_progress: None,
}
}
}
impl KMeansParams {
pub fn new(
centroids: Option<Arc<FixedSizeListArray>>,
max_iters: u32,
redos: usize,
distance_type: DistanceType,
) -> Self {
let init = match centroids {
Some(centroids) => KMeanInit::Incremental(centroids),
None => KMeanInit::Random,
};
Self {
max_iters,
redos,
distance_type,
init,
..Default::default()
}
}
pub fn with_balance_factor(mut self, balance_factor: f32) -> Self {
self.balance_factor = balance_factor;
self
}
pub fn with_on_progress(mut self, cb: Arc<dyn Fn(u32, u32) + Send + Sync>) -> Self {
self.on_progress = Some(cb);
self
}
pub fn with_hierarchical_k(mut self, hierarchical_k: usize) -> Self {
self.hierarchical_k = hierarchical_k;
self
}
}
fn kmeans_random_init<T: ArrowPrimitiveType>(
data: &[T::Native],
dimension: usize,
k: usize,
mut rng: impl Rng,
distance_type: DistanceType,
) -> KMeans {
assert!(data.len() >= k * dimension);
let chosen = (0..data.len() / dimension).choose_multiple(&mut rng, k);
let centroids = PrimitiveArray::<T>::from_iter_values(
chosen
.iter()
.flat_map(|&i| data[i * dimension..(i + 1) * dimension].iter())
.copied(),
);
KMeans {
centroids: Arc::new(centroids),
dimension,
distance_type,
loss: f64::MAX,
}
}
fn split_clusters<T: Float + MulAssign>(
n: usize,
cnts: &mut [usize],
centroids: &mut [T],
dim: usize,
) {
let eps = T::from(1.0 / 1024.0).unwrap();
let mut rng = SmallRng::from_os_rng();
for i in 0..cnts.len() {
if cnts[i] == 0 {
let mut j = 0;
loop {
let p = (cnts[j] as f32 - 1.0) / (n - cnts.len()) as f32;
if rng.random::<f32>() < p {
break;
}
j += 1;
j %= cnts.len();
}
cnts[i] = cnts[j] / 2;
cnts[j] -= cnts[i];
for k in 0..dim {
if k % 2 == 0 {
centroids[i * dim + k] = centroids[j * dim + k] * (T::one() + eps);
centroids[j * dim + k] *= T::one() - eps;
} else {
centroids[i * dim + k] = centroids[j * dim + k] * (T::one() - eps);
centroids[j * dim + k] *= T::one() + eps;
}
}
}
}
}
fn compute_cluster_sizes(
membership: &[Option<u32>],
radius: &[f32],
losses: &[f64],
cluster_sizes: &mut [usize],
) -> f32 {
cluster_sizes.fill(0);
let mut max_cluster_id = 0;
let mut max_cluster_size = 0;
membership.iter().for_each(|cluster_id| {
if let Some(cluster_id) = cluster_id {
let cluster_id = *cluster_id as usize;
cluster_sizes[cluster_id] += 1;
if cluster_sizes[cluster_id] > max_cluster_size {
max_cluster_size = cluster_sizes[cluster_id];
max_cluster_id = cluster_id;
}
}
});
(radius[max_cluster_id] - losses[max_cluster_id] as f32 / cluster_sizes[max_cluster_id] as f32)
/ membership.len() as f32
}
fn compute_balance_loss(cluster_sizes: &[usize], n: usize, balance_factor: f32) -> f32 {
let size_loss = cluster_sizes.iter().map(|size| size.pow(2)).sum::<usize>() as f32;
balance_factor * (size_loss - n.pow(2) as f32 / cluster_sizes.len() as f32)
}
pub trait KMeansAlgo<T: Num> {
fn compute_membership_and_loss(
centroids: &[T],
data: &[T],
dimension: usize,
distance_type: DistanceType,
balance_factor: f32,
cluster_sizes: Option<&[usize]>,
index: Option<&SimpleIndex>,
) -> (Vec<Option<u32>>, Vec<f32>, Vec<f64>) {
let (membership, dists) = Self::compute_membership_and_dist(
centroids,
data,
dimension,
distance_type,
balance_factor,
cluster_sizes,
index,
);
let k = centroids.len() / dimension;
let mut cluster_radius = vec![0.0; k];
let mut losses = vec![0.0; k];
for (cluster_id, dist) in membership.iter().zip(dists.iter()) {
if let (Some(cluster_id), Some(dist)) = (cluster_id, dist) {
let cluster_id = *cluster_id as usize;
cluster_radius[cluster_id] = cluster_radius[cluster_id].max(*dist);
losses[cluster_id] += *dist as f64;
}
}
(membership, cluster_radius, losses)
}
fn compute_membership_and_dist(
centroids: &[T],
data: &[T],
dimension: usize,
distance_type: DistanceType,
balance_factor: f32,
cluster_sizes: Option<&[usize]>,
index: Option<&SimpleIndex>,
) -> (Vec<Option<u32>>, Vec<Option<f32>>);
fn to_kmeans(
data: &[T],
dimension: usize,
k: usize,
membership: &[Option<u32>],
cluster_sizes: &mut [usize],
distance_type: DistanceType,
loss: f64,
) -> KMeans;
}
pub struct KMeansAlgoFloat<T: ArrowNumericType>
where
T::Native: Float + Num,
{
phantom_data: std::marker::PhantomData<T>,
}
impl<T: ArrowNumericType> KMeansAlgo<T::Native> for KMeansAlgoFloat<T>
where
T::Native: Float + Dot + L2 + MulAssign + DivAssign + AddAssign + FromPrimitive + Sync,
PrimitiveArray<T>: From<Vec<T::Native>>,
{
fn compute_membership_and_dist(
centroids: &[T::Native],
data: &[T::Native],
dimension: usize,
distance_type: DistanceType,
balance_factor: f32,
cluster_sizes: Option<&[usize]>,
index: Option<&SimpleIndex>,
) -> (Vec<Option<u32>>, Vec<Option<f32>>) {
let cluster_and_dists = match index {
Some(index) => data
.par_chunks(dimension)
.map(|vec| {
let query = PrimitiveArray::<T>::from_iter_values(vec.iter().copied());
index
.search(Arc::new(query))
.map(|(id, dist)| Some((id, dist)))
.unwrap()
})
.collect::<Vec<_>>(),
None => match distance_type {
DistanceType::L2 => data
.par_chunks(dimension)
.map(|vec| {
argmin_value_float_with_bias(
l2_distance_batch(vec, centroids, dimension),
cluster_sizes
.map(|size| size.iter().map(|size| balance_factor * *size as f32)),
)
})
.collect::<Vec<_>>(),
DistanceType::Dot => data
.par_chunks(dimension)
.map(|vec| {
argmin_value_float_with_bias(
dot_distance_batch(vec, centroids, dimension),
cluster_sizes
.map(|size| size.iter().map(|size| balance_factor * *size as f32)),
)
})
.collect::<Vec<_>>(),
_ => {
panic!(
"KMeans::find_partitions: {} is not supported",
distance_type
);
}
},
};
cluster_and_dists.into_iter().map(Option::unzip).unzip()
}
fn to_kmeans(
data: &[T::Native],
dimension: usize,
k: usize,
membership: &[Option<u32>],
cluster_sizes: &mut [usize],
distance_type: DistanceType,
loss: f64,
) -> KMeans {
let mut centroids = vec![T::Native::zero(); k * dimension];
let mut num_cpus = get_num_compute_intensive_cpus();
if k < num_cpus || k < 16 {
num_cpus = 1;
}
let chunk_size = k / num_cpus;
centroids
.par_chunks_mut(dimension * chunk_size)
.enumerate()
.with_max_len(1)
.for_each(|(i, centroids)| {
let start = i * chunk_size;
let end = ((i + 1) * chunk_size).min(k);
data.chunks(dimension)
.zip(membership.iter())
.filter_map(|(vector, cluster_id)| {
cluster_id.map(|cluster_id| (vector, cluster_id as usize))
})
.for_each(|(vector, cluster_id)| {
if start <= cluster_id && cluster_id < end {
let local_id = cluster_id - start;
let centroid =
&mut centroids[local_id * dimension..(local_id + 1) * dimension];
centroid.iter_mut().zip(vector).for_each(|(c, v)| *c += *v);
}
});
});
centroids
.par_chunks_mut(dimension)
.zip(cluster_sizes.par_iter())
.for_each(|(centroid, &cnt)| {
if cnt > 0 {
let norm = T::Native::one() / T::Native::from_usize(cnt).unwrap();
centroid.iter_mut().for_each(|v| *v *= norm);
}
});
let empty_clusters = cluster_sizes.iter().filter(|&cnt| *cnt == 0).count();
if empty_clusters as f32 / k as f32 > 0.1 {
if data.len() / dimension < k * 256 {
warn!(
"KMeans: more than 10% of clusters are empty: {} of {}.\nHelp: this could mean your dataset \
is too small to have a meaningful index ({} < {}) or has many duplicate vectors.",
empty_clusters,
k,
data.len() / dimension,
k * 256
);
} else {
warn!(
"KMeans: more than 10% of clusters are empty: {} of {}.\nHelp: this could mean your dataset \
has many duplicate vectors.",
empty_clusters, k
);
}
}
split_clusters(
data.len() / dimension,
cluster_sizes,
&mut centroids,
dimension,
);
KMeans {
centroids: Arc::new(PrimitiveArray::<T>::from(centroids)),
dimension,
distance_type,
loss,
}
}
}
struct KModeAlgo {}
impl KMeansAlgo<u8> for KModeAlgo {
fn compute_membership_and_dist(
centroids: &[u8],
data: &[u8],
dimension: usize,
distance_type: DistanceType,
balance_factor: f32,
cluster_sizes: Option<&[usize]>,
_: Option<&SimpleIndex>,
) -> (Vec<Option<u32>>, Vec<Option<f32>>) {
assert_eq!(distance_type, DistanceType::Hamming);
let cluster_and_dists = data
.par_chunks(dimension)
.map(|vec| {
argmin_value(
centroids
.chunks_exact(dimension)
.enumerate()
.map(|(id, c)| {
hamming(vec, c)
+ balance_factor
* cluster_sizes.map(|sizes| sizes[id] as f32).unwrap_or(0.0)
}),
)
})
.collect::<Vec<_>>();
cluster_and_dists.into_iter().map(Option::unzip).unzip()
}
fn to_kmeans(
data: &[u8],
dimension: usize,
k: usize,
membership: &[Option<u32>],
_cluster_sizes: &mut [usize],
distance_type: DistanceType,
loss: f64,
) -> KMeans {
assert_eq!(distance_type, DistanceType::Hamming);
let mut clusters = HashMap::<u32, Vec<usize>>::new();
membership.iter().enumerate().for_each(|(i, part_id)| {
if let Some(part_id) = part_id {
clusters.entry(*part_id).or_default().push(i);
}
});
let centroids = (0..k as u32)
.into_par_iter()
.flat_map(|part_id| {
if let Some(vecs) = clusters.get(&part_id) {
let mut ones = vec![0_u32; dimension * 8];
let cnt = vecs.len() as u32;
vecs.iter().for_each(|&i| {
let vec = &data[i * dimension..(i + 1) * dimension];
ones.iter_mut()
.zip(vec.view_bits::<Lsb0>())
.for_each(|(c, v)| {
if *v.as_ref() {
*c += 1;
}
});
});
let bits = ones.iter().map(|&c| c * 2 > cnt).collect::<BitVec<u8>>();
bits.as_raw_slice()
.iter()
.copied()
.map(Some)
.collect::<Vec<_>>()
} else {
vec![None; dimension]
}
})
.collect::<Vec<_>>();
KMeans {
centroids: Arc::new(UInt8Array::from(centroids)),
dimension,
distance_type,
loss,
}
}
}
#[derive(Debug, Clone)]
pub struct KMeans {
pub centroids: ArrayRef,
pub dimension: usize,
pub distance_type: DistanceType,
pub loss: f64,
}
impl KMeans {
fn empty(dimension: usize, distance_type: DistanceType) -> Self {
Self {
centroids: arrow_array::array::new_empty_array(&DataType::Float32),
dimension,
distance_type,
loss: f64::MAX,
}
}
pub fn with_centroids(
centroids: ArrayRef,
dimension: usize,
distance_type: DistanceType,
loss: f64,
) -> Self {
assert!(matches!(
centroids.data_type(),
DataType::Float16 | DataType::Float32 | DataType::Float64 | DataType::UInt8
));
Self {
centroids,
dimension,
distance_type,
loss,
}
}
fn init_random<T: ArrowPrimitiveType>(
data: &[T::Native],
dimension: usize,
k: usize,
rng: impl Rng,
distance_type: DistanceType,
) -> Self {
kmeans_random_init::<T>(data, dimension, k, rng, distance_type)
}
pub fn new(data: &FixedSizeListArray, k: usize, max_iters: u32) -> arrow::error::Result<Self> {
let params = KMeansParams {
max_iters,
distance_type: DistanceType::L2,
..Default::default()
};
Self::new_with_params(data, k, ¶ms)
}
fn train_kmeans<T: ArrowNumericType, Algo: KMeansAlgo<T::Native>>(
data: &FixedSizeListArray,
k: usize,
params: &KMeansParams,
) -> arrow::error::Result<Self>
where
T::Native: Num,
{
let data = if data.len() >= k * 512 {
data.slice(0, k * 512)
} else {
data.clone()
};
let n = data.len();
let dimension = data.value_length() as usize;
let data =
data.values()
.as_primitive_opt::<T>()
.ok_or(ArrowError::InvalidArgumentError(format!(
"KMeans: data must be {}, got: {}",
T::DATA_TYPE,
data.value_type()
)))?;
let mut best_kmeans = Self::empty(dimension, params.distance_type);
let mut cluster_sizes = vec![0; k];
let mut adjusted_balance_factor = f32::MAX;
let rng = SmallRng::from_os_rng();
for redo in 1..=params.redos {
let mut kmeans: Self = match ¶ms.init {
KMeanInit::Random => Self::init_random::<T>(
data.values(),
dimension,
k,
rng.clone(),
params.distance_type,
),
KMeanInit::Incremental(centroids) => Self::with_centroids(
centroids.values().clone(),
dimension,
params.distance_type,
f64::MAX,
),
};
let mut loss = f64::MAX;
for i in 1..=params.max_iters {
if let Some(cb) = ¶ms.on_progress {
cb(i, params.max_iters);
}
if i % 10 == 0 {
info!(
"KMeans training: iteration {} / {}, redo={}",
i, params.max_iters, redo
);
};
let index = SimpleIndex::may_train_index(
kmeans.centroids.clone(),
kmeans.dimension,
kmeans.distance_type,
)?;
let balance_factor = adjusted_balance_factor.min(params.balance_factor);
let (membership, radius, losses) = Algo::compute_membership_and_loss(
kmeans.centroids.as_primitive::<T>().values(),
data.values(),
dimension,
params.distance_type,
balance_factor,
Some(&cluster_sizes),
index.as_ref(),
);
adjusted_balance_factor =
compute_cluster_sizes(&membership, &radius, &losses, &mut cluster_sizes);
let balance_loss = compute_balance_loss(&cluster_sizes, n, balance_factor);
let last_loss = losses.iter().sum::<f64>() + balance_loss as f64;
kmeans = Algo::to_kmeans(
data.values(),
dimension,
k,
&membership,
&mut cluster_sizes,
params.distance_type,
last_loss,
);
if (loss - last_loss).abs() < params.tolerance * last_loss {
info!(
"KMeans training: converged at iteration {} / {}, redo={}, loss={}, last_loss={}, loss_diff={}",
i,
params.max_iters,
redo,
loss,
last_loss,
(loss - last_loss).abs() / last_loss
);
break;
}
loss = last_loss;
}
if kmeans.loss < best_kmeans.loss {
best_kmeans = kmeans;
}
}
Ok(best_kmeans)
}
fn create_array_from_indices<T: ArrowNumericType>(
indices: &[usize],
data_values: &[T::Native],
dimension: usize,
) -> arrow::error::Result<FixedSizeListArray>
where
T::Native: Clone,
PrimitiveArray<T>: From<Vec<T::Native>>,
{
let mut subset_data = Vec::with_capacity(indices.len() * dimension);
for &idx in indices {
let start = idx * dimension;
let end = start + dimension;
subset_data.extend_from_slice(&data_values[start..end]);
}
let array = PrimitiveArray::<T>::from(subset_data);
FixedSizeListArray::try_new_from_values(array, dimension as i32)
}
fn train_hierarchical_kmeans<T: ArrowNumericType, Algo: KMeansAlgo<T::Native>>(
data: &FixedSizeListArray,
target_k: usize,
params: &KMeansParams,
) -> arrow::error::Result<Self>
where
T::Native: Num,
PrimitiveArray<T>: From<Vec<T::Native>>,
{
#[derive(Clone, Debug)]
struct Cluster<N> {
id: usize,
indices: Vec<usize>,
centroid: Vec<N>,
finalized: bool,
}
impl<N> Eq for Cluster<N> {}
impl<N> PartialEq for Cluster<N> {
fn eq(&self, other: &Self) -> bool {
self.indices.len() == other.indices.len()
}
}
impl<N> Ord for Cluster<N> {
fn cmp(&self, other: &Self) -> Ordering {
match (self.finalized, other.finalized) {
(false, true) => Ordering::Greater,
(true, false) => Ordering::Less,
_ => {
self.indices.len().cmp(&other.indices.len())
}
}
}
}
impl<N> PartialOrd for Cluster<N> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
let n = data.len();
let dimension = data.value_length() as usize;
let data_values = data
.values()
.as_primitive_opt::<T>()
.ok_or(ArrowError::InvalidArgumentError(format!(
"KMeans: data must be {}, got: {}",
T::DATA_TYPE,
data.value_type()
)))?
.values();
let initial_k = params.hierarchical_k.min(target_k).min(n);
info!(
"Hierarchical clustering: initial k={}, target k={}",
initial_k, target_k
);
let initial_kmeans = Self::train_kmeans::<T, Algo>(data, initial_k, params)?;
let (membership, _, _) = Algo::compute_membership_and_loss(
initial_kmeans.centroids.as_primitive::<T>().values(),
data_values,
dimension,
params.distance_type,
0.0, None,
None,
);
let mut heap: BinaryHeap<Cluster<T::Native>> = BinaryHeap::new();
let mut next_cluster_id = 0;
let initial_centroids = initial_kmeans.centroids.as_primitive::<T>().values();
for i in 0..initial_k {
let mut cluster_indices = Vec::new();
for (idx, &cluster_id) in membership.iter().enumerate() {
if let Some(cid) = cluster_id
&& cid as usize == i
{
cluster_indices.push(idx);
}
}
if !cluster_indices.is_empty() {
let centroid_start = i * dimension;
let centroid_end = centroid_start + dimension;
let centroid = initial_centroids[centroid_start..centroid_end].to_vec();
heap.push(Cluster {
id: next_cluster_id,
indices: cluster_indices,
centroid,
finalized: false,
});
next_cluster_id += 1;
}
}
while heap.len() < target_k {
let mut largest_cluster = heap.pop().ok_or(ArrowError::InvalidArgumentError(
"No cluster can be further split".to_string(),
))?;
if largest_cluster.finalized {
log::warn!(
"Cluster {} is already finalized, no further split is possible, finish with {} clusters",
largest_cluster.id,
heap.len() + 1
);
heap.push(largest_cluster);
break;
}
if largest_cluster.indices.len() <= 1 {
log::warn!(
"Cluster {} has only 1 point, no further split is possible, finish with {} clusters",
largest_cluster.id,
heap.len() + 1
);
heap.push(largest_cluster);
break;
}
let cluster_size = largest_cluster.indices.len();
log::debug!(
"Splitting cluster {} with {} points (current total clusters: {})",
largest_cluster.id,
cluster_size,
heap.len() + 1 );
let remaining_k = target_k - heap.len(); let cluster_k = if cluster_size <= params.hierarchical_k {
2.min(remaining_k).min(cluster_size)
} else {
let suggested_k = cluster_size / params.hierarchical_k;
suggested_k
.min(remaining_k)
.min(params.hierarchical_k)
.max(2)
};
let sub_data = Self::create_array_from_indices::<T>(
&largest_cluster.indices,
data_values,
dimension,
)?;
let sub_kmeans = Self::train_kmeans::<T, Algo>(&sub_data, cluster_k, params)?;
let sub_data = sub_data.values().as_primitive::<T>().values();
let (sub_membership, _, _) = Algo::compute_membership_and_loss(
sub_kmeans.centroids.as_primitive::<T>().values(),
sub_data,
dimension,
params.distance_type,
0.0,
None,
None,
);
let approx_cluster_capacity = if cluster_k > 0 {
largest_cluster.indices.len().div_ceil(cluster_k)
} else {
0
};
let mut cluster_assignments: Vec<Vec<usize>> = (0..cluster_k)
.map(|_| Vec::with_capacity(approx_cluster_capacity))
.collect();
let mut first_sid: Option<u32> = None;
let mut all_same = true;
for (local_idx, &membership) in sub_membership.iter().enumerate() {
let Some(sub_cluster_id) = membership else {
continue;
};
if let Some(first) = first_sid {
if sub_cluster_id != first {
all_same = false;
}
} else {
first_sid = Some(sub_cluster_id);
}
let sub_cluster_id = sub_cluster_id as usize;
if let Some(indices) = cluster_assignments.get_mut(sub_cluster_id) {
indices.push(largest_cluster.indices[local_idx]);
} else {
all_same = false;
}
}
if all_same {
largest_cluster.finalized = true;
heap.push(largest_cluster);
continue;
}
let sub_centroids = sub_kmeans.centroids.as_primitive::<T>().values();
for (i, new_cluster_indices) in cluster_assignments.into_iter().enumerate() {
if new_cluster_indices.is_empty() {
continue;
}
let centroid_start = i * dimension;
let centroid_end = centroid_start + dimension;
let centroid = sub_centroids[centroid_start..centroid_end].to_vec();
heap.push(Cluster {
id: next_cluster_id,
indices: new_cluster_indices,
centroid,
finalized: false,
});
next_cluster_id += 1;
}
log::debug!(
"Split complete: now have {} clusters (target: {})",
heap.len(),
target_k
);
}
debug_assert_eq!(heap.len(), target_k);
let mut all_clusters: Vec<Cluster<T::Native>> = heap.into_vec();
all_clusters.sort_by_key(|c| c.id);
let flat_centroids: Vec<T::Native> =
all_clusters.into_iter().flat_map(|c| c.centroid).collect();
let centroids_array = PrimitiveArray::<T>::from(flat_centroids);
Ok(Self {
centroids: Arc::new(centroids_array),
dimension,
distance_type: params.distance_type,
loss: 0.0, })
}
pub fn new_with_params(
data: &FixedSizeListArray,
k: usize,
params: &KMeansParams,
) -> arrow::error::Result<Self> {
let n = data.len();
if n < k {
return Err(ArrowError::InvalidArgumentError(format!(
"KMeans: training does not have sufficient data points: n({}) is smaller than k({})",
n, k
)));
}
if k > 256 && params.hierarchical_k > 1 {
log::debug!("Using hierarchical clustering for k={}", k);
return match (data.value_type(), params.distance_type) {
(DataType::Float16, _) => Self::train_hierarchical_kmeans::<
Float16Type,
KMeansAlgoFloat<Float16Type>,
>(data, k, params),
(DataType::Float32, _) => Self::train_hierarchical_kmeans::<
Float32Type,
KMeansAlgoFloat<Float32Type>,
>(data, k, params),
(DataType::Float64, _) => Self::train_hierarchical_kmeans::<
Float64Type,
KMeansAlgoFloat<Float64Type>,
>(data, k, params),
(DataType::UInt8, DistanceType::Hamming) => {
Self::train_hierarchical_kmeans::<UInt8Type, KModeAlgo>(data, k, params)
}
_ => Err(ArrowError::InvalidArgumentError(format!(
"KMeans: can not train data type {} with distance type: {}",
data.value_type(),
params.distance_type
))),
};
}
match (data.value_type(), params.distance_type) {
(DataType::Float16, _) => {
Self::train_kmeans::<Float16Type, KMeansAlgoFloat<Float16Type>>(data, k, params)
}
(DataType::Float32, _) => {
Self::train_kmeans::<Float32Type, KMeansAlgoFloat<Float32Type>>(data, k, params)
}
(DataType::Float64, _) => {
Self::train_kmeans::<Float64Type, KMeansAlgoFloat<Float64Type>>(data, k, params)
}
(DataType::UInt8, DistanceType::Hamming) => {
Self::train_kmeans::<UInt8Type, KModeAlgo>(data, k, params)
}
_ => Err(ArrowError::InvalidArgumentError(format!(
"KMeans: can not train data type {} with distance type: {}",
data.value_type(),
params.distance_type
))),
}
}
}
pub fn kmeans_find_partitions_arrow_array(
centroids: &FixedSizeListArray,
query: &dyn Array,
nprobes: usize,
distance_type: DistanceType,
) -> arrow::error::Result<(UInt32Array, Float32Array)> {
if centroids.value_length() as usize != query.len() {
return Err(ArrowError::InvalidArgumentError(format!(
"Centroids and vectors have different dimensions: {} != {}",
centroids.value_length(),
query.len()
)));
}
match (centroids.value_type(), query.data_type()) {
(DataType::Float16, DataType::Float16) => Ok(kmeans_find_partitions(
centroids.values().as_primitive::<Float16Type>().values(),
query.as_primitive::<Float16Type>().values(),
nprobes,
distance_type,
)?),
(DataType::Float32, DataType::Float32) => Ok(kmeans_find_partitions(
centroids.values().as_primitive::<Float32Type>().values(),
query.as_primitive::<Float32Type>().values(),
nprobes,
distance_type,
)?),
(DataType::Float64, DataType::Float64) => Ok(kmeans_find_partitions(
centroids.values().as_primitive::<Float64Type>().values(),
query.as_primitive::<Float64Type>().values(),
nprobes,
distance_type,
)?),
(DataType::UInt8, DataType::UInt8) => Ok(kmeans_find_partitions_binary(
centroids.values().as_primitive::<UInt8Type>().values(),
query.as_primitive::<UInt8Type>().values(),
nprobes,
distance_type,
)?),
_ => Err(ArrowError::InvalidArgumentError(format!(
"Centroids and vectors have different types: {} != {}",
centroids.value_type(),
query.data_type()
))),
}
}
pub fn kmeans_find_partitions<T: Float + L2 + Dot>(
centroids: &[T],
query: &[T],
nprobes: usize,
distance_type: DistanceType,
) -> arrow::error::Result<(UInt32Array, Float32Array)> {
let dists: Vec<f32> = match distance_type {
DistanceType::L2 => l2_distance_batch(query, centroids, query.len()).collect(),
DistanceType::Dot => dot_distance_batch(query, centroids, query.len()).collect(),
_ => {
panic!(
"KMeans::find_partitions: {} is not supported",
distance_type
);
}
};
let dists_arr = Float32Array::from(dists);
let indices = sort_to_indices(&dists_arr, None, Some(nprobes))?;
let dists = arrow::compute::take(&dists_arr, &indices, None)?
.as_primitive::<Float32Type>()
.clone();
Ok((indices, dists))
}
pub fn kmeans_find_partitions_binary(
centroids: &[u8],
query: &[u8],
nprobes: usize,
distance_type: DistanceType,
) -> arrow::error::Result<(UInt32Array, Float32Array)> {
let dists: Vec<f32> = match distance_type {
DistanceType::Hamming => hamming_distance_batch(query, centroids, query.len()).collect(),
_ => {
panic!(
"KMeans::find_partitions: {} is not supported",
distance_type
);
}
};
let dists_arr = Float32Array::from(dists);
let indices = sort_to_indices(&dists_arr, None, Some(nprobes))?;
let dists = arrow::compute::take(&dists_arr, &indices, None)?
.as_primitive::<Float32Type>()
.clone();
Ok((indices, dists))
}
#[allow(clippy::type_complexity)]
pub fn compute_partitions_arrow_array(
centroids: &FixedSizeListArray,
vectors: &FixedSizeListArray,
distance_type: DistanceType,
) -> arrow::error::Result<(Vec<Option<u32>>, Vec<Option<f32>>)> {
if centroids.value_length() != vectors.value_length() {
return Err(ArrowError::InvalidArgumentError(
"Centroids and vectors have different dimensions".to_string(),
));
}
match (centroids.value_type(), vectors.value_type()) {
(DataType::Float16, DataType::Float16) => Ok(compute_partitions_with_dists::<
Float16Type,
KMeansAlgoFloat<Float16Type>,
>(
centroids.values().as_primitive(),
vectors.values().as_primitive(),
centroids.value_length(),
distance_type,
)),
(DataType::Float32, DataType::Float32) => Ok(compute_partitions_with_dists::<
Float32Type,
KMeansAlgoFloat<Float32Type>,
>(
centroids.values().as_primitive(),
vectors.values().as_primitive(),
centroids.value_length(),
distance_type,
)),
(DataType::Float32, DataType::Int8) => Ok(compute_partitions_with_dists::<
Float32Type,
KMeansAlgoFloat<Float32Type>,
>(
centroids.values().as_primitive(),
vectors.convert_to_floating_point()?.values().as_primitive(),
centroids.value_length(),
distance_type,
)),
(DataType::Float64, DataType::Float64) => Ok(compute_partitions_with_dists::<
Float64Type,
KMeansAlgoFloat<Float64Type>,
>(
centroids.values().as_primitive(),
vectors.values().as_primitive(),
centroids.value_length(),
distance_type,
)),
(DataType::UInt8, DataType::UInt8) => {
Ok(compute_partitions_with_dists::<UInt8Type, KModeAlgo>(
centroids.values().as_primitive(),
vectors.values().as_primitive(),
centroids.value_length(),
distance_type,
))
}
_ => Err(ArrowError::InvalidArgumentError(
"Centroids and vectors have incompatible types".to_string(),
)),
}
}
pub fn compute_partitions<T: ArrowNumericType, K: KMeansAlgo<T::Native>>(
centroids: &PrimitiveArray<T>,
vectors: &PrimitiveArray<T>,
dimension: impl AsPrimitive<usize>,
distance_type: DistanceType,
) -> (Vec<Option<u32>>, f64)
where
T::Native: Num,
{
let dimension = dimension.as_();
let (membership, _, losses) = K::compute_membership_and_loss(
centroids.values(),
vectors.values(),
dimension,
distance_type,
0.0,
None,
None,
);
(membership, losses.iter().sum::<f64>())
}
pub fn compute_partitions_with_dists<T: ArrowNumericType, K: KMeansAlgo<T::Native>>(
centroids: &PrimitiveArray<T>,
vectors: &PrimitiveArray<T>,
dimension: impl AsPrimitive<usize>,
distance_type: DistanceType,
) -> (Vec<Option<u32>>, Vec<Option<f32>>)
where
T::Native: Num,
{
let dimension = dimension.as_();
K::compute_membership_and_dist(
centroids.values(),
vectors.values(),
dimension,
distance_type,
0.0,
None,
None,
)
}
#[allow(clippy::too_many_arguments)]
pub fn train_kmeans<T: ArrowPrimitiveType>(
array: &PrimitiveArray<T>,
mut params: KMeansParams,
dimension: usize,
k: usize,
sample_rate: usize,
) -> Result<KMeans>
where
T::Native: Dot + L2 + Normalize,
PrimitiveArray<T>: From<Vec<T::Native>>,
{
let num_rows = array.len() / dimension;
if num_rows < k {
return Err(Error::unprocessable(format!(
"KMeans cannot train {k} centroids with {num_rows} vectors; choose a smaller K (< {num_rows})"
)));
}
let data = if num_rows > sample_rate * k {
log::info!(
"Sample {} out of {} to train kmeans of {} dim, {} clusters",
sample_rate * k,
array.len() / dimension,
dimension,
k,
);
let sample_size = sample_rate * k;
array.slice(0, sample_size * dimension)
} else {
array.clone()
};
let data = FixedSizeListArray::try_new_from_values(data, dimension as i32)?;
params.balance_factor /= data.len() as f32;
let model = KMeans::new_with_params(&data, k, ¶ms)?;
Ok(model)
}
#[inline]
pub fn compute_partition<T: Float + L2 + Dot>(
centroids: &[T],
vector: &[T],
distance_type: DistanceType,
) -> Option<u32> {
match distance_type {
DistanceType::L2 => {
argmin_value_float(l2_distance_batch(vector, centroids, vector.len())).map(|(c, _)| c)
}
DistanceType::Dot => {
argmin_value_float(dot_distance_batch(vector, centroids, vector.len())).map(|(c, _)| c)
}
_ => {
panic!(
"KMeans::compute_partition: distance type {} is not supported",
distance_type
);
}
}
}
#[cfg(test)]
mod tests {
use std::iter::repeat_n;
use arrow_array::Float16Array;
use arrow_array::types::Float16Type;
use half::f16;
use lance_arrow::*;
use lance_testing::datagen::generate_random_array;
use super::*;
use lance_linalg::distance::l2;
use lance_linalg::kernels::argmin;
#[test]
fn test_train_with_small_dataset() {
let data = Float32Array::from(vec![1.0, 2.0, 3.0, 4.0]);
let data = FixedSizeListArray::try_new_from_values(data, 2).unwrap();
match KMeans::new(&data, 128, 5) {
Ok(_) => panic!("Should fail to train KMeans"),
Err(e) => {
assert!(e.to_string().contains("smaller than"));
}
}
}
#[test]
fn test_compute_partitions() {
const DIM: usize = 256;
let centroids = generate_random_array(DIM * 18);
let data = generate_random_array(DIM * 20);
let expected = data
.values()
.chunks(DIM)
.map(|row| {
argmin(
centroids
.values()
.chunks(DIM)
.map(|centroid| l2(row, centroid)),
)
})
.collect::<Vec<_>>();
let (actual, _) = compute_partitions::<Float32Type, KMeansAlgoFloat<Float32Type>>(
¢roids,
&data,
DIM,
DistanceType::L2,
);
assert_eq!(expected, actual);
}
#[tokio::test]
async fn test_compute_membership_and_loss() {
const DIM: usize = 256;
let centroids = generate_random_array(DIM * 18);
let data = generate_random_array(DIM * 20);
let (membership, _, losses) = KMeansAlgoFloat::<Float32Type>::compute_membership_and_loss(
centroids.as_slice(),
data.values(),
DIM,
DistanceType::L2,
0.0,
None,
None,
);
let loss = losses.iter().sum::<f64>();
assert!(loss > 0.0, "loss is not zero: {}", loss);
membership.iter().for_each(|cd| {
assert!(cd.is_some());
});
}
#[tokio::test]
async fn test_l2_with_nans() {
const DIM: usize = 8;
const K: usize = 32;
const NUM_CENTROIDS: usize = 16 * 2048;
let centroids = generate_random_array(DIM * NUM_CENTROIDS);
let values = Float32Array::from_iter_values(repeat_n(f32::NAN, DIM * K));
compute_partitions::<Float32Type, KMeansAlgoFloat<Float32Type>>(
¢roids,
&values,
DIM,
DistanceType::L2,
)
.0
.iter()
.for_each(|cd| {
assert!(cd.is_none());
});
}
#[tokio::test]
async fn test_train_l2_kmeans_with_nans() {
const DIM: usize = 8;
const K: usize = 32;
const NUM_CENTROIDS: usize = 16 * 2048;
let centroids = generate_random_array(DIM * NUM_CENTROIDS);
let values = repeat_n(f32::NAN, DIM * K).collect::<Vec<_>>();
let (membership, _, _) = KMeansAlgoFloat::<Float32Type>::compute_membership_and_loss(
centroids.as_slice(),
&values,
DIM,
DistanceType::L2,
0.0,
None,
None,
);
membership.iter().for_each(|cd| assert!(cd.is_none()));
}
#[tokio::test]
async fn test_train_kmode() {
const DIM: usize = 16;
const K: usize = 32;
const NUM_VALUES: usize = 256 * K;
let mut rng = SmallRng::from_os_rng();
let values =
UInt8Array::from_iter_values((0..NUM_VALUES * DIM).map(|_| rng.random_range(0..255)));
let fsl = FixedSizeListArray::try_new_from_values(values, DIM as i32).unwrap();
let params = KMeansParams {
distance_type: DistanceType::Hamming,
..Default::default()
};
let kmeans = KMeans::new_with_params(&fsl, K, ¶ms).unwrap();
assert_eq!(kmeans.centroids.len(), K * DIM);
assert_eq!(kmeans.dimension, DIM);
assert_eq!(kmeans.centroids.data_type(), &DataType::UInt8);
}
#[tokio::test]
async fn test_hierarchical_kmeans() {
const DIM: usize = 64;
const K: usize = 257; const NUM_VALUES: usize = 1024 * K;
let values = generate_random_array(NUM_VALUES * DIM);
let fsl = FixedSizeListArray::try_new_from_values(values, DIM as i32).unwrap();
let params = KMeansParams {
max_iters: 10,
hierarchical_k: 16,
..Default::default()
};
let kmeans = KMeans::new_with_params(&fsl, K, ¶ms).unwrap();
assert_eq!(kmeans.centroids.len(), K * DIM);
assert_eq!(kmeans.dimension, DIM);
assert_eq!(kmeans.centroids.data_type(), &DataType::Float32);
let centroids = kmeans.centroids.as_primitive::<Float32Type>().values();
for val in centroids {
assert!(!val.is_nan(), "Centroid should not contain NaN values");
}
}
#[tokio::test]
async fn test_float16_underflow_fix() {
const DIM: usize = 2;
const K: usize = 2;
const NUM_VALUES: usize = K * 65536;
let f32_values = generate_random_array(NUM_VALUES * DIM);
let f16_values = Float16Array::from_iter_values(
f32_values.values().iter().map(|&v| half::f16::from_f32(v)),
);
let fsl = FixedSizeListArray::try_new_from_values(f16_values, DIM as i32).unwrap();
let params = KMeansParams {
max_iters: 10,
..Default::default()
};
let kmeans = KMeans::new_with_params(&fsl, K, ¶ms).unwrap();
assert_eq!(kmeans.centroids.len(), K * DIM);
assert_eq!(kmeans.dimension, DIM);
assert_eq!(kmeans.centroids.data_type(), &DataType::Float16);
let centroids = kmeans.centroids.as_primitive::<Float16Type>().values();
for &val in centroids {
assert!(!val.is_nan(), "Centroid should not contain NaN values");
assert!(val != f16::ZERO);
}
}
}