1use crate::transformer::QCT;
12
13const SHIFT: f32 = std::f32::consts::FRAC_PI_2;
14
15#[derive(Clone, Debug)]
17pub struct TrainConfig {
18 pub learning_rate: f32,
19 pub num_epochs: usize,
20 pub context_length: usize,
21 pub log_interval: usize,
22 pub grad_clip: f32,
24 pub use_cosine_decay: bool,
26 pub warmup_epochs: usize,
28}
29
30impl Default for TrainConfig {
31 fn default() -> Self {
32 Self {
33 learning_rate: 0.03, num_epochs: 100,
35 context_length: 64, log_interval: 10,
37 grad_clip: 4.236, use_cosine_decay: false,
39 warmup_epochs: 0,
40 }
41 }
42}
43
44#[derive(Clone, Debug)]
46pub struct EpochMetrics {
47 pub epoch: usize,
48 pub loss: f32,
49 pub free_energy: f32,
50 pub grad_norm: f32,
51 pub elapsed_ms: f32,
52 pub learning_rate: f32,
53 pub params_trained: usize,
54}
55
56fn learning_rate(config: &TrainConfig, epoch: usize) -> f32 {
58 learning_rate_pub(config, epoch)
59}
60
61pub fn learning_rate_pub(config: &TrainConfig, epoch: usize) -> f32 {
63 let base_lr = config.learning_rate;
64 if config.warmup_epochs > 0 && epoch < config.warmup_epochs {
65 return base_lr * (epoch + 1) as f32 / config.warmup_epochs as f32;
66 }
67 if config.use_cosine_decay {
68 let effective_epoch = epoch.saturating_sub(config.warmup_epochs);
69 let total = config.num_epochs.saturating_sub(config.warmup_epochs).max(1);
70 let progress = effective_epoch as f32 / total as f32;
71 return base_lr * 0.5 * (1.0 + (std::f32::consts::PI * progress).cos());
72 }
73 base_lr
74}
75
76pub fn train(model: &mut QCT, tokens: &[usize], config: &TrainConfig) -> Vec<EpochMetrics> {
80 let mut metrics = Vec::new();
81 let num_params = model.num_params();
82
83 for epoch in 0..config.num_epochs {
84 let start = std::time::Instant::now();
85 let lr = learning_rate(config, epoch);
86
87 let max_start = tokens.len().saturating_sub(config.context_length + 1);
89 let window_start = if max_start > 0 { epoch % max_start } else { 0 };
90 let window_end = (window_start + config.context_length + 1).min(tokens.len());
91 let window = &tokens[window_start..window_end];
92
93 let base_loss = model.loss(window);
95 let (_, base_free_energy) = model.forward(&window[..window.len() - 1]);
96
97 let all_params = model.all_params();
101 let window_vec: Vec<usize> = window.to_vec(); use rayon::prelude::*;
104 let mut gradients: Vec<f32> = (0..num_params)
105 .into_par_iter()
106 .map(|k| {
107 let mut local = model.clone();
108
109 let mut plus = all_params.clone();
110 plus[k] += SHIFT;
111 local.set_all_params(&plus);
112 let loss_plus = local.loss(&window_vec);
113
114 plus[k] = all_params[k] - SHIFT;
115 local.set_all_params(&plus);
116 let loss_minus = local.loss(&window_vec);
117
118 (loss_plus - loss_minus) / 2.0
119 })
120 .collect();
121
122 let grad_norm: f32 = gradients.iter().map(|g| g * g).sum::<f32>().sqrt();
124
125 if grad_norm > config.grad_clip && grad_norm > 0.0 {
127 let scale = config.grad_clip / grad_norm;
128 for g in &mut gradients {
129 *g *= scale;
130 }
131 }
132
133 let mut updated = all_params;
135 for k in 0..num_params {
136 updated[k] -= lr * gradients[k];
137 }
138 model.set_all_params(&updated);
139
140 let elapsed = start.elapsed().as_secs_f32() * 1000.0;
141
142 if epoch % config.log_interval == 0 || epoch == config.num_epochs - 1 {
143 let m = EpochMetrics {
144 epoch,
145 loss: base_loss,
146 free_energy: base_free_energy,
147 grad_norm,
148 elapsed_ms: elapsed,
149 learning_rate: lr,
150 params_trained: num_params,
151 };
152 log::info!(
153 "Epoch {:4}: loss={:.4} F={:.4} |∇|={:.6} lr={:.5} params={} ({:.0}ms)",
154 m.epoch,
155 m.loss,
156 m.free_energy,
157 m.grad_norm,
158 m.learning_rate,
159 m.params_trained,
160 m.elapsed_ms
161 );
162 metrics.push(m);
163 }
164 }
165
166 metrics
167}
168
169#[cfg(test)]
170mod tests {
171 use super::*;
172 use crate::transformer::QCTConfig;
173
174 #[test]
175 fn training_reduces_loss() {
176 let config = QCTConfig {
177 vocab_size: 10,
178 dim: 4,
179 num_blocks: 1,
180 seed: 42,
181 };
182 let mut model = QCT::new(config);
183 let tokens: Vec<usize> = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5];
184
185 let initial_loss = model.loss(&tokens[..8]);
186
187 let train_config = TrainConfig {
188 learning_rate: 0.05,
189 num_epochs: 5,
190 context_length: 8,
191 log_interval: 5,
192 ..Default::default()
193 };
194 let _metrics = train(&mut model, &tokens, &train_config);
195
196 let final_loss = model.loss(&tokens[..8]);
197 assert!(final_loss.is_finite(), "loss should be finite after training");
198 eprintln!("Initial loss: {:.4}, Final loss: {:.4}", initial_loss, final_loss);
199 }
200
201 #[test]
202 fn gradient_is_nonzero() {
203 let config = QCTConfig {
204 vocab_size: 10,
205 dim: 4,
206 num_blocks: 1,
207 seed: 42,
208 };
209 let mut model = QCT::new(config);
210 let tokens: Vec<usize> = vec![0, 1, 2, 3, 4, 5, 6, 7];
211
212 let train_config = TrainConfig {
213 learning_rate: 0.01,
214 num_epochs: 1,
215 context_length: 6,
216 log_interval: 1,
217 ..Default::default()
218 };
219 let metrics = train(&mut model, &tokens, &train_config);
220
221 assert!(!metrics.is_empty());
222 assert!(metrics[0].grad_norm > 0.0, "gradient should be nonzero");
223 }
224
225 #[test]
226 fn all_params_trained() {
227 let config = QCTConfig {
228 vocab_size: 10,
229 dim: 4,
230 num_blocks: 1,
231 seed: 42,
232 };
233 let mut model = QCT::new(config);
234 let tokens: Vec<usize> = vec![0, 1, 2, 3, 4, 5, 6, 7];
235
236 let train_config = TrainConfig {
237 learning_rate: 0.01,
238 num_epochs: 1,
239 context_length: 6,
240 log_interval: 1,
241 ..Default::default()
242 };
243 let metrics = train(&mut model, &tokens, &train_config);
244 assert_eq!(
245 metrics[0].params_trained,
246 model.num_params(),
247 "should train ALL {} params, not a subset",
248 model.num_params()
249 );
250 }
251
252 #[test]
253 fn all_params_roundtrip() {
254 let config = QCTConfig {
255 vocab_size: 10,
256 dim: 4,
257 num_blocks: 1,
258 seed: 42,
259 };
260 let model = QCT::new(config.clone());
261 let params = model.all_params();
262 let mut model2 = QCT::new(config);
263 model2.set_all_params(¶ms);
264 let params2 = model2.all_params();
265 assert_eq!(params.len(), params2.len());
266 for (a, b) in params.iter().zip(params2.iter()) {
267 assert!((a - b).abs() < 1e-6, "param roundtrip mismatch");
268 }
269 }
270
271 #[test]
272 fn cosine_lr_schedule() {
273 let config = TrainConfig {
274 learning_rate: 0.1,
275 num_epochs: 100,
276 use_cosine_decay: true,
277 ..Default::default()
278 };
279 let lr_start = learning_rate(&config, 0);
280 let lr_mid = learning_rate(&config, 50);
281 let lr_end = learning_rate(&config, 99);
282 assert!((lr_start - 0.1).abs() < 0.01, "start lr should be ~0.1");
283 assert!((lr_mid - 0.05).abs() < 0.01, "mid lr should be ~0.05");
284 assert!(lr_end < 0.01, "end lr should be near 0, got {lr_end}");
285 }
286
287 #[test]
288 fn warmup_lr_schedule() {
289 let config = TrainConfig {
290 learning_rate: 0.1,
291 num_epochs: 100,
292 warmup_epochs: 10,
293 ..Default::default()
294 };
295 let lr_0 = learning_rate(&config, 0);
296 let lr_5 = learning_rate(&config, 5);
297 let lr_10 = learning_rate(&config, 10);
298 assert!(lr_0 < lr_5, "lr should increase during warmup");
299 assert!(lr_5 < lr_10, "lr should increase during warmup");
300 assert!((lr_10 - 0.1).abs() < 0.01, "lr should reach base after warmup");
301 }
302}