1use ndarray::{Array2, Axis};
4
5#[derive(Debug, Clone)]
34pub struct DistillationLoss {
35 pub temperature: f32,
37 pub alpha: f32,
39}
40
41impl DistillationLoss {
42 pub fn new(temperature: f32, alpha: f32) -> Self {
53 assert!(temperature > 0.0, "Temperature must be positive, got {temperature}");
54 assert!((0.0..=1.0).contains(&alpha), "Alpha must be in [0, 1], got {alpha}");
55
56 Self { temperature, alpha }
57 }
58
59 pub fn forward(
71 &self,
72 student_logits: &Array2<f32>,
73 teacher_logits: &Array2<f32>,
74 labels: &[usize],
75 ) -> f32 {
76 assert_eq!(
77 student_logits.shape(),
78 teacher_logits.shape(),
79 "Student and teacher logits must have same shape"
80 );
81 assert_eq!(student_logits.nrows(), labels.len(), "Batch size must match number of labels");
82
83 let kl_loss = self.kl_divergence_loss(student_logits, teacher_logits);
85
86 let ce_loss = self.cross_entropy_loss(student_logits, labels);
88
89 self.alpha * kl_loss * self.temperature * self.temperature + (1.0 - self.alpha) * ce_loss
91 }
92
93 fn kl_divergence_loss(
97 &self,
98 student_logits: &Array2<f32>,
99 teacher_logits: &Array2<f32>,
100 ) -> f32 {
101 let student_soft = softmax_2d(&(student_logits / self.temperature));
102 let teacher_soft = softmax_2d(&(teacher_logits / self.temperature));
103
104 kl_divergence(&teacher_soft, &student_soft)
105 }
106
107 fn cross_entropy_loss(&self, logits: &Array2<f32>, labels: &[usize]) -> f32 {
109 let probs = softmax_2d(logits);
110
111 let mut loss = 0.0;
112 for (i, &label) in labels.iter().enumerate() {
113 let prob = probs[[i, label]].max(1e-10); loss -= prob.max(f32::MIN_POSITIVE).ln();
115 }
116
117 loss / labels.len().max(1) as f32
118 }
119}
120
121fn softmax_2d(x: &Array2<f32>) -> Array2<f32> {
125 let mut result = x.clone();
126
127 for mut row in result.axis_iter_mut(Axis(0)) {
128 let max_val = row.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
130 row.mapv_inplace(|v| (v - max_val).exp());
131
132 let sum: f32 = row.sum();
134 row.mapv_inplace(|v| v / sum);
135 }
136
137 result
138}
139
140fn kl_divergence(p: &Array2<f32>, q: &Array2<f32>) -> f32 {
146 assert_eq!(p.shape(), q.shape());
147
148 if p.nrows() == 0 {
149 return 0.0;
150 }
151
152 let mut total_kl = 0.0;
153
154 for (p_row, q_row) in p.axis_iter(Axis(0)).zip(q.axis_iter(Axis(0))) {
155 let mut kl = 0.0;
156 for (&p_i, &q_i) in p_row.iter().zip(q_row.iter()) {
157 if p_i > 1e-10 {
158 kl += p_i * (p_i / q_i.max(1e-10)).ln();
160 }
161 }
162 total_kl += kl;
163 }
164
165 total_kl / p.nrows() as f32
166}
167
168#[cfg(test)]
169mod tests {
170 use super::*;
171 use approx::assert_relative_eq;
172 use ndarray::array;
173
174 #[test]
175 fn test_distillation_loss_basic() {
176 let loss_fn = DistillationLoss::new(2.0, 0.5);
177 let student = array![[2.0, 1.0, 0.5]];
178 let teacher = array![[1.5, 1.2, 0.8]];
179 let labels = vec![0];
180
181 let loss = loss_fn.forward(&student, &teacher, &labels);
182 assert!(loss > 0.0);
183 assert!(loss.is_finite());
184 }
185
186 #[test]
187 fn test_softmax_sums_to_one() {
188 let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
189 let probs = softmax_2d(&x);
190
191 for row in probs.axis_iter(Axis(0)) {
192 let sum: f32 = row.sum();
193 assert_relative_eq!(sum, 1.0, epsilon = 1e-6);
194 }
195 }
196
197 #[test]
198 fn test_kl_divergence_zero_for_identical() {
199 let p = array![[0.7, 0.2, 0.1], [0.5, 0.3, 0.2]];
200 let kl = kl_divergence(&p, &p);
201 assert_relative_eq!(kl, 0.0, epsilon = 1e-6);
202 }
203
204 #[test]
205 fn test_kl_divergence_positive() {
206 let p = array![[0.7, 0.2, 0.1]];
207 let q = array![[0.4, 0.4, 0.2]];
208 let kl = kl_divergence(&p, &q);
209 assert!(kl > 0.0);
210 }
211
212 #[test]
213 #[should_panic(expected = "Temperature must be positive")]
214 fn test_negative_temperature_panics() {
215 DistillationLoss::new(-1.0, 0.5);
216 }
217
218 #[test]
219 #[should_panic(expected = "Alpha must be in [0, 1]")]
220 fn test_invalid_alpha_panics() {
221 DistillationLoss::new(2.0, 1.5);
222 }
223
224 #[test]
225 fn test_temperature_effect() {
226 let student = array![[10.0, 1.0, 0.1]];
227 let teacher = array![[5.0, 4.0, 3.0]];
228 let labels = vec![0];
229
230 let low_temp_loss = DistillationLoss::new(1.0, 1.0);
231 let high_temp_loss = DistillationLoss::new(5.0, 1.0);
232
233 let loss_low = low_temp_loss.forward(&student, &teacher, &labels);
234 let loss_high = high_temp_loss.forward(&student, &teacher, &labels);
235
236 assert!(loss_low != loss_high);
238 }
239
240 #[test]
241 fn test_alpha_balances_losses() {
242 let student = array![[2.0, 1.0, 0.5]];
243 let teacher = array![[1.5, 1.2, 0.8]];
244 let labels = vec![0];
245
246 let pure_distill = DistillationLoss::new(2.0, 1.0);
248 let loss_distill = pure_distill.forward(&student, &teacher, &labels);
249
250 let pure_hard = DistillationLoss::new(2.0, 0.0);
252 let loss_hard = pure_hard.forward(&student, &teacher, &labels);
253
254 let balanced = DistillationLoss::new(2.0, 0.5);
256 let loss_balanced = balanced.forward(&student, &teacher, &labels);
257
258 assert!(loss_balanced > 0.0);
260 assert!(loss_distill > 0.0);
261 assert!(loss_hard > 0.0);
262 }
263
264 #[test]
283 fn falsify_emb_006_temperature_identity() {
284 let logits = array![[3.0, 1.0, 0.5, -1.0]];
285
286 let softmax_raw = softmax_2d(&logits);
287 let softmax_t1 = softmax_2d(&(&logits / 1.0));
288
289 for (a, b) in softmax_raw.iter().zip(softmax_t1.iter()) {
290 assert_relative_eq!(a, b, epsilon = 1e-6);
291 }
292 }
293
294 #[test]
298 fn falsify_emb_007_temperature_monotonicity() {
299 let logits = array![[5.0, 2.0, 0.1, -3.0]];
300
301 let probs_low = softmax_2d(&(&logits / 1.0));
302 let probs_high = softmax_2d(&(&logits / 10.0));
303
304 let entropy = |probs: &Array2<f32>| -> f32 {
306 probs.iter().filter(|&&p| p > 1e-10).map(|&p| -p * p.ln()).sum()
307 };
308
309 let h_low = entropy(&probs_low);
310 let h_high = entropy(&probs_high);
311
312 assert!(
313 h_high > h_low,
314 "FALSIFIED EMB-007: higher temperature should increase entropy, got h_low={h_low}, h_high={h_high}"
315 );
316 }
317
318 #[test]
331 fn falsify_sm_001_sums_to_one() {
332 let x = array![[3.0, 1.0, 0.5, -1.0], [-2.0, 0.0, 4.0, 1.0]];
333 let probs = softmax_2d(&x);
334
335 for (idx, row) in probs.axis_iter(Axis(0)).enumerate() {
336 let sum: f32 = row.sum();
337 assert_relative_eq!(sum, 1.0, epsilon = 1e-5);
338 let _ = idx;
339 }
340 }
341
342 #[test]
344 fn falsify_sm_002_strictly_positive() {
345 let x = array![[-10.0, -5.0, 0.0, 5.0, 10.0]];
346 let probs = softmax_2d(&x);
347
348 for &p in &probs {
349 assert!(p > 0.0, "FALSIFIED SM-002: softmax output {p} not strictly positive");
350 }
351 }
352
353 #[test]
355 fn falsify_sm_003_order_preservation() {
356 let x = array![[1.0, 5.0, 3.0, 2.0]];
357 let probs = softmax_2d(&x);
358
359 let input_argmax = x
360 .row(0)
361 .iter()
362 .enumerate()
363 .max_by(|(_, a), (_, b)| a.partial_cmp(b).expect("operation should succeed"))
364 .expect("operation should succeed")
365 .0;
366 let output_argmax = probs
367 .row(0)
368 .iter()
369 .enumerate()
370 .max_by(|(_, a), (_, b)| a.partial_cmp(b).expect("operation should succeed"))
371 .expect("operation should succeed")
372 .0;
373
374 assert_eq!(
375 input_argmax, output_argmax,
376 "FALSIFIED SM-003: argmax changed from {input_argmax} to {output_argmax}"
377 );
378 }
379
380 #[test]
388 fn falsify_sm_004_bounded_zero_one() {
389 let x = array![[-100.0, -10.0, 0.0, 10.0, 100.0]];
390 let probs = softmax_2d(&x);
391
392 for &p in &probs {
393 assert!((0.0..=1.0).contains(&p), "FALSIFIED SM-004: softmax output {p} not in [0, 1]");
394 }
395
396 let moderate = array![[1.0, 2.0, 3.0]];
398 let probs_mod = softmax_2d(&moderate);
399 for &p in &probs_mod {
400 assert!(
401 p > 0.0 && p < 1.0,
402 "FALSIFIED SM-004: moderate softmax output {p} not in (0, 1)"
403 );
404 }
405 }
406
407 #[test]
411 fn falsify_sm_005_numerical_stability() {
412 let x = array![[1000.0, 999.0, 998.0]];
413 let probs = softmax_2d(&x);
414
415 for &p in &probs {
416 assert!(
417 p.is_finite(),
418 "FALSIFIED SM-005: softmax output {p} not finite for extreme inputs"
419 );
420 assert!(
421 p > 0.0,
422 "FALSIFIED SM-005: softmax output {p} not positive for extreme inputs"
423 );
424 }
425
426 let sum: f32 = probs.iter().sum();
427 assert_relative_eq!(sum, 1.0, epsilon = 1e-5);
428 }
429
430 #[test]
434 fn falsify_sm_006_identical_elements_uniform() {
435 for n in [2, 4, 8, 16] {
436 let data: Vec<f32> = vec![7.0; n];
437 let x = Array2::from_shape_vec((1, n), data).expect("operation should succeed");
438 let probs = softmax_2d(&x);
439
440 let expected = 1.0 / n as f32;
441 for (i, &p) in probs.iter().enumerate() {
442 assert_relative_eq!(p, expected, epsilon = 1e-6);
443 let _ = i;
444 }
445 }
446 }
447
448 #[test]
452 fn falsify_sm_009_single_element() {
453 for x in [0.0_f32, 1.0, -1.0, 100.0, -100.0, f32::MIN_POSITIVE] {
454 let t = array![[x]];
455 let probs = softmax_2d(&t);
456 assert!(
457 (probs[[0, 0]] - 1.0).abs() < 1e-6,
458 "FALSIFIED SM-009: softmax([{x}]) = {}, expected 1.0",
459 probs[[0, 0]]
460 );
461 }
462 }
463
464 #[test]
473 fn falsify_sm_007_translation_invariance() {
474 let base = array![[1.0_f32, 3.0, -2.0, 0.5]];
475 let base_probs = softmax_2d(&base);
476
477 for c in [100.0_f32, -100.0, 0.0, 42.0, -999.0] {
478 let shifted = array![[1.0 + c, 3.0 + c, -2.0 + c, 0.5 + c]];
479 let shifted_probs = softmax_2d(&shifted);
480
481 for (i, (&orig, &shift)) in base_probs.iter().zip(shifted_probs.iter()).enumerate() {
482 assert!(
483 (orig - shift).abs() < 1e-5,
484 "FALSIFIED SM-007: σ(x+{c})[{i}] = {shift} != σ(x)[{i}] = {orig}"
485 );
486 }
487 }
488 }
489
490 mod softmax_proptest_falsify {
491 use super::*;
492 use proptest::prelude::*;
493
494 proptest! {
496 #![proptest_config(ProptestConfig::with_cases(500))]
497 #[test]
498 fn falsify_sm_001_prop_sums_to_one(
499 logits in proptest::collection::vec(-100.0_f32..100.0, 2..64),
500 ) {
501 let n = logits.len();
502 let arr = Array2::from_shape_vec((1, n), logits).expect("operation should succeed");
503 let probs = softmax_2d(&arr);
504 let sum: f32 = probs.row(0).sum();
505 prop_assert!(
506 (sum - 1.0).abs() < 1e-4,
507 "FALSIFIED SM-001-prop: sum={} for {} elements", sum, n
508 );
509 }
510 }
511
512 proptest! {
514 #![proptest_config(ProptestConfig::with_cases(500))]
515 #[test]
516 fn falsify_sm_002_prop_positive(
517 logits in proptest::collection::vec(-500.0_f32..500.0, 2..32),
518 ) {
519 let n = logits.len();
520 let arr = Array2::from_shape_vec((1, n), logits).expect("operation should succeed");
521 let probs = softmax_2d(&arr);
522 for (i, &p) in probs.row(0).iter().enumerate() {
523 prop_assert!(p >= 0.0, "FALSIFIED SM-002-prop: probs[{}]={} negative", i, p);
524 prop_assert!(p.is_finite(), "FALSIFIED SM-002-prop: probs[{}]={} non-finite", i, p);
525 }
526 }
527 }
528
529 proptest! {
533 #![proptest_config(ProptestConfig::with_cases(500))]
534 #[test]
535 fn falsify_sm_003_prop_order_preservation(
536 logits in proptest::collection::vec(-50.0_f32..50.0, 2..32),
537 ) {
538 let has_dupes = logits.windows(2).any(|w| (w[0] - w[1]).abs() < 1e-10);
539 if has_dupes {
540 return Ok(());
541 }
542
543 let n = logits.len();
544 let arr = Array2::from_shape_vec((1, n), logits.clone()).expect("operation should succeed");
545 let probs = softmax_2d(&arr);
546 let input_argmax = logits.iter().enumerate()
547 .max_by(|(_, a), (_, b)| a.partial_cmp(b).expect("operation should succeed")).expect("operation should succeed").0;
548 let output_argmax = probs.row(0).iter().enumerate()
549 .max_by(|(_, a), (_, b)| a.partial_cmp(b).expect("operation should succeed")).expect("operation should succeed").0;
550 prop_assert_eq!(
551 input_argmax, output_argmax,
552 "FALSIFIED SM-003-prop: argmax {} -> {} for {:?}", input_argmax, output_argmax, logits
553 );
554 }
555 }
556 }
557
558 #[test]
590 fn falsify_apr_distill_train_003_t_scaling_preserves_argmax() {
591 let logits = array![[3.0_f32, 1.0, 0.5, -1.0, 7.0, -3.0, 2.5, 0.0]];
592 let baseline_argmax = logits
593 .row(0)
594 .iter()
595 .enumerate()
596 .max_by(|(_, a), (_, b)| a.partial_cmp(b).expect("operation should succeed"))
597 .expect("operation should succeed")
598 .0;
599
600 for &t in &[1.0_f32, 2.0, 3.0, 5.0, 10.0] {
601 let scaled = &logits / t;
602 let probs = softmax_2d(&scaled);
603 let scaled_argmax = probs
604 .row(0)
605 .iter()
606 .enumerate()
607 .max_by(|(_, a), (_, b)| a.partial_cmp(b).expect("operation should succeed"))
608 .expect("operation should succeed")
609 .0;
610 assert_eq!(
611 baseline_argmax, scaled_argmax,
612 "FALSIFIED APR-DISTILL-TRAIN-003: argmax shifted from {baseline_argmax} to {scaled_argmax} at T={t}"
613 );
614 }
615 }
616
617 #[test]
628 fn falsify_apr_distill_train_004_alpha_one_equals_pure_kd() {
629 let student = array![[2.5_f32, 0.7, -1.3, 4.0]];
630 let teacher = array![[1.8_f32, 1.1, -0.2, 3.5]];
631 let labels = vec![3_usize];
632
633 let temperature = 3.0_f32;
634 let alpha_one = DistillationLoss::new(temperature, 1.0);
635 let total_at_alpha_one = alpha_one.forward(&student, &teacher, &labels);
636
637 let student_soft = softmax_2d(&(&student / temperature));
640 let teacher_soft = softmax_2d(&(&teacher / temperature));
641 let kl = kl_divergence(&teacher_soft, &student_soft);
642 let pure_kd = kl * temperature * temperature;
643
644 assert_relative_eq!(total_at_alpha_one, pure_kd, epsilon = 1e-5);
645 }
646
647 mod apr_distill_train_proptest {
650 use super::*;
651 use proptest::prelude::*;
652
653 proptest! {
654 #![proptest_config(ProptestConfig::with_cases(200))]
655 #[test]
656 fn falsify_apr_distill_train_003_prop_t_scaling_preserves_argmax(
657 logits in proptest::collection::vec(-50.0_f32..50.0, 2..32),
658 ) {
659 let has_dupes = {
660 let mut sorted = logits.clone();
661 sorted.sort_by(|a, b| a.partial_cmp(b).expect("operation should succeed"));
662 sorted.windows(2).any(|w| (w[0] - w[1]).abs() < 1e-6)
663 };
664 if has_dupes {
665 return Ok(());
666 }
667
668 let n = logits.len();
669 let baseline_argmax = logits
670 .iter()
671 .enumerate()
672 .max_by(|(_, a), (_, b)| a.partial_cmp(b).expect("operation should succeed"))
673 .expect("operation should succeed")
674 .0;
675
676 for &t in &[1.0_f32, 2.0, 3.0, 5.0, 10.0] {
677 let arr = Array2::from_shape_vec((1, n), logits.clone())
678 .expect("operation should succeed");
679 let scaled = &arr / t;
680 let probs = softmax_2d(&scaled);
681 let scaled_argmax = probs
682 .row(0)
683 .iter()
684 .enumerate()
685 .max_by(|(_, a), (_, b)| a.partial_cmp(b).expect("operation should succeed"))
686 .expect("operation should succeed")
687 .0;
688 prop_assert_eq!(
689 baseline_argmax, scaled_argmax,
690 "FALSIFIED APR-DISTILL-TRAIN-003-prop: argmax shifted at T={}", t
691 );
692 }
693 }
694 }
695 }
696}