Skip to main content

brainwires_training/
config.rs

1use serde::{Deserialize, Serialize};
2
3/// Training hyperparameters.
4#[derive(Debug, Clone, Serialize, Deserialize)]
5pub struct TrainingHyperparams {
6    /// Number of training epochs.
7    pub epochs: u32,
8    /// Batch size per device.
9    pub batch_size: u32,
10    /// Initial learning rate.
11    pub learning_rate: f64,
12    /// Warmup steps for LR scheduler.
13    pub warmup_steps: u64,
14    /// Weight decay factor.
15    pub weight_decay: f64,
16    /// Learning rate scheduler type.
17    pub lr_scheduler: LrScheduler,
18    /// Random seed for reproducibility.
19    pub seed: u64,
20    /// Maximum sequence length (tokens).
21    pub max_seq_len: usize,
22    /// Gradient accumulation steps (effective batch = batch_size * grad_accum).
23    pub gradient_accumulation_steps: u32,
24    /// Maximum gradient norm for clipping.
25    pub max_grad_norm: f64,
26}
27
28impl Default for TrainingHyperparams {
29    fn default() -> Self {
30        Self {
31            epochs: 3,
32            batch_size: 4,
33            learning_rate: 2e-5,
34            warmup_steps: 100,
35            weight_decay: 0.01,
36            lr_scheduler: LrScheduler::Cosine,
37            seed: 42,
38            max_seq_len: 2048,
39            gradient_accumulation_steps: 4,
40            max_grad_norm: 1.0,
41        }
42    }
43}
44
45/// Learning rate scheduler types.
46#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
47#[serde(rename_all = "snake_case")]
48pub enum LrScheduler {
49    /// Constant learning rate.
50    Constant,
51    /// Linear decay to zero.
52    Linear,
53    /// Cosine annealing.
54    Cosine,
55    /// Cosine with warm restarts.
56    CosineWarmRestarts,
57}
58
59/// LoRA adapter configuration.
60#[derive(Debug, Clone, Serialize, Deserialize)]
61pub struct LoraConfig {
62    /// LoRA rank (typical: 8, 16, 32, 64).
63    pub rank: u32,
64    /// LoRA alpha scaling factor (typical: rank * 2).
65    pub alpha: f32,
66    /// Dropout rate on LoRA layers.
67    pub dropout: f32,
68    /// Target modules to apply LoRA to (e.g., ["q_proj", "v_proj"]).
69    pub target_modules: Vec<String>,
70    /// Adapter method variant.
71    pub method: AdapterMethod,
72}
73
74impl Default for LoraConfig {
75    fn default() -> Self {
76        Self {
77            rank: 16,
78            alpha: 32.0,
79            dropout: 0.05,
80            target_modules: vec![
81                "q_proj".to_string(),
82                "k_proj".to_string(),
83                "v_proj".to_string(),
84                "o_proj".to_string(),
85            ],
86            method: AdapterMethod::LoRA,
87        }
88    }
89}
90
91/// Adapter method for parameter-efficient fine-tuning.
92#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
93#[serde(rename_all = "snake_case")]
94pub enum AdapterMethod {
95    /// Low-Rank Adaptation.
96    LoRA,
97    /// Quantized LoRA (4-bit or 8-bit base weights).
98    QLoRA {
99        /// Quantization bit width.
100        bits: u8,
101    },
102    /// Weight-Decomposed Low-Rank Adaptation (direction + magnitude).
103    DoRA,
104    /// Quantized DoRA.
105    QDoRA {
106        /// Quantization bit width.
107        bits: u8,
108    },
109}
110
111impl AdapterMethod {
112    /// Whether this adapter method uses quantization.
113    pub fn is_quantized(&self) -> bool {
114        matches!(self, Self::QLoRA { .. } | Self::QDoRA { .. })
115    }
116
117    /// Return quantization bit width, if applicable.
118    pub fn quantization_bits(&self) -> Option<u8> {
119        match self {
120            Self::QLoRA { bits } | Self::QDoRA { bits } => Some(*bits),
121            _ => None,
122        }
123    }
124}
125
126/// Alignment training method.
127#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
128#[serde(rename_all = "snake_case")]
129#[derive(Default)]
130pub enum AlignmentMethod {
131    /// Direct Preference Optimization.
132    DPO {
133        /// DPO beta parameter.
134        beta: f64,
135    },
136    /// Odds Ratio Preference Optimization (single-pass).
137    ORPO {
138        /// ORPO lambda parameter.
139        lambda: f64,
140    },
141    /// No alignment, standard SFT only.
142    #[default]
143    None,
144}
145
146impl AlignmentMethod {
147    /// Create DPO alignment with default beta.
148    pub fn dpo() -> Self {
149        Self::DPO { beta: 0.1 }
150    }
151
152    /// Create ORPO alignment with default lambda.
153    pub fn orpo() -> Self {
154        Self::ORPO { lambda: 0.5 }
155    }
156}
157
158#[cfg(test)]
159mod tests {
160    use super::*;
161
162    #[test]
163    fn test_hyperparams_defaults() {
164        let h = TrainingHyperparams::default();
165        assert_eq!(h.epochs, 3);
166        assert_eq!(h.batch_size, 4);
167        assert!((h.learning_rate - 2e-5).abs() < f64::EPSILON);
168    }
169
170    #[test]
171    fn test_lora_config_defaults() {
172        let c = LoraConfig::default();
173        assert_eq!(c.rank, 16);
174        assert_eq!(c.target_modules.len(), 4);
175    }
176
177    #[test]
178    fn test_adapter_method_quantized() {
179        assert!(!AdapterMethod::LoRA.is_quantized());
180        assert!(AdapterMethod::QLoRA { bits: 4 }.is_quantized());
181        assert_eq!(
182            AdapterMethod::QLoRA { bits: 4 }.quantization_bits(),
183            Some(4)
184        );
185        assert!(AdapterMethod::DoRA.quantization_bits().is_none());
186    }
187
188    #[test]
189    fn test_alignment_methods() {
190        let dpo = AlignmentMethod::dpo();
191        assert!(matches!(dpo, AlignmentMethod::DPO { beta } if (beta - 0.1).abs() < f64::EPSILON));
192
193        let orpo = AlignmentMethod::orpo();
194        assert!(
195            matches!(orpo, AlignmentMethod::ORPO { lambda } if (lambda - 0.5).abs() < f64::EPSILON)
196        );
197    }
198
199    #[test]
200    fn test_serialization_roundtrip() {
201        let config = LoraConfig {
202            method: AdapterMethod::QLoRA { bits: 4 },
203            ..Default::default()
204        };
205        let json = serde_json::to_string(&config).unwrap();
206        let parsed: LoraConfig = serde_json::from_str(&json).unwrap();
207        assert_eq!(parsed.method, AdapterMethod::QLoRA { bits: 4 });
208    }
209}