use ternlang_core::trit::Trit;
#[derive(Debug, Clone)]
pub struct PerLayerQuant {
pub trits: Vec<Trit>,
pub scale: f32,
pub shape: Vec<usize>,
pub name: String,
pub sparsity: f64,
}
impl PerLayerQuant {
pub fn quantize(weights: &[f32], name: impl Into<String>, shape: Vec<usize>) -> Self {
assert_eq!(
shape.iter().product::<usize>(),
weights.len(),
"shape product must match weights length"
);
let scale = {
let sum: f32 = weights.iter().map(|w| w.abs()).sum();
if sum == 0.0 { 1.0 } else { sum / weights.len() as f32 }
};
let threshold = 0.5_f32;
let trits: Vec<Trit> = weights.iter().map(|&w| {
let n = w / scale;
if n > threshold {
Trit::Affirm
} else if n < -threshold {
Trit::Reject
} else {
Trit::Tend
}
}).collect();
let zeros = trits.iter().filter(|&&t| t == Trit::Tend).count();
let sparsity = zeros as f64 / trits.len() as f64;
Self { trits, scale, shape, name: name.into(), sparsity }
}
pub fn reconstruct(&self) -> Vec<f32> {
self.trits.iter().map(|&t| match t {
Trit::Affirm => self.scale,
Trit::Reject => -self.scale,
Trit::Tend => 0.0,
}).collect()
}
pub fn mse(&self, original: &[f32]) -> f32 {
assert_eq!(original.len(), self.trits.len());
let recon = self.reconstruct();
let sum_sq: f32 = original.iter().zip(recon.iter())
.map(|(o, r)| (o - r).powi(2))
.sum();
sum_sq / original.len() as f32
}
}
#[cfg(feature = "parallel")]
pub fn quantize_layers_parallel(
layers: Vec<(String, Vec<f32>, Vec<usize>)>,
) -> Vec<PerLayerQuant> {
use rayon::prelude::*;
layers.into_par_iter()
.map(|(name, weights, shape)| PerLayerQuant::quantize(&weights, name, shape))
.collect()
}
pub fn quantize_layers(
layers: Vec<(String, Vec<f32>, Vec<usize>)>,
) -> Vec<PerLayerQuant> {
layers.into_iter()
.map(|(name, weights, shape)| PerLayerQuant::quantize(&weights, name, shape))
.collect()
}