use crate::attention::{attention_project, quantum_causal_attention, AttentionOutput};
use crate::density_matrix::DensityMatrixN;
use crate::embed::QuantumEmbedding;
use crate::hamiltonian::LearnedHamiltonian;
#[derive(Clone, Debug)]
pub struct QCTConfig {
pub vocab_size: usize,
pub dim: usize,
pub num_blocks: usize,
pub seed: u64,
}
impl Default for QCTConfig {
fn default() -> Self {
Self {
vocab_size: 65, dim: 5, num_blocks: 2, seed: 42,
}
}
}
#[derive(Clone)]
pub struct QCTBlock {
pub hamiltonian: LearnedHamiltonian,
pub value_weights: Vec<f32>,
}
const PHI_INV: f32 = 0.618033988;
impl QCTBlock {
pub fn new(dim: usize, seed: u64) -> Self {
let scale = PHI_INV;
let mut value_weights = Vec::with_capacity(dim * dim);
for i in 0..(dim * dim) {
let s = seed.wrapping_add((i + 1000) as u64).wrapping_mul(0x94d049bb133111eb);
value_weights.push(scale * ((s % 2000) as f32 / 1000.0 - 1.0));
}
Self {
hamiltonian: LearnedHamiltonian::new(dim, seed),
value_weights,
}
}
pub fn forward(&self, states: &[DensityMatrixN], values: &[Vec<f32>]) -> (AttentionOutput, Vec<Vec<f32>>) {
let attn = quantum_causal_attention(states, &self.hamiltonian);
let projected = attention_project(&attn, values, self.hamiltonian.dim);
(attn, projected)
}
pub fn num_params(&self) -> usize {
self.hamiltonian.num_params() + self.value_weights.len()
}
}
#[derive(Clone)]
pub struct QCT {
pub config: QCTConfig,
pub embedding: QuantumEmbedding,
pub blocks: Vec<QCTBlock>,
pub output_weights: Vec<f32>,
}
impl QCT {
pub fn new(config: QCTConfig) -> Self {
let embedding = QuantumEmbedding::new(config.vocab_size, config.dim, config.seed);
let mut blocks = Vec::with_capacity(config.num_blocks);
for i in 0..config.num_blocks {
blocks.push(QCTBlock::new(config.dim, config.seed.wrapping_add(i as u64 * 1000)));
}
let out_scale = 0.090169944_f32; let mut output_weights = Vec::with_capacity(config.dim * config.vocab_size);
for i in 0..(config.dim * config.vocab_size) {
let s = config
.seed
.wrapping_add((i + 5000) as u64)
.wrapping_mul(0x517cc1b727220a95);
output_weights.push(out_scale * ((s % 2000) as f32 / 1000.0 - 1.0));
}
Self {
config,
embedding,
blocks,
output_weights,
}
}
pub fn forward(&self, tokens: &[usize]) -> (Vec<Vec<f32>>, f32) {
let dim = self.config.dim;
let t = tokens.len();
let states: Vec<DensityMatrixN> = tokens.iter().map(|&tok| self.embedding.embed(tok)).collect();
let mut values: Vec<Vec<f32>> = states.iter().map(|s| s.populations()).collect();
let mut total_free_energy = 0.0f32;
let mut current_states = states;
for block in &self.blocks {
let (attn, new_values) = block.forward(¤t_states, &values);
total_free_energy += attn.free_energies.iter().sum::<f32>();
values = new_values;
let eps_block = 0.236 / self.blocks.len().max(1) as f32;
for state in &mut current_states {
state.dephase(eps_block);
}
}
let vocab = self.config.vocab_size;
let mut logits = Vec::with_capacity(t);
for i in 0..t {
let mut token_logits = vec![0.0f32; vocab];
for v in 0..vocab {
for d in 0..dim {
token_logits[v] += values[i][d] * self.output_weights[d * vocab + v];
}
}
logits.push(token_logits);
}
(logits, total_free_energy / t as f32)
}
pub fn num_params(&self) -> usize {
let embed_params = self.embedding.num_params();
let block_params: usize = self.blocks.iter().map(|b| b.num_params()).sum();
let output_params = self.output_weights.len();
embed_params + block_params + output_params
}
pub fn all_params(&self) -> Vec<f32> {
let mut p = Vec::with_capacity(self.num_params());
p.extend_from_slice(&self.embedding.angles);
for block in &self.blocks {
p.extend_from_slice(&block.hamiltonian.params());
p.extend_from_slice(&block.value_weights);
}
p.extend_from_slice(&self.output_weights);
p
}
pub fn set_all_params(&mut self, params: &[f32]) {
let mut offset = 0;
let embed_len = self.embedding.angles.len();
self.embedding.angles[..embed_len].copy_from_slice(¶ms[offset..offset + embed_len]);
offset += embed_len;
for block in &mut self.blocks {
let h_len = block.hamiltonian.num_params();
block.hamiltonian.set_params(¶ms[offset..offset + h_len]);
offset += h_len;
let v_len = block.value_weights.len();
block.value_weights[..v_len].copy_from_slice(¶ms[offset..offset + v_len]);
offset += v_len;
}
let out_len = self.output_weights.len();
self.output_weights[..out_len].copy_from_slice(¶ms[offset..offset + out_len]);
}
pub fn apply_gradient_update(&mut self, grad: &[f32], lr: f32, scale: f32) -> usize {
let mut offset = 0;
let factor = lr * scale;
let embed_len = self.embedding.angles.len();
for k in 0..embed_len.min(grad.len()) {
self.embedding.angles[k] -= factor * grad[k];
}
offset += embed_len;
for block in &mut self.blocks {
let d = block.hamiltonian.dim;
for k in 0..d {
if offset + k < grad.len() {
block.hamiltonian.bias[k] -= factor * grad[offset + k];
}
}
offset += d;
let nc = block.hamiltonian.couplings.len();
for k in 0..nc {
if offset + k < grad.len() {
block.hamiltonian.couplings[k] -= factor * grad[offset + k];
}
}
offset += nc;
if offset < grad.len() {
block.hamiltonian.dephasing_rate =
(block.hamiltonian.dephasing_rate - factor * grad[offset]).clamp(0.013155617, 1.0);
}
offset += 1;
if offset < grad.len() {
block.hamiltonian.temperature =
(block.hamiltonian.temperature - factor * grad[offset]).clamp(0.090169944, 11.09017);
}
offset += 1;
let v_len = block.value_weights.len();
for k in 0..v_len {
if offset + k < grad.len() {
block.value_weights[k] -= factor * grad[offset + k];
}
}
offset += v_len;
}
let out_len = self.output_weights.len();
for k in 0..out_len {
if offset + k < grad.len() {
self.output_weights[k] -= factor * grad[offset + k];
}
}
offset += out_len;
offset.min(grad.len())
}
pub fn loss(&self, tokens: &[usize]) -> f32 {
if tokens.len() < 2 {
return 0.0;
}
let (logits, avg_free_energy) = self.forward(&tokens[..tokens.len() - 1]);
Self::loss_from_logits(&logits, tokens, avg_free_energy)
}
pub fn loss_from_logits(logits: &[Vec<f32>], tokens: &[usize], avg_free_energy: f32) -> f32 {
if tokens.len() < 2 || logits.is_empty() {
return 0.0;
}
let mut total_ce = 0.0f32;
let n = logits.len();
for (i, token_logits) in logits.iter().enumerate() {
let target = if i + 1 < tokens.len() { tokens[i + 1] } else { continue };
let max_logit = token_logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let exp_sum: f32 = token_logits.iter().map(|&l| (l - max_logit).exp()).sum();
let log_prob = (token_logits[target] - max_logit) - exp_sum.ln();
total_ce -= log_prob;
}
let avg_ce = total_ce / n as f32;
avg_ce + 0.146 * avg_free_energy
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn qct_forward_produces_logits() {
let config = QCTConfig::default();
let model = QCT::new(config.clone());
let tokens = vec![0, 1, 2, 3, 4, 5, 6, 7];
let (logits, free_energy) = model.forward(&tokens);
assert_eq!(logits.len(), tokens.len());
for (i, l) in logits.iter().enumerate() {
assert_eq!(l.len(), config.vocab_size, "token {i}: logit dim should be vocab_size");
}
assert!(free_energy.is_finite(), "free energy should be finite");
}
#[test]
fn qct_loss_finite() {
let model = QCT::new(QCTConfig::default());
let tokens = vec![0, 1, 2, 3, 4, 5];
let loss = model.loss(&tokens);
assert!(loss.is_finite(), "loss should be finite: {loss}");
assert!(loss > 0.0, "loss should be positive: {loss}");
}
#[test]
fn qct_param_count() {
let config = QCTConfig {
vocab_size: 65,
dim: 5,
num_blocks: 2,
seed: 42,
};
let model = QCT::new(config);
let params = model.num_params();
assert!(params > 0, "should have parameters: {params}");
eprintln!("QCT parameter count: {params}");
}
#[test]
fn qct_deterministic() {
let model = QCT::new(QCTConfig::default());
let tokens = vec![10, 20, 30, 40, 50];
let loss_a = model.loss(&tokens);
let loss_b = model.loss(&tokens);
assert_eq!(loss_a, loss_b, "QCT should be deterministic");
}
}