Skip to main content

citadel_vector/
ann.rs

1//! In-memory ANN index wrapping the vendored PRISM engine.
2
3use crate::prism::{Filter, Metric, PointStore, PrismConfig, PrismIndex};
4
5/// Request `k * OVER_FETCH` candidates to offset PRISM recall below 1.0.
6pub const OVER_FETCH: usize = 4;
7
8/// Post-snapshot tail past this many rows rebuilds instead of brute-force merging.
9pub const REBUILD_TAIL_MAX: usize = 2048;
10
11#[derive(Debug, thiserror::Error)]
12pub enum AnnError {
13    #[error("ANN build requires at least one row")]
14    EmptyInput,
15    #[error("ANN build vector dim mismatch: expected {expected}, got {got} for row_id {row_id}")]
16    DimMismatch {
17        expected: u16,
18        got: usize,
19        row_id: u64,
20    },
21    #[error(
22        "ANN build attribute arity mismatch: expected {expected}, got {got} for row_id {row_id}"
23    )]
24    AttrArityMismatch {
25        expected: usize,
26        got: usize,
27        row_id: u64,
28    },
29}
30
31// binary_rerank=0: the Hamming pre-filter kills recall on continuous vectors.
32// sigma_high low: stay in the fast HIGH search regime.
33fn prism_config(metric: Metric) -> PrismConfig {
34    PrismConfig {
35        metric,
36        binary_rerank: 0,
37        sigma_high: 0.001,
38        ..PrismConfig::default()
39    }
40}
41
42/// In-memory ANN index over a `(row_id, vector)` snapshot.
43pub struct AnnIndex {
44    prism: PrismIndex,
45    /// PRISM internal id -> external row_id.
46    id_map: Vec<u64>,
47    /// Highest row_id in the snapshot.
48    pub snapshot_max: u64,
49    pub metric: Metric,
50    pub dim: u16,
51}
52
53impl std::fmt::Debug for AnnIndex {
54    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
55        f.debug_struct("AnnIndex")
56            .field("snapshot_max", &self.snapshot_max)
57            .field("metric", &self.metric)
58            .field("dim", &self.dim)
59            .field("indexed_len", &self.id_map.len())
60            .finish()
61    }
62}
63
64impl AnnIndex {
65    /// Build an unfiltered index from `(row_id, vector)` pairs.
66    pub fn build(rows: Vec<(u64, Vec<f32>)>, metric: Metric, dim: u16) -> Result<Self, AnnError> {
67        let with_attrs = rows
68            .into_iter()
69            .map(|(id, v)| (id, v, Vec::new()))
70            .collect();
71        Self::build_with_attrs(with_attrs, 0, metric, dim)
72    }
73
74    /// Build a filtered index from `(row_id, vector, attr_codes)` triples. Each
75    /// attribute is a PRISM dimension; distinct tuples form the searchable cells.
76    pub fn build_with_attrs(
77        mut rows: Vec<(u64, Vec<f32>, Vec<u32>)>,
78        num_attrs: usize,
79        metric: Metric,
80        dim: u16,
81    ) -> Result<Self, AnnError> {
82        if rows.is_empty() {
83            return Err(AnnError::EmptyInput);
84        }
85        for (rid, v, a) in &rows {
86            if v.len() != dim as usize {
87                return Err(AnnError::DimMismatch {
88                    expected: dim,
89                    got: v.len(),
90                    row_id: *rid,
91                });
92            }
93            if a.len() != num_attrs {
94                return Err(AnnError::AttrArityMismatch {
95                    expected: num_attrs,
96                    got: a.len(),
97                    row_id: *rid,
98                });
99            }
100        }
101
102        rows.sort_unstable_by_key(|(id, _, _)| *id);
103        let snapshot_max = rows.last().map(|(id, _, _)| *id).unwrap_or(0);
104
105        let n = rows.len();
106        let mut flat: Vec<f32> = Vec::with_capacity(n * dim as usize);
107        let mut row_ids: Vec<u64> = Vec::with_capacity(n);
108        // PRISM needs >=1 attribute dim; an all-zero column = one cell.
109        let attr_dims = num_attrs.max(1);
110        let mut attr_cols: Vec<Vec<u32>> = vec![Vec::with_capacity(n); attr_dims];
111        for (rid, v, a) in &rows {
112            flat.extend_from_slice(v);
113            row_ids.push(*rid);
114            if num_attrs == 0 {
115                attr_cols[0].push(0);
116            } else {
117                for (j, &code) in a.iter().enumerate() {
118                    attr_cols[j].push(code);
119                }
120            }
121        }
122
123        let store = PointStore::from_parts(flat, dim as usize, attr_cols);
124        let prism = PrismIndex::build(store, prism_config(metric));
125
126        // PRISM reorders points by cell; remap to external row_ids.
127        let id_map: Vec<u64> = prism
128            .original_ids
129            .iter()
130            .map(|&old| row_ids[old as usize])
131            .collect();
132
133        Ok(Self {
134            prism,
135            id_map,
136            snapshot_max,
137            metric,
138            dim,
139        })
140    }
141
142    /// Reassemble from persisted parts (the ANN segment decode path). The
143    /// caller is responsible for `prism.store.vectors` being in PRISM-internal
144    /// (cell-reordered) order - see `segment::SegmentParts::into_index`.
145    pub fn from_parts(
146        prism: PrismIndex,
147        id_map: Vec<u64>,
148        snapshot_max: u64,
149        metric: Metric,
150        dim: u16,
151    ) -> Self {
152        Self {
153            prism,
154            id_map,
155            snapshot_max,
156            metric,
157            dim,
158        }
159    }
160
161    pub fn prism(&self) -> &PrismIndex {
162        &self.prism
163    }
164
165    /// PRISM internal id -> external row_id.
166    pub fn id_map(&self) -> &[u64] {
167        &self.id_map
168    }
169
170    /// The PRISM configuration this index family builds with - part of the
171    /// persisted segment's binding (a config change invalidates segments).
172    pub fn active_config(metric: Metric) -> PrismConfig {
173        prism_config(metric)
174    }
175
176    /// Top-k search returning `(row_id, distance)` ascending, at the default ef.
177    pub fn search(&self, query: &[f32], k: usize) -> Vec<(u64, f32)> {
178        let ef = (k * OVER_FETCH).max(self.prism.config.beam_width);
179        self.search_with_ef(query, k, ef)
180    }
181
182    /// Unfiltered search with an explicit beam width `ef`.
183    pub fn search_with_ef(&self, query: &[f32], k: usize, ef: usize) -> Vec<(u64, f32)> {
184        self.search_filtered(query, k, ef, &Filter::none())
185    }
186
187    /// Filtered search at the default ef.
188    pub fn search_filtered_default_ef(
189        &self,
190        query: &[f32],
191        k: usize,
192        filter: &Filter,
193    ) -> Vec<(u64, f32)> {
194        let ef = (k * OVER_FETCH).max(self.prism.config.beam_width);
195        self.search_filtered(query, k, ef, filter)
196    }
197
198    /// Filtered search; `filter` dims index the `build_with_attrs` attributes
199    /// (`Filter::none()` matches all).
200    ///
201    /// Distances use the same units as the SQL operators (`<->` = true L2,
202    /// `<#>` = -dot, `<=>` = 1-cos), so callers may mix them with exact-ranked
203    /// scores. PRISM reports squared L2 internally; converted here.
204    pub fn search_filtered(
205        &self,
206        query: &[f32],
207        k: usize,
208        ef: usize,
209        filter: &Filter,
210    ) -> Vec<(u64, f32)> {
211        debug_assert_eq!(query.len(), self.dim as usize);
212        let sqrt_l2 = self.metric == Metric::L2;
213        self.prism
214            .search(query, filter, k, ef)
215            .into_iter()
216            .map(|r| {
217                let dist = if sqrt_l2 { r.dist.sqrt() } else { r.dist };
218                (self.id_map[r.id as usize], dist)
219            })
220            .collect()
221    }
222
223    /// Number of indexed rows.
224    pub fn indexed_len(&self) -> usize {
225        self.id_map.len()
226    }
227
228    /// Post-snapshot tail exceeds the rebuild cap or a quarter of the indexed set.
229    pub fn tail_is_stale(&self, live_max: u64) -> bool {
230        let tail = live_max.saturating_sub(self.snapshot_max) as usize;
231        tail > REBUILD_TAIL_MAX || tail > self.indexed_len() / 4
232    }
233}
234
235#[cfg(test)]
236mod tests {
237    use super::*;
238
239    fn synth_rows(n: usize, dim: u16) -> Vec<(u64, Vec<f32>)> {
240        (0..n)
241            .map(|i| {
242                let row_id = (i as u64) + 1;
243                let v: Vec<f32> = (0..dim).map(|d| (i as f32 + d as f32) * 0.01).collect();
244                (row_id, v)
245            })
246            .collect()
247    }
248
249    #[test]
250    fn build_empty_input_errors() {
251        let err = AnnIndex::build(Vec::new(), Metric::L2, 4).unwrap_err();
252        assert!(matches!(err, AnnError::EmptyInput));
253    }
254
255    #[test]
256    fn build_dim_mismatch_errors() {
257        let rows = vec![(1u64, vec![1.0, 2.0])];
258        let err = AnnIndex::build(rows, Metric::L2, 4).unwrap_err();
259        assert!(matches!(
260            err,
261            AnnError::DimMismatch {
262                expected: 4,
263                got: 2,
264                row_id: 1
265            }
266        ));
267    }
268
269    #[test]
270    fn build_single_row_succeeds() {
271        let rows = vec![(7u64, vec![0.1, 0.2, 0.3, 0.4])];
272        let idx = AnnIndex::build(rows, Metric::L2, 4).unwrap();
273        assert_eq!(idx.indexed_len(), 1);
274        assert_eq!(idx.snapshot_max, 7);
275    }
276
277    #[test]
278    fn build_small_n_succeeds() {
279        let rows = synth_rows(5, 8);
280        let idx = AnnIndex::build(rows, Metric::L2, 8).unwrap();
281        assert_eq!(idx.indexed_len(), 5);
282    }
283
284    #[test]
285    fn build_large_n_succeeds() {
286        let rows = synth_rows(500, 16);
287        let idx = AnnIndex::build(rows, Metric::L2, 16).unwrap();
288        assert_eq!(idx.indexed_len(), 500);
289    }
290
291    #[test]
292    fn search_returns_row_ids_not_internal_ids() {
293        let n = 200;
294        let rows = synth_rows(n, 8);
295        let idx = AnnIndex::build(rows, Metric::L2, 8).unwrap();
296        let hits = idx.search(&[0.5; 8], 5);
297        assert!(!hits.is_empty());
298        for (rid, _d) in &hits {
299            assert!(*rid >= 1 && *rid <= n as u64);
300        }
301    }
302
303    #[test]
304    fn snapshot_max_tracks_highest_row_id() {
305        let rows = vec![
306            (5u64, vec![1.0, 0.0]),
307            (10u64, vec![0.0, 1.0]),
308            (3u64, vec![1.0, 1.0]),
309        ];
310        let idx = AnnIndex::build(rows, Metric::L2, 2).unwrap();
311        assert_eq!(idx.snapshot_max, 10);
312    }
313
314    #[test]
315    fn cosine_metric_propagates_to_prism() {
316        let rows = synth_rows(50, 16);
317        let idx = AnnIndex::build(rows, Metric::Cosine, 16).unwrap();
318        assert_eq!(idx.metric, Metric::Cosine);
319        assert_eq!(idx.prism.config.metric, Metric::Cosine);
320    }
321
322    #[test]
323    fn inner_metric_propagates_to_prism() {
324        let rows = synth_rows(50, 16);
325        let idx = AnnIndex::build(rows, Metric::InnerProduct, 16).unwrap();
326        assert_eq!(idx.metric, Metric::InnerProduct);
327        assert_eq!(idx.prism.config.metric, Metric::InnerProduct);
328    }
329
330    /// attr 0 = i % 2; row_id = i + 1.
331    fn attr_rows(n: u64, dim: u16) -> Vec<(u64, Vec<f32>, Vec<u32>)> {
332        (0..n)
333            .map(|i| {
334                let v: Vec<f32> = (0..dim).map(|d| (i as f32 + d as f32) * 0.01).collect();
335                (i + 1, v, vec![(i % 2) as u32])
336            })
337            .collect()
338    }
339
340    #[test]
341    fn build_with_attrs_filters_by_attribute() {
342        let idx = AnnIndex::build_with_attrs(attr_rows(100, 8), 1, Metric::L2, 8).unwrap();
343        let hits = idx.search_filtered(&[0.5; 8], 10, 200, &Filter::eq(0, 1));
344        assert!(!hits.is_empty());
345        assert!(hits.len() <= 10);
346        for (rid, _) in &hits {
347            // category 1 == odd i == even row_id.
348            assert_eq!(rid % 2, 0, "row {rid} is not category 1");
349        }
350    }
351
352    #[test]
353    fn build_with_attrs_unfiltered_spans_all_cells() {
354        let idx = AnnIndex::build_with_attrs(attr_rows(100, 8), 1, Metric::L2, 8).unwrap();
355        let hits = idx.search_with_ef(&[0.5; 8], 10, 200);
356        assert_eq!(hits.len(), 10);
357        for (rid, _) in &hits {
358            assert!(*rid >= 1 && *rid <= 100);
359        }
360    }
361
362    #[test]
363    fn build_with_attrs_two_dims_conjunctive_filter() {
364        let n = 180u64;
365        let dim = 8u16;
366        let rows: Vec<(u64, Vec<f32>, Vec<u32>)> = (0..n)
367            .map(|i| {
368                let v: Vec<f32> = (0..dim).map(|d| (i as f32 + d as f32) * 0.01).collect();
369                (i + 1, v, vec![(i % 2) as u32, (i % 3) as u32])
370            })
371            .collect();
372        let idx = AnnIndex::build_with_attrs(rows, 2, Metric::L2, dim).unwrap();
373        let filter = Filter::new(vec![(0, vec![1]), (1, vec![2])]);
374        let hits = idx.search_filtered(&[0.5; 8], 10, 200, &filter);
375        assert!(!hits.is_empty());
376        for (rid, _) in &hits {
377            let i = rid - 1;
378            assert_eq!(i % 2, 1, "row {rid} fails attr0 = 1");
379            assert_eq!(i % 3, 2, "row {rid} fails attr1 = 2");
380        }
381    }
382
383    #[test]
384    fn build_with_attrs_arity_mismatch_errors() {
385        let rows = vec![(1u64, vec![0.0; 4], vec![0u32])];
386        let err = AnnIndex::build_with_attrs(rows, 2, Metric::L2, 4).unwrap_err();
387        assert!(matches!(
388            err,
389            AnnError::AttrArityMismatch {
390                expected: 2,
391                got: 1,
392                row_id: 1
393            }
394        ));
395    }
396
397    #[test]
398    fn build_delegates_to_attrs_path() {
399        let idx = AnnIndex::build(synth_rows(50, 8), Metric::L2, 8).unwrap();
400        assert_eq!(idx.indexed_len(), 50);
401        let hits = idx.search(&[0.3; 8], 5);
402        assert!(!hits.is_empty());
403    }
404}