Skip to main content

oxicuda_quant/distill/
response.rs

1//! # Response-Based Knowledge Distillation
2//!
3//! Distills the teacher's output logit distribution to the student.  The total
4//! training loss combines a hard-label cross-entropy term with a soft-label
5//! distillation term:
6//!
7//! ```text
8//! L = α × CE(student_logits, hard_labels) + β × distil_loss(teacher, student)
9//! ```
10//!
11//! Setting `hard_label_weight = 0` performs pure distillation without ground-truth
12//! labels.
13
14use crate::distill::loss::{DistilLoss, softmax};
15use crate::error::{QuantError, QuantResult};
16
17// ─── ResponseDistiller ───────────────────────────────────────────────────────
18
19/// Response-based knowledge distillation.
20///
21/// Combines a hard-label cross-entropy loss with a soft-label distillation loss.
22#[derive(Debug, Clone)]
23pub struct ResponseDistiller {
24    /// Distillation loss applied to soft targets.
25    pub soft_loss: DistilLoss,
26    /// Weight for the hard-label (cross-entropy) term.
27    pub hard_label_weight: f32,
28    /// Weight for the soft-label (distillation) term.
29    pub soft_label_weight: f32,
30}
31
32impl ResponseDistiller {
33    /// Create a response distiller.
34    ///
35    /// # Parameters
36    ///
37    /// * `soft_loss`          — distillation loss (e.g., KL divergence with temperature).
38    /// * `hard_label_weight`  — weight α for cross-entropy term.
39    /// * `soft_label_weight`  — weight β for distillation term.
40    #[must_use]
41    pub fn new(soft_loss: DistilLoss, hard_label_weight: f32, soft_label_weight: f32) -> Self {
42        Self {
43            soft_loss,
44            hard_label_weight,
45            soft_label_weight,
46        }
47    }
48
49    /// Pure distillation (no hard labels): KL divergence at temperature `tau`.
50    #[must_use]
51    pub fn pure_kl(temperature: f32) -> Self {
52        Self::new(DistilLoss::kl_divergence(temperature), 0.0, 1.0)
53    }
54
55    /// Combined distillation: `0.5 × CE + 0.5 × KL(τ=4)`.
56    #[must_use]
57    pub fn balanced() -> Self {
58        Self::new(DistilLoss::kl_divergence(4.0), 0.5, 0.5)
59    }
60
61    /// Compute the combined distillation loss.
62    ///
63    /// # Parameters
64    ///
65    /// * `student_logits`  — unnormalised student output (length = n_classes).
66    /// * `teacher_logits`  — unnormalised teacher output (same length).
67    /// * `hard_label`      — integer ground-truth class index.
68    ///
69    /// # Errors
70    ///
71    /// * [`QuantError::EmptyInput`]            — either logit slice is empty.
72    /// * [`QuantError::TeacherStudentMismatch`] — logit slices differ in length.
73    /// * [`QuantError::DimensionMismatch`]     — `hard_label` ≥ `n_classes`.
74    pub fn compute_loss(
75        &self,
76        student_logits: &[f32],
77        teacher_logits: &[f32],
78        hard_label: usize,
79    ) -> QuantResult<f32> {
80        if student_logits.is_empty() {
81            return Err(QuantError::EmptyInput(
82                "ResponseDistiller: empty student logits",
83            ));
84        }
85        if student_logits.len() != teacher_logits.len() {
86            return Err(QuantError::TeacherStudentMismatch {
87                teacher: teacher_logits.len(),
88                student: student_logits.len(),
89            });
90        }
91        if hard_label >= student_logits.len() {
92            return Err(QuantError::DimensionMismatch {
93                expected: student_logits.len(),
94                got: hard_label + 1,
95            });
96        }
97
98        // Soft loss component.
99        let soft = self.soft_loss.compute(teacher_logits, student_logits)?;
100
101        // Hard cross-entropy: -log(softmax(student)[hard_label])
102        let probs = softmax(student_logits);
103        let ce = -(probs[hard_label].max(1e-12).ln());
104
105        Ok(self.hard_label_weight * ce + self.soft_label_weight * soft)
106    }
107
108    /// Compute distillation loss over a batch of examples.
109    ///
110    /// Returns the average loss over all examples in the batch.
111    ///
112    /// # Parameters
113    ///
114    /// * `student_batch`   — `[batch_size, n_classes]` row-major student logits.
115    /// * `teacher_batch`   — `[batch_size, n_classes]` row-major teacher logits.
116    /// * `hard_labels`     — `[batch_size]` integer class labels.
117    /// * `n_classes`       — number of output classes.
118    ///
119    /// # Errors
120    ///
121    /// Propagates dimension and empty-input errors.
122    pub fn compute_batch_loss(
123        &self,
124        student_batch: &[f32],
125        teacher_batch: &[f32],
126        hard_labels: &[usize],
127        n_classes: usize,
128    ) -> QuantResult<f32> {
129        let batch_size = hard_labels.len();
130        if batch_size == 0 {
131            return Err(QuantError::EmptyInput(
132                "ResponseDistiller::compute_batch_loss",
133            ));
134        }
135        if student_batch.len() != batch_size * n_classes {
136            return Err(QuantError::DimensionMismatch {
137                expected: batch_size * n_classes,
138                got: student_batch.len(),
139            });
140        }
141        if teacher_batch.len() != batch_size * n_classes {
142            return Err(QuantError::DimensionMismatch {
143                expected: batch_size * n_classes,
144                got: teacher_batch.len(),
145            });
146        }
147        let losses: QuantResult<Vec<f32>> = (0..batch_size)
148            .map(|b| {
149                let s = &student_batch[b * n_classes..(b + 1) * n_classes];
150                let t = &teacher_batch[b * n_classes..(b + 1) * n_classes];
151                self.compute_loss(s, t, hard_labels[b])
152            })
153            .collect();
154        let total: f32 = losses?.iter().sum();
155        Ok(total / batch_size as f32)
156    }
157}
158
159// ─── Tests ───────────────────────────────────────────────────────────────────
160
161#[cfg(test)]
162mod tests {
163    use super::*;
164    use approx::assert_abs_diff_eq;
165
166    #[test]
167    fn pure_kl_zero_when_student_equals_teacher() {
168        let d = ResponseDistiller::pure_kl(1.0);
169        let logits = vec![1.0_f32, 2.0, 3.0];
170        // With identical teacher/student and no hard labels, loss = 0.
171        let loss = d.compute_loss(&logits, &logits, 0).unwrap();
172        assert!(
173            loss.abs() < 1e-3,
174            "pure KL with equal logits ≈ 0, got {loss}"
175        );
176    }
177
178    #[test]
179    fn hard_label_only() {
180        // hard_label_weight=1, soft_label_weight=0
181        let d = ResponseDistiller::new(DistilLoss::mse(), 1.0, 0.0);
182        let student = vec![0.0_f32, 0.0, 10.0]; // strongly predicts class 2
183        let teacher = vec![0.0_f32, 0.0, 0.0]; // doesn't matter
184        // CE = -log(softmax(student)[2]) ≈ 0 (class 2 is dominant)
185        let loss = d.compute_loss(&student, &teacher, 2).unwrap();
186        assert!(
187            loss < 0.1,
188            "CE for correct confident prediction ≈ 0, got {loss}"
189        );
190    }
191
192    #[test]
193    fn hard_label_out_of_range_error() {
194        let d = ResponseDistiller::pure_kl(1.0);
195        let logits = vec![1.0_f32, 2.0, 3.0];
196        assert!(matches!(
197            d.compute_loss(&logits, &logits, 3), // 3 >= n_classes=3
198            Err(QuantError::DimensionMismatch { .. })
199        ));
200    }
201
202    #[test]
203    fn teacher_student_mismatch_error() {
204        let d = ResponseDistiller::pure_kl(1.0);
205        let t = vec![1.0_f32; 3];
206        let s = vec![1.0_f32; 4];
207        assert!(matches!(
208            d.compute_loss(&s, &t, 0),
209            Err(QuantError::TeacherStudentMismatch { .. })
210        ));
211    }
212
213    #[test]
214    fn batch_loss_average() {
215        let d = ResponseDistiller::new(DistilLoss::mse(), 0.0, 1.0);
216        let n_classes = 3;
217        let batch_size = 4;
218        let teacher = vec![1.0_f32; batch_size * n_classes];
219        let student = vec![1.0_f32; batch_size * n_classes];
220        let labels = vec![0_usize; batch_size];
221        let loss = d
222            .compute_batch_loss(&student, &teacher, &labels, n_classes)
223            .unwrap();
224        assert_abs_diff_eq!(loss, 0.0, epsilon = 1e-5);
225    }
226
227    #[test]
228    fn balanced_distiller_construction() {
229        let d = ResponseDistiller::balanced();
230        assert_abs_diff_eq!(d.hard_label_weight, 0.5, epsilon = 1e-7);
231        assert_abs_diff_eq!(d.soft_label_weight, 0.5, epsilon = 1e-7);
232    }
233}