nodedb_vector/quantize/
sq8.rs1use serde::{Deserialize, Serialize};
15
16#[derive(Clone, Serialize, Deserialize)]
18pub struct Sq8Codec {
19 pub dim: usize,
20 mins: Vec<f32>,
22 maxs: Vec<f32>,
24 scales: Vec<f32>,
27 inv_scales: Vec<f32>,
29}
30
31impl Sq8Codec {
32 pub fn calibrate(vectors: &[&[f32]], dim: usize) -> Self {
39 assert!(!vectors.is_empty(), "cannot calibrate on empty set");
40 assert!(dim > 0);
41
42 let mut mins = vec![f32::MAX; dim];
43 let mut maxs = vec![f32::MIN; dim];
44
45 for v in vectors {
46 debug_assert_eq!(v.len(), dim);
47 for d in 0..dim {
48 if v[d] < mins[d] {
49 mins[d] = v[d];
50 }
51 if v[d] > maxs[d] {
52 maxs[d] = v[d];
53 }
54 }
55 }
56
57 let mut scales = vec![0.0f32; dim];
58 let mut inv_scales = vec![0.0f32; dim];
59 for d in 0..dim {
60 let range = maxs[d] - mins[d];
61 if range > f32::EPSILON {
62 scales[d] = range / 255.0;
63 inv_scales[d] = 255.0 / range;
64 }
65 }
66
67 Self {
68 dim,
69 mins,
70 maxs,
71 scales,
72 inv_scales,
73 }
74 }
75
76 pub fn quantize(&self, vector: &[f32]) -> Vec<u8> {
78 debug_assert_eq!(vector.len(), self.dim);
79 let mut out = Vec::with_capacity(self.dim);
80 for ((&v, &min), (&max, &inv_scale)) in vector
81 .iter()
82 .zip(self.mins.iter())
83 .zip(self.maxs.iter().zip(self.inv_scales.iter()))
84 {
85 let clamped = v.clamp(min, max);
86 let q = ((clamped - min) * inv_scale).round() as u8;
87 out.push(q);
88 }
89 out
90 }
91
92 pub fn quantize_batch(&self, vectors: &[&[f32]]) -> Vec<u8> {
96 let mut out = Vec::with_capacity(self.dim * vectors.len());
97 for v in vectors {
98 out.extend(self.quantize(v));
99 }
100 out
101 }
102
103 pub fn dequantize(&self, quantized: &[u8]) -> Vec<f32> {
105 debug_assert_eq!(quantized.len(), self.dim);
106 let mut out = Vec::with_capacity(self.dim);
107 for ((&q, &min), &scale) in quantized
108 .iter()
109 .zip(self.mins.iter())
110 .zip(self.scales.iter())
111 {
112 out.push(min + q as f32 * scale);
113 }
114 out
115 }
116
117 #[inline]
122 pub fn asymmetric_l2(&self, query: &[f32], candidate: &[u8]) -> f32 {
123 debug_assert_eq!(query.len(), self.dim);
124 debug_assert_eq!(candidate.len(), self.dim);
125 let mut sum = 0.0f32;
126 for d in 0..self.dim {
127 let dequant = self.mins[d] + candidate[d] as f32 * self.scales[d];
128 let diff = query[d] - dequant;
129 sum += diff * diff;
130 }
131 sum
132 }
133
134 #[inline]
136 pub fn asymmetric_cosine(&self, query: &[f32], candidate: &[u8]) -> f32 {
137 debug_assert_eq!(query.len(), self.dim);
138 debug_assert_eq!(candidate.len(), self.dim);
139 let mut dot = 0.0f32;
140 let mut norm_q = 0.0f32;
141 let mut norm_c = 0.0f32;
142 for d in 0..self.dim {
143 let dequant = self.mins[d] + candidate[d] as f32 * self.scales[d];
144 dot += query[d] * dequant;
145 norm_q += query[d] * query[d];
146 norm_c += dequant * dequant;
147 }
148 let denom = (norm_q * norm_c).sqrt();
149 if denom < f32::EPSILON {
150 return 1.0;
151 }
152 (1.0 - dot / denom).max(0.0)
153 }
154
155 #[inline]
157 pub fn asymmetric_ip(&self, query: &[f32], candidate: &[u8]) -> f32 {
158 debug_assert_eq!(query.len(), self.dim);
159 debug_assert_eq!(candidate.len(), self.dim);
160 let mut dot = 0.0f32;
161 for d in 0..self.dim {
162 let dequant = self.mins[d] + candidate[d] as f32 * self.scales[d];
163 dot += query[d] * dequant;
164 }
165 -dot
166 }
167
168 pub fn dim(&self) -> usize {
170 self.dim
171 }
172}
173
174#[cfg(test)]
175mod tests {
176 use super::*;
177
178 fn make_vectors() -> Vec<Vec<f32>> {
179 (0..100)
180 .map(|i| vec![i as f32 * 0.1, (i as f32).sin(), (i as f32).cos()])
181 .collect()
182 }
183
184 #[test]
185 fn quantize_dequantize_roundtrip() {
186 let vecs = make_vectors();
187 let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
188 let codec = Sq8Codec::calibrate(&refs, 3);
189
190 for v in &vecs {
191 let q = codec.quantize(v);
192 let dq = codec.dequantize(&q);
193 for d in 0..3 {
194 let error = (v[d] - dq[d]).abs();
195 let range = codec.maxs[d] - codec.mins[d];
196 assert!(
198 error <= range / 255.0 + 1e-6,
199 "d={d}: error={error}, max_step={}",
200 range / 255.0
201 );
202 }
203 }
204 }
205
206 #[test]
207 fn asymmetric_l2_close_to_exact() {
208 let vecs = make_vectors();
209 let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
210 let codec = Sq8Codec::calibrate(&refs, 3);
211
212 let query = &[5.0, 0.5, -0.5];
213 for v in &vecs {
214 let q = codec.quantize(v);
215 let exact = crate::distance::l2_squared(query, v);
216 let approx = codec.asymmetric_l2(query, &q);
217 let rel_error = if exact > 0.01 {
219 (exact - approx).abs() / exact
220 } else {
221 (exact - approx).abs()
222 };
223 assert!(
224 rel_error < 0.05 || (exact - approx).abs() < 0.1,
225 "exact={exact}, approx={approx}, rel_error={rel_error}"
226 );
227 }
228 }
229
230 #[test]
231 fn batch_quantize() {
232 let vecs = make_vectors();
233 let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
234 let codec = Sq8Codec::calibrate(&refs, 3);
235
236 let batch = codec.quantize_batch(&refs);
237 assert_eq!(batch.len(), 3 * 100);
238
239 let single = codec.quantize(&vecs[0]);
241 assert_eq!(&batch[0..3], &single[..]);
242 }
243
244 #[test]
245 fn constant_dimension_handled() {
246 let vecs: Vec<Vec<f32>> = (0..10).map(|i| vec![5.0, i as f32]).collect();
248 let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
249 let codec = Sq8Codec::calibrate(&refs, 2);
250
251 let q = codec.quantize(&[5.0, 3.0]);
253 assert_eq!(q[0], 0); }
255}