Skip to main content

citadel_vector/vendored/prism/
binary.rs

1use super::point::PointStore;
2use rayon::prelude::*;
3
4/// Binary code store for Hamming distance pre-filtering.
5/// Encodes vectors as 1-bit-per-dimension codes via randomized Walsh-Hadamard
6/// rotation + sign extraction (SimHash).
7pub struct BinaryStore {
8    codes: Vec<u64>,
9    code_words: usize,
10    signs: Vec<f32>,
11    block_size: usize,
12}
13
14impl BinaryStore {
15    /// Reassemble from persisted parts (the ANN segment decode path). The
16    /// SIGNS are persisted rather than re-derived from the seed, so a future
17    /// seed change can never silently desynchronize codes from queries.
18    pub fn from_parts(
19        codes: Vec<u64>,
20        code_words: usize,
21        signs: Vec<f32>,
22        block_size: usize,
23    ) -> Self {
24        Self {
25            codes,
26            code_words,
27            signs,
28            block_size,
29        }
30    }
31
32    pub fn codes(&self) -> &[u64] {
33        &self.codes
34    }
35
36    pub fn signs(&self) -> &[f32] {
37        &self.signs
38    }
39
40    pub fn block_size(&self) -> usize {
41        self.block_size
42    }
43
44    /// Build binary codes: random sign flips (D) + Walsh-Hadamard in blocks of
45    /// `largest_pow2_factor(dim)`. Fixed seed for build/query consistency.
46    pub fn build(store: &PointStore) -> Self {
47        let n = store.len;
48        let dim = store.dim;
49        let code_words = dim.div_ceil(64);
50        let block_size = largest_pow2_factor(dim);
51        let signs = seeded_signs(dim);
52
53        let mut codes = vec![0u64; n * code_words];
54        codes
55            .par_chunks_mut(code_words)
56            .enumerate()
57            .for_each(|(i, chunk)| {
58                encode_vector(store.vector(i as u32), &signs, block_size, chunk);
59            });
60
61        Self {
62            codes,
63            code_words,
64            signs,
65            block_size,
66        }
67    }
68
69    /// A store with signs but no codes, for configs that never consult the
70    /// binary pre-filter (`binary_rerank == 0`). `encode_query` stays valid;
71    /// `code()` must not be reached (every caller is gated on the rerank
72    /// factor), so the per-point encoding pass and its memory are skipped.
73    pub fn empty(dim: usize) -> Self {
74        Self {
75            codes: Vec::new(),
76            code_words: dim.div_ceil(64),
77            signs: seeded_signs(dim),
78            block_size: largest_pow2_factor(dim),
79        }
80    }
81
82    /// Get the binary code (packed u64 words) for point id.
83    #[inline]
84    pub fn code(&self, id: u32) -> &[u64] {
85        let start = id as usize * self.code_words;
86        &self.codes[start..start + self.code_words]
87    }
88
89    /// Number of u64 words per binary code.
90    #[inline]
91    pub fn code_words(&self) -> usize {
92        self.code_words
93    }
94
95    /// Encode a query vector to binary code using the same HD rotation.
96    pub fn encode_query(&self, query: &[f32]) -> Vec<u64> {
97        let mut code = vec![0u64; self.code_words];
98        encode_vector(query, &self.signs, self.block_size, &mut code);
99        code
100    }
101}
102
103/// Seed-fixed random sign flips shared by build and query encoding.
104fn seeded_signs(dim: usize) -> Vec<f32> {
105    use rand::{Rng, SeedableRng};
106    let mut rng = rand::rngs::StdRng::seed_from_u64(0x505249534D);
107    (0..dim)
108        .map(|_| if rng.gen_bool(0.5) { 1.0 } else { -1.0 })
109        .collect()
110}
111
112/// Apply HD rotation (sign flip + WHT) and extract signs into packed u64 code.
113fn encode_vector(vec: &[f32], signs: &[f32], block_size: usize, out: &mut [u64]) {
114    let dim = vec.len();
115    let mut buf: Vec<f32> = vec.iter().enumerate().map(|(d, &v)| v * signs[d]).collect();
116    for start in (0..dim).step_by(block_size) {
117        walsh_hadamard(&mut buf[start..start + block_size]);
118    }
119    for d in 0..dim {
120        if buf[d] >= 0.0 {
121            out[d / 64] |= 1u64 << (d % 64);
122        }
123    }
124}
125
126/// In-place Walsh-Hadamard transform on a slice of length 2^k.
127/// Not normalized (irrelevant for sign extraction).
128fn walsh_hadamard(data: &mut [f32]) {
129    let n = data.len();
130    debug_assert!(n.is_power_of_two());
131    if n <= 1 {
132        return;
133    }
134    let mut half = 1;
135    while half < n {
136        let step = half * 2;
137        for i in (0..n).step_by(step) {
138            for j in 0..half {
139                let a = data[i + j];
140                let b = data[i + j + half];
141                data[i + j] = a + b;
142                data[i + j + half] = a - b;
143            }
144        }
145        half = step;
146    }
147}
148
149/// Largest power-of-2 factor of n (i.e., 2^(trailing zeros of n)).
150fn largest_pow2_factor(n: usize) -> usize {
151    if n == 0 {
152        return 1;
153    }
154    1 << n.trailing_zeros()
155}
156
157#[cfg(test)]
158mod tests {
159    use super::super::point::PointStore;
160    use super::*;
161
162    #[test]
163    fn test_walsh_hadamard_identity() {
164        let mut data = vec![1.0, 0.0, 0.0, 0.0];
165        walsh_hadamard(&mut data);
166        assert_eq!(data, vec![1.0, 1.0, 1.0, 1.0]);
167    }
168
169    #[test]
170    fn test_walsh_hadamard_butterfly() {
171        let mut data = vec![1.0, 1.0];
172        walsh_hadamard(&mut data);
173        assert_eq!(data, vec![2.0, 0.0]);
174
175        let mut data = vec![1.0, -1.0];
176        walsh_hadamard(&mut data);
177        assert_eq!(data, vec![0.0, 2.0]);
178    }
179
180    #[test]
181    fn test_largest_pow2_factor() {
182        assert_eq!(largest_pow2_factor(384), 128);
183        assert_eq!(largest_pow2_factor(128), 128);
184        assert_eq!(largest_pow2_factor(256), 256);
185        assert_eq!(largest_pow2_factor(12), 4);
186        assert_eq!(largest_pow2_factor(1), 1);
187    }
188
189    #[test]
190    fn test_binary_query_encoding() {
191        let dim = 128;
192        let p0: Vec<f32> = (0..dim).map(|i| (i as f32 * 0.1).sin()).collect();
193        let mut vecs = Vec::with_capacity(dim);
194        vecs.extend_from_slice(&p0);
195
196        let store = PointStore::from_parts(vecs, dim, vec![vec![0]]);
197        let binary = BinaryStore::build(&store);
198
199        let q = binary.encode_query(&p0);
200        let c0 = binary.code(0);
201        assert_eq!(q, c0, "query encoding must match point encoding");
202    }
203
204    #[test]
205    fn test_hamming_distance_ordering() {
206        use super::super::distance;
207        let dim = 128;
208        let p0: Vec<f32> = (0..dim).map(|i| (i as f32 + 1.0) / dim as f32).collect();
209        let p1: Vec<f32> = p0.iter().map(|&v| v + 0.001).collect();
210        let p2: Vec<f32> = p0.iter().map(|&v| -v).collect();
211
212        let mut vecs = Vec::with_capacity(3 * dim);
213        vecs.extend_from_slice(&p0);
214        vecs.extend_from_slice(&p1);
215        vecs.extend_from_slice(&p2);
216
217        let store = PointStore::from_parts(vecs, dim, vec![vec![0, 0, 0]]);
218        let binary = BinaryStore::build(&store);
219        let q = binary.encode_query(&p0);
220
221        let d0 = distance::hamming(&q, binary.code(0));
222        let d1 = distance::hamming(&q, binary.code(1));
223        let d2 = distance::hamming(&q, binary.code(2));
224
225        assert_eq!(d0, 0, "same vector must have 0 Hamming distance");
226        assert!(
227            d1 < d2,
228            "close vector (d={d1}) must have smaller Hamming than opposite (d={d2})"
229        );
230    }
231
232    #[test]
233    fn test_binary_code_words() {
234        let store = PointStore::from_parts(vec![0.0; 128], 128, vec![vec![0]]);
235        let binary = BinaryStore::build(&store);
236        assert_eq!(binary.code_words(), 2);
237
238        let store = PointStore::from_parts(vec![0.0; 384], 384, vec![vec![0]]);
239        let binary = BinaryStore::build(&store);
240        assert_eq!(binary.code_words(), 6);
241    }
242}