use super::config::TrainerConfig;
use super::state::TrainingState;
use crate::hf_pipeline::fine_tune::FineTuneMethod;
use crate::hf_pipeline::loader::TeacherModel;
pub struct DistillationTrainer<T: TeacherModel> {
pub config: TrainerConfig,
teacher: T,
state: TrainingState,
}
impl<T: TeacherModel> DistillationTrainer<T> {
pub fn new(config: TrainerConfig, teacher: T) -> Self {
Self { config, teacher, state: TrainingState::new() }
}
#[must_use]
pub fn state(&self) -> &TrainingState {
&self.state
}
#[must_use]
pub fn teacher(&self) -> &T {
&self.teacher
}
#[must_use]
#[allow(clippy::too_many_arguments)]
pub fn compute_loss(
&self,
student_logits: &ndarray::Array2<f32>,
teacher_logits: &ndarray::Array2<f32>,
targets: &[usize],
student_hidden: Option<&[ndarray::Array2<f32>]>,
teacher_hidden: Option<&[ndarray::Array2<f32>]>,
student_attention: Option<&[ndarray::Array2<f32>]>,
teacher_attention: Option<&[ndarray::Array2<f32>]>,
) -> f32 {
contract_pre_cross_entropy!();
let mut total_loss =
self.config.distillation_loss.forward(student_logits, teacher_logits, targets);
if let (Some(prog), Some(sh), Some(th)) =
(&self.config.progressive, student_hidden, teacher_hidden)
{
total_loss += prog.hidden_state_loss(sh, th);
}
if let (Some(at), Some(sa), Some(ta)) =
(&self.config.attention_transfer, student_attention, teacher_attention)
{
total_loss += at.loss(sa, ta);
}
total_loss
}
#[must_use]
pub fn is_parameter_efficient(&self) -> bool {
matches!(
self.config.fine_tune.method,
FineTuneMethod::LoRA(_)
| FineTuneMethod::QLoRA { .. }
| FineTuneMethod::PrefixTuning { .. }
)
}
#[must_use]
pub fn estimate_total_memory(&self) -> u64 {
let teacher_mem = self.teacher.estimate_memory(
self.config.fine_tune.batch_size,
self.config.fine_tune.max_seq_length,
);
let student_mem = self.config.fine_tune.estimate_memory(self.teacher.param_count() / 4);
teacher_mem.total() + student_mem.total()
}
pub fn simulate_step(&mut self, loss: f32) {
self.state.record_loss(loss);
self.state.step();
}
pub fn simulate_epoch(&mut self) {
self.state.new_epoch();
}
}