1use crate::attention;
15use crate::complex::Complex;
16use crate::density_matrix::DensityMatrixN;
17use crate::train::EpochMetrics;
18use crate::transformer::QCT;
19
20pub struct ForwardCache {
24 pub rho_before: Vec<Vec<Complex>>,
26 pub unitaries: Vec<Vec<Complex>>,
28 pub rho_after: Vec<Vec<Complex>>,
30 pub populations: Vec<Vec<f32>>,
32 pub values: Vec<Vec<f32>>,
34}
35
36pub struct AllGradients {
38 pub embed_grad: Vec<f32>,
39 pub block_grads: Vec<BlockGrad>,
40 pub output_grad: Vec<f32>,
41}
42
43pub struct BlockGrad {
44 pub hamiltonian_grad: Vec<f32>,
45 pub value_weight_grad: Vec<f32>,
46}
47
48impl AllGradients {
49 pub fn flatten(&self) -> Vec<f32> {
51 let mut v = Vec::new();
52 v.extend_from_slice(&self.embed_grad);
53 for bg in &self.block_grads {
54 v.extend_from_slice(&bg.hamiltonian_grad);
55 v.extend_from_slice(&bg.value_weight_grad);
56 }
57 v.extend_from_slice(&self.output_grad);
58 v
59 }
60}
61
62pub fn forward_with_cache(model: &QCT, tokens: &[usize]) -> (Vec<Vec<f32>>, f32, Vec<ForwardCache>) {
65 forward_with_cache_converter(model, tokens, None)
66}
67
68pub fn forward_with_cache_converter(
73 model: &QCT,
74 tokens: &[usize],
75 converter: Option<&crate::golden_ratio_converter::GoldenRatioConverter>,
76) -> (Vec<Vec<f32>>, f32, Vec<ForwardCache>) {
77 let dim = model.config.dim;
78 let t = tokens.len();
79
80 let mut states: Vec<DensityMatrixN> = tokens.iter().map(|&tok| model.embedding.embed(tok)).collect();
84 if let Some(conv) = converter {
85 for (i, &tok) in tokens.iter().enumerate() {
86 let eps = conv.dephasing_rate(tok);
87 if eps > 1e-6 {
88 states[i].dephase(eps);
89 }
90 }
91 }
92 let mut values: Vec<Vec<f32>> = states.iter().map(|s| s.populations()).collect();
93 let mut current_states = states;
94 let mut caches = Vec::with_capacity(model.blocks.len());
95 let mut total_f = 0.0f32;
96
97 for block in &model.blocks {
98 let mut cache = ForwardCache {
99 rho_before: Vec::with_capacity(t),
100 unitaries: Vec::with_capacity(t),
101 rho_after: Vec::with_capacity(t),
102 populations: Vec::with_capacity(t),
103 values: values.clone(),
104 };
105
106 let precomputed_unitaries: Vec<Vec<Complex>> = {
111 use rayon::prelude::*;
112 (0..t)
113 .into_par_iter()
114 .map(|i| {
115 let h_matrix = block.hamiltonian.build_matrix(i);
116 DensityMatrixN::hamiltonian_unitary(&h_matrix, dim, 0.090) })
118 .collect()
119 };
120
121 const COHERENCE_WINDOW: usize = 8;
137
138 let f_gate = 0.236 / dim as f32;
142
143 for i in 0..t {
144 let mut rho = current_states[i].clone();
145
146 let window_start = i.saturating_sub(COHERENCE_WINDOW);
149 for j in window_start..i {
150 let dist = i - j;
151 let eps = block.hamiltonian.causal_dephasing(dist);
152 rho.couple_dephase(¤t_states[j], eps);
153 }
154
155 cache.rho_before.push(rho.entries.clone());
157
158 let f_before = rho.free_energy(&block.hamiltonian.bias);
161 let unitary = &precomputed_unitaries[i];
162
163 if f_before.abs() > f_gate {
166 cache.unitaries.push(unitary.clone());
167 rho.evolve(unitary);
168 } else {
169 cache.unitaries.push(Vec::new());
170 }
171
172 cache.rho_after.push(rho.entries.clone());
174
175 let f = rho.free_energy(&block.hamiltonian.bias);
176 total_f += f;
177
178 let pops = rho.populations();
179 cache.populations.push(pops);
180 }
181
182 let attn_output = attention::AttentionOutput {
184 populations: cache.populations.clone(),
185 free_energies: vec![0.0; t],
186 coherences: vec![0.0; t],
187 };
188 let new_values = attention::attention_project(&attn_output, &values, dim);
189 values = new_values;
190
191 let eps_block = 0.236 / model.blocks.len().max(1) as f32;
195 for state in &mut current_states {
196 state.dephase(eps_block);
197 }
198
199 caches.push(cache);
200 }
201
202 let vocab = model.config.vocab_size;
204 let mut logits = Vec::with_capacity(t);
205 for i in 0..t {
206 let mut token_logits = vec![0.0f32; vocab];
207 for v in 0..vocab {
208 for d in 0..dim {
209 token_logits[v] += values[i][d] * model.output_weights[d * vocab + v];
210 }
211 }
212 logits.push(token_logits);
213 }
214
215 (logits, total_f / t as f32, caches)
216}
217
218pub fn qug_backward(model: &QCT, tokens: &[usize], logits: &[Vec<f32>], caches: &[ForwardCache]) -> AllGradients {
221 let dim = model.config.dim;
222 let vocab = model.config.vocab_size;
223 let t = tokens.len().saturating_sub(1); if t == 0 {
225 return AllGradients {
226 embed_grad: vec![0.0; model.embedding.num_params()],
227 block_grads: model
228 .blocks
229 .iter()
230 .map(|b| BlockGrad {
231 hamiltonian_grad: vec![0.0; b.hamiltonian.num_params()],
232 value_weight_grad: vec![0.0; b.value_weights.len()],
233 })
234 .collect(),
235 output_grad: vec![0.0; model.output_weights.len()],
236 };
237 }
238
239 let mut d_logits: Vec<Vec<f32>> = Vec::with_capacity(t);
241 for i in 0..t {
242 let target = tokens[i + 1];
243 let max_l = logits[i].iter().cloned().fold(f32::NEG_INFINITY, f32::max);
244 let exp_sum: f32 = logits[i].iter().map(|&l| (l - max_l).exp()).sum();
245 let mut d_log = vec![0.0f32; vocab];
246 for v in 0..vocab {
247 let softmax_v = (logits[i][v] - max_l).exp() / exp_sum;
248 d_log[v] = softmax_v - if v == target { 1.0 } else { 0.0 };
249 }
250 for v in &mut d_log {
252 *v /= t as f32;
253 }
254 d_logits.push(d_log);
255 }
256
257 let mut d_output = vec![0.0f32; dim * vocab];
259 for i in 0..t {
260 let cache = caches.last().unwrap();
261 let vals = if i < cache.populations.len() {
262 &cache.populations[i]
263 } else {
264 continue;
265 };
266 for d_idx in 0..dim {
267 for v in 0..vocab {
268 d_output[d_idx * vocab + v] += vals.get(d_idx).copied().unwrap_or(0.0) * d_logits[i][v];
269 }
270 }
271 }
272
273 let mut d_values: Vec<Vec<f32>> = vec![vec![0.0f32; dim]; t.max(1)];
275 for i in 0..t {
276 for d_idx in 0..dim {
277 for v in 0..vocab {
278 d_values[i][d_idx] += model.output_weights[d_idx * vocab + v] * d_logits[i][v];
279 }
280 }
281 }
282
283 let mut block_grads = Vec::with_capacity(model.blocks.len());
285 for (block_idx, block) in model.blocks.iter().enumerate().rev() {
286 let cache = &caches[block_idx];
287 let num_h = block.hamiltonian.num_params();
288
289 let mut d_vw = vec![0.0f32; dim * dim];
291 for i in 0..t.min(cache.populations.len()) {
293 for d_idx in 0..dim {
294 for s in 0..dim {
295 let pop = cache.populations[i].get(s).copied().unwrap_or(0.0);
296 let dv = d_values[i].get(d_idx).copied().unwrap_or(0.0);
297 d_vw[d_idx * dim + s] += pop * dv;
298 }
299 }
300 }
301
302 let dt = 0.090f32; let len = t.min(cache.unitaries.len());
307
308 let position_grads: Vec<Vec<f32>> = {
309 use rayon::prelude::*;
310 (0..len)
311 .into_par_iter()
312 .map(|i| {
313 let mut local_d_h = vec![0.0f32; num_h];
314
315 let u = &cache.unitaries[i];
319 if u.is_empty() {
320 return local_d_h;
321 }
322
323 let mut scratch_a = vec![Complex::ZERO; dim * dim];
324 let mut scratch_b = vec![Complex::ZERO; dim * dim];
325
326 let mut d_rho = vec![Complex::ZERO; dim * dim];
328 for k in 0..dim {
329 let dp = d_values[i].get(k).copied().unwrap_or(0.0);
330 d_rho[k * dim + k] = Complex::new(dp, 0.0);
331 }
332
333 let mut u_dag = vec![Complex::ZERO; dim * dim];
335 for r in 0..dim {
336 for c in 0..dim {
337 u_dag[r * dim + c] = u[c * dim + r].conj();
338 }
339 }
340
341 dreamwell_math::linalg::cgemm(&u_dag, &d_rho, &mut scratch_a, dim, dim, dim);
343 dreamwell_math::linalg::cgemm(&scratch_a, u, &mut scratch_b, dim, dim, dim);
344
345 let rho_before = &cache.rho_before[i];
347 let mut h_idx = 0;
348
349 for k in 0..dim {
351 let mut comm_diag = 0.0f32;
352 for j in 0..dim {
353 let ab = scratch_b[k * dim + j].mul(rho_before[j * dim + k]);
354 let ba = rho_before[k * dim + j].mul(scratch_b[j * dim + k]);
355 comm_diag += (ab.sub(ba)).im;
356 }
357 if h_idx < local_d_h.len() {
358 local_d_h[h_idx] += -dt * comm_diag;
359 }
360 h_idx += 1;
361 }
362
363 for p in 0..dim {
365 for q in (p + 1)..dim {
366 if h_idx >= local_d_h.len() {
367 break;
368 }
369 let ab_pq = scratch_b[p * dim + q].mul(rho_before[q * dim + p]);
370 let ba_pq = rho_before[p * dim + q].mul(scratch_b[q * dim + p]);
371 let comm_pq = ab_pq.sub(ba_pq);
372 local_d_h[h_idx] += -dt * 2.0 * comm_pq.im;
373 h_idx += 1;
374 }
375 }
376
377 local_d_h
378 })
379 .collect()
380 };
381
382 let mut d_h = vec![0.0f32; num_h];
384 for pg in &position_grads {
385 for (k, &v) in pg.iter().enumerate() {
386 d_h[k] += v;
387 }
388 }
389
390 block_grads.push(BlockGrad {
391 hamiltonian_grad: d_h,
392 value_weight_grad: d_vw,
393 });
394 }
395
396 block_grads.reverse();
398
399 let embed_grad = vec![0.0f32; model.embedding.num_params()];
403
404 AllGradients {
405 embed_grad,
406 block_grads,
407 output_grad: d_output,
408 }
409}
410
411pub fn forward_backward_epoch_gpu(
425 gpu: &dreamwell_math::gpu_training::GpuTrainingContext,
426 model: &QCT,
427 windows: &[(usize, usize)],
428 tokens: &[usize],
429) -> (Vec<f32>, f32, f32) {
430 use dreamwell_math::Complex;
431 let dim = model.config.dim;
432 let vocab = model.config.vocab_size;
433 let stride = dim * dim;
434 let num_windows = windows.len();
435 let num_blocks = model.blocks.len();
436 let dt = 0.090f32; let mut all_window_data: Vec<WindowForwardState> = Vec::with_capacity(num_windows);
444
445 for &(ws, we) in windows {
446 let window_tokens = &tokens[ws..we];
447 let input = &window_tokens[..window_tokens.len().saturating_sub(1)];
448 let t = input.len();
449
450 let states: Vec<DensityMatrixN> = input.iter().map(|&tok| model.embedding.embed(tok)).collect();
452 let values: Vec<Vec<f32>> = states.iter().map(|s| s.populations()).collect();
453
454 let mut block_unitaries: Vec<Vec<Vec<Complex>>> = Vec::with_capacity(num_blocks);
456 for block in &model.blocks {
457 let mut all_h = vec![0.0f32; t * stride];
458 for i in 0..t {
459 let h = block.hamiltonian.build_matrix(i);
460 all_h[i * stride..(i + 1) * stride].copy_from_slice(&h);
461 }
462 let flat_unitaries = gpu.batched_expm(&all_h, dt, t);
463 let per_pos: Vec<Vec<Complex>> = (0..t)
464 .map(|i| flat_unitaries[i * stride..(i + 1) * stride].to_vec())
465 .collect();
466 block_unitaries.push(per_pos);
467 }
468
469 all_window_data.push(WindowForwardState {
470 window_tokens: window_tokens.to_vec(),
471 t,
472 states,
473 values,
474 block_unitaries,
475 });
476 }
477
478 let f_gate = 0.236 / dim as f32;
483 const COHERENCE_WINDOW: usize = 8;
484 let eps_block = 0.236 / num_blocks.max(1) as f32;
485
486 let mut all_window_results: Vec<WindowResult> = Vec::with_capacity(num_windows);
487
488 for wdata in &mut all_window_data {
489 let t = wdata.t;
490 let mut current_states = wdata.states.clone();
491 let mut values = wdata.values.clone();
492 let mut caches: Vec<ForwardCache> = Vec::with_capacity(num_blocks);
493 let mut total_f = 0.0f32;
494
495 for (block_idx, block) in model.blocks.iter().enumerate() {
496 let mut cache = ForwardCache {
497 rho_before: Vec::with_capacity(t),
498 unitaries: Vec::with_capacity(t),
499 rho_after: Vec::with_capacity(t),
500 populations: Vec::with_capacity(t),
501 values: values.clone(),
502 };
503
504 let precomputed_unitaries = &wdata.block_unitaries[block_idx];
505
506 for i in 0..t {
510 let mut rho = current_states[i].clone();
511
512 let window_start = i.saturating_sub(COHERENCE_WINDOW);
514 for j in window_start..i {
515 let dist = i - j;
516 let eps = block.hamiltonian.causal_dephasing(dist);
517 rho.couple_dephase(¤t_states[j], eps);
518 }
519
520 cache.rho_before.push(rho.entries.clone());
521
522 let f_before = rho.free_energy(&block.hamiltonian.bias);
523 let unitary = &precomputed_unitaries[i];
524
525 if f_before.abs() > f_gate {
526 cache.unitaries.push(unitary.clone());
527 rho.evolve(unitary);
528 } else {
529 cache.unitaries.push(Vec::new());
530 }
531
532 cache.rho_after.push(rho.entries.clone());
533
534 let f = rho.free_energy(&block.hamiltonian.bias);
535 total_f += f;
536
537 let pops = rho.populations();
538 cache.populations.push(pops);
539
540 current_states[i] = rho;
542 }
543
544 let attn_output = attention::AttentionOutput {
546 populations: cache.populations.clone(),
547 free_energies: vec![0.0; t],
548 coherences: vec![0.0; t],
549 };
550 values = attention::attention_project(&attn_output, &values, dim);
551
552 for state in &mut current_states {
554 state.dephase(eps_block);
555 }
556
557 caches.push(cache);
558 }
559
560 let mut logits = Vec::with_capacity(t);
562 for i in 0..t {
563 let mut token_logits = vec![0.0f32; vocab];
564 for v in 0..vocab {
565 for d in 0..dim {
566 token_logits[v] += values[i][d] * model.output_weights[d * vocab + v];
567 }
568 }
569 logits.push(token_logits);
570 }
571
572 let avg_f = total_f / t.max(1) as f32;
573 let loss = QCT::loss_from_logits(&logits, &wdata.window_tokens, avg_f);
574
575 all_window_results.push(WindowResult {
576 logits,
577 caches,
578 loss,
579 avg_f,
580 });
581 }
582
583 let num_params = model.num_params();
588 let mut total_grad = vec![0.0f32; num_params];
589 let mut total_loss = 0.0f32;
590 let mut total_f = 0.0f32;
591
592 for (w_idx, wdata) in all_window_data.iter().enumerate() {
593 let wr = &all_window_results[w_idx];
594
595 let grads = qug_backward_gpu(gpu, model, &wdata.window_tokens, &wr.logits, &wr.caches);
596 let grad_flat = grads.flatten();
597
598 total_loss += wr.loss;
599 total_f += wr.avg_f;
600 for (i, &g) in grad_flat.iter().enumerate() {
601 if i < num_params {
602 total_grad[i] += g;
603 }
604 }
605 }
606
607 let n = num_windows as f32;
609 for g in &mut total_grad {
610 *g /= n;
611 }
612 (total_grad, total_loss / n, total_f / n)
613}
614
615fn qug_backward_gpu(
620 gpu: &dreamwell_math::gpu_training::GpuTrainingContext,
621 model: &QCT,
622 tokens: &[usize],
623 logits: &[Vec<f32>],
624 caches: &[ForwardCache],
625) -> AllGradients {
626 use dreamwell_math::Complex;
627 let dim = model.config.dim;
628 let vocab = model.config.vocab_size;
629 let stride = dim * dim;
630 let t = tokens.len().saturating_sub(1);
631 if t == 0 {
632 return AllGradients {
633 embed_grad: vec![0.0; model.embedding.num_params()],
634 block_grads: model
635 .blocks
636 .iter()
637 .map(|b| BlockGrad {
638 hamiltonian_grad: vec![0.0; b.hamiltonian.num_params()],
639 value_weight_grad: vec![0.0; b.value_weights.len()],
640 })
641 .collect(),
642 output_grad: vec![0.0; model.output_weights.len()],
643 };
644 }
645
646 let mut d_logits: Vec<Vec<f32>> = Vec::with_capacity(t);
648 for i in 0..t {
649 let target = tokens[i + 1];
650 let max_l = logits[i].iter().cloned().fold(f32::NEG_INFINITY, f32::max);
651 let exp_sum: f32 = logits[i].iter().map(|&l| (l - max_l).exp()).sum();
652 let mut d_log = vec![0.0f32; vocab];
653 for v in 0..vocab {
654 let softmax_v = (logits[i][v] - max_l).exp() / exp_sum;
655 d_log[v] = (softmax_v - if v == target { 1.0 } else { 0.0 }) / t as f32;
656 }
657 d_logits.push(d_log);
658 }
659
660 let mut d_output = vec![0.0f32; dim * vocab];
662 for i in 0..t {
663 let cache = caches.last().unwrap();
664 let vals = if i < cache.populations.len() {
665 &cache.populations[i]
666 } else {
667 continue;
668 };
669 for d_idx in 0..dim {
670 for v in 0..vocab {
671 d_output[d_idx * vocab + v] += vals.get(d_idx).copied().unwrap_or(0.0) * d_logits[i][v];
672 }
673 }
674 }
675
676 let mut d_values: Vec<Vec<f32>> = vec![vec![0.0f32; dim]; t.max(1)];
678 for i in 0..t {
679 for d_idx in 0..dim {
680 for v in 0..vocab {
681 d_values[i][d_idx] += model.output_weights[d_idx * vocab + v] * d_logits[i][v];
682 }
683 }
684 }
685
686 let dt_val = 0.090f32;
688 let mut block_grads = Vec::with_capacity(model.blocks.len());
689
690 for (block_idx, block) in model.blocks.iter().enumerate().rev() {
691 let cache = &caches[block_idx];
692 let num_h = block.hamiltonian.num_params();
693
694 let mut d_vw = vec![0.0f32; dim * dim];
696 for i in 0..t.min(cache.populations.len()) {
697 for d_idx in 0..dim {
698 for s in 0..dim {
699 let pop = cache.populations[i].get(s).copied().unwrap_or(0.0);
700 let dv = d_values[i].get(d_idx).copied().unwrap_or(0.0);
701 d_vw[d_idx * dim + s] += pop * dv;
702 }
703 }
704 }
705
706 let len = t.min(cache.unitaries.len());
708 let mut active_indices: Vec<usize> = Vec::new();
709 let mut u_batch: Vec<Complex> = Vec::new();
710 let mut d_rho_batch: Vec<Complex> = Vec::new();
711
712 for i in 0..len {
713 let u = &cache.unitaries[i];
714 if u.is_empty() {
715 continue;
716 }
717
718 active_indices.push(i);
719 u_batch.extend_from_slice(u);
720
721 let mut d_rho = vec![Complex::ZERO; stride];
723 for k in 0..dim {
724 let dp = d_values[i].get(k).copied().unwrap_or(0.0);
725 d_rho[k * dim + k] = Complex::new(dp, 0.0);
726 }
727 d_rho_batch.extend_from_slice(&d_rho);
728 }
729
730 let adjoint_results = if !active_indices.is_empty() {
732 gpu.batched_adjoint(&u_batch, &d_rho_batch, active_indices.len())
733 } else {
734 Vec::new()
735 };
736
737 let mut d_h = vec![0.0f32; num_h];
739 for (batch_idx, &pos_idx) in active_indices.iter().enumerate() {
740 let base = batch_idx * stride;
741 let scratch_b = &adjoint_results[base..base + stride];
742 let rho_before = &cache.rho_before[pos_idx];
743 let mut h_idx = 0;
744
745 for k in 0..dim {
747 let mut comm_diag = 0.0f32;
748 for j in 0..dim {
749 let ab = scratch_b[k * dim + j].mul(rho_before[j * dim + k]);
750 let ba = rho_before[k * dim + j].mul(scratch_b[j * dim + k]);
751 comm_diag += (ab.sub(ba)).im;
752 }
753 if h_idx < d_h.len() {
754 d_h[h_idx] += -dt_val * comm_diag;
755 }
756 h_idx += 1;
757 }
758
759 for p in 0..dim {
761 for q in (p + 1)..dim {
762 if h_idx >= d_h.len() {
763 break;
764 }
765 let ab_pq = scratch_b[p * dim + q].mul(rho_before[q * dim + p]);
766 let ba_pq = rho_before[p * dim + q].mul(scratch_b[q * dim + p]);
767 let comm_pq = ab_pq.sub(ba_pq);
768 d_h[h_idx] += -dt_val * 2.0 * comm_pq.im;
769 h_idx += 1;
770 }
771 }
772 }
773
774 block_grads.push(BlockGrad {
775 hamiltonian_grad: d_h,
776 value_weight_grad: d_vw,
777 });
778 }
779
780 block_grads.reverse();
781 let embed_grad = vec![0.0f32; model.embedding.num_params()];
782
783 AllGradients {
784 embed_grad,
785 block_grads,
786 output_grad: d_output,
787 }
788}
789
790struct WindowForwardState {
792 window_tokens: Vec<usize>,
793 t: usize,
794 states: Vec<DensityMatrixN>,
795 values: Vec<Vec<f32>>,
796 block_unitaries: Vec<Vec<Vec<dreamwell_math::Complex>>>,
797}
798
799struct WindowResult {
801 logits: Vec<Vec<f32>>,
802 caches: Vec<ForwardCache>,
803 loss: f32,
804 avg_f: f32,
805}
806
807pub fn gpu_precompute_unitaries(
815 gpu: &dreamwell_math::gpu_training::GpuTrainingContext,
816 block: &crate::transformer::QCTBlock,
817 dim: usize,
818 t: usize,
819 dt: f32,
820) -> Vec<Vec<dreamwell_math::Complex>> {
821 let n2 = dim * dim;
822
823 let mut all_h = vec![0.0f32; t * n2];
825 for i in 0..t {
826 let h = block.hamiltonian.build_matrix(i);
827 all_h[i * n2..(i + 1) * n2].copy_from_slice(&h);
828 }
829
830 let flat = gpu.batched_expm(&all_h, dt, t);
832
833 let stride = dim * dim;
835 (0..t).map(|i| flat[i * stride..(i + 1) * stride].to_vec()).collect()
836}
837
838pub fn gpu_batch_evolve(
843 gpu: &dreamwell_math::gpu_training::GpuTrainingContext,
844 unitaries_flat: &[dreamwell_math::Complex],
845 rhos_flat: &[dreamwell_math::Complex],
846 batch_count: usize,
847) -> Vec<dreamwell_math::Complex> {
848 gpu.batched_evolve(unitaries_flat, rhos_flat, batch_count)
849}
850
851pub fn gpu_batch_adjoint(
855 gpu: &dreamwell_math::gpu_training::GpuTrainingContext,
856 unitaries_flat: &[dreamwell_math::Complex],
857 d_rho_flat: &[dreamwell_math::Complex],
858 batch_count: usize,
859) -> Vec<dreamwell_math::Complex> {
860 gpu.batched_adjoint(unitaries_flat, d_rho_flat, batch_count)
861}
862
863pub fn train_qug(model: &mut QCT, tokens: &[usize], config: &crate::train::TrainConfig) -> Vec<EpochMetrics> {
866 let mut metrics = Vec::new();
867
868 for epoch in 0..config.num_epochs {
869 let start = std::time::Instant::now();
870 let lr = crate::train::learning_rate_pub(config, epoch);
871
872 let max_start = tokens.len().saturating_sub(config.context_length + 1);
874 let window_start = if max_start > 0 { epoch % max_start } else { 0 };
875 let window_end = (window_start + config.context_length + 1).min(tokens.len());
876 let window = &tokens[window_start..window_end];
877
878 let (logits, avg_f, caches) = forward_with_cache(model, &window[..window.len() - 1]);
880
881 let loss = QCT::loss_from_logits(&logits, window, avg_f);
883
884 let grads = qug_backward(model, window, &logits, &caches);
886 let grad_flat = grads.flatten();
887
888 let grad_norm: f32 = grad_flat.iter().map(|g| g * g).sum::<f32>().sqrt();
890
891 let scale = if grad_norm > config.grad_clip && grad_norm > 0.0 {
893 config.grad_clip / grad_norm
894 } else {
895 1.0
896 };
897
898 model.apply_gradient_update(&grad_flat, lr, scale);
900
901 let elapsed = start.elapsed().as_secs_f32() * 1000.0;
902
903 if epoch % config.log_interval == 0 || epoch == config.num_epochs - 1 {
904 let m = EpochMetrics {
905 epoch,
906 loss,
907 free_energy: avg_f,
908 grad_norm,
909 elapsed_ms: elapsed,
910 learning_rate: lr,
911 params_trained: grad_flat.len(),
912 };
913 log::info!(
914 "QUG Epoch {:4}: loss={:.4} F={:.4} |∇|={:.6} lr={:.5} ({:.1}ms)",
915 m.epoch,
916 m.loss,
917 m.free_energy,
918 m.grad_norm,
919 m.learning_rate,
920 m.elapsed_ms
921 );
922 metrics.push(m);
923 }
924 }
925
926 metrics
927}
928
929#[cfg(test)]
930mod tests {
931 use super::*;
932 use crate::transformer::QCTConfig;
933
934 #[test]
935 fn forward_cache_matches_forward() {
936 let config = QCTConfig {
937 vocab_size: 10,
938 dim: 4,
939 num_blocks: 1,
940 seed: 42,
941 };
942 let model = QCT::new(config);
943 let tokens = vec![0, 1, 2, 3, 4, 5];
944
945 let (logits_normal, f_normal) = model.forward(&tokens);
946 let (logits_cached, f_cached, caches) = forward_with_cache(&model, &tokens);
947
948 assert_eq!(logits_normal.len(), logits_cached.len());
949 assert!(
950 (f_normal - f_cached).abs() < 0.5,
951 "free energy mismatch: {} vs {}",
952 f_normal,
953 f_cached
954 );
955 assert!(!caches.is_empty());
956 }
957
958 #[test]
959 fn qug_gradient_nonzero() {
960 let config = QCTConfig {
961 vocab_size: 10,
962 dim: 4,
963 num_blocks: 1,
964 seed: 42,
965 };
966 let model = QCT::new(config);
967 let tokens = vec![0, 1, 2, 3, 4, 5];
968
969 let (logits, _, caches) = forward_with_cache(&model, &tokens[..5]);
970 let grads = qug_backward(&model, &tokens, &logits, &caches);
971 let flat = grads.flatten();
972
973 let norm: f32 = flat.iter().map(|g| g * g).sum::<f32>().sqrt();
974 assert!(norm > 0.0, "QUG gradient should be nonzero");
975 assert!(norm.is_finite(), "QUG gradient should be finite");
976 }
977
978 #[test]
979 fn qug_training_runs() {
980 let config = QCTConfig {
981 vocab_size: 10,
982 dim: 4,
983 num_blocks: 1,
984 seed: 42,
985 };
986 let mut model = QCT::new(config);
987 let tokens: Vec<usize> = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
988
989 let train_config = crate::train::TrainConfig {
990 learning_rate: 0.01,
991 num_epochs: 3,
992 context_length: 6,
993 log_interval: 1,
994 ..Default::default()
995 };
996 let metrics = train_qug(&mut model, &tokens, &train_config);
997 assert_eq!(metrics.len(), 3);
998 assert!(metrics[0].loss.is_finite());
999 assert!(metrics[0].grad_norm > 0.0);
1000 eprintln!(
1001 "QUG training: {:.1}ms/epoch (vs ~7400ms for PSR)",
1002 metrics[0].elapsed_ms
1003 );
1004 }
1005}