entrenar/hf_pipeline/config/
yaml_config.rs1use crate::hf_pipeline::error::{FetchError, Result};
4use crate::hf_pipeline::fine_tune::{FineTuneConfig, MixedPrecision};
5use crate::hf_pipeline::trainer::TrainerConfig;
6use crate::lora::LoRAConfig;
7use serde::{Deserialize, Serialize};
8use std::path::Path;
9
10use super::dataset::DatasetConfig;
11use super::distillation::DistillationConfig;
12use super::output::OutputConfig;
13use super::student::StudentConfig;
14use super::teacher::TeacherConfig;
15use super::training::TrainingConfig;
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct DistillationYamlConfig {
20 pub teacher: TeacherConfig,
22 pub student: StudentConfig,
24 #[serde(default)]
26 pub distillation: DistillationConfig,
27 #[serde(default)]
29 pub training: TrainingConfig,
30 pub dataset: DatasetConfig,
32 #[serde(default)]
34 pub output: OutputConfig,
35}
36
37impl DistillationYamlConfig {
38 pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> {
40 let content = std::fs::read_to_string(path.as_ref()).map_err(|e| {
41 FetchError::ConfigParseError { message: format!("Failed to read config file: {e}") }
42 })?;
43
44 Self::from_yaml(&content)
45 }
46
47 pub fn from_yaml(content: &str) -> Result<Self> {
49 serde_yaml::from_str(content).map_err(|e| FetchError::ConfigParseError {
50 message: format!("Failed to parse YAML: {e}"),
51 })
52 }
53
54 pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<()> {
56 let content = self.to_yaml()?;
57 std::fs::write(path, content).map_err(|e| FetchError::ConfigParseError {
58 message: format!("Failed to write config file: {e}"),
59 })
60 }
61
62 pub fn to_yaml(&self) -> Result<String> {
64 serde_yaml::to_string(self).map_err(|e| FetchError::ConfigParseError {
65 message: format!("Failed to serialize YAML: {e}"),
66 })
67 }
68
69 pub fn validate(&self) -> Result<()> {
71 if self.teacher.model_id.is_empty() {
73 return Err(FetchError::ConfigParseError {
74 message: "teacher.model_id cannot be empty".into(),
75 });
76 }
77
78 if self.student.model_id.is_empty() {
80 return Err(FetchError::ConfigParseError {
81 message: "student.model_id cannot be empty".into(),
82 });
83 }
84
85 if self.distillation.temperature <= 0.0 {
87 return Err(FetchError::ConfigParseError {
88 message: "distillation.temperature must be positive".into(),
89 });
90 }
91
92 if !(0.0..=1.0).contains(&self.distillation.alpha) {
93 return Err(FetchError::ConfigParseError {
94 message: "distillation.alpha must be between 0 and 1".into(),
95 });
96 }
97
98 if self.training.batch_size == 0 {
100 return Err(FetchError::ConfigParseError {
101 message: "training.batch_size must be > 0".into(),
102 });
103 }
104
105 if self.training.learning_rate <= 0.0 {
106 return Err(FetchError::ConfigParseError {
107 message: "training.learning_rate must be positive".into(),
108 });
109 }
110
111 if self.dataset.path.is_empty() {
113 return Err(FetchError::ConfigParseError {
114 message: "dataset.path cannot be empty".into(),
115 });
116 }
117
118 Ok(())
119 }
120
121 pub fn to_trainer_config(&self) -> Result<TrainerConfig> {
123 self.validate()?;
124
125 let mut config = TrainerConfig::new(&self.teacher.model_id, &self.student.model_id)
126 .temperature(self.distillation.temperature)
127 .alpha(self.distillation.alpha)
128 .epochs(self.training.epochs)
129 .output_dir(&self.output.dir);
130
131 if let Some(ref prog) = self.distillation.progressive {
133 let mapping: Vec<(usize, usize)> =
134 prog.layer_mapping.iter().map(|[s, t]| (*s, *t)).collect();
135 config = config.with_progressive(mapping);
136 }
137
138 if let Some(ref at) = self.distillation.attention_transfer {
140 config = config.with_attention_transfer(at.weight);
141 }
142
143 let mut fine_tune = FineTuneConfig::new(&self.student.model_id)
145 .learning_rate(self.training.learning_rate)
146 .epochs(self.training.epochs)
147 .batch_size(self.training.batch_size);
148
149 if let Some(ref lora_yaml) = self.student.lora {
151 let lora_config = LoRAConfig::from(lora_yaml);
152 if self.student.load_in_4bit {
153 fine_tune = fine_tune.with_qlora(lora_config, 4);
154 } else {
155 fine_tune = fine_tune.with_lora(lora_config);
156 }
157 } else if !self.student.load_in_4bit {
158 fine_tune = fine_tune.full_fine_tune();
159 }
160
161 if let Some(ref mp) = self.training.mixed_precision {
163 fine_tune = fine_tune.mixed_precision(match mp.as_str() {
164 "fp16" => Some(MixedPrecision::Fp16),
165 "bf16" => Some(MixedPrecision::Bf16),
166 _ => None,
167 });
168 }
169
170 fine_tune = fine_tune.gradient_checkpointing(self.training.gradient_checkpointing);
171
172 config.fine_tune = fine_tune;
173 config.max_grad_norm = self.training.max_grad_norm;
174 config.seed = self.training.seed;
175 config.log_every_n_steps = self.output.log_steps;
176 config.save_every_n_steps = self.output.save_steps;
177 config.eval_every_n_steps = self.output.eval_steps;
178
179 Ok(config)
180 }
181}