mod accumulator;
pub use accumulator::{OptimizerStats, TernaryGradientAccumulator, TernaryOptimizerWrapper};
use candle_core::{DType, Tensor};
use rand::Rng;
use trit_vsa::{PackedTritVec, Trit};
use crate::error::Result;
#[allow(clippy::cast_possible_truncation)]
pub fn ternary_quantize_stochastic(x: &Tensor, threshold: Option<f32>) -> Result<(Tensor, f32)> {
let x_f32 = x.to_dtype(DType::F32)?;
let flat = x_f32.flatten_all()?;
let data: Vec<f32> = flat.to_vec1()?;
let threshold = threshold.unwrap_or_else(|| {
if data.is_empty() {
0.0
} else {
data.iter().map(|v| v.abs()).sum::<f32>() / data.len() as f32
}
});
let scale = data.iter().map(|v| v.abs()).fold(0.0f32, f32::max);
if scale == 0.0 {
return Ok((x.zeros_like()?, 0.0));
}
let mut rng = rand::thread_rng();
let quantized: Vec<f32> = data
.iter()
.map(|&v| {
let normalized = v / scale;
let abs_norm = normalized.abs();
let rand_val: f32 = rng.gen();
if rand_val < abs_norm {
if normalized > 0.0 { 1.0 } else { -1.0 }
} else {
0.0
}
})
.collect();
let result = Tensor::from_vec(quantized, x.shape(), x.device())?;
Ok((result, scale))
}
#[allow(clippy::cast_possible_truncation)]
pub fn ternary_quantize_deterministic(x: &Tensor, threshold: Option<f32>) -> Result<(Tensor, f32)> {
let x_f32 = x.to_dtype(DType::F32)?;
let flat = x_f32.flatten_all()?;
let data: Vec<f32> = flat.to_vec1()?;
let threshold = threshold.unwrap_or_else(|| {
if data.is_empty() {
0.0
} else {
data.iter().map(|v| v.abs()).sum::<f32>() / data.len() as f32
}
});
let scale = data.iter().map(|v| v.abs()).fold(0.0f32, f32::max);
if scale == 0.0 {
return Ok((x.zeros_like()?, 0.0));
}
let quantized: Vec<f32> = data
.iter()
.map(|&v| {
if v > threshold {
1.0
} else if v < -threshold {
-1.0
} else {
0.0
}
})
.collect();
let result = Tensor::from_vec(quantized, x.shape(), x.device())?;
Ok((result, scale))
}
#[allow(clippy::cast_possible_truncation)]
pub fn ternary_quantize_stochastic_packed(
x: &Tensor,
_threshold: Option<f32>,
) -> Result<(PackedTritVec, f32)> {
let x_f32 = x.to_dtype(DType::F32)?;
let flat = x_f32.flatten_all()?;
let data: Vec<f32> = flat.to_vec1()?;
let scale = data.iter().map(|v| v.abs()).fold(0.0f32, f32::max);
if scale == 0.0 {
return Ok((PackedTritVec::new(data.len()), 0.0));
}
let mut rng = rand::thread_rng();
let mut packed = PackedTritVec::new(data.len());
for (i, &v) in data.iter().enumerate() {
let normalized = v / scale;
let abs_norm = normalized.abs();
let rand_val: f32 = rng.gen();
let trit = if rand_val < abs_norm {
if normalized > 0.0 { Trit::P } else { Trit::N }
} else {
Trit::Z
};
packed.set(i, trit);
}
Ok((packed, scale))
}
#[must_use]
#[allow(clippy::cast_precision_loss)]
pub fn calculate_memory_savings(param_count: usize, num_tensors: usize) -> f32 {
let full_bits = param_count * 32;
let ternary_bits = param_count * 2 + num_tensors * 32;
1.0 - (ternary_bits as f32 / full_bits as f32)
}
#[cfg(test)]
mod tests {
use super::*;
use candle_core::Device;
#[test]
fn test_ternary_quantize_stochastic() {
let device = Device::Cpu;
let x = Tensor::randn(0.0f32, 1.0, 1000, &device).unwrap();
let (quantized, scale) = ternary_quantize_stochastic(&x, None).unwrap();
assert!(scale > 0.0);
let values: Vec<f32> = quantized.flatten_all().unwrap().to_vec1().unwrap();
for v in values {
assert!(v == -1.0 || v == 0.0 || v == 1.0, "Unexpected value: {v}");
}
}
#[test]
fn test_ternary_quantize_deterministic() {
let device = Device::Cpu;
let x = Tensor::randn(0.0f32, 1.0, 1000, &device).unwrap();
let (quantized, scale) = ternary_quantize_deterministic(&x, None).unwrap();
assert!(scale > 0.0);
let values: Vec<f32> = quantized.flatten_all().unwrap().to_vec1().unwrap();
for v in values {
assert!(v == -1.0 || v == 0.0 || v == 1.0, "Unexpected value: {v}");
}
}
#[test]
fn test_ternary_quantize_zeros() {
let device = Device::Cpu;
let x = Tensor::zeros(100, DType::F32, &device).unwrap();
let (quantized, scale) = ternary_quantize_deterministic(&x, None).unwrap();
assert_eq!(scale, 0.0);
let values: Vec<f32> = quantized.flatten_all().unwrap().to_vec1().unwrap();
assert!(values.iter().all(|&v| v == 0.0));
}
#[test]
fn test_memory_savings() {
let savings = calculate_memory_savings(1_000_000, 10);
assert!(savings > 0.9, "Expected > 90% savings, got {:.2}%", savings * 100.0);
}
#[test]
fn test_stochastic_packed() {
let device = Device::Cpu;
let x = Tensor::randn(0.0f32, 1.0, 1000, &device).unwrap();
let (packed, scale) = ternary_quantize_stochastic_packed(&x, None).unwrap();
assert_eq!(packed.len(), 1000);
assert!(scale > 0.0);
}
}