Skip to main content

entrenar/distill/
ensemble.rs

1//! Multi-teacher ensemble distillation
2
3use ndarray::{Array2, Axis};
4
5/// Multi-Teacher Ensemble Distillation
6///
7/// Distills knowledge from multiple teacher models into a single student.
8/// Combines teacher predictions via averaging or weighted averaging.
9///
10/// # Methods
11///
12/// - **Average**: Simple mean of all teacher logits
13/// - **Weighted**: Weighted combination based on teacher confidence/accuracy
14///
15/// # Example
16///
17/// ```
18/// use entrenar::distill::EnsembleDistiller;
19/// use ndarray::array;
20///
21/// let distiller = EnsembleDistiller::new(vec![1.0, 1.0], 2.0);
22/// let teachers = vec![
23///     array![[2.0, 1.0, 0.5]],
24///     array![[1.5, 1.2, 0.8]],
25/// ];
26/// let ensemble_logits = distiller.combine_teachers(&teachers);
27/// ```
28#[derive(Debug, Clone)]
29pub struct EnsembleDistiller {
30    /// Weights for each teacher (normalized to sum to 1)
31    pub weights: Vec<f32>,
32    /// Temperature for distillation
33    pub temperature: f32,
34}
35
36impl EnsembleDistiller {
37    /// Create a new ensemble distiller with given teacher weights
38    ///
39    /// # Arguments
40    ///
41    /// * `weights` - Weight for each teacher (will be normalized)
42    /// * `temperature` - Temperature for softening distributions
43    ///
44    /// # Panics
45    ///
46    /// Panics if weights are empty, all zero, or temperature <= 0
47    pub fn new(weights: Vec<f32>, temperature: f32) -> Self {
48        assert!(!weights.is_empty(), "Must have at least one teacher");
49        assert!(temperature > 0.0, "Temperature must be positive, got {temperature}");
50
51        let sum: f32 = weights.iter().sum();
52        assert!(sum > 0.0, "Teacher weights must sum to positive value");
53
54        // Normalize weights
55        let normalized_weights: Vec<f32> = weights.iter().map(|&w| w / sum).collect();
56
57        Self { weights: normalized_weights, temperature }
58    }
59
60    /// Create an ensemble with uniform weights
61    pub fn uniform(num_teachers: usize, temperature: f32) -> Self {
62        Self::new(vec![1.0; num_teachers], temperature)
63    }
64
65    /// Combine multiple teacher logits into ensemble prediction
66    ///
67    /// # Arguments
68    ///
69    /// * `teacher_logits` - Vector of teacher logits, each [batch_size, num_classes]
70    ///
71    /// # Returns
72    ///
73    /// Weighted average of teacher logits [batch_size, num_classes]
74    pub fn combine_teachers(&self, teacher_logits: &[Array2<f32>]) -> Array2<f32> {
75        assert_eq!(
76            teacher_logits.len(),
77            self.weights.len(),
78            "Number of teachers must match number of weights"
79        );
80        assert!(!teacher_logits.is_empty(), "Must have at least one teacher");
81
82        // Check all teachers have same shape
83        let shape = teacher_logits[0].shape();
84        for t in teacher_logits.iter().skip(1) {
85            assert_eq!(t.shape(), shape, "All teacher logits must have the same shape");
86        }
87
88        // Weighted average
89        let mut ensemble = Array2::zeros((shape[0], shape[1]));
90
91        for (teacher, &weight) in teacher_logits.iter().zip(&self.weights) {
92            ensemble = ensemble + teacher * weight;
93        }
94
95        ensemble
96    }
97
98    /// Combine teachers via probability distribution averaging (more stable)
99    ///
100    /// Converts each teacher's logits to probabilities, averages them,
101    /// then converts back to logits.
102    pub fn combine_via_probabilities(&self, teacher_logits: &[Array2<f32>]) -> Array2<f32> {
103        assert_eq!(
104            teacher_logits.len(),
105            self.weights.len(),
106            "Number of teachers must match number of weights"
107        );
108        assert!(!teacher_logits.is_empty(), "Must have at least one teacher");
109
110        let shape = teacher_logits[0].shape();
111
112        // Convert each teacher to probabilities
113        let teacher_probs: Vec<Array2<f32>> =
114            teacher_logits.iter().map(|logits| softmax_2d(&(logits / self.temperature))).collect();
115
116        // Weighted average of probabilities
117        let mut ensemble_probs = Array2::zeros((shape[0], shape[1]));
118        for (probs, &weight) in teacher_probs.iter().zip(&self.weights) {
119            ensemble_probs = ensemble_probs + probs * weight;
120        }
121
122        // Convert back to logits (inverse softmax via log)
123        // Note: This is approximate - exact inverse doesn't exist
124        ensemble_probs.mapv(|p: f32| (p + 1e-10_f32).max(f32::MIN_POSITIVE).ln() * self.temperature)
125    }
126
127    /// Compute ensemble distillation loss
128    ///
129    /// # Arguments
130    ///
131    /// * `student_logits` - Logits from student [batch_size, num_classes]
132    /// * `teacher_logits` - Vector of teacher logits
133    /// * `labels` - Ground truth labels `[batch_size]`
134    /// * `alpha` - Weight for distillation vs hard loss
135    pub fn distillation_loss(
136        &self,
137        student_logits: &Array2<f32>,
138        teacher_logits: &[Array2<f32>],
139        labels: &[usize],
140        alpha: f32,
141    ) -> f32 {
142        use super::loss::DistillationLoss;
143
144        let ensemble_logits = self.combine_teachers(teacher_logits);
145        let loss_fn = DistillationLoss::new(self.temperature, alpha);
146
147        loss_fn.forward(student_logits, &ensemble_logits, labels)
148    }
149}
150
151/// Compute softmax along last axis for 2D array
152fn softmax_2d(x: &Array2<f32>) -> Array2<f32> {
153    let mut result = x.clone();
154
155    for mut row in result.axis_iter_mut(Axis(0)) {
156        let max_val = row.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
157        row.mapv_inplace(|v| (v - max_val).exp());
158        let sum: f32 = row.sum();
159        row.mapv_inplace(|v| v / sum);
160    }
161
162    result
163}
164
165#[cfg(test)]
166mod tests {
167    use super::*;
168    use approx::assert_relative_eq;
169    use ndarray::array;
170
171    #[test]
172    fn test_uniform_ensemble() {
173        let distiller = EnsembleDistiller::uniform(3, 2.0);
174        assert_eq!(distiller.weights.len(), 3);
175        assert_relative_eq!(distiller.weights.iter().sum::<f32>(), 1.0, epsilon = 1e-6);
176        for &w in &distiller.weights {
177            assert_relative_eq!(w, 1.0 / 3.0, epsilon = 1e-6);
178        }
179    }
180
181    #[test]
182    fn test_weighted_ensemble() {
183        let distiller = EnsembleDistiller::new(vec![1.0, 2.0, 3.0], 2.0);
184        assert_relative_eq!(distiller.weights.iter().sum::<f32>(), 1.0, epsilon = 1e-6);
185        assert_relative_eq!(distiller.weights[0], 1.0 / 6.0, epsilon = 1e-6);
186        assert_relative_eq!(distiller.weights[1], 2.0 / 6.0, epsilon = 1e-6);
187        assert_relative_eq!(distiller.weights[2], 3.0 / 6.0, epsilon = 1e-6);
188    }
189
190    #[test]
191    fn test_combine_teachers() {
192        let distiller = EnsembleDistiller::uniform(2, 2.0);
193
194        let t1 = array![[1.0, 2.0, 3.0]];
195        let t2 = array![[3.0, 2.0, 1.0]];
196        let teachers = vec![t1, t2];
197
198        let ensemble = distiller.combine_teachers(&teachers);
199
200        // Should be average: (1+3)/2=2, (2+2)/2=2, (3+1)/2=2
201        assert_relative_eq!(ensemble[[0, 0]], 2.0, epsilon = 1e-6);
202        assert_relative_eq!(ensemble[[0, 1]], 2.0, epsilon = 1e-6);
203        assert_relative_eq!(ensemble[[0, 2]], 2.0, epsilon = 1e-6);
204    }
205
206    #[test]
207    fn test_weighted_combine() {
208        let distiller = EnsembleDistiller::new(vec![1.0, 3.0], 2.0);
209
210        let t1 = array![[1.0, 2.0, 3.0]];
211        let t2 = array![[3.0, 2.0, 1.0]];
212        let teachers = vec![t1, t2];
213
214        let ensemble = distiller.combine_teachers(&teachers);
215
216        // Should be weighted average: (1*0.25 + 3*0.75) = 2.5
217        assert_relative_eq!(ensemble[[0, 0]], 2.5, epsilon = 1e-6);
218        // (2*0.25 + 2*0.75) = 2.0
219        assert_relative_eq!(ensemble[[0, 1]], 2.0, epsilon = 1e-6);
220        // (3*0.25 + 1*0.75) = 1.5
221        assert_relative_eq!(ensemble[[0, 2]], 1.5, epsilon = 1e-6);
222    }
223
224    #[test]
225    fn test_combine_via_probabilities() {
226        let distiller = EnsembleDistiller::uniform(2, 2.0);
227
228        let t1 = array![[2.0, 1.0, 0.5]];
229        let t2 = array![[1.5, 1.2, 0.8]];
230        let teachers = vec![t1, t2];
231
232        let ensemble = distiller.combine_via_probabilities(&teachers);
233
234        // Result should be finite and reasonable
235        assert!(ensemble.iter().all(|&x| x.is_finite()));
236    }
237
238    #[test]
239    #[should_panic(expected = "Must have at least one teacher")]
240    fn test_empty_weights_panics() {
241        EnsembleDistiller::new(vec![], 2.0);
242    }
243
244    #[test]
245    #[should_panic(expected = "Teacher weights must sum to positive")]
246    fn test_zero_weights_panics() {
247        EnsembleDistiller::new(vec![0.0, 0.0], 2.0);
248    }
249
250    #[test]
251    #[should_panic(expected = "Number of teachers must match")]
252    fn test_mismatched_teachers_panics() {
253        let distiller = EnsembleDistiller::uniform(2, 2.0);
254        let teachers = vec![array![[1.0, 2.0]]]; // Only 1 teacher
255        distiller.combine_teachers(&teachers);
256    }
257
258    #[test]
259    fn test_distillation_loss() {
260        let distiller = EnsembleDistiller::uniform(2, 2.0);
261
262        let student = array![[2.0, 1.0, 0.5]];
263        let t1 = array![[1.8, 1.1, 0.6]];
264        let t2 = array![[1.9, 0.9, 0.7]];
265        let teachers = vec![t1, t2];
266        let labels = vec![0];
267
268        let loss = distiller.distillation_loss(&student, &teachers, &labels, 0.7);
269
270        assert!(loss > 0.0);
271        assert!(loss.is_finite());
272    }
273}