entrenar/hf_pipeline/fine_tune/
config.rs1use std::path::PathBuf;
6
7use crate::hf_pipeline::error::Result;
8use crate::hf_pipeline::FetchError;
9use crate::lora::LoRAConfig;
10
11use super::memory::{MemoryRequirement, MixedPrecision};
12use super::method::FineTuneMethod;
13
14const DEFAULT_SAVE_STEPS: usize = 500;
16
17#[derive(Debug, Clone)]
19pub struct FineTuneConfig {
20 pub model_id: String,
22 pub method: FineTuneMethod,
24 pub output_dir: PathBuf,
26 pub learning_rate: f64,
28 pub epochs: usize,
30 pub batch_size: usize,
32 pub max_seq_length: usize,
34 pub gradient_accumulation_steps: usize,
36 pub weight_decay: f64,
38 pub warmup_ratio: f32,
40 pub save_steps: usize,
42 pub eval_steps: usize,
44 pub gradient_checkpointing: bool,
46 pub mixed_precision: Option<MixedPrecision>,
48}
49
50impl Default for FineTuneConfig {
51 fn default() -> Self {
52 Self {
53 model_id: String::new(),
54 method: FineTuneMethod::default(),
55 output_dir: PathBuf::from("./output"),
56 learning_rate: 2e-4, epochs: 3,
58 batch_size: 8,
59 max_seq_length: 512,
60 gradient_accumulation_steps: 4,
61 weight_decay: 0.01,
62 warmup_ratio: 0.03,
63 save_steps: DEFAULT_SAVE_STEPS,
64 eval_steps: 100,
65 gradient_checkpointing: true,
66 mixed_precision: Some(MixedPrecision::Bf16),
67 }
68 }
69}
70
71impl FineTuneConfig {
72 #[must_use]
74 pub fn new(model_id: impl Into<String>) -> Self {
75 Self { model_id: model_id.into(), ..Default::default() }
76 }
77
78 #[must_use]
80 pub fn with_lora(mut self, config: LoRAConfig) -> Self {
81 self.method = FineTuneMethod::LoRA(config);
82 self
83 }
84
85 #[must_use]
87 pub fn with_qlora(mut self, lora_config: LoRAConfig, bits: u8) -> Self {
88 self.method = FineTuneMethod::QLoRA { lora_config, bits };
89 self
90 }
91
92 #[must_use]
94 pub fn full_fine_tune(mut self) -> Self {
95 self.method = FineTuneMethod::Full;
96 self
97 }
98
99 #[must_use]
101 pub fn learning_rate(mut self, lr: f64) -> Self {
102 self.learning_rate = lr;
103 self
104 }
105
106 #[must_use]
108 pub fn epochs(mut self, n: usize) -> Self {
109 self.epochs = n;
110 self
111 }
112
113 #[must_use]
115 pub fn batch_size(mut self, size: usize) -> Self {
116 self.batch_size = size;
117 self
118 }
119
120 #[must_use]
122 pub fn output_dir(mut self, path: impl Into<PathBuf>) -> Self {
123 self.output_dir = path.into();
124 self
125 }
126
127 #[must_use]
129 pub fn gradient_checkpointing(mut self, enabled: bool) -> Self {
130 self.gradient_checkpointing = enabled;
131 self
132 }
133
134 #[must_use]
136 pub fn mixed_precision(mut self, mode: Option<MixedPrecision>) -> Self {
137 self.mixed_precision = mode;
138 self
139 }
140
141 #[must_use]
146 pub fn estimate_trainable_params(&self, total_params: u64) -> u64 {
147 let d = ((total_params as f64 / 384.0).sqrt() as u64).max(64);
150 let num_layers_est = (total_params / (12 * d * d)).clamp(1, 128);
151
152 match &self.method {
153 FineTuneMethod::Full => total_params,
154 FineTuneMethod::LoRA(config) => {
155 let num_modules = config.num_target_modules().max(4);
157 2 * (config.rank as u64) * d * (num_modules as u64) * num_layers_est
158 }
159 FineTuneMethod::QLoRA { lora_config, .. } => {
160 let num_modules = lora_config.num_target_modules().max(4);
161 2 * (lora_config.rank as u64) * d * (num_modules as u64) * num_layers_est
162 }
163 FineTuneMethod::PrefixTuning { prefix_length } => {
164 (*prefix_length as u64) * d * 2 * num_layers_est
166 }
167 }
168 }
169
170 #[must_use]
172 pub fn estimate_memory(&self, total_params: u64) -> MemoryRequirement {
173 let trainable = self.estimate_trainable_params(total_params);
174
175 let model_bytes = match &self.method {
177 FineTuneMethod::Full => total_params * 4, FineTuneMethod::LoRA(_) => total_params * 2, FineTuneMethod::QLoRA { bits, .. } => {
180 let base = match bits {
182 4 => total_params / 2,
183 2 | 3 | 5..=8 | 0 | 1 | 9.. => total_params,
184 };
185 base + trainable * 2
186 }
187 FineTuneMethod::PrefixTuning { .. } => total_params * 2 + trainable * 4,
188 };
189
190 let optimizer_bytes = trainable * 4 * 2;
192
193 let gradient_bytes = trainable * 4;
195
196 let activation_bytes = (self.batch_size * self.max_seq_length * 4096 * 4) as u64
198 * if self.gradient_checkpointing { 1 } else { 4 };
199
200 MemoryRequirement {
201 model: model_bytes,
202 optimizer: optimizer_bytes,
203 gradients: gradient_bytes,
204 activations: activation_bytes,
205 }
206 }
207
208 pub fn validate(&self) -> Result<()> {
210 if self.model_id.is_empty() {
211 return Err(FetchError::InvalidRepoId { repo_id: String::new() });
212 }
213
214 if self.learning_rate <= 0.0 {
215 return Err(FetchError::ConfigParseError {
216 message: "Learning rate must be positive".into(),
217 });
218 }
219
220 if self.batch_size == 0 {
221 return Err(FetchError::ConfigParseError {
222 message: "Batch size must be greater than 0".into(),
223 });
224 }
225
226 if let FineTuneMethod::QLoRA { bits, .. } = &self.method {
227 if *bits != 4 && *bits != 8 {
228 return Err(FetchError::ConfigParseError {
229 message: format!("QLoRA bits must be 4 or 8, got {bits}"),
230 });
231 }
232 }
233
234 Ok(())
235 }
236}