Skip to main content

entrenar/hf_pipeline/trainer/
distillation_trainer.rs

1//! Knowledge Distillation Trainer implementation.
2
3use super::config::TrainerConfig;
4use super::state::TrainingState;
5use crate::hf_pipeline::fine_tune::FineTuneMethod;
6use crate::hf_pipeline::loader::TeacherModel;
7
8/// Knowledge Distillation Trainer
9///
10/// Orchestrates the training loop for distilling knowledge from
11/// a teacher model to a student model.
12pub struct DistillationTrainer<T: TeacherModel> {
13    /// Configuration
14    pub config: TrainerConfig,
15    /// Teacher model
16    teacher: T,
17    /// Training state
18    state: TrainingState,
19}
20
21impl<T: TeacherModel> DistillationTrainer<T> {
22    /// Create new trainer with teacher model
23    pub fn new(config: TrainerConfig, teacher: T) -> Self {
24        Self { config, teacher, state: TrainingState::new() }
25    }
26
27    /// Get current training state
28    #[must_use]
29    pub fn state(&self) -> &TrainingState {
30        &self.state
31    }
32
33    /// Get teacher model reference
34    #[must_use]
35    pub fn teacher(&self) -> &T {
36        &self.teacher
37    }
38
39    /// Compute total loss for a batch
40    ///
41    /// Combines distillation loss with optional progressive and attention transfer.
42    #[must_use]
43    #[allow(clippy::too_many_arguments)]
44    pub fn compute_loss(
45        &self,
46        student_logits: &ndarray::Array2<f32>,
47        teacher_logits: &ndarray::Array2<f32>,
48        targets: &[usize],
49        student_hidden: Option<&[ndarray::Array2<f32>]>,
50        teacher_hidden: Option<&[ndarray::Array2<f32>]>,
51        student_attention: Option<&[ndarray::Array2<f32>]>,
52        teacher_attention: Option<&[ndarray::Array2<f32>]>,
53    ) -> f32 {
54        // Contract: cross-entropy-kernel-v1.yaml precondition (pv codegen)
55        contract_pre_cross_entropy!();
56
57        // Base distillation loss
58        let mut total_loss =
59            self.config.distillation_loss.forward(student_logits, teacher_logits, targets);
60
61        // Progressive distillation (hidden state matching)
62        if let (Some(prog), Some(sh), Some(th)) =
63            (&self.config.progressive, student_hidden, teacher_hidden)
64        {
65            total_loss += prog.hidden_state_loss(sh, th);
66        }
67
68        // Attention transfer
69        if let (Some(at), Some(sa), Some(ta)) =
70            (&self.config.attention_transfer, student_attention, teacher_attention)
71        {
72            total_loss += at.loss(sa, ta);
73        }
74
75        contract_post_cross_entropy!(total_loss);
76        total_loss
77    }
78
79    /// Check if using LoRA/QLoRA for student fine-tuning
80    #[must_use]
81    pub fn is_parameter_efficient(&self) -> bool {
82        matches!(
83            self.config.fine_tune.method,
84            FineTuneMethod::LoRA(_)
85                | FineTuneMethod::QLoRA { .. }
86                | FineTuneMethod::PrefixTuning { .. }
87        )
88    }
89
90    /// Estimate total memory requirements
91    #[must_use]
92    pub fn estimate_total_memory(&self) -> u64 {
93        let teacher_mem = self.teacher.estimate_memory(
94            self.config.fine_tune.batch_size,
95            self.config.fine_tune.max_seq_length,
96        );
97        let student_mem = self.config.fine_tune.estimate_memory(self.teacher.param_count() / 4); // Assume 4x smaller student
98
99        teacher_mem.total() + student_mem.total()
100    }
101
102    /// Simulate one training step (for testing)
103    pub fn simulate_step(&mut self, loss: f32) {
104        self.state.record_loss(loss);
105        self.state.step();
106    }
107
108    /// Simulate epoch boundary
109    pub fn simulate_epoch(&mut self) {
110        self.state.new_epoch();
111    }
112}