alopex_core/vector/
flat.rs1use crate::types::Key;
3use crate::vector::{score, Metric};
4use crate::Result;
5
6#[derive(Debug)]
8pub struct Candidate<'a> {
9 pub key: &'a Key,
11 pub vector: &'a [f32],
13}
14
15#[derive(Debug, Clone, PartialEq)]
17pub struct ScoredItem {
18 pub key: Key,
20 pub score: f32,
22}
23
24pub fn search_flat<'a, F>(
29 query: &[f32],
30 metric: Metric,
31 top_k: usize,
32 candidates: impl IntoIterator<Item = Candidate<'a>>,
33 mut filter: Option<F>,
34) -> Result<Vec<ScoredItem>>
35where
36 F: FnMut(&Candidate<'a>) -> bool,
37{
38 if top_k == 0 {
39 return Ok(Vec::new());
40 }
41
42 let mut results = Vec::new();
43 for cand in candidates {
44 if let Some(ref mut pred) = filter {
45 if !pred(&cand) {
46 continue;
47 }
48 }
49
50 let s = score(metric, query, cand.vector)?;
52 results.push(ScoredItem {
53 key: cand.key.clone(),
54 score: s,
55 });
56 }
57
58 results.sort_by(|a, b| b.score.total_cmp(&a.score).then_with(|| a.key.cmp(&b.key)));
59 if results.len() > top_k {
60 results.truncate(top_k);
61 }
62 Ok(results)
63}
64
65#[cfg(all(test, not(target_arch = "wasm32")))]
66mod tests {
67 use super::*;
68
69 fn key(bytes: &[u8]) -> Key {
70 bytes.to_vec()
71 }
72
73 #[test]
74 fn respects_filter_before_scoring() {
75 let query = [1.0, 0.0];
76 let ka = key(b"a");
77 let kb = key(b"b");
78 let items = vec![
79 Candidate {
80 key: &ka,
81 vector: &[1.0, 0.0],
82 },
83 Candidate {
84 key: &kb,
85 vector: &[0.0, 1.0],
86 },
87 ];
88 let res = search_flat(
89 &query,
90 Metric::Cosine,
91 10,
92 items,
93 Some(|c: &Candidate| c.key != b"b"),
94 )
95 .unwrap();
96 assert_eq!(res.len(), 1);
97 assert_eq!(res[0].key, b"a");
98 }
99
100 #[test]
101 fn orders_by_score_then_key() {
102 let query = [1.0, 0.0];
103 let kb = key(b"b");
104 let ka = key(b"a");
105 let items = vec![
106 Candidate {
107 key: &kb,
108 vector: &[1.0, 0.0],
109 },
110 Candidate {
111 key: &ka,
112 vector: &[1.0, 0.0],
113 },
114 ];
115 let res = search_flat(
116 &query,
117 Metric::Cosine,
118 10,
119 items,
120 None::<fn(&Candidate) -> bool>,
121 )
122 .unwrap();
123 assert_eq!(res[0].key, b"a");
125 assert_eq!(res[1].key, b"b");
126 }
127
128 #[test]
129 fn switches_metric() {
130 let query = [1.0, 0.0];
131 let ka = key(b"a");
132 let kb = key(b"b");
133 let items = vec![
134 Candidate {
135 key: &ka,
136 vector: &[2.0, 0.0],
137 },
138 Candidate {
139 key: &kb,
140 vector: &[0.0, 2.0],
141 },
142 ];
143 let res =
144 search_flat(&query, Metric::L2, 1, items, None::<fn(&Candidate) -> bool>).unwrap();
145 assert_eq!(res[0].key, b"a");
147 }
148
149 #[test]
150 fn enforces_dimension_match() {
151 let query = [1.0, 0.0];
152 let ka = key(b"a");
153 let items = vec![Candidate {
154 key: &ka,
155 vector: &[1.0, 0.0, 1.0],
156 }];
157 use crate::Error;
158 let err = search_flat(
159 &query,
160 Metric::Cosine,
161 1,
162 items,
163 None::<fn(&Candidate) -> bool>,
164 )
165 .unwrap_err();
166 assert!(matches!(err, Error::DimensionMismatch { .. }));
167 }
168}