embeddenator_retrieval/retrieval/
signature.rs1use std::collections::{HashMap, HashSet};
11
12use embeddenator_vsa::{SparseVec, DIM};
13
14pub trait CandidateGenerator<V> {
27 type Candidate;
28
29 fn candidates(&self, query: &V, k: usize) -> Vec<Self::Candidate>;
30}
31
32pub const DEFAULT_SIGNATURE_PROBES: usize = 24;
36
37#[derive(Clone, Copy, Debug, PartialEq, Eq)]
39pub struct SignatureQueryOptions {
40 pub max_candidates: usize,
42
43 pub probe_radius: u8,
48
49 pub max_probes: usize,
53}
54
55impl Default for SignatureQueryOptions {
56 fn default() -> Self {
57 Self {
58 max_candidates: 1_000,
59 probe_radius: 1,
60 max_probes: 1 + (2 * DEFAULT_SIGNATURE_PROBES),
61 }
62 }
63}
64
65#[derive(Clone, Debug)]
70pub struct TernarySignatureIndex {
71 probe_dims: Vec<usize>,
72 buckets: HashMap<u64, Vec<usize>>, }
74
75impl TernarySignatureIndex {
76 pub fn build_from_map(map: &HashMap<usize, SparseVec>) -> Self {
80 let probe_dims = default_probe_dims(DEFAULT_SIGNATURE_PROBES);
81 Self::build_from_map_with_probes(map, probe_dims)
82 }
83
84 pub fn build_from_map_with_probes(
86 map: &HashMap<usize, SparseVec>,
87 probe_dims: Vec<usize>,
88 ) -> Self {
89 let mut buckets: HashMap<u64, Vec<usize>> = HashMap::new();
90
91 let mut ids: Vec<usize> = map.keys().copied().collect();
93 ids.sort_unstable();
94
95 for id in ids {
96 let Some(vec) = map.get(&id) else { continue };
97 let sig = signature_for(vec, &probe_dims);
98 buckets.entry(sig).or_default().push(id);
99 }
100
101 for ids in buckets.values_mut() {
103 ids.sort_unstable();
104 ids.dedup();
105 }
106
107 Self {
108 probe_dims,
109 buckets,
110 }
111 }
112
113 pub fn probe_dims(&self) -> &[usize] {
114 &self.probe_dims
115 }
116
117 pub fn candidates_with_options(
119 &self,
120 query: &SparseVec,
121 opts: SignatureQueryOptions,
122 ) -> Vec<usize> {
123 if opts.max_candidates == 0 {
124 return Vec::new();
125 }
126
127 let sig = signature_for(query, &self.probe_dims);
128 let probe_radius = opts.probe_radius.min(1);
129 let probe_sigs =
130 probe_signatures(sig, self.probe_dims.len(), probe_radius, opts.max_probes);
131
132 let mut seen: HashSet<usize> = HashSet::new();
133 let mut out: Vec<usize> = Vec::new();
134
135 for ps in probe_sigs {
136 let Some(ids) = self.buckets.get(&ps) else {
137 continue;
138 };
139 for &id in ids {
140 if seen.insert(id) {
141 out.push(id);
142 if out.len() >= opts.max_candidates {
143 break;
144 }
145 }
146 }
147 if out.len() >= opts.max_candidates {
148 break;
149 }
150 }
151
152 out.sort_unstable();
154 out
155 }
156}
157
158impl CandidateGenerator<SparseVec> for TernarySignatureIndex {
159 type Candidate = usize;
160
161 fn candidates(&self, query: &SparseVec, k: usize) -> Vec<Self::Candidate> {
163 self.candidates_with_options(
164 query,
165 SignatureQueryOptions {
166 max_candidates: k,
167 ..SignatureQueryOptions::default()
168 },
169 )
170 }
171}
172
173fn default_probe_dims(count: usize) -> Vec<usize> {
174 let mut out = Vec::with_capacity(count);
175 let mut seen = HashSet::with_capacity(count * 2);
176
177 let mut state: u64 = 0xED00_0000_0000_0001u64;
180
181 while out.len() < count {
182 state = splitmix64(state);
183 let d = (state as usize) % DIM;
184 if seen.insert(d) {
185 out.push(d);
186 }
187 }
188
189 out
190}
191
192fn splitmix64(mut x: u64) -> u64 {
193 x = x.wrapping_add(0x9E37_79B9_7F4A_7C15);
194 let mut z = x;
195 z = (z ^ (z >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9);
196 z = (z ^ (z >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB);
197 z ^ (z >> 31)
198}
199
200fn signature_for(vec: &SparseVec, probe_dims: &[usize]) -> u64 {
205 let mut sig: u64 = 0;
206 for (i, &d) in probe_dims.iter().enumerate() {
207 let lane = match sign_at(vec, d) {
208 0 => 0b00u64,
209 1 => 0b01u64,
210 -1 => 0b10u64,
211 _ => 0b00u64,
212 };
213 sig |= lane << (2 * i);
214 }
215 sig
216}
217
218fn sign_at(vec: &SparseVec, dim: usize) -> i8 {
219 if vec.pos.contains(&dim) {
220 1
221 } else if vec.neg.contains(&dim) {
222 -1
223 } else {
224 0
225 }
226}
227
228fn probe_signatures(base: u64, probes: usize, radius: u8, max_probes: usize) -> Vec<u64> {
229 if max_probes == 0 {
230 return Vec::new();
231 }
232
233 let mut out = Vec::new();
234 out.push(base);
235
236 if radius == 0 {
237 return out;
238 }
239
240 for i in 0..probes {
242 if out.len() >= max_probes {
243 break;
244 }
245
246 let shift = 2 * i;
247 let mask = 0b11u64 << shift;
248 let cur = (base & mask) >> shift;
249
250 for &alt in &[0b00u64, 0b01u64, 0b10u64] {
252 if alt == cur {
253 continue;
254 }
255 let next = (base & !mask) | (alt << shift);
256 out.push(next);
257 if out.len() >= max_probes {
258 break;
259 }
260 }
261 }
262
263 out
264}
265
266#[cfg(test)]
267mod tests {
268 use super::*;
269 use embeddenator_vsa::ReversibleVSAConfig;
270
271 #[test]
272 fn default_probe_dims_are_stable_and_in_range() {
273 let a = default_probe_dims(DEFAULT_SIGNATURE_PROBES);
274 let b = default_probe_dims(DEFAULT_SIGNATURE_PROBES);
275 assert_eq!(a, b);
276 assert_eq!(a.len(), DEFAULT_SIGNATURE_PROBES);
277 for &d in &a {
278 assert!(d < DIM);
279 }
280
281 let mut uniq = a.clone();
282 uniq.sort_unstable();
283 uniq.dedup();
284 assert_eq!(uniq.len(), DEFAULT_SIGNATURE_PROBES);
285 }
286
287 #[test]
288 fn candidates_are_deterministic_and_include_self_when_exact_bucket_hits() {
289 let cfg = ReversibleVSAConfig::default();
290
291 let v0 = SparseVec::encode_data(b"alpha", &cfg, None);
292 let v1 = SparseVec::encode_data(b"beta", &cfg, None);
293
294 let mut map = HashMap::new();
295 map.insert(0, v0.clone());
296 map.insert(1, v1);
297
298 let idx = TernarySignatureIndex::build_from_map(&map);
299 let opts = SignatureQueryOptions {
300 max_candidates: 10,
301 probe_radius: 0,
302 max_probes: 1,
303 };
304
305 let c1 = idx.candidates_with_options(&v0, opts);
306 let c2 = idx.candidates_with_options(&v0, opts);
307 assert_eq!(c1, c2);
308 assert!(c1.contains(&0));
309 }
310
311 #[test]
312 fn probe_signatures_radius_one_includes_base_and_is_bounded() {
313 let base = 0u64;
314 let sigs = probe_signatures(base, 4, 1, 5);
315 assert_eq!(sigs.len(), 5);
316 assert_eq!(sigs[0], base);
317 }
318}