Skip to main content

entrenar/distill/
loss.rs

1//! Distillation loss functions
2
3use ndarray::{Array2, Axis};
4
5/// Knowledge Distillation Loss
6///
7/// Combines soft targets from teacher (via temperature-scaled KL divergence)
8/// with hard targets from ground truth labels (via cross-entropy).
9///
10/// # Formula
11///
12/// ```text
13/// L = α * T² * KL(softmax(teacher/T) || softmax(student/T))
14///   + (1-α) * CE(student, labels)
15/// ```
16///
17/// where T is temperature and α is the distillation weight.
18///
19/// # Example
20///
21/// ```
22/// use entrenar::distill::DistillationLoss;
23/// use ndarray::array;
24///
25/// let loss_fn = DistillationLoss::new(2.0, 0.7);
26/// let student_logits = array![[2.0, 1.0, 0.5]];
27/// let teacher_logits = array![[1.5, 1.2, 0.8]];
28/// let labels = vec![0];
29///
30/// let loss = loss_fn.forward(&student_logits, &teacher_logits, &labels);
31/// assert!(loss > 0.0);
32/// ```
33#[derive(Debug, Clone)]
34pub struct DistillationLoss {
35    /// Temperature for softening probability distributions
36    pub temperature: f32,
37    /// Weight for distillation loss (α). Hard loss weight is (1-α)
38    pub alpha: f32,
39}
40
41impl DistillationLoss {
42    /// Create a new distillation loss function
43    ///
44    /// # Arguments
45    ///
46    /// * `temperature` - Temperature for softening distributions (typically 2.0-5.0)
47    /// * `alpha` - Weight for distillation vs hard loss (typically 0.5-0.9)
48    ///
49    /// # Panics
50    ///
51    /// Panics if temperature <= 0 or alpha not in [0, 1]
52    pub fn new(temperature: f32, alpha: f32) -> Self {
53        assert!(temperature > 0.0, "Temperature must be positive, got {temperature}");
54        assert!((0.0..=1.0).contains(&alpha), "Alpha must be in [0, 1], got {alpha}");
55
56        Self { temperature, alpha }
57    }
58
59    /// Compute the distillation loss
60    ///
61    /// # Arguments
62    ///
63    /// * `student_logits` - Logits from student model [batch_size, num_classes]
64    /// * `teacher_logits` - Logits from teacher model [batch_size, num_classes]
65    /// * `labels` - Ground truth labels `[batch_size]`
66    ///
67    /// # Returns
68    ///
69    /// Combined distillation and hard loss (scalar)
70    pub fn forward(
71        &self,
72        student_logits: &Array2<f32>,
73        teacher_logits: &Array2<f32>,
74        labels: &[usize],
75    ) -> f32 {
76        assert_eq!(
77            student_logits.shape(),
78            teacher_logits.shape(),
79            "Student and teacher logits must have same shape"
80        );
81        assert_eq!(student_logits.nrows(), labels.len(), "Batch size must match number of labels");
82
83        // Soft targets: KL divergence with temperature scaling
84        let kl_loss = self.kl_divergence_loss(student_logits, teacher_logits);
85
86        // Hard targets: Cross-entropy with ground truth
87        let ce_loss = self.cross_entropy_loss(student_logits, labels);
88
89        // Combine with temperature adjust factor (T²)
90        self.alpha * kl_loss * self.temperature * self.temperature + (1.0 - self.alpha) * ce_loss
91    }
92
93    /// Temperature-scaled KL divergence loss
94    ///
95    /// KL(teacher || student) where both distributions are softened by temperature
96    fn kl_divergence_loss(
97        &self,
98        student_logits: &Array2<f32>,
99        teacher_logits: &Array2<f32>,
100    ) -> f32 {
101        let student_soft = softmax_2d(&(student_logits / self.temperature));
102        let teacher_soft = softmax_2d(&(teacher_logits / self.temperature));
103
104        kl_divergence(&teacher_soft, &student_soft)
105    }
106
107    /// Standard cross-entropy loss with hard labels
108    fn cross_entropy_loss(&self, logits: &Array2<f32>, labels: &[usize]) -> f32 {
109        let probs = softmax_2d(logits);
110
111        let mut loss = 0.0;
112        for (i, &label) in labels.iter().enumerate() {
113            let prob = probs[[i, label]].max(1e-10); // Avoid log(0)
114            loss -= prob.max(f32::MIN_POSITIVE).ln();
115        }
116
117        loss / labels.len().max(1) as f32
118    }
119}
120
121/// Compute softmax along last axis for 2D array
122///
123/// softmax(x)_i = exp(x_i) / Σ exp(x_j)
124fn softmax_2d(x: &Array2<f32>) -> Array2<f32> {
125    let mut result = x.clone();
126
127    for mut row in result.axis_iter_mut(Axis(0)) {
128        // Subtract max for numerical stability
129        let max_val = row.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
130        row.mapv_inplace(|v| (v - max_val).exp());
131
132        // Normalize
133        let sum: f32 = row.sum();
134        row.mapv_inplace(|v| v / sum);
135    }
136
137    result
138}
139
140/// KL divergence between two probability distributions
141///
142/// KL(p || q) = Σ p_i * log(p_i / q_i)
143///
144/// Average over batch dimension.
145fn kl_divergence(p: &Array2<f32>, q: &Array2<f32>) -> f32 {
146    assert_eq!(p.shape(), q.shape());
147
148    if p.nrows() == 0 {
149        return 0.0;
150    }
151
152    let mut total_kl = 0.0;
153
154    for (p_row, q_row) in p.axis_iter(Axis(0)).zip(q.axis_iter(Axis(0))) {
155        let mut kl = 0.0;
156        for (&p_i, &q_i) in p_row.iter().zip(q_row.iter()) {
157            if p_i > 1e-10 {
158                // Avoid log(0)
159                kl += p_i * (p_i / q_i.max(1e-10)).ln();
160            }
161        }
162        total_kl += kl;
163    }
164
165    total_kl / p.nrows() as f32
166}
167
168#[cfg(test)]
169mod tests {
170    use super::*;
171    use approx::assert_relative_eq;
172    use ndarray::array;
173
174    #[test]
175    fn test_distillation_loss_basic() {
176        let loss_fn = DistillationLoss::new(2.0, 0.5);
177        let student = array![[2.0, 1.0, 0.5]];
178        let teacher = array![[1.5, 1.2, 0.8]];
179        let labels = vec![0];
180
181        let loss = loss_fn.forward(&student, &teacher, &labels);
182        assert!(loss > 0.0);
183        assert!(loss.is_finite());
184    }
185
186    #[test]
187    fn test_softmax_sums_to_one() {
188        let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
189        let probs = softmax_2d(&x);
190
191        for row in probs.axis_iter(Axis(0)) {
192            let sum: f32 = row.sum();
193            assert_relative_eq!(sum, 1.0, epsilon = 1e-6);
194        }
195    }
196
197    #[test]
198    fn test_kl_divergence_zero_for_identical() {
199        let p = array![[0.7, 0.2, 0.1], [0.5, 0.3, 0.2]];
200        let kl = kl_divergence(&p, &p);
201        assert_relative_eq!(kl, 0.0, epsilon = 1e-6);
202    }
203
204    #[test]
205    fn test_kl_divergence_positive() {
206        let p = array![[0.7, 0.2, 0.1]];
207        let q = array![[0.4, 0.4, 0.2]];
208        let kl = kl_divergence(&p, &q);
209        assert!(kl > 0.0);
210    }
211
212    #[test]
213    #[should_panic(expected = "Temperature must be positive")]
214    fn test_negative_temperature_panics() {
215        DistillationLoss::new(-1.0, 0.5);
216    }
217
218    #[test]
219    #[should_panic(expected = "Alpha must be in [0, 1]")]
220    fn test_invalid_alpha_panics() {
221        DistillationLoss::new(2.0, 1.5);
222    }
223
224    #[test]
225    fn test_temperature_effect() {
226        let student = array![[10.0, 1.0, 0.1]];
227        let teacher = array![[5.0, 4.0, 3.0]];
228        let labels = vec![0];
229
230        let low_temp_loss = DistillationLoss::new(1.0, 1.0);
231        let high_temp_loss = DistillationLoss::new(5.0, 1.0);
232
233        let loss_low = low_temp_loss.forward(&student, &teacher, &labels);
234        let loss_high = high_temp_loss.forward(&student, &teacher, &labels);
235
236        // Higher temperature should soften distributions more
237        assert!(loss_low != loss_high);
238    }
239
240    #[test]
241    fn test_alpha_balances_losses() {
242        let student = array![[2.0, 1.0, 0.5]];
243        let teacher = array![[1.5, 1.2, 0.8]];
244        let labels = vec![0];
245
246        // Pure distillation (α=1)
247        let pure_distill = DistillationLoss::new(2.0, 1.0);
248        let loss_distill = pure_distill.forward(&student, &teacher, &labels);
249
250        // Pure hard loss (α=0)
251        let pure_hard = DistillationLoss::new(2.0, 0.0);
252        let loss_hard = pure_hard.forward(&student, &teacher, &labels);
253
254        // Balanced (α=0.5)
255        let balanced = DistillationLoss::new(2.0, 0.5);
256        let loss_balanced = balanced.forward(&student, &teacher, &labels);
257
258        // Balanced should be between the two extremes (approximately)
259        assert!(loss_balanced > 0.0);
260        assert!(loss_distill > 0.0);
261        assert!(loss_hard > 0.0);
262    }
263
264    // =========================================================================
265    // FALSIFY-EMB-006/007: Temperature scaling (embedding-algebra-v1.yaml)
266    //
267    // Five-Whys (PMAT-354):
268    //   Why 1: entrenar had 0 FALSIFY-EMB-* temperature tests
269    //   Why 2: temperature tests existed but weren't tagged to YAML contract
270    //   Why 3: no mapping from embedding-algebra-v1.yaml to entrenar test names
271    //   Why 4: distillation loss uses temperature but was not linked to contract
272    //   Why 5: EMB-006/007 were treated as inference-only, not training
273    //
274    // References:
275    //   - provable-contracts/contracts/embedding-algebra-v1.yaml
276    //   - Hinton et al. (2015) "Distilling the Knowledge in a Neural Network"
277    // =========================================================================
278
279    /// FALSIFY-EMB-006: Temperature=1.0 is identity for softmax
280    ///
281    /// Contract: softmax(x / 1.0) == softmax(x)
282    #[test]
283    fn falsify_emb_006_temperature_identity() {
284        let logits = array![[3.0, 1.0, 0.5, -1.0]];
285
286        let softmax_raw = softmax_2d(&logits);
287        let softmax_t1 = softmax_2d(&(&logits / 1.0));
288
289        for (a, b) in softmax_raw.iter().zip(softmax_t1.iter()) {
290            assert_relative_eq!(a, b, epsilon = 1e-6);
291        }
292    }
293
294    /// FALSIFY-EMB-007: Higher temperature → more uniform distribution
295    ///
296    /// Contract: entropy(softmax(x/T_high)) > entropy(softmax(x/T_low))
297    #[test]
298    fn falsify_emb_007_temperature_monotonicity() {
299        let logits = array![[5.0, 2.0, 0.1, -3.0]];
300
301        let probs_low = softmax_2d(&(&logits / 1.0));
302        let probs_high = softmax_2d(&(&logits / 10.0));
303
304        // Compute Shannon entropy: -Σ p_i * log(p_i)
305        let entropy = |probs: &Array2<f32>| -> f32 {
306            probs.iter().filter(|&&p| p > 1e-10).map(|&p| -p * p.ln()).sum()
307        };
308
309        let h_low = entropy(&probs_low);
310        let h_high = entropy(&probs_high);
311
312        assert!(
313            h_high > h_low,
314            "FALSIFIED EMB-007: higher temperature should increase entropy, got h_low={h_low}, h_high={h_high}"
315        );
316    }
317
318    // =========================================================================
319    // FALSIFY-SM: softmax-kernel-v1.yaml contract (entrenar's softmax_2d)
320    //
321    // Five-Whys (PMAT-354):
322    //   Why 1: entrenar had test_softmax_sums_to_one but no FALSIFY-SM-*
323    //   Why 2: existing test checks 1 property, not all 3 contract invariants
324    //   Why 3: no mapping from softmax-kernel-v1.yaml to entrenar tests
325    //   Why 4: entrenar predates the provable-contracts YAML
326    //   Why 5: distillation softmax was "obviously correct" (3 lines)
327    // =========================================================================
328
329    /// FALSIFY-SM-001: Softmax output sums to 1 per row
330    #[test]
331    fn falsify_sm_001_sums_to_one() {
332        let x = array![[3.0, 1.0, 0.5, -1.0], [-2.0, 0.0, 4.0, 1.0]];
333        let probs = softmax_2d(&x);
334
335        for (idx, row) in probs.axis_iter(Axis(0)).enumerate() {
336            let sum: f32 = row.sum();
337            assert_relative_eq!(sum, 1.0, epsilon = 1e-5);
338            let _ = idx;
339        }
340    }
341
342    /// FALSIFY-SM-002: All softmax outputs strictly positive
343    #[test]
344    fn falsify_sm_002_strictly_positive() {
345        let x = array![[-10.0, -5.0, 0.0, 5.0, 10.0]];
346        let probs = softmax_2d(&x);
347
348        for &p in &probs {
349            assert!(p > 0.0, "FALSIFIED SM-002: softmax output {p} not strictly positive");
350        }
351    }
352
353    /// FALSIFY-SM-003: Order preservation (argmax invariant)
354    #[test]
355    fn falsify_sm_003_order_preservation() {
356        let x = array![[1.0, 5.0, 3.0, 2.0]];
357        let probs = softmax_2d(&x);
358
359        let input_argmax = x
360            .row(0)
361            .iter()
362            .enumerate()
363            .max_by(|(_, a), (_, b)| a.partial_cmp(b).expect("operation should succeed"))
364            .expect("operation should succeed")
365            .0;
366        let output_argmax = probs
367            .row(0)
368            .iter()
369            .enumerate()
370            .max_by(|(_, a), (_, b)| a.partial_cmp(b).expect("operation should succeed"))
371            .expect("operation should succeed")
372            .0;
373
374        assert_eq!(
375            input_argmax, output_argmax,
376            "FALSIFIED SM-003: argmax changed from {input_argmax} to {output_argmax}"
377        );
378    }
379
380    /// FALSIFY-SM-004: Softmax outputs bounded in [0, 1]
381    ///
382    /// Contract: 0 <= softmax(x)_i <= 1 for all i
383    ///
384    /// N-10 escape: IEEE 754 f32 underflow — exp(-200) = 0.0 exactly, so the
385    /// mathematical open interval (0,1) becomes closed [0,1] in floating point.
386    /// This is correct behavior, not a bug.
387    #[test]
388    fn falsify_sm_004_bounded_zero_one() {
389        let x = array![[-100.0, -10.0, 0.0, 10.0, 100.0]];
390        let probs = softmax_2d(&x);
391
392        for &p in &probs {
393            assert!((0.0..=1.0).contains(&p), "FALSIFIED SM-004: softmax output {p} not in [0, 1]");
394        }
395
396        // For moderate inputs, outputs ARE strictly in (0, 1) — no underflow
397        let moderate = array![[1.0, 2.0, 3.0]];
398        let probs_mod = softmax_2d(&moderate);
399        for &p in &probs_mod {
400            assert!(
401                p > 0.0 && p < 1.0,
402                "FALSIFIED SM-004: moderate softmax output {p} not in (0, 1)"
403            );
404        }
405    }
406
407    /// FALSIFY-SM-005: Numerical stability — extreme inputs don't produce NaN/Inf
408    ///
409    /// Contract: softmax is stable for inputs near f32 limits (via max-subtraction trick)
410    #[test]
411    fn falsify_sm_005_numerical_stability() {
412        let x = array![[1000.0, 999.0, 998.0]];
413        let probs = softmax_2d(&x);
414
415        for &p in &probs {
416            assert!(
417                p.is_finite(),
418                "FALSIFIED SM-005: softmax output {p} not finite for extreme inputs"
419            );
420            assert!(
421                p > 0.0,
422                "FALSIFIED SM-005: softmax output {p} not positive for extreme inputs"
423            );
424        }
425
426        let sum: f32 = probs.iter().sum();
427        assert_relative_eq!(sum, 1.0, epsilon = 1e-5);
428    }
429
430    /// FALSIFY-SM-006: Identical elements → uniform distribution
431    ///
432    /// Contract: softmax([c, c, ..., c]) = [1/n, 1/n, ..., 1/n]
433    #[test]
434    fn falsify_sm_006_identical_elements_uniform() {
435        for n in [2, 4, 8, 16] {
436            let data: Vec<f32> = vec![7.0; n];
437            let x = Array2::from_shape_vec((1, n), data).expect("operation should succeed");
438            let probs = softmax_2d(&x);
439
440            let expected = 1.0 / n as f32;
441            for (i, &p) in probs.iter().enumerate() {
442                assert_relative_eq!(p, expected, epsilon = 1e-6);
443                let _ = i;
444            }
445        }
446    }
447
448    /// FALSIFY-SM-009: Single element boundary — softmax([x]) = [1.0]
449    ///
450    /// Contract: YAML SM-005 = softmax of a single element is always 1.0.
451    #[test]
452    fn falsify_sm_009_single_element() {
453        for x in [0.0_f32, 1.0, -1.0, 100.0, -100.0, f32::MIN_POSITIVE] {
454            let t = array![[x]];
455            let probs = softmax_2d(&t);
456            assert!(
457                (probs[[0, 0]] - 1.0).abs() < 1e-6,
458                "FALSIFIED SM-009: softmax([{x}]) = {}, expected 1.0",
459                probs[[0, 0]]
460            );
461        }
462    }
463
464    /// FALSIFY-SM-007: Translation invariance — σ(x + c) = σ(x) for any scalar c
465    ///
466    /// Five-Whys (PMAT-354):
467    ///   Why 1: SM-INV-003 (translation invariance) had ZERO coverage
468    ///   Why 2: max-subtraction trick IMPLEMENTS this but nobody tested it
469    ///   Why 3: foundational to numerical stability but untested
470    ///
471    /// Contract: σ(x + c·1) = σ(x) for any scalar c.
472    #[test]
473    fn falsify_sm_007_translation_invariance() {
474        let base = array![[1.0_f32, 3.0, -2.0, 0.5]];
475        let base_probs = softmax_2d(&base);
476
477        for c in [100.0_f32, -100.0, 0.0, 42.0, -999.0] {
478            let shifted = array![[1.0 + c, 3.0 + c, -2.0 + c, 0.5 + c]];
479            let shifted_probs = softmax_2d(&shifted);
480
481            for (i, (&orig, &shift)) in base_probs.iter().zip(shifted_probs.iter()).enumerate() {
482                assert!(
483                    (orig - shift).abs() < 1e-5,
484                    "FALSIFIED SM-007: σ(x+{c})[{i}] = {shift} != σ(x)[{i}] = {orig}"
485                );
486            }
487        }
488    }
489
490    mod softmax_proptest_falsify {
491        use super::*;
492        use proptest::prelude::*;
493
494        // FALSIFY-SM-001-prop: Normalization for random vectors
495        proptest! {
496            #![proptest_config(ProptestConfig::with_cases(500))]
497            #[test]
498            fn falsify_sm_001_prop_sums_to_one(
499                logits in proptest::collection::vec(-100.0_f32..100.0, 2..64),
500            ) {
501                let n = logits.len();
502                let arr = Array2::from_shape_vec((1, n), logits).expect("operation should succeed");
503                let probs = softmax_2d(&arr);
504                let sum: f32 = probs.row(0).sum();
505                prop_assert!(
506                    (sum - 1.0).abs() < 1e-4,
507                    "FALSIFIED SM-001-prop: sum={} for {} elements", sum, n
508                );
509            }
510        }
511
512        // FALSIFY-SM-002-prop: Positivity for random vectors
513        proptest! {
514            #![proptest_config(ProptestConfig::with_cases(500))]
515            #[test]
516            fn falsify_sm_002_prop_positive(
517                logits in proptest::collection::vec(-500.0_f32..500.0, 2..32),
518            ) {
519                let n = logits.len();
520                let arr = Array2::from_shape_vec((1, n), logits).expect("operation should succeed");
521                let probs = softmax_2d(&arr);
522                for (i, &p) in probs.row(0).iter().enumerate() {
523                    prop_assert!(p >= 0.0, "FALSIFIED SM-002-prop: probs[{}]={} negative", i, p);
524                    prop_assert!(p.is_finite(), "FALSIFIED SM-002-prop: probs[{}]={} non-finite", i, p);
525                }
526            }
527        }
528
529        // FALSIFY-SM-003-prop: Order preservation for random vectors
530        //
531        // Contract: argmax(softmax(x)) = argmax(x) when no duplicate max
532        proptest! {
533            #![proptest_config(ProptestConfig::with_cases(500))]
534            #[test]
535            fn falsify_sm_003_prop_order_preservation(
536                logits in proptest::collection::vec(-50.0_f32..50.0, 2..32),
537            ) {
538                let has_dupes = logits.windows(2).any(|w| (w[0] - w[1]).abs() < 1e-10);
539                if has_dupes {
540                    return Ok(());
541                }
542
543                let n = logits.len();
544                let arr = Array2::from_shape_vec((1, n), logits.clone()).expect("operation should succeed");
545                let probs = softmax_2d(&arr);
546                let input_argmax = logits.iter().enumerate()
547                    .max_by(|(_, a), (_, b)| a.partial_cmp(b).expect("operation should succeed")).expect("operation should succeed").0;
548                let output_argmax = probs.row(0).iter().enumerate()
549                    .max_by(|(_, a), (_, b)| a.partial_cmp(b).expect("operation should succeed")).expect("operation should succeed").0;
550                prop_assert_eq!(
551                    input_argmax, output_argmax,
552                    "FALSIFIED SM-003-prop: argmax {} -> {} for {:?}", input_argmax, output_argmax, logits
553                );
554            }
555        }
556    }
557
558    // =========================================================================
559    // FALSIFY-APR-DISTILL-TRAIN-003 / TRAIN-004: apr-cli-distill-train-v1.yaml
560    //
561    // Pure-math property tests that partial-discharge 2 of 9 falsifiers in the
562    // PROPOSED `apr-cli-distill-train-v1` contract. The full contract still
563    // requires the missing real-training implementation (per spec §35);
564    // these two falsifiers are the math invariants that any future
565    // implementation must preserve.
566    //
567    // Five-Whys:
568    //   Why 1: §35 found `apr distill --stage train` is a stub.
569    //   Why 2: contract `apr-cli-distill-train-v1.yaml` was authored with 9
570    //          falsifiers but 0 are tested.
571    //   Why 3: TRAIN-003/004 are *purely-mathematical* invariants of the
572    //          existing `softmax_2d` + `DistillationLoss` helpers; testable
573    //          today against existing code.
574    //   Why 4: pinning these now means a future real-training PR cannot
575    //          regress the math without tripping these gates.
576    //   Why 5: §26.8 stack-tool-extension methodology — extend apr in
577    //          falsifier-sized slices, never one big PR.
578    // =========================================================================
579
580    /// FALSIFY-APR-DISTILL-TRAIN-003: temperature scaling preserves softmax ranking.
581    ///
582    /// Contract: For any (logits, T>0): argmax(softmax(logits/T)) == argmax(logits).
583    ///
584    /// This is the deterministic-greedy invariant: if knowledge distillation
585    /// were to *reorder* the teacher's argmax under temperature, the student
586    /// would learn a corrupted preference. The max-subtraction trick in
587    /// `softmax_2d` already gives translation-invariance (FALSIFY-SM-007);
588    /// this test pins the *positive-scaling* case across the canonical T set.
589    #[test]
590    fn falsify_apr_distill_train_003_t_scaling_preserves_argmax() {
591        let logits = array![[3.0_f32, 1.0, 0.5, -1.0, 7.0, -3.0, 2.5, 0.0]];
592        let baseline_argmax = logits
593            .row(0)
594            .iter()
595            .enumerate()
596            .max_by(|(_, a), (_, b)| a.partial_cmp(b).expect("operation should succeed"))
597            .expect("operation should succeed")
598            .0;
599
600        for &t in &[1.0_f32, 2.0, 3.0, 5.0, 10.0] {
601            let scaled = &logits / t;
602            let probs = softmax_2d(&scaled);
603            let scaled_argmax = probs
604                .row(0)
605                .iter()
606                .enumerate()
607                .max_by(|(_, a), (_, b)| a.partial_cmp(b).expect("operation should succeed"))
608                .expect("operation should succeed")
609                .0;
610            assert_eq!(
611                baseline_argmax, scaled_argmax,
612                "FALSIFIED APR-DISTILL-TRAIN-003: argmax shifted from {baseline_argmax} to {scaled_argmax} at T={t}"
613            );
614        }
615    }
616
617    /// FALSIFY-APR-DISTILL-TRAIN-004: alpha=1.0 reduces to pure KD.
618    ///
619    /// Contract: At alpha=1.0, total_loss = T*T * kl_loss exactly
620    /// (the (1-alpha)*ce_loss term is zeroed).
621    ///
622    /// This is the alpha-weighting bookkeeping invariant. If alpha-handling
623    /// regresses (e.g., off-by-one on the (1-alpha) coefficient), this gate
624    /// catches it before the test_distillation_loss_basic top-level test does
625    /// — because the top-level test only checks `loss > 0`, which a buggy
626    /// alpha implementation can also satisfy.
627    #[test]
628    fn falsify_apr_distill_train_004_alpha_one_equals_pure_kd() {
629        let student = array![[2.5_f32, 0.7, -1.3, 4.0]];
630        let teacher = array![[1.8_f32, 1.1, -0.2, 3.5]];
631        let labels = vec![3_usize];
632
633        let temperature = 3.0_f32;
634        let alpha_one = DistillationLoss::new(temperature, 1.0);
635        let total_at_alpha_one = alpha_one.forward(&student, &teacher, &labels);
636
637        // Reproduce the kl_loss term directly via the same helpers
638        // forward() uses, scaled by T*T per the formula.
639        let student_soft = softmax_2d(&(&student / temperature));
640        let teacher_soft = softmax_2d(&(&teacher / temperature));
641        let kl = kl_divergence(&teacher_soft, &student_soft);
642        let pure_kd = kl * temperature * temperature;
643
644        assert_relative_eq!(total_at_alpha_one, pure_kd, epsilon = 1e-5);
645    }
646
647    /// FALSIFY-APR-DISTILL-TRAIN-003 (proptest variant): random logit vectors
648    /// across canonical T set must preserve argmax ranking.
649    mod apr_distill_train_proptest {
650        use super::*;
651        use proptest::prelude::*;
652
653        proptest! {
654            #![proptest_config(ProptestConfig::with_cases(200))]
655            #[test]
656            fn falsify_apr_distill_train_003_prop_t_scaling_preserves_argmax(
657                logits in proptest::collection::vec(-50.0_f32..50.0, 2..32),
658            ) {
659                let has_dupes = {
660                    let mut sorted = logits.clone();
661                    sorted.sort_by(|a, b| a.partial_cmp(b).expect("operation should succeed"));
662                    sorted.windows(2).any(|w| (w[0] - w[1]).abs() < 1e-6)
663                };
664                if has_dupes {
665                    return Ok(());
666                }
667
668                let n = logits.len();
669                let baseline_argmax = logits
670                    .iter()
671                    .enumerate()
672                    .max_by(|(_, a), (_, b)| a.partial_cmp(b).expect("operation should succeed"))
673                    .expect("operation should succeed")
674                    .0;
675
676                for &t in &[1.0_f32, 2.0, 3.0, 5.0, 10.0] {
677                    let arr = Array2::from_shape_vec((1, n), logits.clone())
678                        .expect("operation should succeed");
679                    let scaled = &arr / t;
680                    let probs = softmax_2d(&scaled);
681                    let scaled_argmax = probs
682                        .row(0)
683                        .iter()
684                        .enumerate()
685                        .max_by(|(_, a), (_, b)| a.partial_cmp(b).expect("operation should succeed"))
686                        .expect("operation should succeed")
687                        .0;
688                    prop_assert_eq!(
689                        baseline_argmax, scaled_argmax,
690                        "FALSIFIED APR-DISTILL-TRAIN-003-prop: argmax shifted at T={}", t
691                    );
692                }
693            }
694        }
695    }
696}