use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TrainingHyperparams {
pub epochs: u32,
pub batch_size: u32,
pub learning_rate: f64,
pub warmup_steps: u64,
pub weight_decay: f64,
pub lr_scheduler: LrScheduler,
pub seed: u64,
pub max_seq_len: usize,
pub gradient_accumulation_steps: u32,
pub max_grad_norm: f64,
}
impl Default for TrainingHyperparams {
fn default() -> Self {
Self {
epochs: 3,
batch_size: 4,
learning_rate: 2e-5,
warmup_steps: 100,
weight_decay: 0.01,
lr_scheduler: LrScheduler::Cosine,
seed: 42,
max_seq_len: 2048,
gradient_accumulation_steps: 4,
max_grad_norm: 1.0,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum LrScheduler {
Constant,
Linear,
Cosine,
CosineWarmRestarts,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LoraConfig {
pub rank: u32,
pub alpha: f32,
pub dropout: f32,
pub target_modules: Vec<String>,
pub method: AdapterMethod,
}
impl Default for LoraConfig {
fn default() -> Self {
Self {
rank: 16,
alpha: 32.0,
dropout: 0.05,
target_modules: vec![
"q_proj".to_string(),
"k_proj".to_string(),
"v_proj".to_string(),
"o_proj".to_string(),
],
method: AdapterMethod::LoRA,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum AdapterMethod {
LoRA,
QLoRA {
bits: u8,
},
DoRA,
QDoRA {
bits: u8,
},
}
impl AdapterMethod {
pub fn is_quantized(&self) -> bool {
matches!(self, Self::QLoRA { .. } | Self::QDoRA { .. })
}
pub fn quantization_bits(&self) -> Option<u8> {
match self {
Self::QLoRA { bits } | Self::QDoRA { bits } => Some(*bits),
_ => None,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
#[derive(Default)]
pub enum AlignmentMethod {
DPO {
beta: f64,
},
ORPO {
lambda: f64,
},
#[default]
None,
}
impl AlignmentMethod {
pub fn dpo() -> Self {
Self::DPO { beta: 0.1 }
}
pub fn orpo() -> Self {
Self::ORPO { lambda: 0.5 }
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_hyperparams_defaults() {
let h = TrainingHyperparams::default();
assert_eq!(h.epochs, 3);
assert_eq!(h.batch_size, 4);
assert!((h.learning_rate - 2e-5).abs() < f64::EPSILON);
}
#[test]
fn test_lora_config_defaults() {
let c = LoraConfig::default();
assert_eq!(c.rank, 16);
assert_eq!(c.target_modules.len(), 4);
}
#[test]
fn test_adapter_method_quantized() {
assert!(!AdapterMethod::LoRA.is_quantized());
assert!(AdapterMethod::QLoRA { bits: 4 }.is_quantized());
assert_eq!(
AdapterMethod::QLoRA { bits: 4 }.quantization_bits(),
Some(4)
);
assert!(AdapterMethod::DoRA.quantization_bits().is_none());
}
#[test]
fn test_alignment_methods() {
let dpo = AlignmentMethod::dpo();
assert!(matches!(dpo, AlignmentMethod::DPO { beta } if (beta - 0.1).abs() < f64::EPSILON));
let orpo = AlignmentMethod::orpo();
assert!(
matches!(orpo, AlignmentMethod::ORPO { lambda } if (lambda - 0.5).abs() < f64::EPSILON)
);
}
#[test]
fn test_serialization_roundtrip() {
let config = LoraConfig {
method: AdapterMethod::QLoRA { bits: 4 },
..Default::default()
};
let json = serde_json::to_string(&config).unwrap();
let parsed: LoraConfig = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.method, AdapterMethod::QLoRA { bits: 4 });
}
}