1use crate::QuantError;
11
12#[derive(Debug, Clone)]
14pub struct Codebook {
15 pub bits: u8,
17 pub centroids: &'static [f32],
19 pub boundaries: &'static [f32],
21}
22
23impl Codebook {
24 pub fn for_bits(bits: u8) -> Result<&'static Codebook, QuantError> {
26 match bits {
27 1 => Ok(&CODEBOOK_1BIT),
28 2 => Ok(&CODEBOOK_2BIT),
29 3 => Ok(&CODEBOOK_3BIT),
30 4 => Ok(&CODEBOOK_4BIT),
31 _ => Err(QuantError::UnsupportedBitWidth(bits)),
32 }
33 }
34
35 #[inline]
37 pub fn quantize_scalar(&self, x: f32) -> u8 {
38 let mut idx = 0u8;
40 for &b in self.boundaries {
41 if x > b {
42 idx += 1;
43 } else {
44 break;
45 }
46 }
47 idx
48 }
49
50 #[inline]
52 pub fn dequantize_scalar(&self, idx: u8) -> f32 {
53 self.centroids[idx as usize]
54 }
55}
56
57static CENTROIDS_1BIT: [f32; 2] = [-0.7979, 0.7979];
75static BOUNDARIES_1BIT: [f32; 1] = [0.0];
76
77static CENTROIDS_2BIT: [f32; 4] = [-1.5104, -0.4528, 0.4528, 1.5104];
78static BOUNDARIES_2BIT: [f32; 3] = [-0.9816, 0.0, 0.9816];
79
80static CENTROIDS_3BIT: [f32; 8] = [
81 -2.1520, -1.3440, -0.7560, -0.2450, 0.2450, 0.7560, 1.3440, 2.1520,
82];
83static BOUNDARIES_3BIT: [f32; 7] = [-1.7480, -1.0500, -0.5005, 0.0, 0.5005, 1.0500, 1.7480];
84
85static CENTROIDS_4BIT: [f32; 16] = [
86 -2.7326, -2.0690, -1.6180, -1.2562, -0.9424, -0.6568, -0.3880, -0.1284, 0.1284, 0.3880, 0.6568,
87 0.9424, 1.2562, 1.6180, 2.0690, 2.7326,
88];
89static BOUNDARIES_4BIT: [f32; 15] = [
90 -2.4008, -1.8435, -1.4371, -1.0993, -0.7996, -0.5224, -0.2582, 0.0, 0.2582, 0.5224, 0.7996,
91 1.0993, 1.4371, 1.8435, 2.4008,
92];
93
94static CODEBOOK_1BIT: Codebook = Codebook {
95 bits: 1,
96 centroids: &CENTROIDS_1BIT,
97 boundaries: &BOUNDARIES_1BIT,
98};
99
100static CODEBOOK_2BIT: Codebook = Codebook {
101 bits: 2,
102 centroids: &CENTROIDS_2BIT,
103 boundaries: &BOUNDARIES_2BIT,
104};
105
106static CODEBOOK_3BIT: Codebook = Codebook {
107 bits: 3,
108 centroids: &CENTROIDS_3BIT,
109 boundaries: &BOUNDARIES_3BIT,
110};
111
112static CODEBOOK_4BIT: Codebook = Codebook {
113 bits: 4,
114 centroids: &CENTROIDS_4BIT,
115 boundaries: &BOUNDARIES_4BIT,
116};
117
118#[cfg(test)]
119mod tests {
120 use super::*;
121
122 #[test]
123 fn codebook_lookup_valid() {
124 for bits in 1..=4 {
125 let cb = Codebook::for_bits(bits).unwrap();
126 assert_eq!(cb.centroids.len(), 1 << bits);
127 assert_eq!(cb.boundaries.len(), (1 << bits) - 1);
128 }
129 }
130
131 #[test]
132 fn codebook_invalid_bits() {
133 assert!(Codebook::for_bits(0).is_err());
134 assert!(Codebook::for_bits(5).is_err());
135 }
136
137 #[test]
138 fn quantize_scalar_1bit() {
139 let cb = Codebook::for_bits(1).unwrap();
140 assert_eq!(cb.quantize_scalar(-0.5), 0);
141 assert_eq!(cb.quantize_scalar(0.5), 1);
142 assert_eq!(cb.quantize_scalar(0.0), 0); }
144
145 #[test]
146 fn quantize_scalar_2bit() {
147 let cb = Codebook::for_bits(2).unwrap();
148 assert_eq!(cb.quantize_scalar(-2.0), 0);
149 assert_eq!(cb.quantize_scalar(-0.5), 1);
150 assert_eq!(cb.quantize_scalar(0.5), 2);
151 assert_eq!(cb.quantize_scalar(2.0), 3);
152 }
153
154 #[test]
155 fn centroids_are_sorted() {
156 for bits in 1..=4 {
157 let cb = Codebook::for_bits(bits).unwrap();
158 for w in cb.centroids.windows(2) {
159 assert!(w[0] < w[1], "centroids not sorted for {bits}-bit codebook");
160 }
161 for w in cb.boundaries.windows(2) {
162 assert!(w[0] < w[1], "boundaries not sorted for {bits}-bit codebook");
163 }
164 }
165 }
166
167 #[test]
168 fn boundaries_between_centroids() {
169 for bits in 1..=4 {
170 let cb = Codebook::for_bits(bits).unwrap();
171 for (i, &b) in cb.boundaries.iter().enumerate() {
172 assert!(
173 b > cb.centroids[i] && b < cb.centroids[i + 1],
174 "boundary {b} not between centroids {} and {} for {bits}-bit",
175 cb.centroids[i],
176 cb.centroids[i + 1]
177 );
178 }
179 }
180 }
181
182 #[test]
183 fn dequantize_roundtrip() {
184 let cb = Codebook::for_bits(2).unwrap();
185 for i in 0..4u8 {
186 let val = cb.dequantize_scalar(i);
187 let idx = cb.quantize_scalar(val);
188 assert_eq!(idx, i, "dequantize({i}) = {val}, quantize({val}) = {idx}");
189 }
190 }
191}