Skip to main content

dreamwell_intelligence/
transformer.rs

1// QCT — Quantum Causal Transformer.
2//
3// Stack of quantum attention blocks. Each block: embed → evolve → dephase → measure → project.
4// The full transformer: tokenize → [QCT block × N] → readout → loss.
5//
6// Clean Compute: all buffers pre-allocated per block. No dynamic allocation in forward pass.
7
8use crate::attention::{attention_project, quantum_causal_attention, AttentionOutput};
9use crate::density_matrix::DensityMatrixN;
10use crate::embed::QuantumEmbedding;
11use crate::hamiltonian::LearnedHamiltonian;
12
13/// Configuration for a QCT model.
14#[derive(Clone, Debug)]
15pub struct QCTConfig {
16    /// Vocabulary size (number of unique tokens).
17    pub vocab_size: usize,
18    /// Model dimension (density matrix size per token).
19    pub dim: usize,
20    /// Number of QCT blocks (layers).
21    pub num_blocks: usize,
22    /// Random seed for initialization.
23    pub seed: u64,
24}
25
26impl Default for QCTConfig {
27    fn default() -> Self {
28        Self {
29            vocab_size: 65, // ASCII printable (nanoGPT Shakespeare default)
30            dim: 5,         // 5-mode density matrix (matches our toy models)
31            num_blocks: 2,  // 2 blocks for proof of concept
32            seed: 42,
33        }
34    }
35}
36
37/// A single QCT block: quantum attention + value projection.
38#[derive(Clone)]
39pub struct QCTBlock {
40    pub hamiltonian: LearnedHamiltonian,
41    /// Value projection weights [dim × dim]. Maps populations → output features.
42    pub value_weights: Vec<f32>,
43}
44
45const PHI_INV: f32 = 0.618033988;
46
47impl QCTBlock {
48    pub fn new(dim: usize, seed: u64) -> Self {
49        // Value weight scale: 1/φ — golden partition of [-1, 1] range.
50        // Matches cloud compression ratio and every other blend weight in the pipeline.
51        let scale = PHI_INV;
52        let mut value_weights = Vec::with_capacity(dim * dim);
53        for i in 0..(dim * dim) {
54            let s = seed.wrapping_add((i + 1000) as u64).wrapping_mul(0x94d049bb133111eb);
55            value_weights.push(scale * ((s % 2000) as f32 / 1000.0 - 1.0));
56        }
57        Self {
58            hamiltonian: LearnedHamiltonian::new(dim, seed),
59            value_weights,
60        }
61    }
62
63    /// Forward pass through one QCT block.
64    /// Input: sequence of density matrices + value vectors.
65    /// Output: updated value vectors (populations-weighted projection).
66    pub fn forward(&self, states: &[DensityMatrixN], values: &[Vec<f32>]) -> (AttentionOutput, Vec<Vec<f32>>) {
67        let attn = quantum_causal_attention(states, &self.hamiltonian);
68        let projected = attention_project(&attn, values, self.hamiltonian.dim);
69        (attn, projected)
70    }
71
72    /// Number of learnable parameters in this block.
73    pub fn num_params(&self) -> usize {
74        self.hamiltonian.num_params() + self.value_weights.len()
75    }
76}
77
78/// The full Quantum Causal Transformer.
79#[derive(Clone)]
80pub struct QCT {
81    pub config: QCTConfig,
82    pub embedding: QuantumEmbedding,
83    pub blocks: Vec<QCTBlock>,
84    /// Output projection: dim → vocab_size (logits).
85    pub output_weights: Vec<f32>,
86}
87
88impl QCT {
89    pub fn new(config: QCTConfig) -> Self {
90        let embedding = QuantumEmbedding::new(config.vocab_size, config.dim, config.seed);
91
92        let mut blocks = Vec::with_capacity(config.num_blocks);
93        for i in 0..config.num_blocks {
94            blocks.push(QCTBlock::new(config.dim, config.seed.wrapping_add(i as u64 * 1000)));
95        }
96
97        // Output projection: dim → vocab_size.
98        // Scale: 1/φ⁵ ≈ 0.090 — matches evolution dt and Hamiltonian bias range.
99        // Output is the final projection to logits; small init prevents saturation
100        // while maintaining the φ-chain through the entire parameter space.
101        let out_scale = 0.090169944_f32; // 1/φ⁵
102        let mut output_weights = Vec::with_capacity(config.dim * config.vocab_size);
103        for i in 0..(config.dim * config.vocab_size) {
104            let s = config
105                .seed
106                .wrapping_add((i + 5000) as u64)
107                .wrapping_mul(0x517cc1b727220a95);
108            output_weights.push(out_scale * ((s % 2000) as f32 / 1000.0 - 1.0));
109        }
110
111        Self {
112            config,
113            embedding,
114            blocks,
115            output_weights,
116        }
117    }
118
119    /// Forward pass: tokens → logits.
120    /// Returns (logits, total_free_energy).
121    pub fn forward(&self, tokens: &[usize]) -> (Vec<Vec<f32>>, f32) {
122        let dim = self.config.dim;
123        let t = tokens.len();
124
125        // 1. Embed tokens as density matrices
126        let states: Vec<DensityMatrixN> = tokens.iter().map(|&tok| self.embedding.embed(tok)).collect();
127
128        // 2. Initial value vectors = populations of embedded states
129        let mut values: Vec<Vec<f32>> = states.iter().map(|s| s.populations()).collect();
130
131        // 3. Pass through QCT blocks
132        let mut total_free_energy = 0.0f32;
133        let mut current_states = states;
134
135        for block in &self.blocks {
136            let (attn, new_values) = block.forward(&current_states, &values);
137
138            // Accumulate free energy from attention
139            total_free_energy += attn.free_energies.iter().sum::<f32>();
140
141            // Update values; states carry through (coherences persist between blocks)
142            values = new_values;
143
144            // Inter-block dephasing: ε = 1/(φ³ × num_blocks).
145            // Surviving coherence after all blocks: exp(-1/φ³) ≈ 79%. Scale-invariant.
146            let eps_block = 0.236 / self.blocks.len().max(1) as f32;
147            for state in &mut current_states {
148                state.dephase(eps_block);
149            }
150        }
151
152        // 4. Output projection: values → logits over vocabulary
153        let vocab = self.config.vocab_size;
154        let mut logits = Vec::with_capacity(t);
155        for i in 0..t {
156            let mut token_logits = vec![0.0f32; vocab];
157            for v in 0..vocab {
158                for d in 0..dim {
159                    token_logits[v] += values[i][d] * self.output_weights[d * vocab + v];
160                }
161            }
162            logits.push(token_logits);
163        }
164
165        (logits, total_free_energy / t as f32)
166    }
167
168    /// Total number of learnable parameters.
169    pub fn num_params(&self) -> usize {
170        let embed_params = self.embedding.num_params();
171        let block_params: usize = self.blocks.iter().map(|b| b.num_params()).sum();
172        let output_params = self.output_weights.len();
173        embed_params + block_params + output_params
174    }
175
176    /// Flatten ALL model parameters into a single Vec.
177    /// Order: [embedding_angles | block_0_hamiltonian | block_0_values | ... | output_weights]
178    pub fn all_params(&self) -> Vec<f32> {
179        let mut p = Vec::with_capacity(self.num_params());
180        p.extend_from_slice(&self.embedding.angles);
181        for block in &self.blocks {
182            p.extend_from_slice(&block.hamiltonian.params());
183            p.extend_from_slice(&block.value_weights);
184        }
185        p.extend_from_slice(&self.output_weights);
186        p
187    }
188
189    /// Set ALL model parameters from a flat Vec.
190    pub fn set_all_params(&mut self, params: &[f32]) {
191        let mut offset = 0;
192        let embed_len = self.embedding.angles.len();
193        self.embedding.angles[..embed_len].copy_from_slice(&params[offset..offset + embed_len]);
194        offset += embed_len;
195        for block in &mut self.blocks {
196            let h_len = block.hamiltonian.num_params();
197            block.hamiltonian.set_params(&params[offset..offset + h_len]);
198            offset += h_len;
199            let v_len = block.value_weights.len();
200            block.value_weights[..v_len].copy_from_slice(&params[offset..offset + v_len]);
201            offset += v_len;
202        }
203        let out_len = self.output_weights.len();
204        self.output_weights[..out_len].copy_from_slice(&params[offset..offset + out_len]);
205    }
206
207    /// Apply gradient update in-place: θ ← θ - lr * grad * scale.
208    /// Zero allocation — updates parameters directly without all_params()/set_all_params().
209    /// Returns the number of parameters updated.
210    pub fn apply_gradient_update(&mut self, grad: &[f32], lr: f32, scale: f32) -> usize {
211        let mut offset = 0;
212        let factor = lr * scale;
213
214        // Embedding angles
215        let embed_len = self.embedding.angles.len();
216        for k in 0..embed_len.min(grad.len()) {
217            self.embedding.angles[k] -= factor * grad[k];
218        }
219        offset += embed_len;
220
221        // Block parameters
222        for block in &mut self.blocks {
223            // Hamiltonian bias
224            let d = block.hamiltonian.dim;
225            for k in 0..d {
226                if offset + k < grad.len() {
227                    block.hamiltonian.bias[k] -= factor * grad[offset + k];
228                }
229            }
230            offset += d;
231
232            // Hamiltonian couplings
233            let nc = block.hamiltonian.couplings.len();
234            for k in 0..nc {
235                if offset + k < grad.len() {
236                    block.hamiltonian.couplings[k] -= factor * grad[offset + k];
237                }
238            }
239            offset += nc;
240
241            // Hamiltonian dephasing_rate + temperature
242            // Clamps: φ-derived bounds matching set_params()
243            if offset < grad.len() {
244                block.hamiltonian.dephasing_rate =
245                    (block.hamiltonian.dephasing_rate - factor * grad[offset]).clamp(0.013155617, 1.0);
246                // [1/φ⁸, 1]
247            }
248            offset += 1;
249            if offset < grad.len() {
250                block.hamiltonian.temperature =
251                    (block.hamiltonian.temperature - factor * grad[offset]).clamp(0.090169944, 11.09017);
252                // [1/φ⁵, φ⁵]
253            }
254            offset += 1;
255
256            // Value weights
257            let v_len = block.value_weights.len();
258            for k in 0..v_len {
259                if offset + k < grad.len() {
260                    block.value_weights[k] -= factor * grad[offset + k];
261                }
262            }
263            offset += v_len;
264        }
265
266        // Output weights
267        let out_len = self.output_weights.len();
268        for k in 0..out_len {
269            if offset + k < grad.len() {
270                self.output_weights[k] -= factor * grad[offset + k];
271            }
272        }
273        offset += out_len;
274
275        offset.min(grad.len())
276    }
277
278    /// Compute cross-entropy loss for next-token prediction.
279    /// NOTE: This recomputes the full forward pass. For training loops that already
280    /// have logits from forward_with_cache(), use loss_from_logits() instead.
281    pub fn loss(&self, tokens: &[usize]) -> f32 {
282        if tokens.len() < 2 {
283            return 0.0;
284        }
285        let (logits, avg_free_energy) = self.forward(&tokens[..tokens.len() - 1]);
286        Self::loss_from_logits(&logits, tokens, avg_free_energy)
287    }
288
289    /// Compute loss from pre-computed logits. Zero-cost — no forward pass.
290    /// Use this in training loops where forward_with_cache() already produced logits.
291    pub fn loss_from_logits(logits: &[Vec<f32>], tokens: &[usize], avg_free_energy: f32) -> f32 {
292        if tokens.len() < 2 || logits.is_empty() {
293            return 0.0;
294        }
295        let mut total_ce = 0.0f32;
296        let n = logits.len();
297
298        for (i, token_logits) in logits.iter().enumerate() {
299            let target = if i + 1 < tokens.len() { tokens[i + 1] } else { continue };
300
301            // Softmax + cross-entropy
302            let max_logit = token_logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
303            let exp_sum: f32 = token_logits.iter().map(|&l| (l - max_logit).exp()).sum();
304            let log_prob = (token_logits[target] - max_logit) - exp_sum.ln();
305            total_ce -= log_prob;
306        }
307
308        let avg_ce = total_ce / n as f32;
309        // λ = 1/φ⁴ ≈ 0.146 — free energy gets real weight in the loss.
310        // The model optimizes for prediction accuracy AND coherent structure.
311        avg_ce + 0.146 * avg_free_energy
312    }
313}
314
315#[cfg(test)]
316mod tests {
317    use super::*;
318
319    #[test]
320    fn qct_forward_produces_logits() {
321        let config = QCTConfig::default();
322        let model = QCT::new(config.clone());
323        let tokens = vec![0, 1, 2, 3, 4, 5, 6, 7];
324        let (logits, free_energy) = model.forward(&tokens);
325
326        assert_eq!(logits.len(), tokens.len());
327        for (i, l) in logits.iter().enumerate() {
328            assert_eq!(l.len(), config.vocab_size, "token {i}: logit dim should be vocab_size");
329        }
330        assert!(free_energy.is_finite(), "free energy should be finite");
331    }
332
333    #[test]
334    fn qct_loss_finite() {
335        let model = QCT::new(QCTConfig::default());
336        let tokens = vec![0, 1, 2, 3, 4, 5];
337        let loss = model.loss(&tokens);
338        assert!(loss.is_finite(), "loss should be finite: {loss}");
339        assert!(loss > 0.0, "loss should be positive: {loss}");
340    }
341
342    #[test]
343    fn qct_param_count() {
344        let config = QCTConfig {
345            vocab_size: 65,
346            dim: 5,
347            num_blocks: 2,
348            seed: 42,
349        };
350        let model = QCT::new(config);
351        let params = model.num_params();
352        // Embedding: 65 * 5 = 325
353        // Per block: 5 bias + 10 couplings + 2 (dephasing+temp) + 25 value_weights = 42
354        // 2 blocks: 84
355        // Output: 5 * 65 = 325
356        // Total: 325 + 84 + 325 = 734
357        assert!(params > 0, "should have parameters: {params}");
358        eprintln!("QCT parameter count: {params}");
359    }
360
361    #[test]
362    fn qct_deterministic() {
363        let model = QCT::new(QCTConfig::default());
364        let tokens = vec![10, 20, 30, 40, 50];
365        let loss_a = model.loss(&tokens);
366        let loss_b = model.loss(&tokens);
367        assert_eq!(loss_a, loss_b, "QCT should be deterministic");
368    }
369}