Skip to main content

citadel_vector/vendored/prism/
binary.rs

1use super::point::PointStore;
2use rayon::prelude::*;
3
4/// Binary code store for Hamming distance pre-filtering.
5/// Encodes vectors as 1-bit-per-dimension codes via randomized Walsh-Hadamard
6/// rotation + sign extraction (SimHash).
7pub struct BinaryStore {
8    codes: Vec<u64>,
9    code_words: usize,
10    signs: Vec<f32>,
11    block_size: usize,
12}
13
14impl BinaryStore {
15    /// Build binary codes: random sign flips (D) + Walsh-Hadamard in blocks of
16    /// `largest_pow2_factor(dim)`. Fixed seed for build/query consistency.
17    pub fn build(store: &PointStore) -> Self {
18        let n = store.len;
19        let dim = store.dim;
20        let code_words = dim.div_ceil(64);
21        let block_size = largest_pow2_factor(dim);
22
23        use rand::{Rng, SeedableRng};
24        let mut rng = rand::rngs::StdRng::seed_from_u64(0x505249534D);
25        let signs: Vec<f32> = (0..dim)
26            .map(|_| if rng.gen_bool(0.5) { 1.0 } else { -1.0 })
27            .collect();
28
29        let mut codes = vec![0u64; n * code_words];
30        codes
31            .par_chunks_mut(code_words)
32            .enumerate()
33            .for_each(|(i, chunk)| {
34                encode_vector(store.vector(i as u32), &signs, block_size, chunk);
35            });
36
37        Self {
38            codes,
39            code_words,
40            signs,
41            block_size,
42        }
43    }
44
45    /// Get the binary code (packed u64 words) for point id.
46    #[inline]
47    pub fn code(&self, id: u32) -> &[u64] {
48        let start = id as usize * self.code_words;
49        &self.codes[start..start + self.code_words]
50    }
51
52    /// Number of u64 words per binary code.
53    #[inline]
54    pub fn code_words(&self) -> usize {
55        self.code_words
56    }
57
58    /// Encode a query vector to binary code using the same HD rotation.
59    pub fn encode_query(&self, query: &[f32]) -> Vec<u64> {
60        let mut code = vec![0u64; self.code_words];
61        encode_vector(query, &self.signs, self.block_size, &mut code);
62        code
63    }
64}
65
66/// Apply HD rotation (sign flip + WHT) and extract signs into packed u64 code.
67fn encode_vector(vec: &[f32], signs: &[f32], block_size: usize, out: &mut [u64]) {
68    let dim = vec.len();
69    let mut buf: Vec<f32> = vec.iter().enumerate().map(|(d, &v)| v * signs[d]).collect();
70    for start in (0..dim).step_by(block_size) {
71        walsh_hadamard(&mut buf[start..start + block_size]);
72    }
73    for d in 0..dim {
74        if buf[d] >= 0.0 {
75            out[d / 64] |= 1u64 << (d % 64);
76        }
77    }
78}
79
80/// In-place Walsh-Hadamard transform on a slice of length 2^k.
81/// O(n log n) butterfly operations. Not normalized (unnecessary for sign extraction).
82fn walsh_hadamard(data: &mut [f32]) {
83    let n = data.len();
84    debug_assert!(n.is_power_of_two());
85    if n <= 1 {
86        return;
87    }
88    let mut half = 1;
89    while half < n {
90        let step = half * 2;
91        for i in (0..n).step_by(step) {
92            for j in 0..half {
93                let a = data[i + j];
94                let b = data[i + j + half];
95                data[i + j] = a + b;
96                data[i + j + half] = a - b;
97            }
98        }
99        half = step;
100    }
101}
102
103/// Largest power-of-2 factor of n (i.e., 2^(trailing zeros of n)).
104fn largest_pow2_factor(n: usize) -> usize {
105    if n == 0 {
106        return 1;
107    }
108    1 << n.trailing_zeros()
109}
110
111#[cfg(test)]
112mod tests {
113    use super::super::point::PointStore;
114    use super::*;
115
116    #[test]
117    fn test_walsh_hadamard_identity() {
118        // WHT of [1, 0, 0, 0] = [1, 1, 1, 1]
119        let mut data = vec![1.0, 0.0, 0.0, 0.0];
120        walsh_hadamard(&mut data);
121        assert_eq!(data, vec![1.0, 1.0, 1.0, 1.0]);
122    }
123
124    #[test]
125    fn test_walsh_hadamard_butterfly() {
126        // WHT of [1, 1] = [2, 0]
127        let mut data = vec![1.0, 1.0];
128        walsh_hadamard(&mut data);
129        assert_eq!(data, vec![2.0, 0.0]);
130
131        // WHT of [1, -1] = [0, 2]
132        let mut data = vec![1.0, -1.0];
133        walsh_hadamard(&mut data);
134        assert_eq!(data, vec![0.0, 2.0]);
135    }
136
137    #[test]
138    fn test_largest_pow2_factor() {
139        assert_eq!(largest_pow2_factor(384), 128); // 384 = 3 × 128
140        assert_eq!(largest_pow2_factor(128), 128);
141        assert_eq!(largest_pow2_factor(256), 256);
142        assert_eq!(largest_pow2_factor(12), 4); // 12 = 3 × 4
143        assert_eq!(largest_pow2_factor(1), 1);
144    }
145
146    #[test]
147    fn test_binary_query_encoding() {
148        // Encoding the same vector should produce the same binary code
149        let dim = 128;
150        let p0: Vec<f32> = (0..dim).map(|i| (i as f32 * 0.1).sin()).collect();
151        let mut vecs = Vec::with_capacity(dim);
152        vecs.extend_from_slice(&p0);
153
154        let store = PointStore::from_parts(vecs, dim, vec![vec![0]]);
155        let binary = BinaryStore::build(&store);
156
157        let q = binary.encode_query(&p0);
158        let c0 = binary.code(0);
159        assert_eq!(q, c0, "query encoding must match point encoding");
160    }
161
162    #[test]
163    fn test_hamming_distance_ordering() {
164        use super::super::distance;
165        let dim = 128;
166        // p0: smooth signal
167        let p0: Vec<f32> = (0..dim).map(|i| (i as f32 + 1.0) / dim as f32).collect();
168        // p1: tiny perturbation of p0
169        let p1: Vec<f32> = p0.iter().map(|&v| v + 0.001).collect();
170        // p2: negation of p0 (maximum angular distance)
171        let p2: Vec<f32> = p0.iter().map(|&v| -v).collect();
172
173        let mut vecs = Vec::with_capacity(3 * dim);
174        vecs.extend_from_slice(&p0);
175        vecs.extend_from_slice(&p1);
176        vecs.extend_from_slice(&p2);
177
178        let store = PointStore::from_parts(vecs, dim, vec![vec![0, 0, 0]]);
179        let binary = BinaryStore::build(&store);
180        let q = binary.encode_query(&p0);
181
182        let d0 = distance::hamming(&q, binary.code(0));
183        let d1 = distance::hamming(&q, binary.code(1));
184        let d2 = distance::hamming(&q, binary.code(2));
185
186        assert_eq!(d0, 0, "same vector must have 0 Hamming distance");
187        assert!(
188            d1 < d2,
189            "close vector (d={d1}) must have smaller Hamming than opposite (d={d2})"
190        );
191    }
192
193    #[test]
194    fn test_binary_code_words() {
195        // 128d → 2 u64 words, 384d → 6 u64 words
196        let store = PointStore::from_parts(vec![0.0; 128], 128, vec![vec![0]]);
197        let binary = BinaryStore::build(&store);
198        assert_eq!(binary.code_words(), 2);
199
200        let store = PointStore::from_parts(vec![0.0; 384], 384, vec![vec![0]]);
201        let binary = BinaryStore::build(&store);
202        assert_eq!(binary.code_words(), 6);
203    }
204}