Skip to main content

entrenar/hf_pipeline/distillation/
loss.rs

1//! Knowledge Distillation Loss
2//!
3//! Implements Hinton et al. (2015) distillation loss.
4
5use ndarray::{Array1, Array2};
6
7use super::utils::{cross_entropy_loss, kl_divergence, log_softmax, softmax};
8
9/// Knowledge Distillation Loss
10///
11/// Implements Hinton et al. (2015) distillation loss:
12///
13/// ```text
14/// L_KD = α * T² * KL(softmax(z_s/T) || softmax(z_t/T)) + (1-α) * CE(y, z_s)
15/// ```
16///
17/// Where:
18/// - `z_s` = student logits
19/// - `z_t` = teacher logits
20/// - `T` = temperature (higher = softer targets)
21/// - `α` = weight for distillation loss vs hard label loss
22/// - `y` = ground truth labels
23#[derive(Debug, Clone)]
24pub struct DistillationLoss {
25    /// Temperature for softening distributions (typical: 2-20)
26    pub temperature: f32,
27    /// Weight for soft loss vs hard loss (typical: 0.5-0.9)
28    pub alpha: f32,
29}
30
31impl Default for DistillationLoss {
32    fn default() -> Self {
33        Self::new(4.0, 0.7)
34    }
35}
36
37impl DistillationLoss {
38    /// Create new distillation loss
39    ///
40    /// # Arguments
41    ///
42    /// * `temperature` - Temperature for softening (2-20 typical)
43    /// * `alpha` - Weight for soft loss (0.5-0.9 typical)
44    #[must_use]
45    pub fn new(temperature: f32, alpha: f32) -> Self {
46        Self { temperature, alpha }
47    }
48
49    /// Compute distillation loss for single sample
50    ///
51    /// # Arguments
52    ///
53    /// * `student_logits` - Student model output logits
54    /// * `teacher_logits` - Teacher model output logits
55    /// * `target` - Ground truth label index
56    ///
57    /// # Returns
58    ///
59    /// Combined distillation loss
60    pub fn forward_single(
61        &self,
62        student_logits: &Array1<f32>,
63        teacher_logits: &Array1<f32>,
64        target: usize,
65    ) -> f32 {
66        let t = self.temperature;
67
68        // Temperature-scaled logits
69        let student_scaled: Array1<f32> = student_logits.mapv(|x| x / t);
70        let teacher_scaled: Array1<f32> = teacher_logits.mapv(|x| x / t);
71
72        // Soft targets from teacher
73        let teacher_soft = softmax(&teacher_scaled);
74        let student_log_soft = log_softmax(&student_scaled);
75
76        // KL divergence (scaled by T²)
77        let kl_loss = kl_divergence(&student_log_soft, &teacher_soft) * t * t;
78
79        // Hard label cross-entropy
80        let ce_loss = cross_entropy_loss(student_logits, target);
81
82        // Combined loss
83        self.alpha * kl_loss + (1.0 - self.alpha) * ce_loss
84    }
85
86    /// Compute distillation loss for batch
87    ///
88    /// # Arguments
89    ///
90    /// * `student_logits` - [batch_size, vocab_size]
91    /// * `teacher_logits` - [batch_size, vocab_size]
92    /// * `targets` - Ground truth labels
93    ///
94    /// # Returns
95    ///
96    /// Mean batch loss
97    pub fn forward(
98        &self,
99        student_logits: &Array2<f32>,
100        teacher_logits: &Array2<f32>,
101        targets: &[usize],
102    ) -> f32 {
103        let batch_size = student_logits.nrows();
104        assert_eq!(batch_size, teacher_logits.nrows());
105        assert_eq!(batch_size, targets.len());
106
107        let mut total_loss = 0.0;
108        for (i, &target) in targets.iter().enumerate() {
109            let s_row = student_logits.row(i).to_owned();
110            let t_row = teacher_logits.row(i).to_owned();
111            total_loss += self.forward_single(&s_row, &t_row, target);
112        }
113
114        total_loss / batch_size as f32
115    }
116
117    /// Compute soft loss only (no hard labels)
118    pub fn soft_loss(&self, student_logits: &Array1<f32>, teacher_logits: &Array1<f32>) -> f32 {
119        let t = self.temperature;
120
121        let student_scaled: Array1<f32> = student_logits.mapv(|x| x / t);
122        let teacher_scaled: Array1<f32> = teacher_logits.mapv(|x| x / t);
123
124        let teacher_soft = softmax(&teacher_scaled);
125        let student_log_soft = log_softmax(&student_scaled);
126
127        kl_divergence(&student_log_soft, &teacher_soft) * t * t
128    }
129}