Skip to main content

entrenar/hf_pipeline/config/
yaml_config.rs

1//! Complete distillation YAML configuration
2
3use crate::hf_pipeline::error::{FetchError, Result};
4use crate::hf_pipeline::fine_tune::{FineTuneConfig, MixedPrecision};
5use crate::hf_pipeline::trainer::TrainerConfig;
6use crate::lora::LoRAConfig;
7use serde::{Deserialize, Serialize};
8use std::path::Path;
9
10use super::dataset::DatasetConfig;
11use super::distillation::DistillationConfig;
12use super::output::OutputConfig;
13use super::student::StudentConfig;
14use super::teacher::TeacherConfig;
15use super::training::TrainingConfig;
16
17/// Complete distillation configuration
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct DistillationYamlConfig {
20    /// Teacher model config
21    pub teacher: TeacherConfig,
22    /// Student model config
23    pub student: StudentConfig,
24    /// Distillation loss config
25    #[serde(default)]
26    pub distillation: DistillationConfig,
27    /// Training hyperparameters
28    #[serde(default)]
29    pub training: TrainingConfig,
30    /// Dataset config
31    pub dataset: DatasetConfig,
32    /// Output config
33    #[serde(default)]
34    pub output: OutputConfig,
35}
36
37impl DistillationYamlConfig {
38    /// Load configuration from YAML file
39    pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> {
40        let content = std::fs::read_to_string(path.as_ref()).map_err(|e| {
41            FetchError::ConfigParseError { message: format!("Failed to read config file: {e}") }
42        })?;
43
44        Self::from_yaml(&content)
45    }
46
47    /// Parse configuration from YAML string
48    pub fn from_yaml(content: &str) -> Result<Self> {
49        serde_yaml::from_str(content).map_err(|e| FetchError::ConfigParseError {
50            message: format!("Failed to parse YAML: {e}"),
51        })
52    }
53
54    /// Save configuration to YAML file
55    pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<()> {
56        let content = self.to_yaml()?;
57        std::fs::write(path, content).map_err(|e| FetchError::ConfigParseError {
58            message: format!("Failed to write config file: {e}"),
59        })
60    }
61
62    /// Serialize to YAML string
63    pub fn to_yaml(&self) -> Result<String> {
64        serde_yaml::to_string(self).map_err(|e| FetchError::ConfigParseError {
65            message: format!("Failed to serialize YAML: {e}"),
66        })
67    }
68
69    /// Validate configuration
70    pub fn validate(&self) -> Result<()> {
71        // Validate teacher
72        if self.teacher.model_id.is_empty() {
73            return Err(FetchError::ConfigParseError {
74                message: "teacher.model_id cannot be empty".into(),
75            });
76        }
77
78        // Validate student
79        if self.student.model_id.is_empty() {
80            return Err(FetchError::ConfigParseError {
81                message: "student.model_id cannot be empty".into(),
82            });
83        }
84
85        // Validate distillation
86        if self.distillation.temperature <= 0.0 {
87            return Err(FetchError::ConfigParseError {
88                message: "distillation.temperature must be positive".into(),
89            });
90        }
91
92        if !(0.0..=1.0).contains(&self.distillation.alpha) {
93            return Err(FetchError::ConfigParseError {
94                message: "distillation.alpha must be between 0 and 1".into(),
95            });
96        }
97
98        // Validate training
99        if self.training.batch_size == 0 {
100            return Err(FetchError::ConfigParseError {
101                message: "training.batch_size must be > 0".into(),
102            });
103        }
104
105        if self.training.learning_rate <= 0.0 {
106            return Err(FetchError::ConfigParseError {
107                message: "training.learning_rate must be positive".into(),
108            });
109        }
110
111        // Validate dataset
112        if self.dataset.path.is_empty() {
113            return Err(FetchError::ConfigParseError {
114                message: "dataset.path cannot be empty".into(),
115            });
116        }
117
118        Ok(())
119    }
120
121    /// Convert to TrainerConfig
122    pub fn to_trainer_config(&self) -> Result<TrainerConfig> {
123        self.validate()?;
124
125        let mut config = TrainerConfig::new(&self.teacher.model_id, &self.student.model_id)
126            .temperature(self.distillation.temperature)
127            .alpha(self.distillation.alpha)
128            .epochs(self.training.epochs)
129            .output_dir(&self.output.dir);
130
131        // Add progressive distillation
132        if let Some(ref prog) = self.distillation.progressive {
133            let mapping: Vec<(usize, usize)> =
134                prog.layer_mapping.iter().map(|[s, t]| (*s, *t)).collect();
135            config = config.with_progressive(mapping);
136        }
137
138        // Add attention transfer
139        if let Some(ref at) = self.distillation.attention_transfer {
140            config = config.with_attention_transfer(at.weight);
141        }
142
143        // Set up fine-tuning config
144        let mut fine_tune = FineTuneConfig::new(&self.student.model_id)
145            .learning_rate(self.training.learning_rate)
146            .epochs(self.training.epochs)
147            .batch_size(self.training.batch_size);
148
149        // Set LoRA if configured
150        if let Some(ref lora_yaml) = self.student.lora {
151            let lora_config = LoRAConfig::from(lora_yaml);
152            if self.student.load_in_4bit {
153                fine_tune = fine_tune.with_qlora(lora_config, 4);
154            } else {
155                fine_tune = fine_tune.with_lora(lora_config);
156            }
157        } else if !self.student.load_in_4bit {
158            fine_tune = fine_tune.full_fine_tune();
159        }
160
161        // Set mixed precision
162        if let Some(ref mp) = self.training.mixed_precision {
163            fine_tune = fine_tune.mixed_precision(match mp.as_str() {
164                "fp16" => Some(MixedPrecision::Fp16),
165                "bf16" => Some(MixedPrecision::Bf16),
166                _ => None,
167            });
168        }
169
170        fine_tune = fine_tune.gradient_checkpointing(self.training.gradient_checkpointing);
171
172        config.fine_tune = fine_tune;
173        config.max_grad_norm = self.training.max_grad_norm;
174        config.seed = self.training.seed;
175        config.log_every_n_steps = self.output.log_steps;
176        config.save_every_n_steps = self.output.save_steps;
177        config.eval_every_n_steps = self.output.eval_steps;
178
179        Ok(config)
180    }
181}