Skip to main content

dreamwell_intelligence/
adjoint.rs

1// Quantum Unitarity Gradient (QUG) — reverse-mode autodiff exploiting unitarity.
2//
3// The key insight: because U = exp(-iHdt) is unitary (U†U = I), the adjoint of
4// the evolve operation ρ_out = UρU† is simply:
5//   ∂L/∂ρ_in = U† · (∂L/∂ρ_out) · U
6// This is just two matrix multiplies with the SAME unitary we stored during forward.
7// No Jacobian inversion. No Fréchet derivative. No matrix exponential adjoint.
8//
9// Cost: ~3× forward pass (vs 2P× for parameter shift where P = param count).
10// For 1M params: 700,000x speedup over parameter shift.
11//
12// Clean Compute: pre-allocated gradient buffers. Zero allocation in backward pass.
13
14use crate::attention;
15use crate::complex::Complex;
16use crate::density_matrix::DensityMatrixN;
17use crate::train::EpochMetrics;
18use crate::transformer::QCT;
19
20// φ constants available via dreamwell_math::golden_ratio() if needed.
21
22/// Stored forward state for one QCT block (needed for backward pass).
23pub struct ForwardCache {
24    /// Density matrices BEFORE evolve at each token position [T × dim²].
25    pub rho_before: Vec<Vec<Complex>>,
26    /// Unitaries computed during forward [T × dim²].
27    pub unitaries: Vec<Vec<Complex>>,
28    /// Density matrices AFTER full attention at each position [T × dim²].
29    pub rho_after: Vec<Vec<Complex>>,
30    /// Populations from Born measurement [T × dim].
31    pub populations: Vec<Vec<f32>>,
32    /// Values passed to projection [T × dim].
33    pub values: Vec<Vec<f32>>,
34}
35
36/// Gradient accumulator for all model parameters.
37pub struct AllGradients {
38    pub embed_grad: Vec<f32>,
39    pub block_grads: Vec<BlockGrad>,
40    pub output_grad: Vec<f32>,
41}
42
43pub struct BlockGrad {
44    pub hamiltonian_grad: Vec<f32>,
45    pub value_weight_grad: Vec<f32>,
46}
47
48impl AllGradients {
49    /// Flatten to a single vector matching model.all_params() layout.
50    pub fn flatten(&self) -> Vec<f32> {
51        let mut v = Vec::new();
52        v.extend_from_slice(&self.embed_grad);
53        for bg in &self.block_grads {
54            v.extend_from_slice(&bg.hamiltonian_grad);
55            v.extend_from_slice(&bg.value_weight_grad);
56        }
57        v.extend_from_slice(&self.output_grad);
58        v
59    }
60}
61
62/// Forward pass with caching for QUG backward.
63/// Returns (logits, avg_free_energy, per_block_caches).
64pub fn forward_with_cache(model: &QCT, tokens: &[usize]) -> (Vec<Vec<f32>>, f32, Vec<ForwardCache>) {
65    forward_with_cache_converter(model, tokens, None)
66}
67
68/// Forward pass with optional GoldenRatioConverter for frequency-aware dephasing.
69/// When converter is Some, applies frequency-proportional initial dephasing after embedding.
70/// The embedding stays pure (clean gradients). Frequency enters as a quantum channel.
71/// When None, falls back to standard pure-state embedding with no initial dephasing.
72pub fn forward_with_cache_converter(
73    model: &QCT,
74    tokens: &[usize],
75    converter: Option<&crate::golden_ratio_converter::GoldenRatioConverter>,
76) -> (Vec<Vec<f32>>, f32, Vec<ForwardCache>) {
77    let dim = model.config.dim;
78    let t = tokens.len();
79
80    // Embed as pure states, then apply frequency-proportional dephasing.
81    // Common tokens → strong dephasing → low F → thermodynamically gated (BA-34).
82    // Rare tokens → weak dephasing → high F → full compute, maximum gradient signal.
83    let mut states: Vec<DensityMatrixN> = tokens.iter().map(|&tok| model.embedding.embed(tok)).collect();
84    if let Some(conv) = converter {
85        for (i, &tok) in tokens.iter().enumerate() {
86            let eps = conv.dephasing_rate(tok);
87            if eps > 1e-6 {
88                states[i].dephase(eps);
89            }
90        }
91    }
92    let mut values: Vec<Vec<f32>> = states.iter().map(|s| s.populations()).collect();
93    let mut current_states = states;
94    let mut caches = Vec::with_capacity(model.blocks.len());
95    let mut total_f = 0.0f32;
96
97    for block in &model.blocks {
98        let mut cache = ForwardCache {
99            rho_before: Vec::with_capacity(t),
100            unitaries: Vec::with_capacity(t),
101            rho_after: Vec::with_capacity(t),
102            populations: Vec::with_capacity(t),
103            values: values.clone(),
104        };
105
106        // Pre-compute ALL unitaries in parallel (rayon).
107        // U_i = exp(-iH(i)dt) depends only on position i and static Hamiltonian
108        // parameters — NOT on the evolving density matrices. This moves the
109        // O(d³) matrix exponential out of the serial causal loop.
110        let precomputed_unitaries: Vec<Vec<Complex>> = {
111            use rayon::prelude::*;
112            (0..t)
113                .into_par_iter()
114                .map(|i| {
115                    let h_matrix = block.hamiltonian.build_matrix(i);
116                    DensityMatrixN::hamiltonian_unitary(&h_matrix, dim, 0.090) // 1/φ⁵ — adiabatic evolution step
117                })
118                .collect()
119        };
120
121        // Causal attention loop — φ-windowed dephasing + thermodynamic gate + evolve.
122        //
123        // Quantum Signal Gate: the φ-exponential dephasing ε = exp(-d/φ) creates a
124        // finite coherence horizon. Beyond distance w = ⌈φ · 3·ln(φ)⌉ = 3, the
125        // dephasing contribution falls below 1/φ³ (thermodynamic gate threshold).
126        // Truncate the causal loop from O(t²) to O(t × w).
127        //
128        // This is the DreamGate concept applied to quantum information flow:
129        // the gate opens for coherent signals (within window) and closes for
130        // thermalized signals (beyond window). Error bounded by 1/φ³ per position.
131        // BA-35: Coherence window from φ-exponential decay.
132        // Base horizon: ⌈φ · 3·ln(φ)⌉ = 3 (where ε < 1/φ³).
133        // Training window: scale by φ² ≈ 2.618 for gradient preservation.
134        // Wider window preserves more causal information for learning.
135        // ⌈3 × φ²⌉ = ⌈7.85⌉ = 8.
136        const COHERENCE_WINDOW: usize = 8;
137
138        // BA-34: Thermodynamic gate threshold, scaled by dimension.
139        // At dim=5: 0.236/5 = 0.047. At dim=48: 0.236/48 = 0.005.
140        // The free energy landscape narrows with dimension — the threshold must follow.
141        let f_gate = 0.236 / dim as f32;
142
143        for i in 0..t {
144            let mut rho = current_states[i].clone();
145
146            // φ-windowed causal dephasing: only couple with nearest COHERENCE_WINDOW
147            // predecessors. Distant positions contribute ε < 1/φ³ — below gate threshold.
148            let window_start = i.saturating_sub(COHERENCE_WINDOW);
149            for j in window_start..i {
150                let dist = i - j;
151                let eps = block.hamiltonian.causal_dephasing(dist);
152                rho.couple_dephase(&current_states[j], eps);
153            }
154
155            // Store rho BEFORE evolve
156            cache.rho_before.push(rho.entries.clone());
157
158            // Thermodynamic computation gate: check free energy before evolve.
159            // F < threshold → state is near equilibrium → evolve is negligible.
160            let f_before = rho.free_energy(&block.hamiltonian.bias);
161            let unitary = &precomputed_unitaries[i];
162
163            // BA-34: Thermodynamic computation gate (dimension-scaled).
164            // Skip evolve only for states genuinely at thermal equilibrium.
165            if f_before.abs() > f_gate {
166                cache.unitaries.push(unitary.clone());
167                rho.evolve(unitary);
168            } else {
169                cache.unitaries.push(Vec::new());
170            }
171
172            // Store rho AFTER evolve (or after skip)
173            cache.rho_after.push(rho.entries.clone());
174
175            let f = rho.free_energy(&block.hamiltonian.bias);
176            total_f += f;
177
178            let pops = rho.populations();
179            cache.populations.push(pops);
180        }
181
182        // Value projection
183        let attn_output = attention::AttentionOutput {
184            populations: cache.populations.clone(),
185            free_energies: vec![0.0; t],
186            coherences: vec![0.0; t],
187        };
188        let new_values = attention::attention_project(&attn_output, &values, dim);
189        values = new_values;
190
191        // Inter-block dephasing: ε = 1/(φ³ × num_blocks).
192        // Total surviving coherence after ALL blocks: exp(-1/φ³) ≈ 79%.
193        // Scale-invariant — same survival ratio at any block depth.
194        let eps_block = 0.236 / model.blocks.len().max(1) as f32;
195        for state in &mut current_states {
196            state.dephase(eps_block);
197        }
198
199        caches.push(cache);
200    }
201
202    // Output projection
203    let vocab = model.config.vocab_size;
204    let mut logits = Vec::with_capacity(t);
205    for i in 0..t {
206        let mut token_logits = vec![0.0f32; vocab];
207        for v in 0..vocab {
208            for d in 0..dim {
209                token_logits[v] += values[i][d] * model.output_weights[d * vocab + v];
210            }
211        }
212        logits.push(token_logits);
213    }
214
215    (logits, total_f / t as f32, caches)
216}
217
218/// QUG backward pass — compute gradients using stored unitaries.
219/// Cost: ~3× forward (3 matmuls per block for adjoint, vs 2P forward passes for PSR).
220pub fn qug_backward(model: &QCT, tokens: &[usize], logits: &[Vec<f32>], caches: &[ForwardCache]) -> AllGradients {
221    let dim = model.config.dim;
222    let vocab = model.config.vocab_size;
223    let t = tokens.len().saturating_sub(1); // targets = tokens[1..]
224    if t == 0 {
225        return AllGradients {
226            embed_grad: vec![0.0; model.embedding.num_params()],
227            block_grads: model
228                .blocks
229                .iter()
230                .map(|b| BlockGrad {
231                    hamiltonian_grad: vec![0.0; b.hamiltonian.num_params()],
232                    value_weight_grad: vec![0.0; b.value_weights.len()],
233                })
234                .collect(),
235            output_grad: vec![0.0; model.output_weights.len()],
236        };
237    }
238
239    // 1. ∂L/∂logits from cross-entropy (softmax gradient)
240    let mut d_logits: Vec<Vec<f32>> = Vec::with_capacity(t);
241    for i in 0..t {
242        let target = tokens[i + 1];
243        let max_l = logits[i].iter().cloned().fold(f32::NEG_INFINITY, f32::max);
244        let exp_sum: f32 = logits[i].iter().map(|&l| (l - max_l).exp()).sum();
245        let mut d_log = vec![0.0f32; vocab];
246        for v in 0..vocab {
247            let softmax_v = (logits[i][v] - max_l).exp() / exp_sum;
248            d_log[v] = softmax_v - if v == target { 1.0 } else { 0.0 };
249        }
250        // Scale by 1/t for average
251        for v in &mut d_log {
252            *v /= t as f32;
253        }
254        d_logits.push(d_log);
255    }
256
257    // 2. ∂L/∂output_weights
258    let mut d_output = vec![0.0f32; dim * vocab];
259    for i in 0..t {
260        let cache = caches.last().unwrap();
261        let vals = if i < cache.populations.len() {
262            &cache.populations[i]
263        } else {
264            continue;
265        };
266        for d_idx in 0..dim {
267            for v in 0..vocab {
268                d_output[d_idx * vocab + v] += vals.get(d_idx).copied().unwrap_or(0.0) * d_logits[i][v];
269            }
270        }
271    }
272
273    // 3. ∂L/∂values (from output projection)
274    let mut d_values: Vec<Vec<f32>> = vec![vec![0.0f32; dim]; t.max(1)];
275    for i in 0..t {
276        for d_idx in 0..dim {
277            for v in 0..vocab {
278                d_values[i][d_idx] += model.output_weights[d_idx * vocab + v] * d_logits[i][v];
279            }
280        }
281    }
282
283    // 4. Per-block backward (reverse order)
284    let mut block_grads = Vec::with_capacity(model.blocks.len());
285    for (block_idx, block) in model.blocks.iter().enumerate().rev() {
286        let cache = &caches[block_idx];
287        let num_h = block.hamiltonian.num_params();
288
289        // ∂L/∂value_weights from attention projection
290        let mut d_vw = vec![0.0f32; dim * dim];
291        // Simplified: accumulate gradient from the projection step
292        for i in 0..t.min(cache.populations.len()) {
293            for d_idx in 0..dim {
294                for s in 0..dim {
295                    let pop = cache.populations[i].get(s).copied().unwrap_or(0.0);
296                    let dv = d_values[i].get(d_idx).copied().unwrap_or(0.0);
297                    d_vw[d_idx * dim + s] += pop * dv;
298                }
299            }
300        }
301
302        // ∂L/∂H via QUG: for each position, use stored unitary.
303        // Positions are mathematically independent in the backward pass —
304        // each reads from immutable cached state. Parallelize with rayon.
305        let dt = 0.090f32; // 1/φ⁵ — must match forward pass
306        let len = t.min(cache.unitaries.len());
307
308        let position_grads: Vec<Vec<f32>> = {
309            use rayon::prelude::*;
310            (0..len)
311                .into_par_iter()
312                .map(|i| {
313                    let mut local_d_h = vec![0.0f32; num_h];
314
315                    // Thermodynamic gate: skip gradient for positions where evolve was skipped.
316                    // Empty unitary = position was at thermal equilibrium (F < 1/φ³).
317                    // No unitary was applied → no gradient to propagate → zero contribution.
318                    let u = &cache.unitaries[i];
319                    if u.is_empty() {
320                        return local_d_h;
321                    }
322
323                    let mut scratch_a = vec![Complex::ZERO; dim * dim];
324                    let mut scratch_b = vec![Complex::ZERO; dim * dim];
325
326                    // ∂L/∂ρ_after from populations gradient (diagonal)
327                    let mut d_rho = vec![Complex::ZERO; dim * dim];
328                    for k in 0..dim {
329                        let dp = d_values[i].get(k).copied().unwrap_or(0.0);
330                        d_rho[k * dim + k] = Complex::new(dp, 0.0);
331                    }
332
333                    // Build U† explicitly (conjugate transpose)
334                    let mut u_dag = vec![Complex::ZERO; dim * dim];
335                    for r in 0..dim {
336                        for c in 0..dim {
337                            u_dag[r * dim + c] = u[c * dim + r].conj();
338                        }
339                    }
340
341                    // THE KEY STEP: ∂L/∂ρ_before = U† · d_rho · U
342                    dreamwell_math::linalg::cgemm(&u_dag, &d_rho, &mut scratch_a, dim, dim, dim);
343                    dreamwell_math::linalg::cgemm(&scratch_a, u, &mut scratch_b, dim, dim, dim);
344
345                    // Commutator gradient: ∂L/∂H = -dt · Im([∂L/∂ρ, ρ_before])
346                    let rho_before = &cache.rho_before[i];
347                    let mut h_idx = 0;
348
349                    // Bias gradients
350                    for k in 0..dim {
351                        let mut comm_diag = 0.0f32;
352                        for j in 0..dim {
353                            let ab = scratch_b[k * dim + j].mul(rho_before[j * dim + k]);
354                            let ba = rho_before[k * dim + j].mul(scratch_b[j * dim + k]);
355                            comm_diag += (ab.sub(ba)).im;
356                        }
357                        if h_idx < local_d_h.len() {
358                            local_d_h[h_idx] += -dt * comm_diag;
359                        }
360                        h_idx += 1;
361                    }
362
363                    // Coupling gradients
364                    for p in 0..dim {
365                        for q in (p + 1)..dim {
366                            if h_idx >= local_d_h.len() {
367                                break;
368                            }
369                            let ab_pq = scratch_b[p * dim + q].mul(rho_before[q * dim + p]);
370                            let ba_pq = rho_before[p * dim + q].mul(scratch_b[q * dim + p]);
371                            let comm_pq = ab_pq.sub(ba_pq);
372                            local_d_h[h_idx] += -dt * 2.0 * comm_pq.im;
373                            h_idx += 1;
374                        }
375                    }
376
377                    local_d_h
378                })
379                .collect()
380        };
381
382        // Reduce: sum all position gradients
383        let mut d_h = vec![0.0f32; num_h];
384        for pg in &position_grads {
385            for (k, &v) in pg.iter().enumerate() {
386                d_h[k] += v;
387            }
388        }
389
390        block_grads.push(BlockGrad {
391            hamiltonian_grad: d_h,
392            value_weight_grad: d_vw,
393        });
394    }
395
396    // Reverse the block grads (we iterated in reverse)
397    block_grads.reverse();
398
399    // 5. Embedding gradient (simplified: finite difference on angles is OK for now)
400    // Full analytic embedding gradient requires differentiating through Givens rotations.
401    // For now, use zero (embedding angles are less critical than Hamiltonian params).
402    let embed_grad = vec![0.0f32; model.embedding.num_params()];
403
404    AllGradients {
405        embed_grad,
406        block_grads,
407        output_grad: d_output,
408    }
409}
410
411/// BA-60: GPU-accelerated forward + backward epoch.
412///
413/// Replaces the rayon par_iter(forward_with_cache + qug_backward) block
414/// with a version that uses GpuTrainingContext for the heavy matrix operations:
415///   - Batched expm for unitary precomputation (all positions × all blocks × all windows)
416///   - Batched evolve for ρ' = UρU† (all windows in parallel per position)
417///   - Batched adjoint for ∂L/∂ρ_before = U†·dρ·U (all positions × all blocks × all windows)
418///
419/// The causal dephasing loop, free energy computation, value projection, and
420/// commutator gradient extraction remain on CPU — they are element-wise operations
421/// that are negligible compared to the O(dim³) matrix multiplies.
422///
423/// Returns (avg_gradient_flat, avg_loss, avg_free_energy).
424pub fn forward_backward_epoch_gpu(
425    gpu: &dreamwell_math::gpu_training::GpuTrainingContext,
426    model: &QCT,
427    windows: &[(usize, usize)],
428    tokens: &[usize],
429) -> (Vec<f32>, f32, f32) {
430    use dreamwell_math::Complex;
431    let dim = model.config.dim;
432    let vocab = model.config.vocab_size;
433    let stride = dim * dim;
434    let num_windows = windows.len();
435    let num_blocks = model.blocks.len();
436    let dt = 0.090f32; // 1/φ⁵
437
438    // ═══════════════════════════════════════════════════════
439    // Phase 1: Batched EXPM — all unitaries for all blocks, all windows
440    // ═══════════════════════════════════════════════════════
441
442    // Collect all Hamiltonians: [window][block][position] → flat batch
443    let mut all_window_data: Vec<WindowForwardState> = Vec::with_capacity(num_windows);
444
445    for &(ws, we) in windows {
446        let window_tokens = &tokens[ws..we];
447        let input = &window_tokens[..window_tokens.len().saturating_sub(1)];
448        let t = input.len();
449
450        // Embed tokens into density matrices
451        let states: Vec<DensityMatrixN> = input.iter().map(|&tok| model.embedding.embed(tok)).collect();
452        let values: Vec<Vec<f32>> = states.iter().map(|s| s.populations()).collect();
453
454        // Pre-compute ALL unitaries for ALL blocks using GPU batched expm
455        let mut block_unitaries: Vec<Vec<Vec<Complex>>> = Vec::with_capacity(num_blocks);
456        for block in &model.blocks {
457            let mut all_h = vec![0.0f32; t * stride];
458            for i in 0..t {
459                let h = block.hamiltonian.build_matrix(i);
460                all_h[i * stride..(i + 1) * stride].copy_from_slice(&h);
461            }
462            let flat_unitaries = gpu.batched_expm(&all_h, dt, t);
463            let per_pos: Vec<Vec<Complex>> = (0..t)
464                .map(|i| flat_unitaries[i * stride..(i + 1) * stride].to_vec())
465                .collect();
466            block_unitaries.push(per_pos);
467        }
468
469        all_window_data.push(WindowForwardState {
470            window_tokens: window_tokens.to_vec(),
471            t,
472            states,
473            values,
474            block_unitaries,
475        });
476    }
477
478    // ═══════════════════════════════════════════════════════
479    // Phase 2: Forward pass per window (causal loop on CPU, evolve on GPU)
480    // ═══════════════════════════════════════════════════════
481
482    let f_gate = 0.236 / dim as f32;
483    const COHERENCE_WINDOW: usize = 8;
484    let eps_block = 0.236 / num_blocks.max(1) as f32;
485
486    let mut all_window_results: Vec<WindowResult> = Vec::with_capacity(num_windows);
487
488    for wdata in &mut all_window_data {
489        let t = wdata.t;
490        let mut current_states = wdata.states.clone();
491        let mut values = wdata.values.clone();
492        let mut caches: Vec<ForwardCache> = Vec::with_capacity(num_blocks);
493        let mut total_f = 0.0f32;
494
495        for (block_idx, block) in model.blocks.iter().enumerate() {
496            let mut cache = ForwardCache {
497                rho_before: Vec::with_capacity(t),
498                unitaries: Vec::with_capacity(t),
499                rho_after: Vec::with_capacity(t),
500                populations: Vec::with_capacity(t),
501                values: values.clone(),
502            };
503
504            let precomputed_unitaries = &wdata.block_unitaries[block_idx];
505
506            // Causal loop — sequential per position (CPU).
507            // The evolve (UρU†) uses CPU cgemm_par since it's sequential
508            // and can't be batched. GPU already saved time on expm above.
509            for i in 0..t {
510                let mut rho = current_states[i].clone();
511
512                // Causal dephasing (element-wise, cheap)
513                let window_start = i.saturating_sub(COHERENCE_WINDOW);
514                for j in window_start..i {
515                    let dist = i - j;
516                    let eps = block.hamiltonian.causal_dephasing(dist);
517                    rho.couple_dephase(&current_states[j], eps);
518                }
519
520                cache.rho_before.push(rho.entries.clone());
521
522                let f_before = rho.free_energy(&block.hamiltonian.bias);
523                let unitary = &precomputed_unitaries[i];
524
525                if f_before.abs() > f_gate {
526                    cache.unitaries.push(unitary.clone());
527                    rho.evolve(unitary);
528                } else {
529                    cache.unitaries.push(Vec::new());
530                }
531
532                cache.rho_after.push(rho.entries.clone());
533
534                let f = rho.free_energy(&block.hamiltonian.bias);
535                total_f += f;
536
537                let pops = rho.populations();
538                cache.populations.push(pops);
539
540                // Update current_states for subsequent causal coupling
541                current_states[i] = rho;
542            }
543
544            // Value projection (CPU — O(t × dim²), cheap)
545            let attn_output = attention::AttentionOutput {
546                populations: cache.populations.clone(),
547                free_energies: vec![0.0; t],
548                coherences: vec![0.0; t],
549            };
550            values = attention::attention_project(&attn_output, &values, dim);
551
552            // Inter-block dephasing (CPU — element-wise)
553            for state in &mut current_states {
554                state.dephase(eps_block);
555            }
556
557            caches.push(cache);
558        }
559
560        // Output projection (CPU — O(t × dim × vocab), cheap)
561        let mut logits = Vec::with_capacity(t);
562        for i in 0..t {
563            let mut token_logits = vec![0.0f32; vocab];
564            for v in 0..vocab {
565                for d in 0..dim {
566                    token_logits[v] += values[i][d] * model.output_weights[d * vocab + v];
567                }
568            }
569            logits.push(token_logits);
570        }
571
572        let avg_f = total_f / t.max(1) as f32;
573        let loss = QCT::loss_from_logits(&logits, &wdata.window_tokens, avg_f);
574
575        all_window_results.push(WindowResult {
576            logits,
577            caches,
578            loss,
579            avg_f,
580        });
581    }
582
583    // ═══════════════════════════════════════════════════════
584    // Phase 3: Backward pass (output grad on CPU, adjoint on GPU)
585    // ═══════════════════════════════════════════════════════
586
587    let num_params = model.num_params();
588    let mut total_grad = vec![0.0f32; num_params];
589    let mut total_loss = 0.0f32;
590    let mut total_f = 0.0f32;
591
592    for (w_idx, wdata) in all_window_data.iter().enumerate() {
593        let wr = &all_window_results[w_idx];
594
595        let grads = qug_backward_gpu(gpu, model, &wdata.window_tokens, &wr.logits, &wr.caches);
596        let grad_flat = grads.flatten();
597
598        total_loss += wr.loss;
599        total_f += wr.avg_f;
600        for (i, &g) in grad_flat.iter().enumerate() {
601            if i < num_params {
602                total_grad[i] += g;
603            }
604        }
605    }
606
607    // Average across windows
608    let n = num_windows as f32;
609    for g in &mut total_grad {
610        *g /= n;
611    }
612    (total_grad, total_loss / n, total_f / n)
613}
614
615/// GPU-accelerated QUG backward pass.
616///
617/// Same logic as qug_backward but uses batched GPU adjoint for the key step:
618/// ∂L/∂ρ_before = U† · d_rho · U
619fn qug_backward_gpu(
620    gpu: &dreamwell_math::gpu_training::GpuTrainingContext,
621    model: &QCT,
622    tokens: &[usize],
623    logits: &[Vec<f32>],
624    caches: &[ForwardCache],
625) -> AllGradients {
626    use dreamwell_math::Complex;
627    let dim = model.config.dim;
628    let vocab = model.config.vocab_size;
629    let stride = dim * dim;
630    let t = tokens.len().saturating_sub(1);
631    if t == 0 {
632        return AllGradients {
633            embed_grad: vec![0.0; model.embedding.num_params()],
634            block_grads: model
635                .blocks
636                .iter()
637                .map(|b| BlockGrad {
638                    hamiltonian_grad: vec![0.0; b.hamiltonian.num_params()],
639                    value_weight_grad: vec![0.0; b.value_weights.len()],
640                })
641                .collect(),
642            output_grad: vec![0.0; model.output_weights.len()],
643        };
644    }
645
646    // 1. ∂L/∂logits (CPU — cheap scalar ops)
647    let mut d_logits: Vec<Vec<f32>> = Vec::with_capacity(t);
648    for i in 0..t {
649        let target = tokens[i + 1];
650        let max_l = logits[i].iter().cloned().fold(f32::NEG_INFINITY, f32::max);
651        let exp_sum: f32 = logits[i].iter().map(|&l| (l - max_l).exp()).sum();
652        let mut d_log = vec![0.0f32; vocab];
653        for v in 0..vocab {
654            let softmax_v = (logits[i][v] - max_l).exp() / exp_sum;
655            d_log[v] = (softmax_v - if v == target { 1.0 } else { 0.0 }) / t as f32;
656        }
657        d_logits.push(d_log);
658    }
659
660    // 2. ∂L/∂output_weights (CPU)
661    let mut d_output = vec![0.0f32; dim * vocab];
662    for i in 0..t {
663        let cache = caches.last().unwrap();
664        let vals = if i < cache.populations.len() {
665            &cache.populations[i]
666        } else {
667            continue;
668        };
669        for d_idx in 0..dim {
670            for v in 0..vocab {
671                d_output[d_idx * vocab + v] += vals.get(d_idx).copied().unwrap_or(0.0) * d_logits[i][v];
672            }
673        }
674    }
675
676    // 3. ∂L/∂values (CPU)
677    let mut d_values: Vec<Vec<f32>> = vec![vec![0.0f32; dim]; t.max(1)];
678    for i in 0..t {
679        for d_idx in 0..dim {
680            for v in 0..vocab {
681                d_values[i][d_idx] += model.output_weights[d_idx * vocab + v] * d_logits[i][v];
682            }
683        }
684    }
685
686    // 4. Per-block backward with GPU batched adjoint
687    let dt_val = 0.090f32;
688    let mut block_grads = Vec::with_capacity(model.blocks.len());
689
690    for (block_idx, block) in model.blocks.iter().enumerate().rev() {
691        let cache = &caches[block_idx];
692        let num_h = block.hamiltonian.num_params();
693
694        // ∂L/∂value_weights (CPU — O(t × dim²))
695        let mut d_vw = vec![0.0f32; dim * dim];
696        for i in 0..t.min(cache.populations.len()) {
697            for d_idx in 0..dim {
698                for s in 0..dim {
699                    let pop = cache.populations[i].get(s).copied().unwrap_or(0.0);
700                    let dv = d_values[i].get(d_idx).copied().unwrap_or(0.0);
701                    d_vw[d_idx * dim + s] += pop * dv;
702                }
703            }
704        }
705
706        // Collect non-gated positions for GPU batched adjoint
707        let len = t.min(cache.unitaries.len());
708        let mut active_indices: Vec<usize> = Vec::new();
709        let mut u_batch: Vec<Complex> = Vec::new();
710        let mut d_rho_batch: Vec<Complex> = Vec::new();
711
712        for i in 0..len {
713            let u = &cache.unitaries[i];
714            if u.is_empty() {
715                continue;
716            }
717
718            active_indices.push(i);
719            u_batch.extend_from_slice(u);
720
721            // Build diagonal d_rho from d_values
722            let mut d_rho = vec![Complex::ZERO; stride];
723            for k in 0..dim {
724                let dp = d_values[i].get(k).copied().unwrap_or(0.0);
725                d_rho[k * dim + k] = Complex::new(dp, 0.0);
726            }
727            d_rho_batch.extend_from_slice(&d_rho);
728        }
729
730        // GPU batched adjoint: ∂L/∂ρ_before = U† · d_rho · U
731        let adjoint_results = if !active_indices.is_empty() {
732            gpu.batched_adjoint(&u_batch, &d_rho_batch, active_indices.len())
733        } else {
734            Vec::new()
735        };
736
737        // Extract commutator gradients on CPU (element-wise, cheap)
738        let mut d_h = vec![0.0f32; num_h];
739        for (batch_idx, &pos_idx) in active_indices.iter().enumerate() {
740            let base = batch_idx * stride;
741            let scratch_b = &adjoint_results[base..base + stride];
742            let rho_before = &cache.rho_before[pos_idx];
743            let mut h_idx = 0;
744
745            // Bias gradients
746            for k in 0..dim {
747                let mut comm_diag = 0.0f32;
748                for j in 0..dim {
749                    let ab = scratch_b[k * dim + j].mul(rho_before[j * dim + k]);
750                    let ba = rho_before[k * dim + j].mul(scratch_b[j * dim + k]);
751                    comm_diag += (ab.sub(ba)).im;
752                }
753                if h_idx < d_h.len() {
754                    d_h[h_idx] += -dt_val * comm_diag;
755                }
756                h_idx += 1;
757            }
758
759            // Coupling gradients
760            for p in 0..dim {
761                for q in (p + 1)..dim {
762                    if h_idx >= d_h.len() {
763                        break;
764                    }
765                    let ab_pq = scratch_b[p * dim + q].mul(rho_before[q * dim + p]);
766                    let ba_pq = rho_before[p * dim + q].mul(scratch_b[q * dim + p]);
767                    let comm_pq = ab_pq.sub(ba_pq);
768                    d_h[h_idx] += -dt_val * 2.0 * comm_pq.im;
769                    h_idx += 1;
770                }
771            }
772        }
773
774        block_grads.push(BlockGrad {
775            hamiltonian_grad: d_h,
776            value_weight_grad: d_vw,
777        });
778    }
779
780    block_grads.reverse();
781    let embed_grad = vec![0.0f32; model.embedding.num_params()];
782
783    AllGradients {
784        embed_grad,
785        block_grads,
786        output_grad: d_output,
787    }
788}
789
790/// Intermediate state for GPU-accelerated forward pass per window.
791struct WindowForwardState {
792    window_tokens: Vec<usize>,
793    t: usize,
794    states: Vec<DensityMatrixN>,
795    values: Vec<Vec<f32>>,
796    block_unitaries: Vec<Vec<Vec<dreamwell_math::Complex>>>,
797}
798
799/// Result of forward pass for one window.
800struct WindowResult {
801    logits: Vec<Vec<f32>>,
802    caches: Vec<ForwardCache>,
803    loss: f32,
804    avg_f: f32,
805}
806
807/// GPU-accelerated unitary precomputation for one block across all positions.
808///
809/// Batches all 64 (or T) `hamiltonian_unitary` calls into a single GPU dispatch.
810/// Returns T unitary matrices as a contiguous Vec<Complex>.
811///
812/// This replaces the rayon par_iter precomputation at lines 113-119 of the forward pass
813/// with a single GPU batched expm dispatch — the single largest speedup for training.
814pub fn gpu_precompute_unitaries(
815    gpu: &dreamwell_math::gpu_training::GpuTrainingContext,
816    block: &crate::transformer::QCTBlock,
817    dim: usize,
818    t: usize,
819    dt: f32,
820) -> Vec<Vec<dreamwell_math::Complex>> {
821    let n2 = dim * dim;
822
823    // Build all Hamiltonians on CPU (cheap: just fills a real symmetric matrix from params)
824    let mut all_h = vec![0.0f32; t * n2];
825    for i in 0..t {
826        let h = block.hamiltonian.build_matrix(i);
827        all_h[i * n2..(i + 1) * n2].copy_from_slice(&h);
828    }
829
830    // Single batched GPU expm dispatch
831    let flat = gpu.batched_expm(&all_h, dt, t);
832
833    // Split into per-position unitaries
834    let stride = dim * dim;
835    (0..t).map(|i| flat[i * stride..(i + 1) * stride].to_vec()).collect()
836}
837
838/// GPU-accelerated batched evolve for positions where the thermodynamic gate passes.
839///
840/// Takes a batch of (unitary, density_matrix) pairs and computes ρ' = U·ρ·U† on GPU.
841/// Returns evolved density matrices.
842pub fn gpu_batch_evolve(
843    gpu: &dreamwell_math::gpu_training::GpuTrainingContext,
844    unitaries_flat: &[dreamwell_math::Complex],
845    rhos_flat: &[dreamwell_math::Complex],
846    batch_count: usize,
847) -> Vec<dreamwell_math::Complex> {
848    gpu.batched_evolve(unitaries_flat, rhos_flat, batch_count)
849}
850
851/// GPU-accelerated batched adjoint for QUG backward pass.
852///
853/// Computes ∂L/∂ρ_before = U† · ∂L/∂ρ_after · U for all positions in one batch.
854pub fn gpu_batch_adjoint(
855    gpu: &dreamwell_math::gpu_training::GpuTrainingContext,
856    unitaries_flat: &[dreamwell_math::Complex],
857    d_rho_flat: &[dreamwell_math::Complex],
858    batch_count: usize,
859) -> Vec<dreamwell_math::Complex> {
860    gpu.batched_adjoint(unitaries_flat, d_rho_flat, batch_count)
861}
862
863/// Train using QUG (Quantum Unitarity Gradient) — the fast path.
864/// Cost: ~3× forward per epoch (vs 2P× for parameter shift).
865pub fn train_qug(model: &mut QCT, tokens: &[usize], config: &crate::train::TrainConfig) -> Vec<EpochMetrics> {
866    let mut metrics = Vec::new();
867
868    for epoch in 0..config.num_epochs {
869        let start = std::time::Instant::now();
870        let lr = crate::train::learning_rate_pub(config, epoch);
871
872        // Select window
873        let max_start = tokens.len().saturating_sub(config.context_length + 1);
874        let window_start = if max_start > 0 { epoch % max_start } else { 0 };
875        let window_end = (window_start + config.context_length + 1).min(tokens.len());
876        let window = &tokens[window_start..window_end];
877
878        // Forward with cache
879        let (logits, avg_f, caches) = forward_with_cache(model, &window[..window.len() - 1]);
880
881        // Compute loss from cached logits — no redundant forward pass
882        let loss = QCT::loss_from_logits(&logits, window, avg_f);
883
884        // QUG backward — ONE pass, ~3× forward cost
885        let grads = qug_backward(model, window, &logits, &caches);
886        let grad_flat = grads.flatten();
887
888        // Gradient norm
889        let grad_norm: f32 = grad_flat.iter().map(|g| g * g).sum::<f32>().sqrt();
890
891        // Clip
892        let scale = if grad_norm > config.grad_clip && grad_norm > 0.0 {
893            config.grad_clip / grad_norm
894        } else {
895            1.0
896        };
897
898        // Update parameters in-place (zero allocation)
899        model.apply_gradient_update(&grad_flat, lr, scale);
900
901        let elapsed = start.elapsed().as_secs_f32() * 1000.0;
902
903        if epoch % config.log_interval == 0 || epoch == config.num_epochs - 1 {
904            let m = EpochMetrics {
905                epoch,
906                loss,
907                free_energy: avg_f,
908                grad_norm,
909                elapsed_ms: elapsed,
910                learning_rate: lr,
911                params_trained: grad_flat.len(),
912            };
913            log::info!(
914                "QUG Epoch {:4}: loss={:.4} F={:.4} |∇|={:.6} lr={:.5} ({:.1}ms)",
915                m.epoch,
916                m.loss,
917                m.free_energy,
918                m.grad_norm,
919                m.learning_rate,
920                m.elapsed_ms
921            );
922            metrics.push(m);
923        }
924    }
925
926    metrics
927}
928
929#[cfg(test)]
930mod tests {
931    use super::*;
932    use crate::transformer::QCTConfig;
933
934    #[test]
935    fn forward_cache_matches_forward() {
936        let config = QCTConfig {
937            vocab_size: 10,
938            dim: 4,
939            num_blocks: 1,
940            seed: 42,
941        };
942        let model = QCT::new(config);
943        let tokens = vec![0, 1, 2, 3, 4, 5];
944
945        let (logits_normal, f_normal) = model.forward(&tokens);
946        let (logits_cached, f_cached, caches) = forward_with_cache(&model, &tokens);
947
948        assert_eq!(logits_normal.len(), logits_cached.len());
949        assert!(
950            (f_normal - f_cached).abs() < 0.5,
951            "free energy mismatch: {} vs {}",
952            f_normal,
953            f_cached
954        );
955        assert!(!caches.is_empty());
956    }
957
958    #[test]
959    fn qug_gradient_nonzero() {
960        let config = QCTConfig {
961            vocab_size: 10,
962            dim: 4,
963            num_blocks: 1,
964            seed: 42,
965        };
966        let model = QCT::new(config);
967        let tokens = vec![0, 1, 2, 3, 4, 5];
968
969        let (logits, _, caches) = forward_with_cache(&model, &tokens[..5]);
970        let grads = qug_backward(&model, &tokens, &logits, &caches);
971        let flat = grads.flatten();
972
973        let norm: f32 = flat.iter().map(|g| g * g).sum::<f32>().sqrt();
974        assert!(norm > 0.0, "QUG gradient should be nonzero");
975        assert!(norm.is_finite(), "QUG gradient should be finite");
976    }
977
978    #[test]
979    fn qug_training_runs() {
980        let config = QCTConfig {
981            vocab_size: 10,
982            dim: 4,
983            num_blocks: 1,
984            seed: 42,
985        };
986        let mut model = QCT::new(config);
987        let tokens: Vec<usize> = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
988
989        let train_config = crate::train::TrainConfig {
990            learning_rate: 0.01,
991            num_epochs: 3,
992            context_length: 6,
993            log_interval: 1,
994            ..Default::default()
995        };
996        let metrics = train_qug(&mut model, &tokens, &train_config);
997        assert_eq!(metrics.len(), 3);
998        assert!(metrics[0].loss.is_finite());
999        assert!(metrics[0].grad_norm > 0.0);
1000        eprintln!(
1001            "QUG training: {:.1}ms/epoch (vs ~7400ms for PSR)",
1002            metrics[0].elapsed_ms
1003        );
1004    }
1005}