use std::path::PathBuf;
use crate::hf_pipeline::error::Result;
use crate::hf_pipeline::FetchError;
use crate::lora::LoRAConfig;
use super::memory::{MemoryRequirement, MixedPrecision};
use super::method::FineTuneMethod;
const DEFAULT_SAVE_STEPS: usize = 500;
#[derive(Debug, Clone)]
pub struct FineTuneConfig {
pub model_id: String,
pub method: FineTuneMethod,
pub output_dir: PathBuf,
pub learning_rate: f64,
pub epochs: usize,
pub batch_size: usize,
pub max_seq_length: usize,
pub gradient_accumulation_steps: usize,
pub weight_decay: f64,
pub warmup_ratio: f32,
pub save_steps: usize,
pub eval_steps: usize,
pub gradient_checkpointing: bool,
pub mixed_precision: Option<MixedPrecision>,
}
impl Default for FineTuneConfig {
fn default() -> Self {
Self {
model_id: String::new(),
method: FineTuneMethod::default(),
output_dir: PathBuf::from("./output"),
learning_rate: 2e-4, epochs: 3,
batch_size: 8,
max_seq_length: 512,
gradient_accumulation_steps: 4,
weight_decay: 0.01,
warmup_ratio: 0.03,
save_steps: DEFAULT_SAVE_STEPS,
eval_steps: 100,
gradient_checkpointing: true,
mixed_precision: Some(MixedPrecision::Bf16),
}
}
}
impl FineTuneConfig {
#[must_use]
pub fn new(model_id: impl Into<String>) -> Self {
Self { model_id: model_id.into(), ..Default::default() }
}
#[must_use]
pub fn with_lora(mut self, config: LoRAConfig) -> Self {
self.method = FineTuneMethod::LoRA(config);
self
}
#[must_use]
pub fn with_qlora(mut self, lora_config: LoRAConfig, bits: u8) -> Self {
self.method = FineTuneMethod::QLoRA { lora_config, bits };
self
}
#[must_use]
pub fn full_fine_tune(mut self) -> Self {
self.method = FineTuneMethod::Full;
self
}
#[must_use]
pub fn learning_rate(mut self, lr: f64) -> Self {
self.learning_rate = lr;
self
}
#[must_use]
pub fn epochs(mut self, n: usize) -> Self {
self.epochs = n;
self
}
#[must_use]
pub fn batch_size(mut self, size: usize) -> Self {
self.batch_size = size;
self
}
#[must_use]
pub fn output_dir(mut self, path: impl Into<PathBuf>) -> Self {
self.output_dir = path.into();
self
}
#[must_use]
pub fn gradient_checkpointing(mut self, enabled: bool) -> Self {
self.gradient_checkpointing = enabled;
self
}
#[must_use]
pub fn mixed_precision(mut self, mode: Option<MixedPrecision>) -> Self {
self.mixed_precision = mode;
self
}
#[must_use]
pub fn estimate_trainable_params(&self, total_params: u64) -> u64 {
let d = ((total_params as f64 / 384.0).sqrt() as u64).max(64);
let num_layers_est = (total_params / (12 * d * d)).clamp(1, 128);
match &self.method {
FineTuneMethod::Full => total_params,
FineTuneMethod::LoRA(config) => {
let num_modules = config.num_target_modules().max(4);
2 * (config.rank as u64) * d * (num_modules as u64) * num_layers_est
}
FineTuneMethod::QLoRA { lora_config, .. } => {
let num_modules = lora_config.num_target_modules().max(4);
2 * (lora_config.rank as u64) * d * (num_modules as u64) * num_layers_est
}
FineTuneMethod::PrefixTuning { prefix_length } => {
(*prefix_length as u64) * d * 2 * num_layers_est
}
}
}
#[must_use]
pub fn estimate_memory(&self, total_params: u64) -> MemoryRequirement {
let trainable = self.estimate_trainable_params(total_params);
let model_bytes = match &self.method {
FineTuneMethod::Full => total_params * 4, FineTuneMethod::LoRA(_) => total_params * 2, FineTuneMethod::QLoRA { bits, .. } => {
let base = match bits {
4 => total_params / 2,
2 | 3 | 5..=8 | 0 | 1 | 9.. => total_params,
};
base + trainable * 2
}
FineTuneMethod::PrefixTuning { .. } => total_params * 2 + trainable * 4,
};
let optimizer_bytes = trainable * 4 * 2;
let gradient_bytes = trainable * 4;
let activation_bytes = (self.batch_size * self.max_seq_length * 4096 * 4) as u64
* if self.gradient_checkpointing { 1 } else { 4 };
MemoryRequirement {
model: model_bytes,
optimizer: optimizer_bytes,
gradients: gradient_bytes,
activations: activation_bytes,
}
}
pub fn validate(&self) -> Result<()> {
if self.model_id.is_empty() {
return Err(FetchError::InvalidRepoId { repo_id: String::new() });
}
if self.learning_rate <= 0.0 {
return Err(FetchError::ConfigParseError {
message: "Learning rate must be positive".into(),
});
}
if self.batch_size == 0 {
return Err(FetchError::ConfigParseError {
message: "Batch size must be greater than 0".into(),
});
}
if let FineTuneMethod::QLoRA { bits, .. } = &self.method {
if *bits != 4 && *bits != 8 {
return Err(FetchError::ConfigParseError {
message: format!("QLoRA bits must be 4 or 8, got {bits}"),
});
}
}
Ok(())
}
}