Skip to main content

quantrs2_core/networking/
e91.rs

1//! E91 Entanglement-based Quantum Key Distribution.
2//!
3//! Implements the Ekert 1991 (E91) protocol with:
4//! - Generation of Bell |Φ+⟩ pairs
5//! - Depolarizing noise on each qubit
6//! - Alice and Bob measuring in three angles each
7//! - CHSH Bell inequality test (S parameter)
8//! - Key extraction from matching-basis measurements
9//!
10//! Under ideal conditions S ≈ 2√2 ≈ 2.828; with sufficient noise S ≤ 2 (classical).
11
12use crate::error::{QuantRS2Error, QuantRS2Result};
13use crate::networking::channel::measure_computational;
14use scirs2_core::ndarray::Array2;
15use scirs2_core::random::prelude::*;
16use scirs2_core::random::ChaCha20Rng;
17use scirs2_core::Complex64;
18use std::f64::consts::PI;
19
20// ---------------------------------------------------------------------------
21// Helper: convert u64 seed → 32-byte array for ChaCha20
22// ---------------------------------------------------------------------------
23fn seed_from_u64(seed: u64) -> [u8; 32] {
24    let mut bytes = [0u8; 32];
25    let s = seed.to_le_bytes();
26    bytes[..8].copy_from_slice(&s);
27    bytes[8..16].copy_from_slice(&s);
28    bytes[16..24].copy_from_slice(&s);
29    bytes[24..32].copy_from_slice(&s);
30    bytes
31}
32
33// ---------------------------------------------------------------------------
34// Bell state preparation
35// ---------------------------------------------------------------------------
36
37/// 4×4 density matrix for |Φ+⟩⟨Φ+|.
38/// Basis: {|00⟩, |01⟩, |10⟩, |11⟩} → indices 0, 1, 2, 3.
39fn bell_phi_plus_4x4() -> Array2<Complex64> {
40    let v = 0.5_f64;
41    let mut rho = Array2::<Complex64>::zeros((4, 4));
42    rho[[0, 0]] = Complex64::new(v, 0.0);
43    rho[[0, 3]] = Complex64::new(v, 0.0);
44    rho[[3, 0]] = Complex64::new(v, 0.0);
45    rho[[3, 3]] = Complex64::new(v, 0.0);
46    rho
47}
48
49/// Partial trace of a 4×4 two-qubit density matrix.
50///
51/// `keep_first = true` → trace out qubit B → return ρ_A (2×2).
52fn partial_trace_2q(rho4: &Array2<Complex64>, keep_first: bool) -> Array2<Complex64> {
53    let mut out = Array2::<Complex64>::zeros((2, 2));
54    if keep_first {
55        for a in 0..2_usize {
56            for a2 in 0..2_usize {
57                for b in 0..2_usize {
58                    out[[a, a2]] += rho4[[2 * a + b, 2 * a2 + b]];
59                }
60            }
61        }
62    } else {
63        for b in 0..2_usize {
64            for b2 in 0..2_usize {
65                for a in 0..2_usize {
66                    out[[b, b2]] += rho4[[2 * a + b, 2 * a + b2]];
67                }
68            }
69        }
70    }
71    out
72}
73
74// ---------------------------------------------------------------------------
75// Depolarizing noise on each qubit of a 4×4 two-qubit state
76// ---------------------------------------------------------------------------
77
78fn apply_depolarizing_2q(rho4: &mut Array2<Complex64>, p: f64) {
79    if p <= 0.0 {
80        return;
81    }
82    apply_depolarizing_qubit_a(rho4, p);
83    apply_depolarizing_qubit_b(rho4, p);
84}
85
86fn apply_depolarizing_qubit_a(rho4: &mut Array2<Complex64>, p: f64) {
87    let rho_orig = rho4.clone();
88    let scale_id = Complex64::new(1.0 - p, 0.0);
89    let scale_p = Complex64::new(p / 3.0, 0.0);
90
91    rho4.mapv_inplace(|v| v * scale_id);
92
93    // X_A ⊗ I_B: i → i XOR 2
94    let mut t1 = Array2::<Complex64>::zeros((4, 4));
95    for i in 0..4 {
96        for j in 0..4 {
97            t1[[i, j]] = rho_orig[[i ^ 2, j ^ 2]];
98        }
99    }
100    add_scaled(rho4, &t1, scale_p);
101
102    // Y_A ⊗ I_B
103    let phase_a = |i: usize| -> Complex64 {
104        if (i >> 1) & 1 == 0 {
105            Complex64::new(0.0, 1.0)
106        } else {
107            Complex64::new(0.0, -1.0)
108        }
109    };
110    let mut t2 = Array2::<Complex64>::zeros((4, 4));
111    for i in 0..4 {
112        for j in 0..4 {
113            t2[[i, j]] = phase_a(i) * rho_orig[[i ^ 2, j ^ 2]] * phase_a(j).conj();
114        }
115    }
116    add_scaled(rho4, &t2, scale_p);
117
118    // Z_A ⊗ I_B
119    let sign_a = |i: usize| -> f64 {
120        if (i >> 1) & 1 == 0 {
121            1.0
122        } else {
123            -1.0
124        }
125    };
126    let mut t3 = Array2::<Complex64>::zeros((4, 4));
127    for i in 0..4 {
128        for j in 0..4 {
129            t3[[i, j]] = Complex64::new(sign_a(i) * sign_a(j), 0.0) * rho_orig[[i, j]];
130        }
131    }
132    add_scaled(rho4, &t3, scale_p);
133}
134
135fn apply_depolarizing_qubit_b(rho4: &mut Array2<Complex64>, p: f64) {
136    let rho_orig = rho4.clone();
137    let scale_id = Complex64::new(1.0 - p, 0.0);
138    let scale_p = Complex64::new(p / 3.0, 0.0);
139
140    rho4.mapv_inplace(|v| v * scale_id);
141
142    // I_A ⊗ X_B: i → i XOR 1
143    let mut t1 = Array2::<Complex64>::zeros((4, 4));
144    for i in 0..4 {
145        for j in 0..4 {
146            t1[[i, j]] = rho_orig[[i ^ 1, j ^ 1]];
147        }
148    }
149    add_scaled(rho4, &t1, scale_p);
150
151    // I_A ⊗ Y_B
152    let phase_b = |i: usize| -> Complex64 {
153        if i & 1 == 0 {
154            Complex64::new(0.0, 1.0)
155        } else {
156            Complex64::new(0.0, -1.0)
157        }
158    };
159    let mut t2 = Array2::<Complex64>::zeros((4, 4));
160    for i in 0..4 {
161        for j in 0..4 {
162            t2[[i, j]] = phase_b(i) * rho_orig[[i ^ 1, j ^ 1]] * phase_b(j).conj();
163        }
164    }
165    add_scaled(rho4, &t2, scale_p);
166
167    // I_A ⊗ Z_B
168    let sign_b = |i: usize| -> f64 {
169        if i & 1 == 0 {
170            1.0
171        } else {
172            -1.0
173        }
174    };
175    let mut t3 = Array2::<Complex64>::zeros((4, 4));
176    for i in 0..4 {
177        for j in 0..4 {
178            t3[[i, j]] = Complex64::new(sign_b(i) * sign_b(j), 0.0) * rho_orig[[i, j]];
179        }
180    }
181    add_scaled(rho4, &t3, scale_p);
182}
183
184fn add_scaled(dest: &mut Array2<Complex64>, src: &Array2<Complex64>, scale: Complex64) {
185    for i in 0..dest.nrows() {
186        for j in 0..dest.ncols() {
187            dest[[i, j]] += scale * src[[i, j]];
188        }
189    }
190}
191
192// ---------------------------------------------------------------------------
193// Angle-based measurement
194// ---------------------------------------------------------------------------
195
196/// Measure a qubit density matrix in the rotated basis at angle θ.
197///
198/// The positive eigenvector is cos(θ)|0⟩ + sin(θ)|1⟩.
199/// Returns `false` for the +1 outcome, `true` for −1.
200fn measure_in_angle(rho: &Array2<Complex64>, theta: f64, rand_val: f64) -> bool {
201    let c = theta.cos();
202    let s = theta.sin();
203    // P(+1) = ⟨+_θ|ρ|+_θ⟩ = c²ρ₀₀ + 2cs·Re(ρ₀₁) + s²ρ₁₁
204    let p_plus = (c * c * rho[[0, 0]].re + 2.0 * c * s * rho[[0, 1]].re + s * s * rho[[1, 1]].re)
205        .clamp(0.0, 1.0);
206    // false = +1 (rand < p_plus), true = -1 (rand >= p_plus)
207    rand_val >= p_plus
208}
209
210/// Compute E(θ_a, θ_b) = Pr(same) − Pr(different) from density matrix.
211fn compute_correlation(rho4: &Array2<Complex64>, theta_a: f64, theta_b: f64) -> f64 {
212    let ca = theta_a.cos();
213    let sa = theta_a.sin();
214    let cb = theta_b.cos();
215    let sb = theta_b.sin();
216
217    let m_plus_a = [
218        [Complex64::new(ca * ca, 0.0), Complex64::new(ca * sa, 0.0)],
219        [Complex64::new(ca * sa, 0.0), Complex64::new(sa * sa, 0.0)],
220    ];
221    let m_minus_a = [
222        [Complex64::new(sa * sa, 0.0), Complex64::new(-ca * sa, 0.0)],
223        [Complex64::new(-ca * sa, 0.0), Complex64::new(ca * ca, 0.0)],
224    ];
225    let m_plus_b = [
226        [Complex64::new(cb * cb, 0.0), Complex64::new(cb * sb, 0.0)],
227        [Complex64::new(cb * sb, 0.0), Complex64::new(sb * sb, 0.0)],
228    ];
229    let m_minus_b = [
230        [Complex64::new(sb * sb, 0.0), Complex64::new(-cb * sb, 0.0)],
231        [Complex64::new(-cb * sb, 0.0), Complex64::new(cb * cb, 0.0)],
232    ];
233
234    let p_pp = trace_joint(&m_plus_a, &m_plus_b, rho4);
235    let p_pm = trace_joint(&m_plus_a, &m_minus_b, rho4);
236    let p_mp = trace_joint(&m_minus_a, &m_plus_b, rho4);
237    let p_mm = trace_joint(&m_minus_a, &m_minus_b, rho4);
238
239    (p_pp + p_mm - p_pm - p_mp).clamp(-1.0, 1.0)
240}
241
242/// Tr[(M_a ⊗ M_b) ρ4] for 2×2 projectors M_a, M_b.
243fn trace_joint(
244    m_a: &[[Complex64; 2]; 2],
245    m_b: &[[Complex64; 2]; 2],
246    rho4: &Array2<Complex64>,
247) -> f64 {
248    let mut result = Complex64::new(0.0, 0.0);
249    for ia in 0..2 {
250        for ib in 0..2 {
251            for ja in 0..2 {
252                for jb in 0..2 {
253                    let i = 2 * ia + ib;
254                    let j = 2 * ja + jb;
255                    result += m_a[ia][ja] * m_b[ib][jb] * rho4[[i, j]];
256                }
257            }
258        }
259    }
260    result.re.clamp(0.0, 1.0)
261}
262
263// ---------------------------------------------------------------------------
264// E91 Protocol
265// ---------------------------------------------------------------------------
266
267/// E91 entanglement-based QKD protocol.
268#[derive(Debug, Clone)]
269pub struct E91Protocol {
270    /// Number of entangled pairs to generate.
271    pub n_pairs: usize,
272    /// Depolarizing noise probability per qubit in [0, 1].
273    pub noise: f64,
274    /// Seed for the random number generator.
275    pub rng_seed: u64,
276}
277
278/// Result of running the E91 protocol.
279#[derive(Debug, Clone)]
280pub struct E91Result {
281    /// Extracted key bits.
282    pub key: Vec<bool>,
283    /// CHSH S parameter |E(a1,b1) - E(a1,b2) + E(a2,b1) + E(a2,b2)|.
284    pub chsh_value: f64,
285    /// Whether the Bell test passed (`chsh_value > 2.0`).
286    pub passed_bell_test: bool,
287    /// Key generation rate (bits per pair).
288    pub key_rate: f64,
289}
290
291impl E91Protocol {
292    /// Create a new E91 protocol instance.
293    pub fn new(n_pairs: usize, noise: f64, rng_seed: u64) -> Self {
294        Self {
295            n_pairs,
296            noise: noise.clamp(0.0, 1.0),
297            rng_seed,
298        }
299    }
300
301    /// Execute the E91 protocol.
302    pub fn run(&self) -> QuantRS2Result<E91Result> {
303        if self.n_pairs == 0 {
304            return Err(QuantRS2Error::InvalidInput(
305                "n_pairs must be > 0".to_string(),
306            ));
307        }
308
309        let mut rng = ChaCha20Rng::from_seed(seed_from_u64(self.rng_seed));
310
311        // Alice's measurement angles: 0°, 45°, 90°
312        let alice_angles = [0.0_f64, PI / 4.0, PI / 2.0];
313        // Bob's measurement angles: 22.5°, 67.5°, 112.5°
314        let bob_angles = [PI / 8.0, 3.0 * PI / 8.0, 5.0 * PI / 8.0];
315
316        // Per-pair measurements
317        let mut alice_results: Vec<(usize, bool)> = Vec::with_capacity(self.n_pairs);
318        let mut bob_results: Vec<(usize, bool)> = Vec::with_capacity(self.n_pairs);
319
320        for _ in 0..self.n_pairs {
321            let mut rho4 = bell_phi_plus_4x4();
322            apply_depolarizing_2q(&mut rho4, self.noise);
323
324            let ai = rng.random_range(0..3_usize);
325            let bi = rng.random_range(0..3_usize);
326
327            let rho_a = partial_trace_2q(&rho4, true);
328            let rho_b = partial_trace_2q(&rho4, false);
329
330            let r_a: f64 = rng.random();
331            let r_b: f64 = rng.random();
332            let a_bit = measure_in_angle(&rho_a, alice_angles[ai], r_a);
333            let b_bit = measure_in_angle(&rho_b, bob_angles[bi], r_b);
334
335            alice_results.push((ai, a_bit));
336            bob_results.push((bi, b_bit));
337        }
338
339        // CHSH S-parameter computed analytically from a representative pair
340        let mut representative = bell_phi_plus_4x4();
341        apply_depolarizing_2q(&mut representative, self.noise);
342
343        // S = |E(0°,22.5°) - E(0°,67.5°) + E(45°,22.5°) + E(45°,67.5°)|
344        let e00 = compute_correlation(&representative, alice_angles[0], bob_angles[0]);
345        let e01 = compute_correlation(&representative, alice_angles[0], bob_angles[1]);
346        let e10 = compute_correlation(&representative, alice_angles[1], bob_angles[0]);
347        let e11 = compute_correlation(&representative, alice_angles[1], bob_angles[1]);
348        let chsh_value = (e00 - e01 + e10 + e11).abs();
349
350        // Key extraction: pairs where alice_idx=0 and bob_idx=0
351        let mut key: Vec<bool> = Vec::new();
352        for k in 0..self.n_pairs {
353            let (ai, a_bit) = alice_results[k];
354            let (bi, _b_bit) = bob_results[k];
355            if ai == 0 && bi == 0 {
356                key.push(a_bit);
357            }
358        }
359
360        let key_rate = key.len() as f64 / self.n_pairs as f64;
361        let passed_bell_test = chsh_value > 2.0;
362
363        Ok(E91Result {
364            key,
365            chsh_value,
366            passed_bell_test,
367            key_rate,
368        })
369    }
370}
371
372// ---------------------------------------------------------------------------
373// Tests
374// ---------------------------------------------------------------------------
375
376#[cfg(test)]
377mod tests {
378    use super::*;
379    use approx::assert_abs_diff_eq;
380
381    #[test]
382    fn e91_ideal_chsh_near_2sqrt2() {
383        let proto = E91Protocol::new(200, 0.0, 42);
384        let result = proto.run().expect("e91 run");
385        // Ideal CHSH for |Φ+⟩: S ≈ 2√2 ≈ 2.828
386        assert!(
387            result.chsh_value > 2.5,
388            "expected CHSH ≈ 2√2 ≈ 2.828, got {}",
389            result.chsh_value
390        );
391        assert!(result.passed_bell_test);
392    }
393
394    #[test]
395    fn e91_high_noise_chsh_below_2() {
396        let proto = E91Protocol::new(200, 0.9, 42);
397        let result = proto.run().expect("e91 run");
398        assert!(
399            result.chsh_value < 2.0,
400            "expected CHSH < 2 with high noise, got {}",
401            result.chsh_value
402        );
403        assert!(!result.passed_bell_test);
404    }
405
406    #[test]
407    fn e91_key_rate_reasonable() {
408        let proto = E91Protocol::new(3000, 0.0, 77);
409        let result = proto.run().expect("e91 run");
410        // Key rate ≈ 1/9 (both alice and bob choose index 0 independently from 3 choices)
411        assert!(
412            result.key_rate > 0.02,
413            "expected reasonable key rate, got {}",
414            result.key_rate
415        );
416    }
417
418    #[test]
419    fn bell_phi_plus_correlation_correct() {
420        let rho = bell_phi_plus_4x4();
421        // E(0°, 22.5°) for |Φ+⟩ = cos(2*(0 - π/8)) = cos(π/4) ≈ 0.707
422        let e = compute_correlation(&rho, 0.0, PI / 8.0);
423        assert_abs_diff_eq!(e, (PI / 4.0).cos(), epsilon = 0.01);
424    }
425}