entrenar/hf_pipeline/distillation/progressive.rs
1//! Progressive Knowledge Distillation
2//!
3//! Matches student hidden states to teacher hidden states at selected layers.
4//! Based on Sun et al. (2019).
5
6use ndarray::Array2;
7
8/// Progressive Knowledge Distillation
9///
10/// Matches student hidden states to teacher hidden states at selected layers.
11/// Based on Sun et al. (2019).
12#[derive(Debug, Clone)]
13pub struct ProgressiveDistillation {
14 /// Layer mapping: (student_layer, teacher_layer)
15 pub layer_mapping: Vec<(usize, usize)>,
16 /// Loss weight for hidden state matching
17 pub hidden_weight: f32,
18 /// Projection matrix for dimension alignment (student_dim x teacher_dim)
19 /// Used when student hidden size differs from teacher hidden size.
20 pub projection: Option<Array2<f32>>,
21}
22
23impl Default for ProgressiveDistillation {
24 fn default() -> Self {
25 Self {
26 layer_mapping: vec![(0, 2), (1, 5), (2, 8), (3, 11)],
27 hidden_weight: 1.0,
28 projection: None,
29 }
30 }
31}
32
33impl ProgressiveDistillation {
34 /// Create new progressive distillation config
35 #[must_use]
36 pub fn new(layer_mapping: Vec<(usize, usize)>) -> Self {
37 Self { layer_mapping, hidden_weight: 1.0, projection: None }
38 }
39
40 /// Set projection layer for dimension alignment
41 ///
42 /// Creates a linear projection matrix to align student hidden states
43 /// to teacher hidden size. Initialized with Xavier uniform.
44 ///
45 /// # Arguments
46 ///
47 /// * `student_dim` - Student model hidden dimension
48 /// * `teacher_dim` - Teacher model hidden dimension
49 #[must_use]
50 pub fn with_projection(mut self, student_dim: usize, teacher_dim: usize) -> Self {
51 use rand::Rng;
52
53 // Xavier uniform initialization
54 let scale = (6.0 / (student_dim + teacher_dim) as f32).sqrt();
55 let mut rng = rand::rng();
56
57 let projection =
58 Array2::from_shape_fn((student_dim, teacher_dim), |_| rng.random_range(-scale..scale));
59
60 self.projection = Some(projection);
61 self
62 }
63
64 /// Set hidden state loss weight
65 #[must_use]
66 pub fn with_weight(mut self, weight: f32) -> Self {
67 self.hidden_weight = weight;
68 self
69 }
70
71 /// Compute hidden state matching loss
72 ///
73 /// Uses MSE loss between projected student and teacher hidden states.
74 /// If projection layer is set and shapes differ, projects student to teacher dimension.
75 pub fn hidden_state_loss(
76 &self,
77 student_hidden: &[Array2<f32>],
78 teacher_hidden: &[Array2<f32>],
79 ) -> f32 {
80 let mut total_loss = 0.0;
81 let mut count = 0;
82
83 for (s_idx, t_idx) in &self.layer_mapping {
84 if *s_idx < student_hidden.len() && *t_idx < teacher_hidden.len() {
85 let s_h = &student_hidden[*s_idx];
86 let t_h = &teacher_hidden[*t_idx];
87
88 // MSE loss - project student if dimensions differ
89 if s_h.dim() == t_h.dim() {
90 // Same dimensions: direct MSE
91 let diff = s_h - t_h;
92 let mse = diff.mapv(|x| x * x).mean().unwrap_or(0.0);
93 total_loss += mse;
94 count += 1;
95 } else if let Some(ref proj) = self.projection {
96 // Different dimensions: project student to teacher space
97 // s_h: (batch, student_dim), proj: (student_dim, teacher_dim)
98 // result: (batch, teacher_dim)
99 let s_dim = s_h.shape()[1];
100 let t_dim = t_h.shape()[1];
101
102 // Verify projection dimensions match
103 if proj.shape() == [s_dim, t_dim] {
104 let projected = s_h.dot(proj);
105 let diff = &projected - t_h;
106 let mse = diff.mapv(|x| x * x).mean().unwrap_or(0.0);
107 total_loss += mse;
108 count += 1;
109 }
110 }
111 // Skip if shapes differ and no projection is set
112 }
113 }
114
115 if count > 0 {
116 self.hidden_weight * total_loss / count as f32
117 } else {
118 0.0
119 }
120 }
121}