Skip to main content

entrenar/distill/
progressive.rs

1//! Progressive layer-wise distillation
2
3use ndarray::{Array2, Axis};
4
5/// Progressive Layer-wise Distillation
6///
7/// Distills knowledge progressively from intermediate layers of the teacher
8/// to corresponding layers of the student. This helps the student learn
9/// better intermediate representations, not just final predictions.
10///
11/// # Approach
12///
13/// Instead of only matching final logits, progressive distillation also
14/// matches intermediate layer outputs (hidden states, attention weights, etc.)
15/// between teacher and student at multiple depths.
16///
17/// # Example
18///
19/// ```
20/// use entrenar::distill::ProgressiveDistiller;
21///
22/// let distiller = ProgressiveDistiller::new(vec![1.0, 1.0, 2.0], 2.0);
23///
24/// // Match intermediate layers (e.g., layer 3, 6, 9 of teacher to student)
25/// // let loss = distiller.layer_wise_loss(&student_hiddens, &teacher_hiddens);
26/// ```
27#[derive(Debug, Clone)]
28pub struct ProgressiveDistiller {
29    /// Weight for each layer's distillation loss
30    pub layer_weights: Vec<f32>,
31    /// Temperature for distillation
32    pub temperature: f32,
33}
34
35impl ProgressiveDistiller {
36    /// Create a new progressive distiller
37    ///
38    /// # Arguments
39    ///
40    /// * `layer_weights` - Weight for each layer (will be normalized)
41    /// * `temperature` - Temperature for softening distributions
42    ///
43    /// # Panics
44    ///
45    /// Panics if layer_weights is empty or temperature <= 0
46    pub fn new(layer_weights: Vec<f32>, temperature: f32) -> Self {
47        assert!(!layer_weights.is_empty(), "Must have at least one layer weight");
48        assert!(temperature > 0.0, "Temperature must be positive, got {temperature}");
49
50        let sum: f32 = layer_weights.iter().sum();
51        assert!(sum > 0.0, "Layer weights must sum to positive value");
52
53        // Normalize weights
54        let normalized: Vec<f32> = layer_weights.iter().map(|&w| w / sum).collect();
55
56        Self { layer_weights: normalized, temperature }
57    }
58
59    /// Create progressive distiller with uniform layer weights
60    pub fn uniform(num_layers: usize, temperature: f32) -> Self {
61        Self::new(vec![1.0; num_layers], temperature)
62    }
63
64    /// Compute layer-wise MSE loss between student and teacher hidden states
65    ///
66    /// # Arguments
67    ///
68    /// * `student_hiddens` - Hidden states from student layers `[num_layers]`
69    /// * `teacher_hiddens` - Hidden states from teacher layers `[num_layers]`
70    ///
71    /// # Returns
72    ///
73    /// Weighted MSE loss across all layers
74    pub fn layer_wise_mse_loss(
75        &self,
76        student_hiddens: &[Array2<f32>],
77        teacher_hiddens: &[Array2<f32>],
78    ) -> f32 {
79        assert_eq!(
80            student_hiddens.len(),
81            teacher_hiddens.len(),
82            "Number of layers must match (student vs teacher)"
83        );
84        assert_eq!(
85            student_hiddens.len(),
86            self.layer_weights.len(),
87            "Number of layers must match (student vs weights)"
88        );
89
90        let mut total_loss = 0.0;
91
92        for ((student, teacher), &weight) in
93            student_hiddens.iter().zip(teacher_hiddens).zip(&self.layer_weights)
94        {
95            assert_eq!(
96                student.shape(),
97                teacher.shape(),
98                "Student and teacher hidden states must have same shape"
99            );
100
101            let mse = mse_loss(student, teacher);
102            total_loss += weight * mse;
103        }
104
105        total_loss
106    }
107
108    /// Compute layer-wise cosine similarity loss
109    ///
110    /// Encourages student representations to have similar direction to teacher,
111    /// which can be more robust than MSE.
112    pub fn layer_wise_cosine_loss(
113        &self,
114        student_hiddens: &[Array2<f32>],
115        teacher_hiddens: &[Array2<f32>],
116    ) -> f32 {
117        assert_eq!(
118            student_hiddens.len(),
119            teacher_hiddens.len(),
120            "Number of layers must match (student vs teacher)"
121        );
122        assert_eq!(
123            student_hiddens.len(),
124            self.layer_weights.len(),
125            "Number of layers must match (student vs weights)"
126        );
127
128        let mut total_loss = 0.0;
129
130        for ((student, teacher), &weight) in
131            student_hiddens.iter().zip(teacher_hiddens).zip(&self.layer_weights)
132        {
133            assert_eq!(
134                student.shape(),
135                teacher.shape(),
136                "Student and teacher hidden states must have same shape"
137            );
138
139            // Cosine loss = 1 - cosine_similarity
140            let cos_sim = cosine_similarity(student, teacher);
141            total_loss += weight * (1.0 - cos_sim);
142        }
143
144        total_loss
145    }
146
147    /// Combined progressive distillation loss
148    ///
149    /// Combines final logit distillation with intermediate layer matching.
150    ///
151    /// # Arguments
152    ///
153    /// * `student_logits` - Final logits from student
154    /// * `teacher_logits` - Final logits from teacher
155    /// * `student_hiddens` - Intermediate hidden states from student
156    /// * `teacher_hiddens` - Intermediate hidden states from teacher
157    /// * `labels` - Ground truth labels
158    /// * `alpha` - Weight for distillation vs hard loss
159    /// * `beta` - Weight for hidden state matching vs logit matching
160    #[allow(clippy::too_many_arguments)]
161    pub fn combined_loss(
162        &self,
163        student_logits: &Array2<f32>,
164        teacher_logits: &Array2<f32>,
165        student_hiddens: &[Array2<f32>],
166        teacher_hiddens: &[Array2<f32>],
167        labels: &[usize],
168        alpha: f32,
169        beta: f32,
170    ) -> f32 {
171        use super::loss::DistillationLoss;
172
173        // Final logit distillation
174        let logit_loss = DistillationLoss::new(self.temperature, alpha);
175        let logit_distill = logit_loss.forward(student_logits, teacher_logits, labels);
176
177        // Intermediate layer matching
178        let hidden_loss = self.layer_wise_cosine_loss(student_hiddens, teacher_hiddens);
179
180        // Combine
181        (1.0 - beta) * logit_distill + beta * hidden_loss
182    }
183}
184
185/// Compute MSE loss between two arrays
186fn mse_loss(student: &Array2<f32>, teacher: &Array2<f32>) -> f32 {
187    assert_eq!(student.shape(), teacher.shape());
188
189    let diff = student - teacher;
190    let squared = diff.mapv(|x| x * x);
191    squared.mean().unwrap_or(0.0)
192}
193
194/// Compute cosine similarity between two arrays
195///
196/// cosine_sim(a, b) = (a ยท b) / (||a|| * ||b||)
197///
198/// Averaged over batch dimension.
199fn cosine_similarity(student: &Array2<f32>, teacher: &Array2<f32>) -> f32 {
200    assert_eq!(student.shape(), teacher.shape());
201
202    let batch_size = student.nrows();
203    if batch_size == 0 {
204        return 0.0;
205    }
206
207    let mut total_sim = 0.0;
208
209    for (s_row, t_row) in student.axis_iter(Axis(0)).zip(teacher.axis_iter(Axis(0))) {
210        let dot: f32 = s_row.iter().zip(t_row.iter()).map(|(a, b)| a * b).sum();
211        let s_norm: f32 = s_row.iter().map(|x| x * x).sum::<f32>().sqrt();
212        let t_norm: f32 = t_row.iter().map(|x| x * x).sum::<f32>().sqrt();
213
214        if s_norm > 1e-10 && t_norm > 1e-10 {
215            total_sim += dot / (s_norm * t_norm);
216        }
217    }
218
219    total_sim / batch_size as f32
220}
221
222#[cfg(test)]
223mod tests {
224    use super::*;
225    use approx::assert_relative_eq;
226    use ndarray::array;
227
228    #[test]
229    fn test_uniform_progressive() {
230        let distiller = ProgressiveDistiller::uniform(3, 2.0);
231        assert_eq!(distiller.layer_weights.len(), 3);
232        assert_relative_eq!(distiller.layer_weights.iter().sum::<f32>(), 1.0, epsilon = 1e-6);
233        for &w in &distiller.layer_weights {
234            assert_relative_eq!(w, 1.0 / 3.0, epsilon = 1e-6);
235        }
236    }
237
238    #[test]
239    fn test_weighted_progressive() {
240        let distiller = ProgressiveDistiller::new(vec![1.0, 2.0, 3.0], 2.0);
241        assert_relative_eq!(distiller.layer_weights.iter().sum::<f32>(), 1.0, epsilon = 1e-6);
242    }
243
244    #[test]
245    fn test_mse_loss_zero_for_identical() {
246        let a = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
247        let mse = mse_loss(&a, &a);
248        assert_relative_eq!(mse, 0.0, epsilon = 1e-6);
249    }
250
251    #[test]
252    fn test_mse_loss_positive() {
253        let a = array![[1.0, 2.0, 3.0]];
254        let b = array![[2.0, 3.0, 4.0]];
255        let mse = mse_loss(&a, &b);
256        assert!(mse > 0.0);
257        // MSE = mean((1,1,1)^2) = 1.0
258        assert_relative_eq!(mse, 1.0, epsilon = 1e-6);
259    }
260
261    #[test]
262    fn test_cosine_similarity_one_for_identical() {
263        let a = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
264        let cos = cosine_similarity(&a, &a);
265        assert_relative_eq!(cos, 1.0, epsilon = 1e-6);
266    }
267
268    #[test]
269    fn test_cosine_similarity_zero_for_orthogonal() {
270        let a = array![[1.0, 0.0]];
271        let b = array![[0.0, 1.0]];
272        let cos = cosine_similarity(&a, &b);
273        assert_relative_eq!(cos, 0.0, epsilon = 1e-6);
274    }
275
276    #[test]
277    fn test_cosine_similarity_positive() {
278        let a = array![[1.0, 2.0, 3.0]];
279        let b = array![[2.0, 4.0, 6.0]]; // Same direction, scaled
280        let cos = cosine_similarity(&a, &b);
281        assert_relative_eq!(cos, 1.0, epsilon = 1e-6);
282    }
283
284    #[test]
285    fn test_layer_wise_mse_loss() {
286        let distiller = ProgressiveDistiller::uniform(2, 2.0);
287
288        let student_hiddens = vec![array![[1.0, 2.0], [3.0, 4.0]], array![[5.0, 6.0], [7.0, 8.0]]];
289        let teacher_hiddens = vec![array![[1.1, 2.1], [3.1, 4.1]], array![[5.1, 6.1], [7.1, 8.1]]];
290
291        let loss = distiller.layer_wise_mse_loss(&student_hiddens, &teacher_hiddens);
292        assert!(loss > 0.0);
293        assert!(loss.is_finite());
294    }
295
296    #[test]
297    fn test_layer_wise_cosine_loss() {
298        let distiller = ProgressiveDistiller::uniform(2, 2.0);
299
300        let student_hiddens = vec![array![[1.0, 2.0], [3.0, 4.0]], array![[5.0, 6.0], [7.0, 8.0]]];
301        let teacher_hiddens = vec![array![[1.1, 2.1], [3.1, 4.1]], array![[5.1, 6.1], [7.1, 8.1]]];
302
303        let loss = distiller.layer_wise_cosine_loss(&student_hiddens, &teacher_hiddens);
304        assert!(loss >= 0.0); // Cosine loss should be >= 0
305        assert!(loss.is_finite());
306    }
307
308    #[test]
309    fn test_combined_loss() {
310        let distiller = ProgressiveDistiller::uniform(2, 2.0);
311
312        let student_logits = array![[2.0, 1.0, 0.5]];
313        let teacher_logits = array![[1.8, 1.1, 0.6]];
314
315        let student_hiddens = vec![array![[1.0, 2.0]], array![[3.0, 4.0]]];
316        let teacher_hiddens = vec![array![[1.1, 2.1]], array![[3.1, 4.1]]];
317
318        let labels = vec![0];
319
320        let loss = distiller.combined_loss(
321            &student_logits,
322            &teacher_logits,
323            &student_hiddens,
324            &teacher_hiddens,
325            &labels,
326            0.7, // alpha
327            0.3, // beta
328        );
329
330        assert!(loss > 0.0);
331        assert!(loss.is_finite());
332    }
333
334    #[test]
335    #[should_panic(expected = "Must have at least one layer weight")]
336    fn test_empty_layers_panics() {
337        ProgressiveDistiller::new(vec![], 2.0);
338    }
339
340    #[test]
341    #[should_panic(expected = "Temperature must be positive")]
342    fn test_invalid_temperature_panics() {
343        ProgressiveDistiller::new(vec![1.0], 0.0);
344    }
345
346    #[test]
347    #[should_panic(expected = "Number of layers must match")]
348    fn test_mismatched_layers_panics() {
349        let distiller = ProgressiveDistiller::uniform(2, 2.0);
350        let student = vec![array![[1.0, 2.0]]]; // 1 layer
351        let teacher = vec![array![[1.0, 2.0]], array![[3.0, 4.0]]]; // 2 layers
352        distiller.layer_wise_mse_loss(&student, &teacher);
353    }
354}