1#[derive(Debug, Clone)]
3pub struct Codebook {
4 pub centroids: Vec<Vec<f32>>,
6 pub sub_dimension: usize,
8}
9
10impl Codebook {
11 pub fn new(sub_dimension: usize) -> Self {
13 Self {
14 centroids: Vec::new(),
15 sub_dimension,
16 }
17 }
18
19 pub fn nearest_centroid(&self, sub_vector: &[f32]) -> u8 {
22 let mut best_idx = 0u8;
23 let mut best_dist = f32::INFINITY;
24 for (idx, centroid) in self.centroids.iter().enumerate() {
25 let dist: f32 = sub_vector
26 .iter()
27 .zip(centroid.iter())
28 .map(|(a, b)| (a - b) * (a - b))
29 .sum();
30 if dist < best_dist {
31 best_dist = dist;
32 best_idx = idx as u8;
33 }
34 }
35 best_idx
36 }
37
38 pub fn centroid(&self, code: u8) -> Option<&Vec<f32>> {
40 self.centroids.get(code as usize)
41 }
42}
43
44#[derive(Debug, Clone)]
46pub struct QuantizedVector {
47 pub codes: Vec<u8>,
49 pub original_dim: usize,
51}
52
53#[derive(Debug, Clone)]
55pub struct ReconstructedVector {
56 pub vector: Vec<f32>,
58 pub quantization_error: f32,
60}
61
62#[derive(Debug, Clone, Copy)]
64pub struct QuantizerConfig {
65 pub n_subspaces: usize,
67 pub n_clusters: usize,
69}
70
71impl Default for QuantizerConfig {
72 fn default() -> Self {
73 Self {
74 n_subspaces: 4,
75 n_clusters: 256,
76 }
77 }
78}
79
80#[derive(Debug)]
82pub enum QuantizerError {
83 NotTrained,
85 DimensionMismatch,
87 InsufficientData(usize),
89 InvalidConfig(String),
91}
92
93impl std::fmt::Display for QuantizerError {
94 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
95 match self {
96 Self::NotTrained => write!(f, "Quantizer is not trained"),
97 Self::DimensionMismatch => write!(f, "Vector dimension mismatch"),
98 Self::InsufficientData(n) => {
99 write!(f, "Insufficient training data: {n} vectors")
100 }
101 Self::InvalidConfig(msg) => write!(f, "Invalid config: {msg}"),
102 }
103 }
104}
105
106impl std::error::Error for QuantizerError {}
107
108#[derive(Debug)]
110pub struct Quantizer {
111 config: QuantizerConfig,
112 codebooks: Vec<Codebook>,
113}
114
115impl Quantizer {
116 pub fn new(config: QuantizerConfig) -> Self {
118 Self {
119 config,
120 codebooks: Vec::new(),
121 }
122 }
123
124 pub fn train(&mut self, data: &[Vec<f32>]) -> Result<(), QuantizerError> {
128 if self.config.n_clusters > 256 {
130 return Err(QuantizerError::InvalidConfig(
131 "n_clusters must be ≤ 256".to_string(),
132 ));
133 }
134 if self.config.n_subspaces == 0 {
135 return Err(QuantizerError::InvalidConfig(
136 "n_subspaces must be > 0".to_string(),
137 ));
138 }
139 if data.is_empty() {
140 return Err(QuantizerError::InsufficientData(0));
141 }
142 if data.len() < self.config.n_clusters {
143 return Err(QuantizerError::InsufficientData(data.len()));
144 }
145
146 let dim = data[0].len();
147 if dim % self.config.n_subspaces != 0 {
148 return Err(QuantizerError::InvalidConfig(format!(
149 "Dimension {dim} is not divisible by n_subspaces {}",
150 self.config.n_subspaces
151 )));
152 }
153 let sub_dim = dim / self.config.n_subspaces;
154
155 for v in data {
157 if v.len() != dim {
158 return Err(QuantizerError::DimensionMismatch);
159 }
160 }
161
162 let actual_k = self.config.n_clusters.min(data.len());
164 let mut codebooks = Vec::with_capacity(self.config.n_subspaces);
165 for sub in 0..self.config.n_subspaces {
166 let start = sub * sub_dim;
167 let end = start + sub_dim;
168 let sub_vecs: Vec<Vec<f32>> = data.iter().map(|v| v[start..end].to_vec()).collect();
170 let cb = kmeans_train(&sub_vecs, actual_k, sub_dim, 50)?;
171 codebooks.push(cb);
172 }
173 self.codebooks = codebooks;
174 Ok(())
175 }
176
177 pub fn encode(&self, vector: &[f32]) -> Result<QuantizedVector, QuantizerError> {
179 if self.codebooks.is_empty() {
180 return Err(QuantizerError::NotTrained);
181 }
182 let dim = vector.len();
183 let expected_dim = self.codebooks.len() * self.codebooks[0].sub_dimension;
184 if dim != expected_dim {
185 return Err(QuantizerError::DimensionMismatch);
186 }
187 let sub_dim = self.codebooks[0].sub_dimension;
188 let codes: Vec<u8> = self
189 .codebooks
190 .iter()
191 .enumerate()
192 .map(|(i, cb)| {
193 let start = i * sub_dim;
194 let end = start + sub_dim;
195 cb.nearest_centroid(&vector[start..end])
196 })
197 .collect();
198 Ok(QuantizedVector {
199 codes,
200 original_dim: dim,
201 })
202 }
203
204 pub fn decode(&self, qv: &QuantizedVector) -> Result<ReconstructedVector, QuantizerError> {
206 if self.codebooks.is_empty() {
207 return Err(QuantizerError::NotTrained);
208 }
209 if qv.codes.len() != self.codebooks.len() {
210 return Err(QuantizerError::DimensionMismatch);
211 }
212 let mut result = Vec::with_capacity(qv.original_dim);
213 for (cb, &code) in self.codebooks.iter().zip(qv.codes.iter()) {
214 match cb.centroid(code) {
215 Some(c) => result.extend_from_slice(c),
216 None => return Err(QuantizerError::DimensionMismatch),
217 }
218 }
219 let error = 0.0_f32; Ok(ReconstructedVector {
221 vector: result,
222 quantization_error: error,
223 })
224 }
225
226 pub fn encode_batch(
228 &self,
229 vectors: &[Vec<f32>],
230 ) -> Result<Vec<QuantizedVector>, QuantizerError> {
231 vectors.iter().map(|v| self.encode(v)).collect()
232 }
233
234 pub fn is_trained(&self) -> bool {
236 !self.codebooks.is_empty()
237 }
238
239 pub fn compression_ratio(&self, original_dim: usize) -> f32 {
244 let n_sub = self.config.n_subspaces;
245 if n_sub == 0 {
246 return 1.0;
247 }
248 (original_dim as f32 * 4.0) / n_sub as f32
249 }
250
251 pub fn codebook_count(&self) -> usize {
253 self.codebooks.len()
254 }
255}
256
257fn kmeans_train(
261 sub_vecs: &[Vec<f32>],
262 k: usize,
263 sub_dim: usize,
264 max_iters: usize,
265) -> Result<Codebook, QuantizerError> {
266 let n = sub_vecs.len();
267 if k == 0 || n == 0 {
268 return Err(QuantizerError::InvalidConfig(
269 "k and n must be > 0".to_string(),
270 ));
271 }
272
273 let mut centroids = kmeans_init(sub_vecs, k, sub_dim);
275
276 for _iter in 0..max_iters {
277 let assignments: Vec<usize> = sub_vecs
279 .iter()
280 .map(|v| nearest_centroid_idx(¢roids, v))
281 .collect();
282
283 let mut sums: Vec<Vec<f64>> = vec![vec![0.0_f64; sub_dim]; k];
285 let mut counts: Vec<usize> = vec![0; k];
286
287 for (v, &a) in sub_vecs.iter().zip(assignments.iter()) {
288 for (d, &x) in v.iter().enumerate() {
289 sums[a][d] += x as f64;
290 }
291 counts[a] += 1;
292 }
293
294 let mut converged = true;
295 for (ci, centroid) in centroids.iter_mut().enumerate() {
296 if counts[ci] == 0 {
297 continue;
298 }
299 for d in 0..sub_dim {
300 let new_val = (sums[ci][d] / counts[ci] as f64) as f32;
301 if (new_val - centroid[d]).abs() > 1e-6 {
302 converged = false;
303 }
304 centroid[d] = new_val;
305 }
306 }
307 if converged {
308 break;
309 }
310 }
311
312 Ok(Codebook {
313 centroids,
314 sub_dimension: sub_dim,
315 })
316}
317
318fn kmeans_init(data: &[Vec<f32>], k: usize, sub_dim: usize) -> Vec<Vec<f32>> {
320 let n = data.len();
321 let mut chosen_indices: Vec<usize> = Vec::with_capacity(k);
324
325 let first_idx = (sub_dim * 7 + n * 3) % n;
327 chosen_indices.push(first_idx);
328
329 let mut distances: Vec<f32> = vec![f32::INFINITY; n];
330 for _ in 1..k {
331 for (i, v) in data.iter().enumerate() {
333 let last = &data[*chosen_indices.last().unwrap_or(&0)];
334 let dist: f32 = v
335 .iter()
336 .zip(last.iter())
337 .map(|(a, b)| (a - b) * (a - b))
338 .sum();
339 if dist < distances[i] {
340 distances[i] = dist;
341 }
342 }
343 let next_idx = distances
345 .iter()
346 .enumerate()
347 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
348 .map(|(i, _)| i)
349 .unwrap_or(0);
350 chosen_indices.push(next_idx);
351 }
352
353 chosen_indices
354 .into_iter()
355 .map(|i| data[i % n].clone())
356 .collect()
357}
358
359fn nearest_centroid_idx(centroids: &[Vec<f32>], v: &[f32]) -> usize {
361 centroids
362 .iter()
363 .enumerate()
364 .min_by(|(_, a), (_, b)| {
365 let da: f32 = a.iter().zip(v.iter()).map(|(x, y)| (x - y) * (x - y)).sum();
366 let db: f32 = b.iter().zip(v.iter()).map(|(x, y)| (x - y) * (x - y)).sum();
367 da.partial_cmp(&db).unwrap_or(std::cmp::Ordering::Equal)
368 })
369 .map(|(i, _)| i)
370 .unwrap_or(0)
371}
372
373#[cfg(test)]
374mod tests {
375 type Result<T> = std::result::Result<T, Box<dyn std::error::Error>>;
376 use super::*;
377
378 fn make_config(n_subspaces: usize, n_clusters: usize) -> QuantizerConfig {
379 QuantizerConfig {
380 n_subspaces,
381 n_clusters,
382 }
383 }
384
385 fn make_data(n: usize, dim: usize) -> Vec<Vec<f32>> {
387 (0..n)
388 .map(|i| (0..dim).map(|d| (i as f32 * 0.1) + d as f32).collect())
389 .collect()
390 }
391
392 #[test]
395 fn test_not_trained_initially() {
396 let q = Quantizer::new(make_config(4, 8));
397 assert!(!q.is_trained());
398 }
399
400 #[test]
401 fn test_is_trained_after_train() -> Result<()> {
402 let mut q = Quantizer::new(make_config(4, 8));
403 let data = make_data(32, 8);
404 q.train(&data)?;
405 assert!(q.is_trained());
406 Ok(())
407 }
408
409 #[test]
412 fn test_train_empty_data_error() {
413 let mut q = Quantizer::new(make_config(4, 8));
414 let err = q.train(&[]);
415 assert!(matches!(err, Err(QuantizerError::InsufficientData(0))));
416 }
417
418 #[test]
419 fn test_train_insufficient_data_error() {
420 let mut q = Quantizer::new(make_config(2, 10));
421 let data = make_data(5, 4); let err = q.train(&data);
423 assert!(matches!(err, Err(QuantizerError::InsufficientData(_))));
424 }
425
426 #[test]
427 fn test_train_n_clusters_over_256() {
428 let mut q = Quantizer::new(make_config(2, 300));
429 let data = make_data(400, 8);
430 let err = q.train(&data);
431 assert!(matches!(err, Err(QuantizerError::InvalidConfig(_))));
432 }
433
434 #[test]
435 fn test_train_dimension_not_divisible() {
436 let mut q = Quantizer::new(make_config(3, 4)); let data = make_data(20, 8);
438 let err = q.train(&data);
439 assert!(matches!(err, Err(QuantizerError::InvalidConfig(_))));
440 }
441
442 #[test]
445 fn test_encode_not_trained_error() {
446 let q = Quantizer::new(make_config(4, 8));
447 let v = vec![0.0f32; 8];
448 assert!(matches!(q.encode(&v), Err(QuantizerError::NotTrained)));
449 }
450
451 #[test]
452 fn test_encode_dimension_mismatch() -> Result<()> {
453 let mut q = Quantizer::new(make_config(2, 4));
454 let data = make_data(16, 8);
455 q.train(&data)?;
456 let v = vec![0.0f32; 4]; assert!(matches!(
458 q.encode(&v),
459 Err(QuantizerError::DimensionMismatch)
460 ));
461 Ok(())
462 }
463
464 #[test]
465 fn test_encode_codes_length() -> Result<()> {
466 let mut q = Quantizer::new(make_config(4, 8));
467 let data = make_data(32, 8);
468 q.train(&data)?;
469 let v = vec![1.0f32; 8];
470 let qv = q.encode(&v)?;
471 assert_eq!(qv.codes.len(), 4); Ok(())
473 }
474
475 #[test]
476 fn test_encode_original_dim_stored() -> Result<()> {
477 let mut q = Quantizer::new(make_config(2, 4));
478 let data = make_data(16, 8);
479 q.train(&data)?;
480 let v = vec![0.0f32; 8];
481 let qv = q.encode(&v)?;
482 assert_eq!(qv.original_dim, 8);
483 Ok(())
484 }
485
486 #[test]
487 fn test_decode_produces_correct_dim() -> Result<()> {
488 let mut q = Quantizer::new(make_config(4, 8));
489 let data = make_data(32, 8);
490 q.train(&data)?;
491 let v = vec![0.5f32; 8];
492 let qv = q.encode(&v)?;
493 let rv = q.decode(&qv)?;
494 assert_eq!(rv.vector.len(), 8);
495 Ok(())
496 }
497
498 #[test]
499 fn test_decode_not_trained_error() {
500 let q = Quantizer::new(make_config(4, 8));
501 let qv = QuantizedVector {
502 codes: vec![0; 4],
503 original_dim: 8,
504 };
505 assert!(matches!(q.decode(&qv), Err(QuantizerError::NotTrained)));
506 }
507
508 #[test]
509 fn test_encode_decode_approximates_original() -> Result<()> {
510 let mut q = Quantizer::new(make_config(2, 4));
512 let mut data: Vec<Vec<f32>> = Vec::new();
514 for c in 0..4 {
515 for _ in 0..8 {
516 let v: Vec<f32> = (0..8).map(|d| (c as f32 * 10.0) + d as f32 * 0.1).collect();
517 data.push(v);
518 }
519 }
520 q.train(&data)?;
521 let test = data[0].clone();
522 let qv = q.encode(&test)?;
523 let rv = q.decode(&qv)?;
524 for (&orig, &rec) in test.iter().zip(rv.vector.iter()) {
526 assert!((orig - rec).abs() < 5.0, "orig={orig}, rec={rec}");
527 }
528 Ok(())
529 }
530
531 #[test]
534 fn test_encode_batch_empty() -> Result<()> {
535 let mut q = Quantizer::new(make_config(2, 4));
536 let data = make_data(16, 8);
537 q.train(&data)?;
538 let result = q.encode_batch(&[])?;
539 assert!(result.is_empty());
540 Ok(())
541 }
542
543 #[test]
544 fn test_encode_batch_multiple() -> Result<()> {
545 let mut q = Quantizer::new(make_config(2, 4));
546 let data = make_data(16, 8);
547 q.train(&data)?;
548 let batch = data.clone();
549 let result = q.encode_batch(&batch)?;
550 assert_eq!(result.len(), data.len());
551 Ok(())
552 }
553
554 #[test]
557 fn test_compression_ratio_basic() {
558 let q = Quantizer::new(make_config(4, 8));
559 let ratio = q.compression_ratio(128);
561 assert!((ratio - 128.0).abs() < 0.001);
562 }
563
564 #[test]
565 fn test_compression_ratio_formula() {
566 let q = Quantizer::new(make_config(8, 256));
567 let ratio = q.compression_ratio(64);
569 assert!((ratio - 32.0).abs() < 0.001);
570 }
571
572 #[test]
575 fn test_codebook_count_before_training() {
576 let q = Quantizer::new(make_config(4, 8));
577 assert_eq!(q.codebook_count(), 0);
578 }
579
580 #[test]
581 fn test_codebook_count_after_training_matches_n_subspaces() -> Result<()> {
582 let mut q = Quantizer::new(make_config(4, 8));
583 let data = make_data(32, 8);
584 q.train(&data)?;
585 assert_eq!(q.codebook_count(), 4);
586 Ok(())
587 }
588
589 #[test]
592 fn test_codebook_new() {
593 let cb = Codebook::new(4);
594 assert_eq!(cb.sub_dimension, 4);
595 assert!(cb.centroids.is_empty());
596 }
597
598 #[test]
599 fn test_nearest_centroid_single() {
600 let mut cb = Codebook::new(2);
601 cb.centroids = vec![vec![0.0, 0.0], vec![10.0, 10.0]];
602 let code = cb.nearest_centroid(&[1.0, 1.0]);
603 assert_eq!(code, 0); }
605
606 #[test]
607 fn test_nearest_centroid_second() {
608 let mut cb = Codebook::new(2);
609 cb.centroids = vec![vec![0.0, 0.0], vec![10.0, 10.0]];
610 let code = cb.nearest_centroid(&[9.0, 9.0]);
611 assert_eq!(code, 1); }
613
614 #[test]
615 fn test_centroid_valid_code() -> Result<()> {
616 let mut cb = Codebook::new(2);
617 cb.centroids = vec![vec![1.0, 2.0]];
618 let c = cb.centroid(0).expect("centroid at index 0 should exist");
619 assert_eq!(c[0], 1.0);
620 Ok(())
621 }
622
623 #[test]
624 fn test_centroid_out_of_range() {
625 let cb = Codebook::new(2);
626 assert!(cb.centroid(5).is_none());
627 }
628
629 #[test]
632 fn test_not_trained_display() {
633 let e = QuantizerError::NotTrained;
634 assert!(format!("{e}").contains("trained"));
635 }
636
637 #[test]
638 fn test_dimension_mismatch_display() {
639 let e = QuantizerError::DimensionMismatch;
640 assert!(format!("{e}").contains("mismatch"));
641 }
642
643 #[test]
644 fn test_insufficient_data_display() {
645 let e = QuantizerError::InsufficientData(3);
646 assert!(format!("{e}").contains("3"));
647 }
648
649 #[test]
650 fn test_invalid_config_display() {
651 let e = QuantizerError::InvalidConfig("bad".to_string());
652 assert!(format!("{e}").contains("bad"));
653 }
654}