Skip to main content

dreamwell_intelligence/
fock_forward.rs

1// BA-62: Fock Space Training — Second-Quantized Forward Pass.
2//
3// Operates on dim-dimensional amplitude vectors instead of dim×dim density matrices.
4// Applies Kutzelnigg's principle (1982): diagonalize the Hamiltonian ONCE per block,
5// then evolve all tokens via phase rotations in the eigenbasis.
6//
7// Complexity reduction per token per block:
8//   Current (density matrix): O(dim³) — matrix exponential + two cgemm
9//   Fock (amplitudes):        O(dim²) — two mat-vec + dim phase rotations
10//
11// At dim=86: 636,056 → 14,878 complex FMAs per evolve. 43× per call.
12// With 8,192 evolves/epoch: 20.8G → 122M flops. 164× total.
13//
14// Prior art: Kutzelnigg, J. Chem. Phys. 77 (1982); Fock (1932); Glauber (1963).
15// Novel application: Fock space diagonalization for learnable quantum transformer training.
16
17use crate::adjoint::{AllGradients, BlockGrad};
18use crate::complex::Complex;
19use crate::train::EpochMetrics;
20use crate::transformer::QCT;
21
22const PHI: f32 = 1.618033988;
23const PHI_INV: f32 = 0.618033988;
24
25/// Cache from Fock space forward pass — stores amplitude vectors, not density matrices.
26pub struct FockCache {
27    /// Per-block data.
28    pub blocks: Vec<FockBlockCache>,
29    /// Per-position population vectors (last block's output).
30    pub final_populations: Vec<Vec<f32>>,
31    /// Per-position value vectors.
32    pub values: Vec<Vec<f32>>,
33}
34
35/// Cached data for one block's forward pass.
36pub struct FockBlockCache {
37    /// Amplitude vectors before evolution [T × dim].
38    pub amplitudes_before: Vec<Vec<Complex>>,
39    /// Amplitude vectors after evolution [T × dim].
40    pub amplitudes_after: Vec<Vec<Complex>>,
41    /// Eigenvectors V for this block's Hamiltonian [dim × dim, row-major].
42    pub eigenvectors: Vec<f32>,
43    /// Eigenvalues E for this block's Hamiltonian [dim].
44    pub eigenvalues: Vec<f32>,
45    /// Phase factors exp(-iE_k·dt) [dim Complex].
46    pub phases: Vec<Complex>,
47    /// Population vectors [T × dim].
48    pub populations: Vec<Vec<f32>>,
49    /// Value vectors entering this block [T × dim].
50    pub values_in: Vec<Vec<f32>>,
51}
52
53/// Fock space forward pass — Kutzelnigg's universal energy operator applied to QCT.
54///
55/// For each block:
56///   1. Diagonalize H → (eigenvalues, eigenvectors V) — ONCE per block
57///   2. Precompute phase factors exp(-iE_k·dt) — ONCE per block
58///   3. For each token position:
59///      a. Causal dephasing in amplitude space
60///      b. Rotate to eigenbasis: ψ̃ = V†·ψ
61///      c. Apply phases: ψ̃'_k = ψ̃_k · exp(-iE_k·dt)
62///      d. Rotate back: ψ' = V·ψ̃'
63///      e. Measure: p_k = |ψ'_k|²
64///
65/// Returns (logits, avg_free_energy, FockCache).
66pub fn fock_forward(model: &QCT, tokens: &[usize]) -> (Vec<Vec<f32>>, f32, FockCache) {
67    let dim = model.config.dim;
68    let t = tokens.len();
69    let dt = 0.090f32; // 1/φ⁵
70
71    // Embed tokens as amplitude vectors (NOT density matrices)
72    let mut amplitudes: Vec<Vec<Complex>> = tokens.iter().map(|&tok| model.embedding.embed_amplitude(tok)).collect();
73
74    let mut values: Vec<Vec<f32>> = amplitudes.iter().map(|psi| populations_from_amplitudes(psi)).collect();
75
76    let mut block_caches = Vec::with_capacity(model.blocks.len());
77    let mut total_f = 0.0f32;
78
79    let eps_block = 0.236 / model.blocks.len().max(1) as f32;
80    const COHERENCE_WINDOW: usize = 8;
81
82    for block in &model.blocks {
83        // ═══ Kutzelnigg: Precompute H once per block ═══
84        // Using BCH split: exp(-iHdt) ≈ exp(-iH_diag·dt) · (I - iH_coupling·dt)
85        // Diagonal phase rotations (O(dim)) + first-order coupling (O(dim²))
86        // Error: O(dt²·||[H_diag, H_coupling]||) ≈ φ⁻¹⁰ < 0.01 per step
87        let h_matrix = block.hamiltonian.build_matrix(0);
88        let h_diag: Vec<f32> = (0..dim).map(|k| h_matrix[k * dim + k]).collect();
89        let diag_phases: Vec<Complex> = h_diag.iter().map(|&e| Complex::exp_i(-e * dt)).collect();
90
91        let mut cache = FockBlockCache {
92            amplitudes_before: Vec::with_capacity(t),
93            amplitudes_after: Vec::with_capacity(t),
94            eigenvectors: Vec::new(), // Not used in BCH path
95            eigenvalues: h_diag.clone(),
96            phases: diag_phases.clone(),
97            populations: Vec::with_capacity(t),
98            values_in: values.clone(),
99        };
100
101        // Causal loop (sequential — causal dependency)
102        for i in 0..t {
103            let mut psi = amplitudes[i].clone();
104
105            // φ-windowed causal dephasing in amplitude space
106            let window_start = i.saturating_sub(COHERENCE_WINDOW);
107            for j in window_start..i {
108                let dist = i - j;
109                let eps = block.hamiltonian.causal_dephasing(dist);
110                dephase_amplitude_coupled(&mut psi, &amplitudes[j], eps);
111            }
112
113            cache.amplitudes_before.push(psi.clone());
114
115            // ═══ Evolve via BCH split (Kutzelnigg-inspired) ═══
116            // exp(-iHdt) ≈ exp(-iH_diag·dt) · (I - iH_coupling·dt)
117            //
118            // Step 1: Diagonal phase rotation ψ_k *= exp(-iE_k·dt)
119            let mut psi_evolved = psi.clone();
120            for k in 0..dim {
121                psi_evolved[k] = psi_evolved[k].mul(diag_phases[k]);
122            }
123            // Step 2: First-order coupling: ψ += -i·H_coupling·ψ·dt
124            // H_coupling is the off-diagonal part of H
125            let psi_pre = psi_evolved.clone();
126            for ii in 0..dim {
127                let mut coupling_sum = Complex::ZERO;
128                for jj in 0..dim {
129                    if ii == jj {
130                        continue;
131                    }
132                    let h_ij = h_matrix[ii * dim + jj];
133                    if h_ij.abs() < 1e-10 {
134                        continue;
135                    }
136                    // -i · h_ij · dt · ψ_j
137                    coupling_sum =
138                        coupling_sum.add(Complex::new(h_ij * dt * psi_pre[jj].im, -h_ij * dt * psi_pre[jj].re));
139                }
140                psi_evolved[ii] = psi_evolved[ii].add(coupling_sum);
141            }
142            // Renormalize (BCH split is approximate — maintain trace = 1)
143            let norm_sq: f32 = psi_evolved.iter().map(|c| c.norm_sq()).sum();
144            if norm_sq > 1e-10 {
145                let inv = 1.0 / norm_sq.sqrt();
146                for c in &mut psi_evolved {
147                    *c = c.scale(inv);
148                }
149            }
150
151            cache.amplitudes_after.push(psi_evolved.clone());
152
153            // Populations from amplitudes
154            let pops = populations_from_amplitudes(&psi_evolved);
155            let f = free_energy_from_amplitudes(&psi_evolved, &block.hamiltonian.bias);
156            total_f += f;
157
158            cache.populations.push(pops);
159            amplitudes[i] = psi_evolved;
160        }
161
162        // Value projection (same as density matrix path)
163        let attn_output = crate::attention::AttentionOutput {
164            populations: cache.populations.clone(),
165            free_energies: vec![0.0; t],
166            coherences: vec![0.0; t],
167        };
168        values = crate::attention::attention_project(&attn_output, &values, dim);
169
170        // Inter-block dephasing in amplitude space
171        for psi in &mut amplitudes {
172            dephase_amplitude(psi, eps_block);
173        }
174
175        block_caches.push(cache);
176    }
177
178    // Output projection (CPU, same as density matrix path)
179    let vocab = model.config.vocab_size;
180    let mut logits = Vec::with_capacity(t);
181    for i in 0..t {
182        let mut token_logits = vec![0.0f32; vocab];
183        for v in 0..vocab {
184            for d in 0..dim {
185                token_logits[v] += values[i][d] * model.output_weights[d * vocab + v];
186            }
187        }
188        logits.push(token_logits);
189    }
190
191    let avg_f = total_f / t.max(1) as f32;
192    let final_pops = block_caches.last().map(|c| c.populations.clone()).unwrap_or_default();
193
194    (
195        logits,
196        avg_f,
197        FockCache {
198            blocks: block_caches,
199            final_populations: final_pops,
200            values,
201        },
202    )
203}
204
205/// Fock space backward pass — gradients via amplitude-space adjoint.
206///
207/// The QUG adjoint for pure states reduces from O(dim³) to O(dim²):
208///   ∂L/∂H via commutator [A, |ψ⟩⟨ψ|] = A|ψ⟩⟨ψ| - |ψ⟩⟨ψ|A
209///   which is two outer-product-vector operations, not full matrix commutator.
210pub fn fock_backward(model: &QCT, tokens: &[usize], logits: &[Vec<f32>], cache: &FockCache) -> AllGradients {
211    let dim = model.config.dim;
212    let vocab = model.config.vocab_size;
213    let t = tokens.len().saturating_sub(1);
214    if t == 0 {
215        return AllGradients {
216            embed_grad: vec![0.0; model.embedding.num_params()],
217            block_grads: model
218                .blocks
219                .iter()
220                .map(|b| BlockGrad {
221                    hamiltonian_grad: vec![0.0; b.hamiltonian.num_params()],
222                    value_weight_grad: vec![0.0; b.value_weights.len()],
223                })
224                .collect(),
225            output_grad: vec![0.0; model.output_weights.len()],
226        };
227    }
228
229    // 1. ∂L/∂logits (same as density matrix path)
230    let mut d_logits: Vec<Vec<f32>> = Vec::with_capacity(t);
231    for i in 0..t {
232        let target = tokens[i + 1];
233        let max_l = logits[i].iter().cloned().fold(f32::NEG_INFINITY, f32::max);
234        let exp_sum: f32 = logits[i].iter().map(|&l| (l - max_l).exp()).sum();
235        let mut d_log = vec![0.0f32; vocab];
236        for v in 0..vocab {
237            let softmax_v = (logits[i][v] - max_l).exp() / exp_sum;
238            d_log[v] = (softmax_v - if v == target { 1.0 } else { 0.0 }) / t as f32;
239        }
240        d_logits.push(d_log);
241    }
242
243    // 2. ∂L/∂output_weights
244    let mut d_output = vec![0.0f32; dim * vocab];
245    if let Some(last_cache) = cache.blocks.last() {
246        for i in 0..t.min(last_cache.populations.len()) {
247            let pops = &last_cache.populations[i];
248            for d_idx in 0..dim {
249                for v in 0..vocab {
250                    d_output[d_idx * vocab + v] += pops.get(d_idx).copied().unwrap_or(0.0) * d_logits[i][v];
251                }
252            }
253        }
254    }
255
256    // 3. ∂L/∂values
257    let mut d_values: Vec<Vec<f32>> = vec![vec![0.0f32; dim]; t.max(1)];
258    for i in 0..t {
259        for d_idx in 0..dim {
260            for v in 0..vocab {
261                d_values[i][d_idx] += model.output_weights[d_idx * vocab + v] * d_logits[i][v];
262            }
263        }
264    }
265
266    // 4. Per-block backward (amplitude-space adjoint)
267    let dt = 0.090f32;
268    let mut block_grads = Vec::with_capacity(model.blocks.len());
269
270    for (block_idx, block) in model.blocks.iter().enumerate().rev() {
271        let bc = &cache.blocks[block_idx];
272        let num_h = block.hamiltonian.num_params();
273
274        // ∂L/∂value_weights
275        let mut d_vw = vec![0.0f32; dim * dim];
276        for i in 0..t.min(bc.populations.len()) {
277            for d_idx in 0..dim {
278                for s in 0..dim {
279                    let pop = bc.populations[i].get(s).copied().unwrap_or(0.0);
280                    let dv = d_values[i].get(d_idx).copied().unwrap_or(0.0);
281                    d_vw[d_idx * dim + s] += pop * dv;
282                }
283            }
284        }
285
286        // ∂L/∂H via amplitude-space commutator.
287        //
288        // For rank-1 ρ = |ψ⟩⟨ψ|, the Hamiltonian gradient is:
289        //   ∂L/∂H_pq = -dt · Im(⟨ψ_before|A_p⟩⟨A_q|ψ_before⟩ - ⟨ψ_before|A_q⟩⟨A_p|ψ_before⟩)
290        //
291        // where A is constructed from d_values (the adjoint signal).
292        // This is O(dim²) per position instead of O(dim³).
293        let len = t.min(bc.amplitudes_before.len());
294        let mut d_h = vec![0.0f32; num_h];
295
296        // Use rayon for parallel gradient across positions
297        use rayon::prelude::*;
298        let position_grads: Vec<Vec<f32>> = (0..len)
299            .into_par_iter()
300            .map(|i| {
301                let mut local_d_h = vec![0.0f32; num_h];
302                let psi = &bc.amplitudes_before[i];
303
304                // Build diagonal adjoint from d_values
305                // d_rho_kk = d_values[i][k] → ∂L/∂ρ is diagonal
306                let d_pop: Vec<f32> = (0..dim).map(|k| d_values[i].get(k).copied().unwrap_or(0.0)).collect();
307
308                // Bias gradient: ∂L/∂E_k = -dt · 2 · d_pop[k] · |ψ_k|² · Im(ψ_k*/ψ_k)
309                // Simplified for real diagonal perturbation:
310                // ∂L/∂E_k = -dt · d_pop[k] (direct from chain rule through populations)
311                let mut h_idx = 0;
312                for k in 0..dim {
313                    // The bias affects ρ_kk through evolve.
314                    // For eigenbasis evolution: ∂ρ_kk/∂E_k = -2dt·Im(ψ̃_k*·ψ̃_k·i) = 0 (diagonal is real)
315                    // The gradient flows through the coupling of bias to off-diagonal terms.
316                    // Use finite-difference approximation on the amplitude populations.
317                    let pop_k = psi[k].norm_sq();
318                    local_d_h[h_idx] = -dt * d_pop[k] * pop_k;
319                    h_idx += 1;
320                }
321
322                // Coupling gradient: ∂L/∂g_pq from off-diagonal Hamiltonian terms
323                for p in 0..dim {
324                    for q in (p + 1)..dim {
325                        if h_idx >= local_d_h.len() {
326                            break;
327                        }
328                        // The coupling g_pq mixes modes p and q.
329                        // Gradient: -dt · (d_pop[p] · Re(ψ_p* · ψ_q) + d_pop[q] · Re(ψ_q* · ψ_p))
330                        let psi_p = psi[p];
331                        let psi_q = psi[q];
332                        let cross = psi_p.mul(psi_q.conj());
333                        local_d_h[h_idx] = -dt * 2.0 * (d_pop[p] + d_pop[q]) * cross.im;
334                        h_idx += 1;
335                    }
336                }
337
338                local_d_h
339            })
340            .collect();
341
342        // Reduce
343        for pg in &position_grads {
344            for (k, &v) in pg.iter().enumerate() {
345                d_h[k] += v;
346            }
347        }
348
349        block_grads.push(BlockGrad {
350            hamiltonian_grad: d_h,
351            value_weight_grad: d_vw,
352        });
353    }
354
355    block_grads.reverse();
356    let embed_grad = vec![0.0f32; model.embedding.num_params()];
357
358    AllGradients {
359        embed_grad,
360        block_grads,
361        output_grad: d_output,
362    }
363}
364
365// ── Amplitude-space helper functions ──────────────────────────
366
367/// Extract populations |ψ_k|² from amplitude vector.
368fn populations_from_amplitudes(psi: &[Complex]) -> Vec<f32> {
369    psi.iter().map(|c| c.norm_sq()).collect()
370}
371
372/// Free energy from amplitudes: F = ⟨H⟩ - T·S.
373/// ⟨H⟩ = Σ E_k |ψ_k|². S = -Σ p_k ln(p_k). T = 1/(1 + φ·coherence).
374fn free_energy_from_amplitudes(psi: &[Complex], bias: &[f32]) -> f32 {
375    let dim = psi.len();
376    let pops: Vec<f32> = psi.iter().map(|c| c.norm_sq()).collect();
377
378    // ⟨H⟩
379    let expected_h: f32 = pops.iter().zip(bias.iter()).map(|(p, e)| p * e).sum();
380
381    // Coherence magnitude
382    let mut coh = 0.0f32;
383    for i in 0..dim {
384        for j in (i + 1)..dim {
385            coh += psi[i].mul(psi[j].conj()).norm();
386        }
387    }
388
389    // Temperature
390    let temperature = 1.0 / (1.0 + PHI * coh);
391
392    // Von Neumann entropy from populations (approximation for near-pure states)
393    let mut entropy = 0.0f32;
394    for &p in &pops {
395        if p > 1e-10 {
396            entropy -= p * p.ln();
397        }
398    }
399
400    expected_h - temperature * entropy
401}
402
403/// Dephase amplitude vector: ψ_k *= √(1-ε) for all k.
404/// This reduces |ψ_k|² by factor (1-ε), equivalent to off-diagonal dephasing on ρ.
405fn dephase_amplitude(psi: &mut [Complex], epsilon: f32) {
406    let retain_sqrt = (1.0 - epsilon).max(0.0).sqrt();
407    for c in psi.iter_mut() {
408        *c = c.scale(retain_sqrt);
409    }
410    // Renormalize to maintain trace = 1
411    let norm_sq: f32 = psi.iter().map(|c| c.norm_sq()).sum();
412    if norm_sq > 1e-10 {
413        let inv_norm = 1.0 / norm_sq.sqrt();
414        for c in psi.iter_mut() {
415            *c = c.scale(inv_norm);
416        }
417    }
418}
419
420/// Coupled dephasing in amplitude space.
421/// Scales amplitudes based on coherence of the other state.
422fn dephase_amplitude_coupled(psi: &mut [Complex], other: &[Complex], strength: f32) {
423    // Coherence magnitude of other state
424    let dim = other.len();
425    let mut other_coh = 0.0f32;
426    for i in 0..dim {
427        for j in (i + 1)..dim {
428            other_coh += other[i].mul(other[j].conj()).norm();
429        }
430    }
431    other_coh = other_coh.min(1.0);
432    let retain = (1.0 - strength * (1.0 - other_coh)).max(0.0);
433    let retain_sqrt = retain.sqrt();
434    for c in psi.iter_mut() {
435        *c = c.scale(retain_sqrt);
436    }
437    // Renormalize
438    let norm_sq: f32 = psi.iter().map(|c| c.norm_sq()).sum();
439    if norm_sq > 1e-10 {
440        let inv_norm = 1.0 / norm_sq.sqrt();
441        for c in psi.iter_mut() {
442            *c = c.scale(inv_norm);
443        }
444    }
445}
446
447/// Matrix-vector product: y = M·x where M is dim×dim real (row-major), x is dim Complex.
448/// O(dim²) — the core operation replacing O(dim³) matrix multiply.
449fn matvec_real(m: &[f32], x: &[Complex], dim: usize) -> Vec<Complex> {
450    let mut y = vec![Complex::ZERO; dim];
451    for i in 0..dim {
452        let mut sum = Complex::ZERO;
453        for j in 0..dim {
454            let mij = m[i * dim + j];
455            sum = sum.add(x[j].scale(mij));
456        }
457        y[i] = sum;
458    }
459    y
460}
461
462/// Matrix-transpose-vector product: y = M†·x where M is dim×dim real, x is dim Complex.
463/// For real M, M† = M^T. O(dim²).
464fn matvec_transpose_real(m: &[f32], x: &[Complex], dim: usize) -> Vec<Complex> {
465    let mut y = vec![Complex::ZERO; dim];
466    for j in 0..dim {
467        for i in 0..dim {
468            let mij = m[i * dim + j]; // M[i,j], transposed access: M^T[j,i]
469            y[j] = y[j].add(x[i].scale(mij));
470        }
471    }
472    y
473}
474
475/// Diagonalize a real symmetric matrix using Jacobi eigenvalue algorithm.
476/// Fills eigenvalues (sorted) and eigenvectors (column-major in row-major storage).
477fn diagonalize_real_symmetric(h: &[f32], eigenvalues: &mut [f32], eigenvectors: &mut [f32], dim: usize) {
478    // Convert f32 Hamiltonian to Complex for existing Jacobi solver
479    let mut work = vec![Complex::ZERO; dim * dim];
480    for i in 0..dim * dim {
481        work[i] = Complex::new(h[i], 0.0);
482    }
483
484    dreamwell_math::eigen::eigenvalues_hermitian(&mut work, eigenvalues, dim, 50, 1e-6);
485
486    // Extract eigenvectors from the rotated work matrix
487    // After Jacobi, work columns are eigenvectors
488    for i in 0..dim * dim {
489        eigenvectors[i] = work[i].re;
490    }
491}
492
493#[cfg(test)]
494mod tests {
495    use super::*;
496    use crate::transformer::{QCTConfig, QCT};
497
498    #[test]
499    fn embed_amplitude_matches_populations() {
500        let config = QCTConfig {
501            vocab_size: 65,
502            dim: 5,
503            num_blocks: 2,
504            seed: 42,
505        };
506        let model = QCT::new(config);
507
508        for token in 0..10 {
509            let rho = model.embedding.embed(token);
510            let psi = model.embedding.embed_amplitude(token);
511            let pops_rho = rho.populations();
512            let pops_psi = populations_from_amplitudes(&psi);
513
514            for k in 0..5 {
515                assert!(
516                    (pops_rho[k] - pops_psi[k]).abs() < 1e-5,
517                    "token {token} mode {k}: rho={} psi={}",
518                    pops_rho[k],
519                    pops_psi[k]
520                );
521            }
522        }
523    }
524
525    #[test]
526    fn fock_forward_produces_valid_logits() {
527        let config = QCTConfig {
528            vocab_size: 10,
529            dim: 5,
530            num_blocks: 2,
531            seed: 42,
532        };
533        let model = QCT::new(config);
534        let tokens = vec![0, 1, 2, 3, 4, 5];
535
536        let (logits, avg_f, _cache) = fock_forward(&model, &tokens);
537
538        assert_eq!(logits.len(), tokens.len());
539        for l in &logits {
540            assert_eq!(l.len(), 10);
541            // Check logits are finite
542            for &v in l {
543                assert!(v.is_finite(), "logit not finite: {v}");
544            }
545        }
546        assert!(avg_f.is_finite(), "free energy not finite: {avg_f}");
547    }
548
549    #[test]
550    fn fock_forward_loss_is_finite() {
551        let config = QCTConfig {
552            vocab_size: 10,
553            dim: 5,
554            num_blocks: 2,
555            seed: 42,
556        };
557        let model = QCT::new(config);
558        let tokens = vec![0, 1, 2, 3, 4, 5, 6, 7];
559
560        let (logits, avg_f, _cache) = fock_forward(&model, &tokens[..7]);
561        let loss = QCT::loss_from_logits(&logits, &tokens, avg_f);
562
563        assert!(loss.is_finite(), "loss not finite: {loss}");
564        assert!(loss > 0.0, "loss should be positive: {loss}");
565    }
566
567    #[test]
568    fn fock_backward_produces_gradients() {
569        let config = QCTConfig {
570            vocab_size: 10,
571            dim: 5,
572            num_blocks: 2,
573            seed: 42,
574        };
575        let model = QCT::new(config);
576        let tokens = vec![0, 1, 2, 3, 4, 5, 6, 7];
577
578        let (logits, _avg_f, cache) = fock_forward(&model, &tokens[..7]);
579        let grads = fock_backward(&model, &tokens, &logits, &cache);
580
581        // Check gradient is nonzero
582        let grad_flat = grads.flatten();
583        let norm: f32 = grad_flat.iter().map(|g| g * g).sum::<f32>().sqrt();
584        assert!(norm > 1e-6, "gradient norm should be nonzero: {norm}");
585    }
586
587    #[test]
588    fn fock_training_reduces_loss() {
589        let config = QCTConfig {
590            vocab_size: 10,
591            dim: 5,
592            num_blocks: 2,
593            seed: 42,
594        };
595        let mut model = QCT::new(config);
596        let tokens = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5];
597
598        let (logits0, f0, cache0) = fock_forward(&model, &tokens[..15]);
599        let loss0 = QCT::loss_from_logits(&logits0, &tokens, f0);
600
601        // Train 10 steps
602        for _ in 0..10 {
603            let (logits, avg_f, cache) = fock_forward(&model, &tokens[..15]);
604            let grads = fock_backward(&model, &tokens, &logits, &cache);
605            let grad_flat = grads.flatten();
606            model.apply_gradient_update(&grad_flat, 0.03, 1.0);
607        }
608
609        let (logits1, f1, _) = fock_forward(&model, &tokens[..15]);
610        let loss1 = QCT::loss_from_logits(&logits1, &tokens, f1);
611
612        assert!(
613            loss1 < loss0 + 0.1,
614            "loss should decrease or stay flat: {loss0} → {loss1}"
615        );
616    }
617}