use crate::hf_pipeline::distillation::{
AttentionTransfer, DistillationLoss, ProgressiveDistillation,
};
use crate::hf_pipeline::fine_tune::FineTuneConfig;
use std::path::PathBuf;
const DEFAULT_SAVE_STEPS: usize = 500;
#[derive(Debug, Clone)]
pub struct TrainerConfig {
pub teacher_model: String,
pub student_model: String,
pub output_dir: PathBuf,
pub distillation_loss: DistillationLoss,
pub progressive: Option<ProgressiveDistillation>,
pub attention_transfer: Option<AttentionTransfer>,
pub fine_tune: FineTuneConfig,
pub epochs: usize,
pub steps_per_epoch: usize,
pub log_every_n_steps: usize,
pub save_every_n_steps: usize,
pub eval_every_n_steps: usize,
pub max_grad_norm: f32,
pub seed: u64,
}
impl Default for TrainerConfig {
fn default() -> Self {
Self {
teacher_model: String::new(),
student_model: String::new(),
output_dir: PathBuf::from("./distillation_output"),
distillation_loss: DistillationLoss::default(),
progressive: None,
attention_transfer: None,
fine_tune: FineTuneConfig::default(),
epochs: 3,
steps_per_epoch: 0,
log_every_n_steps: 10,
save_every_n_steps: DEFAULT_SAVE_STEPS,
eval_every_n_steps: 100,
max_grad_norm: 1.0,
seed: 42,
}
}
}
impl TrainerConfig {
#[must_use]
pub fn new(teacher: impl Into<String>, student: impl Into<String>) -> Self {
Self { teacher_model: teacher.into(), student_model: student.into(), ..Default::default() }
}
#[must_use]
pub fn temperature(mut self, temp: f32) -> Self {
self.distillation_loss = DistillationLoss::new(temp, self.distillation_loss.alpha);
self
}
#[must_use]
pub fn alpha(mut self, alpha: f32) -> Self {
self.distillation_loss = DistillationLoss::new(self.distillation_loss.temperature, alpha);
self
}
#[must_use]
pub fn with_progressive(mut self, layer_mapping: Vec<(usize, usize)>) -> Self {
self.progressive = Some(ProgressiveDistillation::new(layer_mapping));
self
}
#[must_use]
pub fn with_attention_transfer(mut self, weight: f32) -> Self {
self.attention_transfer = Some(AttentionTransfer::new(weight));
self
}
#[must_use]
pub fn output_dir(mut self, path: impl Into<PathBuf>) -> Self {
self.output_dir = path.into();
self
}
#[must_use]
pub fn epochs(mut self, n: usize) -> Self {
self.epochs = n;
self
}
}