nodedb_vector/quantize/
sq8.rs1use serde::{Deserialize, Serialize};
17
18use crate::error::VectorError;
19
20pub const MAGIC: &[u8; 6] = b"NDSQ\0\0";
24
25pub const SQ8_FORMAT_VERSION: u8 = 1;
27
28#[derive(Clone, Serialize, Deserialize, zerompk::ToMessagePack, zerompk::FromMessagePack)]
30pub struct Sq8Codec {
31 pub dim: usize,
32 mins: Vec<f32>,
34 maxs: Vec<f32>,
36 scales: Vec<f32>,
39 inv_scales: Vec<f32>,
41}
42
43impl Sq8Codec {
44 pub fn calibrate(vectors: &[&[f32]], dim: usize) -> Self {
51 assert!(!vectors.is_empty(), "cannot calibrate on empty set");
52 assert!(dim > 0);
53
54 let mut mins = vec![f32::MAX; dim];
55 let mut maxs = vec![f32::MIN; dim];
56
57 for v in vectors {
58 debug_assert_eq!(v.len(), dim);
59 for d in 0..dim {
60 if v[d] < mins[d] {
61 mins[d] = v[d];
62 }
63 if v[d] > maxs[d] {
64 maxs[d] = v[d];
65 }
66 }
67 }
68
69 let mut scales = vec![0.0f32; dim];
70 let mut inv_scales = vec![0.0f32; dim];
71 for d in 0..dim {
72 let range = maxs[d] - mins[d];
73 if range > f32::EPSILON {
74 scales[d] = range / 255.0;
75 inv_scales[d] = 255.0 / range;
76 }
77 }
78
79 Self {
80 dim,
81 mins,
82 maxs,
83 scales,
84 inv_scales,
85 }
86 }
87
88 pub fn quantize(&self, vector: &[f32]) -> Vec<u8> {
90 debug_assert_eq!(vector.len(), self.dim);
91 let mut out = Vec::with_capacity(self.dim);
93 for ((&v, &min), (&max, &inv_scale)) in vector
94 .iter()
95 .zip(self.mins.iter())
96 .zip(self.maxs.iter().zip(self.inv_scales.iter()))
97 {
98 let clamped = v.clamp(min, max);
99 let q = ((clamped - min) * inv_scale).round() as u8;
100 out.push(q);
101 }
102 out
103 }
104
105 pub fn quantize_batch(&self, vectors: &[&[f32]]) -> Vec<u8> {
109 let mut out = Vec::with_capacity(self.dim * vectors.len());
111 for v in vectors {
112 out.extend(self.quantize(v));
113 }
114 out
115 }
116
117 pub fn dequantize(&self, quantized: &[u8]) -> Vec<f32> {
119 debug_assert_eq!(quantized.len(), self.dim);
120 let mut out = Vec::with_capacity(self.dim);
122 for ((&q, &min), &scale) in quantized
123 .iter()
124 .zip(self.mins.iter())
125 .zip(self.scales.iter())
126 {
127 out.push(min + q as f32 * scale);
128 }
129 out
130 }
131
132 #[inline]
137 pub fn asymmetric_l2(&self, query: &[f32], candidate: &[u8]) -> f32 {
138 debug_assert_eq!(query.len(), self.dim);
139 debug_assert_eq!(candidate.len(), self.dim);
140 let mut sum = 0.0f32;
141 for d in 0..self.dim {
142 let dequant = self.mins[d] + candidate[d] as f32 * self.scales[d];
143 let diff = query[d] - dequant;
144 sum += diff * diff;
145 }
146 sum
147 }
148
149 #[inline]
151 pub fn asymmetric_cosine(&self, query: &[f32], candidate: &[u8]) -> f32 {
152 debug_assert_eq!(query.len(), self.dim);
153 debug_assert_eq!(candidate.len(), self.dim);
154 let mut dot = 0.0f32;
155 let mut norm_q = 0.0f32;
156 let mut norm_c = 0.0f32;
157 for d in 0..self.dim {
158 let dequant = self.mins[d] + candidate[d] as f32 * self.scales[d];
159 dot += query[d] * dequant;
160 norm_q += query[d] * query[d];
161 norm_c += dequant * dequant;
162 }
163 let denom = (norm_q * norm_c).sqrt();
164 if denom < f32::EPSILON {
165 return 1.0;
166 }
167 (1.0 - dot / denom).max(0.0)
168 }
169
170 #[inline]
172 pub fn asymmetric_ip(&self, query: &[f32], candidate: &[u8]) -> f32 {
173 debug_assert_eq!(query.len(), self.dim);
174 debug_assert_eq!(candidate.len(), self.dim);
175 let mut dot = 0.0f32;
176 for d in 0..self.dim {
177 let dequant = self.mins[d] + candidate[d] as f32 * self.scales[d];
178 dot += query[d] * dequant;
179 }
180 -dot
181 }
182
183 pub fn dim(&self) -> usize {
185 self.dim
186 }
187
188 pub fn to_bytes(&self) -> Vec<u8> {
192 let payload = zerompk::to_msgpack_vec(self).unwrap_or_default();
193 let mut out = Vec::with_capacity(7 + payload.len());
195 out.extend_from_slice(MAGIC);
196 out.push(SQ8_FORMAT_VERSION);
197 out.extend_from_slice(&payload);
198 out
199 }
200
201 pub fn from_bytes(bytes: &[u8]) -> Result<Self, VectorError> {
206 if bytes.len() < 7 || &bytes[0..6] != MAGIC {
207 return Err(VectorError::InvalidMagic);
208 }
209 let version = bytes[6];
210 if version != SQ8_FORMAT_VERSION {
211 return Err(VectorError::UnsupportedVersion {
212 found: version,
213 expected: SQ8_FORMAT_VERSION,
214 });
215 }
216 zerompk::from_msgpack::<Self>(&bytes[7..])
217 .map_err(|e| VectorError::DeserializationFailed(e.to_string()))
218 }
219}
220
221#[cfg(test)]
222mod tests {
223 use super::*;
224
225 fn make_codec() -> Sq8Codec {
226 let vecs: Vec<Vec<f32>> = (0..100)
227 .map(|i| vec![i as f32 * 0.1, (i as f32).sin(), (i as f32).cos()])
228 .collect();
229 let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
230 Sq8Codec::calibrate(&refs, 3)
231 }
232
233 #[test]
234 fn sq8_codec_golden_format() {
235 let codec = make_codec();
236 let bytes = codec.to_bytes();
237 assert_eq!(&bytes[0..6], MAGIC);
239 assert_eq!(bytes[6], SQ8_FORMAT_VERSION);
241 let decoded = zerompk::from_msgpack::<Sq8Codec>(&bytes[7..]).unwrap();
243 assert_eq!(decoded.dim, 3);
244 }
245
246 #[test]
247 fn sq8_roundtrip() {
248 let codec = make_codec();
249 let bytes = codec.to_bytes();
250 let restored = Sq8Codec::from_bytes(&bytes).unwrap();
251 assert_eq!(restored.dim, codec.dim);
252 assert_eq!(restored.inv_scales.len(), codec.inv_scales.len());
253 for (a, b) in restored.inv_scales.iter().zip(codec.inv_scales.iter()) {
254 assert!((a - b).abs() < 1e-6, "inv_scales mismatch: {a} vs {b}");
255 }
256 }
257
258 #[test]
259 fn sq8_invalid_magic_returns_error() {
260 let mut bytes = make_codec().to_bytes();
261 bytes[0] = b'X'; assert!(matches!(
263 Sq8Codec::from_bytes(&bytes),
264 Err(VectorError::InvalidMagic)
265 ));
266 }
267
268 #[test]
269 fn sq8_version_mismatch_returns_error() {
270 let mut bytes = make_codec().to_bytes();
271 bytes[6] = 0; assert!(matches!(
273 Sq8Codec::from_bytes(&bytes),
274 Err(VectorError::UnsupportedVersion {
275 found: 0,
276 expected: 1
277 })
278 ));
279 }
280
281 fn make_vectors() -> Vec<Vec<f32>> {
282 (0..100)
283 .map(|i| vec![i as f32 * 0.1, (i as f32).sin(), (i as f32).cos()])
284 .collect()
285 }
286
287 #[test]
288 fn quantize_dequantize_roundtrip() {
289 let vecs = make_vectors();
290 let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
291 let codec = Sq8Codec::calibrate(&refs, 3);
292
293 for v in &vecs {
294 let q = codec.quantize(v);
295 let dq = codec.dequantize(&q);
296 for d in 0..3 {
297 let error = (v[d] - dq[d]).abs();
298 let range = codec.maxs[d] - codec.mins[d];
299 assert!(
301 error <= range / 255.0 + 1e-6,
302 "d={d}: error={error}, max_step={}",
303 range / 255.0
304 );
305 }
306 }
307 }
308
309 #[test]
310 fn asymmetric_l2_close_to_exact() {
311 let vecs = make_vectors();
312 let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
313 let codec = Sq8Codec::calibrate(&refs, 3);
314
315 let query = &[5.0, 0.5, -0.5];
316 for v in &vecs {
317 let q = codec.quantize(v);
318 let exact = crate::distance::l2_squared(query, v);
319 let approx = codec.asymmetric_l2(query, &q);
320 let rel_error = if exact > 0.01 {
322 (exact - approx).abs() / exact
323 } else {
324 (exact - approx).abs()
325 };
326 assert!(
327 rel_error < 0.05 || (exact - approx).abs() < 0.1,
328 "exact={exact}, approx={approx}, rel_error={rel_error}"
329 );
330 }
331 }
332
333 #[test]
334 fn batch_quantize() {
335 let vecs = make_vectors();
336 let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
337 let codec = Sq8Codec::calibrate(&refs, 3);
338
339 let batch = codec.quantize_batch(&refs);
340 assert_eq!(batch.len(), 3 * 100);
341
342 let single = codec.quantize(&vecs[0]);
344 assert_eq!(&batch[0..3], &single[..]);
345 }
346
347 #[test]
348 fn constant_dimension_handled() {
349 let vecs: Vec<Vec<f32>> = (0..10).map(|i| vec![5.0, i as f32]).collect();
351 let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
352 let codec = Sq8Codec::calibrate(&refs, 2);
353
354 let q = codec.quantize(&[5.0, 3.0]);
356 assert_eq!(q[0], 0); }
358}