Skip to main content

entrenar/hf_pipeline/fine_tune/
config.rs

1//! Fine-tuning configuration
2//!
3//! Provides configuration options for fine-tuning HuggingFace models.
4
5use std::path::PathBuf;
6
7use crate::hf_pipeline::error::Result;
8use crate::hf_pipeline::FetchError;
9use crate::lora::LoRAConfig;
10
11use super::memory::{MemoryRequirement, MixedPrecision};
12use super::method::FineTuneMethod;
13
14/// Default number of training steps between checkpoint saves.
15const DEFAULT_SAVE_STEPS: usize = 500;
16
17/// Fine-tuning configuration
18#[derive(Debug, Clone)]
19pub struct FineTuneConfig {
20    /// Base model repository ID
21    pub model_id: String,
22    /// Fine-tuning method
23    pub method: FineTuneMethod,
24    /// Output directory for checkpoints
25    pub output_dir: PathBuf,
26    /// Learning rate
27    pub learning_rate: f64,
28    /// Number of epochs
29    pub epochs: usize,
30    /// Batch size
31    pub batch_size: usize,
32    /// Maximum sequence length
33    pub max_seq_length: usize,
34    /// Gradient accumulation steps
35    pub gradient_accumulation_steps: usize,
36    /// Weight decay
37    pub weight_decay: f64,
38    /// Warmup ratio (fraction of total steps)
39    pub warmup_ratio: f32,
40    /// Save checkpoints every N steps
41    pub save_steps: usize,
42    /// Evaluate every N steps
43    pub eval_steps: usize,
44    /// Use gradient checkpointing (memory optimization)
45    pub gradient_checkpointing: bool,
46    /// Use mixed precision (fp16/bf16)
47    pub mixed_precision: Option<MixedPrecision>,
48}
49
50impl Default for FineTuneConfig {
51    fn default() -> Self {
52        Self {
53            model_id: String::new(),
54            method: FineTuneMethod::default(),
55            output_dir: PathBuf::from("./output"),
56            learning_rate: 2e-4, // Recommended for LoRA
57            epochs: 3,
58            batch_size: 8,
59            max_seq_length: 512,
60            gradient_accumulation_steps: 4,
61            weight_decay: 0.01,
62            warmup_ratio: 0.03,
63            save_steps: DEFAULT_SAVE_STEPS,
64            eval_steps: 100,
65            gradient_checkpointing: true,
66            mixed_precision: Some(MixedPrecision::Bf16),
67        }
68    }
69}
70
71impl FineTuneConfig {
72    /// Create new fine-tuning config for a model
73    #[must_use]
74    pub fn new(model_id: impl Into<String>) -> Self {
75        Self { model_id: model_id.into(), ..Default::default() }
76    }
77
78    /// Use LoRA fine-tuning
79    #[must_use]
80    pub fn with_lora(mut self, config: LoRAConfig) -> Self {
81        self.method = FineTuneMethod::LoRA(config);
82        self
83    }
84
85    /// Use QLoRA fine-tuning
86    #[must_use]
87    pub fn with_qlora(mut self, lora_config: LoRAConfig, bits: u8) -> Self {
88        self.method = FineTuneMethod::QLoRA { lora_config, bits };
89        self
90    }
91
92    /// Use full fine-tuning
93    #[must_use]
94    pub fn full_fine_tune(mut self) -> Self {
95        self.method = FineTuneMethod::Full;
96        self
97    }
98
99    /// Set learning rate
100    #[must_use]
101    pub fn learning_rate(mut self, lr: f64) -> Self {
102        self.learning_rate = lr;
103        self
104    }
105
106    /// Set number of epochs
107    #[must_use]
108    pub fn epochs(mut self, n: usize) -> Self {
109        self.epochs = n;
110        self
111    }
112
113    /// Set batch size
114    #[must_use]
115    pub fn batch_size(mut self, size: usize) -> Self {
116        self.batch_size = size;
117        self
118    }
119
120    /// Set output directory
121    #[must_use]
122    pub fn output_dir(mut self, path: impl Into<PathBuf>) -> Self {
123        self.output_dir = path.into();
124        self
125    }
126
127    /// Enable gradient checkpointing
128    #[must_use]
129    pub fn gradient_checkpointing(mut self, enabled: bool) -> Self {
130        self.gradient_checkpointing = enabled;
131        self
132    }
133
134    /// Set mixed precision mode
135    #[must_use]
136    pub fn mixed_precision(mut self, mode: Option<MixedPrecision>) -> Self {
137        self.mixed_precision = mode;
138        self
139    }
140
141    /// Estimate trainable parameters based on fine-tuning method.
142    ///
143    /// N-06 (Meyer DbC): Derives hidden_size and num_layers from total_params
144    /// using the approximation total ≈ 12 * L * d² (transformer scaling law).
145    #[must_use]
146    pub fn estimate_trainable_params(&self, total_params: u64) -> u64 {
147        // Estimate hidden_size from total params: total ≈ 12 * L * d²
148        // Rough: d ≈ sqrt(total / 384) for typical 32-layer model
149        let d = ((total_params as f64 / 384.0).sqrt() as u64).max(64);
150        let num_layers_est = (total_params / (12 * d * d)).clamp(1, 128);
151
152        match &self.method {
153            FineTuneMethod::Full => total_params,
154            FineTuneMethod::LoRA(config) => {
155                // LoRA params = 2 * rank * d * num_modules * num_layers
156                let num_modules = config.num_target_modules().max(4);
157                2 * (config.rank as u64) * d * (num_modules as u64) * num_layers_est
158            }
159            FineTuneMethod::QLoRA { lora_config, .. } => {
160                let num_modules = lora_config.num_target_modules().max(4);
161                2 * (lora_config.rank as u64) * d * (num_modules as u64) * num_layers_est
162            }
163            FineTuneMethod::PrefixTuning { prefix_length } => {
164                // Prefix params = prefix_length * hidden_size * 2 * num_layers
165                (*prefix_length as u64) * d * 2 * num_layers_est
166            }
167        }
168    }
169
170    /// Estimate memory requirements in bytes
171    #[must_use]
172    pub fn estimate_memory(&self, total_params: u64) -> MemoryRequirement {
173        let trainable = self.estimate_trainable_params(total_params);
174
175        // Model memory
176        let model_bytes = match &self.method {
177            FineTuneMethod::Full => total_params * 4,    // FP32
178            FineTuneMethod::LoRA(_) => total_params * 2, // FP16 base + LoRA
179            FineTuneMethod::QLoRA { bits, .. } => {
180                // Quantized base + FP16 LoRA
181                let base = match bits {
182                    4 => total_params / 2,
183                    2 | 3 | 5..=8 | 0 | 1 | 9.. => total_params,
184                };
185                base + trainable * 2
186            }
187            FineTuneMethod::PrefixTuning { .. } => total_params * 2 + trainable * 4,
188        };
189
190        // Optimizer states (Adam: 2x for momentum + variance)
191        let optimizer_bytes = trainable * 4 * 2;
192
193        // Gradients
194        let gradient_bytes = trainable * 4;
195
196        // Activations (rough estimate based on batch size and seq len)
197        let activation_bytes = (self.batch_size * self.max_seq_length * 4096 * 4) as u64
198            * if self.gradient_checkpointing { 1 } else { 4 };
199
200        MemoryRequirement {
201            model: model_bytes,
202            optimizer: optimizer_bytes,
203            gradients: gradient_bytes,
204            activations: activation_bytes,
205        }
206    }
207
208    /// Validate configuration
209    pub fn validate(&self) -> Result<()> {
210        if self.model_id.is_empty() {
211            return Err(FetchError::InvalidRepoId { repo_id: String::new() });
212        }
213
214        if self.learning_rate <= 0.0 {
215            return Err(FetchError::ConfigParseError {
216                message: "Learning rate must be positive".into(),
217            });
218        }
219
220        if self.batch_size == 0 {
221            return Err(FetchError::ConfigParseError {
222                message: "Batch size must be greater than 0".into(),
223            });
224        }
225
226        if let FineTuneMethod::QLoRA { bits, .. } = &self.method {
227            if *bits != 4 && *bits != 8 {
228                return Err(FetchError::ConfigParseError {
229                    message: format!("QLoRA bits must be 4 or 8, got {bits}"),
230                });
231            }
232        }
233
234        Ok(())
235    }
236}