entrenar/hf_pipeline/trainer/
distillation_trainer.rs1use super::config::TrainerConfig;
4use super::state::TrainingState;
5use crate::hf_pipeline::fine_tune::FineTuneMethod;
6use crate::hf_pipeline::loader::TeacherModel;
7
8pub struct DistillationTrainer<T: TeacherModel> {
13 pub config: TrainerConfig,
15 teacher: T,
17 state: TrainingState,
19}
20
21impl<T: TeacherModel> DistillationTrainer<T> {
22 pub fn new(config: TrainerConfig, teacher: T) -> Self {
24 Self { config, teacher, state: TrainingState::new() }
25 }
26
27 #[must_use]
29 pub fn state(&self) -> &TrainingState {
30 &self.state
31 }
32
33 #[must_use]
35 pub fn teacher(&self) -> &T {
36 &self.teacher
37 }
38
39 #[must_use]
43 #[allow(clippy::too_many_arguments)]
44 pub fn compute_loss(
45 &self,
46 student_logits: &ndarray::Array2<f32>,
47 teacher_logits: &ndarray::Array2<f32>,
48 targets: &[usize],
49 student_hidden: Option<&[ndarray::Array2<f32>]>,
50 teacher_hidden: Option<&[ndarray::Array2<f32>]>,
51 student_attention: Option<&[ndarray::Array2<f32>]>,
52 teacher_attention: Option<&[ndarray::Array2<f32>]>,
53 ) -> f32 {
54 contract_pre_cross_entropy!();
56
57 let mut total_loss =
59 self.config.distillation_loss.forward(student_logits, teacher_logits, targets);
60
61 if let (Some(prog), Some(sh), Some(th)) =
63 (&self.config.progressive, student_hidden, teacher_hidden)
64 {
65 total_loss += prog.hidden_state_loss(sh, th);
66 }
67
68 if let (Some(at), Some(sa), Some(ta)) =
70 (&self.config.attention_transfer, student_attention, teacher_attention)
71 {
72 total_loss += at.loss(sa, ta);
73 }
74
75 contract_post_cross_entropy!(total_loss);
76 total_loss
77 }
78
79 #[must_use]
81 pub fn is_parameter_efficient(&self) -> bool {
82 matches!(
83 self.config.fine_tune.method,
84 FineTuneMethod::LoRA(_)
85 | FineTuneMethod::QLoRA { .. }
86 | FineTuneMethod::PrefixTuning { .. }
87 )
88 }
89
90 #[must_use]
92 pub fn estimate_total_memory(&self) -> u64 {
93 let teacher_mem = self.teacher.estimate_memory(
94 self.config.fine_tune.batch_size,
95 self.config.fine_tune.max_seq_length,
96 );
97 let student_mem = self.config.fine_tune.estimate_memory(self.teacher.param_count() / 4); teacher_mem.total() + student_mem.total()
100 }
101
102 pub fn simulate_step(&mut self, loss: f32) {
104 self.state.record_loss(loss);
105 self.state.step();
106 }
107
108 pub fn simulate_epoch(&mut self) {
110 self.state.new_epoch();
111 }
112}