ternlang_compress/
quantize.rs1use ternlang_core::trit::Trit;
15
16#[derive(Debug, Clone)]
18pub struct PerLayerQuant {
19 pub trits: Vec<Trit>,
21 pub scale: f32,
24 pub shape: Vec<usize>,
26 pub name: String,
28 pub sparsity: f64,
30}
31
32impl PerLayerQuant {
33 pub fn quantize(weights: &[f32], name: impl Into<String>, shape: Vec<usize>) -> Self {
38 assert_eq!(
39 shape.iter().product::<usize>(),
40 weights.len(),
41 "shape product must match weights length"
42 );
43
44 let scale = {
46 let sum: f32 = weights.iter().map(|w| w.abs()).sum();
47 if sum == 0.0 { 1.0 } else { sum / weights.len() as f32 }
48 };
49
50 let threshold = 0.5_f32;
52 let trits: Vec<Trit> = weights.iter().map(|&w| {
53 let n = w / scale;
54 if n > threshold {
55 Trit::Affirm
56 } else if n < -threshold {
57 Trit::Reject
58 } else {
59 Trit::Tend
60 }
61 }).collect();
62
63 let zeros = trits.iter().filter(|&&t| t == Trit::Tend).count();
64 let sparsity = zeros as f64 / trits.len() as f64;
65
66 Self { trits, scale, shape, name: name.into(), sparsity }
67 }
68
69 pub fn reconstruct(&self) -> Vec<f32> {
72 self.trits.iter().map(|&t| match t {
73 Trit::Affirm => self.scale,
74 Trit::Reject => -self.scale,
75 Trit::Tend => 0.0,
76 }).collect()
77 }
78
79 pub fn mse(&self, original: &[f32]) -> f32 {
81 assert_eq!(original.len(), self.trits.len());
82 let recon = self.reconstruct();
83 let sum_sq: f32 = original.iter().zip(recon.iter())
84 .map(|(o, r)| (o - r).powi(2))
85 .sum();
86 sum_sq / original.len() as f32
87 }
88}
89
90#[cfg(feature = "parallel")]
92pub fn quantize_layers_parallel(
93 layers: Vec<(String, Vec<f32>, Vec<usize>)>,
94) -> Vec<PerLayerQuant> {
95 use rayon::prelude::*;
96 layers.into_par_iter()
97 .map(|(name, weights, shape)| PerLayerQuant::quantize(&weights, name, shape))
98 .collect()
99}
100
101pub fn quantize_layers(
103 layers: Vec<(String, Vec<f32>, Vec<usize>)>,
104) -> Vec<PerLayerQuant> {
105 layers.into_iter()
106 .map(|(name, weights, shape)| PerLayerQuant::quantize(&weights, name, shape))
107 .collect()
108}