Skip to main content

proof_engine/quantum/
entanglement.rs

1use std::f64::consts::PI;
2use super::schrodinger::Complex;
3
4/// Single qubit state: |psi> = alpha|0> + beta|1>.
5#[derive(Clone, Debug)]
6pub struct QubitState {
7    pub alpha: Complex,
8    pub beta: Complex,
9}
10
11impl QubitState {
12    pub fn new(alpha: Complex, beta: Complex) -> Self {
13        Self { alpha, beta }
14    }
15
16    pub fn zero() -> Self {
17        Self { alpha: Complex::one(), beta: Complex::zero() }
18    }
19
20    pub fn one() -> Self {
21        Self { alpha: Complex::zero(), beta: Complex::one() }
22    }
23
24    pub fn norm_sq(&self) -> f64 {
25        self.alpha.norm_sq() + self.beta.norm_sq()
26    }
27
28    pub fn normalize(&mut self) {
29        let n = self.norm_sq().sqrt();
30        if n > 1e-30 {
31            self.alpha = self.alpha / n;
32            self.beta = self.beta / n;
33        }
34    }
35}
36
37/// Two-qubit state: amplitudes for |00>, |01>, |10>, |11>.
38#[derive(Clone, Debug)]
39pub struct TwoQubitState {
40    pub amplitudes: [Complex; 4],
41}
42
43impl TwoQubitState {
44    pub fn new(amplitudes: [Complex; 4]) -> Self {
45        Self { amplitudes }
46    }
47
48    /// Product state |a>|b>.
49    pub fn product(a: &QubitState, b: &QubitState) -> Self {
50        Self {
51            amplitudes: [
52                a.alpha * b.alpha, // |00>
53                a.alpha * b.beta,  // |01>
54                a.beta * b.alpha,  // |10>
55                a.beta * b.beta,   // |11>
56            ],
57        }
58    }
59
60    pub fn norm_sq(&self) -> f64 {
61        self.amplitudes.iter().map(|c| c.norm_sq()).sum()
62    }
63
64    pub fn normalize(&mut self) {
65        let n = self.norm_sq().sqrt();
66        if n > 1e-30 {
67            for a in &mut self.amplitudes {
68                *a = *a / n;
69            }
70        }
71    }
72}
73
74/// Create a Bell state.
75/// 0: Phi+ = (|00> + |11>)/sqrt(2)
76/// 1: Phi- = (|00> - |11>)/sqrt(2)
77/// 2: Psi+ = (|01> + |10>)/sqrt(2)
78/// 3: Psi- = (|01> - |10>)/sqrt(2)
79pub fn bell_state(which: u8) -> TwoQubitState {
80    let s = 1.0 / 2.0_f64.sqrt();
81    match which {
82        0 => TwoQubitState::new([
83            Complex::new(s, 0.0), Complex::zero(),
84            Complex::zero(), Complex::new(s, 0.0),
85        ]),
86        1 => TwoQubitState::new([
87            Complex::new(s, 0.0), Complex::zero(),
88            Complex::zero(), Complex::new(-s, 0.0),
89        ]),
90        2 => TwoQubitState::new([
91            Complex::zero(), Complex::new(s, 0.0),
92            Complex::new(s, 0.0), Complex::zero(),
93        ]),
94        _ => TwoQubitState::new([
95            Complex::zero(), Complex::new(s, 0.0),
96            Complex::new(-s, 0.0), Complex::zero(),
97        ]),
98    }
99}
100
101/// Measure one qubit of a two-qubit state.
102/// Returns (outcome, collapsed state of the other qubit).
103pub fn measure_qubit(state: &TwoQubitState, which: usize, rng_val: f64) -> (u8, QubitState) {
104    let a = &state.amplitudes;
105    if which == 0 {
106        // Measuring first qubit
107        let p0 = a[0].norm_sq() + a[1].norm_sq(); // prob of first qubit = 0
108        if rng_val < p0 {
109            // Outcome 0: remaining state is alpha|0> + beta|1> from a[0]|0> + a[1]|1>
110            let mut q = QubitState::new(a[0], a[1]);
111            q.normalize();
112            (0, q)
113        } else {
114            let mut q = QubitState::new(a[2], a[3]);
115            q.normalize();
116            (1, q)
117        }
118    } else {
119        // Measuring second qubit
120        let p0 = a[0].norm_sq() + a[2].norm_sq();
121        if rng_val < p0 {
122            let mut q = QubitState::new(a[0], a[2]);
123            q.normalize();
124            (0, q)
125        } else {
126            let mut q = QubitState::new(a[1], a[3]);
127            q.normalize();
128            (1, q)
129        }
130    }
131}
132
133/// 2x2 density matrix.
134#[derive(Clone, Debug)]
135pub struct DensityMatrix2x2 {
136    pub rho: [[Complex; 2]; 2],
137}
138
139impl DensityMatrix2x2 {
140    pub fn trace(&self) -> f64 {
141        (self.rho[0][0] + self.rho[1][1]).re
142    }
143
144    pub fn purity(&self) -> f64 {
145        let mut sum = Complex::zero();
146        for i in 0..2 {
147            for j in 0..2 {
148                sum += self.rho[i][j] * self.rho[j][i];
149            }
150        }
151        sum.re
152    }
153
154    pub fn is_mixed(&self) -> bool {
155        self.purity() < 1.0 - 1e-6
156    }
157}
158
159/// Partial trace: trace out one qubit to get the density matrix of the other.
160pub fn partial_trace(state: &TwoQubitState, trace_out: usize) -> DensityMatrix2x2 {
161    let a = &state.amplitudes;
162    if trace_out == 1 {
163        // Trace out second qubit -> density matrix of first
164        let rho00 = a[0] * a[0].conj() + a[1] * a[1].conj();
165        let rho01 = a[0] * a[2].conj() + a[1] * a[3].conj();
166        let rho10 = a[2] * a[0].conj() + a[3] * a[1].conj();
167        let rho11 = a[2] * a[2].conj() + a[3] * a[3].conj();
168        DensityMatrix2x2 { rho: [[rho00, rho01], [rho10, rho11]] }
169    } else {
170        // Trace out first qubit -> density matrix of second
171        let rho00 = a[0] * a[0].conj() + a[2] * a[2].conj();
172        let rho01 = a[0] * a[1].conj() + a[2] * a[3].conj();
173        let rho10 = a[1] * a[0].conj() + a[3] * a[2].conj();
174        let rho11 = a[1] * a[1].conj() + a[3] * a[3].conj();
175        DensityMatrix2x2 { rho: [[rho00, rho01], [rho10, rho11]] }
176    }
177}
178
179/// Concurrence: entanglement measure for two-qubit pure states.
180/// C = 2|ad - bc| where state = a|00> + b|01> + c|10> + d|11>.
181pub fn concurrence(state: &TwoQubitState) -> f64 {
182    let a = state.amplitudes[0];
183    let b = state.amplitudes[1];
184    let c = state.amplitudes[2];
185    let d = state.amplitudes[3];
186    2.0 * (a * d - b * c).norm()
187}
188
189/// CHSH correlation: S = E(a1,b1) - E(a1,b2) + E(a2,b1) + E(a2,b2).
190/// Each angle specifies a measurement axis in the XZ plane.
191/// Returns |S|, which violates Bell inequality when > 2.
192pub fn chsh_correlation(
193    state: &TwoQubitState,
194    a1: f64,
195    a2: f64,
196    b1: f64,
197    b2: f64,
198) -> f64 {
199    let e = |a_angle: f64, b_angle: f64| -> f64 {
200        // E(a,b) = <psi| (sigma_a tensor sigma_b) |psi>
201        // sigma_n = cos(theta)*sigma_z + sin(theta)*sigma_x for angle theta
202        let ca = a_angle.cos();
203        let sa = a_angle.sin();
204        let cb = b_angle.cos();
205        let sb = b_angle.sin();
206
207        let amp = &state.amplitudes;
208        // Compute <psi| A tensor B |psi>
209        // A = [[ca, sa],[sa, -ca]], B = [[cb, sb],[sb, -cb]]
210        // A tensor B is 4x4
211        let mut result = Complex::zero();
212        let a_mat = [[ca, sa], [sa, -ca]];
213        let b_mat = [[cb, sb], [sb, -cb]];
214
215        for i in 0..2 {
216            for j in 0..2 {
217                let bra_idx = i * 2 + j;
218                for k in 0..2 {
219                    for l in 0..2 {
220                        let ket_idx = k * 2 + l;
221                        let coeff = a_mat[i][k] * b_mat[j][l];
222                        result += amp[bra_idx].conj() * amp[ket_idx] * coeff;
223                    }
224                }
225            }
226        }
227        result.re
228    };
229
230    let s = e(a1, b1) - e(a1, b2) + e(a2, b1) + e(a2, b2);
231    s.abs()
232}
233
234/// Renderer for entanglement visualization.
235pub struct EntanglementRenderer {
236    pub width: usize,
237}
238
239impl EntanglementRenderer {
240    pub fn new(width: usize) -> Self {
241        Self { width }
242    }
243
244    /// Render two particles with correlated states.
245    pub fn render(&self, state: &TwoQubitState, measured: Option<(u8, u8)>) -> Vec<(char, f64, f64, f64)> {
246        let mut result = Vec::with_capacity(self.width);
247        let mid = self.width / 2;
248
249        for i in 0..self.width {
250            if let Some((m0, m1)) = measured {
251                // After measurement: show definite states
252                if i < mid {
253                    let ch = if m0 == 0 { '0' } else { '1' };
254                    result.push((ch, 0.0, 1.0, 0.0));
255                } else {
256                    let ch = if m1 == 0 { '0' } else { '1' };
257                    result.push((ch, 1.0, 0.0, 0.0));
258                }
259            } else {
260                // Before measurement: show superposition
261                if i == mid - 2 || i == mid + 1 {
262                    let prob = if i < mid {
263                        state.amplitudes[0].norm_sq() + state.amplitudes[1].norm_sq()
264                    } else {
265                        state.amplitudes[0].norm_sq() + state.amplitudes[2].norm_sq()
266                    };
267                    let brightness = prob.min(1.0);
268                    result.push(('*', brightness, brightness, 0.5));
269                } else if i == mid - 1 || i == mid {
270                    result.push(('~', 0.3, 0.3, 0.8)); // entanglement link
271                } else {
272                    result.push((' ', 0.0, 0.0, 0.0));
273                }
274            }
275        }
276        result
277    }
278}
279
280/// N-particle GHZ state: (|00...0> + |11...1>)/sqrt(2).
281#[derive(Clone, Debug)]
282pub struct GHZState {
283    pub n_qubits: usize,
284    pub amplitudes: Vec<Complex>,
285}
286
287impl GHZState {
288    pub fn new(n_qubits: usize) -> Self {
289        let size = 1 << n_qubits;
290        let mut amplitudes = vec![Complex::zero(); size];
291        let s = 1.0 / 2.0_f64.sqrt();
292        amplitudes[0] = Complex::new(s, 0.0);             // |00...0>
293        amplitudes[size - 1] = Complex::new(s, 0.0);       // |11...1>
294        Self { n_qubits, amplitudes }
295    }
296
297    pub fn norm_sq(&self) -> f64 {
298        self.amplitudes.iter().map(|c| c.norm_sq()).sum()
299    }
300
301    /// Measure all qubits. Returns bit string.
302    pub fn measure(&self, rng_val: f64) -> Vec<u8> {
303        let n = self.amplitudes.len();
304        let mut cumulative = 0.0;
305        let mut outcome = 0;
306        for i in 0..n {
307            cumulative += self.amplitudes[i].norm_sq();
308            if rng_val < cumulative {
309                outcome = i;
310                break;
311            }
312        }
313        // Convert outcome to bits
314        (0..self.n_qubits)
315            .map(|bit| ((outcome >> (self.n_qubits - 1 - bit)) & 1) as u8)
316            .collect()
317    }
318}
319
320#[cfg(test)]
321mod tests {
322    use super::*;
323
324    #[test]
325    fn test_bell_states_normalized() {
326        for i in 0..4 {
327            let state = bell_state(i);
328            let norm = state.norm_sq();
329            assert!((norm - 1.0).abs() < 1e-10, "Bell state {} norm: {}", i, norm);
330        }
331    }
332
333    #[test]
334    fn test_bell_state_maximally_entangled() {
335        for i in 0..4 {
336            let state = bell_state(i);
337            let c = concurrence(&state);
338            assert!((c - 1.0).abs() < 1e-10, "Bell state {} concurrence: {}", i, c);
339        }
340    }
341
342    #[test]
343    fn test_product_state_not_entangled() {
344        let a = QubitState::zero();
345        let b = QubitState::zero();
346        let state = TwoQubitState::product(&a, &b);
347        let c = concurrence(&state);
348        assert!(c < 1e-10, "Product state concurrence: {}", c);
349    }
350
351    #[test]
352    fn test_measurement_correlation() {
353        // For Phi+, measuring qubit 0 as 0 should give qubit 1 as 0
354        let state = bell_state(0); // Phi+
355        let (outcome, remaining) = measure_qubit(&state, 0, 0.1); // force outcome 0
356        if outcome == 0 {
357            // Remaining qubit should be |0>
358            assert!(remaining.alpha.norm_sq() > 0.9);
359        } else {
360            // Remaining qubit should be |1>
361            assert!(remaining.beta.norm_sq() > 0.9);
362        }
363    }
364
365    #[test]
366    fn test_partial_trace_bell_gives_mixed() {
367        let state = bell_state(0);
368        let rho = partial_trace(&state, 1);
369        assert!(rho.is_mixed(), "Partial trace of Bell state should be mixed");
370        let purity = rho.purity();
371        assert!((purity - 0.5).abs() < 1e-10, "Purity: {}", purity);
372    }
373
374    #[test]
375    fn test_partial_trace_product_gives_pure() {
376        let a = QubitState::new(
377            Complex::new(1.0 / 2.0_f64.sqrt(), 0.0),
378            Complex::new(1.0 / 2.0_f64.sqrt(), 0.0),
379        );
380        let b = QubitState::zero();
381        let state = TwoQubitState::product(&a, &b);
382        let rho = partial_trace(&state, 1);
383        let purity = rho.purity();
384        assert!((purity - 1.0).abs() < 1e-10, "Product state purity: {}", purity);
385    }
386
387    #[test]
388    fn test_chsh_violation() {
389        // For Phi+, optimal angles give S = 2*sqrt(2) ~ 2.828
390        let state = bell_state(0);
391        let s = chsh_correlation(&state, 0.0, PI / 2.0, PI / 4.0, -PI / 4.0);
392        assert!(s > 2.0, "CHSH S = {} should violate Bell inequality (> 2)", s);
393        assert!((s - 2.0 * 2.0_f64.sqrt()).abs() < 0.3, "S = {} should be ~2.828", s);
394    }
395
396    #[test]
397    fn test_chsh_classical_bound() {
398        // Product state should not violate
399        let state = TwoQubitState::product(&QubitState::zero(), &QubitState::zero());
400        let s = chsh_correlation(&state, 0.0, PI / 2.0, PI / 4.0, -PI / 4.0);
401        assert!(s <= 2.1, "Product state S = {} should be <= 2", s);
402    }
403
404    #[test]
405    fn test_ghz_state() {
406        let ghz = GHZState::new(3);
407        assert_eq!(ghz.amplitudes.len(), 8);
408        let norm = ghz.norm_sq();
409        assert!((norm - 1.0).abs() < 1e-10);
410
411        // Measurement should give all 0s or all 1s
412        let result_0 = ghz.measure(0.1);
413        assert!(result_0.iter().all(|&b| b == 0) || result_0.iter().all(|&b| b == 1));
414        let result_1 = ghz.measure(0.9);
415        assert!(result_1.iter().all(|&b| b == 0) || result_1.iter().all(|&b| b == 1));
416    }
417
418    #[test]
419    fn test_renderer() {
420        let state = bell_state(0);
421        let renderer = EntanglementRenderer::new(20);
422        let result = renderer.render(&state, None);
423        assert_eq!(result.len(), 20);
424        let measured = renderer.render(&state, Some((0, 0)));
425        assert_eq!(measured.len(), 20);
426    }
427}