citadel_vector/vendored/prism/
binary.rs1use super::point::PointStore;
2use rayon::prelude::*;
3
4pub struct BinaryStore {
8 codes: Vec<u64>,
9 code_words: usize,
10 signs: Vec<f32>,
11 block_size: usize,
12}
13
14impl BinaryStore {
15 pub fn build(store: &PointStore) -> Self {
18 let n = store.len;
19 let dim = store.dim;
20 let code_words = dim.div_ceil(64);
21 let block_size = largest_pow2_factor(dim);
22
23 use rand::{Rng, SeedableRng};
24 let mut rng = rand::rngs::StdRng::seed_from_u64(0x505249534D);
25 let signs: Vec<f32> = (0..dim)
26 .map(|_| if rng.gen_bool(0.5) { 1.0 } else { -1.0 })
27 .collect();
28
29 let mut codes = vec![0u64; n * code_words];
30 codes
31 .par_chunks_mut(code_words)
32 .enumerate()
33 .for_each(|(i, chunk)| {
34 encode_vector(store.vector(i as u32), &signs, block_size, chunk);
35 });
36
37 Self {
38 codes,
39 code_words,
40 signs,
41 block_size,
42 }
43 }
44
45 #[inline]
47 pub fn code(&self, id: u32) -> &[u64] {
48 let start = id as usize * self.code_words;
49 &self.codes[start..start + self.code_words]
50 }
51
52 #[inline]
54 pub fn code_words(&self) -> usize {
55 self.code_words
56 }
57
58 pub fn encode_query(&self, query: &[f32]) -> Vec<u64> {
60 let mut code = vec![0u64; self.code_words];
61 encode_vector(query, &self.signs, self.block_size, &mut code);
62 code
63 }
64}
65
66fn encode_vector(vec: &[f32], signs: &[f32], block_size: usize, out: &mut [u64]) {
68 let dim = vec.len();
69 let mut buf: Vec<f32> = vec.iter().enumerate().map(|(d, &v)| v * signs[d]).collect();
70 for start in (0..dim).step_by(block_size) {
71 walsh_hadamard(&mut buf[start..start + block_size]);
72 }
73 for d in 0..dim {
74 if buf[d] >= 0.0 {
75 out[d / 64] |= 1u64 << (d % 64);
76 }
77 }
78}
79
80fn walsh_hadamard(data: &mut [f32]) {
83 let n = data.len();
84 debug_assert!(n.is_power_of_two());
85 if n <= 1 {
86 return;
87 }
88 let mut half = 1;
89 while half < n {
90 let step = half * 2;
91 for i in (0..n).step_by(step) {
92 for j in 0..half {
93 let a = data[i + j];
94 let b = data[i + j + half];
95 data[i + j] = a + b;
96 data[i + j + half] = a - b;
97 }
98 }
99 half = step;
100 }
101}
102
103fn largest_pow2_factor(n: usize) -> usize {
105 if n == 0 {
106 return 1;
107 }
108 1 << n.trailing_zeros()
109}
110
111#[cfg(test)]
112mod tests {
113 use super::super::point::PointStore;
114 use super::*;
115
116 #[test]
117 fn test_walsh_hadamard_identity() {
118 let mut data = vec![1.0, 0.0, 0.0, 0.0];
120 walsh_hadamard(&mut data);
121 assert_eq!(data, vec![1.0, 1.0, 1.0, 1.0]);
122 }
123
124 #[test]
125 fn test_walsh_hadamard_butterfly() {
126 let mut data = vec![1.0, 1.0];
128 walsh_hadamard(&mut data);
129 assert_eq!(data, vec![2.0, 0.0]);
130
131 let mut data = vec![1.0, -1.0];
133 walsh_hadamard(&mut data);
134 assert_eq!(data, vec![0.0, 2.0]);
135 }
136
137 #[test]
138 fn test_largest_pow2_factor() {
139 assert_eq!(largest_pow2_factor(384), 128); assert_eq!(largest_pow2_factor(128), 128);
141 assert_eq!(largest_pow2_factor(256), 256);
142 assert_eq!(largest_pow2_factor(12), 4); assert_eq!(largest_pow2_factor(1), 1);
144 }
145
146 #[test]
147 fn test_binary_query_encoding() {
148 let dim = 128;
150 let p0: Vec<f32> = (0..dim).map(|i| (i as f32 * 0.1).sin()).collect();
151 let mut vecs = Vec::with_capacity(dim);
152 vecs.extend_from_slice(&p0);
153
154 let store = PointStore::from_parts(vecs, dim, vec![vec![0]]);
155 let binary = BinaryStore::build(&store);
156
157 let q = binary.encode_query(&p0);
158 let c0 = binary.code(0);
159 assert_eq!(q, c0, "query encoding must match point encoding");
160 }
161
162 #[test]
163 fn test_hamming_distance_ordering() {
164 use super::super::distance;
165 let dim = 128;
166 let p0: Vec<f32> = (0..dim).map(|i| (i as f32 + 1.0) / dim as f32).collect();
168 let p1: Vec<f32> = p0.iter().map(|&v| v + 0.001).collect();
170 let p2: Vec<f32> = p0.iter().map(|&v| -v).collect();
172
173 let mut vecs = Vec::with_capacity(3 * dim);
174 vecs.extend_from_slice(&p0);
175 vecs.extend_from_slice(&p1);
176 vecs.extend_from_slice(&p2);
177
178 let store = PointStore::from_parts(vecs, dim, vec![vec![0, 0, 0]]);
179 let binary = BinaryStore::build(&store);
180 let q = binary.encode_query(&p0);
181
182 let d0 = distance::hamming(&q, binary.code(0));
183 let d1 = distance::hamming(&q, binary.code(1));
184 let d2 = distance::hamming(&q, binary.code(2));
185
186 assert_eq!(d0, 0, "same vector must have 0 Hamming distance");
187 assert!(
188 d1 < d2,
189 "close vector (d={d1}) must have smaller Hamming than opposite (d={d2})"
190 );
191 }
192
193 #[test]
194 fn test_binary_code_words() {
195 let store = PointStore::from_parts(vec![0.0; 128], 128, vec![vec![0]]);
197 let binary = BinaryStore::build(&store);
198 assert_eq!(binary.code_words(), 2);
199
200 let store = PointStore::from_parts(vec![0.0; 384], 384, vec![vec![0]]);
201 let binary = BinaryStore::build(&store);
202 assert_eq!(binary.code_words(), 6);
203 }
204}