Skip to main content

constraint_theory_core/
quantizer.rs

1//! Pythagorean Quantizer - Unified Quantization with Constraint Preservation
2//!
3//! This module provides the `PythagoreanQuantizer` which synthesizes multiple
4//! quantization technologies for exact constraint satisfaction:
5//!
6//! - **TurboQuant**: Near-optimal distortion, works online, O(d log d)
7//! - **BitNet**: Ternary weights {-1, 0, 1} for LLM inference
8//! - **PolarQuant**: Exact unit norm preservation via polar coordinate quantization
9//!
10//! # Architecture
11//!
12//! ```text
13//! Input ──► [Mode Selector] ──► [Quantizer] ──► [Constraint Layer]
14//!
15//! Modes:
16//! • TERNARY  (BitNet): {-1, 0, 1} for LLM weights
17//! • POLAR    (PolarQuant): Exact unit norm preservation
18//! • TURBO    (TurboQuant): Near-optimal distortion
19//! • HYBRID: Auto-select based on input characteristics
20//! ```
21//!
22//! # Example
23//!
24//! ```
25//! use constraint_theory_core::quantizer::{PythagoreanQuantizer, QuantizationMode};
26//!
27//! // Create a quantizer with POLAR mode for unit norm preservation
28//! let quantizer = PythagoreanQuantizer::new(QuantizationMode::Polar, 8);
29//!
30//! // Quantize a vector
31//! let vector = vec![0.6, 0.8, 0.0, 0.0];
32//! let result = quantizer.quantize(&vector);
33//!
34//! // Result preserves unit norm exactly
35//! let norm: f64 = result.data.iter().map(|x| x * x).sum::<f64>().sqrt();
36//! assert!((norm - 1.0).abs() < 0.01);
37//! ```
38
39use std::f64;
40
41/// Quantization modes supported by PythagoreanQuantizer.
42#[derive(Clone, Copy, Debug, PartialEq, Eq)]
43pub enum QuantizationMode {
44    /// Ternary quantization (BitNet style): {-1, 0, 1}
45    /// Best for: LLM weights, sparse representations
46    Ternary,
47    
48    /// Polar coordinate quantization (PolarQuant style)
49    /// Best for: Unit norm preservation, embeddings
50    Polar,
51    
52    /// Near-optimal distortion quantization (TurboQuant style)
53    /// Best for: Vector databases, general purpose
54    Turbo,
55    
56    /// Auto-select mode based on input characteristics
57    Hybrid,
58}
59
60/// Result of quantization operation.
61#[derive(Clone, Debug)]
62pub struct QuantizationResult {
63    /// Quantized data
64    pub data: Vec<f64>,
65    /// Quantization mode used
66    pub mode: QuantizationMode,
67    /// Bits per element
68    pub bits: u8,
69    /// Mean squared error from original
70    pub mse: f64,
71    /// Whether constraints are satisfied
72    pub constraints_satisfied: bool,
73    /// Unit norm preserved (for Polar mode)
74    pub unit_norm_preserved: bool,
75}
76
77impl QuantizationResult {
78    /// Create a new quantization result.
79    pub fn new(data: Vec<f64>, mode: QuantizationMode, bits: u8) -> Self {
80        Self {
81            data,
82            mode,
83            bits,
84            mse: 0.0,
85            constraints_satisfied: true,
86            unit_norm_preserved: true,
87        }
88    }
89    
90    /// Compute the norm of the quantized vector.
91    pub fn norm(&self) -> f64 {
92        self.data.iter().map(|x| x * x).sum::<f64>().sqrt()
93    }
94    
95    /// Check if unit norm is preserved within tolerance.
96    pub fn check_unit_norm(&self, tolerance: f64) -> bool {
97        (self.norm() - 1.0).abs() < tolerance
98    }
99}
100
101/// Pythagorean Quantizer - Unified quantization with constraint preservation.
102///
103/// Integrates TurboQuant, BitNet, and PolarQuant technologies with
104/// Pythagorean snapping for exact constraint satisfaction.
105#[derive(Clone, Debug)]
106pub struct PythagoreanQuantizer {
107    /// Quantization mode
108    pub mode: QuantizationMode,
109    /// Bits per element (1 for ternary, 4-8 for others)
110    pub bits: u8,
111    /// Maximum denominator for Pythagorean ratios
112    max_denominator: usize,
113}
114
115impl PythagoreanQuantizer {
116    /// Create a new Pythagorean quantizer.
117    ///
118    /// # Arguments
119    ///
120    /// * `mode` - Quantization mode to use
121    /// * `bits` - Bits per element (1 for ternary, 4-8 for others)
122    ///
123    /// # Example
124    ///
125    /// ```
126    /// use constraint_theory_core::quantizer::{PythagoreanQuantizer, QuantizationMode};
127    ///
128    /// let quantizer = PythagoreanQuantizer::new(QuantizationMode::Ternary, 1);
129    /// ```
130    pub fn new(mode: QuantizationMode, bits: u8) -> Self {
131        Self {
132            mode,
133            bits: bits.max(1),
134            max_denominator: 100,
135        }
136    }
137    
138    /// Create a quantizer optimized for LLM weights (ternary).
139    pub fn for_llm() -> Self {
140        Self::new(QuantizationMode::Ternary, 1)
141    }
142    
143    /// Create a quantizer optimized for embeddings (polar).
144    pub fn for_embeddings() -> Self {
145        Self::new(QuantizationMode::Polar, 8)
146    }
147    
148    /// Create a quantizer optimized for vector databases (turbo).
149    pub fn for_vector_db() -> Self {
150        Self::new(QuantizationMode::Turbo, 4)
151    }
152    
153    /// Create a hybrid quantizer that auto-selects mode.
154    pub fn hybrid() -> Self {
155        Self::new(QuantizationMode::Hybrid, 4)
156    }
157    
158    /// Quantize data with constraint preservation.
159    ///
160    /// # Arguments
161    ///
162    /// * `data` - Input data to quantize
163    ///
164    /// # Returns
165    ///
166    /// Quantization result with preserved constraints
167    ///
168    /// # Example
169    ///
170    /// ```
171    /// use constraint_theory_core::quantizer::{PythagoreanQuantizer, QuantizationMode};
172    ///
173    /// let quantizer = PythagoreanQuantizer::new(QuantizationMode::Polar, 8);
174    /// let data = vec![0.6, 0.8, 0.0, 0.0];
175    /// let result = quantizer.quantize(&data);
176    ///
177    /// assert_eq!(result.data.len(), 4);
178    /// ```
179    pub fn quantize(&self, data: &[f64]) -> QuantizationResult {
180        let mode = self.select_mode(data);
181        
182        let (quantized, mse) = match mode {
183            QuantizationMode::Ternary => self.quantize_ternary(data),
184            QuantizationMode::Polar => self.quantize_polar(data),
185            QuantizationMode::Turbo => self.quantize_turbo(data),
186            QuantizationMode::Hybrid => self.quantize_hybrid(data),
187        };
188        
189        let mut result = QuantizationResult::new(quantized, mode, self.bits);
190        result.mse = mse;
191        result.unit_norm_preserved = self.check_unit_norm(&result.data);
192        result.constraints_satisfied = result.unit_norm_preserved || mode != QuantizationMode::Polar;
193        
194        result
195    }
196    
197    /// Auto-select quantization mode based on input characteristics.
198    fn select_mode(&self, data: &[f64]) -> QuantizationMode {
199        if self.mode != QuantizationMode::Hybrid {
200            return self.mode;
201        }
202        
203        // Check if input is already unit normalized
204        let norm: f64 = data.iter().map(|x| x * x).sum::<f64>().sqrt();
205        let is_unit_norm = (norm - 1.0).abs() < 0.01;
206        
207        // Check sparsity (for ternary mode)
208        let threshold = 0.1;
209        let sparse_count = data.iter().filter(|&&x| x.abs() < threshold).count();
210        let sparsity = sparse_count as f64 / data.len() as f64;
211        
212        if is_unit_norm {
213            QuantizationMode::Polar
214        } else if sparsity > 0.5 {
215            QuantizationMode::Ternary
216        } else {
217            QuantizationMode::Turbo
218        }
219    }
220    
221    /// Ternary quantization (BitNet style): {-1, 0, 1}.
222    ///
223    /// Achieves 16x memory reduction for LLM weights.
224    fn quantize_ternary(&self, data: &[f64]) -> (Vec<f64>, f64) {
225        // Compute threshold for zero bucket
226        let mean_abs: f64 = data.iter().map(|x| x.abs()).sum::<f64>() / data.len().max(1) as f64;
227        let threshold = mean_abs * 0.1; // Small values -> 0
228        
229        let quantized: Vec<f64> = data.iter().map(|&x| {
230            if x.abs() < threshold {
231                0.0
232            } else if x > 0.0 {
233                1.0
234            } else {
235                -1.0
236            }
237        }).collect();
238        
239        let mse: f64 = data.iter()
240            .zip(quantized.iter())
241            .map(|(o, q)| (o - q).powi(2))
242            .sum::<f64>() / data.len().max(1) as f64;
243        
244        (quantized, mse)
245    }
246    
247    /// Polar coordinate quantization (PolarQuant style).
248    ///
249    /// Preserves unit norm exactly via polar coordinate quantization.
250    fn quantize_polar(&self, data: &[f64]) -> (Vec<f64>, f64) {
251        let n = data.len();
252        if n < 2 {
253            return (data.to_vec(), 0.0);
254        }
255        
256        // Compute current norm
257        let norm: f64 = data.iter().map(|x| x * x).sum::<f64>().sqrt();
258        if norm < 1e-10 {
259            return (vec![1.0], 0.0);
260        }
261        
262        // Normalize first
263        let normalized: Vec<f64> = data.iter().map(|&x| x / norm).collect();
264        
265        // Convert to polar coordinates for each pair
266        let mut quantized = vec![0.0; n];
267        
268        for i in (0..n).step_by(2) {
269            if i + 1 < n {
270                let (q0, q1) = self.quantize_polar_pair(normalized[i], normalized[i + 1]);
271                quantized[i] = q0;
272                quantized[i + 1] = q1;
273            } else {
274                // Odd dimension - snap to nearest Pythagorean ratio
275                quantized[i] = self.snap_to_pythagorean(normalized[i]);
276            }
277        }
278        
279        // Re-normalize to ensure exact unit norm
280        let q_norm: f64 = quantized.iter().map(|x| x * x).sum::<f64>().sqrt();
281        if q_norm > 1e-10 {
282            quantized = quantized.iter().map(|&x| x / q_norm).collect();
283        }
284        
285        let mse: f64 = normalized.iter()
286            .zip(quantized.iter())
287            .map(|(o, q)| (o - q).powi(2))
288            .sum::<f64>() / n as f64;
289        
290        (quantized, mse)
291    }
292    
293    /// Quantize a 2D point using polar coordinates.
294    fn quantize_polar_pair(&self, x: f64, y: f64) -> (f64, f64) {
295        // Convert to angle
296        let angle = y.atan2(x);
297        
298        // Snap angle to nearest Pythagorean angle
299        let snapped_angle = self.snap_angle_to_pythagorean(angle);
300        
301        // Convert back to Cartesian (unit norm preserved)
302        (snapped_angle.cos(), snapped_angle.sin())
303    }
304    
305    /// Snap an angle to the nearest Pythagorean angle.
306    fn snap_angle_to_pythagorean(&self, angle: f64) -> f64 {
307        // Angles corresponding to common Pythagorean triples
308        let pythagorean_angles: &[f64] = &[
309            0.0, std::f64::consts::FRAC_PI_2, std::f64::consts::PI, -std::f64::consts::FRAC_PI_2,
310            // 3-4-5 triangle: atan(4/3) ≈ 0.927 radians
311            (4.0_f64 / 3.0).atan(),
312            (3.0_f64 / 4.0).atan(),
313            // 5-12-13 triangle
314            (12.0_f64 / 5.0).atan(),
315            (5.0_f64 / 12.0).atan(),
316            // 8-15-17 triangle
317            (15.0_f64 / 8.0).atan(),
318            (8.0_f64 / 15.0).atan(),
319            // 45 degrees
320            std::f64::consts::FRAC_PI_4,
321            // 30 degrees
322            std::f64::consts::FRAC_PI_6,
323            // 60 degrees
324            std::f64::consts::FRAC_PI_3,
325        ];
326        
327        let mut best = angle;
328        let mut min_diff = f64::MAX;
329        
330        for &pyth_angle in pythagorean_angles {
331            // Handle angle wrapping
332            let diff = ((angle - pyth_angle).abs() % std::f64::consts::TAU)
333                .min((pyth_angle - angle).abs() % std::f64::consts::TAU);
334            if diff < min_diff {
335                min_diff = diff;
336                best = pyth_angle;
337            }
338        }
339        
340        best
341    }
342    
343    /// Turbo quantization (TurboQuant style).
344    ///
345    /// Near-optimal distortion: D(b,d) ≤ 2.7 · D*(b,d)
346    fn quantize_turbo(&self, data: &[f64]) -> (Vec<f64>, f64) {
347        let n = data.len();
348        if n == 0 {
349            return (vec![], 0.0);
350        }
351        
352        // Compute statistics
353        let min_val = data.iter().cloned().fold(f64::INFINITY, f64::min);
354        let max_val = data.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
355        let range = max_val - min_val;
356        
357        if range < 1e-10 {
358            return (vec![min_val; n], 0.0);
359        }
360        
361        // Number of quantization levels
362        let levels = (1 << self.bits) as f64; // 2^bits
363        
364        // Quantize each value
365        let quantized: Vec<f64> = data.iter().map(|&x| {
366            // Scale to [0, levels-1]
367            let scaled = ((x - min_val) / range * (levels - 1.0)).round();
368            // Snap to Pythagorean ratio if close
369            let snapped = self.snap_to_pythagorean(scaled / (levels - 1.0));
370            // Scale back
371            min_val + snapped * range
372        }).collect();
373        
374        let mse: f64 = data.iter()
375            .zip(quantized.iter())
376            .map(|(o, q)| (o - q).powi(2))
377            .sum::<f64>() / n as f64;
378        
379        (quantized, mse)
380    }
381    
382    /// Hybrid quantization - combines best aspects of all modes.
383    fn quantize_hybrid(&self, data: &[f64]) -> (Vec<f64>, f64) {
384        let mode = self.select_mode(data);
385        match mode {
386            QuantizationMode::Ternary => self.quantize_ternary(data),
387            QuantizationMode::Polar => self.quantize_polar(data),
388            QuantizationMode::Turbo => self.quantize_turbo(data),
389            QuantizationMode::Hybrid => self.quantize_turbo(data), // Default to Turbo
390        }
391    }
392    
393    /// Snap a value to the nearest Pythagorean ratio.
394    ///
395    /// Pythagorean ratios are of the form a/c or b/c where a² + b² = c².
396    pub fn snap_to_pythagorean(&self, value: f64) -> f64 {
397        // Common Pythagorean ratios from primitive triples
398        let pythagorean_ratios: &[f64] = &[
399            0.0, 1.0,
400            3.0/5.0, 4.0/5.0,
401            5.0/13.0, 12.0/13.0,
402            8.0/17.0, 15.0/17.0,
403            7.0/25.0, 24.0/25.0,
404            20.0/29.0, 21.0/29.0,
405            9.0/41.0, 40.0/41.0,
406            0.5, 0.7071067811865476, // sqrt(2)/2
407        ];
408        
409        let mut best = value;
410        let mut min_dist = f64::MAX;
411        
412        for &ratio in pythagorean_ratios {
413            let dist = (value - ratio).abs();
414            if dist < min_dist {
415                min_dist = dist;
416                best = ratio;
417            }
418        }
419        
420        best
421    }
422    
423    /// Snap to Pythagorean lattice with explicit rational representation.
424    ///
425    /// # Arguments
426    ///
427    /// * `value` - Value to snap
428    /// * `max_denominator` - Maximum denominator for rational approximation
429    ///
430    /// # Returns
431    ///
432    /// Tuple of (snapped_value, numerator, denominator)
433    pub fn snap_to_lattice(&self, value: f64, max_denominator: usize) -> (f64, i64, u64) {
434        // Find best rational approximation with Pythagorean constraint
435        let mut best_val = value;
436        let mut best_num = value.round() as i64;
437        let mut best_den = 1u64;
438        let mut best_err = f64::MAX;
439        
440        // Check Pythagorean triples up to max_denominator
441        for c in 2..=max_denominator {
442            for a in 1..c {
443                let b_sq = (c * c - a * a) as f64;
444                if b_sq > 0.0 {
445                    let b = b_sq.sqrt() as usize;
446                    if b * b == (c * c - a * a) {
447                        // This is a valid Pythagorean triple
448                        let ratio_a = a as f64 / c as f64;
449                        let ratio_b = b as f64 / c as f64;
450                        
451                        let err_a = (value - ratio_a).abs();
452                        if err_a < best_err {
453                            best_err = err_a;
454                            best_val = ratio_a;
455                            best_num = a as i64;
456                            best_den = c as u64;
457                        }
458                        
459                        let err_b = (value - ratio_b).abs();
460                        if err_b < best_err {
461                            best_err = err_b;
462                            best_val = ratio_b;
463                            best_num = b as i64;
464                            best_den = c as u64;
465                        }
466                    }
467                }
468            }
469        }
470        
471        (best_val, best_num, best_den)
472    }
473    
474    /// Check if unit norm is preserved within tolerance.
475    fn check_unit_norm(&self, data: &[f64]) -> bool {
476        let norm: f64 = data.iter().map(|x| x * x).sum::<f64>().sqrt();
477        (norm - 1.0).abs() < 0.01
478    }
479    
480    /// Batch quantization for multiple vectors.
481    ///
482    /// # Arguments
483    ///
484    /// * `vectors` - Slice of vectors to quantize
485    ///
486    /// # Returns
487    ///
488    /// Vector of quantization results
489    pub fn quantize_batch(&self, vectors: &[Vec<f64>]) -> Vec<QuantizationResult> {
490        vectors.iter().map(|v| self.quantize(v)).collect()
491    }
492}
493
494impl Default for PythagoreanQuantizer {
495    fn default() -> Self {
496        Self::hybrid()
497    }
498}
499
500/// A rational number for exact representation.
501#[derive(Clone, Copy, Debug, PartialEq, Eq)]
502pub struct Rational {
503    /// Numerator
504    pub num: i64,
505    /// Denominator (always positive)
506    pub den: u64,
507}
508
509impl Rational {
510    /// Create a new rational number.
511    pub fn new(num: i64, den: u64) -> Self {
512        Self { num, den }
513    }
514    
515    /// Convert to floating point.
516    pub fn to_f64(&self) -> f64 {
517        self.num as f64 / self.den as f64
518    }
519    
520    /// Check if this is a Pythagorean ratio (part of a Pythagorean triple).
521    pub fn is_pythagorean(&self) -> bool {
522        // Check if numerator² + something² = denominator²
523        let a = self.num.unsigned_abs() as u64;
524        let c = self.den;
525        
526        if c == 0 {
527            return false;
528        }
529        
530        if a > c {
531            return false;
532        }
533        
534        let b_sq = c * c - a * a;
535        let b = (b_sq as f64).sqrt() as u64;
536        b * b == b_sq
537    }
538}
539
540#[cfg(test)]
541mod tests {
542    use super::*;
543
544    #[test]
545    fn test_quantization_modes() {
546        let data = vec![0.6, 0.8, 0.0, 0.0];
547        
548        // Test Ternary mode
549        let q = PythagoreanQuantizer::new(QuantizationMode::Ternary, 1);
550        let result = q.quantize(&data);
551        assert_eq!(result.mode, QuantizationMode::Ternary);
552        
553        // Test Polar mode
554        let q = PythagoreanQuantizer::new(QuantizationMode::Polar, 8);
555        let result = q.quantize(&data);
556        assert!(result.check_unit_norm(0.1));
557        
558        // Test Turbo mode
559        let q = PythagoreanQuantizer::new(QuantizationMode::Turbo, 4);
560        let result = q.quantize(&data);
561        assert_eq!(result.mode, QuantizationMode::Turbo);
562    }
563
564    #[test]
565    fn test_polar_unit_norm() {
566        let q = PythagoreanQuantizer::for_embeddings();
567        
568        // Test with various unit vectors
569        let vectors = vec![
570            vec![1.0, 0.0, 0.0, 0.0],
571            vec![0.707, 0.707, 0.0, 0.0],
572            vec![0.6, 0.8, 0.0, 0.0],
573            vec![0.5, 0.5, 0.5, 0.5],
574        ];
575        
576        for v in vectors {
577            let result = q.quantize(&v);
578            assert!(result.check_unit_norm(0.1), "Failed for vector {:?}", v);
579        }
580    }
581
582    #[test]
583    fn test_ternary_quantization() {
584        let q = PythagoreanQuantizer::for_llm();
585        let data = vec![-0.8, -0.1, 0.1, 0.9];
586        let result = q.quantize(&data);
587        
588        // All values should be -1, 0, or 1
589        for &val in &result.data {
590            assert!(val == -1.0 || val == 0.0 || val == 1.0);
591        }
592    }
593
594    #[test]
595    fn test_snap_to_pythagorean() {
596        let q = PythagoreanQuantizer::new(QuantizationMode::Polar, 8);
597        
598        // 0.6 should snap to 3/5
599        let snapped = q.snap_to_pythagorean(0.6);
600        assert!((snapped - 0.6).abs() < 0.01);
601        
602        // 0.8 should snap to 4/5
603        let snapped = q.snap_to_pythagorean(0.8);
604        assert!((snapped - 0.8).abs() < 0.01);
605    }
606
607    #[test]
608    fn test_snap_to_lattice() {
609        let q = PythagoreanQuantizer::new(QuantizationMode::Polar, 8);
610        
611        let (val, num, den) = q.snap_to_lattice(0.6, 20);
612        assert_eq!(num, 3);
613        assert_eq!(den, 5);
614        assert!((val - 0.6).abs() < 0.01);
615    }
616
617    #[test]
618    fn test_hybrid_mode_selection() {
619        let q = PythagoreanQuantizer::hybrid();
620        
621        // Unit norm vector -> should select Polar
622        let unit = vec![0.6, 0.8];
623        assert_eq!(q.select_mode(&unit), QuantizationMode::Polar);
624        
625        // Sparse vector -> should select Ternary
626        let sparse = vec![0.01, 0.02, 0.0, 0.0, 0.0, 0.0];
627        assert_eq!(q.select_mode(&sparse), QuantizationMode::Ternary);
628        
629        // Dense vector -> should select Turbo
630        let dense = vec![0.5, 0.6, 0.7, 0.8];
631        assert_eq!(q.select_mode(&dense), QuantizationMode::Turbo);
632    }
633
634    #[test]
635    fn test_rational() {
636        let r = Rational::new(3, 5);
637        assert!((r.to_f64() - 0.6).abs() < 1e-10);
638        assert!(r.is_pythagorean());
639        
640        let r = Rational::new(4, 5);
641        assert!((r.to_f64() - 0.8).abs() < 1e-10);
642        assert!(r.is_pythagorean());
643        
644        let r = Rational::new(1, 3);
645        assert!(!r.is_pythagorean());
646    }
647
648    #[test]
649    fn test_batch_quantization() {
650        let q = PythagoreanQuantizer::for_embeddings();
651        let vectors = vec![
652            vec![0.6, 0.8],
653            vec![1.0, 0.0],
654            vec![0.707, 0.707],
655        ];
656        
657        let results = q.quantize_batch(&vectors);
658        assert_eq!(results.len(), 3);
659        
660        for result in results {
661            assert!(result.check_unit_norm(0.1));
662        }
663    }
664
665    #[test]
666    fn test_empty_input() {
667        let q = PythagoreanQuantizer::hybrid();
668        let result = q.quantize(&[]);
669        assert!(result.data.is_empty());
670    }
671
672    #[test]
673    fn test_single_element() {
674        let q = PythagoreanQuantizer::new(QuantizationMode::Polar, 8);
675        let result = q.quantize(&[1.0]);
676        assert_eq!(result.data.len(), 1);
677    }
678}