1#[derive(Debug, Clone)]
3pub struct Codebook {
4 pub centroids: Vec<Vec<f32>>,
6 pub sub_dimension: usize,
8}
9
10impl Codebook {
11 pub fn new(sub_dimension: usize) -> Self {
13 Self {
14 centroids: Vec::new(),
15 sub_dimension,
16 }
17 }
18
19 pub fn nearest_centroid(&self, sub_vector: &[f32]) -> u8 {
22 let mut best_idx = 0u8;
23 let mut best_dist = f32::INFINITY;
24 for (idx, centroid) in self.centroids.iter().enumerate() {
25 let dist: f32 = sub_vector
26 .iter()
27 .zip(centroid.iter())
28 .map(|(a, b)| (a - b) * (a - b))
29 .sum();
30 if dist < best_dist {
31 best_dist = dist;
32 best_idx = idx as u8;
33 }
34 }
35 best_idx
36 }
37
38 pub fn centroid(&self, code: u8) -> Option<&Vec<f32>> {
40 self.centroids.get(code as usize)
41 }
42}
43
44#[derive(Debug, Clone)]
46pub struct QuantizedVector {
47 pub codes: Vec<u8>,
49 pub original_dim: usize,
51}
52
53#[derive(Debug, Clone)]
55pub struct ReconstructedVector {
56 pub vector: Vec<f32>,
58 pub quantization_error: f32,
60}
61
62#[derive(Debug, Clone, Copy)]
64pub struct QuantizerConfig {
65 pub n_subspaces: usize,
67 pub n_clusters: usize,
69}
70
71impl Default for QuantizerConfig {
72 fn default() -> Self {
73 Self {
74 n_subspaces: 4,
75 n_clusters: 256,
76 }
77 }
78}
79
80#[derive(Debug)]
82pub enum QuantizerError {
83 NotTrained,
85 DimensionMismatch,
87 InsufficientData(usize),
89 InvalidConfig(String),
91}
92
93impl std::fmt::Display for QuantizerError {
94 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
95 match self {
96 Self::NotTrained => write!(f, "Quantizer is not trained"),
97 Self::DimensionMismatch => write!(f, "Vector dimension mismatch"),
98 Self::InsufficientData(n) => {
99 write!(f, "Insufficient training data: {n} vectors")
100 }
101 Self::InvalidConfig(msg) => write!(f, "Invalid config: {msg}"),
102 }
103 }
104}
105
106impl std::error::Error for QuantizerError {}
107
108#[derive(Debug)]
110pub struct Quantizer {
111 config: QuantizerConfig,
112 codebooks: Vec<Codebook>,
113}
114
115impl Quantizer {
116 pub fn new(config: QuantizerConfig) -> Self {
118 Self {
119 config,
120 codebooks: Vec::new(),
121 }
122 }
123
124 pub fn train(&mut self, data: &[Vec<f32>]) -> Result<(), QuantizerError> {
128 if self.config.n_clusters > 256 {
130 return Err(QuantizerError::InvalidConfig(
131 "n_clusters must be ≤ 256".to_string(),
132 ));
133 }
134 if self.config.n_subspaces == 0 {
135 return Err(QuantizerError::InvalidConfig(
136 "n_subspaces must be > 0".to_string(),
137 ));
138 }
139 if data.is_empty() {
140 return Err(QuantizerError::InsufficientData(0));
141 }
142 if data.len() < self.config.n_clusters {
143 return Err(QuantizerError::InsufficientData(data.len()));
144 }
145
146 let dim = data[0].len();
147 if dim % self.config.n_subspaces != 0 {
148 return Err(QuantizerError::InvalidConfig(format!(
149 "Dimension {dim} is not divisible by n_subspaces {}",
150 self.config.n_subspaces
151 )));
152 }
153 let sub_dim = dim / self.config.n_subspaces;
154
155 for v in data {
157 if v.len() != dim {
158 return Err(QuantizerError::DimensionMismatch);
159 }
160 }
161
162 let actual_k = self.config.n_clusters.min(data.len());
164 let mut codebooks = Vec::with_capacity(self.config.n_subspaces);
165 for sub in 0..self.config.n_subspaces {
166 let start = sub * sub_dim;
167 let end = start + sub_dim;
168 let sub_vecs: Vec<Vec<f32>> = data.iter().map(|v| v[start..end].to_vec()).collect();
170 let cb = kmeans_train(&sub_vecs, actual_k, sub_dim, 50)?;
171 codebooks.push(cb);
172 }
173 self.codebooks = codebooks;
174 Ok(())
175 }
176
177 pub fn encode(&self, vector: &[f32]) -> Result<QuantizedVector, QuantizerError> {
179 if self.codebooks.is_empty() {
180 return Err(QuantizerError::NotTrained);
181 }
182 let dim = vector.len();
183 let expected_dim = self.codebooks.len() * self.codebooks[0].sub_dimension;
184 if dim != expected_dim {
185 return Err(QuantizerError::DimensionMismatch);
186 }
187 let sub_dim = self.codebooks[0].sub_dimension;
188 let codes: Vec<u8> = self
189 .codebooks
190 .iter()
191 .enumerate()
192 .map(|(i, cb)| {
193 let start = i * sub_dim;
194 let end = start + sub_dim;
195 cb.nearest_centroid(&vector[start..end])
196 })
197 .collect();
198 Ok(QuantizedVector {
199 codes,
200 original_dim: dim,
201 })
202 }
203
204 pub fn decode(&self, qv: &QuantizedVector) -> Result<ReconstructedVector, QuantizerError> {
206 if self.codebooks.is_empty() {
207 return Err(QuantizerError::NotTrained);
208 }
209 if qv.codes.len() != self.codebooks.len() {
210 return Err(QuantizerError::DimensionMismatch);
211 }
212 let mut result = Vec::with_capacity(qv.original_dim);
213 for (cb, &code) in self.codebooks.iter().zip(qv.codes.iter()) {
214 match cb.centroid(code) {
215 Some(c) => result.extend_from_slice(c),
216 None => return Err(QuantizerError::DimensionMismatch),
217 }
218 }
219 let error = 0.0_f32; Ok(ReconstructedVector {
221 vector: result,
222 quantization_error: error,
223 })
224 }
225
226 pub fn encode_batch(
228 &self,
229 vectors: &[Vec<f32>],
230 ) -> Result<Vec<QuantizedVector>, QuantizerError> {
231 vectors.iter().map(|v| self.encode(v)).collect()
232 }
233
234 pub fn is_trained(&self) -> bool {
236 !self.codebooks.is_empty()
237 }
238
239 pub fn compression_ratio(&self, original_dim: usize) -> f32 {
244 let n_sub = self.config.n_subspaces;
245 if n_sub == 0 {
246 return 1.0;
247 }
248 (original_dim as f32 * 4.0) / n_sub as f32
249 }
250
251 pub fn codebook_count(&self) -> usize {
253 self.codebooks.len()
254 }
255}
256
257fn kmeans_train(
261 sub_vecs: &[Vec<f32>],
262 k: usize,
263 sub_dim: usize,
264 max_iters: usize,
265) -> Result<Codebook, QuantizerError> {
266 let n = sub_vecs.len();
267 if k == 0 || n == 0 {
268 return Err(QuantizerError::InvalidConfig(
269 "k and n must be > 0".to_string(),
270 ));
271 }
272
273 let mut centroids = kmeans_init(sub_vecs, k, sub_dim);
275
276 for _iter in 0..max_iters {
277 let assignments: Vec<usize> = sub_vecs
279 .iter()
280 .map(|v| nearest_centroid_idx(¢roids, v))
281 .collect();
282
283 let mut sums: Vec<Vec<f64>> = vec![vec![0.0_f64; sub_dim]; k];
285 let mut counts: Vec<usize> = vec![0; k];
286
287 for (v, &a) in sub_vecs.iter().zip(assignments.iter()) {
288 for (d, &x) in v.iter().enumerate() {
289 sums[a][d] += x as f64;
290 }
291 counts[a] += 1;
292 }
293
294 let mut converged = true;
295 for (ci, centroid) in centroids.iter_mut().enumerate() {
296 if counts[ci] == 0 {
297 continue;
298 }
299 for d in 0..sub_dim {
300 let new_val = (sums[ci][d] / counts[ci] as f64) as f32;
301 if (new_val - centroid[d]).abs() > 1e-6 {
302 converged = false;
303 }
304 centroid[d] = new_val;
305 }
306 }
307 if converged {
308 break;
309 }
310 }
311
312 Ok(Codebook {
313 centroids,
314 sub_dimension: sub_dim,
315 })
316}
317
318fn kmeans_init(data: &[Vec<f32>], k: usize, sub_dim: usize) -> Vec<Vec<f32>> {
320 let n = data.len();
321 let mut chosen_indices: Vec<usize> = Vec::with_capacity(k);
324
325 let first_idx = (sub_dim * 7 + n * 3) % n;
327 chosen_indices.push(first_idx);
328
329 let mut distances: Vec<f32> = vec![f32::INFINITY; n];
330 for _ in 1..k {
331 for (i, v) in data.iter().enumerate() {
333 let last = &data[*chosen_indices.last().unwrap_or(&0)];
334 let dist: f32 = v
335 .iter()
336 .zip(last.iter())
337 .map(|(a, b)| (a - b) * (a - b))
338 .sum();
339 if dist < distances[i] {
340 distances[i] = dist;
341 }
342 }
343 let next_idx = distances
345 .iter()
346 .enumerate()
347 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
348 .map(|(i, _)| i)
349 .unwrap_or(0);
350 chosen_indices.push(next_idx);
351 }
352
353 chosen_indices
354 .into_iter()
355 .map(|i| data[i % n].clone())
356 .collect()
357}
358
359fn nearest_centroid_idx(centroids: &[Vec<f32>], v: &[f32]) -> usize {
361 centroids
362 .iter()
363 .enumerate()
364 .min_by(|(_, a), (_, b)| {
365 let da: f32 = a.iter().zip(v.iter()).map(|(x, y)| (x - y) * (x - y)).sum();
366 let db: f32 = b.iter().zip(v.iter()).map(|(x, y)| (x - y) * (x - y)).sum();
367 da.partial_cmp(&db).unwrap_or(std::cmp::Ordering::Equal)
368 })
369 .map(|(i, _)| i)
370 .unwrap_or(0)
371}
372
373#[cfg(test)]
374mod tests {
375 use super::*;
376
377 fn make_config(n_subspaces: usize, n_clusters: usize) -> QuantizerConfig {
378 QuantizerConfig {
379 n_subspaces,
380 n_clusters,
381 }
382 }
383
384 fn make_data(n: usize, dim: usize) -> Vec<Vec<f32>> {
386 (0..n)
387 .map(|i| (0..dim).map(|d| (i as f32 * 0.1) + d as f32).collect())
388 .collect()
389 }
390
391 #[test]
394 fn test_not_trained_initially() {
395 let q = Quantizer::new(make_config(4, 8));
396 assert!(!q.is_trained());
397 }
398
399 #[test]
400 fn test_is_trained_after_train() {
401 let mut q = Quantizer::new(make_config(4, 8));
402 let data = make_data(32, 8);
403 q.train(&data).unwrap();
404 assert!(q.is_trained());
405 }
406
407 #[test]
410 fn test_train_empty_data_error() {
411 let mut q = Quantizer::new(make_config(4, 8));
412 let err = q.train(&[]);
413 assert!(matches!(err, Err(QuantizerError::InsufficientData(0))));
414 }
415
416 #[test]
417 fn test_train_insufficient_data_error() {
418 let mut q = Quantizer::new(make_config(2, 10));
419 let data = make_data(5, 4); let err = q.train(&data);
421 assert!(matches!(err, Err(QuantizerError::InsufficientData(_))));
422 }
423
424 #[test]
425 fn test_train_n_clusters_over_256() {
426 let mut q = Quantizer::new(make_config(2, 300));
427 let data = make_data(400, 8);
428 let err = q.train(&data);
429 assert!(matches!(err, Err(QuantizerError::InvalidConfig(_))));
430 }
431
432 #[test]
433 fn test_train_dimension_not_divisible() {
434 let mut q = Quantizer::new(make_config(3, 4)); let data = make_data(20, 8);
436 let err = q.train(&data);
437 assert!(matches!(err, Err(QuantizerError::InvalidConfig(_))));
438 }
439
440 #[test]
443 fn test_encode_not_trained_error() {
444 let q = Quantizer::new(make_config(4, 8));
445 let v = vec![0.0f32; 8];
446 assert!(matches!(q.encode(&v), Err(QuantizerError::NotTrained)));
447 }
448
449 #[test]
450 fn test_encode_dimension_mismatch() {
451 let mut q = Quantizer::new(make_config(2, 4));
452 let data = make_data(16, 8);
453 q.train(&data).unwrap();
454 let v = vec![0.0f32; 4]; assert!(matches!(
456 q.encode(&v),
457 Err(QuantizerError::DimensionMismatch)
458 ));
459 }
460
461 #[test]
462 fn test_encode_codes_length() {
463 let mut q = Quantizer::new(make_config(4, 8));
464 let data = make_data(32, 8);
465 q.train(&data).unwrap();
466 let v = vec![1.0f32; 8];
467 let qv = q.encode(&v).unwrap();
468 assert_eq!(qv.codes.len(), 4); }
470
471 #[test]
472 fn test_encode_original_dim_stored() {
473 let mut q = Quantizer::new(make_config(2, 4));
474 let data = make_data(16, 8);
475 q.train(&data).unwrap();
476 let v = vec![0.0f32; 8];
477 let qv = q.encode(&v).unwrap();
478 assert_eq!(qv.original_dim, 8);
479 }
480
481 #[test]
482 fn test_decode_produces_correct_dim() {
483 let mut q = Quantizer::new(make_config(4, 8));
484 let data = make_data(32, 8);
485 q.train(&data).unwrap();
486 let v = vec![0.5f32; 8];
487 let qv = q.encode(&v).unwrap();
488 let rv = q.decode(&qv).unwrap();
489 assert_eq!(rv.vector.len(), 8);
490 }
491
492 #[test]
493 fn test_decode_not_trained_error() {
494 let q = Quantizer::new(make_config(4, 8));
495 let qv = QuantizedVector {
496 codes: vec![0; 4],
497 original_dim: 8,
498 };
499 assert!(matches!(q.decode(&qv), Err(QuantizerError::NotTrained)));
500 }
501
502 #[test]
503 fn test_encode_decode_approximates_original() {
504 let mut q = Quantizer::new(make_config(2, 4));
506 let mut data: Vec<Vec<f32>> = Vec::new();
508 for c in 0..4 {
509 for _ in 0..8 {
510 let v: Vec<f32> = (0..8).map(|d| (c as f32 * 10.0) + d as f32 * 0.1).collect();
511 data.push(v);
512 }
513 }
514 q.train(&data).unwrap();
515 let test = data[0].clone();
516 let qv = q.encode(&test).unwrap();
517 let rv = q.decode(&qv).unwrap();
518 for (&orig, &rec) in test.iter().zip(rv.vector.iter()) {
520 assert!((orig - rec).abs() < 5.0, "orig={orig}, rec={rec}");
521 }
522 }
523
524 #[test]
527 fn test_encode_batch_empty() {
528 let mut q = Quantizer::new(make_config(2, 4));
529 let data = make_data(16, 8);
530 q.train(&data).unwrap();
531 let result = q.encode_batch(&[]).unwrap();
532 assert!(result.is_empty());
533 }
534
535 #[test]
536 fn test_encode_batch_multiple() {
537 let mut q = Quantizer::new(make_config(2, 4));
538 let data = make_data(16, 8);
539 q.train(&data).unwrap();
540 let batch = data.clone();
541 let result = q.encode_batch(&batch).unwrap();
542 assert_eq!(result.len(), data.len());
543 }
544
545 #[test]
548 fn test_compression_ratio_basic() {
549 let q = Quantizer::new(make_config(4, 8));
550 let ratio = q.compression_ratio(128);
552 assert!((ratio - 128.0).abs() < 0.001);
553 }
554
555 #[test]
556 fn test_compression_ratio_formula() {
557 let q = Quantizer::new(make_config(8, 256));
558 let ratio = q.compression_ratio(64);
560 assert!((ratio - 32.0).abs() < 0.001);
561 }
562
563 #[test]
566 fn test_codebook_count_before_training() {
567 let q = Quantizer::new(make_config(4, 8));
568 assert_eq!(q.codebook_count(), 0);
569 }
570
571 #[test]
572 fn test_codebook_count_after_training_matches_n_subspaces() {
573 let mut q = Quantizer::new(make_config(4, 8));
574 let data = make_data(32, 8);
575 q.train(&data).unwrap();
576 assert_eq!(q.codebook_count(), 4);
577 }
578
579 #[test]
582 fn test_codebook_new() {
583 let cb = Codebook::new(4);
584 assert_eq!(cb.sub_dimension, 4);
585 assert!(cb.centroids.is_empty());
586 }
587
588 #[test]
589 fn test_nearest_centroid_single() {
590 let mut cb = Codebook::new(2);
591 cb.centroids = vec![vec![0.0, 0.0], vec![10.0, 10.0]];
592 let code = cb.nearest_centroid(&[1.0, 1.0]);
593 assert_eq!(code, 0); }
595
596 #[test]
597 fn test_nearest_centroid_second() {
598 let mut cb = Codebook::new(2);
599 cb.centroids = vec![vec![0.0, 0.0], vec![10.0, 10.0]];
600 let code = cb.nearest_centroid(&[9.0, 9.0]);
601 assert_eq!(code, 1); }
603
604 #[test]
605 fn test_centroid_valid_code() {
606 let mut cb = Codebook::new(2);
607 cb.centroids = vec![vec![1.0, 2.0]];
608 let c = cb.centroid(0).unwrap();
609 assert_eq!(c[0], 1.0);
610 }
611
612 #[test]
613 fn test_centroid_out_of_range() {
614 let cb = Codebook::new(2);
615 assert!(cb.centroid(5).is_none());
616 }
617
618 #[test]
621 fn test_not_trained_display() {
622 let e = QuantizerError::NotTrained;
623 assert!(format!("{e}").contains("trained"));
624 }
625
626 #[test]
627 fn test_dimension_mismatch_display() {
628 let e = QuantizerError::DimensionMismatch;
629 assert!(format!("{e}").contains("mismatch"));
630 }
631
632 #[test]
633 fn test_insufficient_data_display() {
634 let e = QuantizerError::InsufficientData(3);
635 assert!(format!("{e}").contains("3"));
636 }
637
638 #[test]
639 fn test_invalid_config_display() {
640 let e = QuantizerError::InvalidConfig("bad".to_string());
641 assert!(format!("{e}").contains("bad"));
642 }
643}