entrenar/hf_pipeline/trainer/
config.rs1use crate::hf_pipeline::distillation::{
4 AttentionTransfer, DistillationLoss, ProgressiveDistillation,
5};
6use crate::hf_pipeline::fine_tune::FineTuneConfig;
7use std::path::PathBuf;
8
9const DEFAULT_SAVE_STEPS: usize = 500;
11
12#[derive(Debug, Clone)]
14pub struct TrainerConfig {
15 pub teacher_model: String,
17 pub student_model: String,
19 pub output_dir: PathBuf,
21 pub distillation_loss: DistillationLoss,
23 pub progressive: Option<ProgressiveDistillation>,
25 pub attention_transfer: Option<AttentionTransfer>,
27 pub fine_tune: FineTuneConfig,
29 pub epochs: usize,
31 pub steps_per_epoch: usize,
33 pub log_every_n_steps: usize,
35 pub save_every_n_steps: usize,
37 pub eval_every_n_steps: usize,
39 pub max_grad_norm: f32,
41 pub seed: u64,
43}
44
45impl Default for TrainerConfig {
46 fn default() -> Self {
47 Self {
48 teacher_model: String::new(),
49 student_model: String::new(),
50 output_dir: PathBuf::from("./distillation_output"),
51 distillation_loss: DistillationLoss::default(),
52 progressive: None,
53 attention_transfer: None,
54 fine_tune: FineTuneConfig::default(),
55 epochs: 3,
56 steps_per_epoch: 0,
57 log_every_n_steps: 10,
58 save_every_n_steps: DEFAULT_SAVE_STEPS,
59 eval_every_n_steps: 100,
60 max_grad_norm: 1.0,
61 seed: 42,
62 }
63 }
64}
65
66impl TrainerConfig {
67 #[must_use]
69 pub fn new(teacher: impl Into<String>, student: impl Into<String>) -> Self {
70 Self { teacher_model: teacher.into(), student_model: student.into(), ..Default::default() }
71 }
72
73 #[must_use]
75 pub fn temperature(mut self, temp: f32) -> Self {
76 contract_pre_temperature!();
77 self.distillation_loss = DistillationLoss::new(temp, self.distillation_loss.alpha);
78 contract_post_temperature_bounds!(temp);
79 self
80 }
81
82 #[must_use]
84 pub fn alpha(mut self, alpha: f32) -> Self {
85 self.distillation_loss = DistillationLoss::new(self.distillation_loss.temperature, alpha);
86 self
87 }
88
89 #[must_use]
91 pub fn with_progressive(mut self, layer_mapping: Vec<(usize, usize)>) -> Self {
92 self.progressive = Some(ProgressiveDistillation::new(layer_mapping));
93 self
94 }
95
96 #[must_use]
98 pub fn with_attention_transfer(mut self, weight: f32) -> Self {
99 self.attention_transfer = Some(AttentionTransfer::new(weight));
100 self
101 }
102
103 #[must_use]
105 pub fn output_dir(mut self, path: impl Into<PathBuf>) -> Self {
106 self.output_dir = path.into();
107 self
108 }
109
110 #[must_use]
112 pub fn epochs(mut self, n: usize) -> Self {
113 self.epochs = n;
114 self
115 }
116}