reddb_server/storage/engine/turboquant/
codebook.rs1use std::collections::HashMap;
7use std::sync::{Mutex, OnceLock};
8
9#[derive(Debug, Clone)]
10pub struct Codebook {
11 dim: usize,
12 bit_width: u8,
13 boundaries: Vec<f64>,
14 centroids: Vec<f64>,
15 iterations: usize,
16}
17
18impl Codebook {
19 pub fn for_dim_bits(dim: usize, bit_width: u8) -> Self {
20 static CACHE: OnceLock<Mutex<HashMap<(usize, u8), Codebook>>> = OnceLock::new();
21 let cache = CACHE.get_or_init(|| Mutex::new(HashMap::new()));
22 if let Some(found) = cache
23 .lock()
24 .unwrap_or_else(|e| e.into_inner())
25 .get(&(dim, bit_width))
26 {
27 return found.clone();
28 }
29
30 let levels = 1usize << bit_width;
31 let step = 2.0 / levels as f64;
32 let centroids = (0..levels)
33 .map(|i| -1.0 + (i as f64 + 0.5) * step)
34 .collect::<Vec<_>>();
35 let boundaries = centroids
36 .windows(2)
37 .map(|pair| (pair[0] + pair[1]) * 0.5)
38 .collect::<Vec<_>>();
39 let codebook = Self {
40 dim,
41 bit_width,
42 boundaries,
43 centroids,
44 iterations: 1,
45 };
46 cache
47 .lock()
48 .unwrap_or_else(|e| e.into_inner())
49 .insert((dim, bit_width), codebook.clone());
50 codebook
51 }
52
53 pub fn dim(&self) -> usize {
54 self.dim
55 }
56
57 pub fn bit_width(&self) -> u8 {
58 self.bit_width
59 }
60
61 pub fn boundaries(&self) -> &[f64] {
62 &self.boundaries
63 }
64
65 pub fn centroids(&self) -> &[f64] {
66 &self.centroids
67 }
68
69 pub fn iterations(&self) -> usize {
70 self.iterations
71 }
72
73 pub fn quantize(&self, value: f32) -> u8 {
74 let value = value.clamp(-1.0, 1.0) as f64;
75 self.boundaries
76 .partition_point(|boundary| value > *boundary) as u8
77 }
78}
79
80#[cfg(test)]
81mod tests {
82 use super::*;
83
84 #[test]
85 fn codebook_for_supported_dims_is_monotonic_and_converges() {
86 for dim in [384, 768, 1536, 3072] {
87 let codebook = Codebook::for_dim_bits(dim, 4);
88 assert_eq!(codebook.centroids().len(), 16);
89 assert!(codebook.boundaries().windows(2).all(|w| w[0] < w[1]));
90 assert!(codebook.centroids().windows(2).all(|w| w[0] < w[1]));
91 assert!(codebook.iterations() <= 200);
92 }
93 }
94}