Skip to main content

ternlang_ml/
qat.rs

1// SPDX-License-Identifier: LicenseRef-Ternlang-Commercial
2// Ternlang — RFI-IRFOS Ternary Intelligence Stack
3// Phase 12B: Quantization-Aware Training (QAT) with Straight-Through Estimator (STE)
4//
5// STE rule:  d(quantize(w))/dw = 1  when |w| <= clip_threshold, else 0.
6// Latent f32 weights are maintained throughout training; quantized weights are
7// used only during the forward pass. The backward pass treats the step function
8// as identity within the clip window — this is the "straight-through" estimate.
9
10use crate::{TernaryMLP, TritMatrix, bitnet_threshold, quantize};
11use ternlang_core::trit::Trit;
12
13// ─── Configuration ────────────────────────────────────────────────────────────
14
15pub struct QatConfig {
16    /// Learning rate for SGD on latent weights.
17    pub lr: f32,
18    /// Number of full passes over the training set.
19    pub epochs: usize,
20    /// STE clip threshold: gradients are zeroed for |w_latent| > clip.
21    /// A value of 1.0 lets gradients flow anywhere, 0.5 restricts to the
22    /// "undecided" region near the quantization boundary.
23    pub clip_threshold: f32,
24    /// Print loss every N epochs (0 = silent).
25    pub log_every: usize,
26}
27
28impl Default for QatConfig {
29    fn default() -> Self {
30        Self {
31            lr: 0.01,
32            epochs: 100,
33            clip_threshold: 1.0,
34            log_every: 10,
35        }
36    }
37}
38
39// ─── Training result ─────────────────────────────────────────────────────────
40
41pub struct QatResult {
42    pub final_loss: f32,
43    pub epochs_run: usize,
44    /// Fraction of latent weights that are currently in the |w| <= clip zone.
45    pub active_gradient_fraction: f32,
46}
47
48// ─── STE Trainer ─────────────────────────────────────────────────────────────
49
50/// Maintains latent f32 shadow weights for a 2-layer MLP.
51/// During each training step:
52///   1. Quantize latent weights → ternary {-1, 0, +1}
53///   2. Forward pass (f32 arithmetic on quantized weights)
54///   3. MSE loss + backprop through STE
55///   4. SGD update on latent weights
56pub struct SteTrainer {
57    pub w1_latent: Vec<f32>,   // shape [in_features × hidden_size]
58    pub w2_latent: Vec<f32>,   // shape [hidden_size × out_features]
59    pub in_features:  usize,
60    pub hidden_size:  usize,
61    pub out_features: usize,
62    pub config: QatConfig,
63}
64
65impl SteTrainer {
66    /// Initialise from an existing TernaryMLP's quantized weights.
67    /// The ternary {-1,0,+1} values become the initial latent floats.
68    pub fn from_mlp(mlp: &TernaryMLP, config: QatConfig) -> Self {
69        let w1_latent = mlp.w1.to_i8_vec().iter().map(|&v| v as f32).collect();
70        let w2_latent = mlp.w2.to_i8_vec().iter().map(|&v| v as f32).collect();
71        Self {
72            w1_latent,
73            w2_latent,
74            in_features:  mlp.in_features,
75            hidden_size:  mlp.hidden_size,
76            out_features: mlp.out_features,
77            config,
78        }
79    }
80
81    /// Initialise from raw f32 weights (quantization happens at first step).
82    pub fn from_f32(
83        in_features: usize,
84        hidden_size: usize,
85        out_features: usize,
86        w1_f32: Vec<f32>,
87        w2_f32: Vec<f32>,
88        config: QatConfig,
89    ) -> Self {
90        assert_eq!(w1_f32.len(), in_features * hidden_size);
91        assert_eq!(w2_f32.len(), hidden_size * out_features);
92        Self { w1_latent: w1_f32, w2_latent: w2_f32, in_features, hidden_size, out_features, config }
93    }
94
95    // ── Helpers ───────────────────────────────────────────────────────────────
96
97    fn quantize_latent(latent: &[f32]) -> Vec<f32> {
98        let tau = bitnet_threshold(latent);
99        quantize(latent, tau).iter().map(|&t| match t {
100            Trit::Affirm =>  1.0,
101            Trit::Reject => -1.0,
102            Trit::Tend   =>  0.0,
103        }).collect()
104    }
105
106    /// STE mask: 1.0 where |w_latent| <= clip, else 0.0.
107    fn ste_mask(latent: &[f32], clip: f32) -> Vec<f32> {
108        latent.iter().map(|&w| if w.abs() <= clip { 1.0 } else { 0.0 }).collect()
109    }
110
111    /// Matrix multiply A [m×k] × B [k×n] → C [m×n] (f32, row-major).
112    fn matmul(a: &[f32], b: &[f32], m: usize, k: usize, n: usize) -> Vec<f32> {
113        let mut c = vec![0.0f32; m * n];
114        for i in 0..m {
115            for j in 0..n {
116                let mut acc = 0.0f32;
117                for p in 0..k {
118                    acc += a[i * k + p] * b[p * n + j];
119                }
120                c[i * n + j] = acc;
121            }
122        }
123        c
124    }
125
126    /// Transpose a matrix [rows×cols] → [cols×rows].
127    fn transpose(a: &[f32], rows: usize, cols: usize) -> Vec<f32> {
128        let mut out = vec![0.0f32; rows * cols];
129        for r in 0..rows {
130            for c in 0..cols {
131                out[c * rows + r] = a[r * cols + c];
132            }
133        }
134        out
135    }
136
137    // ── Forward + backward ────────────────────────────────────────────────────
138
139    /// One SGD step on a single sample.
140    ///
141    /// `input`  — flat f32 row vector, length `in_features`
142    /// `target` — flat f32 row vector, length `out_features`
143    ///
144    /// Returns the MSE loss for this sample.
145    pub fn train_step(&mut self, input: &[f32], target: &[f32]) -> f32 {
146        let (inf, hs, outf) = (self.in_features, self.hidden_size, self.out_features);
147
148        // Quantize latent → float {-1,0,+1}
149        let w1_q = Self::quantize_latent(&self.w1_latent);
150        let w2_q = Self::quantize_latent(&self.w2_latent);
151
152        // ── Forward ──────────────────────────────────────────────────────────
153        // hidden = input [1×in] × w1_q [in×hs] → [1×hs]
154        let hidden = Self::matmul(input, &w1_q, 1, inf, hs);
155
156        // Trit activation (sign) on hidden
157        let hidden_act: Vec<f32> = hidden.iter().map(|&h| {
158            if h > 0.0 { 1.0 } else if h < 0.0 { -1.0 } else { 0.0 }
159        }).collect();
160
161        // output = hidden_act [1×hs] × w2_q [hs×out] → [1×out]
162        let output = Self::matmul(&hidden_act, &w2_q, 1, hs, outf);
163
164        // ── Loss (MSE) ───────────────────────────────────────────────────────
165        let loss: f32 = output.iter().zip(target.iter())
166            .map(|(o, t)| (o - t).powi(2))
167            .sum::<f32>() / outf as f32;
168
169        // ── Backward (STE) ───────────────────────────────────────────────────
170        // d_loss / d_output = 2*(output - target) / out_features
171        let d_output: Vec<f32> = output.iter().zip(target.iter())
172            .map(|(o, t)| 2.0 * (o - t) / outf as f32)
173            .collect();
174
175        // Layer 2 weight gradient:
176        //   d_loss/d_w2_q  = hidden_act^T [hs×1] × d_output [1×out] → [hs×out]
177        //   d_loss/d_w2_latent = d_loss/d_w2_q ⊙ ste_mask(w2_latent)
178        let hidden_act_t = Self::transpose(&hidden_act, 1, hs);
179        let d_w2_q = Self::matmul(&hidden_act_t, &d_output, hs, 1, outf);
180        let ste2 = Self::ste_mask(&self.w2_latent, self.config.clip_threshold);
181        let d_w2: Vec<f32> = d_w2_q.iter().zip(ste2.iter()).map(|(g, m)| g * m).collect();
182
183        // Propagate gradient back through w2 to hidden_act:
184        //   d_loss/d_hidden_act = d_output [1×out] × w2_q^T [out×hs] → [1×hs]
185        let w2_q_t = Self::transpose(&w2_q, hs, outf);
186        let d_hidden_act = Self::matmul(&d_output, &w2_q_t, 1, outf, hs);
187
188        // Trit activation derivative (straight-through: 1 for non-zero, 0 for 0)
189        // We approximate: pass gradient through for hidden != 0.
190        let d_hidden: Vec<f32> = d_hidden_act.iter().zip(hidden.iter())
191            .map(|(g, h)| if *h != 0.0 { *g } else { 0.0 })
192            .collect();
193
194        // Layer 1 weight gradient:
195        //   d_loss/d_w1_q  = input^T [in×1] × d_hidden [1×hs] → [in×hs]
196        //   d_loss/d_w1_latent = d_loss/d_w1_q ⊙ ste_mask(w1_latent)
197        let input_t = Self::transpose(input, 1, inf);
198        let d_w1_q = Self::matmul(&input_t, &d_hidden, inf, 1, hs);
199        let ste1 = Self::ste_mask(&self.w1_latent, self.config.clip_threshold);
200        let d_w1: Vec<f32> = d_w1_q.iter().zip(ste1.iter()).map(|(g, m)| g * m).collect();
201
202        // ── SGD update ───────────────────────────────────────────────────────
203        let lr = self.config.lr;
204        for (w, g) in self.w1_latent.iter_mut().zip(d_w1.iter()) {
205            *w -= lr * g;
206        }
207        for (w, g) in self.w2_latent.iter_mut().zip(d_w2.iter()) {
208            *w -= lr * g;
209        }
210
211        loss
212    }
213
214    /// Run the full training loop over a dataset.
215    ///
216    /// `samples` — slice of (input, target) pairs
217    pub fn train(&mut self, samples: &[(Vec<f32>, Vec<f32>)]) -> QatResult {
218        let mut final_loss = 0.0f32;
219
220        for epoch in 0..self.config.epochs {
221            let mut epoch_loss = 0.0f32;
222            for (input, target) in samples.iter() {
223                epoch_loss += self.train_step(input, target);
224            }
225            epoch_loss /= samples.len() as f32;
226            final_loss = epoch_loss;
227
228            if self.config.log_every > 0 && (epoch + 1) % self.config.log_every == 0 {
229                println!("[QAT/STE] epoch {:>4}/{} | loss: {:.6}", epoch + 1, self.config.epochs, epoch_loss);
230            }
231        }
232
233        let active = self.w1_latent.iter().chain(self.w2_latent.iter())
234            .filter(|&&w| w.abs() <= self.config.clip_threshold)
235            .count();
236        let total = self.w1_latent.len() + self.w2_latent.len();
237        let active_gradient_fraction = active as f32 / total as f32;
238
239        QatResult {
240            final_loss,
241            epochs_run: self.config.epochs,
242            active_gradient_fraction,
243        }
244    }
245
246    /// Finalise training: quantize latent weights and return a TernaryMLP.
247    pub fn finalize(&self) -> TernaryMLP {
248        let tau1 = bitnet_threshold(&self.w1_latent);
249        let tau2 = bitnet_threshold(&self.w2_latent);
250        let w1 = TritMatrix::from_f32(self.in_features, self.hidden_size, &self.w1_latent, tau1);
251        let w2 = TritMatrix::from_f32(self.hidden_size, self.out_features, &self.w2_latent, tau2);
252        TernaryMLP::new(w1, w2)
253    }
254}
255
256// ─── Tests ────────────────────────────────────────────────────────────────────
257
258#[cfg(test)]
259mod tests {
260    use super::*;
261
262    fn lcg(n: usize, seed: u64) -> Vec<f32> {
263        let mut s = seed;
264        (0..n).map(|_| {
265            s = s.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
266            ((s >> 33) as f32) / (u32::MAX as f32) * 2.0 - 1.0
267        }).collect()
268    }
269
270    #[test]
271    fn ste_trainer_reduces_loss() {
272        let (inf, hs, outf) = (8, 16, 4);
273        let w1 = lcg(inf * hs, 0xdead);
274        let w2 = lcg(hs * outf, 0xbeef);
275        let config = QatConfig { lr: 0.05, epochs: 50, clip_threshold: 1.0, log_every: 0 };
276        let mut trainer = SteTrainer::from_f32(inf, hs, outf, w1, w2, config);
277
278        // XOR-like classification target: output[0] = 1.0, rest = 0.0
279        let samples: Vec<(Vec<f32>, Vec<f32>)> = (0..8).map(|i| {
280            let input = lcg(inf, i as u64 * 17 + 3);
281            let target = vec![1.0, -1.0, 0.0, 0.0];
282            (input, target)
283        }).collect();
284
285        let initial_loss = {
286            let mut l = 0.0f32;
287            for (input, target) in &samples {
288                let w1_q = SteTrainer::quantize_latent(&trainer.w1_latent);
289                let w2_q = SteTrainer::quantize_latent(&trainer.w2_latent);
290                let hidden = SteTrainer::matmul(input, &w1_q, 1, inf, hs);
291                let hidden_act: Vec<f32> = hidden.iter().map(|&h|
292                    if h > 0.0 { 1.0 } else if h < 0.0 { -1.0 } else { 0.0 }
293                ).collect();
294                let output = SteTrainer::matmul(&hidden_act, &w2_q, 1, hs, outf);
295                l += output.iter().zip(target.iter()).map(|(o, t)| (o-t).powi(2)).sum::<f32>() / outf as f32;
296            }
297            l / samples.len() as f32
298        };
299
300        let result = trainer.train(&samples);
301        println!("[test] initial_loss={:.4} final_loss={:.4}", initial_loss, result.final_loss);
302        assert!(result.final_loss <= initial_loss, "QAT training must not increase loss");
303        assert!(result.active_gradient_fraction > 0.0, "Some gradients must flow through STE");
304    }
305
306    #[test]
307    fn finalize_produces_valid_mlp() {
308        let (inf, hs, outf) = (4, 8, 2);
309        let w1 = lcg(inf * hs, 0xfeed);
310        let w2 = lcg(hs * outf, 0xcafe);
311        let config = QatConfig { lr: 0.01, epochs: 5, clip_threshold: 1.0, log_every: 0 };
312        let mut trainer = SteTrainer::from_f32(inf, hs, outf, w1, w2, config);
313
314        let samples = vec![
315            (lcg(inf, 1), vec![1.0, -1.0]),
316            (lcg(inf, 2), vec![-1.0, 1.0]),
317        ];
318        trainer.train(&samples);
319
320        let mlp = trainer.finalize();
321        assert_eq!(mlp.in_features, inf);
322        assert_eq!(mlp.hidden_size, hs);
323        assert_eq!(mlp.out_features, outf);
324
325        // Forward pass smoke test
326        let input = TritMatrix::from_f32(1, inf, &lcg(inf, 99), 0.3);
327        let (output, _, _) = mlp.forward(&input);
328        assert_eq!(output.rows, 1);
329        assert_eq!(output.cols, outf);
330    }
331}