alopex_core/vector/
flat.rs

1//! Flat planner: filter → score → top-k using vector metrics.
2use crate::types::Key;
3use crate::vector::{score, Metric};
4use crate::Result;
5
6/// A candidate item for flat search.
7#[derive(Debug)]
8pub struct Candidate<'a> {
9    /// Key identifying the item.
10    pub key: &'a Key,
11    /// Vector embedding.
12    pub vector: &'a [f32],
13}
14
15/// Scored output item.
16#[derive(Debug, Clone, PartialEq)]
17pub struct ScoredItem {
18    /// Key identifying the item.
19    pub key: Key,
20    /// Similarity score (higher is better).
21    pub score: f32,
22}
23
24/// Executes a flat search:
25/// 1) Optionally filters candidates
26/// 2) Scores using the selected metric
27/// 3) Returns the top-k results sorted by descending score, then key for stability
28pub fn search_flat<'a, F>(
29    query: &[f32],
30    metric: Metric,
31    top_k: usize,
32    candidates: impl IntoIterator<Item = Candidate<'a>>,
33    mut filter: Option<F>,
34) -> Result<Vec<ScoredItem>>
35where
36    F: FnMut(&Candidate<'a>) -> bool,
37{
38    if top_k == 0 {
39        return Ok(Vec::new());
40    }
41
42    let mut results = Vec::new();
43    for cand in candidates {
44        if let Some(ref mut pred) = filter {
45            if !pred(&cand) {
46                continue;
47            }
48        }
49
50        // score() will validate dimension equality and return typed errors.
51        let s = score(metric, query, cand.vector)?;
52        results.push(ScoredItem {
53            key: cand.key.clone(),
54            score: s,
55        });
56    }
57
58    results.sort_by(|a, b| b.score.total_cmp(&a.score).then_with(|| a.key.cmp(&b.key)));
59    if results.len() > top_k {
60        results.truncate(top_k);
61    }
62    Ok(results)
63}
64
65#[cfg(all(test, not(target_arch = "wasm32")))]
66mod tests {
67    use super::*;
68
69    fn key(bytes: &[u8]) -> Key {
70        bytes.to_vec()
71    }
72
73    #[test]
74    fn respects_filter_before_scoring() {
75        let query = [1.0, 0.0];
76        let ka = key(b"a");
77        let kb = key(b"b");
78        let items = vec![
79            Candidate {
80                key: &ka,
81                vector: &[1.0, 0.0],
82            },
83            Candidate {
84                key: &kb,
85                vector: &[0.0, 1.0],
86            },
87        ];
88        let res = search_flat(
89            &query,
90            Metric::Cosine,
91            10,
92            items,
93            Some(|c: &Candidate| c.key != b"b"),
94        )
95        .unwrap();
96        assert_eq!(res.len(), 1);
97        assert_eq!(res[0].key, b"a");
98    }
99
100    #[test]
101    fn orders_by_score_then_key() {
102        let query = [1.0, 0.0];
103        let kb = key(b"b");
104        let ka = key(b"a");
105        let items = vec![
106            Candidate {
107                key: &kb,
108                vector: &[1.0, 0.0],
109            },
110            Candidate {
111                key: &ka,
112                vector: &[1.0, 0.0],
113            },
114        ];
115        let res = search_flat(
116            &query,
117            Metric::Cosine,
118            10,
119            items,
120            None::<fn(&Candidate) -> bool>,
121        )
122        .unwrap();
123        // same score, sorted by key for determinism
124        assert_eq!(res[0].key, b"a");
125        assert_eq!(res[1].key, b"b");
126    }
127
128    #[test]
129    fn switches_metric() {
130        let query = [1.0, 0.0];
131        let ka = key(b"a");
132        let kb = key(b"b");
133        let items = vec![
134            Candidate {
135                key: &ka,
136                vector: &[2.0, 0.0],
137            },
138            Candidate {
139                key: &kb,
140                vector: &[0.0, 2.0],
141            },
142        ];
143        let res =
144            search_flat(&query, Metric::L2, 1, items, None::<fn(&Candidate) -> bool>).unwrap();
145        // L2 negative distance: closer is higher (less negative), so "a" should win.
146        assert_eq!(res[0].key, b"a");
147    }
148
149    #[test]
150    fn enforces_dimension_match() {
151        let query = [1.0, 0.0];
152        let ka = key(b"a");
153        let items = vec![Candidate {
154            key: &ka,
155            vector: &[1.0, 0.0, 1.0],
156        }];
157        use crate::Error;
158        let err = search_flat(
159            &query,
160            Metric::Cosine,
161            1,
162            items,
163            None::<fn(&Candidate) -> bool>,
164        )
165        .unwrap_err();
166        assert!(matches!(err, Error::DimensionMismatch { .. }));
167    }
168}