Skip to main content

ternlang_ml/
perplexity.rs

1// SPDX-License-Identifier: LicenseRef-Ternlang-Commercial
2// Ternlang — RFI-IRFOS Ternary Intelligence Stack
3// Phase 12C: Perplexity Validation
4//
5// Pseudo-perplexity for ternary MLPs:
6//   1. Treat each test sample as a classification problem.
7//   2. The "correct class" is argmax(target).
8//   3. Compute softmax over f32 logits from forward_logits().
9//   4. CE = -log P(correct_class) per sample.
10//   5. PPL = exp(mean CE over all samples).
11//
12// Additionally reports top-1 accuracy and average ternary output entropy.
13
14use crate::TernaryMLP;
15
16// ─── Softmax / entropy helpers ───────────────────────────────────────────────
17
18fn softmax(logits: &[f32]) -> Vec<f32> {
19    let max = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
20    let exps: Vec<f32> = logits.iter().map(|&x| (x - max).exp()).collect();
21    let sum: f32 = exps.iter().sum();
22    exps.iter().map(|e| e / sum.max(1e-9)).collect()
23}
24
25fn cross_entropy(probs: &[f32], target_class: usize) -> f32 {
26    -probs[target_class].max(1e-9).ln()
27}
28
29fn entropy(probs: &[f32]) -> f32 {
30    probs.iter().map(|&p| if p > 0.0 { -p * p.ln() } else { 0.0 }).sum()
31}
32
33// ─── Evaluation results ───────────────────────────────────────────────────────
34
35#[derive(Debug, Clone)]
36pub struct PerplexityReport {
37    pub pseudo_perplexity: f32,
38    pub mean_cross_entropy: f32,
39    pub top1_accuracy: f32,
40    pub mean_output_entropy: f32,
41    pub n_samples: usize,
42}
43
44impl PerplexityReport {
45    pub fn print(&self, label: &str) {
46        println!("  [{}]", label);
47        println!("    Pseudo-PPL:         {:.4}", self.pseudo_perplexity);
48        println!("    Mean cross-entropy: {:.4}", self.mean_cross_entropy);
49        println!("    Top-1 accuracy:     {:.1}%", self.top1_accuracy * 100.0);
50        println!("    Output entropy:     {:.4} nats", self.mean_output_entropy);
51        println!("    Samples:            {}", self.n_samples);
52    }
53}
54
55// ─── Evaluator ───────────────────────────────────────────────────────────────
56
57pub struct PerplexityEvaluator;
58
59impl PerplexityEvaluator {
60    /// Evaluate `model` on `test_samples`.
61    ///
62    /// Each sample is `(input: Vec<f32>, target: Vec<f32>)`.
63    /// The correct class for cross-entropy is `argmax(target)`.
64    pub fn evaluate(model: &TernaryMLP, test_samples: &[(Vec<f32>, Vec<f32>)]) -> PerplexityReport {
65        assert!(!test_samples.is_empty(), "need at least one test sample");
66
67        let mut total_ce = 0.0f32;
68        let mut total_entropy = 0.0f32;
69        let mut correct = 0usize;
70
71        for (input, target) in test_samples {
72            let logits = model.forward_logits(input);
73            let probs  = softmax(&logits);
74
75            let target_class = argmax(target);
76            let pred_class   = argmax(&logits);
77
78            total_ce      += cross_entropy(&probs, target_class);
79            total_entropy += entropy(&probs);
80            if pred_class == target_class { correct += 1; }
81        }
82
83        let n = test_samples.len() as f32;
84        let mean_ce = total_ce / n;
85
86        PerplexityReport {
87            pseudo_perplexity:   mean_ce.exp(),
88            mean_cross_entropy:  mean_ce,
89            top1_accuracy:       correct as f32 / n,
90            mean_output_entropy: total_entropy / n,
91            n_samples:           test_samples.len(),
92        }
93    }
94}
95
96fn argmax(v: &[f32]) -> usize {
97    v.iter().enumerate()
98        .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
99        .map(|(i, _)| i)
100        .unwrap_or(0)
101}
102
103// ─── Comparison utility ───────────────────────────────────────────────────────
104
105pub struct ComparisonResult {
106    pub before: PerplexityReport,
107    pub after:  PerplexityReport,
108}
109
110impl ComparisonResult {
111    pub fn print(&self) {
112        println!("\n╔══════════════════════════════════════════════════════════════╗");
113        println!(  "║     Phase 12C — QAT Perplexity Comparison (RFI-IRFOS TIS)  ║");
114        println!(  "╠══════════════════════════════════════════════════════════════╣");
115        self.before.print("Pre-QAT  (baseline)");
116        println!(  "  ──────────────────────────────────────────────────────────");
117        self.after.print("Post-QAT (STE fine-tuned)");
118        println!(  "╠══════════════════════════════════════════════════════════════╣");
119
120        let ppl_delta  = self.after.pseudo_perplexity - self.before.pseudo_perplexity;
121        let acc_delta  = self.after.top1_accuracy     - self.before.top1_accuracy;
122        let improved   = ppl_delta < 0.0 || acc_delta > 0.0;
123
124        println!("  PPL delta:  {:+.4}  ({})",
125            ppl_delta, if ppl_delta < 0.0 { "IMPROVED" } else { "regressed" });
126        println!("  Acc delta:  {:+.1}%  ({})",
127            acc_delta * 100.0, if acc_delta > 0.0 { "IMPROVED" } else if acc_delta < 0.0 { "regressed" } else { "unchanged" });
128        println!("  Verdict:    {}", if improved { "[OK] QAT improved model quality" } else { "[WARN] No improvement — check lr/epochs" });
129        println!("╚══════════════════════════════════════════════════════════════╝");
130    }
131}
132
133/// Run a full pre/post QAT comparison and return the result.
134pub fn compare_perplexity(
135    pre_qat:  &TernaryMLP,
136    post_qat: &TernaryMLP,
137    test_samples: &[(Vec<f32>, Vec<f32>)],
138) -> ComparisonResult {
139    ComparisonResult {
140        before: PerplexityEvaluator::evaluate(pre_qat,  test_samples),
141        after:  PerplexityEvaluator::evaluate(post_qat, test_samples),
142    }
143}
144
145// ─── Tests ────────────────────────────────────────────────────────────────────
146
147#[cfg(test)]
148mod tests {
149    use super::*;
150    use crate::TernaryMLP;
151
152    fn lcg(n: usize, seed: u64) -> Vec<f32> {
153        let mut s = seed;
154        (0..n).map(|_| {
155            s = s.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
156            ((s >> 33) as f32) / (u32::MAX as f32) * 2.0 - 1.0
157        }).collect()
158    }
159
160    fn make_mlp(inf: usize, hs: usize, outf: usize, seed: u64) -> TernaryMLP {
161        let w1 = lcg(inf * hs, seed);
162        let w2 = lcg(hs * outf, seed + 1);
163        TernaryMLP::from_f32(inf, hs, outf, &w1, &w2)
164    }
165
166    #[test]
167    fn perplexity_report_is_finite() {
168        let mlp = make_mlp(8, 16, 4, 0xabc);
169        let samples: Vec<(Vec<f32>, Vec<f32>)> = (0u64..10).map(|i| {
170            let input  = lcg(8, i * 7 + 1);
171            let mut target = vec![0.0f32; 4];
172            target[i as usize % 4] = 1.0;
173            (input, target)
174        }).collect();
175
176        let report = PerplexityEvaluator::evaluate(&mlp, &samples);
177        assert!(report.pseudo_perplexity.is_finite());
178        assert!(report.pseudo_perplexity > 0.0);
179        assert!(report.mean_cross_entropy >= 0.0);
180        assert!(report.top1_accuracy >= 0.0 && report.top1_accuracy <= 1.0);
181    }
182
183    #[test]
184    fn qat_comparison_runs() {
185        use crate::qat::{QatConfig, SteTrainer};
186
187        let (inf, hs, outf) = (8, 16, 4);
188        let w1 = lcg(inf * hs, 0x1234);
189        let w2 = lcg(hs * outf, 0x5678);
190
191        let pre_qat = TernaryMLP::from_f32(inf, hs, outf, &w1, &w2);
192
193        let config = QatConfig { lr: 0.05, epochs: 30, clip_threshold: 1.0, log_every: 0 };
194        let mut trainer = SteTrainer::from_f32(inf, hs, outf, w1.clone(), w2.clone(), config);
195
196        let train_samples: Vec<(Vec<f32>, Vec<f32>)> = (0u64..8).map(|i| {
197            let input  = lcg(inf, i * 13 + 5);
198            let mut target = vec![-1.0f32; outf];
199            target[i as usize % outf] = 1.0;
200            (input, target)
201        }).collect();
202
203        trainer.train(&train_samples);
204        let post_qat = trainer.finalize();
205
206        let test_samples: Vec<(Vec<f32>, Vec<f32>)> = (0u64..8).map(|i| {
207            let input  = lcg(inf, i * 19 + 3);
208            let mut target = vec![-1.0f32; outf];
209            target[i as usize % outf] = 1.0;
210            (input, target)
211        }).collect();
212
213        let result = compare_perplexity(&pre_qat, &post_qat, &test_samples);
214        // Both reports must be well-formed
215        assert!(result.before.pseudo_perplexity.is_finite());
216        assert!(result.after.pseudo_perplexity.is_finite());
217    }
218}