use std::f64;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum QuantizationMode {
Ternary,
Polar,
Turbo,
Hybrid,
}
#[derive(Clone, Debug)]
pub struct QuantizationResult {
pub data: Vec<f64>,
pub mode: QuantizationMode,
pub bits: u8,
pub mse: f64,
pub constraints_satisfied: bool,
pub unit_norm_preserved: bool,
}
impl QuantizationResult {
pub fn new(data: Vec<f64>, mode: QuantizationMode, bits: u8) -> Self {
Self {
data,
mode,
bits,
mse: 0.0,
constraints_satisfied: true,
unit_norm_preserved: true,
}
}
pub fn norm(&self) -> f64 {
self.data.iter().map(|x| x * x).sum::<f64>().sqrt()
}
pub fn check_unit_norm(&self, tolerance: f64) -> bool {
(self.norm() - 1.0).abs() < tolerance
}
}
#[derive(Clone, Debug)]
pub struct PythagoreanQuantizer {
pub mode: QuantizationMode,
pub bits: u8,
max_denominator: usize,
}
impl PythagoreanQuantizer {
pub fn new(mode: QuantizationMode, bits: u8) -> Self {
Self {
mode,
bits: bits.max(1),
max_denominator: 100,
}
}
pub fn for_llm() -> Self {
Self::new(QuantizationMode::Ternary, 1)
}
pub fn for_embeddings() -> Self {
Self::new(QuantizationMode::Polar, 8)
}
pub fn for_vector_db() -> Self {
Self::new(QuantizationMode::Turbo, 4)
}
pub fn hybrid() -> Self {
Self::new(QuantizationMode::Hybrid, 4)
}
pub fn quantize(&self, data: &[f64]) -> QuantizationResult {
let mode = self.select_mode(data);
let (quantized, mse) = match mode {
QuantizationMode::Ternary => self.quantize_ternary(data),
QuantizationMode::Polar => self.quantize_polar(data),
QuantizationMode::Turbo => self.quantize_turbo(data),
QuantizationMode::Hybrid => self.quantize_hybrid(data),
};
let mut result = QuantizationResult::new(quantized, mode, self.bits);
result.mse = mse;
result.unit_norm_preserved = self.check_unit_norm(&result.data);
result.constraints_satisfied = result.unit_norm_preserved || mode != QuantizationMode::Polar;
result
}
fn select_mode(&self, data: &[f64]) -> QuantizationMode {
if self.mode != QuantizationMode::Hybrid {
return self.mode;
}
let norm: f64 = data.iter().map(|x| x * x).sum::<f64>().sqrt();
let is_unit_norm = (norm - 1.0).abs() < 0.01;
let threshold = 0.1;
let sparse_count = data.iter().filter(|&&x| x.abs() < threshold).count();
let sparsity = sparse_count as f64 / data.len() as f64;
if is_unit_norm {
QuantizationMode::Polar
} else if sparsity > 0.5 {
QuantizationMode::Ternary
} else {
QuantizationMode::Turbo
}
}
fn quantize_ternary(&self, data: &[f64]) -> (Vec<f64>, f64) {
let mean_abs: f64 = data.iter().map(|x| x.abs()).sum::<f64>() / data.len().max(1) as f64;
let threshold = mean_abs * 0.1;
let quantized: Vec<f64> = data.iter().map(|&x| {
if x.abs() < threshold {
0.0
} else if x > 0.0 {
1.0
} else {
-1.0
}
}).collect();
let mse: f64 = data.iter()
.zip(quantized.iter())
.map(|(o, q)| (o - q).powi(2))
.sum::<f64>() / data.len().max(1) as f64;
(quantized, mse)
}
fn quantize_polar(&self, data: &[f64]) -> (Vec<f64>, f64) {
let n = data.len();
if n < 2 {
return (data.to_vec(), 0.0);
}
let norm: f64 = data.iter().map(|x| x * x).sum::<f64>().sqrt();
if norm < 1e-10 {
return (vec![1.0], 0.0);
}
let normalized: Vec<f64> = data.iter().map(|&x| x / norm).collect();
let mut quantized = vec![0.0; n];
for i in (0..n).step_by(2) {
if i + 1 < n {
let (q0, q1) = self.quantize_polar_pair(normalized[i], normalized[i + 1]);
quantized[i] = q0;
quantized[i + 1] = q1;
} else {
quantized[i] = self.snap_to_pythagorean(normalized[i]);
}
}
let q_norm: f64 = quantized.iter().map(|x| x * x).sum::<f64>().sqrt();
if q_norm > 1e-10 {
quantized = quantized.iter().map(|&x| x / q_norm).collect();
}
let mse: f64 = normalized.iter()
.zip(quantized.iter())
.map(|(o, q)| (o - q).powi(2))
.sum::<f64>() / n as f64;
(quantized, mse)
}
fn quantize_polar_pair(&self, x: f64, y: f64) -> (f64, f64) {
let angle = y.atan2(x);
let snapped_angle = self.snap_angle_to_pythagorean(angle);
(snapped_angle.cos(), snapped_angle.sin())
}
fn snap_angle_to_pythagorean(&self, angle: f64) -> f64 {
let pythagorean_angles: &[f64] = &[
0.0, std::f64::consts::FRAC_PI_2, std::f64::consts::PI, -std::f64::consts::FRAC_PI_2,
(4.0_f64 / 3.0).atan(),
(3.0_f64 / 4.0).atan(),
(12.0_f64 / 5.0).atan(),
(5.0_f64 / 12.0).atan(),
(15.0_f64 / 8.0).atan(),
(8.0_f64 / 15.0).atan(),
std::f64::consts::FRAC_PI_4,
std::f64::consts::FRAC_PI_6,
std::f64::consts::FRAC_PI_3,
];
let mut best = angle;
let mut min_diff = f64::MAX;
for &pyth_angle in pythagorean_angles {
let diff = ((angle - pyth_angle).abs() % std::f64::consts::TAU)
.min((pyth_angle - angle).abs() % std::f64::consts::TAU);
if diff < min_diff {
min_diff = diff;
best = pyth_angle;
}
}
best
}
fn quantize_turbo(&self, data: &[f64]) -> (Vec<f64>, f64) {
let n = data.len();
if n == 0 {
return (vec![], 0.0);
}
let min_val = data.iter().cloned().fold(f64::INFINITY, f64::min);
let max_val = data.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
let range = max_val - min_val;
if range < 1e-10 {
return (vec![min_val; n], 0.0);
}
let levels = (1 << self.bits) as f64;
let quantized: Vec<f64> = data.iter().map(|&x| {
let scaled = ((x - min_val) / range * (levels - 1.0)).round();
let snapped = self.snap_to_pythagorean(scaled / (levels - 1.0));
min_val + snapped * range
}).collect();
let mse: f64 = data.iter()
.zip(quantized.iter())
.map(|(o, q)| (o - q).powi(2))
.sum::<f64>() / n as f64;
(quantized, mse)
}
fn quantize_hybrid(&self, data: &[f64]) -> (Vec<f64>, f64) {
let mode = self.select_mode(data);
match mode {
QuantizationMode::Ternary => self.quantize_ternary(data),
QuantizationMode::Polar => self.quantize_polar(data),
QuantizationMode::Turbo => self.quantize_turbo(data),
QuantizationMode::Hybrid => self.quantize_turbo(data), }
}
pub fn snap_to_pythagorean(&self, value: f64) -> f64 {
let pythagorean_ratios: &[f64] = &[
0.0, 1.0,
3.0/5.0, 4.0/5.0,
5.0/13.0, 12.0/13.0,
8.0/17.0, 15.0/17.0,
7.0/25.0, 24.0/25.0,
20.0/29.0, 21.0/29.0,
9.0/41.0, 40.0/41.0,
0.5, 0.7071067811865476, ];
let mut best = value;
let mut min_dist = f64::MAX;
for &ratio in pythagorean_ratios {
let dist = (value - ratio).abs();
if dist < min_dist {
min_dist = dist;
best = ratio;
}
}
best
}
pub fn snap_to_lattice(&self, value: f64, max_denominator: usize) -> (f64, i64, u64) {
let mut best_val = value;
let mut best_num = value.round() as i64;
let mut best_den = 1u64;
let mut best_err = f64::MAX;
for c in 2..=max_denominator {
for a in 1..c {
let b_sq = (c * c - a * a) as f64;
if b_sq > 0.0 {
let b = b_sq.sqrt() as usize;
if b * b == (c * c - a * a) {
let ratio_a = a as f64 / c as f64;
let ratio_b = b as f64 / c as f64;
let err_a = (value - ratio_a).abs();
if err_a < best_err {
best_err = err_a;
best_val = ratio_a;
best_num = a as i64;
best_den = c as u64;
}
let err_b = (value - ratio_b).abs();
if err_b < best_err {
best_err = err_b;
best_val = ratio_b;
best_num = b as i64;
best_den = c as u64;
}
}
}
}
}
(best_val, best_num, best_den)
}
fn check_unit_norm(&self, data: &[f64]) -> bool {
let norm: f64 = data.iter().map(|x| x * x).sum::<f64>().sqrt();
(norm - 1.0).abs() < 0.01
}
pub fn quantize_batch(&self, vectors: &[Vec<f64>]) -> Vec<QuantizationResult> {
vectors.iter().map(|v| self.quantize(v)).collect()
}
}
impl Default for PythagoreanQuantizer {
fn default() -> Self {
Self::hybrid()
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct Rational {
pub num: i64,
pub den: u64,
}
impl Rational {
pub fn new(num: i64, den: u64) -> Self {
Self { num, den }
}
pub fn to_f64(&self) -> f64 {
self.num as f64 / self.den as f64
}
pub fn is_pythagorean(&self) -> bool {
let a = self.num.unsigned_abs() as u64;
let c = self.den;
if c == 0 {
return false;
}
if a > c {
return false;
}
let b_sq = c * c - a * a;
let b = (b_sq as f64).sqrt() as u64;
b * b == b_sq
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_quantization_modes() {
let data = vec![0.6, 0.8, 0.0, 0.0];
let q = PythagoreanQuantizer::new(QuantizationMode::Ternary, 1);
let result = q.quantize(&data);
assert_eq!(result.mode, QuantizationMode::Ternary);
let q = PythagoreanQuantizer::new(QuantizationMode::Polar, 8);
let result = q.quantize(&data);
assert!(result.check_unit_norm(0.1));
let q = PythagoreanQuantizer::new(QuantizationMode::Turbo, 4);
let result = q.quantize(&data);
assert_eq!(result.mode, QuantizationMode::Turbo);
}
#[test]
fn test_polar_unit_norm() {
let q = PythagoreanQuantizer::for_embeddings();
let vectors = vec![
vec![1.0, 0.0, 0.0, 0.0],
vec![0.707, 0.707, 0.0, 0.0],
vec![0.6, 0.8, 0.0, 0.0],
vec![0.5, 0.5, 0.5, 0.5],
];
for v in vectors {
let result = q.quantize(&v);
assert!(result.check_unit_norm(0.1), "Failed for vector {:?}", v);
}
}
#[test]
fn test_ternary_quantization() {
let q = PythagoreanQuantizer::for_llm();
let data = vec![-0.8, -0.1, 0.1, 0.9];
let result = q.quantize(&data);
for &val in &result.data {
assert!(val == -1.0 || val == 0.0 || val == 1.0);
}
}
#[test]
fn test_snap_to_pythagorean() {
let q = PythagoreanQuantizer::new(QuantizationMode::Polar, 8);
let snapped = q.snap_to_pythagorean(0.6);
assert!((snapped - 0.6).abs() < 0.01);
let snapped = q.snap_to_pythagorean(0.8);
assert!((snapped - 0.8).abs() < 0.01);
}
#[test]
fn test_snap_to_lattice() {
let q = PythagoreanQuantizer::new(QuantizationMode::Polar, 8);
let (val, num, den) = q.snap_to_lattice(0.6, 20);
assert_eq!(num, 3);
assert_eq!(den, 5);
assert!((val - 0.6).abs() < 0.01);
}
#[test]
fn test_hybrid_mode_selection() {
let q = PythagoreanQuantizer::hybrid();
let unit = vec![0.6, 0.8];
assert_eq!(q.select_mode(&unit), QuantizationMode::Polar);
let sparse = vec![0.01, 0.02, 0.0, 0.0, 0.0, 0.0];
assert_eq!(q.select_mode(&sparse), QuantizationMode::Ternary);
let dense = vec![0.5, 0.6, 0.7, 0.8];
assert_eq!(q.select_mode(&dense), QuantizationMode::Turbo);
}
#[test]
fn test_rational() {
let r = Rational::new(3, 5);
assert!((r.to_f64() - 0.6).abs() < 1e-10);
assert!(r.is_pythagorean());
let r = Rational::new(4, 5);
assert!((r.to_f64() - 0.8).abs() < 1e-10);
assert!(r.is_pythagorean());
let r = Rational::new(1, 3);
assert!(!r.is_pythagorean());
}
#[test]
fn test_batch_quantization() {
let q = PythagoreanQuantizer::for_embeddings();
let vectors = vec![
vec![0.6, 0.8],
vec![1.0, 0.0],
vec![0.707, 0.707],
];
let results = q.quantize_batch(&vectors);
assert_eq!(results.len(), 3);
for result in results {
assert!(result.check_unit_norm(0.1));
}
}
#[test]
fn test_empty_input() {
let q = PythagoreanQuantizer::hybrid();
let result = q.quantize(&[]);
assert!(result.data.is_empty());
}
#[test]
fn test_single_element() {
let q = PythagoreanQuantizer::new(QuantizationMode::Polar, 8);
let result = q.quantize(&[1.0]);
assert_eq!(result.data.len(), 1);
}
}