1use std::collections::{HashMap, HashSet};
14
15use crate::distance::scalar::scalar_distance;
16use nodedb_types::vector_distance::DistanceMetric;
17
18use super::storage::MultiVectorStore;
19
20fn assign(vectors: &[Vec<f32>], centroids: &[Vec<f32>]) -> Vec<usize> {
26 vectors
27 .iter()
28 .map(|v| {
29 centroids
30 .iter()
31 .enumerate()
32 .map(|(i, c)| (i, scalar_distance(v, c, DistanceMetric::L2)))
33 .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
34 .map(|(i, _)| i)
35 .unwrap_or(0)
36 })
37 .collect()
38}
39
40fn lcg_next(s: &mut u64) -> u64 {
42 *s = s
43 .wrapping_mul(6364136223846793005)
44 .wrapping_add(1442695040888963407);
45 *s
46}
47
48fn min_dist_to_centroids(v: &[f32], centroids: &[Vec<f32>]) -> f32 {
50 centroids
51 .iter()
52 .map(|c| scalar_distance(v, c, DistanceMetric::L2))
53 .min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
54 .unwrap_or(f32::INFINITY)
55}
56
57fn recompute(
62 vectors: &[Vec<f32>],
63 assignments: &[usize],
64 num_centroids: usize,
65 dim: usize,
66 prev_centroids: &[Vec<f32>],
67) -> Vec<Vec<f32>> {
68 let mut sums = vec![vec![0.0f32; dim]; num_centroids];
69 let mut counts = vec![0usize; num_centroids];
70
71 for (v, &c) in vectors.iter().zip(assignments.iter()) {
72 for (s, x) in sums[c].iter_mut().zip(v.iter()) {
73 *s += x;
74 }
75 counts[c] += 1;
76 }
77
78 for (s, &n) in sums.iter_mut().zip(counts.iter()) {
80 if n > 0 {
81 s.iter_mut().for_each(|x| *x /= n as f32);
82 }
83 }
84
85 for c_idx in 0..num_centroids {
90 if counts[c_idx] != 0 {
91 continue;
92 }
93 let live: Vec<Vec<f32>> = counts
94 .iter()
95 .enumerate()
96 .filter(|(i, cnt)| *i != c_idx && **cnt > 0)
97 .map(|(i, _)| sums[i].clone())
98 .collect();
99 let seed_pool: &[Vec<f32>] = if live.is_empty() {
100 prev_centroids
101 } else {
102 &live
103 };
104 let farthest = vectors
105 .iter()
106 .map(|v| (v, min_dist_to_centroids(v, seed_pool)))
107 .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
108 .map(|(v, _)| v.clone());
109 if let Some(v) = farthest {
110 sums[c_idx] = v;
111 counts[c_idx] = 1; } else if c_idx < prev_centroids.len() {
113 sums[c_idx] = prev_centroids[c_idx].clone();
114 counts[c_idx] = 1;
115 }
116 }
117
118 sums
119}
120
121fn kmeans_plus_plus_init(vectors: &[Vec<f32>], k: usize, seed: u64) -> Vec<Vec<f32>> {
125 let mut state = seed.wrapping_add(1);
126 let first = (lcg_next(&mut state) as usize) % vectors.len();
127 let mut centroids: Vec<Vec<f32>> = vec![vectors[first].clone()];
128
129 while centroids.len() < k {
130 let dists: Vec<f32> = vectors
131 .iter()
132 .map(|v| {
133 let d = min_dist_to_centroids(v, ¢roids);
134 d * d
135 })
136 .collect();
137 let total: f64 = dists.iter().map(|&d| d as f64).sum();
138 if total <= 0.0 {
139 let idx = (lcg_next(&mut state) as usize) % vectors.len();
142 centroids.push(vectors[idx].clone());
143 continue;
144 }
145 let r = (lcg_next(&mut state) as f64) / (u64::MAX as f64) * total;
147 let mut acc = 0.0f64;
148 let mut pick = vectors.len() - 1;
149 for (i, &d) in dists.iter().enumerate() {
150 acc += d as f64;
151 if acc >= r {
152 pick = i;
153 break;
154 }
155 }
156 centroids.push(vectors[pick].clone());
157 }
158
159 centroids
160}
161
162fn kmeans(
166 vectors: &[Vec<f32>],
167 num_centroids: usize,
168 iters: usize,
169 seed: u64,
170 dim: usize,
171) -> Vec<Vec<f32>> {
172 if vectors.is_empty() || num_centroids == 0 {
173 return Vec::new();
174 }
175
176 let k = num_centroids.min(vectors.len());
177 let mut centroids = kmeans_plus_plus_init(vectors, k, seed);
178
179 for _ in 0..iters {
180 let assignments = assign(vectors, ¢roids);
181 let new_centroids = recompute(vectors, &assignments, k, dim, ¢roids);
182 centroids = new_centroids;
183 }
184
185 centroids
186}
187
188pub struct PlaidPruner {
197 pub centroids: Vec<Vec<f32>>,
198 doc_centroids: HashMap<u32, Vec<u16>>,
200}
201
202impl PlaidPruner {
203 pub fn train(
209 store: &MultiVectorStore,
210 num_centroids: u16,
211 kmeans_iters: usize,
212 seed: u64,
213 ) -> Self {
214 let dim = store.dim;
215 let nc = num_centroids as usize;
216
217 let all_vectors: Vec<Vec<f32>> = store
219 .iter()
220 .flat_map(|doc| doc.vectors.iter().cloned())
221 .collect();
222
223 let centroids = kmeans(&all_vectors, nc, kmeans_iters, seed, dim);
224
225 let doc_centroids: HashMap<u32, Vec<u16>> = store
227 .iter()
228 .map(|doc| {
229 let mut ids: Vec<u16> = doc
230 .vectors
231 .iter()
232 .map(|v| {
233 centroids
234 .iter()
235 .enumerate()
236 .map(|(i, c)| (i as u16, scalar_distance(v, c, DistanceMetric::L2)))
237 .min_by(|a, b| {
238 a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal)
239 })
240 .map(|(i, _)| i)
241 .unwrap_or(0)
242 })
243 .collect();
244 ids.sort_unstable();
245 ids.dedup();
246 (doc.doc_id, ids)
247 })
248 .collect();
249
250 Self {
251 centroids,
252 doc_centroids,
253 }
254 }
255
256 pub fn candidates(&self, query: &[Vec<f32>]) -> Vec<u32> {
262 if self.centroids.is_empty() || query.is_empty() {
263 return Vec::new();
264 }
265
266 let query_bag: HashSet<u16> = query
268 .iter()
269 .filter_map(|v| {
270 self.centroids
271 .iter()
272 .enumerate()
273 .map(|(i, c)| (i as u16, scalar_distance(v, c, DistanceMetric::L2)))
274 .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
275 .map(|(id, _)| id)
276 })
277 .collect();
278
279 self.doc_centroids
281 .iter()
282 .filter(|(_, doc_ids)| doc_ids.iter().any(|id| query_bag.contains(id)))
283 .map(|(&doc_id, _)| doc_id)
284 .collect()
285 }
286}
287
288#[cfg(test)]
293mod tests {
294 use super::*;
295 use crate::multivec::storage::{MultiVecMode, MultiVectorDoc, MultiVectorStore};
296
297 fn build_store() -> MultiVectorStore {
298 let mut store = MultiVectorStore::new(2, MultiVecMode::PerToken);
299
300 for i in 0u32..3 {
305 store
306 .insert(MultiVectorDoc {
307 doc_id: i,
308 vectors: vec![vec![i as f32 * 0.1, i as f32 * 0.1]],
309 })
310 .unwrap();
311 }
312 for i in 3u32..6 {
313 store
314 .insert(MultiVectorDoc {
315 doc_id: i,
316 vectors: vec![vec![10.0 + i as f32 * 0.1, 0.0]],
317 })
318 .unwrap();
319 }
320 for i in 6u32..9 {
321 store
322 .insert(MultiVectorDoc {
323 doc_id: i,
324 vectors: vec![vec![0.0, 10.0 + i as f32 * 0.1]],
325 })
326 .unwrap();
327 }
328 store
329 }
330
331 #[test]
332 fn train_produces_correct_centroid_count() {
333 let store = build_store();
334 let pruner = PlaidPruner::train(&store, 3, 10, 42);
335 assert_eq!(pruner.centroids.len(), 3);
336 }
337
338 #[test]
339 fn centroids_have_correct_dim() {
340 let store = build_store();
341 let pruner = PlaidPruner::train(&store, 3, 10, 42);
342 for c in &pruner.centroids {
343 assert_eq!(c.len(), 2);
344 }
345 }
346
347 #[test]
348 fn candidates_non_empty_for_matching_query() {
349 let store = build_store();
350 let pruner = PlaidPruner::train(&store, 3, 10, 42);
351
352 let query = vec![vec![0.0f32, 0.0f32]];
354 let cands = pruner.candidates(&query);
355 assert!(!cands.is_empty(), "expected at least one candidate");
356 }
357
358 #[test]
359 fn candidates_empty_when_no_centroids() {
360 let store = MultiVectorStore::new(2, MultiVecMode::PerToken);
362 let pruner = PlaidPruner::train(&store, 3, 5, 1);
363 let query = vec![vec![0.0f32, 0.0f32]];
364 assert!(pruner.candidates(&query).is_empty());
365 }
366
367 #[test]
368 fn candidates_cover_input_range() {
369 let store = build_store();
372 let pruner = PlaidPruner::train(&store, 3, 15, 7);
373 let query = vec![
374 vec![0.0f32, 0.0f32],
375 vec![10.0f32, 0.0f32],
376 vec![0.0f32, 10.0f32],
377 ];
378 let mut cands = pruner.candidates(&query);
379 cands.sort_unstable();
380 cands.dedup();
381 assert_eq!(cands.len(), 9, "all docs should be candidates: {:?}", cands);
382 }
383}