use crate::adjoint::{AllGradients, BlockGrad};
use crate::complex::Complex;
use crate::train::EpochMetrics;
use crate::transformer::QCT;
const PHI: f32 = 1.618033988;
const PHI_INV: f32 = 0.618033988;
pub struct FockCache {
pub blocks: Vec<FockBlockCache>,
pub final_populations: Vec<Vec<f32>>,
pub values: Vec<Vec<f32>>,
}
pub struct FockBlockCache {
pub amplitudes_before: Vec<Vec<Complex>>,
pub amplitudes_after: Vec<Vec<Complex>>,
pub eigenvectors: Vec<f32>,
pub eigenvalues: Vec<f32>,
pub phases: Vec<Complex>,
pub populations: Vec<Vec<f32>>,
pub values_in: Vec<Vec<f32>>,
}
pub fn fock_forward(model: &QCT, tokens: &[usize]) -> (Vec<Vec<f32>>, f32, FockCache) {
let dim = model.config.dim;
let t = tokens.len();
let dt = 0.090f32;
let mut amplitudes: Vec<Vec<Complex>> = tokens.iter().map(|&tok| model.embedding.embed_amplitude(tok)).collect();
let mut values: Vec<Vec<f32>> = amplitudes.iter().map(|psi| populations_from_amplitudes(psi)).collect();
let mut block_caches = Vec::with_capacity(model.blocks.len());
let mut total_f = 0.0f32;
let eps_block = 0.236 / model.blocks.len().max(1) as f32;
const COHERENCE_WINDOW: usize = 8;
for block in &model.blocks {
let h_matrix = block.hamiltonian.build_matrix(0);
let h_diag: Vec<f32> = (0..dim).map(|k| h_matrix[k * dim + k]).collect();
let diag_phases: Vec<Complex> = h_diag.iter().map(|&e| Complex::exp_i(-e * dt)).collect();
let mut cache = FockBlockCache {
amplitudes_before: Vec::with_capacity(t),
amplitudes_after: Vec::with_capacity(t),
eigenvectors: Vec::new(), eigenvalues: h_diag.clone(),
phases: diag_phases.clone(),
populations: Vec::with_capacity(t),
values_in: values.clone(),
};
for i in 0..t {
let mut psi = amplitudes[i].clone();
let window_start = i.saturating_sub(COHERENCE_WINDOW);
for j in window_start..i {
let dist = i - j;
let eps = block.hamiltonian.causal_dephasing(dist);
dephase_amplitude_coupled(&mut psi, &litudes[j], eps);
}
cache.amplitudes_before.push(psi.clone());
let mut psi_evolved = psi.clone();
for k in 0..dim {
psi_evolved[k] = psi_evolved[k].mul(diag_phases[k]);
}
let psi_pre = psi_evolved.clone();
for ii in 0..dim {
let mut coupling_sum = Complex::ZERO;
for jj in 0..dim {
if ii == jj {
continue;
}
let h_ij = h_matrix[ii * dim + jj];
if h_ij.abs() < 1e-10 {
continue;
}
coupling_sum =
coupling_sum.add(Complex::new(h_ij * dt * psi_pre[jj].im, -h_ij * dt * psi_pre[jj].re));
}
psi_evolved[ii] = psi_evolved[ii].add(coupling_sum);
}
let norm_sq: f32 = psi_evolved.iter().map(|c| c.norm_sq()).sum();
if norm_sq > 1e-10 {
let inv = 1.0 / norm_sq.sqrt();
for c in &mut psi_evolved {
*c = c.scale(inv);
}
}
cache.amplitudes_after.push(psi_evolved.clone());
let pops = populations_from_amplitudes(&psi_evolved);
let f = free_energy_from_amplitudes(&psi_evolved, &block.hamiltonian.bias);
total_f += f;
cache.populations.push(pops);
amplitudes[i] = psi_evolved;
}
let attn_output = crate::attention::AttentionOutput {
populations: cache.populations.clone(),
free_energies: vec![0.0; t],
coherences: vec![0.0; t],
};
values = crate::attention::attention_project(&attn_output, &values, dim);
for psi in &mut amplitudes {
dephase_amplitude(psi, eps_block);
}
block_caches.push(cache);
}
let vocab = model.config.vocab_size;
let mut logits = Vec::with_capacity(t);
for i in 0..t {
let mut token_logits = vec![0.0f32; vocab];
for v in 0..vocab {
for d in 0..dim {
token_logits[v] += values[i][d] * model.output_weights[d * vocab + v];
}
}
logits.push(token_logits);
}
let avg_f = total_f / t.max(1) as f32;
let final_pops = block_caches.last().map(|c| c.populations.clone()).unwrap_or_default();
(
logits,
avg_f,
FockCache {
blocks: block_caches,
final_populations: final_pops,
values,
},
)
}
pub fn fock_backward(model: &QCT, tokens: &[usize], logits: &[Vec<f32>], cache: &FockCache) -> AllGradients {
let dim = model.config.dim;
let vocab = model.config.vocab_size;
let t = tokens.len().saturating_sub(1);
if t == 0 {
return AllGradients {
embed_grad: vec![0.0; model.embedding.num_params()],
block_grads: model
.blocks
.iter()
.map(|b| BlockGrad {
hamiltonian_grad: vec![0.0; b.hamiltonian.num_params()],
value_weight_grad: vec![0.0; b.value_weights.len()],
})
.collect(),
output_grad: vec![0.0; model.output_weights.len()],
};
}
let mut d_logits: Vec<Vec<f32>> = Vec::with_capacity(t);
for i in 0..t {
let target = tokens[i + 1];
let max_l = logits[i].iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let exp_sum: f32 = logits[i].iter().map(|&l| (l - max_l).exp()).sum();
let mut d_log = vec![0.0f32; vocab];
for v in 0..vocab {
let softmax_v = (logits[i][v] - max_l).exp() / exp_sum;
d_log[v] = (softmax_v - if v == target { 1.0 } else { 0.0 }) / t as f32;
}
d_logits.push(d_log);
}
let mut d_output = vec![0.0f32; dim * vocab];
if let Some(last_cache) = cache.blocks.last() {
for i in 0..t.min(last_cache.populations.len()) {
let pops = &last_cache.populations[i];
for d_idx in 0..dim {
for v in 0..vocab {
d_output[d_idx * vocab + v] += pops.get(d_idx).copied().unwrap_or(0.0) * d_logits[i][v];
}
}
}
}
let mut d_values: Vec<Vec<f32>> = vec![vec![0.0f32; dim]; t.max(1)];
for i in 0..t {
for d_idx in 0..dim {
for v in 0..vocab {
d_values[i][d_idx] += model.output_weights[d_idx * vocab + v] * d_logits[i][v];
}
}
}
let dt = 0.090f32;
let mut block_grads = Vec::with_capacity(model.blocks.len());
for (block_idx, block) in model.blocks.iter().enumerate().rev() {
let bc = &cache.blocks[block_idx];
let num_h = block.hamiltonian.num_params();
let mut d_vw = vec![0.0f32; dim * dim];
for i in 0..t.min(bc.populations.len()) {
for d_idx in 0..dim {
for s in 0..dim {
let pop = bc.populations[i].get(s).copied().unwrap_or(0.0);
let dv = d_values[i].get(d_idx).copied().unwrap_or(0.0);
d_vw[d_idx * dim + s] += pop * dv;
}
}
}
let len = t.min(bc.amplitudes_before.len());
let mut d_h = vec![0.0f32; num_h];
use rayon::prelude::*;
let position_grads: Vec<Vec<f32>> = (0..len)
.into_par_iter()
.map(|i| {
let mut local_d_h = vec![0.0f32; num_h];
let psi = &bc.amplitudes_before[i];
let d_pop: Vec<f32> = (0..dim).map(|k| d_values[i].get(k).copied().unwrap_or(0.0)).collect();
let mut h_idx = 0;
for k in 0..dim {
let pop_k = psi[k].norm_sq();
local_d_h[h_idx] = -dt * d_pop[k] * pop_k;
h_idx += 1;
}
for p in 0..dim {
for q in (p + 1)..dim {
if h_idx >= local_d_h.len() {
break;
}
let psi_p = psi[p];
let psi_q = psi[q];
let cross = psi_p.mul(psi_q.conj());
local_d_h[h_idx] = -dt * 2.0 * (d_pop[p] + d_pop[q]) * cross.im;
h_idx += 1;
}
}
local_d_h
})
.collect();
for pg in &position_grads {
for (k, &v) in pg.iter().enumerate() {
d_h[k] += v;
}
}
block_grads.push(BlockGrad {
hamiltonian_grad: d_h,
value_weight_grad: d_vw,
});
}
block_grads.reverse();
let embed_grad = vec![0.0f32; model.embedding.num_params()];
AllGradients {
embed_grad,
block_grads,
output_grad: d_output,
}
}
fn populations_from_amplitudes(psi: &[Complex]) -> Vec<f32> {
psi.iter().map(|c| c.norm_sq()).collect()
}
fn free_energy_from_amplitudes(psi: &[Complex], bias: &[f32]) -> f32 {
let dim = psi.len();
let pops: Vec<f32> = psi.iter().map(|c| c.norm_sq()).collect();
let expected_h: f32 = pops.iter().zip(bias.iter()).map(|(p, e)| p * e).sum();
let mut coh = 0.0f32;
for i in 0..dim {
for j in (i + 1)..dim {
coh += psi[i].mul(psi[j].conj()).norm();
}
}
let temperature = 1.0 / (1.0 + PHI * coh);
let mut entropy = 0.0f32;
for &p in &pops {
if p > 1e-10 {
entropy -= p * p.ln();
}
}
expected_h - temperature * entropy
}
fn dephase_amplitude(psi: &mut [Complex], epsilon: f32) {
let retain_sqrt = (1.0 - epsilon).max(0.0).sqrt();
for c in psi.iter_mut() {
*c = c.scale(retain_sqrt);
}
let norm_sq: f32 = psi.iter().map(|c| c.norm_sq()).sum();
if norm_sq > 1e-10 {
let inv_norm = 1.0 / norm_sq.sqrt();
for c in psi.iter_mut() {
*c = c.scale(inv_norm);
}
}
}
fn dephase_amplitude_coupled(psi: &mut [Complex], other: &[Complex], strength: f32) {
let dim = other.len();
let mut other_coh = 0.0f32;
for i in 0..dim {
for j in (i + 1)..dim {
other_coh += other[i].mul(other[j].conj()).norm();
}
}
other_coh = other_coh.min(1.0);
let retain = (1.0 - strength * (1.0 - other_coh)).max(0.0);
let retain_sqrt = retain.sqrt();
for c in psi.iter_mut() {
*c = c.scale(retain_sqrt);
}
let norm_sq: f32 = psi.iter().map(|c| c.norm_sq()).sum();
if norm_sq > 1e-10 {
let inv_norm = 1.0 / norm_sq.sqrt();
for c in psi.iter_mut() {
*c = c.scale(inv_norm);
}
}
}
fn matvec_real(m: &[f32], x: &[Complex], dim: usize) -> Vec<Complex> {
let mut y = vec![Complex::ZERO; dim];
for i in 0..dim {
let mut sum = Complex::ZERO;
for j in 0..dim {
let mij = m[i * dim + j];
sum = sum.add(x[j].scale(mij));
}
y[i] = sum;
}
y
}
fn matvec_transpose_real(m: &[f32], x: &[Complex], dim: usize) -> Vec<Complex> {
let mut y = vec![Complex::ZERO; dim];
for j in 0..dim {
for i in 0..dim {
let mij = m[i * dim + j]; y[j] = y[j].add(x[i].scale(mij));
}
}
y
}
fn diagonalize_real_symmetric(h: &[f32], eigenvalues: &mut [f32], eigenvectors: &mut [f32], dim: usize) {
let mut work = vec![Complex::ZERO; dim * dim];
for i in 0..dim * dim {
work[i] = Complex::new(h[i], 0.0);
}
dreamwell_math::eigen::eigenvalues_hermitian(&mut work, eigenvalues, dim, 50, 1e-6);
for i in 0..dim * dim {
eigenvectors[i] = work[i].re;
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::transformer::{QCTConfig, QCT};
#[test]
fn embed_amplitude_matches_populations() {
let config = QCTConfig {
vocab_size: 65,
dim: 5,
num_blocks: 2,
seed: 42,
};
let model = QCT::new(config);
for token in 0..10 {
let rho = model.embedding.embed(token);
let psi = model.embedding.embed_amplitude(token);
let pops_rho = rho.populations();
let pops_psi = populations_from_amplitudes(&psi);
for k in 0..5 {
assert!(
(pops_rho[k] - pops_psi[k]).abs() < 1e-5,
"token {token} mode {k}: rho={} psi={}",
pops_rho[k],
pops_psi[k]
);
}
}
}
#[test]
fn fock_forward_produces_valid_logits() {
let config = QCTConfig {
vocab_size: 10,
dim: 5,
num_blocks: 2,
seed: 42,
};
let model = QCT::new(config);
let tokens = vec![0, 1, 2, 3, 4, 5];
let (logits, avg_f, _cache) = fock_forward(&model, &tokens);
assert_eq!(logits.len(), tokens.len());
for l in &logits {
assert_eq!(l.len(), 10);
for &v in l {
assert!(v.is_finite(), "logit not finite: {v}");
}
}
assert!(avg_f.is_finite(), "free energy not finite: {avg_f}");
}
#[test]
fn fock_forward_loss_is_finite() {
let config = QCTConfig {
vocab_size: 10,
dim: 5,
num_blocks: 2,
seed: 42,
};
let model = QCT::new(config);
let tokens = vec![0, 1, 2, 3, 4, 5, 6, 7];
let (logits, avg_f, _cache) = fock_forward(&model, &tokens[..7]);
let loss = QCT::loss_from_logits(&logits, &tokens, avg_f);
assert!(loss.is_finite(), "loss not finite: {loss}");
assert!(loss > 0.0, "loss should be positive: {loss}");
}
#[test]
fn fock_backward_produces_gradients() {
let config = QCTConfig {
vocab_size: 10,
dim: 5,
num_blocks: 2,
seed: 42,
};
let model = QCT::new(config);
let tokens = vec![0, 1, 2, 3, 4, 5, 6, 7];
let (logits, _avg_f, cache) = fock_forward(&model, &tokens[..7]);
let grads = fock_backward(&model, &tokens, &logits, &cache);
let grad_flat = grads.flatten();
let norm: f32 = grad_flat.iter().map(|g| g * g).sum::<f32>().sqrt();
assert!(norm > 1e-6, "gradient norm should be nonzero: {norm}");
}
#[test]
fn fock_training_reduces_loss() {
let config = QCTConfig {
vocab_size: 10,
dim: 5,
num_blocks: 2,
seed: 42,
};
let mut model = QCT::new(config);
let tokens = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5];
let (logits0, f0, cache0) = fock_forward(&model, &tokens[..15]);
let loss0 = QCT::loss_from_logits(&logits0, &tokens, f0);
for _ in 0..10 {
let (logits, avg_f, cache) = fock_forward(&model, &tokens[..15]);
let grads = fock_backward(&model, &tokens, &logits, &cache);
let grad_flat = grads.flatten();
model.apply_gradient_update(&grad_flat, 0.03, 1.0);
}
let (logits1, f1, _) = fock_forward(&model, &tokens[..15]);
let loss1 = QCT::loss_from_logits(&logits1, &tokens, f1);
assert!(
loss1 < loss0 + 0.1,
"loss should decrease or stay flat: {loss0} → {loss1}"
);
}
}