nodedb_vector/quantize/
sq8.rs1use serde::{Deserialize, Serialize};
17
18use crate::error::VectorError;
19
20pub const MAGIC: &[u8; 6] = b"NDSQ\0\0";
24
25pub const SQ8_FORMAT_VERSION: u8 = 1;
27
28#[derive(Clone, Serialize, Deserialize, zerompk::ToMessagePack, zerompk::FromMessagePack)]
30pub struct Sq8Codec {
31 pub dim: usize,
32 mins: Vec<f32>,
34 maxs: Vec<f32>,
36 scales: Vec<f32>,
39 inv_scales: Vec<f32>,
41}
42
43impl Sq8Codec {
44 pub fn calibrate(vectors: &[&[f32]], dim: usize) -> Self {
51 assert!(!vectors.is_empty(), "cannot calibrate on empty set");
52 assert!(dim > 0);
53
54 let mut mins = vec![f32::MAX; dim];
55 let mut maxs = vec![f32::MIN; dim];
56
57 for v in vectors {
58 debug_assert_eq!(v.len(), dim);
59 for d in 0..dim {
60 if v[d] < mins[d] {
61 mins[d] = v[d];
62 }
63 if v[d] > maxs[d] {
64 maxs[d] = v[d];
65 }
66 }
67 }
68
69 let mut scales = vec![0.0f32; dim];
70 let mut inv_scales = vec![0.0f32; dim];
71 for d in 0..dim {
72 let range = maxs[d] - mins[d];
73 if range > f32::EPSILON {
74 scales[d] = range / 255.0;
75 inv_scales[d] = 255.0 / range;
76 }
77 }
78
79 Self {
80 dim,
81 mins,
82 maxs,
83 scales,
84 inv_scales,
85 }
86 }
87
88 pub fn quantize(&self, vector: &[f32]) -> Vec<u8> {
90 debug_assert_eq!(vector.len(), self.dim);
91 let mut out = Vec::with_capacity(self.dim);
92 for ((&v, &min), (&max, &inv_scale)) in vector
93 .iter()
94 .zip(self.mins.iter())
95 .zip(self.maxs.iter().zip(self.inv_scales.iter()))
96 {
97 let clamped = v.clamp(min, max);
98 let q = ((clamped - min) * inv_scale).round() as u8;
99 out.push(q);
100 }
101 out
102 }
103
104 pub fn quantize_batch(&self, vectors: &[&[f32]]) -> Vec<u8> {
108 let mut out = Vec::with_capacity(self.dim * vectors.len());
109 for v in vectors {
110 out.extend(self.quantize(v));
111 }
112 out
113 }
114
115 pub fn dequantize(&self, quantized: &[u8]) -> Vec<f32> {
117 debug_assert_eq!(quantized.len(), self.dim);
118 let mut out = Vec::with_capacity(self.dim);
119 for ((&q, &min), &scale) in quantized
120 .iter()
121 .zip(self.mins.iter())
122 .zip(self.scales.iter())
123 {
124 out.push(min + q as f32 * scale);
125 }
126 out
127 }
128
129 #[inline]
134 pub fn asymmetric_l2(&self, query: &[f32], candidate: &[u8]) -> f32 {
135 debug_assert_eq!(query.len(), self.dim);
136 debug_assert_eq!(candidate.len(), self.dim);
137 let mut sum = 0.0f32;
138 for d in 0..self.dim {
139 let dequant = self.mins[d] + candidate[d] as f32 * self.scales[d];
140 let diff = query[d] - dequant;
141 sum += diff * diff;
142 }
143 sum
144 }
145
146 #[inline]
148 pub fn asymmetric_cosine(&self, query: &[f32], candidate: &[u8]) -> f32 {
149 debug_assert_eq!(query.len(), self.dim);
150 debug_assert_eq!(candidate.len(), self.dim);
151 let mut dot = 0.0f32;
152 let mut norm_q = 0.0f32;
153 let mut norm_c = 0.0f32;
154 for d in 0..self.dim {
155 let dequant = self.mins[d] + candidate[d] as f32 * self.scales[d];
156 dot += query[d] * dequant;
157 norm_q += query[d] * query[d];
158 norm_c += dequant * dequant;
159 }
160 let denom = (norm_q * norm_c).sqrt();
161 if denom < f32::EPSILON {
162 return 1.0;
163 }
164 (1.0 - dot / denom).max(0.0)
165 }
166
167 #[inline]
169 pub fn asymmetric_ip(&self, query: &[f32], candidate: &[u8]) -> f32 {
170 debug_assert_eq!(query.len(), self.dim);
171 debug_assert_eq!(candidate.len(), self.dim);
172 let mut dot = 0.0f32;
173 for d in 0..self.dim {
174 let dequant = self.mins[d] + candidate[d] as f32 * self.scales[d];
175 dot += query[d] * dequant;
176 }
177 -dot
178 }
179
180 pub fn dim(&self) -> usize {
182 self.dim
183 }
184
185 pub fn to_bytes(&self) -> Vec<u8> {
189 let payload = zerompk::to_msgpack_vec(self).unwrap_or_default();
190 let mut out = Vec::with_capacity(7 + payload.len());
191 out.extend_from_slice(MAGIC);
192 out.push(SQ8_FORMAT_VERSION);
193 out.extend_from_slice(&payload);
194 out
195 }
196
197 pub fn from_bytes(bytes: &[u8]) -> Result<Self, VectorError> {
202 if bytes.len() < 7 || &bytes[0..6] != MAGIC {
203 return Err(VectorError::InvalidMagic);
204 }
205 let version = bytes[6];
206 if version != SQ8_FORMAT_VERSION {
207 return Err(VectorError::UnsupportedVersion {
208 found: version,
209 expected: SQ8_FORMAT_VERSION,
210 });
211 }
212 zerompk::from_msgpack::<Self>(&bytes[7..])
213 .map_err(|e| VectorError::DeserializationFailed(e.to_string()))
214 }
215}
216
217#[cfg(test)]
218mod tests {
219 use super::*;
220
221 fn make_codec() -> Sq8Codec {
222 let vecs: Vec<Vec<f32>> = (0..100)
223 .map(|i| vec![i as f32 * 0.1, (i as f32).sin(), (i as f32).cos()])
224 .collect();
225 let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
226 Sq8Codec::calibrate(&refs, 3)
227 }
228
229 #[test]
230 fn sq8_codec_golden_format() {
231 let codec = make_codec();
232 let bytes = codec.to_bytes();
233 assert_eq!(&bytes[0..6], MAGIC);
235 assert_eq!(bytes[6], SQ8_FORMAT_VERSION);
237 let decoded = zerompk::from_msgpack::<Sq8Codec>(&bytes[7..]).unwrap();
239 assert_eq!(decoded.dim, 3);
240 }
241
242 #[test]
243 fn sq8_roundtrip() {
244 let codec = make_codec();
245 let bytes = codec.to_bytes();
246 let restored = Sq8Codec::from_bytes(&bytes).unwrap();
247 assert_eq!(restored.dim, codec.dim);
248 assert_eq!(restored.inv_scales.len(), codec.inv_scales.len());
249 for (a, b) in restored.inv_scales.iter().zip(codec.inv_scales.iter()) {
250 assert!((a - b).abs() < 1e-6, "inv_scales mismatch: {a} vs {b}");
251 }
252 }
253
254 #[test]
255 fn sq8_invalid_magic_returns_error() {
256 let mut bytes = make_codec().to_bytes();
257 bytes[0] = b'X'; assert!(matches!(
259 Sq8Codec::from_bytes(&bytes),
260 Err(VectorError::InvalidMagic)
261 ));
262 }
263
264 #[test]
265 fn sq8_version_mismatch_returns_error() {
266 let mut bytes = make_codec().to_bytes();
267 bytes[6] = 0; assert!(matches!(
269 Sq8Codec::from_bytes(&bytes),
270 Err(VectorError::UnsupportedVersion {
271 found: 0,
272 expected: 1
273 })
274 ));
275 }
276
277 fn make_vectors() -> Vec<Vec<f32>> {
278 (0..100)
279 .map(|i| vec![i as f32 * 0.1, (i as f32).sin(), (i as f32).cos()])
280 .collect()
281 }
282
283 #[test]
284 fn quantize_dequantize_roundtrip() {
285 let vecs = make_vectors();
286 let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
287 let codec = Sq8Codec::calibrate(&refs, 3);
288
289 for v in &vecs {
290 let q = codec.quantize(v);
291 let dq = codec.dequantize(&q);
292 for d in 0..3 {
293 let error = (v[d] - dq[d]).abs();
294 let range = codec.maxs[d] - codec.mins[d];
295 assert!(
297 error <= range / 255.0 + 1e-6,
298 "d={d}: error={error}, max_step={}",
299 range / 255.0
300 );
301 }
302 }
303 }
304
305 #[test]
306 fn asymmetric_l2_close_to_exact() {
307 let vecs = make_vectors();
308 let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
309 let codec = Sq8Codec::calibrate(&refs, 3);
310
311 let query = &[5.0, 0.5, -0.5];
312 for v in &vecs {
313 let q = codec.quantize(v);
314 let exact = crate::distance::l2_squared(query, v);
315 let approx = codec.asymmetric_l2(query, &q);
316 let rel_error = if exact > 0.01 {
318 (exact - approx).abs() / exact
319 } else {
320 (exact - approx).abs()
321 };
322 assert!(
323 rel_error < 0.05 || (exact - approx).abs() < 0.1,
324 "exact={exact}, approx={approx}, rel_error={rel_error}"
325 );
326 }
327 }
328
329 #[test]
330 fn batch_quantize() {
331 let vecs = make_vectors();
332 let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
333 let codec = Sq8Codec::calibrate(&refs, 3);
334
335 let batch = codec.quantize_batch(&refs);
336 assert_eq!(batch.len(), 3 * 100);
337
338 let single = codec.quantize(&vecs[0]);
340 assert_eq!(&batch[0..3], &single[..]);
341 }
342
343 #[test]
344 fn constant_dimension_handled() {
345 let vecs: Vec<Vec<f32>> = (0..10).map(|i| vec![5.0, i as f32]).collect();
347 let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
348 let codec = Sq8Codec::calibrate(&refs, 2);
349
350 let q = codec.quantize(&[5.0, 3.0]);
352 assert_eq!(q[0], 0); }
354}