brainwires_training/
config.rs1use serde::{Deserialize, Serialize};
2
3#[derive(Debug, Clone, Serialize, Deserialize)]
5pub struct TrainingHyperparams {
6 pub epochs: u32,
8 pub batch_size: u32,
10 pub learning_rate: f64,
12 pub warmup_steps: u64,
14 pub weight_decay: f64,
16 pub lr_scheduler: LrScheduler,
18 pub seed: u64,
20 pub max_seq_len: usize,
22 pub gradient_accumulation_steps: u32,
24 pub max_grad_norm: f64,
26}
27
28impl Default for TrainingHyperparams {
29 fn default() -> Self {
30 Self {
31 epochs: 3,
32 batch_size: 4,
33 learning_rate: 2e-5,
34 warmup_steps: 100,
35 weight_decay: 0.01,
36 lr_scheduler: LrScheduler::Cosine,
37 seed: 42,
38 max_seq_len: 2048,
39 gradient_accumulation_steps: 4,
40 max_grad_norm: 1.0,
41 }
42 }
43}
44
45#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
47#[serde(rename_all = "snake_case")]
48pub enum LrScheduler {
49 Constant,
51 Linear,
53 Cosine,
55 CosineWarmRestarts,
57}
58
59#[derive(Debug, Clone, Serialize, Deserialize)]
61pub struct LoraConfig {
62 pub rank: u32,
64 pub alpha: f32,
66 pub dropout: f32,
68 pub target_modules: Vec<String>,
70 pub method: AdapterMethod,
72}
73
74impl Default for LoraConfig {
75 fn default() -> Self {
76 Self {
77 rank: 16,
78 alpha: 32.0,
79 dropout: 0.05,
80 target_modules: vec![
81 "q_proj".to_string(),
82 "k_proj".to_string(),
83 "v_proj".to_string(),
84 "o_proj".to_string(),
85 ],
86 method: AdapterMethod::LoRA,
87 }
88 }
89}
90
91#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
93#[serde(rename_all = "snake_case")]
94pub enum AdapterMethod {
95 LoRA,
97 QLoRA {
99 bits: u8,
101 },
102 DoRA,
104 QDoRA {
106 bits: u8,
108 },
109}
110
111impl AdapterMethod {
112 pub fn is_quantized(&self) -> bool {
114 matches!(self, Self::QLoRA { .. } | Self::QDoRA { .. })
115 }
116
117 pub fn quantization_bits(&self) -> Option<u8> {
119 match self {
120 Self::QLoRA { bits } | Self::QDoRA { bits } => Some(*bits),
121 _ => None,
122 }
123 }
124}
125
126#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
128#[serde(rename_all = "snake_case")]
129#[derive(Default)]
130pub enum AlignmentMethod {
131 DPO {
133 beta: f64,
135 },
136 ORPO {
138 lambda: f64,
140 },
141 #[default]
143 None,
144}
145
146impl AlignmentMethod {
147 pub fn dpo() -> Self {
149 Self::DPO { beta: 0.1 }
150 }
151
152 pub fn orpo() -> Self {
154 Self::ORPO { lambda: 0.5 }
155 }
156}
157
158#[cfg(test)]
159mod tests {
160 use super::*;
161
162 #[test]
163 fn test_hyperparams_defaults() {
164 let h = TrainingHyperparams::default();
165 assert_eq!(h.epochs, 3);
166 assert_eq!(h.batch_size, 4);
167 assert!((h.learning_rate - 2e-5).abs() < f64::EPSILON);
168 }
169
170 #[test]
171 fn test_lora_config_defaults() {
172 let c = LoraConfig::default();
173 assert_eq!(c.rank, 16);
174 assert_eq!(c.target_modules.len(), 4);
175 }
176
177 #[test]
178 fn test_adapter_method_quantized() {
179 assert!(!AdapterMethod::LoRA.is_quantized());
180 assert!(AdapterMethod::QLoRA { bits: 4 }.is_quantized());
181 assert_eq!(
182 AdapterMethod::QLoRA { bits: 4 }.quantization_bits(),
183 Some(4)
184 );
185 assert!(AdapterMethod::DoRA.quantization_bits().is_none());
186 }
187
188 #[test]
189 fn test_alignment_methods() {
190 let dpo = AlignmentMethod::dpo();
191 assert!(matches!(dpo, AlignmentMethod::DPO { beta } if (beta - 0.1).abs() < f64::EPSILON));
192
193 let orpo = AlignmentMethod::orpo();
194 assert!(
195 matches!(orpo, AlignmentMethod::ORPO { lambda } if (lambda - 0.5).abs() < f64::EPSILON)
196 );
197 }
198
199 #[test]
200 fn test_serialization_roundtrip() {
201 let config = LoraConfig {
202 method: AdapterMethod::QLoRA { bits: 4 },
203 ..Default::default()
204 };
205 let json = serde_json::to_string(&config).unwrap();
206 let parsed: LoraConfig = serde_json::from_str(&json).unwrap();
207 assert_eq!(parsed.method, AdapterMethod::QLoRA { bits: 4 });
208 }
209}