1use crate::prism::{Filter, Metric, PointStore, PrismConfig, PrismIndex};
4
5pub const OVER_FETCH: usize = 4;
7
8pub const REBUILD_TAIL_MAX: usize = 2048;
10
11#[derive(Debug, thiserror::Error)]
12pub enum AnnError {
13 #[error("ANN build requires at least one row")]
14 EmptyInput,
15 #[error("ANN build vector dim mismatch: expected {expected}, got {got} for row_id {row_id}")]
16 DimMismatch {
17 expected: u16,
18 got: usize,
19 row_id: u64,
20 },
21 #[error(
22 "ANN build attribute arity mismatch: expected {expected}, got {got} for row_id {row_id}"
23 )]
24 AttrArityMismatch {
25 expected: usize,
26 got: usize,
27 row_id: u64,
28 },
29}
30
31fn prism_config(metric: Metric) -> PrismConfig {
34 PrismConfig {
35 metric,
36 binary_rerank: 0,
37 sigma_high: 0.001,
38 ..PrismConfig::default()
39 }
40}
41
42pub struct AnnIndex {
44 prism: PrismIndex,
45 id_map: Vec<u64>,
47 pub snapshot_max: u64,
49 pub metric: Metric,
50 pub dim: u16,
51}
52
53impl std::fmt::Debug for AnnIndex {
54 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
55 f.debug_struct("AnnIndex")
56 .field("snapshot_max", &self.snapshot_max)
57 .field("metric", &self.metric)
58 .field("dim", &self.dim)
59 .field("indexed_len", &self.id_map.len())
60 .finish()
61 }
62}
63
64impl AnnIndex {
65 pub fn build(rows: Vec<(u64, Vec<f32>)>, metric: Metric, dim: u16) -> Result<Self, AnnError> {
67 let with_attrs = rows
68 .into_iter()
69 .map(|(id, v)| (id, v, Vec::new()))
70 .collect();
71 Self::build_with_attrs(with_attrs, 0, metric, dim)
72 }
73
74 pub fn build_with_attrs(
77 mut rows: Vec<(u64, Vec<f32>, Vec<u32>)>,
78 num_attrs: usize,
79 metric: Metric,
80 dim: u16,
81 ) -> Result<Self, AnnError> {
82 if rows.is_empty() {
83 return Err(AnnError::EmptyInput);
84 }
85 for (rid, v, a) in &rows {
86 if v.len() != dim as usize {
87 return Err(AnnError::DimMismatch {
88 expected: dim,
89 got: v.len(),
90 row_id: *rid,
91 });
92 }
93 if a.len() != num_attrs {
94 return Err(AnnError::AttrArityMismatch {
95 expected: num_attrs,
96 got: a.len(),
97 row_id: *rid,
98 });
99 }
100 }
101
102 rows.sort_unstable_by_key(|(id, _, _)| *id);
103 let snapshot_max = rows.last().map(|(id, _, _)| *id).unwrap_or(0);
104
105 let n = rows.len();
106 let mut flat: Vec<f32> = Vec::with_capacity(n * dim as usize);
107 let mut row_ids: Vec<u64> = Vec::with_capacity(n);
108 let attr_dims = num_attrs.max(1);
110 let mut attr_cols: Vec<Vec<u32>> = vec![Vec::with_capacity(n); attr_dims];
111 for (rid, v, a) in &rows {
112 flat.extend_from_slice(v);
113 row_ids.push(*rid);
114 if num_attrs == 0 {
115 attr_cols[0].push(0);
116 } else {
117 for (j, &code) in a.iter().enumerate() {
118 attr_cols[j].push(code);
119 }
120 }
121 }
122
123 let store = PointStore::from_parts(flat, dim as usize, attr_cols);
124 let prism = PrismIndex::build(store, prism_config(metric));
125
126 let id_map: Vec<u64> = prism
128 .original_ids
129 .iter()
130 .map(|&old| row_ids[old as usize])
131 .collect();
132
133 Ok(Self {
134 prism,
135 id_map,
136 snapshot_max,
137 metric,
138 dim,
139 })
140 }
141
142 pub fn from_parts(
146 prism: PrismIndex,
147 id_map: Vec<u64>,
148 snapshot_max: u64,
149 metric: Metric,
150 dim: u16,
151 ) -> Self {
152 Self {
153 prism,
154 id_map,
155 snapshot_max,
156 metric,
157 dim,
158 }
159 }
160
161 pub fn prism(&self) -> &PrismIndex {
162 &self.prism
163 }
164
165 pub fn id_map(&self) -> &[u64] {
167 &self.id_map
168 }
169
170 pub fn active_config(metric: Metric) -> PrismConfig {
173 prism_config(metric)
174 }
175
176 pub fn search(&self, query: &[f32], k: usize) -> Vec<(u64, f32)> {
178 let ef = (k * OVER_FETCH).max(self.prism.config.beam_width);
179 self.search_with_ef(query, k, ef)
180 }
181
182 pub fn search_with_ef(&self, query: &[f32], k: usize, ef: usize) -> Vec<(u64, f32)> {
184 self.search_filtered(query, k, ef, &Filter::none())
185 }
186
187 pub fn search_filtered_default_ef(
189 &self,
190 query: &[f32],
191 k: usize,
192 filter: &Filter,
193 ) -> Vec<(u64, f32)> {
194 let ef = (k * OVER_FETCH).max(self.prism.config.beam_width);
195 self.search_filtered(query, k, ef, filter)
196 }
197
198 pub fn search_filtered(
205 &self,
206 query: &[f32],
207 k: usize,
208 ef: usize,
209 filter: &Filter,
210 ) -> Vec<(u64, f32)> {
211 debug_assert_eq!(query.len(), self.dim as usize);
212 let sqrt_l2 = self.metric == Metric::L2;
213 self.prism
214 .search(query, filter, k, ef)
215 .into_iter()
216 .map(|r| {
217 let dist = if sqrt_l2 { r.dist.sqrt() } else { r.dist };
218 (self.id_map[r.id as usize], dist)
219 })
220 .collect()
221 }
222
223 pub fn indexed_len(&self) -> usize {
225 self.id_map.len()
226 }
227
228 pub fn tail_is_stale(&self, live_max: u64) -> bool {
230 let tail = live_max.saturating_sub(self.snapshot_max) as usize;
231 tail > REBUILD_TAIL_MAX || tail > self.indexed_len() / 4
232 }
233}
234
235#[cfg(test)]
236mod tests {
237 use super::*;
238
239 fn synth_rows(n: usize, dim: u16) -> Vec<(u64, Vec<f32>)> {
240 (0..n)
241 .map(|i| {
242 let row_id = (i as u64) + 1;
243 let v: Vec<f32> = (0..dim).map(|d| (i as f32 + d as f32) * 0.01).collect();
244 (row_id, v)
245 })
246 .collect()
247 }
248
249 #[test]
250 fn build_empty_input_errors() {
251 let err = AnnIndex::build(Vec::new(), Metric::L2, 4).unwrap_err();
252 assert!(matches!(err, AnnError::EmptyInput));
253 }
254
255 #[test]
256 fn build_dim_mismatch_errors() {
257 let rows = vec![(1u64, vec![1.0, 2.0])];
258 let err = AnnIndex::build(rows, Metric::L2, 4).unwrap_err();
259 assert!(matches!(
260 err,
261 AnnError::DimMismatch {
262 expected: 4,
263 got: 2,
264 row_id: 1
265 }
266 ));
267 }
268
269 #[test]
270 fn build_single_row_succeeds() {
271 let rows = vec![(7u64, vec![0.1, 0.2, 0.3, 0.4])];
272 let idx = AnnIndex::build(rows, Metric::L2, 4).unwrap();
273 assert_eq!(idx.indexed_len(), 1);
274 assert_eq!(idx.snapshot_max, 7);
275 }
276
277 #[test]
278 fn build_small_n_succeeds() {
279 let rows = synth_rows(5, 8);
280 let idx = AnnIndex::build(rows, Metric::L2, 8).unwrap();
281 assert_eq!(idx.indexed_len(), 5);
282 }
283
284 #[test]
285 fn build_large_n_succeeds() {
286 let rows = synth_rows(500, 16);
287 let idx = AnnIndex::build(rows, Metric::L2, 16).unwrap();
288 assert_eq!(idx.indexed_len(), 500);
289 }
290
291 #[test]
292 fn search_returns_row_ids_not_internal_ids() {
293 let n = 200;
294 let rows = synth_rows(n, 8);
295 let idx = AnnIndex::build(rows, Metric::L2, 8).unwrap();
296 let hits = idx.search(&[0.5; 8], 5);
297 assert!(!hits.is_empty());
298 for (rid, _d) in &hits {
299 assert!(*rid >= 1 && *rid <= n as u64);
300 }
301 }
302
303 #[test]
304 fn snapshot_max_tracks_highest_row_id() {
305 let rows = vec![
306 (5u64, vec![1.0, 0.0]),
307 (10u64, vec![0.0, 1.0]),
308 (3u64, vec![1.0, 1.0]),
309 ];
310 let idx = AnnIndex::build(rows, Metric::L2, 2).unwrap();
311 assert_eq!(idx.snapshot_max, 10);
312 }
313
314 #[test]
315 fn cosine_metric_propagates_to_prism() {
316 let rows = synth_rows(50, 16);
317 let idx = AnnIndex::build(rows, Metric::Cosine, 16).unwrap();
318 assert_eq!(idx.metric, Metric::Cosine);
319 assert_eq!(idx.prism.config.metric, Metric::Cosine);
320 }
321
322 #[test]
323 fn inner_metric_propagates_to_prism() {
324 let rows = synth_rows(50, 16);
325 let idx = AnnIndex::build(rows, Metric::InnerProduct, 16).unwrap();
326 assert_eq!(idx.metric, Metric::InnerProduct);
327 assert_eq!(idx.prism.config.metric, Metric::InnerProduct);
328 }
329
330 fn attr_rows(n: u64, dim: u16) -> Vec<(u64, Vec<f32>, Vec<u32>)> {
332 (0..n)
333 .map(|i| {
334 let v: Vec<f32> = (0..dim).map(|d| (i as f32 + d as f32) * 0.01).collect();
335 (i + 1, v, vec![(i % 2) as u32])
336 })
337 .collect()
338 }
339
340 #[test]
341 fn build_with_attrs_filters_by_attribute() {
342 let idx = AnnIndex::build_with_attrs(attr_rows(100, 8), 1, Metric::L2, 8).unwrap();
343 let hits = idx.search_filtered(&[0.5; 8], 10, 200, &Filter::eq(0, 1));
344 assert!(!hits.is_empty());
345 assert!(hits.len() <= 10);
346 for (rid, _) in &hits {
347 assert_eq!(rid % 2, 0, "row {rid} is not category 1");
349 }
350 }
351
352 #[test]
353 fn build_with_attrs_unfiltered_spans_all_cells() {
354 let idx = AnnIndex::build_with_attrs(attr_rows(100, 8), 1, Metric::L2, 8).unwrap();
355 let hits = idx.search_with_ef(&[0.5; 8], 10, 200);
356 assert_eq!(hits.len(), 10);
357 for (rid, _) in &hits {
358 assert!(*rid >= 1 && *rid <= 100);
359 }
360 }
361
362 #[test]
363 fn build_with_attrs_two_dims_conjunctive_filter() {
364 let n = 180u64;
365 let dim = 8u16;
366 let rows: Vec<(u64, Vec<f32>, Vec<u32>)> = (0..n)
367 .map(|i| {
368 let v: Vec<f32> = (0..dim).map(|d| (i as f32 + d as f32) * 0.01).collect();
369 (i + 1, v, vec![(i % 2) as u32, (i % 3) as u32])
370 })
371 .collect();
372 let idx = AnnIndex::build_with_attrs(rows, 2, Metric::L2, dim).unwrap();
373 let filter = Filter::new(vec![(0, vec![1]), (1, vec![2])]);
374 let hits = idx.search_filtered(&[0.5; 8], 10, 200, &filter);
375 assert!(!hits.is_empty());
376 for (rid, _) in &hits {
377 let i = rid - 1;
378 assert_eq!(i % 2, 1, "row {rid} fails attr0 = 1");
379 assert_eq!(i % 3, 2, "row {rid} fails attr1 = 2");
380 }
381 }
382
383 #[test]
384 fn build_with_attrs_arity_mismatch_errors() {
385 let rows = vec![(1u64, vec![0.0; 4], vec![0u32])];
386 let err = AnnIndex::build_with_attrs(rows, 2, Metric::L2, 4).unwrap_err();
387 assert!(matches!(
388 err,
389 AnnError::AttrArityMismatch {
390 expected: 2,
391 got: 1,
392 row_id: 1
393 }
394 ));
395 }
396
397 #[test]
398 fn build_delegates_to_attrs_path() {
399 let idx = AnnIndex::build(synth_rows(50, 8), Metric::L2, 8).unwrap();
400 assert_eq!(idx.indexed_len(), 50);
401 let hits = idx.search(&[0.3; 8], 5);
402 assert!(!hits.is_empty());
403 }
404}