Skip to main content

embeddenator_retrieval/retrieval/
signature.rs

1//! Optional signature-based candidate generation for sparse ternary vectors.
2//!
3//! This is an opt-in alternative to inverted-index candidate generation.
4//!
5//! Design goals:
6//! - Deterministic across runs (fixed probe dimensions + stable iteration)
7//! - Fast to build + query
8//! - Multi-probe support (radius-1) to soften bucket boundary effects
9
10use std::collections::{HashMap, HashSet};
11
12use embeddenator_vsa::{SparseVec, DIM};
13
14/// Trait for generating candidate IDs from a query vector.
15///
16/// # Note on Local Definition
17///
18/// This trait is defined locally to avoid cyclic dependencies with
19/// `embeddenator-interop`. The external `embeddenator_interop::CandidateGenerator`
20/// trait is NOT implemented by types in this module.
21///
22/// **Migration note (v0.22.0):** If your code relied on `TernarySignatureIndex`
23/// implementing `embeddenator_interop::CandidateGenerator`, you will need to
24/// use this local trait instead, or create an adapter. This was necessary to
25/// break the dependency cycle between retrieval and interop crates.
26pub trait CandidateGenerator<V> {
27    type Candidate;
28
29    fn candidates(&self, query: &V, k: usize) -> Vec<Self::Candidate>;
30}
31
32/// How many probe dimensions are used for the default signature.
33///
34/// Each probe consumes 2 bits in the `u64` signature encoding.
35pub const DEFAULT_SIGNATURE_PROBES: usize = 24;
36
37/// Query-time knobs for signature candidate generation.
38#[derive(Clone, Copy, Debug, PartialEq, Eq)]
39pub struct SignatureQueryOptions {
40    /// Maximum number of candidate IDs to return.
41    pub max_candidates: usize,
42
43    /// Multi-probe radius. Currently clamped to {0,1}.
44    ///
45    /// - 0: only exact signature bucket
46    /// - 1: also probe one-dimension variants (two alternates per probe)
47    pub probe_radius: u8,
48
49    /// Upper bound on the number of signature buckets to probe.
50    ///
51    /// This protects against expensive probing when probe dimensions are large.
52    pub max_probes: usize,
53}
54
55impl Default for SignatureQueryOptions {
56    fn default() -> Self {
57        Self {
58            max_candidates: 1_000,
59            probe_radius: 1,
60            max_probes: 1 + (2 * DEFAULT_SIGNATURE_PROBES),
61        }
62    }
63}
64
65/// Signature-bucket index for sparse ternary vectors.
66///
67/// The signature is a compact encoding of the vector’s values at a fixed set of
68/// probe dimensions. Vectors sharing signatures are likely to be similar.
69#[derive(Clone, Debug)]
70pub struct TernarySignatureIndex {
71    probe_dims: Vec<usize>,
72    buckets: HashMap<u64, Vec<usize>>, // signature -> sorted IDs
73}
74
75impl TernarySignatureIndex {
76    /// Build a signature index from a codebook-style map.
77    ///
78    /// IDs do not need to be contiguous.
79    pub fn build_from_map(map: &HashMap<usize, SparseVec>) -> Self {
80        let probe_dims = default_probe_dims(DEFAULT_SIGNATURE_PROBES);
81        Self::build_from_map_with_probes(map, probe_dims)
82    }
83
84    /// Build a signature index from a map using explicit probe dimensions.
85    pub fn build_from_map_with_probes(
86        map: &HashMap<usize, SparseVec>,
87        probe_dims: Vec<usize>,
88    ) -> Self {
89        let mut buckets: HashMap<u64, Vec<usize>> = HashMap::new();
90
91        // Deterministic build: iterate IDs in sorted order.
92        let mut ids: Vec<usize> = map.keys().copied().collect();
93        ids.sort_unstable();
94
95        for id in ids {
96            let Some(vec) = map.get(&id) else { continue };
97            let sig = signature_for(vec, &probe_dims);
98            buckets.entry(sig).or_default().push(id);
99        }
100
101        // Buckets are already in increasing ID order due to sorted iteration, but keep it explicit.
102        for ids in buckets.values_mut() {
103            ids.sort_unstable();
104            ids.dedup();
105        }
106
107        Self {
108            probe_dims,
109            buckets,
110        }
111    }
112
113    pub fn probe_dims(&self) -> &[usize] {
114        &self.probe_dims
115    }
116
117    /// Get candidate IDs for a query vector.
118    pub fn candidates_with_options(
119        &self,
120        query: &SparseVec,
121        opts: SignatureQueryOptions,
122    ) -> Vec<usize> {
123        if opts.max_candidates == 0 {
124            return Vec::new();
125        }
126
127        let sig = signature_for(query, &self.probe_dims);
128        let probe_radius = opts.probe_radius.min(1);
129        let probe_sigs =
130            probe_signatures(sig, self.probe_dims.len(), probe_radius, opts.max_probes);
131
132        let mut seen: HashSet<usize> = HashSet::new();
133        let mut out: Vec<usize> = Vec::new();
134
135        for ps in probe_sigs {
136            let Some(ids) = self.buckets.get(&ps) else {
137                continue;
138            };
139            for &id in ids {
140                if seen.insert(id) {
141                    out.push(id);
142                    if out.len() >= opts.max_candidates {
143                        break;
144                    }
145                }
146            }
147            if out.len() >= opts.max_candidates {
148                break;
149            }
150        }
151
152        // Keep deterministic ordering for downstream callers.
153        out.sort_unstable();
154        out
155    }
156}
157
158impl CandidateGenerator<SparseVec> for TernarySignatureIndex {
159    type Candidate = usize;
160
161    /// Generate up to `k` candidate IDs.
162    fn candidates(&self, query: &SparseVec, k: usize) -> Vec<Self::Candidate> {
163        self.candidates_with_options(
164            query,
165            SignatureQueryOptions {
166                max_candidates: k,
167                ..SignatureQueryOptions::default()
168            },
169        )
170    }
171}
172
173fn default_probe_dims(count: usize) -> Vec<usize> {
174    let mut out = Vec::with_capacity(count);
175    let mut seen = HashSet::with_capacity(count * 2);
176
177    // Deterministic pseudo-random stream (SplitMix64).
178    // Arbitrary fixed seed for deterministic probe selection.
179    let mut state: u64 = 0xED00_0000_0000_0001u64;
180
181    while out.len() < count {
182        state = splitmix64(state);
183        let d = (state as usize) % DIM;
184        if seen.insert(d) {
185            out.push(d);
186        }
187    }
188
189    out
190}
191
192fn splitmix64(mut x: u64) -> u64 {
193    x = x.wrapping_add(0x9E37_79B9_7F4A_7C15);
194    let mut z = x;
195    z = (z ^ (z >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9);
196    z = (z ^ (z >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB);
197    z ^ (z >> 31)
198}
199
200/// Encode a signature into 2-bit lanes:
201/// - 0 => 00
202/// - +1 => 01
203/// - -1 => 10
204fn signature_for(vec: &SparseVec, probe_dims: &[usize]) -> u64 {
205    let mut sig: u64 = 0;
206    for (i, &d) in probe_dims.iter().enumerate() {
207        let lane = match sign_at(vec, d) {
208            0 => 0b00u64,
209            1 => 0b01u64,
210            -1 => 0b10u64,
211            _ => 0b00u64,
212        };
213        sig |= lane << (2 * i);
214    }
215    sig
216}
217
218fn sign_at(vec: &SparseVec, dim: usize) -> i8 {
219    if vec.pos.contains(&dim) {
220        1
221    } else if vec.neg.contains(&dim) {
222        -1
223    } else {
224        0
225    }
226}
227
228fn probe_signatures(base: u64, probes: usize, radius: u8, max_probes: usize) -> Vec<u64> {
229    if max_probes == 0 {
230        return Vec::new();
231    }
232
233    let mut out = Vec::new();
234    out.push(base);
235
236    if radius == 0 {
237        return out;
238    }
239
240    // Radius-1 probing: for each probe lane, flip it to the other two values.
241    for i in 0..probes {
242        if out.len() >= max_probes {
243            break;
244        }
245
246        let shift = 2 * i;
247        let mask = 0b11u64 << shift;
248        let cur = (base & mask) >> shift;
249
250        // Deterministic order: 00 -> 01 -> 10.
251        for &alt in &[0b00u64, 0b01u64, 0b10u64] {
252            if alt == cur {
253                continue;
254            }
255            let next = (base & !mask) | (alt << shift);
256            out.push(next);
257            if out.len() >= max_probes {
258                break;
259            }
260        }
261    }
262
263    out
264}
265
266#[cfg(test)]
267mod tests {
268    use super::*;
269    use embeddenator_vsa::ReversibleVSAConfig;
270
271    #[test]
272    fn default_probe_dims_are_stable_and_in_range() {
273        let a = default_probe_dims(DEFAULT_SIGNATURE_PROBES);
274        let b = default_probe_dims(DEFAULT_SIGNATURE_PROBES);
275        assert_eq!(a, b);
276        assert_eq!(a.len(), DEFAULT_SIGNATURE_PROBES);
277        for &d in &a {
278            assert!(d < DIM);
279        }
280
281        let mut uniq = a.clone();
282        uniq.sort_unstable();
283        uniq.dedup();
284        assert_eq!(uniq.len(), DEFAULT_SIGNATURE_PROBES);
285    }
286
287    #[test]
288    fn candidates_are_deterministic_and_include_self_when_exact_bucket_hits() {
289        let cfg = ReversibleVSAConfig::default();
290
291        let v0 = SparseVec::encode_data(b"alpha", &cfg, None);
292        let v1 = SparseVec::encode_data(b"beta", &cfg, None);
293
294        let mut map = HashMap::new();
295        map.insert(0, v0.clone());
296        map.insert(1, v1);
297
298        let idx = TernarySignatureIndex::build_from_map(&map);
299        let opts = SignatureQueryOptions {
300            max_candidates: 10,
301            probe_radius: 0,
302            max_probes: 1,
303        };
304
305        let c1 = idx.candidates_with_options(&v0, opts);
306        let c2 = idx.candidates_with_options(&v0, opts);
307        assert_eq!(c1, c2);
308        assert!(c1.contains(&0));
309    }
310
311    #[test]
312    fn probe_signatures_radius_one_includes_base_and_is_bounded() {
313        let base = 0u64;
314        let sigs = probe_signatures(base, 4, 1, 5);
315        assert_eq!(sigs.len(), 5);
316        assert_eq!(sigs[0], base);
317    }
318}