1use ndarray::{Array2, Axis};
4
5#[derive(Debug, Clone)]
29pub struct EnsembleDistiller {
30 pub weights: Vec<f32>,
32 pub temperature: f32,
34}
35
36impl EnsembleDistiller {
37 pub fn new(weights: Vec<f32>, temperature: f32) -> Self {
48 assert!(!weights.is_empty(), "Must have at least one teacher");
49 assert!(temperature > 0.0, "Temperature must be positive, got {temperature}");
50
51 let sum: f32 = weights.iter().sum();
52 assert!(sum > 0.0, "Teacher weights must sum to positive value");
53
54 let normalized_weights: Vec<f32> = weights.iter().map(|&w| w / sum).collect();
56
57 Self { weights: normalized_weights, temperature }
58 }
59
60 pub fn uniform(num_teachers: usize, temperature: f32) -> Self {
62 Self::new(vec![1.0; num_teachers], temperature)
63 }
64
65 pub fn combine_teachers(&self, teacher_logits: &[Array2<f32>]) -> Array2<f32> {
75 assert_eq!(
76 teacher_logits.len(),
77 self.weights.len(),
78 "Number of teachers must match number of weights"
79 );
80 assert!(!teacher_logits.is_empty(), "Must have at least one teacher");
81
82 let shape = teacher_logits[0].shape();
84 for t in teacher_logits.iter().skip(1) {
85 assert_eq!(t.shape(), shape, "All teacher logits must have the same shape");
86 }
87
88 let mut ensemble = Array2::zeros((shape[0], shape[1]));
90
91 for (teacher, &weight) in teacher_logits.iter().zip(&self.weights) {
92 ensemble = ensemble + teacher * weight;
93 }
94
95 ensemble
96 }
97
98 pub fn combine_via_probabilities(&self, teacher_logits: &[Array2<f32>]) -> Array2<f32> {
103 assert_eq!(
104 teacher_logits.len(),
105 self.weights.len(),
106 "Number of teachers must match number of weights"
107 );
108 assert!(!teacher_logits.is_empty(), "Must have at least one teacher");
109
110 let shape = teacher_logits[0].shape();
111
112 let teacher_probs: Vec<Array2<f32>> =
114 teacher_logits.iter().map(|logits| softmax_2d(&(logits / self.temperature))).collect();
115
116 let mut ensemble_probs = Array2::zeros((shape[0], shape[1]));
118 for (probs, &weight) in teacher_probs.iter().zip(&self.weights) {
119 ensemble_probs = ensemble_probs + probs * weight;
120 }
121
122 ensemble_probs.mapv(|p: f32| (p + 1e-10_f32).max(f32::MIN_POSITIVE).ln() * self.temperature)
125 }
126
127 pub fn distillation_loss(
136 &self,
137 student_logits: &Array2<f32>,
138 teacher_logits: &[Array2<f32>],
139 labels: &[usize],
140 alpha: f32,
141 ) -> f32 {
142 use super::loss::DistillationLoss;
143
144 let ensemble_logits = self.combine_teachers(teacher_logits);
145 let loss_fn = DistillationLoss::new(self.temperature, alpha);
146
147 loss_fn.forward(student_logits, &ensemble_logits, labels)
148 }
149}
150
151fn softmax_2d(x: &Array2<f32>) -> Array2<f32> {
153 let mut result = x.clone();
154
155 for mut row in result.axis_iter_mut(Axis(0)) {
156 let max_val = row.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
157 row.mapv_inplace(|v| (v - max_val).exp());
158 let sum: f32 = row.sum();
159 row.mapv_inplace(|v| v / sum);
160 }
161
162 result
163}
164
165#[cfg(test)]
166mod tests {
167 use super::*;
168 use approx::assert_relative_eq;
169 use ndarray::array;
170
171 #[test]
172 fn test_uniform_ensemble() {
173 let distiller = EnsembleDistiller::uniform(3, 2.0);
174 assert_eq!(distiller.weights.len(), 3);
175 assert_relative_eq!(distiller.weights.iter().sum::<f32>(), 1.0, epsilon = 1e-6);
176 for &w in &distiller.weights {
177 assert_relative_eq!(w, 1.0 / 3.0, epsilon = 1e-6);
178 }
179 }
180
181 #[test]
182 fn test_weighted_ensemble() {
183 let distiller = EnsembleDistiller::new(vec![1.0, 2.0, 3.0], 2.0);
184 assert_relative_eq!(distiller.weights.iter().sum::<f32>(), 1.0, epsilon = 1e-6);
185 assert_relative_eq!(distiller.weights[0], 1.0 / 6.0, epsilon = 1e-6);
186 assert_relative_eq!(distiller.weights[1], 2.0 / 6.0, epsilon = 1e-6);
187 assert_relative_eq!(distiller.weights[2], 3.0 / 6.0, epsilon = 1e-6);
188 }
189
190 #[test]
191 fn test_combine_teachers() {
192 let distiller = EnsembleDistiller::uniform(2, 2.0);
193
194 let t1 = array![[1.0, 2.0, 3.0]];
195 let t2 = array![[3.0, 2.0, 1.0]];
196 let teachers = vec![t1, t2];
197
198 let ensemble = distiller.combine_teachers(&teachers);
199
200 assert_relative_eq!(ensemble[[0, 0]], 2.0, epsilon = 1e-6);
202 assert_relative_eq!(ensemble[[0, 1]], 2.0, epsilon = 1e-6);
203 assert_relative_eq!(ensemble[[0, 2]], 2.0, epsilon = 1e-6);
204 }
205
206 #[test]
207 fn test_weighted_combine() {
208 let distiller = EnsembleDistiller::new(vec![1.0, 3.0], 2.0);
209
210 let t1 = array![[1.0, 2.0, 3.0]];
211 let t2 = array![[3.0, 2.0, 1.0]];
212 let teachers = vec![t1, t2];
213
214 let ensemble = distiller.combine_teachers(&teachers);
215
216 assert_relative_eq!(ensemble[[0, 0]], 2.5, epsilon = 1e-6);
218 assert_relative_eq!(ensemble[[0, 1]], 2.0, epsilon = 1e-6);
220 assert_relative_eq!(ensemble[[0, 2]], 1.5, epsilon = 1e-6);
222 }
223
224 #[test]
225 fn test_combine_via_probabilities() {
226 let distiller = EnsembleDistiller::uniform(2, 2.0);
227
228 let t1 = array![[2.0, 1.0, 0.5]];
229 let t2 = array![[1.5, 1.2, 0.8]];
230 let teachers = vec![t1, t2];
231
232 let ensemble = distiller.combine_via_probabilities(&teachers);
233
234 assert!(ensemble.iter().all(|&x| x.is_finite()));
236 }
237
238 #[test]
239 #[should_panic(expected = "Must have at least one teacher")]
240 fn test_empty_weights_panics() {
241 EnsembleDistiller::new(vec![], 2.0);
242 }
243
244 #[test]
245 #[should_panic(expected = "Teacher weights must sum to positive")]
246 fn test_zero_weights_panics() {
247 EnsembleDistiller::new(vec![0.0, 0.0], 2.0);
248 }
249
250 #[test]
251 #[should_panic(expected = "Number of teachers must match")]
252 fn test_mismatched_teachers_panics() {
253 let distiller = EnsembleDistiller::uniform(2, 2.0);
254 let teachers = vec![array![[1.0, 2.0]]]; distiller.combine_teachers(&teachers);
256 }
257
258 #[test]
259 fn test_distillation_loss() {
260 let distiller = EnsembleDistiller::uniform(2, 2.0);
261
262 let student = array![[2.0, 1.0, 0.5]];
263 let t1 = array![[1.8, 1.1, 0.6]];
264 let t2 = array![[1.9, 0.9, 0.7]];
265 let teachers = vec![t1, t2];
266 let labels = vec![0];
267
268 let loss = distiller.distillation_loss(&student, &teachers, &labels, 0.7);
269
270 assert!(loss > 0.0);
271 assert!(loss.is_finite());
272 }
273}