1use std::f64;
40
41#[derive(Clone, Copy, Debug, PartialEq, Eq)]
43pub enum QuantizationMode {
44 Ternary,
47
48 Polar,
51
52 Turbo,
55
56 Hybrid,
58}
59
60#[derive(Clone, Debug)]
62pub struct QuantizationResult {
63 pub data: Vec<f64>,
65 pub mode: QuantizationMode,
67 pub bits: u8,
69 pub mse: f64,
71 pub constraints_satisfied: bool,
73 pub unit_norm_preserved: bool,
75}
76
77impl QuantizationResult {
78 pub fn new(data: Vec<f64>, mode: QuantizationMode, bits: u8) -> Self {
80 Self {
81 data,
82 mode,
83 bits,
84 mse: 0.0,
85 constraints_satisfied: true,
86 unit_norm_preserved: true,
87 }
88 }
89
90 pub fn norm(&self) -> f64 {
92 self.data.iter().map(|x| x * x).sum::<f64>().sqrt()
93 }
94
95 pub fn check_unit_norm(&self, tolerance: f64) -> bool {
97 (self.norm() - 1.0).abs() < tolerance
98 }
99}
100
101#[derive(Clone, Debug)]
106pub struct PythagoreanQuantizer {
107 pub mode: QuantizationMode,
109 pub bits: u8,
111 max_denominator: usize,
113}
114
115impl PythagoreanQuantizer {
116 pub fn new(mode: QuantizationMode, bits: u8) -> Self {
131 Self {
132 mode,
133 bits: bits.max(1),
134 max_denominator: 100,
135 }
136 }
137
138 pub fn for_llm() -> Self {
140 Self::new(QuantizationMode::Ternary, 1)
141 }
142
143 pub fn for_embeddings() -> Self {
145 Self::new(QuantizationMode::Polar, 8)
146 }
147
148 pub fn for_vector_db() -> Self {
150 Self::new(QuantizationMode::Turbo, 4)
151 }
152
153 pub fn hybrid() -> Self {
155 Self::new(QuantizationMode::Hybrid, 4)
156 }
157
158 pub fn quantize(&self, data: &[f64]) -> QuantizationResult {
180 let mode = self.select_mode(data);
181
182 let (quantized, mse) = match mode {
183 QuantizationMode::Ternary => self.quantize_ternary(data),
184 QuantizationMode::Polar => self.quantize_polar(data),
185 QuantizationMode::Turbo => self.quantize_turbo(data),
186 QuantizationMode::Hybrid => self.quantize_hybrid(data),
187 };
188
189 let mut result = QuantizationResult::new(quantized, mode, self.bits);
190 result.mse = mse;
191 result.unit_norm_preserved = self.check_unit_norm(&result.data);
192 result.constraints_satisfied = result.unit_norm_preserved || mode != QuantizationMode::Polar;
193
194 result
195 }
196
197 fn select_mode(&self, data: &[f64]) -> QuantizationMode {
199 if self.mode != QuantizationMode::Hybrid {
200 return self.mode;
201 }
202
203 let norm: f64 = data.iter().map(|x| x * x).sum::<f64>().sqrt();
205 let is_unit_norm = (norm - 1.0).abs() < 0.01;
206
207 let threshold = 0.1;
209 let sparse_count = data.iter().filter(|&&x| x.abs() < threshold).count();
210 let sparsity = sparse_count as f64 / data.len() as f64;
211
212 if is_unit_norm {
213 QuantizationMode::Polar
214 } else if sparsity > 0.5 {
215 QuantizationMode::Ternary
216 } else {
217 QuantizationMode::Turbo
218 }
219 }
220
221 fn quantize_ternary(&self, data: &[f64]) -> (Vec<f64>, f64) {
225 let mean_abs: f64 = data.iter().map(|x| x.abs()).sum::<f64>() / data.len().max(1) as f64;
227 let threshold = mean_abs * 0.1; let quantized: Vec<f64> = data.iter().map(|&x| {
230 if x.abs() < threshold {
231 0.0
232 } else if x > 0.0 {
233 1.0
234 } else {
235 -1.0
236 }
237 }).collect();
238
239 let mse: f64 = data.iter()
240 .zip(quantized.iter())
241 .map(|(o, q)| (o - q).powi(2))
242 .sum::<f64>() / data.len().max(1) as f64;
243
244 (quantized, mse)
245 }
246
247 fn quantize_polar(&self, data: &[f64]) -> (Vec<f64>, f64) {
251 let n = data.len();
252 if n < 2 {
253 return (data.to_vec(), 0.0);
254 }
255
256 let norm: f64 = data.iter().map(|x| x * x).sum::<f64>().sqrt();
258 if norm < 1e-10 {
259 return (vec![1.0], 0.0);
260 }
261
262 let normalized: Vec<f64> = data.iter().map(|&x| x / norm).collect();
264
265 let mut quantized = vec![0.0; n];
267
268 for i in (0..n).step_by(2) {
269 if i + 1 < n {
270 let (q0, q1) = self.quantize_polar_pair(normalized[i], normalized[i + 1]);
271 quantized[i] = q0;
272 quantized[i + 1] = q1;
273 } else {
274 quantized[i] = self.snap_to_pythagorean(normalized[i]);
276 }
277 }
278
279 let q_norm: f64 = quantized.iter().map(|x| x * x).sum::<f64>().sqrt();
281 if q_norm > 1e-10 {
282 quantized = quantized.iter().map(|&x| x / q_norm).collect();
283 }
284
285 let mse: f64 = normalized.iter()
286 .zip(quantized.iter())
287 .map(|(o, q)| (o - q).powi(2))
288 .sum::<f64>() / n as f64;
289
290 (quantized, mse)
291 }
292
293 fn quantize_polar_pair(&self, x: f64, y: f64) -> (f64, f64) {
295 let angle = y.atan2(x);
297
298 let snapped_angle = self.snap_angle_to_pythagorean(angle);
300
301 (snapped_angle.cos(), snapped_angle.sin())
303 }
304
305 fn snap_angle_to_pythagorean(&self, angle: f64) -> f64 {
307 let pythagorean_angles: &[f64] = &[
309 0.0, std::f64::consts::FRAC_PI_2, std::f64::consts::PI, -std::f64::consts::FRAC_PI_2,
310 (4.0_f64 / 3.0).atan(),
312 (3.0_f64 / 4.0).atan(),
313 (12.0_f64 / 5.0).atan(),
315 (5.0_f64 / 12.0).atan(),
316 (15.0_f64 / 8.0).atan(),
318 (8.0_f64 / 15.0).atan(),
319 std::f64::consts::FRAC_PI_4,
321 std::f64::consts::FRAC_PI_6,
323 std::f64::consts::FRAC_PI_3,
325 ];
326
327 let mut best = angle;
328 let mut min_diff = f64::MAX;
329
330 for &pyth_angle in pythagorean_angles {
331 let diff = ((angle - pyth_angle).abs() % std::f64::consts::TAU)
333 .min((pyth_angle - angle).abs() % std::f64::consts::TAU);
334 if diff < min_diff {
335 min_diff = diff;
336 best = pyth_angle;
337 }
338 }
339
340 best
341 }
342
343 fn quantize_turbo(&self, data: &[f64]) -> (Vec<f64>, f64) {
347 let n = data.len();
348 if n == 0 {
349 return (vec![], 0.0);
350 }
351
352 let min_val = data.iter().cloned().fold(f64::INFINITY, f64::min);
354 let max_val = data.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
355 let range = max_val - min_val;
356
357 if range < 1e-10 {
358 return (vec![min_val; n], 0.0);
359 }
360
361 let levels = (1 << self.bits) as f64; let quantized: Vec<f64> = data.iter().map(|&x| {
366 let scaled = ((x - min_val) / range * (levels - 1.0)).round();
368 let snapped = self.snap_to_pythagorean(scaled / (levels - 1.0));
370 min_val + snapped * range
372 }).collect();
373
374 let mse: f64 = data.iter()
375 .zip(quantized.iter())
376 .map(|(o, q)| (o - q).powi(2))
377 .sum::<f64>() / n as f64;
378
379 (quantized, mse)
380 }
381
382 fn quantize_hybrid(&self, data: &[f64]) -> (Vec<f64>, f64) {
384 let mode = self.select_mode(data);
385 match mode {
386 QuantizationMode::Ternary => self.quantize_ternary(data),
387 QuantizationMode::Polar => self.quantize_polar(data),
388 QuantizationMode::Turbo => self.quantize_turbo(data),
389 QuantizationMode::Hybrid => self.quantize_turbo(data), }
391 }
392
393 pub fn snap_to_pythagorean(&self, value: f64) -> f64 {
397 let pythagorean_ratios: &[f64] = &[
399 0.0, 1.0,
400 3.0/5.0, 4.0/5.0,
401 5.0/13.0, 12.0/13.0,
402 8.0/17.0, 15.0/17.0,
403 7.0/25.0, 24.0/25.0,
404 20.0/29.0, 21.0/29.0,
405 9.0/41.0, 40.0/41.0,
406 0.5, 0.7071067811865476, ];
408
409 let mut best = value;
410 let mut min_dist = f64::MAX;
411
412 for &ratio in pythagorean_ratios {
413 let dist = (value - ratio).abs();
414 if dist < min_dist {
415 min_dist = dist;
416 best = ratio;
417 }
418 }
419
420 best
421 }
422
423 pub fn snap_to_lattice(&self, value: f64, max_denominator: usize) -> (f64, i64, u64) {
434 let mut best_val = value;
436 let mut best_num = value.round() as i64;
437 let mut best_den = 1u64;
438 let mut best_err = f64::MAX;
439
440 for c in 2..=max_denominator {
442 for a in 1..c {
443 let b_sq = (c * c - a * a) as f64;
444 if b_sq > 0.0 {
445 let b = b_sq.sqrt() as usize;
446 if b * b == (c * c - a * a) {
447 let ratio_a = a as f64 / c as f64;
449 let ratio_b = b as f64 / c as f64;
450
451 let err_a = (value - ratio_a).abs();
452 if err_a < best_err {
453 best_err = err_a;
454 best_val = ratio_a;
455 best_num = a as i64;
456 best_den = c as u64;
457 }
458
459 let err_b = (value - ratio_b).abs();
460 if err_b < best_err {
461 best_err = err_b;
462 best_val = ratio_b;
463 best_num = b as i64;
464 best_den = c as u64;
465 }
466 }
467 }
468 }
469 }
470
471 (best_val, best_num, best_den)
472 }
473
474 fn check_unit_norm(&self, data: &[f64]) -> bool {
476 let norm: f64 = data.iter().map(|x| x * x).sum::<f64>().sqrt();
477 (norm - 1.0).abs() < 0.01
478 }
479
480 pub fn quantize_batch(&self, vectors: &[Vec<f64>]) -> Vec<QuantizationResult> {
490 vectors.iter().map(|v| self.quantize(v)).collect()
491 }
492}
493
494impl Default for PythagoreanQuantizer {
495 fn default() -> Self {
496 Self::hybrid()
497 }
498}
499
500#[derive(Clone, Copy, Debug, PartialEq, Eq)]
502pub struct Rational {
503 pub num: i64,
505 pub den: u64,
507}
508
509impl Rational {
510 pub fn new(num: i64, den: u64) -> Self {
512 Self { num, den }
513 }
514
515 pub fn to_f64(&self) -> f64 {
517 self.num as f64 / self.den as f64
518 }
519
520 pub fn is_pythagorean(&self) -> bool {
522 let a = self.num.unsigned_abs() as u64;
524 let c = self.den;
525
526 if c == 0 {
527 return false;
528 }
529
530 if a > c {
531 return false;
532 }
533
534 let b_sq = c * c - a * a;
535 let b = (b_sq as f64).sqrt() as u64;
536 b * b == b_sq
537 }
538}
539
540#[cfg(test)]
541mod tests {
542 use super::*;
543
544 #[test]
545 fn test_quantization_modes() {
546 let data = vec![0.6, 0.8, 0.0, 0.0];
547
548 let q = PythagoreanQuantizer::new(QuantizationMode::Ternary, 1);
550 let result = q.quantize(&data);
551 assert_eq!(result.mode, QuantizationMode::Ternary);
552
553 let q = PythagoreanQuantizer::new(QuantizationMode::Polar, 8);
555 let result = q.quantize(&data);
556 assert!(result.check_unit_norm(0.1));
557
558 let q = PythagoreanQuantizer::new(QuantizationMode::Turbo, 4);
560 let result = q.quantize(&data);
561 assert_eq!(result.mode, QuantizationMode::Turbo);
562 }
563
564 #[test]
565 fn test_polar_unit_norm() {
566 let q = PythagoreanQuantizer::for_embeddings();
567
568 let vectors = vec![
570 vec![1.0, 0.0, 0.0, 0.0],
571 vec![0.707, 0.707, 0.0, 0.0],
572 vec![0.6, 0.8, 0.0, 0.0],
573 vec![0.5, 0.5, 0.5, 0.5],
574 ];
575
576 for v in vectors {
577 let result = q.quantize(&v);
578 assert!(result.check_unit_norm(0.1), "Failed for vector {:?}", v);
579 }
580 }
581
582 #[test]
583 fn test_ternary_quantization() {
584 let q = PythagoreanQuantizer::for_llm();
585 let data = vec![-0.8, -0.1, 0.1, 0.9];
586 let result = q.quantize(&data);
587
588 for &val in &result.data {
590 assert!(val == -1.0 || val == 0.0 || val == 1.0);
591 }
592 }
593
594 #[test]
595 fn test_snap_to_pythagorean() {
596 let q = PythagoreanQuantizer::new(QuantizationMode::Polar, 8);
597
598 let snapped = q.snap_to_pythagorean(0.6);
600 assert!((snapped - 0.6).abs() < 0.01);
601
602 let snapped = q.snap_to_pythagorean(0.8);
604 assert!((snapped - 0.8).abs() < 0.01);
605 }
606
607 #[test]
608 fn test_snap_to_lattice() {
609 let q = PythagoreanQuantizer::new(QuantizationMode::Polar, 8);
610
611 let (val, num, den) = q.snap_to_lattice(0.6, 20);
612 assert_eq!(num, 3);
613 assert_eq!(den, 5);
614 assert!((val - 0.6).abs() < 0.01);
615 }
616
617 #[test]
618 fn test_hybrid_mode_selection() {
619 let q = PythagoreanQuantizer::hybrid();
620
621 let unit = vec![0.6, 0.8];
623 assert_eq!(q.select_mode(&unit), QuantizationMode::Polar);
624
625 let sparse = vec![0.01, 0.02, 0.0, 0.0, 0.0, 0.0];
627 assert_eq!(q.select_mode(&sparse), QuantizationMode::Ternary);
628
629 let dense = vec![0.5, 0.6, 0.7, 0.8];
631 assert_eq!(q.select_mode(&dense), QuantizationMode::Turbo);
632 }
633
634 #[test]
635 fn test_rational() {
636 let r = Rational::new(3, 5);
637 assert!((r.to_f64() - 0.6).abs() < 1e-10);
638 assert!(r.is_pythagorean());
639
640 let r = Rational::new(4, 5);
641 assert!((r.to_f64() - 0.8).abs() < 1e-10);
642 assert!(r.is_pythagorean());
643
644 let r = Rational::new(1, 3);
645 assert!(!r.is_pythagorean());
646 }
647
648 #[test]
649 fn test_batch_quantization() {
650 let q = PythagoreanQuantizer::for_embeddings();
651 let vectors = vec![
652 vec![0.6, 0.8],
653 vec![1.0, 0.0],
654 vec![0.707, 0.707],
655 ];
656
657 let results = q.quantize_batch(&vectors);
658 assert_eq!(results.len(), 3);
659
660 for result in results {
661 assert!(result.check_unit_norm(0.1));
662 }
663 }
664
665 #[test]
666 fn test_empty_input() {
667 let q = PythagoreanQuantizer::hybrid();
668 let result = q.quantize(&[]);
669 assert!(result.data.is_empty());
670 }
671
672 #[test]
673 fn test_single_element() {
674 let q = PythagoreanQuantizer::new(QuantizationMode::Polar, 8);
675 let result = q.quantize(&[1.0]);
676 assert_eq!(result.data.len(), 1);
677 }
678}