oxicuda_quant/distill/
response.rs1use crate::distill::loss::{DistilLoss, softmax};
15use crate::error::{QuantError, QuantResult};
16
17#[derive(Debug, Clone)]
23pub struct ResponseDistiller {
24 pub soft_loss: DistilLoss,
26 pub hard_label_weight: f32,
28 pub soft_label_weight: f32,
30}
31
32impl ResponseDistiller {
33 #[must_use]
41 pub fn new(soft_loss: DistilLoss, hard_label_weight: f32, soft_label_weight: f32) -> Self {
42 Self {
43 soft_loss,
44 hard_label_weight,
45 soft_label_weight,
46 }
47 }
48
49 #[must_use]
51 pub fn pure_kl(temperature: f32) -> Self {
52 Self::new(DistilLoss::kl_divergence(temperature), 0.0, 1.0)
53 }
54
55 #[must_use]
57 pub fn balanced() -> Self {
58 Self::new(DistilLoss::kl_divergence(4.0), 0.5, 0.5)
59 }
60
61 pub fn compute_loss(
75 &self,
76 student_logits: &[f32],
77 teacher_logits: &[f32],
78 hard_label: usize,
79 ) -> QuantResult<f32> {
80 if student_logits.is_empty() {
81 return Err(QuantError::EmptyInput(
82 "ResponseDistiller: empty student logits",
83 ));
84 }
85 if student_logits.len() != teacher_logits.len() {
86 return Err(QuantError::TeacherStudentMismatch {
87 teacher: teacher_logits.len(),
88 student: student_logits.len(),
89 });
90 }
91 if hard_label >= student_logits.len() {
92 return Err(QuantError::DimensionMismatch {
93 expected: student_logits.len(),
94 got: hard_label + 1,
95 });
96 }
97
98 let soft = self.soft_loss.compute(teacher_logits, student_logits)?;
100
101 let probs = softmax(student_logits);
103 let ce = -(probs[hard_label].max(1e-12).ln());
104
105 Ok(self.hard_label_weight * ce + self.soft_label_weight * soft)
106 }
107
108 pub fn compute_batch_loss(
123 &self,
124 student_batch: &[f32],
125 teacher_batch: &[f32],
126 hard_labels: &[usize],
127 n_classes: usize,
128 ) -> QuantResult<f32> {
129 let batch_size = hard_labels.len();
130 if batch_size == 0 {
131 return Err(QuantError::EmptyInput(
132 "ResponseDistiller::compute_batch_loss",
133 ));
134 }
135 if student_batch.len() != batch_size * n_classes {
136 return Err(QuantError::DimensionMismatch {
137 expected: batch_size * n_classes,
138 got: student_batch.len(),
139 });
140 }
141 if teacher_batch.len() != batch_size * n_classes {
142 return Err(QuantError::DimensionMismatch {
143 expected: batch_size * n_classes,
144 got: teacher_batch.len(),
145 });
146 }
147 let losses: QuantResult<Vec<f32>> = (0..batch_size)
148 .map(|b| {
149 let s = &student_batch[b * n_classes..(b + 1) * n_classes];
150 let t = &teacher_batch[b * n_classes..(b + 1) * n_classes];
151 self.compute_loss(s, t, hard_labels[b])
152 })
153 .collect();
154 let total: f32 = losses?.iter().sum();
155 Ok(total / batch_size as f32)
156 }
157}
158
159#[cfg(test)]
162mod tests {
163 use super::*;
164 use approx::assert_abs_diff_eq;
165
166 #[test]
167 fn pure_kl_zero_when_student_equals_teacher() {
168 let d = ResponseDistiller::pure_kl(1.0);
169 let logits = vec![1.0_f32, 2.0, 3.0];
170 let loss = d.compute_loss(&logits, &logits, 0).unwrap();
172 assert!(
173 loss.abs() < 1e-3,
174 "pure KL with equal logits ≈ 0, got {loss}"
175 );
176 }
177
178 #[test]
179 fn hard_label_only() {
180 let d = ResponseDistiller::new(DistilLoss::mse(), 1.0, 0.0);
182 let student = vec![0.0_f32, 0.0, 10.0]; let teacher = vec![0.0_f32, 0.0, 0.0]; let loss = d.compute_loss(&student, &teacher, 2).unwrap();
186 assert!(
187 loss < 0.1,
188 "CE for correct confident prediction ≈ 0, got {loss}"
189 );
190 }
191
192 #[test]
193 fn hard_label_out_of_range_error() {
194 let d = ResponseDistiller::pure_kl(1.0);
195 let logits = vec![1.0_f32, 2.0, 3.0];
196 assert!(matches!(
197 d.compute_loss(&logits, &logits, 3), Err(QuantError::DimensionMismatch { .. })
199 ));
200 }
201
202 #[test]
203 fn teacher_student_mismatch_error() {
204 let d = ResponseDistiller::pure_kl(1.0);
205 let t = vec![1.0_f32; 3];
206 let s = vec![1.0_f32; 4];
207 assert!(matches!(
208 d.compute_loss(&s, &t, 0),
209 Err(QuantError::TeacherStudentMismatch { .. })
210 ));
211 }
212
213 #[test]
214 fn batch_loss_average() {
215 let d = ResponseDistiller::new(DistilLoss::mse(), 0.0, 1.0);
216 let n_classes = 3;
217 let batch_size = 4;
218 let teacher = vec![1.0_f32; batch_size * n_classes];
219 let student = vec![1.0_f32; batch_size * n_classes];
220 let labels = vec![0_usize; batch_size];
221 let loss = d
222 .compute_batch_loss(&student, &teacher, &labels, n_classes)
223 .unwrap();
224 assert_abs_diff_eq!(loss, 0.0, epsilon = 1e-5);
225 }
226
227 #[test]
228 fn balanced_distiller_construction() {
229 let d = ResponseDistiller::balanced();
230 assert_abs_diff_eq!(d.hard_label_weight, 0.5, epsilon = 1e-7);
231 assert_abs_diff_eq!(d.soft_label_weight, 0.5, epsilon = 1e-7);
232 }
233}