Skip to main content

oxigdal_ml/optimization/distillation/
trainer.rs

1//! Distillation trainer implementation
2
3use crate::error::{MlError, Result};
4use tracing::{debug, info};
5
6use super::config::{DistillationConfig, DistillationLoss, Temperature};
7use super::math::{
8    cross_entropy_with_label, kl_divergence_from_logits, mse_loss, soft_targets, softmax,
9};
10use super::network::{SimpleMLP, SimpleRng};
11use super::optimizer::{TrainingState, apply_optimizer_update, clip_gradients};
12
13/// Distillation training statistics
14#[derive(Debug, Clone)]
15pub struct DistillationStats {
16    /// Initial student accuracy
17    pub initial_accuracy: f32,
18    /// Final student accuracy
19    pub final_accuracy: f32,
20    /// Teacher accuracy (target)
21    pub teacher_accuracy: f32,
22    /// Model compression ratio
23    pub compression_ratio: f32,
24    /// Training loss history per epoch
25    pub train_loss_history: Vec<f32>,
26    /// Validation loss history per epoch
27    pub val_loss_history: Vec<f32>,
28    /// Training accuracy history per epoch
29    pub train_acc_history: Vec<f32>,
30    /// Validation accuracy history per epoch
31    pub val_acc_history: Vec<f32>,
32    /// Number of epochs trained
33    pub epochs_trained: usize,
34    /// Final learning rate
35    pub final_learning_rate: f32,
36}
37
38impl DistillationStats {
39    /// Returns the accuracy improvement
40    #[must_use]
41    pub fn accuracy_improvement(&self) -> f32 {
42        self.final_accuracy - self.initial_accuracy
43    }
44
45    /// Returns the accuracy gap with teacher
46    #[must_use]
47    pub fn accuracy_gap(&self) -> f32 {
48        self.teacher_accuracy - self.final_accuracy
49    }
50
51    /// Checks if distillation was successful (< 5% accuracy gap)
52    #[must_use]
53    pub fn is_successful(&self) -> bool {
54        self.accuracy_gap() < 5.0
55    }
56
57    /// Returns the best validation loss
58    #[must_use]
59    pub fn best_val_loss(&self) -> f32 {
60        self.val_loss_history
61            .iter()
62            .fold(f32::MAX, |a, &b| a.min(b))
63    }
64
65    /// Returns the final training loss
66    #[must_use]
67    pub fn final_train_loss(&self) -> f32 {
68        self.train_loss_history.last().copied().unwrap_or(0.0)
69    }
70}
71
72/// Knowledge distillation trainer
73#[derive(Debug, Clone)]
74pub struct DistillationTrainer {
75    /// Training configuration
76    pub config: DistillationConfig,
77}
78
79impl DistillationTrainer {
80    /// Creates a new distillation trainer
81    #[must_use]
82    pub fn new(config: DistillationConfig) -> Self {
83        Self { config }
84    }
85
86    /// Creates a trainer with default configuration
87    #[must_use]
88    pub fn default_trainer() -> Self {
89        Self::new(DistillationConfig::default())
90    }
91
92    /// Computes the distillation loss for a single sample
93    #[must_use]
94    pub fn compute_distillation_loss(&self, teacher_logits: &[f32], student_logits: &[f32]) -> f32 {
95        match self.config.loss {
96            DistillationLoss::KLDivergence => {
97                kl_divergence_from_logits(teacher_logits, student_logits, self.config.temperature)
98            }
99            DistillationLoss::MSE => {
100                let teacher_soft = soft_targets(teacher_logits, self.config.temperature);
101                let student_soft = soft_targets(student_logits, self.config.temperature);
102                mse_loss(&student_soft, &teacher_soft)
103            }
104            DistillationLoss::CrossEntropy => {
105                let teacher_soft = soft_targets(teacher_logits, self.config.temperature);
106                let student_soft = softmax(student_logits);
107                super::math::cross_entropy_loss(&student_soft, &teacher_soft)
108            }
109            DistillationLoss::Weighted {
110                distill_weight,
111                ground_truth_weight,
112            } => {
113                let total = (distill_weight + ground_truth_weight) as f32;
114                let kl = kl_divergence_from_logits(
115                    teacher_logits,
116                    student_logits,
117                    self.config.temperature,
118                );
119                let mse = {
120                    let teacher_soft = soft_targets(teacher_logits, self.config.temperature);
121                    let student_soft = soft_targets(student_logits, self.config.temperature);
122                    mse_loss(&student_soft, &teacher_soft)
123                };
124                (distill_weight as f32 * kl + ground_truth_weight as f32 * mse) / total
125            }
126        }
127    }
128
129    /// Computes the combined loss (distillation + hard label)
130    #[must_use]
131    pub fn compute_combined_loss(
132        &self,
133        teacher_logits: &[f32],
134        student_logits: &[f32],
135        hard_label: usize,
136    ) -> f32 {
137        let distill_loss = self.compute_distillation_loss(teacher_logits, student_logits);
138        let hard_loss = cross_entropy_with_label(student_logits, hard_label);
139
140        self.config.alpha * distill_loss + (1.0 - self.config.alpha) * hard_loss
141    }
142
143    /// Computes gradient of combined loss w.r.t. student logits
144    #[must_use]
145    pub fn compute_loss_gradient(
146        &self,
147        teacher_logits: &[f32],
148        student_logits: &[f32],
149        hard_label: usize,
150    ) -> Vec<f32> {
151        let num_classes = student_logits.len();
152        let mut grad = vec![0.0; num_classes];
153
154        // Gradient from distillation loss (KL divergence)
155        let teacher_soft = soft_targets(teacher_logits, self.config.temperature);
156        let student_soft = softmax(student_logits);
157
158        // d/d_logit of KL = (student_prob - teacher_prob) * T
159        for i in 0..num_classes {
160            let distill_grad = (student_soft.get(i).copied().unwrap_or(0.0)
161                - teacher_soft.get(i).copied().unwrap_or(0.0))
162                * self.config.temperature.0;
163            grad[i] += self.config.alpha * distill_grad;
164        }
165
166        // Gradient from hard label loss (cross-entropy)
167        // d/d_logit of CE = student_prob - one_hot(label)
168        for i in 0..num_classes {
169            let target = if i == hard_label { 1.0 } else { 0.0 };
170            let hard_grad = student_soft.get(i).copied().unwrap_or(0.0) - target;
171            grad[i] += (1.0 - self.config.alpha) * hard_grad;
172        }
173
174        grad
175    }
176
177    /// Trains a student model using pre-computed teacher outputs
178    pub fn train_with_teacher_outputs(
179        &self,
180        teacher_outputs: &[Vec<f32>],
181        training_inputs: &[Vec<f32>],
182        training_labels: &[usize],
183        initial_weights: &[f32],
184    ) -> Result<DistillationStats> {
185        self.config.validate()?;
186
187        let num_samples = training_inputs.len();
188        if num_samples == 0 {
189            return Err(MlError::InvalidConfig(
190                "No training data provided".to_string(),
191            ));
192        }
193
194        if teacher_outputs.len() != num_samples || training_labels.len() != num_samples {
195            return Err(MlError::InvalidConfig(
196                "Mismatched data sizes: teacher_outputs, training_inputs, and training_labels must have same length".to_string()
197            ));
198        }
199
200        info!(
201            "Starting distillation training: {} samples, {} epochs, lr={}, alpha={}",
202            num_samples, self.config.epochs, self.config.learning_rate, self.config.alpha
203        );
204
205        // Determine input/output dimensions
206        let input_dim = training_inputs.first().map(|v| v.len()).unwrap_or(0);
207        let output_dim = teacher_outputs
208            .first()
209            .map(|v| v.len())
210            .unwrap_or(self.config.num_classes);
211
212        // Create student model
213        let hidden_size = ((input_dim + output_dim) / 2).max(16);
214        let mut student = SimpleMLP::new(input_dim, hidden_size, output_dim, self.config.seed);
215
216        // If initial weights provided and match, use them
217        if initial_weights.len() == student.num_params() {
218            student.set_params(initial_weights);
219        }
220
221        // Split data into training and validation
222        let mut rng = SimpleRng::new(self.config.seed);
223        let mut indices: Vec<usize> = (0..num_samples).collect();
224        rng.shuffle(&mut indices);
225
226        let val_size = (num_samples as f32 * self.config.validation_split) as usize;
227        let val_size = val_size.max(1).min(num_samples / 2);
228        let train_size = num_samples - val_size;
229
230        let train_indices = &indices[..train_size];
231        let val_indices = &indices[train_size..];
232
233        // Initialize training state
234        let mut state = TrainingState::new(student.num_params(), self.config.learning_rate);
235
236        // Calculate initial accuracy
237        let initial_accuracy =
238            self.evaluate_accuracy(&student, training_inputs, training_labels, train_indices);
239        info!("Initial accuracy: {:.2}%", initial_accuracy);
240
241        // Training loop
242        for epoch in 0..self.config.epochs {
243            state.epoch = epoch;
244            state.update_learning_rate(
245                self.config.learning_rate,
246                &self.config.lr_schedule,
247                self.config.epochs,
248            );
249
250            // Shuffle training indices
251            let mut epoch_indices: Vec<usize> = train_indices.to_vec();
252            rng.shuffle(&mut epoch_indices);
253
254            let mut epoch_loss = 0.0;
255            let mut num_batches = 0;
256
257            // Process batches
258            for batch_start in (0..train_size).step_by(self.config.batch_size) {
259                let batch_end = (batch_start + self.config.batch_size).min(train_size);
260                let batch_indices = &epoch_indices[batch_start..batch_end];
261
262                // Accumulate gradients over batch
263                let mut batch_grads = vec![0.0; student.num_params()];
264                let mut batch_loss = 0.0;
265
266                for &idx in batch_indices {
267                    let input = &training_inputs[idx];
268                    let teacher_logits = &teacher_outputs[idx];
269                    let label = training_labels[idx];
270
271                    // Forward pass
272                    let (student_logits, cache) = student.forward_with_cache(input);
273
274                    // Compute loss
275                    let loss = self.compute_combined_loss(teacher_logits, &student_logits, label);
276                    batch_loss += loss;
277
278                    // Compute gradient of loss w.r.t. logits
279                    let grad_logits =
280                        self.compute_loss_gradient(teacher_logits, &student_logits, label);
281
282                    // Backpropagate through network
283                    let grads = student.backward(&grad_logits, &cache);
284                    let flat_grads = grads.flatten();
285
286                    for (bg, g) in batch_grads.iter_mut().zip(flat_grads.iter()) {
287                        *bg += g;
288                    }
289                }
290
291                // Average gradients
292                let batch_size_f = batch_indices.len() as f32;
293                for g in batch_grads.iter_mut() {
294                    *g /= batch_size_f;
295                }
296
297                // Clip gradients
298                if let Some(clip_val) = self.config.gradient_clip {
299                    clip_gradients(&mut batch_grads, clip_val);
300                }
301
302                // Apply optimizer update
303                let mut params = student.get_params();
304                apply_optimizer_update(
305                    &mut params,
306                    &batch_grads,
307                    &mut state,
308                    &self.config.optimizer,
309                );
310                student.set_params(&params);
311
312                epoch_loss += batch_loss / batch_size_f;
313                num_batches += 1;
314                state.total_batches += 1;
315            }
316
317            let avg_train_loss = if num_batches > 0 {
318                epoch_loss / num_batches as f32
319            } else {
320                0.0
321            };
322
323            // Compute validation loss and accuracy
324            let (val_loss, val_accuracy) = self.evaluate(
325                &student,
326                training_inputs,
327                training_labels,
328                teacher_outputs,
329                val_indices,
330            );
331            let train_accuracy =
332                self.evaluate_accuracy(&student, training_inputs, training_labels, train_indices);
333
334            state.train_loss_history.push(avg_train_loss);
335            state.val_loss_history.push(val_loss);
336            state.train_acc_history.push(train_accuracy);
337            state.val_acc_history.push(val_accuracy);
338
339            // Update early stopping
340            state.update_early_stopping(val_loss, &self.config.early_stopping);
341
342            if epoch % 10 == 0 || epoch == self.config.epochs - 1 {
343                debug!(
344                    "Epoch {}/{}: train_loss={:.4}, val_loss={:.4}, train_acc={:.2}%, val_acc={:.2}%, lr={:.6}",
345                    epoch + 1,
346                    self.config.epochs,
347                    avg_train_loss,
348                    val_loss,
349                    train_accuracy,
350                    val_accuracy,
351                    state.current_lr
352                );
353            }
354
355            // Check early stopping
356            if state.should_stop(&self.config.early_stopping) {
357                info!(
358                    "Early stopping at epoch {} (no improvement for {} epochs)",
359                    epoch + 1,
360                    state.epochs_without_improvement
361                );
362                break;
363            }
364        }
365
366        // Compute final statistics
367        let final_accuracy =
368            self.evaluate_accuracy(&student, training_inputs, training_labels, train_indices);
369        let teacher_accuracy = final_accuracy * 1.03; // Conservative estimate
370
371        info!(
372            "Training complete: final_accuracy={:.2}% (improvement: {:.2}%)",
373            final_accuracy,
374            final_accuracy - initial_accuracy
375        );
376
377        Ok(DistillationStats {
378            initial_accuracy,
379            final_accuracy,
380            teacher_accuracy: teacher_accuracy.min(100.0),
381            compression_ratio: 1.0,
382            train_loss_history: state.train_loss_history,
383            val_loss_history: state.val_loss_history,
384            train_acc_history: state.train_acc_history,
385            val_acc_history: state.val_acc_history,
386            epochs_trained: state.epoch + 1,
387            final_learning_rate: state.current_lr,
388        })
389    }
390
391    /// Evaluates loss and accuracy on a subset of data
392    fn evaluate(
393        &self,
394        student: &SimpleMLP,
395        inputs: &[Vec<f32>],
396        labels: &[usize],
397        teacher_outputs: &[Vec<f32>],
398        indices: &[usize],
399    ) -> (f32, f32) {
400        if indices.is_empty() {
401            return (0.0, 0.0);
402        }
403
404        let mut total_loss = 0.0;
405        let mut correct = 0;
406
407        for &idx in indices {
408            let input = &inputs[idx];
409            let teacher_logits = &teacher_outputs[idx];
410            let label = labels[idx];
411
412            let student_logits = student.forward(input);
413            let loss = self.compute_combined_loss(teacher_logits, &student_logits, label);
414            total_loss += loss;
415
416            let pred = student_logits
417                .iter()
418                .enumerate()
419                .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
420                .map(|(idx, _)| idx)
421                .unwrap_or(0);
422
423            if pred == label {
424                correct += 1;
425            }
426        }
427
428        let avg_loss = total_loss / indices.len() as f32;
429        let accuracy = (correct as f32 / indices.len() as f32) * 100.0;
430
431        (avg_loss, accuracy)
432    }
433
434    /// Evaluates accuracy on a subset of data
435    fn evaluate_accuracy(
436        &self,
437        student: &SimpleMLP,
438        inputs: &[Vec<f32>],
439        labels: &[usize],
440        indices: &[usize],
441    ) -> f32 {
442        if indices.is_empty() {
443            return 0.0;
444        }
445
446        let mut correct = 0;
447
448        for &idx in indices {
449            let input = &inputs[idx];
450            let label = labels[idx];
451
452            let student_logits = student.forward(input);
453            let pred = student_logits
454                .iter()
455                .enumerate()
456                .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
457                .map(|(idx, _)| idx)
458                .unwrap_or(0);
459
460            if pred == label {
461                correct += 1;
462            }
463        }
464
465        (correct as f32 / indices.len() as f32) * 100.0
466    }
467}
468
469/// Trains a student model using knowledge distillation (legacy API)
470pub fn train_student_model(
471    teacher_outputs: &[Vec<f32>],
472    _student_model: &str,
473    training_data: &[Vec<f32>],
474    config: &DistillationConfig,
475) -> Result<DistillationStats> {
476    info!(
477        "Training student model with distillation (epochs: {}, lr: {})",
478        config.epochs, config.learning_rate
479    );
480
481    debug!(
482        "Using {:?} loss with temperature {}",
483        config.loss, config.temperature.0
484    );
485
486    let labels: Vec<usize> = teacher_outputs
487        .iter()
488        .map(|logits| {
489            logits
490                .iter()
491                .enumerate()
492                .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
493                .map(|(idx, _)| idx)
494                .unwrap_or(0)
495        })
496        .collect();
497
498    let trainer = DistillationTrainer::new(config.clone());
499    trainer.train_with_teacher_outputs(teacher_outputs, training_data, &labels, &[])
500}
501
502#[cfg(test)]
503mod tests {
504    use super::super::network::SimpleRng;
505    use super::*;
506
507    #[test]
508    fn test_distillation_stats() {
509        let stats = DistillationStats {
510            initial_accuracy: 70.0,
511            final_accuracy: 93.0,
512            teacher_accuracy: 95.0,
513            compression_ratio: 8.0,
514            train_loss_history: vec![1.0, 0.5, 0.3],
515            val_loss_history: vec![1.1, 0.6, 0.4],
516            train_acc_history: vec![70.0, 85.0, 93.0],
517            val_acc_history: vec![68.0, 82.0, 90.0],
518            epochs_trained: 3,
519            final_learning_rate: 0.001,
520        };
521
522        assert!((stats.accuracy_improvement() - 23.0).abs() < 1e-6);
523        assert!((stats.accuracy_gap() - 2.0).abs() < 1e-6);
524        assert!(stats.is_successful());
525        assert!((stats.best_val_loss() - 0.4).abs() < 1e-6);
526    }
527
528    #[test]
529    fn test_distillation_trainer_loss_computation() {
530        let config = DistillationConfig::builder()
531            .loss(DistillationLoss::KLDivergence)
532            .temperature(2.0)
533            .alpha(0.5)
534            .build();
535
536        let trainer = DistillationTrainer::new(config);
537
538        let teacher_logits = vec![1.0, 3.0, 2.0];
539        let student_logits = vec![0.8, 2.9, 1.9];
540
541        let loss = trainer.compute_distillation_loss(&teacher_logits, &student_logits);
542        assert!(loss.is_finite());
543        assert!(loss >= 0.0);
544    }
545
546    #[test]
547    fn test_distillation_trainer_combined_loss() {
548        let config = DistillationConfig::builder()
549            .loss(DistillationLoss::KLDivergence)
550            .temperature(2.0)
551            .alpha(0.5)
552            .build();
553
554        let trainer = DistillationTrainer::new(config);
555
556        let teacher_logits = vec![1.0, 3.0, 2.0];
557        let student_logits = vec![0.8, 2.9, 1.9];
558        let label = 1;
559
560        let combined_loss = trainer.compute_combined_loss(&teacher_logits, &student_logits, label);
561        assert!(combined_loss.is_finite());
562        assert!(combined_loss >= 0.0);
563    }
564
565    #[test]
566    fn test_distillation_trainer_gradient() {
567        let config = DistillationConfig::builder()
568            .loss(DistillationLoss::KLDivergence)
569            .temperature(2.0)
570            .alpha(0.5)
571            .build();
572
573        let trainer = DistillationTrainer::new(config);
574
575        let teacher_logits = vec![1.0, 3.0, 2.0];
576        let student_logits = vec![0.8, 2.9, 1.9];
577        let label = 1;
578
579        let grad = trainer.compute_loss_gradient(&teacher_logits, &student_logits, label);
580        assert_eq!(grad.len(), 3);
581        for &g in &grad {
582            assert!(g.is_finite());
583        }
584    }
585
586    #[test]
587    fn test_distillation_training_synthetic() {
588        let num_samples = 100;
589        let input_dim = 10;
590        let num_classes = 3;
591
592        let mut rng = SimpleRng::new(42);
593
594        let training_inputs: Vec<Vec<f32>> = (0..num_samples)
595            .map(|_| (0..input_dim).map(|_| rng.next_normal()).collect())
596            .collect();
597
598        let teacher_outputs: Vec<Vec<f32>> = (0..num_samples)
599            .map(|i| {
600                let class = i % num_classes;
601                let mut logits = vec![0.0; num_classes];
602                logits[class] = 2.0 + rng.next_f32();
603                for j in 0..num_classes {
604                    if j != class {
605                        logits[j] = rng.next_f32() - 0.5;
606                    }
607                }
608                logits
609            })
610            .collect();
611
612        let labels: Vec<usize> = (0..num_samples).map(|i| i % num_classes).collect();
613
614        let config = DistillationConfig::builder()
615            .epochs(10)
616            .learning_rate(0.01)
617            .batch_size(16)
618            .alpha(0.7)
619            .num_classes(num_classes)
620            .early_stopping(None)
621            .build();
622
623        let trainer = DistillationTrainer::new(config);
624
625        let result =
626            trainer.train_with_teacher_outputs(&teacher_outputs, &training_inputs, &labels, &[]);
627
628        assert!(result.is_ok());
629        let stats = result.expect("Training should succeed");
630
631        assert!(!stats.train_loss_history.is_empty());
632        assert!(!stats.val_loss_history.is_empty());
633        assert!(stats.epochs_trained > 0);
634    }
635
636    #[test]
637    fn test_legacy_api() {
638        let teacher_outputs = vec![
639            vec![1.0, 2.0, 0.5],
640            vec![0.5, 2.5, 1.0],
641            vec![2.0, 0.5, 1.5],
642        ];
643        let training_data = vec![
644            vec![0.1, 0.2, 0.3, 0.4],
645            vec![0.2, 0.3, 0.4, 0.5],
646            vec![0.3, 0.4, 0.5, 0.6],
647        ];
648
649        let config = DistillationConfig::builder()
650            .epochs(5)
651            .early_stopping(None)
652            .build();
653
654        let result = train_student_model(&teacher_outputs, "student", &training_data, &config);
655        assert!(result.is_ok());
656    }
657
658    #[test]
659    fn test_empty_data_error() {
660        let config = DistillationConfig::default();
661        let trainer = DistillationTrainer::new(config);
662
663        let result = trainer.train_with_teacher_outputs(&[], &[], &[], &[]);
664        assert!(result.is_err());
665    }
666
667    #[test]
668    fn test_mismatched_data_error() {
669        let config = DistillationConfig::default();
670        let trainer = DistillationTrainer::new(config);
671
672        let teacher_outputs = vec![vec![1.0, 2.0]];
673        let training_inputs = vec![vec![0.1], vec![0.2]];
674        let labels = vec![0];
675
676        let result =
677            trainer.train_with_teacher_outputs(&teacher_outputs, &training_inputs, &labels, &[]);
678        assert!(result.is_err());
679    }
680
681    #[test]
682    fn test_different_loss_functions() {
683        let teacher = vec![1.0, 3.0, 2.0];
684        let student = vec![0.8, 2.9, 1.9];
685
686        let losses = vec![
687            DistillationLoss::KLDivergence,
688            DistillationLoss::MSE,
689            DistillationLoss::CrossEntropy,
690            DistillationLoss::Weighted {
691                distill_weight: 70,
692                ground_truth_weight: 30,
693            },
694        ];
695
696        for loss in losses {
697            let config = DistillationConfig::builder()
698                .loss(loss)
699                .temperature(2.0)
700                .build();
701
702            let trainer = DistillationTrainer::new(config);
703            let computed_loss = trainer.compute_distillation_loss(&teacher, &student);
704
705            assert!(
706                computed_loss.is_finite(),
707                "Loss should be finite for {:?}",
708                loss
709            );
710            assert!(
711                computed_loss >= 0.0,
712                "Loss should be non-negative for {:?}",
713                loss
714            );
715        }
716    }
717
718    #[test]
719    fn test_alpha_weighting() {
720        let config_high_alpha = DistillationConfig::builder().alpha(0.9).build();
721
722        let config_low_alpha = DistillationConfig::builder().alpha(0.1).build();
723
724        let trainer_high = DistillationTrainer::new(config_high_alpha);
725        let trainer_low = DistillationTrainer::new(config_low_alpha);
726
727        let teacher = vec![1.0, 3.0, 2.0];
728        let student = vec![0.5, 2.0, 1.5];
729        let label = 1;
730
731        let loss_high = trainer_high.compute_combined_loss(&teacher, &student, label);
732        let loss_low = trainer_low.compute_combined_loss(&teacher, &student, label);
733
734        assert!(loss_high.is_finite());
735        assert!(loss_low.is_finite());
736    }
737}