citadel_vector/vendored/prism/
binary.rs1use super::point::PointStore;
2use rayon::prelude::*;
3
4pub struct BinaryStore {
8 codes: Vec<u64>,
9 code_words: usize,
10 signs: Vec<f32>,
11 block_size: usize,
12}
13
14impl BinaryStore {
15 pub fn from_parts(
19 codes: Vec<u64>,
20 code_words: usize,
21 signs: Vec<f32>,
22 block_size: usize,
23 ) -> Self {
24 Self {
25 codes,
26 code_words,
27 signs,
28 block_size,
29 }
30 }
31
32 pub fn codes(&self) -> &[u64] {
33 &self.codes
34 }
35
36 pub fn signs(&self) -> &[f32] {
37 &self.signs
38 }
39
40 pub fn block_size(&self) -> usize {
41 self.block_size
42 }
43
44 pub fn build(store: &PointStore) -> Self {
47 let n = store.len;
48 let dim = store.dim;
49 let code_words = dim.div_ceil(64);
50 let block_size = largest_pow2_factor(dim);
51 let signs = seeded_signs(dim);
52
53 let mut codes = vec![0u64; n * code_words];
54 codes
55 .par_chunks_mut(code_words)
56 .enumerate()
57 .for_each(|(i, chunk)| {
58 encode_vector(store.vector(i as u32), &signs, block_size, chunk);
59 });
60
61 Self {
62 codes,
63 code_words,
64 signs,
65 block_size,
66 }
67 }
68
69 pub fn empty(dim: usize) -> Self {
74 Self {
75 codes: Vec::new(),
76 code_words: dim.div_ceil(64),
77 signs: seeded_signs(dim),
78 block_size: largest_pow2_factor(dim),
79 }
80 }
81
82 #[inline]
84 pub fn code(&self, id: u32) -> &[u64] {
85 let start = id as usize * self.code_words;
86 &self.codes[start..start + self.code_words]
87 }
88
89 #[inline]
91 pub fn code_words(&self) -> usize {
92 self.code_words
93 }
94
95 pub fn encode_query(&self, query: &[f32]) -> Vec<u64> {
97 let mut code = vec![0u64; self.code_words];
98 encode_vector(query, &self.signs, self.block_size, &mut code);
99 code
100 }
101}
102
103fn seeded_signs(dim: usize) -> Vec<f32> {
105 use rand::{Rng, SeedableRng};
106 let mut rng = rand::rngs::StdRng::seed_from_u64(0x505249534D);
107 (0..dim)
108 .map(|_| if rng.gen_bool(0.5) { 1.0 } else { -1.0 })
109 .collect()
110}
111
112fn encode_vector(vec: &[f32], signs: &[f32], block_size: usize, out: &mut [u64]) {
114 let dim = vec.len();
115 let mut buf: Vec<f32> = vec.iter().enumerate().map(|(d, &v)| v * signs[d]).collect();
116 for start in (0..dim).step_by(block_size) {
117 walsh_hadamard(&mut buf[start..start + block_size]);
118 }
119 for d in 0..dim {
120 if buf[d] >= 0.0 {
121 out[d / 64] |= 1u64 << (d % 64);
122 }
123 }
124}
125
126fn walsh_hadamard(data: &mut [f32]) {
129 let n = data.len();
130 debug_assert!(n.is_power_of_two());
131 if n <= 1 {
132 return;
133 }
134 let mut half = 1;
135 while half < n {
136 let step = half * 2;
137 for i in (0..n).step_by(step) {
138 for j in 0..half {
139 let a = data[i + j];
140 let b = data[i + j + half];
141 data[i + j] = a + b;
142 data[i + j + half] = a - b;
143 }
144 }
145 half = step;
146 }
147}
148
149fn largest_pow2_factor(n: usize) -> usize {
151 if n == 0 {
152 return 1;
153 }
154 1 << n.trailing_zeros()
155}
156
157#[cfg(test)]
158mod tests {
159 use super::super::point::PointStore;
160 use super::*;
161
162 #[test]
163 fn test_walsh_hadamard_identity() {
164 let mut data = vec![1.0, 0.0, 0.0, 0.0];
165 walsh_hadamard(&mut data);
166 assert_eq!(data, vec![1.0, 1.0, 1.0, 1.0]);
167 }
168
169 #[test]
170 fn test_walsh_hadamard_butterfly() {
171 let mut data = vec![1.0, 1.0];
172 walsh_hadamard(&mut data);
173 assert_eq!(data, vec![2.0, 0.0]);
174
175 let mut data = vec![1.0, -1.0];
176 walsh_hadamard(&mut data);
177 assert_eq!(data, vec![0.0, 2.0]);
178 }
179
180 #[test]
181 fn test_largest_pow2_factor() {
182 assert_eq!(largest_pow2_factor(384), 128);
183 assert_eq!(largest_pow2_factor(128), 128);
184 assert_eq!(largest_pow2_factor(256), 256);
185 assert_eq!(largest_pow2_factor(12), 4);
186 assert_eq!(largest_pow2_factor(1), 1);
187 }
188
189 #[test]
190 fn test_binary_query_encoding() {
191 let dim = 128;
192 let p0: Vec<f32> = (0..dim).map(|i| (i as f32 * 0.1).sin()).collect();
193 let mut vecs = Vec::with_capacity(dim);
194 vecs.extend_from_slice(&p0);
195
196 let store = PointStore::from_parts(vecs, dim, vec![vec![0]]);
197 let binary = BinaryStore::build(&store);
198
199 let q = binary.encode_query(&p0);
200 let c0 = binary.code(0);
201 assert_eq!(q, c0, "query encoding must match point encoding");
202 }
203
204 #[test]
205 fn test_hamming_distance_ordering() {
206 use super::super::distance;
207 let dim = 128;
208 let p0: Vec<f32> = (0..dim).map(|i| (i as f32 + 1.0) / dim as f32).collect();
209 let p1: Vec<f32> = p0.iter().map(|&v| v + 0.001).collect();
210 let p2: Vec<f32> = p0.iter().map(|&v| -v).collect();
211
212 let mut vecs = Vec::with_capacity(3 * dim);
213 vecs.extend_from_slice(&p0);
214 vecs.extend_from_slice(&p1);
215 vecs.extend_from_slice(&p2);
216
217 let store = PointStore::from_parts(vecs, dim, vec![vec![0, 0, 0]]);
218 let binary = BinaryStore::build(&store);
219 let q = binary.encode_query(&p0);
220
221 let d0 = distance::hamming(&q, binary.code(0));
222 let d1 = distance::hamming(&q, binary.code(1));
223 let d2 = distance::hamming(&q, binary.code(2));
224
225 assert_eq!(d0, 0, "same vector must have 0 Hamming distance");
226 assert!(
227 d1 < d2,
228 "close vector (d={d1}) must have smaller Hamming than opposite (d={d2})"
229 );
230 }
231
232 #[test]
233 fn test_binary_code_words() {
234 let store = PointStore::from_parts(vec![0.0; 128], 128, vec![vec![0]]);
235 let binary = BinaryStore::build(&store);
236 assert_eq!(binary.code_words(), 2);
237
238 let store = PointStore::from_parts(vec![0.0; 384], 384, vec![vec![0]]);
239 let binary = BinaryStore::build(&store);
240 assert_eq!(binary.code_words(), 6);
241 }
242}