use crate::density_matrix::DensityMatrixN;
use crate::hamiltonian::LearnedHamiltonian;
pub struct AttentionOutput {
pub populations: Vec<Vec<f32>>,
pub free_energies: Vec<f32>,
pub coherences: Vec<f32>,
}
pub fn quantum_causal_attention(tokens: &[DensityMatrixN], hamiltonian: &LearnedHamiltonian) -> AttentionOutput {
let t = tokens.len();
let dim = hamiltonian.dim;
let mut populations = Vec::with_capacity(t);
let mut free_energies = Vec::with_capacity(t);
let mut coherences = Vec::with_capacity(t);
for i in 0..t {
let mut rho = tokens[i].clone();
let window_start = i.saturating_sub(8);
for j in window_start..i {
let distance = i - j;
let eps = hamiltonian.causal_dephasing(distance);
rho.couple_dephase(&tokens[j], eps);
}
let h_matrix = hamiltonian.build_matrix(i);
let unitary = DensityMatrixN::hamiltonian_unitary(&h_matrix, dim, 0.090); rho.evolve(&unitary);
let f = rho.free_energy(&hamiltonian.bias);
let coh = rho.coherence_magnitude();
let pops = rho.populations();
free_energies.push(f);
coherences.push(coh);
populations.push(pops);
}
AttentionOutput {
populations,
free_energies,
coherences,
}
}
pub fn attention_project(attention: &AttentionOutput, values: &[Vec<f32>], dim: usize) -> Vec<Vec<f32>> {
let t = values.len();
let v_dim = values.first().map(|v| v.len()).unwrap_or(0);
let mut outputs = Vec::with_capacity(t);
for i in 0..t {
let mut out = vec![0.0f32; v_dim];
let pops = &attention.populations[i];
for j in 0..=i {
let weight = pops[j % dim];
for (k, v) in values[j].iter().enumerate() {
out[k] += weight * v;
}
}
let norm = (i + 1) as f32;
for v in &mut out {
*v /= norm;
}
outputs.push(out);
}
outputs
}
#[cfg(test)]
mod tests {
use super::*;
use crate::embed::QuantumEmbedding;
#[test]
fn attention_produces_valid_output() {
let dim = 5;
let emb = QuantumEmbedding::new(65, dim, 42);
let h = LearnedHamiltonian::new(dim, 42);
let tokens: Vec<DensityMatrixN> = (0..8).map(|t| emb.embed(t)).collect();
let output = quantum_causal_attention(&tokens, &h);
assert_eq!(output.populations.len(), 8);
assert_eq!(output.free_energies.len(), 8);
assert_eq!(output.coherences.len(), 8);
for (i, pops) in output.populations.iter().enumerate() {
assert_eq!(pops.len(), dim, "token {i} should have {dim} populations");
let sum: f32 = pops.iter().sum();
assert!(
(sum - 1.0).abs() < 0.1,
"token {i}: populations should sum to ~1.0, got {sum}"
);
}
}
#[test]
fn coherence_decays_with_context_length() {
let dim = 5;
let emb = QuantumEmbedding::new(65, dim, 42);
let mut h = LearnedHamiltonian::new(dim, 42);
h.dephasing_rate = 0.3;
let token_state = emb.embed(10);
let tokens: Vec<DensityMatrixN> = (0..16).map(|_| token_state.clone()).collect();
let output = quantum_causal_attention(&tokens, &h);
let first_coh = output.coherences[1]; let last_coh = output.coherences[15];
assert!(
last_coh <= first_coh + 0.01,
"coherence should decay: first={first_coh}, last={last_coh}"
);
}
#[test]
fn free_energy_finite_for_all_tokens() {
let dim = 5;
let emb = QuantumEmbedding::new(65, dim, 42);
let h = LearnedHamiltonian::new(dim, 42);
let tokens: Vec<DensityMatrixN> = (0..32).map(|t| emb.embed(t % 65)).collect();
let output = quantum_causal_attention(&tokens, &h);
for (i, &f) in output.free_energies.iter().enumerate() {
assert!(f.is_finite(), "token {i}: free energy must be finite, got {f}");
}
}
}