reddb_server/storage/engine/
binary_quantize.rs1use std::cmp::Ordering;
42
43#[derive(Clone, Debug)]
45pub struct BinaryVector {
46 data: Vec<u64>,
48 dim: usize,
50}
51
52impl BinaryVector {
53 pub fn from_f32(values: &[f32]) -> Self {
57 let dim = values.len();
58 let n_words = dim.div_ceil(64); let mut data = vec![0u64; n_words];
60
61 for (i, &v) in values.iter().enumerate() {
62 if v > 0.0 {
63 let word_idx = i / 64;
64 let bit_idx = i % 64;
65 data[word_idx] |= 1u64 << bit_idx;
66 }
67 }
68
69 Self { data, dim }
70 }
71
72 pub fn from_f32_threshold(values: &[f32], threshold: f32) -> Self {
76 let dim = values.len();
77 let n_words = dim.div_ceil(64);
78 let mut data = vec![0u64; n_words];
79
80 for (i, &v) in values.iter().enumerate() {
81 if v > threshold {
82 let word_idx = i / 64;
83 let bit_idx = i % 64;
84 data[word_idx] |= 1u64 << bit_idx;
85 }
86 }
87
88 Self { data, dim }
89 }
90
91 pub fn from_f32_median(values: &[f32]) -> Self {
96 let mut sorted = values.to_vec();
97 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal));
98 let median = if sorted.len().is_multiple_of(2) {
99 (sorted[sorted.len() / 2 - 1] + sorted[sorted.len() / 2]) / 2.0
100 } else {
101 sorted[sorted.len() / 2]
102 };
103
104 Self::from_f32_threshold(values, median)
105 }
106
107 pub fn from_raw(data: Vec<u64>, dim: usize) -> Self {
109 Self { data, dim }
110 }
111
112 #[inline]
114 pub fn dim(&self) -> usize {
115 self.dim
116 }
117
118 #[inline]
120 pub fn data(&self) -> &[u64] {
121 &self.data
122 }
123
124 #[inline]
126 pub fn size_bytes(&self) -> usize {
127 self.data.len() * 8
128 }
129
130 #[inline]
135 pub fn hamming_distance(&self, other: &Self) -> u32 {
136 debug_assert_eq!(self.dim, other.dim, "Dimensions must match");
137
138 hamming_distance_simd(&self.data, &other.data)
139 }
140
141 #[inline]
145 pub fn hamming_distance_normalized(&self, other: &Self) -> f32 {
146 let dist = self.hamming_distance(other) as f32;
147 dist / self.dim as f32
148 }
149
150 #[inline]
156 pub fn approx_cosine_similarity(&self, other: &Self) -> f32 {
157 let normalized_dist = self.hamming_distance_normalized(other);
158 1.0 - 2.0 * normalized_dist
159 }
160}
161
162#[inline]
168pub fn hamming_distance_simd(a: &[u64], b: &[u64]) -> u32 {
169 debug_assert_eq!(a.len(), b.len(), "Vectors must have same length");
170
171 #[cfg(target_arch = "x86_64")]
172 {
173 if is_x86_feature_detected!("popcnt") {
174 return unsafe { hamming_distance_popcnt(a, b) };
175 }
176 }
177
178 hamming_distance_scalar(a, b)
179}
180
181#[inline]
183fn hamming_distance_scalar(a: &[u64], b: &[u64]) -> u32 {
184 let mut count = 0u32;
185 for (x, y) in a.iter().zip(b.iter()) {
186 count += (x ^ y).count_ones();
187 }
188 count
189}
190
191#[cfg(target_arch = "x86_64")]
193#[target_feature(enable = "popcnt")]
194#[inline]
195unsafe fn hamming_distance_popcnt(a: &[u64], b: &[u64]) -> u32 {
196 use std::arch::x86_64::_popcnt64;
197
198 let mut count = 0i32;
199
200 let chunks = a.len() / 4;
202 for i in 0..chunks {
203 let idx = i * 4;
204 let xor0 = a[idx] ^ b[idx];
205 let xor1 = a[idx + 1] ^ b[idx + 1];
206 let xor2 = a[idx + 2] ^ b[idx + 2];
207 let xor3 = a[idx + 3] ^ b[idx + 3];
208
209 count += _popcnt64(xor0 as i64);
210 count += _popcnt64(xor1 as i64);
211 count += _popcnt64(xor2 as i64);
212 count += _popcnt64(xor3 as i64);
213 }
214
215 for i in (chunks * 4)..a.len() {
217 count += _popcnt64((a[i] ^ b[i]) as i64);
218 }
219
220 count as u32
221}
222
223#[derive(Clone)]
229pub struct BinaryIndex {
230 vectors: Vec<u64>,
232 words_per_vector: usize,
234 n_vectors: usize,
236 dim: usize,
238}
239
240impl BinaryIndex {
241 pub fn new(dim: usize) -> Self {
243 let words_per_vector = dim.div_ceil(64);
244 Self {
245 vectors: Vec::new(),
246 words_per_vector,
247 n_vectors: 0,
248 dim,
249 }
250 }
251
252 pub fn with_capacity(dim: usize, capacity: usize) -> Self {
254 let words_per_vector = dim.div_ceil(64);
255 Self {
256 vectors: Vec::with_capacity(capacity * words_per_vector),
257 words_per_vector,
258 n_vectors: 0,
259 dim,
260 }
261 }
262
263 pub fn add(&mut self, vector: &BinaryVector) {
265 debug_assert_eq!(vector.dim, self.dim, "Dimension mismatch");
266 self.vectors.extend_from_slice(&vector.data);
267 self.n_vectors += 1;
268 }
269
270 pub fn add_f32(&mut self, vector: &[f32]) {
272 let binary = BinaryVector::from_f32(vector);
273 self.add(&binary);
274 }
275
276 #[inline]
278 pub fn len(&self) -> usize {
279 self.n_vectors
280 }
281
282 #[inline]
284 pub fn is_empty(&self) -> bool {
285 self.n_vectors == 0
286 }
287
288 pub fn memory_bytes(&self) -> usize {
290 self.vectors.len() * 8
291 }
292
293 pub fn get(&self, idx: usize) -> Option<BinaryVector> {
295 if idx >= self.n_vectors {
296 return None;
297 }
298 let start = idx * self.words_per_vector;
299 let end = start + self.words_per_vector;
300 Some(BinaryVector::from_raw(
301 self.vectors[start..end].to_vec(),
302 self.dim,
303 ))
304 }
305
306 pub fn search(&self, query: &BinaryVector, k: usize) -> Vec<(usize, u32)> {
310 if self.n_vectors == 0 {
311 return Vec::new();
312 }
313
314 let k = k.min(self.n_vectors);
315 let mut results: Vec<(usize, u32)> = Vec::with_capacity(self.n_vectors);
316
317 for i in 0..self.n_vectors {
319 let start = i * self.words_per_vector;
320 let end = start + self.words_per_vector;
321 let dist = hamming_distance_simd(&query.data, &self.vectors[start..end]);
322 results.push((i, dist));
323 }
324
325 if k < self.n_vectors {
327 results.select_nth_unstable_by_key(k - 1, |&(_, d)| d);
328 results.truncate(k);
329 }
330 results.sort_by_key(|&(_, d)| d);
331
332 results
333 }
334
335 pub fn search_f32(&self, query: &[f32], k: usize) -> Vec<(usize, u32)> {
337 let binary_query = BinaryVector::from_f32(query);
338 self.search(&binary_query, k)
339 }
340
341 pub fn batch_search(&self, queries: &[BinaryVector], k: usize) -> Vec<Vec<(usize, u32)>> {
345 queries.iter().map(|q| self.search(q, k)).collect()
346 }
347}
348
349#[derive(Debug, Clone)]
355pub struct BinarySearchResult {
356 pub id: usize,
358 pub hamming_distance: u32,
360 pub rescored_distance: Option<f32>,
362}
363
364impl BinarySearchResult {
365 pub fn new(id: usize, hamming_distance: u32) -> Self {
366 Self {
367 id,
368 hamming_distance,
369 rescored_distance: None,
370 }
371 }
372
373 pub fn final_distance(&self) -> f32 {
375 self.rescored_distance
376 .unwrap_or(self.hamming_distance as f32)
377 }
378}
379
380#[cfg(test)]
385mod tests {
386 use super::*;
387
388 #[test]
389 fn test_binary_quantization_positive() {
390 let values = vec![1.0, -1.0, 0.5, -0.5, 0.0, 2.0, -2.0, 0.1];
391 let binary = BinaryVector::from_f32(&values);
392
393 assert_eq!(binary.data[0] & 0xFF, 0b10100101);
396 }
397
398 #[test]
399 fn test_hamming_distance_identical() {
400 let v1 = BinaryVector::from_f32(&[1.0, -1.0, 1.0, -1.0]);
401 let v2 = BinaryVector::from_f32(&[1.0, -1.0, 1.0, -1.0]);
402 assert_eq!(v1.hamming_distance(&v2), 0);
403 }
404
405 #[test]
406 fn test_hamming_distance_opposite() {
407 let v1 = BinaryVector::from_f32(&[1.0, 1.0, 1.0, 1.0]);
408 let v2 = BinaryVector::from_f32(&[-1.0, -1.0, -1.0, -1.0]);
409 assert_eq!(v1.hamming_distance(&v2), 4);
410 }
411
412 #[test]
413 fn test_hamming_distance_partial() {
414 let v1 = BinaryVector::from_f32(&[1.0, 1.0, -1.0, -1.0]);
415 let v2 = BinaryVector::from_f32(&[1.0, -1.0, 1.0, -1.0]);
416 assert_eq!(v1.hamming_distance(&v2), 2);
417 }
418
419 #[test]
420 fn test_large_vector() {
421 let v1: Vec<f32> = (0..1024)
423 .map(|i| if i % 2 == 0 { 1.0 } else { -1.0 })
424 .collect();
425 let v2: Vec<f32> = (0..1024)
426 .map(|i| if i % 3 == 0 { 1.0 } else { -1.0 })
427 .collect();
428
429 let b1 = BinaryVector::from_f32(&v1);
430 let b2 = BinaryVector::from_f32(&v2);
431
432 assert_eq!(b1.size_bytes(), 128);
434 assert_eq!(b1.data.len(), 16);
435
436 let dist = b1.hamming_distance(&b2);
437 assert!(dist > 0 && dist < 1024);
438 }
439
440 #[test]
441 fn test_binary_index_search() {
442 let mut index = BinaryIndex::new(64);
443
444 let v1 = vec![1.0f32; 64];
446 let v2 = vec![-1.0f32; 64];
447 let v3: Vec<f32> = (0..64).map(|i| if i < 32 { 1.0 } else { -1.0 }).collect();
448
449 index.add_f32(&v1);
450 index.add_f32(&v2);
451 index.add_f32(&v3);
452
453 let query: Vec<f32> = (0..64).map(|i| if i < 60 { 1.0 } else { -1.0 }).collect();
455 let results = index.search_f32(&query, 3);
456
457 assert_eq!(results[0].0, 0);
459 assert_eq!(results[0].1, 4);
460 }
461
462 #[test]
463 fn test_approx_cosine() {
464 let v1 = BinaryVector::from_f32(&[1.0; 128]);
465 let v2 = BinaryVector::from_f32(&[1.0; 128]);
466 let sim = v1.approx_cosine_similarity(&v2);
467 assert!((sim - 1.0).abs() < 0.001); let v3 = BinaryVector::from_f32(&[-1.0; 128]);
470 let sim2 = v1.approx_cosine_similarity(&v3);
471 assert!((sim2 - (-1.0)).abs() < 0.001); }
473
474 #[test]
475 fn test_compression_ratio() {
476 let fp32_size = 1024 * 4;
481 let binary = BinaryVector::from_f32(&vec![1.0; 1024]);
482 let binary_size = binary.size_bytes();
483
484 assert_eq!(binary_size, 128);
485 assert_eq!(fp32_size / binary_size, 32);
486 }
487}