use ndarray::{Array1, Array2};
use super::utils::{cross_entropy_loss, kl_divergence, log_softmax, softmax};
#[derive(Debug, Clone)]
pub struct DistillationLoss {
pub temperature: f32,
pub alpha: f32,
}
impl Default for DistillationLoss {
fn default() -> Self {
Self::new(4.0, 0.7)
}
}
impl DistillationLoss {
#[must_use]
pub fn new(temperature: f32, alpha: f32) -> Self {
Self { temperature, alpha }
}
pub fn forward_single(
&self,
student_logits: &Array1<f32>,
teacher_logits: &Array1<f32>,
target: usize,
) -> f32 {
let t = self.temperature;
let student_scaled: Array1<f32> = student_logits.mapv(|x| x / t);
let teacher_scaled: Array1<f32> = teacher_logits.mapv(|x| x / t);
let teacher_soft = softmax(&teacher_scaled);
let student_log_soft = log_softmax(&student_scaled);
let kl_loss = kl_divergence(&student_log_soft, &teacher_soft) * t * t;
let ce_loss = cross_entropy_loss(student_logits, target);
self.alpha * kl_loss + (1.0 - self.alpha) * ce_loss
}
pub fn forward(
&self,
student_logits: &Array2<f32>,
teacher_logits: &Array2<f32>,
targets: &[usize],
) -> f32 {
let batch_size = student_logits.nrows();
assert_eq!(batch_size, teacher_logits.nrows());
assert_eq!(batch_size, targets.len());
let mut total_loss = 0.0;
for (i, &target) in targets.iter().enumerate() {
let s_row = student_logits.row(i).to_owned();
let t_row = teacher_logits.row(i).to_owned();
total_loss += self.forward_single(&s_row, &t_row, target);
}
total_loss / batch_size as f32
}
pub fn soft_loss(&self, student_logits: &Array1<f32>, teacher_logits: &Array1<f32>) -> f32 {
let t = self.temperature;
let student_scaled: Array1<f32> = student_logits.mapv(|x| x / t);
let teacher_scaled: Array1<f32> = teacher_logits.mapv(|x| x / t);
let teacher_soft = softmax(&teacher_scaled);
let student_log_soft = log_softmax(&student_scaled);
kl_divergence(&student_log_soft, &teacher_soft) * t * t
}
}