entrenar/hf_pipeline/distillation/
loss.rs1use ndarray::{Array1, Array2};
6
7use super::utils::{cross_entropy_loss, kl_divergence, log_softmax, softmax};
8
9#[derive(Debug, Clone)]
24pub struct DistillationLoss {
25 pub temperature: f32,
27 pub alpha: f32,
29}
30
31impl Default for DistillationLoss {
32 fn default() -> Self {
33 Self::new(4.0, 0.7)
34 }
35}
36
37impl DistillationLoss {
38 #[must_use]
45 pub fn new(temperature: f32, alpha: f32) -> Self {
46 Self { temperature, alpha }
47 }
48
49 pub fn forward_single(
61 &self,
62 student_logits: &Array1<f32>,
63 teacher_logits: &Array1<f32>,
64 target: usize,
65 ) -> f32 {
66 let t = self.temperature;
67
68 let student_scaled: Array1<f32> = student_logits.mapv(|x| x / t);
70 let teacher_scaled: Array1<f32> = teacher_logits.mapv(|x| x / t);
71
72 let teacher_soft = softmax(&teacher_scaled);
74 let student_log_soft = log_softmax(&student_scaled);
75
76 let kl_loss = kl_divergence(&student_log_soft, &teacher_soft) * t * t;
78
79 let ce_loss = cross_entropy_loss(student_logits, target);
81
82 self.alpha * kl_loss + (1.0 - self.alpha) * ce_loss
84 }
85
86 pub fn forward(
98 &self,
99 student_logits: &Array2<f32>,
100 teacher_logits: &Array2<f32>,
101 targets: &[usize],
102 ) -> f32 {
103 let batch_size = student_logits.nrows();
104 assert_eq!(batch_size, teacher_logits.nrows());
105 assert_eq!(batch_size, targets.len());
106
107 let mut total_loss = 0.0;
108 for (i, &target) in targets.iter().enumerate() {
109 let s_row = student_logits.row(i).to_owned();
110 let t_row = teacher_logits.row(i).to_owned();
111 total_loss += self.forward_single(&s_row, &t_row, target);
112 }
113
114 total_loss / batch_size as f32
115 }
116
117 pub fn soft_loss(&self, student_logits: &Array1<f32>, teacher_logits: &Array1<f32>) -> f32 {
119 let t = self.temperature;
120
121 let student_scaled: Array1<f32> = student_logits.mapv(|x| x / t);
122 let teacher_scaled: Array1<f32> = teacher_logits.mapv(|x| x / t);
123
124 let teacher_soft = softmax(&teacher_scaled);
125 let student_log_soft = log_softmax(&student_scaled);
126
127 kl_divergence(&student_log_soft, &teacher_soft) * t * t
128 }
129}