Skip to main content

nodedb_vector/multivec/
plaid.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! PLAID-style centroid-based candidate pruning for multi-vector search.
4//!
5//! Builds K-means centroids over all document vectors.  Each document is
6//! encoded as a sorted bag of centroid IDs.  At query time the query's
7//! centroid bag is computed and only documents whose centroid bag overlaps
8//! the query bag are returned as candidates.
9//!
10//! Reference: Santhanam et al., "PLAID: An Efficient Engine for Late
11//! Interaction Retrieval", CIKM 2022.
12
13use std::collections::{HashMap, HashSet};
14
15use crate::distance::scalar::scalar_distance;
16use nodedb_types::vector_distance::DistanceMetric;
17
18use super::storage::MultiVectorStore;
19
20// ---------------------------------------------------------------------------
21// Internal Lloyd's K-means (tiny, self-contained)
22// ---------------------------------------------------------------------------
23
24/// Assign each vector to its nearest centroid index.
25fn assign(vectors: &[Vec<f32>], centroids: &[Vec<f32>]) -> Vec<usize> {
26    vectors
27        .iter()
28        .map(|v| {
29            centroids
30                .iter()
31                .enumerate()
32                .map(|(i, c)| (i, scalar_distance(v, c, DistanceMetric::L2)))
33                .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
34                .map(|(i, _)| i)
35                .unwrap_or(0)
36        })
37        .collect()
38}
39
40/// Deterministic LCG step for reproducible randomness across train calls.
41fn lcg_next(s: &mut u64) -> u64 {
42    *s = s
43        .wrapping_mul(6364136223846793005)
44        .wrapping_add(1442695040888963407);
45    *s
46}
47
48/// Distance from `v` to the nearest centroid in `centroids`.
49fn min_dist_to_centroids(v: &[f32], centroids: &[Vec<f32>]) -> f32 {
50    centroids
51        .iter()
52        .map(|c| scalar_distance(v, c, DistanceMetric::L2))
53        .min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
54        .unwrap_or(f32::INFINITY)
55}
56
57/// Recompute centroids as the mean of their assigned vectors. Empty clusters
58/// are re-seeded to the input vector farthest from all live centroids — this
59/// guarantees every centroid covers some part of the input space and prevents
60/// k-means from collapsing into fewer-than-k effective clusters.
61fn recompute(
62    vectors: &[Vec<f32>],
63    assignments: &[usize],
64    num_centroids: usize,
65    dim: usize,
66    prev_centroids: &[Vec<f32>],
67) -> Vec<Vec<f32>> {
68    let mut sums = vec![vec![0.0f32; dim]; num_centroids];
69    let mut counts = vec![0usize; num_centroids];
70
71    for (v, &c) in vectors.iter().zip(assignments.iter()) {
72        for (s, x) in sums[c].iter_mut().zip(v.iter()) {
73            *s += x;
74        }
75        counts[c] += 1;
76    }
77
78    // First pass: average populated clusters in place.
79    for (s, &n) in sums.iter_mut().zip(counts.iter()) {
80        if n > 0 {
81            s.iter_mut().for_each(|x| *x /= n as f32);
82        }
83    }
84
85    // Second pass: re-seed empty clusters from vectors farthest from any live
86    // centroid. Snapshot the live set first so each empty slot is filled
87    // deterministically and subsequent re-seeds in the same call see the
88    // updated pool.
89    for c_idx in 0..num_centroids {
90        if counts[c_idx] != 0 {
91            continue;
92        }
93        let live: Vec<Vec<f32>> = counts
94            .iter()
95            .enumerate()
96            .filter(|(i, cnt)| *i != c_idx && **cnt > 0)
97            .map(|(i, _)| sums[i].clone())
98            .collect();
99        let seed_pool: &[Vec<f32>] = if live.is_empty() {
100            prev_centroids
101        } else {
102            &live
103        };
104        let farthest = vectors
105            .iter()
106            .map(|v| (v, min_dist_to_centroids(v, seed_pool)))
107            .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
108            .map(|(v, _)| v.clone());
109        if let Some(v) = farthest {
110            sums[c_idx] = v;
111            counts[c_idx] = 1; // mark live so later empty slots see it as a seed.
112        } else if c_idx < prev_centroids.len() {
113            sums[c_idx] = prev_centroids[c_idx].clone();
114            counts[c_idx] = 1;
115        }
116    }
117
118    sums
119}
120
121/// k-means++ initialisation: first centroid uniform random, subsequent
122/// centroids drawn with probability proportional to squared distance from the
123/// nearest already-chosen centroid. Deterministic given `seed`.
124fn kmeans_plus_plus_init(vectors: &[Vec<f32>], k: usize, seed: u64) -> Vec<Vec<f32>> {
125    let mut state = seed.wrapping_add(1);
126    let first = (lcg_next(&mut state) as usize) % vectors.len();
127    let mut centroids: Vec<Vec<f32>> = vec![vectors[first].clone()];
128
129    while centroids.len() < k {
130        let dists: Vec<f32> = vectors
131            .iter()
132            .map(|v| {
133                let d = min_dist_to_centroids(v, &centroids);
134                d * d
135            })
136            .collect();
137        let total: f64 = dists.iter().map(|&d| d as f64).sum();
138        if total <= 0.0 {
139            // All remaining vectors coincide with existing centroids; just
140            // pick any unique-by-index vector to fill k.
141            let idx = (lcg_next(&mut state) as usize) % vectors.len();
142            centroids.push(vectors[idx].clone());
143            continue;
144        }
145        // Deterministic weighted draw from the LCG.
146        let r = (lcg_next(&mut state) as f64) / (u64::MAX as f64) * total;
147        let mut acc = 0.0f64;
148        let mut pick = vectors.len() - 1;
149        for (i, &d) in dists.iter().enumerate() {
150            acc += d as f64;
151            if acc >= r {
152                pick = i;
153                break;
154            }
155        }
156        centroids.push(vectors[pick].clone());
157    }
158
159    centroids
160}
161
162/// Run Lloyd's K-means with k-means++ initialisation. Empty clusters are
163/// re-seeded each iteration so the result always has exactly `k` distinct
164/// centroids covering the input space.
165fn kmeans(
166    vectors: &[Vec<f32>],
167    num_centroids: usize,
168    iters: usize,
169    seed: u64,
170    dim: usize,
171) -> Vec<Vec<f32>> {
172    if vectors.is_empty() || num_centroids == 0 {
173        return Vec::new();
174    }
175
176    let k = num_centroids.min(vectors.len());
177    let mut centroids = kmeans_plus_plus_init(vectors, k, seed);
178
179    for _ in 0..iters {
180        let assignments = assign(vectors, &centroids);
181        let new_centroids = recompute(vectors, &assignments, k, dim, &centroids);
182        centroids = new_centroids;
183    }
184
185    centroids
186}
187
188// ---------------------------------------------------------------------------
189// PlaidPruner
190// ---------------------------------------------------------------------------
191
192/// PLAID centroid-based candidate pruner.
193///
194/// After `train`, call `candidates` at query time to get the set of document
195/// IDs whose centroid bag overlaps the query's centroid bag.
196pub struct PlaidPruner {
197    pub centroids: Vec<Vec<f32>>,
198    /// Sorted list of centroid IDs for each document.
199    doc_centroids: HashMap<u32, Vec<u16>>,
200}
201
202impl PlaidPruner {
203    /// Train the pruner from a `MultiVectorStore`.
204    ///
205    /// * `num_centroids` — number of K-means clusters.
206    /// * `kmeans_iters` — Lloyd iterations.
207    /// * `seed` — deterministic seed for centroid initialisation.
208    pub fn train(
209        store: &MultiVectorStore,
210        num_centroids: u16,
211        kmeans_iters: usize,
212        seed: u64,
213    ) -> Self {
214        let dim = store.dim;
215        let nc = num_centroids as usize;
216
217        // Collect all document vectors for K-means training.
218        let all_vectors: Vec<Vec<f32>> = store
219            .iter()
220            .flat_map(|doc| doc.vectors.iter().cloned())
221            .collect();
222
223        let centroids = kmeans(&all_vectors, nc, kmeans_iters, seed, dim);
224
225        // Encode each document as a sorted, deduplicated bag of centroid IDs.
226        let doc_centroids: HashMap<u32, Vec<u16>> = store
227            .iter()
228            .map(|doc| {
229                let mut ids: Vec<u16> = doc
230                    .vectors
231                    .iter()
232                    .map(|v| {
233                        centroids
234                            .iter()
235                            .enumerate()
236                            .map(|(i, c)| (i as u16, scalar_distance(v, c, DistanceMetric::L2)))
237                            .min_by(|a, b| {
238                                a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal)
239                            })
240                            .map(|(i, _)| i)
241                            .unwrap_or(0)
242                    })
243                    .collect();
244                ids.sort_unstable();
245                ids.dedup();
246                (doc.doc_id, ids)
247            })
248            .collect();
249
250        Self {
251            centroids,
252            doc_centroids,
253        }
254    }
255
256    /// Return candidate doc IDs whose centroid bag overlaps the query's
257    /// centroid bag.
258    ///
259    /// The query centroid bag is the set of nearest centroids for each query
260    /// vector.
261    pub fn candidates(&self, query: &[Vec<f32>]) -> Vec<u32> {
262        if self.centroids.is_empty() || query.is_empty() {
263            return Vec::new();
264        }
265
266        // Build query centroid bag.
267        let query_bag: HashSet<u16> = query
268            .iter()
269            .filter_map(|v| {
270                self.centroids
271                    .iter()
272                    .enumerate()
273                    .map(|(i, c)| (i as u16, scalar_distance(v, c, DistanceMetric::L2)))
274                    .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
275                    .map(|(id, _)| id)
276            })
277            .collect();
278
279        // Collect docs that share at least one centroid with the query.
280        self.doc_centroids
281            .iter()
282            .filter(|(_, doc_ids)| doc_ids.iter().any(|id| query_bag.contains(id)))
283            .map(|(&doc_id, _)| doc_id)
284            .collect()
285    }
286}
287
288// ---------------------------------------------------------------------------
289// Tests
290// ---------------------------------------------------------------------------
291
292#[cfg(test)]
293mod tests {
294    use super::*;
295    use crate::multivec::storage::{MultiVecMode, MultiVectorDoc, MultiVectorStore};
296
297    fn build_store() -> MultiVectorStore {
298        let mut store = MultiVectorStore::new(2, MultiVecMode::PerToken);
299
300        // Three well-separated clusters of documents.
301        // Cluster A: docs 0–2 near (0,0).
302        // Cluster B: docs 3–5 near (10,0).
303        // Cluster C: docs 6–8 near (0,10).
304        for i in 0u32..3 {
305            store
306                .insert(MultiVectorDoc {
307                    doc_id: i,
308                    vectors: vec![vec![i as f32 * 0.1, i as f32 * 0.1]],
309                })
310                .unwrap();
311        }
312        for i in 3u32..6 {
313            store
314                .insert(MultiVectorDoc {
315                    doc_id: i,
316                    vectors: vec![vec![10.0 + i as f32 * 0.1, 0.0]],
317                })
318                .unwrap();
319        }
320        for i in 6u32..9 {
321            store
322                .insert(MultiVectorDoc {
323                    doc_id: i,
324                    vectors: vec![vec![0.0, 10.0 + i as f32 * 0.1]],
325                })
326                .unwrap();
327        }
328        store
329    }
330
331    #[test]
332    fn train_produces_correct_centroid_count() {
333        let store = build_store();
334        let pruner = PlaidPruner::train(&store, 3, 10, 42);
335        assert_eq!(pruner.centroids.len(), 3);
336    }
337
338    #[test]
339    fn centroids_have_correct_dim() {
340        let store = build_store();
341        let pruner = PlaidPruner::train(&store, 3, 10, 42);
342        for c in &pruner.centroids {
343            assert_eq!(c.len(), 2);
344        }
345    }
346
347    #[test]
348    fn candidates_non_empty_for_matching_query() {
349        let store = build_store();
350        let pruner = PlaidPruner::train(&store, 3, 10, 42);
351
352        // A query near cluster A should return at least some candidates.
353        let query = vec![vec![0.0f32, 0.0f32]];
354        let cands = pruner.candidates(&query);
355        assert!(!cands.is_empty(), "expected at least one candidate");
356    }
357
358    #[test]
359    fn candidates_empty_when_no_centroids() {
360        // An empty store produces a pruner with no centroids.
361        let store = MultiVectorStore::new(2, MultiVecMode::PerToken);
362        let pruner = PlaidPruner::train(&store, 3, 5, 1);
363        let query = vec![vec![0.0f32, 0.0f32]];
364        assert!(pruner.candidates(&query).is_empty());
365    }
366
367    #[test]
368    fn candidates_cover_input_range() {
369        // After training, the set of all-doc candidates (using multiple query
370        // vectors spanning the whole space) should cover all 9 documents.
371        let store = build_store();
372        let pruner = PlaidPruner::train(&store, 3, 15, 7);
373        let query = vec![
374            vec![0.0f32, 0.0f32],
375            vec![10.0f32, 0.0f32],
376            vec![0.0f32, 10.0f32],
377        ];
378        let mut cands = pruner.candidates(&query);
379        cands.sort_unstable();
380        cands.dedup();
381        assert_eq!(cands.len(), 9, "all docs should be candidates: {:?}", cands);
382    }
383}