1use crate::error::{QuantRS2Error, QuantRS2Result};
13use crate::networking::channel::measure_computational;
14use scirs2_core::ndarray::Array2;
15use scirs2_core::random::prelude::*;
16use scirs2_core::random::ChaCha20Rng;
17use scirs2_core::Complex64;
18use std::f64::consts::PI;
19
20fn seed_from_u64(seed: u64) -> [u8; 32] {
24 let mut bytes = [0u8; 32];
25 let s = seed.to_le_bytes();
26 bytes[..8].copy_from_slice(&s);
27 bytes[8..16].copy_from_slice(&s);
28 bytes[16..24].copy_from_slice(&s);
29 bytes[24..32].copy_from_slice(&s);
30 bytes
31}
32
33fn bell_phi_plus_4x4() -> Array2<Complex64> {
40 let v = 0.5_f64;
41 let mut rho = Array2::<Complex64>::zeros((4, 4));
42 rho[[0, 0]] = Complex64::new(v, 0.0);
43 rho[[0, 3]] = Complex64::new(v, 0.0);
44 rho[[3, 0]] = Complex64::new(v, 0.0);
45 rho[[3, 3]] = Complex64::new(v, 0.0);
46 rho
47}
48
49fn partial_trace_2q(rho4: &Array2<Complex64>, keep_first: bool) -> Array2<Complex64> {
53 let mut out = Array2::<Complex64>::zeros((2, 2));
54 if keep_first {
55 for a in 0..2_usize {
56 for a2 in 0..2_usize {
57 for b in 0..2_usize {
58 out[[a, a2]] += rho4[[2 * a + b, 2 * a2 + b]];
59 }
60 }
61 }
62 } else {
63 for b in 0..2_usize {
64 for b2 in 0..2_usize {
65 for a in 0..2_usize {
66 out[[b, b2]] += rho4[[2 * a + b, 2 * a + b2]];
67 }
68 }
69 }
70 }
71 out
72}
73
74fn apply_depolarizing_2q(rho4: &mut Array2<Complex64>, p: f64) {
79 if p <= 0.0 {
80 return;
81 }
82 apply_depolarizing_qubit_a(rho4, p);
83 apply_depolarizing_qubit_b(rho4, p);
84}
85
86fn apply_depolarizing_qubit_a(rho4: &mut Array2<Complex64>, p: f64) {
87 let rho_orig = rho4.clone();
88 let scale_id = Complex64::new(1.0 - p, 0.0);
89 let scale_p = Complex64::new(p / 3.0, 0.0);
90
91 rho4.mapv_inplace(|v| v * scale_id);
92
93 let mut t1 = Array2::<Complex64>::zeros((4, 4));
95 for i in 0..4 {
96 for j in 0..4 {
97 t1[[i, j]] = rho_orig[[i ^ 2, j ^ 2]];
98 }
99 }
100 add_scaled(rho4, &t1, scale_p);
101
102 let phase_a = |i: usize| -> Complex64 {
104 if (i >> 1) & 1 == 0 {
105 Complex64::new(0.0, 1.0)
106 } else {
107 Complex64::new(0.0, -1.0)
108 }
109 };
110 let mut t2 = Array2::<Complex64>::zeros((4, 4));
111 for i in 0..4 {
112 for j in 0..4 {
113 t2[[i, j]] = phase_a(i) * rho_orig[[i ^ 2, j ^ 2]] * phase_a(j).conj();
114 }
115 }
116 add_scaled(rho4, &t2, scale_p);
117
118 let sign_a = |i: usize| -> f64 {
120 if (i >> 1) & 1 == 0 {
121 1.0
122 } else {
123 -1.0
124 }
125 };
126 let mut t3 = Array2::<Complex64>::zeros((4, 4));
127 for i in 0..4 {
128 for j in 0..4 {
129 t3[[i, j]] = Complex64::new(sign_a(i) * sign_a(j), 0.0) * rho_orig[[i, j]];
130 }
131 }
132 add_scaled(rho4, &t3, scale_p);
133}
134
135fn apply_depolarizing_qubit_b(rho4: &mut Array2<Complex64>, p: f64) {
136 let rho_orig = rho4.clone();
137 let scale_id = Complex64::new(1.0 - p, 0.0);
138 let scale_p = Complex64::new(p / 3.0, 0.0);
139
140 rho4.mapv_inplace(|v| v * scale_id);
141
142 let mut t1 = Array2::<Complex64>::zeros((4, 4));
144 for i in 0..4 {
145 for j in 0..4 {
146 t1[[i, j]] = rho_orig[[i ^ 1, j ^ 1]];
147 }
148 }
149 add_scaled(rho4, &t1, scale_p);
150
151 let phase_b = |i: usize| -> Complex64 {
153 if i & 1 == 0 {
154 Complex64::new(0.0, 1.0)
155 } else {
156 Complex64::new(0.0, -1.0)
157 }
158 };
159 let mut t2 = Array2::<Complex64>::zeros((4, 4));
160 for i in 0..4 {
161 for j in 0..4 {
162 t2[[i, j]] = phase_b(i) * rho_orig[[i ^ 1, j ^ 1]] * phase_b(j).conj();
163 }
164 }
165 add_scaled(rho4, &t2, scale_p);
166
167 let sign_b = |i: usize| -> f64 {
169 if i & 1 == 0 {
170 1.0
171 } else {
172 -1.0
173 }
174 };
175 let mut t3 = Array2::<Complex64>::zeros((4, 4));
176 for i in 0..4 {
177 for j in 0..4 {
178 t3[[i, j]] = Complex64::new(sign_b(i) * sign_b(j), 0.0) * rho_orig[[i, j]];
179 }
180 }
181 add_scaled(rho4, &t3, scale_p);
182}
183
184fn add_scaled(dest: &mut Array2<Complex64>, src: &Array2<Complex64>, scale: Complex64) {
185 for i in 0..dest.nrows() {
186 for j in 0..dest.ncols() {
187 dest[[i, j]] += scale * src[[i, j]];
188 }
189 }
190}
191
192fn measure_in_angle(rho: &Array2<Complex64>, theta: f64, rand_val: f64) -> bool {
201 let c = theta.cos();
202 let s = theta.sin();
203 let p_plus = (c * c * rho[[0, 0]].re + 2.0 * c * s * rho[[0, 1]].re + s * s * rho[[1, 1]].re)
205 .clamp(0.0, 1.0);
206 rand_val >= p_plus
208}
209
210fn compute_correlation(rho4: &Array2<Complex64>, theta_a: f64, theta_b: f64) -> f64 {
212 let ca = theta_a.cos();
213 let sa = theta_a.sin();
214 let cb = theta_b.cos();
215 let sb = theta_b.sin();
216
217 let m_plus_a = [
218 [Complex64::new(ca * ca, 0.0), Complex64::new(ca * sa, 0.0)],
219 [Complex64::new(ca * sa, 0.0), Complex64::new(sa * sa, 0.0)],
220 ];
221 let m_minus_a = [
222 [Complex64::new(sa * sa, 0.0), Complex64::new(-ca * sa, 0.0)],
223 [Complex64::new(-ca * sa, 0.0), Complex64::new(ca * ca, 0.0)],
224 ];
225 let m_plus_b = [
226 [Complex64::new(cb * cb, 0.0), Complex64::new(cb * sb, 0.0)],
227 [Complex64::new(cb * sb, 0.0), Complex64::new(sb * sb, 0.0)],
228 ];
229 let m_minus_b = [
230 [Complex64::new(sb * sb, 0.0), Complex64::new(-cb * sb, 0.0)],
231 [Complex64::new(-cb * sb, 0.0), Complex64::new(cb * cb, 0.0)],
232 ];
233
234 let p_pp = trace_joint(&m_plus_a, &m_plus_b, rho4);
235 let p_pm = trace_joint(&m_plus_a, &m_minus_b, rho4);
236 let p_mp = trace_joint(&m_minus_a, &m_plus_b, rho4);
237 let p_mm = trace_joint(&m_minus_a, &m_minus_b, rho4);
238
239 (p_pp + p_mm - p_pm - p_mp).clamp(-1.0, 1.0)
240}
241
242fn trace_joint(
244 m_a: &[[Complex64; 2]; 2],
245 m_b: &[[Complex64; 2]; 2],
246 rho4: &Array2<Complex64>,
247) -> f64 {
248 let mut result = Complex64::new(0.0, 0.0);
249 for ia in 0..2 {
250 for ib in 0..2 {
251 for ja in 0..2 {
252 for jb in 0..2 {
253 let i = 2 * ia + ib;
254 let j = 2 * ja + jb;
255 result += m_a[ia][ja] * m_b[ib][jb] * rho4[[i, j]];
256 }
257 }
258 }
259 }
260 result.re.clamp(0.0, 1.0)
261}
262
263#[derive(Debug, Clone)]
269pub struct E91Protocol {
270 pub n_pairs: usize,
272 pub noise: f64,
274 pub rng_seed: u64,
276}
277
278#[derive(Debug, Clone)]
280pub struct E91Result {
281 pub key: Vec<bool>,
283 pub chsh_value: f64,
285 pub passed_bell_test: bool,
287 pub key_rate: f64,
289}
290
291impl E91Protocol {
292 pub fn new(n_pairs: usize, noise: f64, rng_seed: u64) -> Self {
294 Self {
295 n_pairs,
296 noise: noise.clamp(0.0, 1.0),
297 rng_seed,
298 }
299 }
300
301 pub fn run(&self) -> QuantRS2Result<E91Result> {
303 if self.n_pairs == 0 {
304 return Err(QuantRS2Error::InvalidInput(
305 "n_pairs must be > 0".to_string(),
306 ));
307 }
308
309 let mut rng = ChaCha20Rng::from_seed(seed_from_u64(self.rng_seed));
310
311 let alice_angles = [0.0_f64, PI / 4.0, PI / 2.0];
313 let bob_angles = [PI / 8.0, 3.0 * PI / 8.0, 5.0 * PI / 8.0];
315
316 let mut alice_results: Vec<(usize, bool)> = Vec::with_capacity(self.n_pairs);
318 let mut bob_results: Vec<(usize, bool)> = Vec::with_capacity(self.n_pairs);
319
320 for _ in 0..self.n_pairs {
321 let mut rho4 = bell_phi_plus_4x4();
322 apply_depolarizing_2q(&mut rho4, self.noise);
323
324 let ai = rng.random_range(0..3_usize);
325 let bi = rng.random_range(0..3_usize);
326
327 let rho_a = partial_trace_2q(&rho4, true);
328 let rho_b = partial_trace_2q(&rho4, false);
329
330 let r_a: f64 = rng.random();
331 let r_b: f64 = rng.random();
332 let a_bit = measure_in_angle(&rho_a, alice_angles[ai], r_a);
333 let b_bit = measure_in_angle(&rho_b, bob_angles[bi], r_b);
334
335 alice_results.push((ai, a_bit));
336 bob_results.push((bi, b_bit));
337 }
338
339 let mut representative = bell_phi_plus_4x4();
341 apply_depolarizing_2q(&mut representative, self.noise);
342
343 let e00 = compute_correlation(&representative, alice_angles[0], bob_angles[0]);
345 let e01 = compute_correlation(&representative, alice_angles[0], bob_angles[1]);
346 let e10 = compute_correlation(&representative, alice_angles[1], bob_angles[0]);
347 let e11 = compute_correlation(&representative, alice_angles[1], bob_angles[1]);
348 let chsh_value = (e00 - e01 + e10 + e11).abs();
349
350 let mut key: Vec<bool> = Vec::new();
352 for k in 0..self.n_pairs {
353 let (ai, a_bit) = alice_results[k];
354 let (bi, _b_bit) = bob_results[k];
355 if ai == 0 && bi == 0 {
356 key.push(a_bit);
357 }
358 }
359
360 let key_rate = key.len() as f64 / self.n_pairs as f64;
361 let passed_bell_test = chsh_value > 2.0;
362
363 Ok(E91Result {
364 key,
365 chsh_value,
366 passed_bell_test,
367 key_rate,
368 })
369 }
370}
371
372#[cfg(test)]
377mod tests {
378 use super::*;
379 use approx::assert_abs_diff_eq;
380
381 #[test]
382 fn e91_ideal_chsh_near_2sqrt2() {
383 let proto = E91Protocol::new(200, 0.0, 42);
384 let result = proto.run().expect("e91 run");
385 assert!(
387 result.chsh_value > 2.5,
388 "expected CHSH ≈ 2√2 ≈ 2.828, got {}",
389 result.chsh_value
390 );
391 assert!(result.passed_bell_test);
392 }
393
394 #[test]
395 fn e91_high_noise_chsh_below_2() {
396 let proto = E91Protocol::new(200, 0.9, 42);
397 let result = proto.run().expect("e91 run");
398 assert!(
399 result.chsh_value < 2.0,
400 "expected CHSH < 2 with high noise, got {}",
401 result.chsh_value
402 );
403 assert!(!result.passed_bell_test);
404 }
405
406 #[test]
407 fn e91_key_rate_reasonable() {
408 let proto = E91Protocol::new(3000, 0.0, 77);
409 let result = proto.run().expect("e91 run");
410 assert!(
412 result.key_rate > 0.02,
413 "expected reasonable key rate, got {}",
414 result.key_rate
415 );
416 }
417
418 #[test]
419 fn bell_phi_plus_correlation_correct() {
420 let rho = bell_phi_plus_4x4();
421 let e = compute_correlation(&rho, 0.0, PI / 8.0);
423 assert_abs_diff_eq!(e, (PI / 4.0).cos(), epsilon = 0.01);
424 }
425}