dreamwell_intelligence/
transformer.rs1use crate::attention::{attention_project, quantum_causal_attention, AttentionOutput};
9use crate::density_matrix::DensityMatrixN;
10use crate::embed::QuantumEmbedding;
11use crate::hamiltonian::LearnedHamiltonian;
12
13#[derive(Clone, Debug)]
15pub struct QCTConfig {
16 pub vocab_size: usize,
18 pub dim: usize,
20 pub num_blocks: usize,
22 pub seed: u64,
24}
25
26impl Default for QCTConfig {
27 fn default() -> Self {
28 Self {
29 vocab_size: 65, dim: 5, num_blocks: 2, seed: 42,
33 }
34 }
35}
36
37#[derive(Clone)]
39pub struct QCTBlock {
40 pub hamiltonian: LearnedHamiltonian,
41 pub value_weights: Vec<f32>,
43}
44
45const PHI_INV: f32 = 0.618033988;
46
47impl QCTBlock {
48 pub fn new(dim: usize, seed: u64) -> Self {
49 let scale = PHI_INV;
52 let mut value_weights = Vec::with_capacity(dim * dim);
53 for i in 0..(dim * dim) {
54 let s = seed.wrapping_add((i + 1000) as u64).wrapping_mul(0x94d049bb133111eb);
55 value_weights.push(scale * ((s % 2000) as f32 / 1000.0 - 1.0));
56 }
57 Self {
58 hamiltonian: LearnedHamiltonian::new(dim, seed),
59 value_weights,
60 }
61 }
62
63 pub fn forward(&self, states: &[DensityMatrixN], values: &[Vec<f32>]) -> (AttentionOutput, Vec<Vec<f32>>) {
67 let attn = quantum_causal_attention(states, &self.hamiltonian);
68 let projected = attention_project(&attn, values, self.hamiltonian.dim);
69 (attn, projected)
70 }
71
72 pub fn num_params(&self) -> usize {
74 self.hamiltonian.num_params() + self.value_weights.len()
75 }
76}
77
78#[derive(Clone)]
80pub struct QCT {
81 pub config: QCTConfig,
82 pub embedding: QuantumEmbedding,
83 pub blocks: Vec<QCTBlock>,
84 pub output_weights: Vec<f32>,
86}
87
88impl QCT {
89 pub fn new(config: QCTConfig) -> Self {
90 let embedding = QuantumEmbedding::new(config.vocab_size, config.dim, config.seed);
91
92 let mut blocks = Vec::with_capacity(config.num_blocks);
93 for i in 0..config.num_blocks {
94 blocks.push(QCTBlock::new(config.dim, config.seed.wrapping_add(i as u64 * 1000)));
95 }
96
97 let out_scale = 0.090169944_f32; let mut output_weights = Vec::with_capacity(config.dim * config.vocab_size);
103 for i in 0..(config.dim * config.vocab_size) {
104 let s = config
105 .seed
106 .wrapping_add((i + 5000) as u64)
107 .wrapping_mul(0x517cc1b727220a95);
108 output_weights.push(out_scale * ((s % 2000) as f32 / 1000.0 - 1.0));
109 }
110
111 Self {
112 config,
113 embedding,
114 blocks,
115 output_weights,
116 }
117 }
118
119 pub fn forward(&self, tokens: &[usize]) -> (Vec<Vec<f32>>, f32) {
122 let dim = self.config.dim;
123 let t = tokens.len();
124
125 let states: Vec<DensityMatrixN> = tokens.iter().map(|&tok| self.embedding.embed(tok)).collect();
127
128 let mut values: Vec<Vec<f32>> = states.iter().map(|s| s.populations()).collect();
130
131 let mut total_free_energy = 0.0f32;
133 let mut current_states = states;
134
135 for block in &self.blocks {
136 let (attn, new_values) = block.forward(¤t_states, &values);
137
138 total_free_energy += attn.free_energies.iter().sum::<f32>();
140
141 values = new_values;
143
144 let eps_block = 0.236 / self.blocks.len().max(1) as f32;
147 for state in &mut current_states {
148 state.dephase(eps_block);
149 }
150 }
151
152 let vocab = self.config.vocab_size;
154 let mut logits = Vec::with_capacity(t);
155 for i in 0..t {
156 let mut token_logits = vec![0.0f32; vocab];
157 for v in 0..vocab {
158 for d in 0..dim {
159 token_logits[v] += values[i][d] * self.output_weights[d * vocab + v];
160 }
161 }
162 logits.push(token_logits);
163 }
164
165 (logits, total_free_energy / t as f32)
166 }
167
168 pub fn num_params(&self) -> usize {
170 let embed_params = self.embedding.num_params();
171 let block_params: usize = self.blocks.iter().map(|b| b.num_params()).sum();
172 let output_params = self.output_weights.len();
173 embed_params + block_params + output_params
174 }
175
176 pub fn all_params(&self) -> Vec<f32> {
179 let mut p = Vec::with_capacity(self.num_params());
180 p.extend_from_slice(&self.embedding.angles);
181 for block in &self.blocks {
182 p.extend_from_slice(&block.hamiltonian.params());
183 p.extend_from_slice(&block.value_weights);
184 }
185 p.extend_from_slice(&self.output_weights);
186 p
187 }
188
189 pub fn set_all_params(&mut self, params: &[f32]) {
191 let mut offset = 0;
192 let embed_len = self.embedding.angles.len();
193 self.embedding.angles[..embed_len].copy_from_slice(¶ms[offset..offset + embed_len]);
194 offset += embed_len;
195 for block in &mut self.blocks {
196 let h_len = block.hamiltonian.num_params();
197 block.hamiltonian.set_params(¶ms[offset..offset + h_len]);
198 offset += h_len;
199 let v_len = block.value_weights.len();
200 block.value_weights[..v_len].copy_from_slice(¶ms[offset..offset + v_len]);
201 offset += v_len;
202 }
203 let out_len = self.output_weights.len();
204 self.output_weights[..out_len].copy_from_slice(¶ms[offset..offset + out_len]);
205 }
206
207 pub fn apply_gradient_update(&mut self, grad: &[f32], lr: f32, scale: f32) -> usize {
211 let mut offset = 0;
212 let factor = lr * scale;
213
214 let embed_len = self.embedding.angles.len();
216 for k in 0..embed_len.min(grad.len()) {
217 self.embedding.angles[k] -= factor * grad[k];
218 }
219 offset += embed_len;
220
221 for block in &mut self.blocks {
223 let d = block.hamiltonian.dim;
225 for k in 0..d {
226 if offset + k < grad.len() {
227 block.hamiltonian.bias[k] -= factor * grad[offset + k];
228 }
229 }
230 offset += d;
231
232 let nc = block.hamiltonian.couplings.len();
234 for k in 0..nc {
235 if offset + k < grad.len() {
236 block.hamiltonian.couplings[k] -= factor * grad[offset + k];
237 }
238 }
239 offset += nc;
240
241 if offset < grad.len() {
244 block.hamiltonian.dephasing_rate =
245 (block.hamiltonian.dephasing_rate - factor * grad[offset]).clamp(0.013155617, 1.0);
246 }
248 offset += 1;
249 if offset < grad.len() {
250 block.hamiltonian.temperature =
251 (block.hamiltonian.temperature - factor * grad[offset]).clamp(0.090169944, 11.09017);
252 }
254 offset += 1;
255
256 let v_len = block.value_weights.len();
258 for k in 0..v_len {
259 if offset + k < grad.len() {
260 block.value_weights[k] -= factor * grad[offset + k];
261 }
262 }
263 offset += v_len;
264 }
265
266 let out_len = self.output_weights.len();
268 for k in 0..out_len {
269 if offset + k < grad.len() {
270 self.output_weights[k] -= factor * grad[offset + k];
271 }
272 }
273 offset += out_len;
274
275 offset.min(grad.len())
276 }
277
278 pub fn loss(&self, tokens: &[usize]) -> f32 {
282 if tokens.len() < 2 {
283 return 0.0;
284 }
285 let (logits, avg_free_energy) = self.forward(&tokens[..tokens.len() - 1]);
286 Self::loss_from_logits(&logits, tokens, avg_free_energy)
287 }
288
289 pub fn loss_from_logits(logits: &[Vec<f32>], tokens: &[usize], avg_free_energy: f32) -> f32 {
292 if tokens.len() < 2 || logits.is_empty() {
293 return 0.0;
294 }
295 let mut total_ce = 0.0f32;
296 let n = logits.len();
297
298 for (i, token_logits) in logits.iter().enumerate() {
299 let target = if i + 1 < tokens.len() { tokens[i + 1] } else { continue };
300
301 let max_logit = token_logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
303 let exp_sum: f32 = token_logits.iter().map(|&l| (l - max_logit).exp()).sum();
304 let log_prob = (token_logits[target] - max_logit) - exp_sum.ln();
305 total_ce -= log_prob;
306 }
307
308 let avg_ce = total_ce / n as f32;
309 avg_ce + 0.146 * avg_free_energy
312 }
313}
314
315#[cfg(test)]
316mod tests {
317 use super::*;
318
319 #[test]
320 fn qct_forward_produces_logits() {
321 let config = QCTConfig::default();
322 let model = QCT::new(config.clone());
323 let tokens = vec![0, 1, 2, 3, 4, 5, 6, 7];
324 let (logits, free_energy) = model.forward(&tokens);
325
326 assert_eq!(logits.len(), tokens.len());
327 for (i, l) in logits.iter().enumerate() {
328 assert_eq!(l.len(), config.vocab_size, "token {i}: logit dim should be vocab_size");
329 }
330 assert!(free_energy.is_finite(), "free energy should be finite");
331 }
332
333 #[test]
334 fn qct_loss_finite() {
335 let model = QCT::new(QCTConfig::default());
336 let tokens = vec![0, 1, 2, 3, 4, 5];
337 let loss = model.loss(&tokens);
338 assert!(loss.is_finite(), "loss should be finite: {loss}");
339 assert!(loss > 0.0, "loss should be positive: {loss}");
340 }
341
342 #[test]
343 fn qct_param_count() {
344 let config = QCTConfig {
345 vocab_size: 65,
346 dim: 5,
347 num_blocks: 2,
348 seed: 42,
349 };
350 let model = QCT::new(config);
351 let params = model.num_params();
352 assert!(params > 0, "should have parameters: {params}");
358 eprintln!("QCT parameter count: {params}");
359 }
360
361 #[test]
362 fn qct_deterministic() {
363 let model = QCT::new(QCTConfig::default());
364 let tokens = vec![10, 20, 30, 40, 50];
365 let loss_a = model.loss(&tokens);
366 let loss_b = model.loss(&tokens);
367 assert_eq!(loss_a, loss_b, "QCT should be deterministic");
368 }
369}