Skip to main content

ternlang_compress/
quantize.rs

1// Per-layer post-training ternary quantization (PTQ).
2//
3// Algorithm (BitNet b1.58 scheme):
4//   1. Compute per-layer scale α = mean(|W|)
5//   2. Normalise: W_norm = W / α
6//   3. Map to trits:
7//        W_norm >  0.5  →  +1 (Affirm)
8//        W_norm < -0.5  →  -1 (Reject)
9//        otherwise      →   0 (Tend)
10//
11// The scale α is stored alongside the ternary weights so the original
12// magnitude can be approximately reconstructed: W ≈ α × W_t
13
14use ternlang_core::trit::Trit;
15
16/// The result of quantizing one weight tensor.
17#[derive(Debug, Clone)]
18pub struct PerLayerQuant {
19    /// Ternary weights in row-major order.
20    pub trits: Vec<Trit>,
21    /// Per-layer scale factor α = mean(|W|).
22    /// Multiply by this to approximate the original magnitudes.
23    pub scale: f32,
24    /// Shape: [rows, cols] for 2D tensors; arbitrary for higher ranks.
25    pub shape: Vec<usize>,
26    /// Human-readable layer name (e.g. "model.layers.0.self_attn.q_proj.weight").
27    pub name: String,
28    /// Fraction of trits that are Tend (zero) — 0.0 … 1.0.
29    pub sparsity: f64,
30}
31
32impl PerLayerQuant {
33    /// Quantize `weights` (flat, row-major) to balanced ternary.
34    ///
35    /// `name`  — layer identifier for diagnostics.
36    /// `shape` — logical shape of the tensor (product must equal `weights.len()`).
37    pub fn quantize(weights: &[f32], name: impl Into<String>, shape: Vec<usize>) -> Self {
38        assert_eq!(
39            shape.iter().product::<usize>(),
40            weights.len(),
41            "shape product must match weights length"
42        );
43
44        // Step 1: scale
45        let scale = {
46            let sum: f32 = weights.iter().map(|w| w.abs()).sum();
47            if sum == 0.0 { 1.0 } else { sum / weights.len() as f32 }
48        };
49
50        // Step 2 + 3: normalise and threshold
51        let threshold = 0.5_f32;
52        let trits: Vec<Trit> = weights.iter().map(|&w| {
53            let n = w / scale;
54            if n > threshold {
55                Trit::Affirm
56            } else if n < -threshold {
57                Trit::Reject
58            } else {
59                Trit::Tend
60            }
61        }).collect();
62
63        let zeros = trits.iter().filter(|&&t| t == Trit::Tend).count();
64        let sparsity = zeros as f64 / trits.len() as f64;
65
66        Self { trits, scale, shape, name: name.into(), sparsity }
67    }
68
69    /// Approximate reconstruction: W_approx = α × W_t (as f32).
70    /// Useful for measuring quantization error.
71    pub fn reconstruct(&self) -> Vec<f32> {
72        self.trits.iter().map(|&t| match t {
73            Trit::Affirm =>  self.scale,
74            Trit::Reject => -self.scale,
75            Trit::Tend   =>  0.0,
76        }).collect()
77    }
78
79    /// Mean squared error between original weights and ternary reconstruction.
80    pub fn mse(&self, original: &[f32]) -> f32 {
81        assert_eq!(original.len(), self.trits.len());
82        let recon = self.reconstruct();
83        let sum_sq: f32 = original.iter().zip(recon.iter())
84            .map(|(o, r)| (o - r).powi(2))
85            .sum();
86        sum_sq / original.len() as f32
87    }
88}
89
90/// Quantize all layers of a model in parallel (requires `parallel` feature).
91#[cfg(feature = "parallel")]
92pub fn quantize_layers_parallel(
93    layers: Vec<(String, Vec<f32>, Vec<usize>)>,
94) -> Vec<PerLayerQuant> {
95    use rayon::prelude::*;
96    layers.into_par_iter()
97        .map(|(name, weights, shape)| PerLayerQuant::quantize(&weights, name, shape))
98        .collect()
99}
100
101/// Single-threaded fallback.
102pub fn quantize_layers(
103    layers: Vec<(String, Vec<f32>, Vec<usize>)>,
104) -> Vec<PerLayerQuant> {
105    layers.into_iter()
106        .map(|(name, weights, shape)| PerLayerQuant::quantize(&weights, name, shape))
107        .collect()
108}