1use ndarray::{Array2, Axis};
4
5#[derive(Debug, Clone)]
28pub struct ProgressiveDistiller {
29 pub layer_weights: Vec<f32>,
31 pub temperature: f32,
33}
34
35impl ProgressiveDistiller {
36 pub fn new(layer_weights: Vec<f32>, temperature: f32) -> Self {
47 assert!(!layer_weights.is_empty(), "Must have at least one layer weight");
48 assert!(temperature > 0.0, "Temperature must be positive, got {temperature}");
49
50 let sum: f32 = layer_weights.iter().sum();
51 assert!(sum > 0.0, "Layer weights must sum to positive value");
52
53 let normalized: Vec<f32> = layer_weights.iter().map(|&w| w / sum).collect();
55
56 Self { layer_weights: normalized, temperature }
57 }
58
59 pub fn uniform(num_layers: usize, temperature: f32) -> Self {
61 Self::new(vec![1.0; num_layers], temperature)
62 }
63
64 pub fn layer_wise_mse_loss(
75 &self,
76 student_hiddens: &[Array2<f32>],
77 teacher_hiddens: &[Array2<f32>],
78 ) -> f32 {
79 assert_eq!(
80 student_hiddens.len(),
81 teacher_hiddens.len(),
82 "Number of layers must match (student vs teacher)"
83 );
84 assert_eq!(
85 student_hiddens.len(),
86 self.layer_weights.len(),
87 "Number of layers must match (student vs weights)"
88 );
89
90 let mut total_loss = 0.0;
91
92 for ((student, teacher), &weight) in
93 student_hiddens.iter().zip(teacher_hiddens).zip(&self.layer_weights)
94 {
95 assert_eq!(
96 student.shape(),
97 teacher.shape(),
98 "Student and teacher hidden states must have same shape"
99 );
100
101 let mse = mse_loss(student, teacher);
102 total_loss += weight * mse;
103 }
104
105 total_loss
106 }
107
108 pub fn layer_wise_cosine_loss(
113 &self,
114 student_hiddens: &[Array2<f32>],
115 teacher_hiddens: &[Array2<f32>],
116 ) -> f32 {
117 assert_eq!(
118 student_hiddens.len(),
119 teacher_hiddens.len(),
120 "Number of layers must match (student vs teacher)"
121 );
122 assert_eq!(
123 student_hiddens.len(),
124 self.layer_weights.len(),
125 "Number of layers must match (student vs weights)"
126 );
127
128 let mut total_loss = 0.0;
129
130 for ((student, teacher), &weight) in
131 student_hiddens.iter().zip(teacher_hiddens).zip(&self.layer_weights)
132 {
133 assert_eq!(
134 student.shape(),
135 teacher.shape(),
136 "Student and teacher hidden states must have same shape"
137 );
138
139 let cos_sim = cosine_similarity(student, teacher);
141 total_loss += weight * (1.0 - cos_sim);
142 }
143
144 total_loss
145 }
146
147 #[allow(clippy::too_many_arguments)]
161 pub fn combined_loss(
162 &self,
163 student_logits: &Array2<f32>,
164 teacher_logits: &Array2<f32>,
165 student_hiddens: &[Array2<f32>],
166 teacher_hiddens: &[Array2<f32>],
167 labels: &[usize],
168 alpha: f32,
169 beta: f32,
170 ) -> f32 {
171 use super::loss::DistillationLoss;
172
173 let logit_loss = DistillationLoss::new(self.temperature, alpha);
175 let logit_distill = logit_loss.forward(student_logits, teacher_logits, labels);
176
177 let hidden_loss = self.layer_wise_cosine_loss(student_hiddens, teacher_hiddens);
179
180 (1.0 - beta) * logit_distill + beta * hidden_loss
182 }
183}
184
185fn mse_loss(student: &Array2<f32>, teacher: &Array2<f32>) -> f32 {
187 assert_eq!(student.shape(), teacher.shape());
188
189 let diff = student - teacher;
190 let squared = diff.mapv(|x| x * x);
191 squared.mean().unwrap_or(0.0)
192}
193
194fn cosine_similarity(student: &Array2<f32>, teacher: &Array2<f32>) -> f32 {
200 assert_eq!(student.shape(), teacher.shape());
201
202 let batch_size = student.nrows();
203 if batch_size == 0 {
204 return 0.0;
205 }
206
207 let mut total_sim = 0.0;
208
209 for (s_row, t_row) in student.axis_iter(Axis(0)).zip(teacher.axis_iter(Axis(0))) {
210 let dot: f32 = s_row.iter().zip(t_row.iter()).map(|(a, b)| a * b).sum();
211 let s_norm: f32 = s_row.iter().map(|x| x * x).sum::<f32>().sqrt();
212 let t_norm: f32 = t_row.iter().map(|x| x * x).sum::<f32>().sqrt();
213
214 if s_norm > 1e-10 && t_norm > 1e-10 {
215 total_sim += dot / (s_norm * t_norm);
216 }
217 }
218
219 total_sim / batch_size as f32
220}
221
222#[cfg(test)]
223mod tests {
224 use super::*;
225 use approx::assert_relative_eq;
226 use ndarray::array;
227
228 #[test]
229 fn test_uniform_progressive() {
230 let distiller = ProgressiveDistiller::uniform(3, 2.0);
231 assert_eq!(distiller.layer_weights.len(), 3);
232 assert_relative_eq!(distiller.layer_weights.iter().sum::<f32>(), 1.0, epsilon = 1e-6);
233 for &w in &distiller.layer_weights {
234 assert_relative_eq!(w, 1.0 / 3.0, epsilon = 1e-6);
235 }
236 }
237
238 #[test]
239 fn test_weighted_progressive() {
240 let distiller = ProgressiveDistiller::new(vec![1.0, 2.0, 3.0], 2.0);
241 assert_relative_eq!(distiller.layer_weights.iter().sum::<f32>(), 1.0, epsilon = 1e-6);
242 }
243
244 #[test]
245 fn test_mse_loss_zero_for_identical() {
246 let a = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
247 let mse = mse_loss(&a, &a);
248 assert_relative_eq!(mse, 0.0, epsilon = 1e-6);
249 }
250
251 #[test]
252 fn test_mse_loss_positive() {
253 let a = array![[1.0, 2.0, 3.0]];
254 let b = array![[2.0, 3.0, 4.0]];
255 let mse = mse_loss(&a, &b);
256 assert!(mse > 0.0);
257 assert_relative_eq!(mse, 1.0, epsilon = 1e-6);
259 }
260
261 #[test]
262 fn test_cosine_similarity_one_for_identical() {
263 let a = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
264 let cos = cosine_similarity(&a, &a);
265 assert_relative_eq!(cos, 1.0, epsilon = 1e-6);
266 }
267
268 #[test]
269 fn test_cosine_similarity_zero_for_orthogonal() {
270 let a = array![[1.0, 0.0]];
271 let b = array![[0.0, 1.0]];
272 let cos = cosine_similarity(&a, &b);
273 assert_relative_eq!(cos, 0.0, epsilon = 1e-6);
274 }
275
276 #[test]
277 fn test_cosine_similarity_positive() {
278 let a = array![[1.0, 2.0, 3.0]];
279 let b = array![[2.0, 4.0, 6.0]]; let cos = cosine_similarity(&a, &b);
281 assert_relative_eq!(cos, 1.0, epsilon = 1e-6);
282 }
283
284 #[test]
285 fn test_layer_wise_mse_loss() {
286 let distiller = ProgressiveDistiller::uniform(2, 2.0);
287
288 let student_hiddens = vec![array![[1.0, 2.0], [3.0, 4.0]], array![[5.0, 6.0], [7.0, 8.0]]];
289 let teacher_hiddens = vec![array![[1.1, 2.1], [3.1, 4.1]], array![[5.1, 6.1], [7.1, 8.1]]];
290
291 let loss = distiller.layer_wise_mse_loss(&student_hiddens, &teacher_hiddens);
292 assert!(loss > 0.0);
293 assert!(loss.is_finite());
294 }
295
296 #[test]
297 fn test_layer_wise_cosine_loss() {
298 let distiller = ProgressiveDistiller::uniform(2, 2.0);
299
300 let student_hiddens = vec![array![[1.0, 2.0], [3.0, 4.0]], array![[5.0, 6.0], [7.0, 8.0]]];
301 let teacher_hiddens = vec![array![[1.1, 2.1], [3.1, 4.1]], array![[5.1, 6.1], [7.1, 8.1]]];
302
303 let loss = distiller.layer_wise_cosine_loss(&student_hiddens, &teacher_hiddens);
304 assert!(loss >= 0.0); assert!(loss.is_finite());
306 }
307
308 #[test]
309 fn test_combined_loss() {
310 let distiller = ProgressiveDistiller::uniform(2, 2.0);
311
312 let student_logits = array![[2.0, 1.0, 0.5]];
313 let teacher_logits = array![[1.8, 1.1, 0.6]];
314
315 let student_hiddens = vec![array![[1.0, 2.0]], array![[3.0, 4.0]]];
316 let teacher_hiddens = vec![array![[1.1, 2.1]], array![[3.1, 4.1]]];
317
318 let labels = vec![0];
319
320 let loss = distiller.combined_loss(
321 &student_logits,
322 &teacher_logits,
323 &student_hiddens,
324 &teacher_hiddens,
325 &labels,
326 0.7, 0.3, );
329
330 assert!(loss > 0.0);
331 assert!(loss.is_finite());
332 }
333
334 #[test]
335 #[should_panic(expected = "Must have at least one layer weight")]
336 fn test_empty_layers_panics() {
337 ProgressiveDistiller::new(vec![], 2.0);
338 }
339
340 #[test]
341 #[should_panic(expected = "Temperature must be positive")]
342 fn test_invalid_temperature_panics() {
343 ProgressiveDistiller::new(vec![1.0], 0.0);
344 }
345
346 #[test]
347 #[should_panic(expected = "Number of layers must match")]
348 fn test_mismatched_layers_panics() {
349 let distiller = ProgressiveDistiller::uniform(2, 2.0);
350 let student = vec![array![[1.0, 2.0]]]; let teacher = vec![array![[1.0, 2.0]], array![[3.0, 4.0]]]; distiller.layer_wise_mse_loss(&student, &teacher);
353 }
354}