Skip to main content

dreamwell_intelligence/
train.rs

1// Training loop — parameter shift gradient descent on ALL QCT parameters.
2//
3// Uses the exact parameter shift rule (Model 1 validated):
4//   ∂L/∂θ_k = [L(θ_k + π/2) - L(θ_k - π/2)] / 2
5//
6// Trains ALL parameters: embedding angles + Hamiltonian + value weights + output.
7// Optional: rayon parallel gradient (embarrassingly parallel per parameter).
8//
9// Clean Compute: pre-allocated gradient buffer. No autograd graph. Explicit loops.
10
11use crate::transformer::QCT;
12
13const SHIFT: f32 = std::f32::consts::FRAC_PI_2;
14
15/// Training configuration.
16#[derive(Clone, Debug)]
17pub struct TrainConfig {
18    pub learning_rate: f32,
19    pub num_epochs: usize,
20    pub context_length: usize,
21    pub log_interval: usize,
22    /// Gradient clipping: max allowed grad norm.
23    pub grad_clip: f32,
24    /// Use cosine learning rate decay.
25    pub use_cosine_decay: bool,
26    /// Warmup epochs (linear ramp from 0 to lr).
27    pub warmup_epochs: usize,
28}
29
30impl Default for TrainConfig {
31    fn default() -> Self {
32        Self {
33            learning_rate: 0.03, // Loom default (close to 1/φ⁴×φ)
34            num_epochs: 100,
35            context_length: 64, // Matches Loom context
36            log_interval: 10,
37            grad_clip: 4.236, // φ³ — gradient ceiling mirrors 1/φ³ convergence floor
38            use_cosine_decay: false,
39            warmup_epochs: 0,
40        }
41    }
42}
43
44/// Training metrics for one epoch.
45#[derive(Clone, Debug)]
46pub struct EpochMetrics {
47    pub epoch: usize,
48    pub loss: f32,
49    pub free_energy: f32,
50    pub grad_norm: f32,
51    pub elapsed_ms: f32,
52    pub learning_rate: f32,
53    pub params_trained: usize,
54}
55
56/// Compute learning rate for a given epoch (warmup + optional cosine decay).
57fn learning_rate(config: &TrainConfig, epoch: usize) -> f32 {
58    learning_rate_pub(config, epoch)
59}
60
61/// Public version for use by adjoint module.
62pub fn learning_rate_pub(config: &TrainConfig, epoch: usize) -> f32 {
63    let base_lr = config.learning_rate;
64    if config.warmup_epochs > 0 && epoch < config.warmup_epochs {
65        return base_lr * (epoch + 1) as f32 / config.warmup_epochs as f32;
66    }
67    if config.use_cosine_decay {
68        let effective_epoch = epoch.saturating_sub(config.warmup_epochs);
69        let total = config.num_epochs.saturating_sub(config.warmup_epochs).max(1);
70        let progress = effective_epoch as f32 / total as f32;
71        return base_lr * 0.5 * (1.0 + (std::f32::consts::PI * progress).cos());
72    }
73    base_lr
74}
75
76/// Train the QCT on a token sequence using parameter shift gradient descent.
77/// Trains ALL parameters (embedding + hamiltonian + values + output).
78/// Returns metrics for each logged epoch.
79pub fn train(model: &mut QCT, tokens: &[usize], config: &TrainConfig) -> Vec<EpochMetrics> {
80    let mut metrics = Vec::new();
81    let num_params = model.num_params();
82
83    for epoch in 0..config.num_epochs {
84        let start = std::time::Instant::now();
85        let lr = learning_rate(config, epoch);
86
87        // Select training window
88        let max_start = tokens.len().saturating_sub(config.context_length + 1);
89        let window_start = if max_start > 0 { epoch % max_start } else { 0 };
90        let window_end = (window_start + config.context_length + 1).min(tokens.len());
91        let window = &tokens[window_start..window_end];
92
93        // Baseline loss and free energy
94        let base_loss = model.loss(window);
95        let (_, base_free_energy) = model.forward(&window[..window.len() - 1]);
96
97        // Parameter shift gradient over ALL parameters (rayon parallel).
98        // Each parameter's gradient is independent — embarrassingly parallel.
99        // Each thread clones the model and evaluates ±shift independently.
100        let all_params = model.all_params();
101        let window_vec: Vec<usize> = window.to_vec(); // owned for Send
102
103        use rayon::prelude::*;
104        let mut gradients: Vec<f32> = (0..num_params)
105            .into_par_iter()
106            .map(|k| {
107                let mut local = model.clone();
108
109                let mut plus = all_params.clone();
110                plus[k] += SHIFT;
111                local.set_all_params(&plus);
112                let loss_plus = local.loss(&window_vec);
113
114                plus[k] = all_params[k] - SHIFT;
115                local.set_all_params(&plus);
116                let loss_minus = local.loss(&window_vec);
117
118                (loss_plus - loss_minus) / 2.0
119            })
120            .collect();
121
122        // Gradient norm
123        let grad_norm: f32 = gradients.iter().map(|g| g * g).sum::<f32>().sqrt();
124
125        // Gradient clipping
126        if grad_norm > config.grad_clip && grad_norm > 0.0 {
127            let scale = config.grad_clip / grad_norm;
128            for g in &mut gradients {
129                *g *= scale;
130            }
131        }
132
133        // Update ALL parameters
134        let mut updated = all_params;
135        for k in 0..num_params {
136            updated[k] -= lr * gradients[k];
137        }
138        model.set_all_params(&updated);
139
140        let elapsed = start.elapsed().as_secs_f32() * 1000.0;
141
142        if epoch % config.log_interval == 0 || epoch == config.num_epochs - 1 {
143            let m = EpochMetrics {
144                epoch,
145                loss: base_loss,
146                free_energy: base_free_energy,
147                grad_norm,
148                elapsed_ms: elapsed,
149                learning_rate: lr,
150                params_trained: num_params,
151            };
152            log::info!(
153                "Epoch {:4}: loss={:.4} F={:.4} |∇|={:.6} lr={:.5} params={} ({:.0}ms)",
154                m.epoch,
155                m.loss,
156                m.free_energy,
157                m.grad_norm,
158                m.learning_rate,
159                m.params_trained,
160                m.elapsed_ms
161            );
162            metrics.push(m);
163        }
164    }
165
166    metrics
167}
168
169#[cfg(test)]
170mod tests {
171    use super::*;
172    use crate::transformer::QCTConfig;
173
174    #[test]
175    fn training_reduces_loss() {
176        let config = QCTConfig {
177            vocab_size: 10,
178            dim: 4,
179            num_blocks: 1,
180            seed: 42,
181        };
182        let mut model = QCT::new(config);
183        let tokens: Vec<usize> = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5];
184
185        let initial_loss = model.loss(&tokens[..8]);
186
187        let train_config = TrainConfig {
188            learning_rate: 0.05,
189            num_epochs: 5,
190            context_length: 8,
191            log_interval: 5,
192            ..Default::default()
193        };
194        let _metrics = train(&mut model, &tokens, &train_config);
195
196        let final_loss = model.loss(&tokens[..8]);
197        assert!(final_loss.is_finite(), "loss should be finite after training");
198        eprintln!("Initial loss: {:.4}, Final loss: {:.4}", initial_loss, final_loss);
199    }
200
201    #[test]
202    fn gradient_is_nonzero() {
203        let config = QCTConfig {
204            vocab_size: 10,
205            dim: 4,
206            num_blocks: 1,
207            seed: 42,
208        };
209        let mut model = QCT::new(config);
210        let tokens: Vec<usize> = vec![0, 1, 2, 3, 4, 5, 6, 7];
211
212        let train_config = TrainConfig {
213            learning_rate: 0.01,
214            num_epochs: 1,
215            context_length: 6,
216            log_interval: 1,
217            ..Default::default()
218        };
219        let metrics = train(&mut model, &tokens, &train_config);
220
221        assert!(!metrics.is_empty());
222        assert!(metrics[0].grad_norm > 0.0, "gradient should be nonzero");
223    }
224
225    #[test]
226    fn all_params_trained() {
227        let config = QCTConfig {
228            vocab_size: 10,
229            dim: 4,
230            num_blocks: 1,
231            seed: 42,
232        };
233        let mut model = QCT::new(config);
234        let tokens: Vec<usize> = vec![0, 1, 2, 3, 4, 5, 6, 7];
235
236        let train_config = TrainConfig {
237            learning_rate: 0.01,
238            num_epochs: 1,
239            context_length: 6,
240            log_interval: 1,
241            ..Default::default()
242        };
243        let metrics = train(&mut model, &tokens, &train_config);
244        assert_eq!(
245            metrics[0].params_trained,
246            model.num_params(),
247            "should train ALL {} params, not a subset",
248            model.num_params()
249        );
250    }
251
252    #[test]
253    fn all_params_roundtrip() {
254        let config = QCTConfig {
255            vocab_size: 10,
256            dim: 4,
257            num_blocks: 1,
258            seed: 42,
259        };
260        let model = QCT::new(config.clone());
261        let params = model.all_params();
262        let mut model2 = QCT::new(config);
263        model2.set_all_params(&params);
264        let params2 = model2.all_params();
265        assert_eq!(params.len(), params2.len());
266        for (a, b) in params.iter().zip(params2.iter()) {
267            assert!((a - b).abs() < 1e-6, "param roundtrip mismatch");
268        }
269    }
270
271    #[test]
272    fn cosine_lr_schedule() {
273        let config = TrainConfig {
274            learning_rate: 0.1,
275            num_epochs: 100,
276            use_cosine_decay: true,
277            ..Default::default()
278        };
279        let lr_start = learning_rate(&config, 0);
280        let lr_mid = learning_rate(&config, 50);
281        let lr_end = learning_rate(&config, 99);
282        assert!((lr_start - 0.1).abs() < 0.01, "start lr should be ~0.1");
283        assert!((lr_mid - 0.05).abs() < 0.01, "mid lr should be ~0.05");
284        assert!(lr_end < 0.01, "end lr should be near 0, got {lr_end}");
285    }
286
287    #[test]
288    fn warmup_lr_schedule() {
289        let config = TrainConfig {
290            learning_rate: 0.1,
291            num_epochs: 100,
292            warmup_epochs: 10,
293            ..Default::default()
294        };
295        let lr_0 = learning_rate(&config, 0);
296        let lr_5 = learning_rate(&config, 5);
297        let lr_10 = learning_rate(&config, 10);
298        assert!(lr_0 < lr_5, "lr should increase during warmup");
299        assert!(lr_5 < lr_10, "lr should increase during warmup");
300        assert!((lr_10 - 0.1).abs() < 0.01, "lr should reach base after warmup");
301    }
302}