citadel_vector/vendored/prism/
binary.rs1use super::point::PointStore;
2use rayon::prelude::*;
3
4pub struct BinaryStore {
8 codes: Vec<u64>,
9 code_words: usize,
10 signs: Vec<f32>,
11 block_size: usize,
12}
13
14impl BinaryStore {
15 pub fn from_parts(
19 codes: Vec<u64>,
20 code_words: usize,
21 signs: Vec<f32>,
22 block_size: usize,
23 ) -> Self {
24 Self {
25 codes,
26 code_words,
27 signs,
28 block_size,
29 }
30 }
31
32 pub fn codes(&self) -> &[u64] {
33 &self.codes
34 }
35
36 pub fn signs(&self) -> &[f32] {
37 &self.signs
38 }
39
40 pub fn block_size(&self) -> usize {
41 self.block_size
42 }
43
44 pub fn build(store: &PointStore) -> Self {
47 let n = store.len;
48 let dim = store.dim;
49 let code_words = dim.div_ceil(64);
50 let block_size = largest_pow2_factor(dim);
51
52 use rand::{Rng, SeedableRng};
53 let mut rng = rand::rngs::StdRng::seed_from_u64(0x505249534D);
54 let signs: Vec<f32> = (0..dim)
55 .map(|_| if rng.gen_bool(0.5) { 1.0 } else { -1.0 })
56 .collect();
57
58 let mut codes = vec![0u64; n * code_words];
59 codes
60 .par_chunks_mut(code_words)
61 .enumerate()
62 .for_each(|(i, chunk)| {
63 encode_vector(store.vector(i as u32), &signs, block_size, chunk);
64 });
65
66 Self {
67 codes,
68 code_words,
69 signs,
70 block_size,
71 }
72 }
73
74 #[inline]
76 pub fn code(&self, id: u32) -> &[u64] {
77 let start = id as usize * self.code_words;
78 &self.codes[start..start + self.code_words]
79 }
80
81 #[inline]
83 pub fn code_words(&self) -> usize {
84 self.code_words
85 }
86
87 pub fn encode_query(&self, query: &[f32]) -> Vec<u64> {
89 let mut code = vec![0u64; self.code_words];
90 encode_vector(query, &self.signs, self.block_size, &mut code);
91 code
92 }
93}
94
95fn encode_vector(vec: &[f32], signs: &[f32], block_size: usize, out: &mut [u64]) {
97 let dim = vec.len();
98 let mut buf: Vec<f32> = vec.iter().enumerate().map(|(d, &v)| v * signs[d]).collect();
99 for start in (0..dim).step_by(block_size) {
100 walsh_hadamard(&mut buf[start..start + block_size]);
101 }
102 for d in 0..dim {
103 if buf[d] >= 0.0 {
104 out[d / 64] |= 1u64 << (d % 64);
105 }
106 }
107}
108
109fn walsh_hadamard(data: &mut [f32]) {
112 let n = data.len();
113 debug_assert!(n.is_power_of_two());
114 if n <= 1 {
115 return;
116 }
117 let mut half = 1;
118 while half < n {
119 let step = half * 2;
120 for i in (0..n).step_by(step) {
121 for j in 0..half {
122 let a = data[i + j];
123 let b = data[i + j + half];
124 data[i + j] = a + b;
125 data[i + j + half] = a - b;
126 }
127 }
128 half = step;
129 }
130}
131
132fn largest_pow2_factor(n: usize) -> usize {
134 if n == 0 {
135 return 1;
136 }
137 1 << n.trailing_zeros()
138}
139
140#[cfg(test)]
141mod tests {
142 use super::super::point::PointStore;
143 use super::*;
144
145 #[test]
146 fn test_walsh_hadamard_identity() {
147 let mut data = vec![1.0, 0.0, 0.0, 0.0];
148 walsh_hadamard(&mut data);
149 assert_eq!(data, vec![1.0, 1.0, 1.0, 1.0]);
150 }
151
152 #[test]
153 fn test_walsh_hadamard_butterfly() {
154 let mut data = vec![1.0, 1.0];
155 walsh_hadamard(&mut data);
156 assert_eq!(data, vec![2.0, 0.0]);
157
158 let mut data = vec![1.0, -1.0];
159 walsh_hadamard(&mut data);
160 assert_eq!(data, vec![0.0, 2.0]);
161 }
162
163 #[test]
164 fn test_largest_pow2_factor() {
165 assert_eq!(largest_pow2_factor(384), 128);
166 assert_eq!(largest_pow2_factor(128), 128);
167 assert_eq!(largest_pow2_factor(256), 256);
168 assert_eq!(largest_pow2_factor(12), 4);
169 assert_eq!(largest_pow2_factor(1), 1);
170 }
171
172 #[test]
173 fn test_binary_query_encoding() {
174 let dim = 128;
175 let p0: Vec<f32> = (0..dim).map(|i| (i as f32 * 0.1).sin()).collect();
176 let mut vecs = Vec::with_capacity(dim);
177 vecs.extend_from_slice(&p0);
178
179 let store = PointStore::from_parts(vecs, dim, vec![vec![0]]);
180 let binary = BinaryStore::build(&store);
181
182 let q = binary.encode_query(&p0);
183 let c0 = binary.code(0);
184 assert_eq!(q, c0, "query encoding must match point encoding");
185 }
186
187 #[test]
188 fn test_hamming_distance_ordering() {
189 use super::super::distance;
190 let dim = 128;
191 let p0: Vec<f32> = (0..dim).map(|i| (i as f32 + 1.0) / dim as f32).collect();
192 let p1: Vec<f32> = p0.iter().map(|&v| v + 0.001).collect();
193 let p2: Vec<f32> = p0.iter().map(|&v| -v).collect();
194
195 let mut vecs = Vec::with_capacity(3 * dim);
196 vecs.extend_from_slice(&p0);
197 vecs.extend_from_slice(&p1);
198 vecs.extend_from_slice(&p2);
199
200 let store = PointStore::from_parts(vecs, dim, vec![vec![0, 0, 0]]);
201 let binary = BinaryStore::build(&store);
202 let q = binary.encode_query(&p0);
203
204 let d0 = distance::hamming(&q, binary.code(0));
205 let d1 = distance::hamming(&q, binary.code(1));
206 let d2 = distance::hamming(&q, binary.code(2));
207
208 assert_eq!(d0, 0, "same vector must have 0 Hamming distance");
209 assert!(
210 d1 < d2,
211 "close vector (d={d1}) must have smaller Hamming than opposite (d={d2})"
212 );
213 }
214
215 #[test]
216 fn test_binary_code_words() {
217 let store = PointStore::from_parts(vec![0.0; 128], 128, vec![vec![0]]);
218 let binary = BinaryStore::build(&store);
219 assert_eq!(binary.code_words(), 2);
220
221 let store = PointStore::from_parts(vec![0.0; 384], 384, vec![vec![0]]);
222 let binary = BinaryStore::build(&store);
223 assert_eq!(binary.code_words(), 6);
224 }
225}