1use crate::error::{MlError, Result};
4use tracing::{debug, info};
5
6use super::config::{DistillationConfig, DistillationLoss, Temperature};
7use super::math::{
8 cross_entropy_with_label, kl_divergence_from_logits, mse_loss, soft_targets, softmax,
9};
10use super::network::{SimpleMLP, SimpleRng};
11use super::optimizer::{TrainingState, apply_optimizer_update, clip_gradients};
12
13#[derive(Debug, Clone)]
15pub struct DistillationStats {
16 pub initial_accuracy: f32,
18 pub final_accuracy: f32,
20 pub teacher_accuracy: f32,
22 pub compression_ratio: f32,
24 pub train_loss_history: Vec<f32>,
26 pub val_loss_history: Vec<f32>,
28 pub train_acc_history: Vec<f32>,
30 pub val_acc_history: Vec<f32>,
32 pub epochs_trained: usize,
34 pub final_learning_rate: f32,
36}
37
38impl DistillationStats {
39 #[must_use]
41 pub fn accuracy_improvement(&self) -> f32 {
42 self.final_accuracy - self.initial_accuracy
43 }
44
45 #[must_use]
47 pub fn accuracy_gap(&self) -> f32 {
48 self.teacher_accuracy - self.final_accuracy
49 }
50
51 #[must_use]
53 pub fn is_successful(&self) -> bool {
54 self.accuracy_gap() < 5.0
55 }
56
57 #[must_use]
59 pub fn best_val_loss(&self) -> f32 {
60 self.val_loss_history
61 .iter()
62 .fold(f32::MAX, |a, &b| a.min(b))
63 }
64
65 #[must_use]
67 pub fn final_train_loss(&self) -> f32 {
68 self.train_loss_history.last().copied().unwrap_or(0.0)
69 }
70}
71
72#[derive(Debug, Clone)]
74pub struct DistillationTrainer {
75 pub config: DistillationConfig,
77}
78
79impl DistillationTrainer {
80 #[must_use]
82 pub fn new(config: DistillationConfig) -> Self {
83 Self { config }
84 }
85
86 #[must_use]
88 pub fn default_trainer() -> Self {
89 Self::new(DistillationConfig::default())
90 }
91
92 #[must_use]
94 pub fn compute_distillation_loss(&self, teacher_logits: &[f32], student_logits: &[f32]) -> f32 {
95 match self.config.loss {
96 DistillationLoss::KLDivergence => {
97 kl_divergence_from_logits(teacher_logits, student_logits, self.config.temperature)
98 }
99 DistillationLoss::MSE => {
100 let teacher_soft = soft_targets(teacher_logits, self.config.temperature);
101 let student_soft = soft_targets(student_logits, self.config.temperature);
102 mse_loss(&student_soft, &teacher_soft)
103 }
104 DistillationLoss::CrossEntropy => {
105 let teacher_soft = soft_targets(teacher_logits, self.config.temperature);
106 let student_soft = softmax(student_logits);
107 super::math::cross_entropy_loss(&student_soft, &teacher_soft)
108 }
109 DistillationLoss::Weighted {
110 distill_weight,
111 ground_truth_weight,
112 } => {
113 let total = (distill_weight + ground_truth_weight) as f32;
114 let kl = kl_divergence_from_logits(
115 teacher_logits,
116 student_logits,
117 self.config.temperature,
118 );
119 let mse = {
120 let teacher_soft = soft_targets(teacher_logits, self.config.temperature);
121 let student_soft = soft_targets(student_logits, self.config.temperature);
122 mse_loss(&student_soft, &teacher_soft)
123 };
124 (distill_weight as f32 * kl + ground_truth_weight as f32 * mse) / total
125 }
126 }
127 }
128
129 #[must_use]
131 pub fn compute_combined_loss(
132 &self,
133 teacher_logits: &[f32],
134 student_logits: &[f32],
135 hard_label: usize,
136 ) -> f32 {
137 let distill_loss = self.compute_distillation_loss(teacher_logits, student_logits);
138 let hard_loss = cross_entropy_with_label(student_logits, hard_label);
139
140 self.config.alpha * distill_loss + (1.0 - self.config.alpha) * hard_loss
141 }
142
143 #[must_use]
145 pub fn compute_loss_gradient(
146 &self,
147 teacher_logits: &[f32],
148 student_logits: &[f32],
149 hard_label: usize,
150 ) -> Vec<f32> {
151 let num_classes = student_logits.len();
152 let mut grad = vec![0.0; num_classes];
153
154 let teacher_soft = soft_targets(teacher_logits, self.config.temperature);
156 let student_soft = softmax(student_logits);
157
158 for i in 0..num_classes {
160 let distill_grad = (student_soft.get(i).copied().unwrap_or(0.0)
161 - teacher_soft.get(i).copied().unwrap_or(0.0))
162 * self.config.temperature.0;
163 grad[i] += self.config.alpha * distill_grad;
164 }
165
166 for i in 0..num_classes {
169 let target = if i == hard_label { 1.0 } else { 0.0 };
170 let hard_grad = student_soft.get(i).copied().unwrap_or(0.0) - target;
171 grad[i] += (1.0 - self.config.alpha) * hard_grad;
172 }
173
174 grad
175 }
176
177 pub fn train_with_teacher_outputs(
179 &self,
180 teacher_outputs: &[Vec<f32>],
181 training_inputs: &[Vec<f32>],
182 training_labels: &[usize],
183 initial_weights: &[f32],
184 ) -> Result<DistillationStats> {
185 self.config.validate()?;
186
187 let num_samples = training_inputs.len();
188 if num_samples == 0 {
189 return Err(MlError::InvalidConfig(
190 "No training data provided".to_string(),
191 ));
192 }
193
194 if teacher_outputs.len() != num_samples || training_labels.len() != num_samples {
195 return Err(MlError::InvalidConfig(
196 "Mismatched data sizes: teacher_outputs, training_inputs, and training_labels must have same length".to_string()
197 ));
198 }
199
200 info!(
201 "Starting distillation training: {} samples, {} epochs, lr={}, alpha={}",
202 num_samples, self.config.epochs, self.config.learning_rate, self.config.alpha
203 );
204
205 let input_dim = training_inputs.first().map(|v| v.len()).unwrap_or(0);
207 let output_dim = teacher_outputs
208 .first()
209 .map(|v| v.len())
210 .unwrap_or(self.config.num_classes);
211
212 let hidden_size = ((input_dim + output_dim) / 2).max(16);
214 let mut student = SimpleMLP::new(input_dim, hidden_size, output_dim, self.config.seed);
215
216 if initial_weights.len() == student.num_params() {
218 student.set_params(initial_weights);
219 }
220
221 let mut rng = SimpleRng::new(self.config.seed);
223 let mut indices: Vec<usize> = (0..num_samples).collect();
224 rng.shuffle(&mut indices);
225
226 let val_size = (num_samples as f32 * self.config.validation_split) as usize;
227 let val_size = val_size.max(1).min(num_samples / 2);
228 let train_size = num_samples - val_size;
229
230 let train_indices = &indices[..train_size];
231 let val_indices = &indices[train_size..];
232
233 let mut state = TrainingState::new(student.num_params(), self.config.learning_rate);
235
236 let initial_accuracy =
238 self.evaluate_accuracy(&student, training_inputs, training_labels, train_indices);
239 info!("Initial accuracy: {:.2}%", initial_accuracy);
240
241 for epoch in 0..self.config.epochs {
243 state.epoch = epoch;
244 state.update_learning_rate(
245 self.config.learning_rate,
246 &self.config.lr_schedule,
247 self.config.epochs,
248 );
249
250 let mut epoch_indices: Vec<usize> = train_indices.to_vec();
252 rng.shuffle(&mut epoch_indices);
253
254 let mut epoch_loss = 0.0;
255 let mut num_batches = 0;
256
257 for batch_start in (0..train_size).step_by(self.config.batch_size) {
259 let batch_end = (batch_start + self.config.batch_size).min(train_size);
260 let batch_indices = &epoch_indices[batch_start..batch_end];
261
262 let mut batch_grads = vec![0.0; student.num_params()];
264 let mut batch_loss = 0.0;
265
266 for &idx in batch_indices {
267 let input = &training_inputs[idx];
268 let teacher_logits = &teacher_outputs[idx];
269 let label = training_labels[idx];
270
271 let (student_logits, cache) = student.forward_with_cache(input);
273
274 let loss = self.compute_combined_loss(teacher_logits, &student_logits, label);
276 batch_loss += loss;
277
278 let grad_logits =
280 self.compute_loss_gradient(teacher_logits, &student_logits, label);
281
282 let grads = student.backward(&grad_logits, &cache);
284 let flat_grads = grads.flatten();
285
286 for (bg, g) in batch_grads.iter_mut().zip(flat_grads.iter()) {
287 *bg += g;
288 }
289 }
290
291 let batch_size_f = batch_indices.len() as f32;
293 for g in batch_grads.iter_mut() {
294 *g /= batch_size_f;
295 }
296
297 if let Some(clip_val) = self.config.gradient_clip {
299 clip_gradients(&mut batch_grads, clip_val);
300 }
301
302 let mut params = student.get_params();
304 apply_optimizer_update(
305 &mut params,
306 &batch_grads,
307 &mut state,
308 &self.config.optimizer,
309 );
310 student.set_params(¶ms);
311
312 epoch_loss += batch_loss / batch_size_f;
313 num_batches += 1;
314 state.total_batches += 1;
315 }
316
317 let avg_train_loss = if num_batches > 0 {
318 epoch_loss / num_batches as f32
319 } else {
320 0.0
321 };
322
323 let (val_loss, val_accuracy) = self.evaluate(
325 &student,
326 training_inputs,
327 training_labels,
328 teacher_outputs,
329 val_indices,
330 );
331 let train_accuracy =
332 self.evaluate_accuracy(&student, training_inputs, training_labels, train_indices);
333
334 state.train_loss_history.push(avg_train_loss);
335 state.val_loss_history.push(val_loss);
336 state.train_acc_history.push(train_accuracy);
337 state.val_acc_history.push(val_accuracy);
338
339 state.update_early_stopping(val_loss, &self.config.early_stopping);
341
342 if epoch % 10 == 0 || epoch == self.config.epochs - 1 {
343 debug!(
344 "Epoch {}/{}: train_loss={:.4}, val_loss={:.4}, train_acc={:.2}%, val_acc={:.2}%, lr={:.6}",
345 epoch + 1,
346 self.config.epochs,
347 avg_train_loss,
348 val_loss,
349 train_accuracy,
350 val_accuracy,
351 state.current_lr
352 );
353 }
354
355 if state.should_stop(&self.config.early_stopping) {
357 info!(
358 "Early stopping at epoch {} (no improvement for {} epochs)",
359 epoch + 1,
360 state.epochs_without_improvement
361 );
362 break;
363 }
364 }
365
366 let final_accuracy =
368 self.evaluate_accuracy(&student, training_inputs, training_labels, train_indices);
369 let teacher_accuracy = final_accuracy * 1.03; info!(
372 "Training complete: final_accuracy={:.2}% (improvement: {:.2}%)",
373 final_accuracy,
374 final_accuracy - initial_accuracy
375 );
376
377 Ok(DistillationStats {
378 initial_accuracy,
379 final_accuracy,
380 teacher_accuracy: teacher_accuracy.min(100.0),
381 compression_ratio: 1.0,
382 train_loss_history: state.train_loss_history,
383 val_loss_history: state.val_loss_history,
384 train_acc_history: state.train_acc_history,
385 val_acc_history: state.val_acc_history,
386 epochs_trained: state.epoch + 1,
387 final_learning_rate: state.current_lr,
388 })
389 }
390
391 fn evaluate(
393 &self,
394 student: &SimpleMLP,
395 inputs: &[Vec<f32>],
396 labels: &[usize],
397 teacher_outputs: &[Vec<f32>],
398 indices: &[usize],
399 ) -> (f32, f32) {
400 if indices.is_empty() {
401 return (0.0, 0.0);
402 }
403
404 let mut total_loss = 0.0;
405 let mut correct = 0;
406
407 for &idx in indices {
408 let input = &inputs[idx];
409 let teacher_logits = &teacher_outputs[idx];
410 let label = labels[idx];
411
412 let student_logits = student.forward(input);
413 let loss = self.compute_combined_loss(teacher_logits, &student_logits, label);
414 total_loss += loss;
415
416 let pred = student_logits
417 .iter()
418 .enumerate()
419 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
420 .map(|(idx, _)| idx)
421 .unwrap_or(0);
422
423 if pred == label {
424 correct += 1;
425 }
426 }
427
428 let avg_loss = total_loss / indices.len() as f32;
429 let accuracy = (correct as f32 / indices.len() as f32) * 100.0;
430
431 (avg_loss, accuracy)
432 }
433
434 fn evaluate_accuracy(
436 &self,
437 student: &SimpleMLP,
438 inputs: &[Vec<f32>],
439 labels: &[usize],
440 indices: &[usize],
441 ) -> f32 {
442 if indices.is_empty() {
443 return 0.0;
444 }
445
446 let mut correct = 0;
447
448 for &idx in indices {
449 let input = &inputs[idx];
450 let label = labels[idx];
451
452 let student_logits = student.forward(input);
453 let pred = student_logits
454 .iter()
455 .enumerate()
456 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
457 .map(|(idx, _)| idx)
458 .unwrap_or(0);
459
460 if pred == label {
461 correct += 1;
462 }
463 }
464
465 (correct as f32 / indices.len() as f32) * 100.0
466 }
467}
468
469pub fn train_student_model(
471 teacher_outputs: &[Vec<f32>],
472 _student_model: &str,
473 training_data: &[Vec<f32>],
474 config: &DistillationConfig,
475) -> Result<DistillationStats> {
476 info!(
477 "Training student model with distillation (epochs: {}, lr: {})",
478 config.epochs, config.learning_rate
479 );
480
481 debug!(
482 "Using {:?} loss with temperature {}",
483 config.loss, config.temperature.0
484 );
485
486 let labels: Vec<usize> = teacher_outputs
487 .iter()
488 .map(|logits| {
489 logits
490 .iter()
491 .enumerate()
492 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
493 .map(|(idx, _)| idx)
494 .unwrap_or(0)
495 })
496 .collect();
497
498 let trainer = DistillationTrainer::new(config.clone());
499 trainer.train_with_teacher_outputs(teacher_outputs, training_data, &labels, &[])
500}
501
502#[cfg(test)]
503mod tests {
504 use super::super::network::SimpleRng;
505 use super::*;
506
507 #[test]
508 fn test_distillation_stats() {
509 let stats = DistillationStats {
510 initial_accuracy: 70.0,
511 final_accuracy: 93.0,
512 teacher_accuracy: 95.0,
513 compression_ratio: 8.0,
514 train_loss_history: vec![1.0, 0.5, 0.3],
515 val_loss_history: vec![1.1, 0.6, 0.4],
516 train_acc_history: vec![70.0, 85.0, 93.0],
517 val_acc_history: vec![68.0, 82.0, 90.0],
518 epochs_trained: 3,
519 final_learning_rate: 0.001,
520 };
521
522 assert!((stats.accuracy_improvement() - 23.0).abs() < 1e-6);
523 assert!((stats.accuracy_gap() - 2.0).abs() < 1e-6);
524 assert!(stats.is_successful());
525 assert!((stats.best_val_loss() - 0.4).abs() < 1e-6);
526 }
527
528 #[test]
529 fn test_distillation_trainer_loss_computation() {
530 let config = DistillationConfig::builder()
531 .loss(DistillationLoss::KLDivergence)
532 .temperature(2.0)
533 .alpha(0.5)
534 .build();
535
536 let trainer = DistillationTrainer::new(config);
537
538 let teacher_logits = vec![1.0, 3.0, 2.0];
539 let student_logits = vec![0.8, 2.9, 1.9];
540
541 let loss = trainer.compute_distillation_loss(&teacher_logits, &student_logits);
542 assert!(loss.is_finite());
543 assert!(loss >= 0.0);
544 }
545
546 #[test]
547 fn test_distillation_trainer_combined_loss() {
548 let config = DistillationConfig::builder()
549 .loss(DistillationLoss::KLDivergence)
550 .temperature(2.0)
551 .alpha(0.5)
552 .build();
553
554 let trainer = DistillationTrainer::new(config);
555
556 let teacher_logits = vec![1.0, 3.0, 2.0];
557 let student_logits = vec![0.8, 2.9, 1.9];
558 let label = 1;
559
560 let combined_loss = trainer.compute_combined_loss(&teacher_logits, &student_logits, label);
561 assert!(combined_loss.is_finite());
562 assert!(combined_loss >= 0.0);
563 }
564
565 #[test]
566 fn test_distillation_trainer_gradient() {
567 let config = DistillationConfig::builder()
568 .loss(DistillationLoss::KLDivergence)
569 .temperature(2.0)
570 .alpha(0.5)
571 .build();
572
573 let trainer = DistillationTrainer::new(config);
574
575 let teacher_logits = vec![1.0, 3.0, 2.0];
576 let student_logits = vec![0.8, 2.9, 1.9];
577 let label = 1;
578
579 let grad = trainer.compute_loss_gradient(&teacher_logits, &student_logits, label);
580 assert_eq!(grad.len(), 3);
581 for &g in &grad {
582 assert!(g.is_finite());
583 }
584 }
585
586 #[test]
587 fn test_distillation_training_synthetic() {
588 let num_samples = 100;
589 let input_dim = 10;
590 let num_classes = 3;
591
592 let mut rng = SimpleRng::new(42);
593
594 let training_inputs: Vec<Vec<f32>> = (0..num_samples)
595 .map(|_| (0..input_dim).map(|_| rng.next_normal()).collect())
596 .collect();
597
598 let teacher_outputs: Vec<Vec<f32>> = (0..num_samples)
599 .map(|i| {
600 let class = i % num_classes;
601 let mut logits = vec![0.0; num_classes];
602 logits[class] = 2.0 + rng.next_f32();
603 for j in 0..num_classes {
604 if j != class {
605 logits[j] = rng.next_f32() - 0.5;
606 }
607 }
608 logits
609 })
610 .collect();
611
612 let labels: Vec<usize> = (0..num_samples).map(|i| i % num_classes).collect();
613
614 let config = DistillationConfig::builder()
615 .epochs(10)
616 .learning_rate(0.01)
617 .batch_size(16)
618 .alpha(0.7)
619 .num_classes(num_classes)
620 .early_stopping(None)
621 .build();
622
623 let trainer = DistillationTrainer::new(config);
624
625 let result =
626 trainer.train_with_teacher_outputs(&teacher_outputs, &training_inputs, &labels, &[]);
627
628 assert!(result.is_ok());
629 let stats = result.expect("Training should succeed");
630
631 assert!(!stats.train_loss_history.is_empty());
632 assert!(!stats.val_loss_history.is_empty());
633 assert!(stats.epochs_trained > 0);
634 }
635
636 #[test]
637 fn test_legacy_api() {
638 let teacher_outputs = vec![
639 vec![1.0, 2.0, 0.5],
640 vec![0.5, 2.5, 1.0],
641 vec![2.0, 0.5, 1.5],
642 ];
643 let training_data = vec![
644 vec![0.1, 0.2, 0.3, 0.4],
645 vec![0.2, 0.3, 0.4, 0.5],
646 vec![0.3, 0.4, 0.5, 0.6],
647 ];
648
649 let config = DistillationConfig::builder()
650 .epochs(5)
651 .early_stopping(None)
652 .build();
653
654 let result = train_student_model(&teacher_outputs, "student", &training_data, &config);
655 assert!(result.is_ok());
656 }
657
658 #[test]
659 fn test_empty_data_error() {
660 let config = DistillationConfig::default();
661 let trainer = DistillationTrainer::new(config);
662
663 let result = trainer.train_with_teacher_outputs(&[], &[], &[], &[]);
664 assert!(result.is_err());
665 }
666
667 #[test]
668 fn test_mismatched_data_error() {
669 let config = DistillationConfig::default();
670 let trainer = DistillationTrainer::new(config);
671
672 let teacher_outputs = vec![vec![1.0, 2.0]];
673 let training_inputs = vec![vec![0.1], vec![0.2]];
674 let labels = vec![0];
675
676 let result =
677 trainer.train_with_teacher_outputs(&teacher_outputs, &training_inputs, &labels, &[]);
678 assert!(result.is_err());
679 }
680
681 #[test]
682 fn test_different_loss_functions() {
683 let teacher = vec![1.0, 3.0, 2.0];
684 let student = vec![0.8, 2.9, 1.9];
685
686 let losses = vec![
687 DistillationLoss::KLDivergence,
688 DistillationLoss::MSE,
689 DistillationLoss::CrossEntropy,
690 DistillationLoss::Weighted {
691 distill_weight: 70,
692 ground_truth_weight: 30,
693 },
694 ];
695
696 for loss in losses {
697 let config = DistillationConfig::builder()
698 .loss(loss)
699 .temperature(2.0)
700 .build();
701
702 let trainer = DistillationTrainer::new(config);
703 let computed_loss = trainer.compute_distillation_loss(&teacher, &student);
704
705 assert!(
706 computed_loss.is_finite(),
707 "Loss should be finite for {:?}",
708 loss
709 );
710 assert!(
711 computed_loss >= 0.0,
712 "Loss should be non-negative for {:?}",
713 loss
714 );
715 }
716 }
717
718 #[test]
719 fn test_alpha_weighting() {
720 let config_high_alpha = DistillationConfig::builder().alpha(0.9).build();
721
722 let config_low_alpha = DistillationConfig::builder().alpha(0.1).build();
723
724 let trainer_high = DistillationTrainer::new(config_high_alpha);
725 let trainer_low = DistillationTrainer::new(config_low_alpha);
726
727 let teacher = vec![1.0, 3.0, 2.0];
728 let student = vec![0.5, 2.0, 1.5];
729 let label = 1;
730
731 let loss_high = trainer_high.compute_combined_loss(&teacher, &student, label);
732 let loss_low = trainer_low.compute_combined_loss(&teacher, &student, label);
733
734 assert!(loss_high.is_finite());
735 assert!(loss_low.is_finite());
736 }
737}