Skip to main content

citadel_vector/
ann.rs

1//! In-memory ANN index wrapping the vendored PRISM engine.
2
3use crate::prism::{Filter, Metric, PointStore, PrismConfig, PrismIndex};
4
5/// Request `k * OVER_FETCH` candidates to offset PRISM recall below 1.0.
6pub const OVER_FETCH: usize = 4;
7
8#[derive(Debug, thiserror::Error)]
9pub enum AnnError {
10    #[error("ANN build requires at least one row")]
11    EmptyInput,
12    #[error("ANN build vector dim mismatch: expected {expected}, got {got} for row_id {row_id}")]
13    DimMismatch {
14        expected: u16,
15        got: usize,
16        row_id: u64,
17    },
18    #[error(
19        "ANN build attribute arity mismatch: expected {expected}, got {got} for row_id {row_id}"
20    )]
21    AttrArityMismatch {
22        expected: usize,
23        got: usize,
24        row_id: u64,
25    },
26}
27
28// binary_rerank=0: the Hamming pre-filter kills recall on continuous vectors.
29// sigma_high low: stay in the fast HIGH search regime.
30fn prism_config(metric: Metric) -> PrismConfig {
31    PrismConfig {
32        metric,
33        binary_rerank: 0,
34        sigma_high: 0.001,
35        ..PrismConfig::default()
36    }
37}
38
39/// In-memory ANN index over a `(row_id, vector)` snapshot.
40pub struct AnnIndex {
41    prism: PrismIndex,
42    /// PRISM internal id -> external row_id.
43    id_map: Vec<u64>,
44    /// Highest row_id in the snapshot.
45    pub snapshot_max: u64,
46    pub metric: Metric,
47    pub dim: u16,
48}
49
50impl std::fmt::Debug for AnnIndex {
51    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
52        f.debug_struct("AnnIndex")
53            .field("snapshot_max", &self.snapshot_max)
54            .field("metric", &self.metric)
55            .field("dim", &self.dim)
56            .field("indexed_len", &self.id_map.len())
57            .finish()
58    }
59}
60
61impl AnnIndex {
62    /// Build an unfiltered index from `(row_id, vector)` pairs.
63    pub fn build(rows: Vec<(u64, Vec<f32>)>, metric: Metric, dim: u16) -> Result<Self, AnnError> {
64        let with_attrs = rows
65            .into_iter()
66            .map(|(id, v)| (id, v, Vec::new()))
67            .collect();
68        Self::build_with_attrs(with_attrs, 0, metric, dim)
69    }
70
71    /// Build a filtered index from `(row_id, vector, attr_codes)` triples. Each
72    /// attribute is a PRISM dimension; distinct tuples form the searchable cells.
73    pub fn build_with_attrs(
74        mut rows: Vec<(u64, Vec<f32>, Vec<u32>)>,
75        num_attrs: usize,
76        metric: Metric,
77        dim: u16,
78    ) -> Result<Self, AnnError> {
79        if rows.is_empty() {
80            return Err(AnnError::EmptyInput);
81        }
82        for (rid, v, a) in &rows {
83            if v.len() != dim as usize {
84                return Err(AnnError::DimMismatch {
85                    expected: dim,
86                    got: v.len(),
87                    row_id: *rid,
88                });
89            }
90            if a.len() != num_attrs {
91                return Err(AnnError::AttrArityMismatch {
92                    expected: num_attrs,
93                    got: a.len(),
94                    row_id: *rid,
95                });
96            }
97        }
98
99        rows.sort_unstable_by_key(|(id, _, _)| *id);
100        let snapshot_max = rows.last().map(|(id, _, _)| *id).unwrap_or(0);
101
102        let n = rows.len();
103        let mut flat: Vec<f32> = Vec::with_capacity(n * dim as usize);
104        let mut row_ids: Vec<u64> = Vec::with_capacity(n);
105        // PRISM needs >=1 attribute dim; an all-zero column = one cell.
106        let attr_dims = num_attrs.max(1);
107        let mut attr_cols: Vec<Vec<u32>> = vec![Vec::with_capacity(n); attr_dims];
108        for (rid, v, a) in &rows {
109            flat.extend_from_slice(v);
110            row_ids.push(*rid);
111            if num_attrs == 0 {
112                attr_cols[0].push(0);
113            } else {
114                for (j, &code) in a.iter().enumerate() {
115                    attr_cols[j].push(code);
116                }
117            }
118        }
119
120        let store = PointStore::from_parts(flat, dim as usize, attr_cols);
121        let prism = PrismIndex::build(store, prism_config(metric));
122
123        // PRISM reorders points by cell; remap to external row_ids.
124        let id_map: Vec<u64> = prism
125            .original_ids
126            .iter()
127            .map(|&old| row_ids[old as usize])
128            .collect();
129
130        Ok(Self {
131            prism,
132            id_map,
133            snapshot_max,
134            metric,
135            dim,
136        })
137    }
138
139    /// Reassemble from persisted parts (the ANN segment decode path). The
140    /// caller is responsible for `prism.store.vectors` being in PRISM-internal
141    /// (cell-reordered) order - see `segment::SegmentParts::into_index`.
142    pub fn from_parts(
143        prism: PrismIndex,
144        id_map: Vec<u64>,
145        snapshot_max: u64,
146        metric: Metric,
147        dim: u16,
148    ) -> Self {
149        Self {
150            prism,
151            id_map,
152            snapshot_max,
153            metric,
154            dim,
155        }
156    }
157
158    pub fn prism(&self) -> &PrismIndex {
159        &self.prism
160    }
161
162    /// PRISM internal id -> external row_id.
163    pub fn id_map(&self) -> &[u64] {
164        &self.id_map
165    }
166
167    /// The PRISM configuration this index family builds with - part of the
168    /// persisted segment's binding (a config change invalidates segments).
169    pub fn active_config(metric: Metric) -> PrismConfig {
170        prism_config(metric)
171    }
172
173    /// Top-k search returning `(row_id, distance)` ascending, at the default ef.
174    pub fn search(&self, query: &[f32], k: usize) -> Vec<(u64, f32)> {
175        let ef = (k * OVER_FETCH).max(self.prism.config.beam_width);
176        self.search_with_ef(query, k, ef)
177    }
178
179    /// Unfiltered search with an explicit beam width `ef`.
180    pub fn search_with_ef(&self, query: &[f32], k: usize, ef: usize) -> Vec<(u64, f32)> {
181        self.search_filtered(query, k, ef, &Filter::none())
182    }
183
184    /// Filtered search at the default ef.
185    pub fn search_filtered_default_ef(
186        &self,
187        query: &[f32],
188        k: usize,
189        filter: &Filter,
190    ) -> Vec<(u64, f32)> {
191        let ef = (k * OVER_FETCH).max(self.prism.config.beam_width);
192        self.search_filtered(query, k, ef, filter)
193    }
194
195    /// Filtered search; `filter` dims index the `build_with_attrs` attributes
196    /// (`Filter::none()` matches all).
197    ///
198    /// Distances use the same units as the SQL operators (`<->` = true L2,
199    /// `<#>` = -dot, `<=>` = 1-cos), so callers may mix them with exact-ranked
200    /// scores. PRISM reports squared L2 internally; converted here.
201    pub fn search_filtered(
202        &self,
203        query: &[f32],
204        k: usize,
205        ef: usize,
206        filter: &Filter,
207    ) -> Vec<(u64, f32)> {
208        debug_assert_eq!(query.len(), self.dim as usize);
209        let sqrt_l2 = self.metric == Metric::L2;
210        self.prism
211            .search(query, filter, k, ef)
212            .into_iter()
213            .map(|r| {
214                let dist = if sqrt_l2 { r.dist.sqrt() } else { r.dist };
215                (self.id_map[r.id as usize], dist)
216            })
217            .collect()
218    }
219
220    /// Number of indexed rows.
221    pub fn indexed_len(&self) -> usize {
222        self.id_map.len()
223    }
224}
225
226#[cfg(test)]
227mod tests {
228    use super::*;
229
230    fn synth_rows(n: usize, dim: u16) -> Vec<(u64, Vec<f32>)> {
231        (0..n)
232            .map(|i| {
233                let row_id = (i as u64) + 1;
234                let v: Vec<f32> = (0..dim).map(|d| (i as f32 + d as f32) * 0.01).collect();
235                (row_id, v)
236            })
237            .collect()
238    }
239
240    #[test]
241    fn build_empty_input_errors() {
242        let err = AnnIndex::build(Vec::new(), Metric::L2, 4).unwrap_err();
243        assert!(matches!(err, AnnError::EmptyInput));
244    }
245
246    #[test]
247    fn build_dim_mismatch_errors() {
248        let rows = vec![(1u64, vec![1.0, 2.0])];
249        let err = AnnIndex::build(rows, Metric::L2, 4).unwrap_err();
250        assert!(matches!(
251            err,
252            AnnError::DimMismatch {
253                expected: 4,
254                got: 2,
255                row_id: 1
256            }
257        ));
258    }
259
260    #[test]
261    fn build_single_row_succeeds() {
262        let rows = vec![(7u64, vec![0.1, 0.2, 0.3, 0.4])];
263        let idx = AnnIndex::build(rows, Metric::L2, 4).unwrap();
264        assert_eq!(idx.indexed_len(), 1);
265        assert_eq!(idx.snapshot_max, 7);
266    }
267
268    #[test]
269    fn build_small_n_succeeds() {
270        let rows = synth_rows(5, 8);
271        let idx = AnnIndex::build(rows, Metric::L2, 8).unwrap();
272        assert_eq!(idx.indexed_len(), 5);
273    }
274
275    #[test]
276    fn build_large_n_succeeds() {
277        let rows = synth_rows(500, 16);
278        let idx = AnnIndex::build(rows, Metric::L2, 16).unwrap();
279        assert_eq!(idx.indexed_len(), 500);
280    }
281
282    #[test]
283    fn search_returns_row_ids_not_internal_ids() {
284        let n = 200;
285        let rows = synth_rows(n, 8);
286        let idx = AnnIndex::build(rows, Metric::L2, 8).unwrap();
287        let hits = idx.search(&[0.5; 8], 5);
288        assert!(!hits.is_empty());
289        for (rid, _d) in &hits {
290            assert!(*rid >= 1 && *rid <= n as u64);
291        }
292    }
293
294    #[test]
295    fn snapshot_max_tracks_highest_row_id() {
296        let rows = vec![
297            (5u64, vec![1.0, 0.0]),
298            (10u64, vec![0.0, 1.0]),
299            (3u64, vec![1.0, 1.0]),
300        ];
301        let idx = AnnIndex::build(rows, Metric::L2, 2).unwrap();
302        assert_eq!(idx.snapshot_max, 10);
303    }
304
305    #[test]
306    fn cosine_metric_propagates_to_prism() {
307        let rows = synth_rows(50, 16);
308        let idx = AnnIndex::build(rows, Metric::Cosine, 16).unwrap();
309        assert_eq!(idx.metric, Metric::Cosine);
310        assert_eq!(idx.prism.config.metric, Metric::Cosine);
311    }
312
313    #[test]
314    fn inner_metric_propagates_to_prism() {
315        let rows = synth_rows(50, 16);
316        let idx = AnnIndex::build(rows, Metric::InnerProduct, 16).unwrap();
317        assert_eq!(idx.metric, Metric::InnerProduct);
318        assert_eq!(idx.prism.config.metric, Metric::InnerProduct);
319    }
320
321    /// attr 0 = i % 2; row_id = i + 1.
322    fn attr_rows(n: u64, dim: u16) -> Vec<(u64, Vec<f32>, Vec<u32>)> {
323        (0..n)
324            .map(|i| {
325                let v: Vec<f32> = (0..dim).map(|d| (i as f32 + d as f32) * 0.01).collect();
326                (i + 1, v, vec![(i % 2) as u32])
327            })
328            .collect()
329    }
330
331    #[test]
332    fn build_with_attrs_filters_by_attribute() {
333        let idx = AnnIndex::build_with_attrs(attr_rows(100, 8), 1, Metric::L2, 8).unwrap();
334        let hits = idx.search_filtered(&[0.5; 8], 10, 200, &Filter::eq(0, 1));
335        assert!(!hits.is_empty());
336        assert!(hits.len() <= 10);
337        for (rid, _) in &hits {
338            // category 1 == odd i == even row_id.
339            assert_eq!(rid % 2, 0, "row {rid} is not category 1");
340        }
341    }
342
343    #[test]
344    fn build_with_attrs_unfiltered_spans_all_cells() {
345        let idx = AnnIndex::build_with_attrs(attr_rows(100, 8), 1, Metric::L2, 8).unwrap();
346        let hits = idx.search_with_ef(&[0.5; 8], 10, 200);
347        assert_eq!(hits.len(), 10);
348        for (rid, _) in &hits {
349            assert!(*rid >= 1 && *rid <= 100);
350        }
351    }
352
353    #[test]
354    fn build_with_attrs_two_dims_conjunctive_filter() {
355        let n = 180u64;
356        let dim = 8u16;
357        let rows: Vec<(u64, Vec<f32>, Vec<u32>)> = (0..n)
358            .map(|i| {
359                let v: Vec<f32> = (0..dim).map(|d| (i as f32 + d as f32) * 0.01).collect();
360                (i + 1, v, vec![(i % 2) as u32, (i % 3) as u32])
361            })
362            .collect();
363        let idx = AnnIndex::build_with_attrs(rows, 2, Metric::L2, dim).unwrap();
364        let filter = Filter::new(vec![(0, vec![1]), (1, vec![2])]);
365        let hits = idx.search_filtered(&[0.5; 8], 10, 200, &filter);
366        assert!(!hits.is_empty());
367        for (rid, _) in &hits {
368            let i = rid - 1;
369            assert_eq!(i % 2, 1, "row {rid} fails attr0 = 1");
370            assert_eq!(i % 3, 2, "row {rid} fails attr1 = 2");
371        }
372    }
373
374    #[test]
375    fn build_with_attrs_arity_mismatch_errors() {
376        let rows = vec![(1u64, vec![0.0; 4], vec![0u32])];
377        let err = AnnIndex::build_with_attrs(rows, 2, Metric::L2, 4).unwrap_err();
378        assert!(matches!(
379            err,
380            AnnError::AttrArityMismatch {
381                expected: 2,
382                got: 1,
383                row_id: 1
384            }
385        ));
386    }
387
388    #[test]
389    fn build_delegates_to_attrs_path() {
390        let idx = AnnIndex::build(synth_rows(50, 8), Metric::L2, 8).unwrap();
391        assert_eq!(idx.indexed_len(), 50);
392        let hits = idx.search(&[0.3; 8], 5);
393        assert!(!hits.is_empty());
394    }
395}