dreamwell_intelligence/
embed.rs1use crate::complex::Complex;
8use crate::density_matrix::DensityMatrixN;
9
10#[derive(Clone)]
12pub struct QuantumEmbedding {
13 pub vocab_size: usize,
14 pub dim: usize,
15 pub angles: Vec<f32>,
18}
19
20impl QuantumEmbedding {
21 pub fn new(vocab_size: usize, dim: usize, seed: u64) -> Self {
22 let mut angles = Vec::with_capacity(vocab_size * dim);
23 for i in 0..(vocab_size * dim) {
24 let s = seed.wrapping_add(i as u64).wrapping_mul(0x517cc1b727220a95);
25 angles.push(std::f32::consts::PI * ((s % 2000) as f32 / 1000.0 - 1.0));
26 }
27 Self {
28 vocab_size,
29 dim,
30 angles,
31 }
32 }
33
34 pub fn embed(&self, token: usize) -> DensityMatrixN {
38 let d = self.dim;
39 let base = token % d;
40
41 let offset = token * d;
43 let mut amplitudes = vec![Complex::ZERO; d];
44
45 amplitudes[base] = Complex::ONE;
47
48 for k in 0..d {
50 if k == base {
51 continue;
52 }
53 let angle = self.angles.get(offset + k).copied().unwrap_or(0.0);
54 let c = angle.cos();
55 let s = angle.sin();
56 let old_base = amplitudes[base];
57 let old_k = amplitudes[k];
58 amplitudes[base] = Complex::new(c * old_base.re - s * old_k.re, c * old_base.im - s * old_k.im);
59 amplitudes[k] = Complex::new(s * old_base.re + c * old_k.re, s * old_base.im + c * old_k.im);
60 }
61
62 let mut entries = vec![Complex::ZERO; d * d];
64 for i in 0..d {
65 for j in 0..d {
66 entries[i * d + j] = amplitudes[i].mul(amplitudes[j].conj());
67 }
68 }
69
70 let n2 = d * d;
71 DensityMatrixN {
72 dim: d,
73 entries,
74 scratch_a: vec![Complex::ZERO; n2],
75 scratch_b: vec![Complex::ZERO; n2],
76 }
77 }
78
79 pub fn embed_amplitude(&self, token: usize) -> Vec<Complex> {
88 let d = self.dim;
89 let base = token % d;
90 let offset = token * d;
91 let mut amplitudes = vec![Complex::ZERO; d];
92
93 amplitudes[base] = Complex::ONE;
94
95 for k in 0..d {
96 if k == base {
97 continue;
98 }
99 let angle = self.angles.get(offset + k).copied().unwrap_or(0.0);
100 let c = angle.cos();
101 let s = angle.sin();
102 let old_base = amplitudes[base];
103 let old_k = amplitudes[k];
104 amplitudes[base] = Complex::new(c * old_base.re - s * old_k.re, c * old_base.im - s * old_k.im);
105 amplitudes[k] = Complex::new(s * old_base.re + c * old_k.re, s * old_base.im + c * old_k.im);
106 }
107
108 amplitudes
109 }
110
111 pub fn num_params(&self) -> usize {
113 self.angles.len()
114 }
115}
116
117#[cfg(test)]
118mod tests {
119 use super::*;
120
121 #[test]
122 fn embedding_produces_valid_state() {
123 let emb = QuantumEmbedding::new(65, 5, 42);
124 for token in 0..65 {
125 let rho = emb.embed(token);
126 assert!(
127 (rho.trace() - 1.0).abs() < 1e-4,
128 "token {token}: trace = {}",
129 rho.trace()
130 );
131 assert!(
132 rho.purity() > 0.9,
133 "token {token}: embedded state should be nearly pure, got {}",
134 rho.purity()
135 );
136 }
137 }
138
139 #[test]
140 fn different_tokens_different_states() {
141 let emb = QuantumEmbedding::new(65, 5, 42);
142 let rho_a = emb.embed(0);
143 let rho_b = emb.embed(1);
144 let pops_a = rho_a.populations();
146 let pops_b = rho_b.populations();
147 let diff: f32 = pops_a.iter().zip(pops_b.iter()).map(|(a, b)| (a - b).abs()).sum();
148 assert!(diff > 0.01, "Different tokens should produce different states");
149 }
150}