1use crate::adjoint::{AllGradients, BlockGrad};
18use crate::complex::Complex;
19use crate::train::EpochMetrics;
20use crate::transformer::QCT;
21
22const PHI: f32 = 1.618033988;
23const PHI_INV: f32 = 0.618033988;
24
25pub struct FockCache {
27 pub blocks: Vec<FockBlockCache>,
29 pub final_populations: Vec<Vec<f32>>,
31 pub values: Vec<Vec<f32>>,
33}
34
35pub struct FockBlockCache {
37 pub amplitudes_before: Vec<Vec<Complex>>,
39 pub amplitudes_after: Vec<Vec<Complex>>,
41 pub eigenvectors: Vec<f32>,
43 pub eigenvalues: Vec<f32>,
45 pub phases: Vec<Complex>,
47 pub populations: Vec<Vec<f32>>,
49 pub values_in: Vec<Vec<f32>>,
51}
52
53pub fn fock_forward(model: &QCT, tokens: &[usize]) -> (Vec<Vec<f32>>, f32, FockCache) {
67 let dim = model.config.dim;
68 let t = tokens.len();
69 let dt = 0.090f32; let mut amplitudes: Vec<Vec<Complex>> = tokens.iter().map(|&tok| model.embedding.embed_amplitude(tok)).collect();
73
74 let mut values: Vec<Vec<f32>> = amplitudes.iter().map(|psi| populations_from_amplitudes(psi)).collect();
75
76 let mut block_caches = Vec::with_capacity(model.blocks.len());
77 let mut total_f = 0.0f32;
78
79 let eps_block = 0.236 / model.blocks.len().max(1) as f32;
80 const COHERENCE_WINDOW: usize = 8;
81
82 for block in &model.blocks {
83 let h_matrix = block.hamiltonian.build_matrix(0);
88 let h_diag: Vec<f32> = (0..dim).map(|k| h_matrix[k * dim + k]).collect();
89 let diag_phases: Vec<Complex> = h_diag.iter().map(|&e| Complex::exp_i(-e * dt)).collect();
90
91 let mut cache = FockBlockCache {
92 amplitudes_before: Vec::with_capacity(t),
93 amplitudes_after: Vec::with_capacity(t),
94 eigenvectors: Vec::new(), eigenvalues: h_diag.clone(),
96 phases: diag_phases.clone(),
97 populations: Vec::with_capacity(t),
98 values_in: values.clone(),
99 };
100
101 for i in 0..t {
103 let mut psi = amplitudes[i].clone();
104
105 let window_start = i.saturating_sub(COHERENCE_WINDOW);
107 for j in window_start..i {
108 let dist = i - j;
109 let eps = block.hamiltonian.causal_dephasing(dist);
110 dephase_amplitude_coupled(&mut psi, &litudes[j], eps);
111 }
112
113 cache.amplitudes_before.push(psi.clone());
114
115 let mut psi_evolved = psi.clone();
120 for k in 0..dim {
121 psi_evolved[k] = psi_evolved[k].mul(diag_phases[k]);
122 }
123 let psi_pre = psi_evolved.clone();
126 for ii in 0..dim {
127 let mut coupling_sum = Complex::ZERO;
128 for jj in 0..dim {
129 if ii == jj {
130 continue;
131 }
132 let h_ij = h_matrix[ii * dim + jj];
133 if h_ij.abs() < 1e-10 {
134 continue;
135 }
136 coupling_sum =
138 coupling_sum.add(Complex::new(h_ij * dt * psi_pre[jj].im, -h_ij * dt * psi_pre[jj].re));
139 }
140 psi_evolved[ii] = psi_evolved[ii].add(coupling_sum);
141 }
142 let norm_sq: f32 = psi_evolved.iter().map(|c| c.norm_sq()).sum();
144 if norm_sq > 1e-10 {
145 let inv = 1.0 / norm_sq.sqrt();
146 for c in &mut psi_evolved {
147 *c = c.scale(inv);
148 }
149 }
150
151 cache.amplitudes_after.push(psi_evolved.clone());
152
153 let pops = populations_from_amplitudes(&psi_evolved);
155 let f = free_energy_from_amplitudes(&psi_evolved, &block.hamiltonian.bias);
156 total_f += f;
157
158 cache.populations.push(pops);
159 amplitudes[i] = psi_evolved;
160 }
161
162 let attn_output = crate::attention::AttentionOutput {
164 populations: cache.populations.clone(),
165 free_energies: vec![0.0; t],
166 coherences: vec![0.0; t],
167 };
168 values = crate::attention::attention_project(&attn_output, &values, dim);
169
170 for psi in &mut amplitudes {
172 dephase_amplitude(psi, eps_block);
173 }
174
175 block_caches.push(cache);
176 }
177
178 let vocab = model.config.vocab_size;
180 let mut logits = Vec::with_capacity(t);
181 for i in 0..t {
182 let mut token_logits = vec![0.0f32; vocab];
183 for v in 0..vocab {
184 for d in 0..dim {
185 token_logits[v] += values[i][d] * model.output_weights[d * vocab + v];
186 }
187 }
188 logits.push(token_logits);
189 }
190
191 let avg_f = total_f / t.max(1) as f32;
192 let final_pops = block_caches.last().map(|c| c.populations.clone()).unwrap_or_default();
193
194 (
195 logits,
196 avg_f,
197 FockCache {
198 blocks: block_caches,
199 final_populations: final_pops,
200 values,
201 },
202 )
203}
204
205pub fn fock_backward(model: &QCT, tokens: &[usize], logits: &[Vec<f32>], cache: &FockCache) -> AllGradients {
211 let dim = model.config.dim;
212 let vocab = model.config.vocab_size;
213 let t = tokens.len().saturating_sub(1);
214 if t == 0 {
215 return AllGradients {
216 embed_grad: vec![0.0; model.embedding.num_params()],
217 block_grads: model
218 .blocks
219 .iter()
220 .map(|b| BlockGrad {
221 hamiltonian_grad: vec![0.0; b.hamiltonian.num_params()],
222 value_weight_grad: vec![0.0; b.value_weights.len()],
223 })
224 .collect(),
225 output_grad: vec![0.0; model.output_weights.len()],
226 };
227 }
228
229 let mut d_logits: Vec<Vec<f32>> = Vec::with_capacity(t);
231 for i in 0..t {
232 let target = tokens[i + 1];
233 let max_l = logits[i].iter().cloned().fold(f32::NEG_INFINITY, f32::max);
234 let exp_sum: f32 = logits[i].iter().map(|&l| (l - max_l).exp()).sum();
235 let mut d_log = vec![0.0f32; vocab];
236 for v in 0..vocab {
237 let softmax_v = (logits[i][v] - max_l).exp() / exp_sum;
238 d_log[v] = (softmax_v - if v == target { 1.0 } else { 0.0 }) / t as f32;
239 }
240 d_logits.push(d_log);
241 }
242
243 let mut d_output = vec![0.0f32; dim * vocab];
245 if let Some(last_cache) = cache.blocks.last() {
246 for i in 0..t.min(last_cache.populations.len()) {
247 let pops = &last_cache.populations[i];
248 for d_idx in 0..dim {
249 for v in 0..vocab {
250 d_output[d_idx * vocab + v] += pops.get(d_idx).copied().unwrap_or(0.0) * d_logits[i][v];
251 }
252 }
253 }
254 }
255
256 let mut d_values: Vec<Vec<f32>> = vec![vec![0.0f32; dim]; t.max(1)];
258 for i in 0..t {
259 for d_idx in 0..dim {
260 for v in 0..vocab {
261 d_values[i][d_idx] += model.output_weights[d_idx * vocab + v] * d_logits[i][v];
262 }
263 }
264 }
265
266 let dt = 0.090f32;
268 let mut block_grads = Vec::with_capacity(model.blocks.len());
269
270 for (block_idx, block) in model.blocks.iter().enumerate().rev() {
271 let bc = &cache.blocks[block_idx];
272 let num_h = block.hamiltonian.num_params();
273
274 let mut d_vw = vec![0.0f32; dim * dim];
276 for i in 0..t.min(bc.populations.len()) {
277 for d_idx in 0..dim {
278 for s in 0..dim {
279 let pop = bc.populations[i].get(s).copied().unwrap_or(0.0);
280 let dv = d_values[i].get(d_idx).copied().unwrap_or(0.0);
281 d_vw[d_idx * dim + s] += pop * dv;
282 }
283 }
284 }
285
286 let len = t.min(bc.amplitudes_before.len());
294 let mut d_h = vec![0.0f32; num_h];
295
296 use rayon::prelude::*;
298 let position_grads: Vec<Vec<f32>> = (0..len)
299 .into_par_iter()
300 .map(|i| {
301 let mut local_d_h = vec![0.0f32; num_h];
302 let psi = &bc.amplitudes_before[i];
303
304 let d_pop: Vec<f32> = (0..dim).map(|k| d_values[i].get(k).copied().unwrap_or(0.0)).collect();
307
308 let mut h_idx = 0;
312 for k in 0..dim {
313 let pop_k = psi[k].norm_sq();
318 local_d_h[h_idx] = -dt * d_pop[k] * pop_k;
319 h_idx += 1;
320 }
321
322 for p in 0..dim {
324 for q in (p + 1)..dim {
325 if h_idx >= local_d_h.len() {
326 break;
327 }
328 let psi_p = psi[p];
331 let psi_q = psi[q];
332 let cross = psi_p.mul(psi_q.conj());
333 local_d_h[h_idx] = -dt * 2.0 * (d_pop[p] + d_pop[q]) * cross.im;
334 h_idx += 1;
335 }
336 }
337
338 local_d_h
339 })
340 .collect();
341
342 for pg in &position_grads {
344 for (k, &v) in pg.iter().enumerate() {
345 d_h[k] += v;
346 }
347 }
348
349 block_grads.push(BlockGrad {
350 hamiltonian_grad: d_h,
351 value_weight_grad: d_vw,
352 });
353 }
354
355 block_grads.reverse();
356 let embed_grad = vec![0.0f32; model.embedding.num_params()];
357
358 AllGradients {
359 embed_grad,
360 block_grads,
361 output_grad: d_output,
362 }
363}
364
365fn populations_from_amplitudes(psi: &[Complex]) -> Vec<f32> {
369 psi.iter().map(|c| c.norm_sq()).collect()
370}
371
372fn free_energy_from_amplitudes(psi: &[Complex], bias: &[f32]) -> f32 {
375 let dim = psi.len();
376 let pops: Vec<f32> = psi.iter().map(|c| c.norm_sq()).collect();
377
378 let expected_h: f32 = pops.iter().zip(bias.iter()).map(|(p, e)| p * e).sum();
380
381 let mut coh = 0.0f32;
383 for i in 0..dim {
384 for j in (i + 1)..dim {
385 coh += psi[i].mul(psi[j].conj()).norm();
386 }
387 }
388
389 let temperature = 1.0 / (1.0 + PHI * coh);
391
392 let mut entropy = 0.0f32;
394 for &p in &pops {
395 if p > 1e-10 {
396 entropy -= p * p.ln();
397 }
398 }
399
400 expected_h - temperature * entropy
401}
402
403fn dephase_amplitude(psi: &mut [Complex], epsilon: f32) {
406 let retain_sqrt = (1.0 - epsilon).max(0.0).sqrt();
407 for c in psi.iter_mut() {
408 *c = c.scale(retain_sqrt);
409 }
410 let norm_sq: f32 = psi.iter().map(|c| c.norm_sq()).sum();
412 if norm_sq > 1e-10 {
413 let inv_norm = 1.0 / norm_sq.sqrt();
414 for c in psi.iter_mut() {
415 *c = c.scale(inv_norm);
416 }
417 }
418}
419
420fn dephase_amplitude_coupled(psi: &mut [Complex], other: &[Complex], strength: f32) {
423 let dim = other.len();
425 let mut other_coh = 0.0f32;
426 for i in 0..dim {
427 for j in (i + 1)..dim {
428 other_coh += other[i].mul(other[j].conj()).norm();
429 }
430 }
431 other_coh = other_coh.min(1.0);
432 let retain = (1.0 - strength * (1.0 - other_coh)).max(0.0);
433 let retain_sqrt = retain.sqrt();
434 for c in psi.iter_mut() {
435 *c = c.scale(retain_sqrt);
436 }
437 let norm_sq: f32 = psi.iter().map(|c| c.norm_sq()).sum();
439 if norm_sq > 1e-10 {
440 let inv_norm = 1.0 / norm_sq.sqrt();
441 for c in psi.iter_mut() {
442 *c = c.scale(inv_norm);
443 }
444 }
445}
446
447fn matvec_real(m: &[f32], x: &[Complex], dim: usize) -> Vec<Complex> {
450 let mut y = vec![Complex::ZERO; dim];
451 for i in 0..dim {
452 let mut sum = Complex::ZERO;
453 for j in 0..dim {
454 let mij = m[i * dim + j];
455 sum = sum.add(x[j].scale(mij));
456 }
457 y[i] = sum;
458 }
459 y
460}
461
462fn matvec_transpose_real(m: &[f32], x: &[Complex], dim: usize) -> Vec<Complex> {
465 let mut y = vec![Complex::ZERO; dim];
466 for j in 0..dim {
467 for i in 0..dim {
468 let mij = m[i * dim + j]; y[j] = y[j].add(x[i].scale(mij));
470 }
471 }
472 y
473}
474
475fn diagonalize_real_symmetric(h: &[f32], eigenvalues: &mut [f32], eigenvectors: &mut [f32], dim: usize) {
478 let mut work = vec![Complex::ZERO; dim * dim];
480 for i in 0..dim * dim {
481 work[i] = Complex::new(h[i], 0.0);
482 }
483
484 dreamwell_math::eigen::eigenvalues_hermitian(&mut work, eigenvalues, dim, 50, 1e-6);
485
486 for i in 0..dim * dim {
489 eigenvectors[i] = work[i].re;
490 }
491}
492
493#[cfg(test)]
494mod tests {
495 use super::*;
496 use crate::transformer::{QCTConfig, QCT};
497
498 #[test]
499 fn embed_amplitude_matches_populations() {
500 let config = QCTConfig {
501 vocab_size: 65,
502 dim: 5,
503 num_blocks: 2,
504 seed: 42,
505 };
506 let model = QCT::new(config);
507
508 for token in 0..10 {
509 let rho = model.embedding.embed(token);
510 let psi = model.embedding.embed_amplitude(token);
511 let pops_rho = rho.populations();
512 let pops_psi = populations_from_amplitudes(&psi);
513
514 for k in 0..5 {
515 assert!(
516 (pops_rho[k] - pops_psi[k]).abs() < 1e-5,
517 "token {token} mode {k}: rho={} psi={}",
518 pops_rho[k],
519 pops_psi[k]
520 );
521 }
522 }
523 }
524
525 #[test]
526 fn fock_forward_produces_valid_logits() {
527 let config = QCTConfig {
528 vocab_size: 10,
529 dim: 5,
530 num_blocks: 2,
531 seed: 42,
532 };
533 let model = QCT::new(config);
534 let tokens = vec![0, 1, 2, 3, 4, 5];
535
536 let (logits, avg_f, _cache) = fock_forward(&model, &tokens);
537
538 assert_eq!(logits.len(), tokens.len());
539 for l in &logits {
540 assert_eq!(l.len(), 10);
541 for &v in l {
543 assert!(v.is_finite(), "logit not finite: {v}");
544 }
545 }
546 assert!(avg_f.is_finite(), "free energy not finite: {avg_f}");
547 }
548
549 #[test]
550 fn fock_forward_loss_is_finite() {
551 let config = QCTConfig {
552 vocab_size: 10,
553 dim: 5,
554 num_blocks: 2,
555 seed: 42,
556 };
557 let model = QCT::new(config);
558 let tokens = vec![0, 1, 2, 3, 4, 5, 6, 7];
559
560 let (logits, avg_f, _cache) = fock_forward(&model, &tokens[..7]);
561 let loss = QCT::loss_from_logits(&logits, &tokens, avg_f);
562
563 assert!(loss.is_finite(), "loss not finite: {loss}");
564 assert!(loss > 0.0, "loss should be positive: {loss}");
565 }
566
567 #[test]
568 fn fock_backward_produces_gradients() {
569 let config = QCTConfig {
570 vocab_size: 10,
571 dim: 5,
572 num_blocks: 2,
573 seed: 42,
574 };
575 let model = QCT::new(config);
576 let tokens = vec![0, 1, 2, 3, 4, 5, 6, 7];
577
578 let (logits, _avg_f, cache) = fock_forward(&model, &tokens[..7]);
579 let grads = fock_backward(&model, &tokens, &logits, &cache);
580
581 let grad_flat = grads.flatten();
583 let norm: f32 = grad_flat.iter().map(|g| g * g).sum::<f32>().sqrt();
584 assert!(norm > 1e-6, "gradient norm should be nonzero: {norm}");
585 }
586
587 #[test]
588 fn fock_training_reduces_loss() {
589 let config = QCTConfig {
590 vocab_size: 10,
591 dim: 5,
592 num_blocks: 2,
593 seed: 42,
594 };
595 let mut model = QCT::new(config);
596 let tokens = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5];
597
598 let (logits0, f0, cache0) = fock_forward(&model, &tokens[..15]);
599 let loss0 = QCT::loss_from_logits(&logits0, &tokens, f0);
600
601 for _ in 0..10 {
603 let (logits, avg_f, cache) = fock_forward(&model, &tokens[..15]);
604 let grads = fock_backward(&model, &tokens, &logits, &cache);
605 let grad_flat = grads.flatten();
606 model.apply_gradient_update(&grad_flat, 0.03, 1.0);
607 }
608
609 let (logits1, f1, _) = fock_forward(&model, &tokens[..15]);
610 let loss1 = QCT::loss_from_logits(&logits1, &tokens, f1);
611
612 assert!(
613 loss1 < loss0 + 0.1,
614 "loss should decrease or stay flat: {loss0} → {loss1}"
615 );
616 }
617}