use crate::distill::loss::{DistilLoss, softmax};
use crate::error::{QuantError, QuantResult};
#[derive(Debug, Clone)]
pub struct ResponseDistiller {
pub soft_loss: DistilLoss,
pub hard_label_weight: f32,
pub soft_label_weight: f32,
}
impl ResponseDistiller {
#[must_use]
pub fn new(soft_loss: DistilLoss, hard_label_weight: f32, soft_label_weight: f32) -> Self {
Self {
soft_loss,
hard_label_weight,
soft_label_weight,
}
}
#[must_use]
pub fn pure_kl(temperature: f32) -> Self {
Self::new(DistilLoss::kl_divergence(temperature), 0.0, 1.0)
}
#[must_use]
pub fn balanced() -> Self {
Self::new(DistilLoss::kl_divergence(4.0), 0.5, 0.5)
}
pub fn compute_loss(
&self,
student_logits: &[f32],
teacher_logits: &[f32],
hard_label: usize,
) -> QuantResult<f32> {
if student_logits.is_empty() {
return Err(QuantError::EmptyInput(
"ResponseDistiller: empty student logits",
));
}
if student_logits.len() != teacher_logits.len() {
return Err(QuantError::TeacherStudentMismatch {
teacher: teacher_logits.len(),
student: student_logits.len(),
});
}
if hard_label >= student_logits.len() {
return Err(QuantError::DimensionMismatch {
expected: student_logits.len(),
got: hard_label + 1,
});
}
let soft = self.soft_loss.compute(teacher_logits, student_logits)?;
let probs = softmax(student_logits);
let ce = -(probs[hard_label].max(1e-12).ln());
Ok(self.hard_label_weight * ce + self.soft_label_weight * soft)
}
pub fn compute_batch_loss(
&self,
student_batch: &[f32],
teacher_batch: &[f32],
hard_labels: &[usize],
n_classes: usize,
) -> QuantResult<f32> {
let batch_size = hard_labels.len();
if batch_size == 0 {
return Err(QuantError::EmptyInput(
"ResponseDistiller::compute_batch_loss",
));
}
if student_batch.len() != batch_size * n_classes {
return Err(QuantError::DimensionMismatch {
expected: batch_size * n_classes,
got: student_batch.len(),
});
}
if teacher_batch.len() != batch_size * n_classes {
return Err(QuantError::DimensionMismatch {
expected: batch_size * n_classes,
got: teacher_batch.len(),
});
}
let losses: QuantResult<Vec<f32>> = (0..batch_size)
.map(|b| {
let s = &student_batch[b * n_classes..(b + 1) * n_classes];
let t = &teacher_batch[b * n_classes..(b + 1) * n_classes];
self.compute_loss(s, t, hard_labels[b])
})
.collect();
let total: f32 = losses?.iter().sum();
Ok(total / batch_size as f32)
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
#[test]
fn pure_kl_zero_when_student_equals_teacher() {
let d = ResponseDistiller::pure_kl(1.0);
let logits = vec![1.0_f32, 2.0, 3.0];
let loss = d.compute_loss(&logits, &logits, 0).unwrap();
assert!(
loss.abs() < 1e-3,
"pure KL with equal logits ≈ 0, got {loss}"
);
}
#[test]
fn hard_label_only() {
let d = ResponseDistiller::new(DistilLoss::mse(), 1.0, 0.0);
let student = vec![0.0_f32, 0.0, 10.0]; let teacher = vec![0.0_f32, 0.0, 0.0]; let loss = d.compute_loss(&student, &teacher, 2).unwrap();
assert!(
loss < 0.1,
"CE for correct confident prediction ≈ 0, got {loss}"
);
}
#[test]
fn hard_label_out_of_range_error() {
let d = ResponseDistiller::pure_kl(1.0);
let logits = vec![1.0_f32, 2.0, 3.0];
assert!(matches!(
d.compute_loss(&logits, &logits, 3), Err(QuantError::DimensionMismatch { .. })
));
}
#[test]
fn teacher_student_mismatch_error() {
let d = ResponseDistiller::pure_kl(1.0);
let t = vec![1.0_f32; 3];
let s = vec![1.0_f32; 4];
assert!(matches!(
d.compute_loss(&s, &t, 0),
Err(QuantError::TeacherStudentMismatch { .. })
));
}
#[test]
fn batch_loss_average() {
let d = ResponseDistiller::new(DistilLoss::mse(), 0.0, 1.0);
let n_classes = 3;
let batch_size = 4;
let teacher = vec![1.0_f32; batch_size * n_classes];
let student = vec![1.0_f32; batch_size * n_classes];
let labels = vec![0_usize; batch_size];
let loss = d
.compute_batch_loss(&student, &teacher, &labels, n_classes)
.unwrap();
assert_abs_diff_eq!(loss, 0.0, epsilon = 1e-5);
}
#[test]
fn balanced_distiller_construction() {
let d = ResponseDistiller::balanced();
assert_abs_diff_eq!(d.hard_label_weight, 0.5, epsilon = 1e-7);
assert_abs_diff_eq!(d.soft_label_weight, 0.5, epsilon = 1e-7);
}
}