use std::collections::{HashMap, HashSet};
use crate::distance::scalar::scalar_distance;
use nodedb_types::vector_distance::DistanceMetric;
use super::storage::MultiVectorStore;
fn assign(vectors: &[Vec<f32>], centroids: &[Vec<f32>]) -> Vec<usize> {
vectors
.iter()
.map(|v| {
centroids
.iter()
.enumerate()
.map(|(i, c)| (i, scalar_distance(v, c, DistanceMetric::L2)))
.min_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
.map(|(i, _)| i)
.unwrap_or(0)
})
.collect()
}
fn lcg_next(s: &mut u64) -> u64 {
*s = s
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
*s
}
fn min_dist_to_centroids(v: &[f32], centroids: &[Vec<f32>]) -> f32 {
centroids
.iter()
.map(|c| scalar_distance(v, c, DistanceMetric::L2))
.min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.unwrap_or(f32::INFINITY)
}
fn recompute(
vectors: &[Vec<f32>],
assignments: &[usize],
num_centroids: usize,
dim: usize,
prev_centroids: &[Vec<f32>],
) -> Vec<Vec<f32>> {
let mut sums = vec![vec![0.0f32; dim]; num_centroids];
let mut counts = vec![0usize; num_centroids];
for (v, &c) in vectors.iter().zip(assignments.iter()) {
for (s, x) in sums[c].iter_mut().zip(v.iter()) {
*s += x;
}
counts[c] += 1;
}
for (s, &n) in sums.iter_mut().zip(counts.iter()) {
if n > 0 {
s.iter_mut().for_each(|x| *x /= n as f32);
}
}
for c_idx in 0..num_centroids {
if counts[c_idx] != 0 {
continue;
}
let live: Vec<Vec<f32>> = counts
.iter()
.enumerate()
.filter(|(i, cnt)| *i != c_idx && **cnt > 0)
.map(|(i, _)| sums[i].clone())
.collect();
let seed_pool: &[Vec<f32>] = if live.is_empty() {
prev_centroids
} else {
&live
};
let farthest = vectors
.iter()
.map(|v| (v, min_dist_to_centroids(v, seed_pool)))
.max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
.map(|(v, _)| v.clone());
if let Some(v) = farthest {
sums[c_idx] = v;
counts[c_idx] = 1; } else if c_idx < prev_centroids.len() {
sums[c_idx] = prev_centroids[c_idx].clone();
counts[c_idx] = 1;
}
}
sums
}
fn kmeans_plus_plus_init(vectors: &[Vec<f32>], k: usize, seed: u64) -> Vec<Vec<f32>> {
let mut state = seed.wrapping_add(1);
let first = (lcg_next(&mut state) as usize) % vectors.len();
let mut centroids: Vec<Vec<f32>> = vec![vectors[first].clone()];
while centroids.len() < k {
let dists: Vec<f32> = vectors
.iter()
.map(|v| {
let d = min_dist_to_centroids(v, ¢roids);
d * d
})
.collect();
let total: f64 = dists.iter().map(|&d| d as f64).sum();
if total <= 0.0 {
let idx = (lcg_next(&mut state) as usize) % vectors.len();
centroids.push(vectors[idx].clone());
continue;
}
let r = (lcg_next(&mut state) as f64) / (u64::MAX as f64) * total;
let mut acc = 0.0f64;
let mut pick = vectors.len() - 1;
for (i, &d) in dists.iter().enumerate() {
acc += d as f64;
if acc >= r {
pick = i;
break;
}
}
centroids.push(vectors[pick].clone());
}
centroids
}
fn kmeans(
vectors: &[Vec<f32>],
num_centroids: usize,
iters: usize,
seed: u64,
dim: usize,
) -> Vec<Vec<f32>> {
if vectors.is_empty() || num_centroids == 0 {
return Vec::new();
}
let k = num_centroids.min(vectors.len());
let mut centroids = kmeans_plus_plus_init(vectors, k, seed);
for _ in 0..iters {
let assignments = assign(vectors, ¢roids);
let new_centroids = recompute(vectors, &assignments, k, dim, ¢roids);
centroids = new_centroids;
}
centroids
}
pub struct PlaidPruner {
pub centroids: Vec<Vec<f32>>,
doc_centroids: HashMap<u32, Vec<u16>>,
}
impl PlaidPruner {
pub fn train(
store: &MultiVectorStore,
num_centroids: u16,
kmeans_iters: usize,
seed: u64,
) -> Self {
let dim = store.dim;
let nc = num_centroids as usize;
let all_vectors: Vec<Vec<f32>> = store
.iter()
.flat_map(|doc| doc.vectors.iter().cloned())
.collect();
let centroids = kmeans(&all_vectors, nc, kmeans_iters, seed, dim);
let doc_centroids: HashMap<u32, Vec<u16>> = store
.iter()
.map(|doc| {
let mut ids: Vec<u16> = doc
.vectors
.iter()
.map(|v| {
centroids
.iter()
.enumerate()
.map(|(i, c)| (i as u16, scalar_distance(v, c, DistanceMetric::L2)))
.min_by(|a, b| {
a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal)
})
.map(|(i, _)| i)
.unwrap_or(0)
})
.collect();
ids.sort_unstable();
ids.dedup();
(doc.doc_id, ids)
})
.collect();
Self {
centroids,
doc_centroids,
}
}
pub fn candidates(&self, query: &[Vec<f32>]) -> Vec<u32> {
if self.centroids.is_empty() || query.is_empty() {
return Vec::new();
}
let query_bag: HashSet<u16> = query
.iter()
.filter_map(|v| {
self.centroids
.iter()
.enumerate()
.map(|(i, c)| (i as u16, scalar_distance(v, c, DistanceMetric::L2)))
.min_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
.map(|(id, _)| id)
})
.collect();
self.doc_centroids
.iter()
.filter(|(_, doc_ids)| doc_ids.iter().any(|id| query_bag.contains(id)))
.map(|(&doc_id, _)| doc_id)
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::multivec::storage::{MultiVecMode, MultiVectorDoc, MultiVectorStore};
fn build_store() -> MultiVectorStore {
let mut store = MultiVectorStore::new(2, MultiVecMode::PerToken);
for i in 0u32..3 {
store
.insert(MultiVectorDoc {
doc_id: i,
vectors: vec![vec![i as f32 * 0.1, i as f32 * 0.1]],
})
.unwrap();
}
for i in 3u32..6 {
store
.insert(MultiVectorDoc {
doc_id: i,
vectors: vec![vec![10.0 + i as f32 * 0.1, 0.0]],
})
.unwrap();
}
for i in 6u32..9 {
store
.insert(MultiVectorDoc {
doc_id: i,
vectors: vec![vec![0.0, 10.0 + i as f32 * 0.1]],
})
.unwrap();
}
store
}
#[test]
fn train_produces_correct_centroid_count() {
let store = build_store();
let pruner = PlaidPruner::train(&store, 3, 10, 42);
assert_eq!(pruner.centroids.len(), 3);
}
#[test]
fn centroids_have_correct_dim() {
let store = build_store();
let pruner = PlaidPruner::train(&store, 3, 10, 42);
for c in &pruner.centroids {
assert_eq!(c.len(), 2);
}
}
#[test]
fn candidates_non_empty_for_matching_query() {
let store = build_store();
let pruner = PlaidPruner::train(&store, 3, 10, 42);
let query = vec![vec![0.0f32, 0.0f32]];
let cands = pruner.candidates(&query);
assert!(!cands.is_empty(), "expected at least one candidate");
}
#[test]
fn candidates_empty_when_no_centroids() {
let store = MultiVectorStore::new(2, MultiVecMode::PerToken);
let pruner = PlaidPruner::train(&store, 3, 5, 1);
let query = vec![vec![0.0f32, 0.0f32]];
assert!(pruner.candidates(&query).is_empty());
}
#[test]
fn candidates_cover_input_range() {
let store = build_store();
let pruner = PlaidPruner::train(&store, 3, 15, 7);
let query = vec![
vec![0.0f32, 0.0f32],
vec![10.0f32, 0.0f32],
vec![0.0f32, 10.0f32],
];
let mut cands = pruner.candidates(&query);
cands.sort_unstable();
cands.dedup();
assert_eq!(cands.len(), 9, "all docs should be candidates: {:?}", cands);
}
}