1use crate::TernaryMLP;
15
16fn softmax(logits: &[f32]) -> Vec<f32> {
19 let max = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
20 let exps: Vec<f32> = logits.iter().map(|&x| (x - max).exp()).collect();
21 let sum: f32 = exps.iter().sum();
22 exps.iter().map(|e| e / sum.max(1e-9)).collect()
23}
24
25fn cross_entropy(probs: &[f32], target_class: usize) -> f32 {
26 -probs[target_class].max(1e-9).ln()
27}
28
29fn entropy(probs: &[f32]) -> f32 {
30 probs.iter().map(|&p| if p > 0.0 { -p * p.ln() } else { 0.0 }).sum()
31}
32
33#[derive(Debug, Clone)]
36pub struct PerplexityReport {
37 pub pseudo_perplexity: f32,
38 pub mean_cross_entropy: f32,
39 pub top1_accuracy: f32,
40 pub mean_output_entropy: f32,
41 pub n_samples: usize,
42}
43
44impl PerplexityReport {
45 pub fn print(&self, label: &str) {
46 println!(" [{}]", label);
47 println!(" Pseudo-PPL: {:.4}", self.pseudo_perplexity);
48 println!(" Mean cross-entropy: {:.4}", self.mean_cross_entropy);
49 println!(" Top-1 accuracy: {:.1}%", self.top1_accuracy * 100.0);
50 println!(" Output entropy: {:.4} nats", self.mean_output_entropy);
51 println!(" Samples: {}", self.n_samples);
52 }
53}
54
55pub struct PerplexityEvaluator;
58
59impl PerplexityEvaluator {
60 pub fn evaluate(model: &TernaryMLP, test_samples: &[(Vec<f32>, Vec<f32>)]) -> PerplexityReport {
65 assert!(!test_samples.is_empty(), "need at least one test sample");
66
67 let mut total_ce = 0.0f32;
68 let mut total_entropy = 0.0f32;
69 let mut correct = 0usize;
70
71 for (input, target) in test_samples {
72 let logits = model.forward_logits(input);
73 let probs = softmax(&logits);
74
75 let target_class = argmax(target);
76 let pred_class = argmax(&logits);
77
78 total_ce += cross_entropy(&probs, target_class);
79 total_entropy += entropy(&probs);
80 if pred_class == target_class { correct += 1; }
81 }
82
83 let n = test_samples.len() as f32;
84 let mean_ce = total_ce / n;
85
86 PerplexityReport {
87 pseudo_perplexity: mean_ce.exp(),
88 mean_cross_entropy: mean_ce,
89 top1_accuracy: correct as f32 / n,
90 mean_output_entropy: total_entropy / n,
91 n_samples: test_samples.len(),
92 }
93 }
94}
95
96fn argmax(v: &[f32]) -> usize {
97 v.iter().enumerate()
98 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
99 .map(|(i, _)| i)
100 .unwrap_or(0)
101}
102
103pub struct ComparisonResult {
106 pub before: PerplexityReport,
107 pub after: PerplexityReport,
108}
109
110impl ComparisonResult {
111 pub fn print(&self) {
112 println!("\n╔══════════════════════════════════════════════════════════════╗");
113 println!( "║ Phase 12C — QAT Perplexity Comparison (RFI-IRFOS TIS) ║");
114 println!( "╠══════════════════════════════════════════════════════════════╣");
115 self.before.print("Pre-QAT (baseline)");
116 println!( " ──────────────────────────────────────────────────────────");
117 self.after.print("Post-QAT (STE fine-tuned)");
118 println!( "╠══════════════════════════════════════════════════════════════╣");
119
120 let ppl_delta = self.after.pseudo_perplexity - self.before.pseudo_perplexity;
121 let acc_delta = self.after.top1_accuracy - self.before.top1_accuracy;
122 let improved = ppl_delta < 0.0 || acc_delta > 0.0;
123
124 println!(" PPL delta: {:+.4} ({})",
125 ppl_delta, if ppl_delta < 0.0 { "IMPROVED" } else { "regressed" });
126 println!(" Acc delta: {:+.1}% ({})",
127 acc_delta * 100.0, if acc_delta > 0.0 { "IMPROVED" } else if acc_delta < 0.0 { "regressed" } else { "unchanged" });
128 println!(" Verdict: {}", if improved { "[OK] QAT improved model quality" } else { "[WARN] No improvement — check lr/epochs" });
129 println!("╚══════════════════════════════════════════════════════════════╝");
130 }
131}
132
133pub fn compare_perplexity(
135 pre_qat: &TernaryMLP,
136 post_qat: &TernaryMLP,
137 test_samples: &[(Vec<f32>, Vec<f32>)],
138) -> ComparisonResult {
139 ComparisonResult {
140 before: PerplexityEvaluator::evaluate(pre_qat, test_samples),
141 after: PerplexityEvaluator::evaluate(post_qat, test_samples),
142 }
143}
144
145#[cfg(test)]
148mod tests {
149 use super::*;
150 use crate::TernaryMLP;
151
152 fn lcg(n: usize, seed: u64) -> Vec<f32> {
153 let mut s = seed;
154 (0..n).map(|_| {
155 s = s.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
156 ((s >> 33) as f32) / (u32::MAX as f32) * 2.0 - 1.0
157 }).collect()
158 }
159
160 fn make_mlp(inf: usize, hs: usize, outf: usize, seed: u64) -> TernaryMLP {
161 let w1 = lcg(inf * hs, seed);
162 let w2 = lcg(hs * outf, seed + 1);
163 TernaryMLP::from_f32(inf, hs, outf, &w1, &w2)
164 }
165
166 #[test]
167 fn perplexity_report_is_finite() {
168 let mlp = make_mlp(8, 16, 4, 0xabc);
169 let samples: Vec<(Vec<f32>, Vec<f32>)> = (0u64..10).map(|i| {
170 let input = lcg(8, i * 7 + 1);
171 let mut target = vec![0.0f32; 4];
172 target[i as usize % 4] = 1.0;
173 (input, target)
174 }).collect();
175
176 let report = PerplexityEvaluator::evaluate(&mlp, &samples);
177 assert!(report.pseudo_perplexity.is_finite());
178 assert!(report.pseudo_perplexity > 0.0);
179 assert!(report.mean_cross_entropy >= 0.0);
180 assert!(report.top1_accuracy >= 0.0 && report.top1_accuracy <= 1.0);
181 }
182
183 #[test]
184 fn qat_comparison_runs() {
185 use crate::qat::{QatConfig, SteTrainer};
186
187 let (inf, hs, outf) = (8, 16, 4);
188 let w1 = lcg(inf * hs, 0x1234);
189 let w2 = lcg(hs * outf, 0x5678);
190
191 let pre_qat = TernaryMLP::from_f32(inf, hs, outf, &w1, &w2);
192
193 let config = QatConfig { lr: 0.05, epochs: 30, clip_threshold: 1.0, log_every: 0 };
194 let mut trainer = SteTrainer::from_f32(inf, hs, outf, w1.clone(), w2.clone(), config);
195
196 let train_samples: Vec<(Vec<f32>, Vec<f32>)> = (0u64..8).map(|i| {
197 let input = lcg(inf, i * 13 + 5);
198 let mut target = vec![-1.0f32; outf];
199 target[i as usize % outf] = 1.0;
200 (input, target)
201 }).collect();
202
203 trainer.train(&train_samples);
204 let post_qat = trainer.finalize();
205
206 let test_samples: Vec<(Vec<f32>, Vec<f32>)> = (0u64..8).map(|i| {
207 let input = lcg(inf, i * 19 + 3);
208 let mut target = vec![-1.0f32; outf];
209 target[i as usize % outf] = 1.0;
210 (input, target)
211 }).collect();
212
213 let result = compare_perplexity(&pre_qat, &post_qat, &test_samples);
214 assert!(result.before.pseudo_perplexity.is_finite());
216 assert!(result.after.pseudo_perplexity.is_finite());
217 }
218}