Skip to main content

entrenar/hf_pipeline/distillation/
progressive.rs

1//! Progressive Knowledge Distillation
2//!
3//! Matches student hidden states to teacher hidden states at selected layers.
4//! Based on Sun et al. (2019).
5
6use ndarray::Array2;
7
8/// Progressive Knowledge Distillation
9///
10/// Matches student hidden states to teacher hidden states at selected layers.
11/// Based on Sun et al. (2019).
12#[derive(Debug, Clone)]
13pub struct ProgressiveDistillation {
14    /// Layer mapping: (student_layer, teacher_layer)
15    pub layer_mapping: Vec<(usize, usize)>,
16    /// Loss weight for hidden state matching
17    pub hidden_weight: f32,
18    /// Projection matrix for dimension alignment (student_dim x teacher_dim)
19    /// Used when student hidden size differs from teacher hidden size.
20    pub projection: Option<Array2<f32>>,
21}
22
23impl Default for ProgressiveDistillation {
24    fn default() -> Self {
25        Self {
26            layer_mapping: vec![(0, 2), (1, 5), (2, 8), (3, 11)],
27            hidden_weight: 1.0,
28            projection: None,
29        }
30    }
31}
32
33impl ProgressiveDistillation {
34    /// Create new progressive distillation config
35    #[must_use]
36    pub fn new(layer_mapping: Vec<(usize, usize)>) -> Self {
37        Self { layer_mapping, hidden_weight: 1.0, projection: None }
38    }
39
40    /// Set projection layer for dimension alignment
41    ///
42    /// Creates a linear projection matrix to align student hidden states
43    /// to teacher hidden size. Initialized with Xavier uniform.
44    ///
45    /// # Arguments
46    ///
47    /// * `student_dim` - Student model hidden dimension
48    /// * `teacher_dim` - Teacher model hidden dimension
49    #[must_use]
50    pub fn with_projection(mut self, student_dim: usize, teacher_dim: usize) -> Self {
51        use rand::Rng;
52
53        // Xavier uniform initialization
54        let scale = (6.0 / (student_dim + teacher_dim) as f32).sqrt();
55        let mut rng = rand::rng();
56
57        let projection =
58            Array2::from_shape_fn((student_dim, teacher_dim), |_| rng.random_range(-scale..scale));
59
60        self.projection = Some(projection);
61        self
62    }
63
64    /// Set hidden state loss weight
65    #[must_use]
66    pub fn with_weight(mut self, weight: f32) -> Self {
67        self.hidden_weight = weight;
68        self
69    }
70
71    /// Compute hidden state matching loss
72    ///
73    /// Uses MSE loss between projected student and teacher hidden states.
74    /// If projection layer is set and shapes differ, projects student to teacher dimension.
75    pub fn hidden_state_loss(
76        &self,
77        student_hidden: &[Array2<f32>],
78        teacher_hidden: &[Array2<f32>],
79    ) -> f32 {
80        let mut total_loss = 0.0;
81        let mut count = 0;
82
83        for (s_idx, t_idx) in &self.layer_mapping {
84            if *s_idx < student_hidden.len() && *t_idx < teacher_hidden.len() {
85                let s_h = &student_hidden[*s_idx];
86                let t_h = &teacher_hidden[*t_idx];
87
88                // MSE loss - project student if dimensions differ
89                if s_h.dim() == t_h.dim() {
90                    // Same dimensions: direct MSE
91                    let diff = s_h - t_h;
92                    let mse = diff.mapv(|x| x * x).mean().unwrap_or(0.0);
93                    total_loss += mse;
94                    count += 1;
95                } else if let Some(ref proj) = self.projection {
96                    // Different dimensions: project student to teacher space
97                    // s_h: (batch, student_dim), proj: (student_dim, teacher_dim)
98                    // result: (batch, teacher_dim)
99                    let s_dim = s_h.shape()[1];
100                    let t_dim = t_h.shape()[1];
101
102                    // Verify projection dimensions match
103                    if proj.shape() == [s_dim, t_dim] {
104                        let projected = s_h.dot(proj);
105                        let diff = &projected - t_h;
106                        let mse = diff.mapv(|x| x * x).mean().unwrap_or(0.0);
107                        total_loss += mse;
108                        count += 1;
109                    }
110                }
111                // Skip if shapes differ and no projection is set
112            }
113        }
114
115        if count > 0 {
116            self.hidden_weight * total_loss / count as f32
117        } else {
118            0.0
119        }
120    }
121}