Skip to main content

dreamwell_intelligence/
embed.rs

1// Quantum Token Embedding — maps tokens to density matrix states.
2//
3// Each token is embedded as a parameterized pure state on the Bloch hypersphere.
4// Similar tokens → nearby quantum states (high fidelity).
5// Dissimilar tokens → orthogonal states (zero fidelity).
6
7use crate::complex::Complex;
8use crate::density_matrix::DensityMatrixN;
9
10/// Quantum embedding: maps vocab indices to density matrices.
11#[derive(Clone)]
12pub struct QuantumEmbedding {
13    pub vocab_size: usize,
14    pub dim: usize,
15    /// Per-token rotation angles [vocab_size × dim]. Learnable.
16    /// Each token gets `dim` angles that parameterize a unitary rotation.
17    pub angles: Vec<f32>,
18}
19
20impl QuantumEmbedding {
21    pub fn new(vocab_size: usize, dim: usize, seed: u64) -> Self {
22        let mut angles = Vec::with_capacity(vocab_size * dim);
23        for i in 0..(vocab_size * dim) {
24            let s = seed.wrapping_add(i as u64).wrapping_mul(0x517cc1b727220a95);
25            angles.push(std::f32::consts::PI * ((s % 2000) as f32 / 1000.0 - 1.0));
26        }
27        Self {
28            vocab_size,
29            dim,
30            angles,
31        }
32    }
33
34    /// Embed a single token as a density matrix.
35    /// Start with |k mod dim⟩, then apply parameterized rotations to spread
36    /// the state across modes. The rotations are the learnable embedding.
37    pub fn embed(&self, token: usize) -> DensityMatrixN {
38        let d = self.dim;
39        let base = token % d;
40
41        // Build amplitudes from angles via parameterized rotation
42        let offset = token * d;
43        let mut amplitudes = vec![Complex::ZERO; d];
44
45        // Start with base state
46        amplitudes[base] = Complex::ONE;
47
48        // Apply rotation: each angle mixes the base with another mode
49        for k in 0..d {
50            if k == base {
51                continue;
52            }
53            let angle = self.angles.get(offset + k).copied().unwrap_or(0.0);
54            let c = angle.cos();
55            let s = angle.sin();
56            let old_base = amplitudes[base];
57            let old_k = amplitudes[k];
58            amplitudes[base] = Complex::new(c * old_base.re - s * old_k.re, c * old_base.im - s * old_k.im);
59            amplitudes[k] = Complex::new(s * old_base.re + c * old_k.re, s * old_base.im + c * old_k.im);
60        }
61
62        // Build density matrix ρ = |ψ⟩⟨ψ|
63        let mut entries = vec![Complex::ZERO; d * d];
64        for i in 0..d {
65            for j in 0..d {
66                entries[i * d + j] = amplitudes[i].mul(amplitudes[j].conj());
67            }
68        }
69
70        let n2 = d * d;
71        DensityMatrixN {
72            dim: d,
73            entries,
74            scratch_a: vec![Complex::ZERO; n2],
75            scratch_b: vec![Complex::ZERO; n2],
76        }
77    }
78
79    /// Embed a single token as an amplitude vector (Fock space representation).
80    ///
81    /// Returns the dim-dimensional complex state vector |ψ⟩ WITHOUT forming
82    /// the density matrix ρ = |ψ⟩⟨ψ|. This is the Fock space representation:
83    /// dim amplitudes instead of dim² matrix entries.
84    ///
85    /// BA-62: At dim=86, this returns 688 bytes instead of 59,168 bytes (86× smaller).
86    /// The populations |ψ_k|² are identical to the diagonal of ρ = |ψ⟩⟨ψ|.
87    pub fn embed_amplitude(&self, token: usize) -> Vec<Complex> {
88        let d = self.dim;
89        let base = token % d;
90        let offset = token * d;
91        let mut amplitudes = vec![Complex::ZERO; d];
92
93        amplitudes[base] = Complex::ONE;
94
95        for k in 0..d {
96            if k == base {
97                continue;
98            }
99            let angle = self.angles.get(offset + k).copied().unwrap_or(0.0);
100            let c = angle.cos();
101            let s = angle.sin();
102            let old_base = amplitudes[base];
103            let old_k = amplitudes[k];
104            amplitudes[base] = Complex::new(c * old_base.re - s * old_k.re, c * old_base.im - s * old_k.im);
105            amplitudes[k] = Complex::new(s * old_base.re + c * old_k.re, s * old_base.im + c * old_k.im);
106        }
107
108        amplitudes
109    }
110
111    /// Number of learnable parameters.
112    pub fn num_params(&self) -> usize {
113        self.angles.len()
114    }
115}
116
117#[cfg(test)]
118mod tests {
119    use super::*;
120
121    #[test]
122    fn embedding_produces_valid_state() {
123        let emb = QuantumEmbedding::new(65, 5, 42);
124        for token in 0..65 {
125            let rho = emb.embed(token);
126            assert!(
127                (rho.trace() - 1.0).abs() < 1e-4,
128                "token {token}: trace = {}",
129                rho.trace()
130            );
131            assert!(
132                rho.purity() > 0.9,
133                "token {token}: embedded state should be nearly pure, got {}",
134                rho.purity()
135            );
136        }
137    }
138
139    #[test]
140    fn different_tokens_different_states() {
141        let emb = QuantumEmbedding::new(65, 5, 42);
142        let rho_a = emb.embed(0);
143        let rho_b = emb.embed(1);
144        // Check that populations differ
145        let pops_a = rho_a.populations();
146        let pops_b = rho_b.populations();
147        let diff: f32 = pops_a.iter().zip(pops_b.iter()).map(|(a, b)| (a - b).abs()).sum();
148        assert!(diff > 0.01, "Different tokens should produce different states");
149    }
150}