1use crate::prism::{Filter, Metric, PointStore, PrismConfig, PrismIndex};
4
5pub const OVER_FETCH: usize = 4;
7
8#[derive(Debug, thiserror::Error)]
9pub enum AnnError {
10 #[error("ANN build requires at least one row")]
11 EmptyInput,
12 #[error("ANN build vector dim mismatch: expected {expected}, got {got} for row_id {row_id}")]
13 DimMismatch {
14 expected: u16,
15 got: usize,
16 row_id: u64,
17 },
18 #[error(
19 "ANN build attribute arity mismatch: expected {expected}, got {got} for row_id {row_id}"
20 )]
21 AttrArityMismatch {
22 expected: usize,
23 got: usize,
24 row_id: u64,
25 },
26}
27
28fn prism_config(metric: Metric) -> PrismConfig {
31 PrismConfig {
32 metric,
33 binary_rerank: 0,
34 sigma_high: 0.001,
35 ..PrismConfig::default()
36 }
37}
38
39pub struct AnnIndex {
41 prism: PrismIndex,
42 id_map: Vec<u64>,
44 pub snapshot_max: u64,
46 pub metric: Metric,
47 pub dim: u16,
48}
49
50impl std::fmt::Debug for AnnIndex {
51 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
52 f.debug_struct("AnnIndex")
53 .field("snapshot_max", &self.snapshot_max)
54 .field("metric", &self.metric)
55 .field("dim", &self.dim)
56 .field("indexed_len", &self.id_map.len())
57 .finish()
58 }
59}
60
61impl AnnIndex {
62 pub fn build(rows: Vec<(u64, Vec<f32>)>, metric: Metric, dim: u16) -> Result<Self, AnnError> {
64 let with_attrs = rows
65 .into_iter()
66 .map(|(id, v)| (id, v, Vec::new()))
67 .collect();
68 Self::build_with_attrs(with_attrs, 0, metric, dim)
69 }
70
71 pub fn build_with_attrs(
74 mut rows: Vec<(u64, Vec<f32>, Vec<u32>)>,
75 num_attrs: usize,
76 metric: Metric,
77 dim: u16,
78 ) -> Result<Self, AnnError> {
79 if rows.is_empty() {
80 return Err(AnnError::EmptyInput);
81 }
82 for (rid, v, a) in &rows {
83 if v.len() != dim as usize {
84 return Err(AnnError::DimMismatch {
85 expected: dim,
86 got: v.len(),
87 row_id: *rid,
88 });
89 }
90 if a.len() != num_attrs {
91 return Err(AnnError::AttrArityMismatch {
92 expected: num_attrs,
93 got: a.len(),
94 row_id: *rid,
95 });
96 }
97 }
98
99 rows.sort_unstable_by_key(|(id, _, _)| *id);
100 let snapshot_max = rows.last().map(|(id, _, _)| *id).unwrap_or(0);
101
102 let n = rows.len();
103 let mut flat: Vec<f32> = Vec::with_capacity(n * dim as usize);
104 let mut row_ids: Vec<u64> = Vec::with_capacity(n);
105 let attr_dims = num_attrs.max(1);
107 let mut attr_cols: Vec<Vec<u32>> = vec![Vec::with_capacity(n); attr_dims];
108 for (rid, v, a) in &rows {
109 flat.extend_from_slice(v);
110 row_ids.push(*rid);
111 if num_attrs == 0 {
112 attr_cols[0].push(0);
113 } else {
114 for (j, &code) in a.iter().enumerate() {
115 attr_cols[j].push(code);
116 }
117 }
118 }
119
120 let store = PointStore::from_parts(flat, dim as usize, attr_cols);
121 let prism = PrismIndex::build(store, prism_config(metric));
122
123 let id_map: Vec<u64> = prism
125 .original_ids
126 .iter()
127 .map(|&old| row_ids[old as usize])
128 .collect();
129
130 Ok(Self {
131 prism,
132 id_map,
133 snapshot_max,
134 metric,
135 dim,
136 })
137 }
138
139 pub fn from_parts(
143 prism: PrismIndex,
144 id_map: Vec<u64>,
145 snapshot_max: u64,
146 metric: Metric,
147 dim: u16,
148 ) -> Self {
149 Self {
150 prism,
151 id_map,
152 snapshot_max,
153 metric,
154 dim,
155 }
156 }
157
158 pub fn prism(&self) -> &PrismIndex {
159 &self.prism
160 }
161
162 pub fn id_map(&self) -> &[u64] {
164 &self.id_map
165 }
166
167 pub fn active_config(metric: Metric) -> PrismConfig {
170 prism_config(metric)
171 }
172
173 pub fn search(&self, query: &[f32], k: usize) -> Vec<(u64, f32)> {
175 let ef = (k * OVER_FETCH).max(self.prism.config.beam_width);
176 self.search_with_ef(query, k, ef)
177 }
178
179 pub fn search_with_ef(&self, query: &[f32], k: usize, ef: usize) -> Vec<(u64, f32)> {
181 self.search_filtered(query, k, ef, &Filter::none())
182 }
183
184 pub fn search_filtered_default_ef(
186 &self,
187 query: &[f32],
188 k: usize,
189 filter: &Filter,
190 ) -> Vec<(u64, f32)> {
191 let ef = (k * OVER_FETCH).max(self.prism.config.beam_width);
192 self.search_filtered(query, k, ef, filter)
193 }
194
195 pub fn search_filtered(
202 &self,
203 query: &[f32],
204 k: usize,
205 ef: usize,
206 filter: &Filter,
207 ) -> Vec<(u64, f32)> {
208 debug_assert_eq!(query.len(), self.dim as usize);
209 let sqrt_l2 = self.metric == Metric::L2;
210 self.prism
211 .search(query, filter, k, ef)
212 .into_iter()
213 .map(|r| {
214 let dist = if sqrt_l2 { r.dist.sqrt() } else { r.dist };
215 (self.id_map[r.id as usize], dist)
216 })
217 .collect()
218 }
219
220 pub fn indexed_len(&self) -> usize {
222 self.id_map.len()
223 }
224}
225
226#[cfg(test)]
227mod tests {
228 use super::*;
229
230 fn synth_rows(n: usize, dim: u16) -> Vec<(u64, Vec<f32>)> {
231 (0..n)
232 .map(|i| {
233 let row_id = (i as u64) + 1;
234 let v: Vec<f32> = (0..dim).map(|d| (i as f32 + d as f32) * 0.01).collect();
235 (row_id, v)
236 })
237 .collect()
238 }
239
240 #[test]
241 fn build_empty_input_errors() {
242 let err = AnnIndex::build(Vec::new(), Metric::L2, 4).unwrap_err();
243 assert!(matches!(err, AnnError::EmptyInput));
244 }
245
246 #[test]
247 fn build_dim_mismatch_errors() {
248 let rows = vec![(1u64, vec![1.0, 2.0])];
249 let err = AnnIndex::build(rows, Metric::L2, 4).unwrap_err();
250 assert!(matches!(
251 err,
252 AnnError::DimMismatch {
253 expected: 4,
254 got: 2,
255 row_id: 1
256 }
257 ));
258 }
259
260 #[test]
261 fn build_single_row_succeeds() {
262 let rows = vec![(7u64, vec![0.1, 0.2, 0.3, 0.4])];
263 let idx = AnnIndex::build(rows, Metric::L2, 4).unwrap();
264 assert_eq!(idx.indexed_len(), 1);
265 assert_eq!(idx.snapshot_max, 7);
266 }
267
268 #[test]
269 fn build_small_n_succeeds() {
270 let rows = synth_rows(5, 8);
271 let idx = AnnIndex::build(rows, Metric::L2, 8).unwrap();
272 assert_eq!(idx.indexed_len(), 5);
273 }
274
275 #[test]
276 fn build_large_n_succeeds() {
277 let rows = synth_rows(500, 16);
278 let idx = AnnIndex::build(rows, Metric::L2, 16).unwrap();
279 assert_eq!(idx.indexed_len(), 500);
280 }
281
282 #[test]
283 fn search_returns_row_ids_not_internal_ids() {
284 let n = 200;
285 let rows = synth_rows(n, 8);
286 let idx = AnnIndex::build(rows, Metric::L2, 8).unwrap();
287 let hits = idx.search(&[0.5; 8], 5);
288 assert!(!hits.is_empty());
289 for (rid, _d) in &hits {
290 assert!(*rid >= 1 && *rid <= n as u64);
291 }
292 }
293
294 #[test]
295 fn snapshot_max_tracks_highest_row_id() {
296 let rows = vec![
297 (5u64, vec![1.0, 0.0]),
298 (10u64, vec![0.0, 1.0]),
299 (3u64, vec![1.0, 1.0]),
300 ];
301 let idx = AnnIndex::build(rows, Metric::L2, 2).unwrap();
302 assert_eq!(idx.snapshot_max, 10);
303 }
304
305 #[test]
306 fn cosine_metric_propagates_to_prism() {
307 let rows = synth_rows(50, 16);
308 let idx = AnnIndex::build(rows, Metric::Cosine, 16).unwrap();
309 assert_eq!(idx.metric, Metric::Cosine);
310 assert_eq!(idx.prism.config.metric, Metric::Cosine);
311 }
312
313 #[test]
314 fn inner_metric_propagates_to_prism() {
315 let rows = synth_rows(50, 16);
316 let idx = AnnIndex::build(rows, Metric::InnerProduct, 16).unwrap();
317 assert_eq!(idx.metric, Metric::InnerProduct);
318 assert_eq!(idx.prism.config.metric, Metric::InnerProduct);
319 }
320
321 fn attr_rows(n: u64, dim: u16) -> Vec<(u64, Vec<f32>, Vec<u32>)> {
323 (0..n)
324 .map(|i| {
325 let v: Vec<f32> = (0..dim).map(|d| (i as f32 + d as f32) * 0.01).collect();
326 (i + 1, v, vec![(i % 2) as u32])
327 })
328 .collect()
329 }
330
331 #[test]
332 fn build_with_attrs_filters_by_attribute() {
333 let idx = AnnIndex::build_with_attrs(attr_rows(100, 8), 1, Metric::L2, 8).unwrap();
334 let hits = idx.search_filtered(&[0.5; 8], 10, 200, &Filter::eq(0, 1));
335 assert!(!hits.is_empty());
336 assert!(hits.len() <= 10);
337 for (rid, _) in &hits {
338 assert_eq!(rid % 2, 0, "row {rid} is not category 1");
340 }
341 }
342
343 #[test]
344 fn build_with_attrs_unfiltered_spans_all_cells() {
345 let idx = AnnIndex::build_with_attrs(attr_rows(100, 8), 1, Metric::L2, 8).unwrap();
346 let hits = idx.search_with_ef(&[0.5; 8], 10, 200);
347 assert_eq!(hits.len(), 10);
348 for (rid, _) in &hits {
349 assert!(*rid >= 1 && *rid <= 100);
350 }
351 }
352
353 #[test]
354 fn build_with_attrs_two_dims_conjunctive_filter() {
355 let n = 180u64;
356 let dim = 8u16;
357 let rows: Vec<(u64, Vec<f32>, Vec<u32>)> = (0..n)
358 .map(|i| {
359 let v: Vec<f32> = (0..dim).map(|d| (i as f32 + d as f32) * 0.01).collect();
360 (i + 1, v, vec![(i % 2) as u32, (i % 3) as u32])
361 })
362 .collect();
363 let idx = AnnIndex::build_with_attrs(rows, 2, Metric::L2, dim).unwrap();
364 let filter = Filter::new(vec![(0, vec![1]), (1, vec![2])]);
365 let hits = idx.search_filtered(&[0.5; 8], 10, 200, &filter);
366 assert!(!hits.is_empty());
367 for (rid, _) in &hits {
368 let i = rid - 1;
369 assert_eq!(i % 2, 1, "row {rid} fails attr0 = 1");
370 assert_eq!(i % 3, 2, "row {rid} fails attr1 = 2");
371 }
372 }
373
374 #[test]
375 fn build_with_attrs_arity_mismatch_errors() {
376 let rows = vec![(1u64, vec![0.0; 4], vec![0u32])];
377 let err = AnnIndex::build_with_attrs(rows, 2, Metric::L2, 4).unwrap_err();
378 assert!(matches!(
379 err,
380 AnnError::AttrArityMismatch {
381 expected: 2,
382 got: 1,
383 row_id: 1
384 }
385 ));
386 }
387
388 #[test]
389 fn build_delegates_to_attrs_path() {
390 let idx = AnnIndex::build(synth_rows(50, 8), Metric::L2, 8).unwrap();
391 assert_eq!(idx.indexed_len(), 50);
392 let hits = idx.search(&[0.3; 8], 5);
393 assert!(!hits.is_empty());
394 }
395}