Skip to main content

citadel_vector/vendored/prism/
binary.rs

1use super::point::PointStore;
2use rayon::prelude::*;
3
4/// Binary code store for Hamming distance pre-filtering.
5/// Encodes vectors as 1-bit-per-dimension codes via randomized Walsh-Hadamard
6/// rotation + sign extraction (SimHash).
7pub struct BinaryStore {
8    codes: Vec<u64>,
9    code_words: usize,
10    signs: Vec<f32>,
11    block_size: usize,
12}
13
14impl BinaryStore {
15    /// Reassemble from persisted parts (the ANN segment decode path). The
16    /// SIGNS are persisted rather than re-derived from the seed, so a future
17    /// seed change can never silently desynchronize codes from queries.
18    pub fn from_parts(
19        codes: Vec<u64>,
20        code_words: usize,
21        signs: Vec<f32>,
22        block_size: usize,
23    ) -> Self {
24        Self {
25            codes,
26            code_words,
27            signs,
28            block_size,
29        }
30    }
31
32    pub fn codes(&self) -> &[u64] {
33        &self.codes
34    }
35
36    pub fn signs(&self) -> &[f32] {
37        &self.signs
38    }
39
40    pub fn block_size(&self) -> usize {
41        self.block_size
42    }
43
44    /// Build binary codes: random sign flips (D) + Walsh-Hadamard in blocks of
45    /// `largest_pow2_factor(dim)`. Fixed seed for build/query consistency.
46    pub fn build(store: &PointStore) -> Self {
47        let n = store.len;
48        let dim = store.dim;
49        let code_words = dim.div_ceil(64);
50        let block_size = largest_pow2_factor(dim);
51
52        use rand::{Rng, SeedableRng};
53        let mut rng = rand::rngs::StdRng::seed_from_u64(0x505249534D);
54        let signs: Vec<f32> = (0..dim)
55            .map(|_| if rng.gen_bool(0.5) { 1.0 } else { -1.0 })
56            .collect();
57
58        let mut codes = vec![0u64; n * code_words];
59        codes
60            .par_chunks_mut(code_words)
61            .enumerate()
62            .for_each(|(i, chunk)| {
63                encode_vector(store.vector(i as u32), &signs, block_size, chunk);
64            });
65
66        Self {
67            codes,
68            code_words,
69            signs,
70            block_size,
71        }
72    }
73
74    /// Get the binary code (packed u64 words) for point id.
75    #[inline]
76    pub fn code(&self, id: u32) -> &[u64] {
77        let start = id as usize * self.code_words;
78        &self.codes[start..start + self.code_words]
79    }
80
81    /// Number of u64 words per binary code.
82    #[inline]
83    pub fn code_words(&self) -> usize {
84        self.code_words
85    }
86
87    /// Encode a query vector to binary code using the same HD rotation.
88    pub fn encode_query(&self, query: &[f32]) -> Vec<u64> {
89        let mut code = vec![0u64; self.code_words];
90        encode_vector(query, &self.signs, self.block_size, &mut code);
91        code
92    }
93}
94
95/// Apply HD rotation (sign flip + WHT) and extract signs into packed u64 code.
96fn encode_vector(vec: &[f32], signs: &[f32], block_size: usize, out: &mut [u64]) {
97    let dim = vec.len();
98    let mut buf: Vec<f32> = vec.iter().enumerate().map(|(d, &v)| v * signs[d]).collect();
99    for start in (0..dim).step_by(block_size) {
100        walsh_hadamard(&mut buf[start..start + block_size]);
101    }
102    for d in 0..dim {
103        if buf[d] >= 0.0 {
104            out[d / 64] |= 1u64 << (d % 64);
105        }
106    }
107}
108
109/// In-place Walsh-Hadamard transform on a slice of length 2^k.
110/// Not normalized (irrelevant for sign extraction).
111fn walsh_hadamard(data: &mut [f32]) {
112    let n = data.len();
113    debug_assert!(n.is_power_of_two());
114    if n <= 1 {
115        return;
116    }
117    let mut half = 1;
118    while half < n {
119        let step = half * 2;
120        for i in (0..n).step_by(step) {
121            for j in 0..half {
122                let a = data[i + j];
123                let b = data[i + j + half];
124                data[i + j] = a + b;
125                data[i + j + half] = a - b;
126            }
127        }
128        half = step;
129    }
130}
131
132/// Largest power-of-2 factor of n (i.e., 2^(trailing zeros of n)).
133fn largest_pow2_factor(n: usize) -> usize {
134    if n == 0 {
135        return 1;
136    }
137    1 << n.trailing_zeros()
138}
139
140#[cfg(test)]
141mod tests {
142    use super::super::point::PointStore;
143    use super::*;
144
145    #[test]
146    fn test_walsh_hadamard_identity() {
147        let mut data = vec![1.0, 0.0, 0.0, 0.0];
148        walsh_hadamard(&mut data);
149        assert_eq!(data, vec![1.0, 1.0, 1.0, 1.0]);
150    }
151
152    #[test]
153    fn test_walsh_hadamard_butterfly() {
154        let mut data = vec![1.0, 1.0];
155        walsh_hadamard(&mut data);
156        assert_eq!(data, vec![2.0, 0.0]);
157
158        let mut data = vec![1.0, -1.0];
159        walsh_hadamard(&mut data);
160        assert_eq!(data, vec![0.0, 2.0]);
161    }
162
163    #[test]
164    fn test_largest_pow2_factor() {
165        assert_eq!(largest_pow2_factor(384), 128);
166        assert_eq!(largest_pow2_factor(128), 128);
167        assert_eq!(largest_pow2_factor(256), 256);
168        assert_eq!(largest_pow2_factor(12), 4);
169        assert_eq!(largest_pow2_factor(1), 1);
170    }
171
172    #[test]
173    fn test_binary_query_encoding() {
174        let dim = 128;
175        let p0: Vec<f32> = (0..dim).map(|i| (i as f32 * 0.1).sin()).collect();
176        let mut vecs = Vec::with_capacity(dim);
177        vecs.extend_from_slice(&p0);
178
179        let store = PointStore::from_parts(vecs, dim, vec![vec![0]]);
180        let binary = BinaryStore::build(&store);
181
182        let q = binary.encode_query(&p0);
183        let c0 = binary.code(0);
184        assert_eq!(q, c0, "query encoding must match point encoding");
185    }
186
187    #[test]
188    fn test_hamming_distance_ordering() {
189        use super::super::distance;
190        let dim = 128;
191        let p0: Vec<f32> = (0..dim).map(|i| (i as f32 + 1.0) / dim as f32).collect();
192        let p1: Vec<f32> = p0.iter().map(|&v| v + 0.001).collect();
193        let p2: Vec<f32> = p0.iter().map(|&v| -v).collect();
194
195        let mut vecs = Vec::with_capacity(3 * dim);
196        vecs.extend_from_slice(&p0);
197        vecs.extend_from_slice(&p1);
198        vecs.extend_from_slice(&p2);
199
200        let store = PointStore::from_parts(vecs, dim, vec![vec![0, 0, 0]]);
201        let binary = BinaryStore::build(&store);
202        let q = binary.encode_query(&p0);
203
204        let d0 = distance::hamming(&q, binary.code(0));
205        let d1 = distance::hamming(&q, binary.code(1));
206        let d2 = distance::hamming(&q, binary.code(2));
207
208        assert_eq!(d0, 0, "same vector must have 0 Hamming distance");
209        assert!(
210            d1 < d2,
211            "close vector (d={d1}) must have smaller Hamming than opposite (d={d2})"
212        );
213    }
214
215    #[test]
216    fn test_binary_code_words() {
217        let store = PointStore::from_parts(vec![0.0; 128], 128, vec![vec![0]]);
218        let binary = BinaryStore::build(&store);
219        assert_eq!(binary.code_words(), 2);
220
221        let store = PointStore::from_parts(vec![0.0; 384], 384, vec![vec![0]]);
222        let binary = BinaryStore::build(&store);
223        assert_eq!(binary.code_words(), 6);
224    }
225}