1use std::f64::consts::PI;
2use super::schrodinger::Complex;
3
4#[derive(Clone, Debug)]
6pub struct QubitState {
7 pub alpha: Complex,
8 pub beta: Complex,
9}
10
11impl QubitState {
12 pub fn new(alpha: Complex, beta: Complex) -> Self {
13 Self { alpha, beta }
14 }
15
16 pub fn zero() -> Self {
17 Self { alpha: Complex::one(), beta: Complex::zero() }
18 }
19
20 pub fn one() -> Self {
21 Self { alpha: Complex::zero(), beta: Complex::one() }
22 }
23
24 pub fn norm_sq(&self) -> f64 {
25 self.alpha.norm_sq() + self.beta.norm_sq()
26 }
27
28 pub fn normalize(&mut self) {
29 let n = self.norm_sq().sqrt();
30 if n > 1e-30 {
31 self.alpha = self.alpha / n;
32 self.beta = self.beta / n;
33 }
34 }
35}
36
37#[derive(Clone, Debug)]
39pub struct TwoQubitState {
40 pub amplitudes: [Complex; 4],
41}
42
43impl TwoQubitState {
44 pub fn new(amplitudes: [Complex; 4]) -> Self {
45 Self { amplitudes }
46 }
47
48 pub fn product(a: &QubitState, b: &QubitState) -> Self {
50 Self {
51 amplitudes: [
52 a.alpha * b.alpha, a.alpha * b.beta, a.beta * b.alpha, a.beta * b.beta, ],
57 }
58 }
59
60 pub fn norm_sq(&self) -> f64 {
61 self.amplitudes.iter().map(|c| c.norm_sq()).sum()
62 }
63
64 pub fn normalize(&mut self) {
65 let n = self.norm_sq().sqrt();
66 if n > 1e-30 {
67 for a in &mut self.amplitudes {
68 *a = *a / n;
69 }
70 }
71 }
72}
73
74pub fn bell_state(which: u8) -> TwoQubitState {
80 let s = 1.0 / 2.0_f64.sqrt();
81 match which {
82 0 => TwoQubitState::new([
83 Complex::new(s, 0.0), Complex::zero(),
84 Complex::zero(), Complex::new(s, 0.0),
85 ]),
86 1 => TwoQubitState::new([
87 Complex::new(s, 0.0), Complex::zero(),
88 Complex::zero(), Complex::new(-s, 0.0),
89 ]),
90 2 => TwoQubitState::new([
91 Complex::zero(), Complex::new(s, 0.0),
92 Complex::new(s, 0.0), Complex::zero(),
93 ]),
94 _ => TwoQubitState::new([
95 Complex::zero(), Complex::new(s, 0.0),
96 Complex::new(-s, 0.0), Complex::zero(),
97 ]),
98 }
99}
100
101pub fn measure_qubit(state: &TwoQubitState, which: usize, rng_val: f64) -> (u8, QubitState) {
104 let a = &state.amplitudes;
105 if which == 0 {
106 let p0 = a[0].norm_sq() + a[1].norm_sq(); if rng_val < p0 {
109 let mut q = QubitState::new(a[0], a[1]);
111 q.normalize();
112 (0, q)
113 } else {
114 let mut q = QubitState::new(a[2], a[3]);
115 q.normalize();
116 (1, q)
117 }
118 } else {
119 let p0 = a[0].norm_sq() + a[2].norm_sq();
121 if rng_val < p0 {
122 let mut q = QubitState::new(a[0], a[2]);
123 q.normalize();
124 (0, q)
125 } else {
126 let mut q = QubitState::new(a[1], a[3]);
127 q.normalize();
128 (1, q)
129 }
130 }
131}
132
133#[derive(Clone, Debug)]
135pub struct DensityMatrix2x2 {
136 pub rho: [[Complex; 2]; 2],
137}
138
139impl DensityMatrix2x2 {
140 pub fn trace(&self) -> f64 {
141 (self.rho[0][0] + self.rho[1][1]).re
142 }
143
144 pub fn purity(&self) -> f64 {
145 let mut sum = Complex::zero();
146 for i in 0..2 {
147 for j in 0..2 {
148 sum += self.rho[i][j] * self.rho[j][i];
149 }
150 }
151 sum.re
152 }
153
154 pub fn is_mixed(&self) -> bool {
155 self.purity() < 1.0 - 1e-6
156 }
157}
158
159pub fn partial_trace(state: &TwoQubitState, trace_out: usize) -> DensityMatrix2x2 {
161 let a = &state.amplitudes;
162 if trace_out == 1 {
163 let rho00 = a[0] * a[0].conj() + a[1] * a[1].conj();
165 let rho01 = a[0] * a[2].conj() + a[1] * a[3].conj();
166 let rho10 = a[2] * a[0].conj() + a[3] * a[1].conj();
167 let rho11 = a[2] * a[2].conj() + a[3] * a[3].conj();
168 DensityMatrix2x2 { rho: [[rho00, rho01], [rho10, rho11]] }
169 } else {
170 let rho00 = a[0] * a[0].conj() + a[2] * a[2].conj();
172 let rho01 = a[0] * a[1].conj() + a[2] * a[3].conj();
173 let rho10 = a[1] * a[0].conj() + a[3] * a[2].conj();
174 let rho11 = a[1] * a[1].conj() + a[3] * a[3].conj();
175 DensityMatrix2x2 { rho: [[rho00, rho01], [rho10, rho11]] }
176 }
177}
178
179pub fn concurrence(state: &TwoQubitState) -> f64 {
182 let a = state.amplitudes[0];
183 let b = state.amplitudes[1];
184 let c = state.amplitudes[2];
185 let d = state.amplitudes[3];
186 2.0 * (a * d - b * c).norm()
187}
188
189pub fn chsh_correlation(
193 state: &TwoQubitState,
194 a1: f64,
195 a2: f64,
196 b1: f64,
197 b2: f64,
198) -> f64 {
199 let e = |a_angle: f64, b_angle: f64| -> f64 {
200 let ca = a_angle.cos();
203 let sa = a_angle.sin();
204 let cb = b_angle.cos();
205 let sb = b_angle.sin();
206
207 let amp = &state.amplitudes;
208 let mut result = Complex::zero();
212 let a_mat = [[ca, sa], [sa, -ca]];
213 let b_mat = [[cb, sb], [sb, -cb]];
214
215 for i in 0..2 {
216 for j in 0..2 {
217 let bra_idx = i * 2 + j;
218 for k in 0..2 {
219 for l in 0..2 {
220 let ket_idx = k * 2 + l;
221 let coeff = a_mat[i][k] * b_mat[j][l];
222 result += amp[bra_idx].conj() * amp[ket_idx] * coeff;
223 }
224 }
225 }
226 }
227 result.re
228 };
229
230 let s = e(a1, b1) - e(a1, b2) + e(a2, b1) + e(a2, b2);
231 s.abs()
232}
233
234pub struct EntanglementRenderer {
236 pub width: usize,
237}
238
239impl EntanglementRenderer {
240 pub fn new(width: usize) -> Self {
241 Self { width }
242 }
243
244 pub fn render(&self, state: &TwoQubitState, measured: Option<(u8, u8)>) -> Vec<(char, f64, f64, f64)> {
246 let mut result = Vec::with_capacity(self.width);
247 let mid = self.width / 2;
248
249 for i in 0..self.width {
250 if let Some((m0, m1)) = measured {
251 if i < mid {
253 let ch = if m0 == 0 { '0' } else { '1' };
254 result.push((ch, 0.0, 1.0, 0.0));
255 } else {
256 let ch = if m1 == 0 { '0' } else { '1' };
257 result.push((ch, 1.0, 0.0, 0.0));
258 }
259 } else {
260 if i == mid - 2 || i == mid + 1 {
262 let prob = if i < mid {
263 state.amplitudes[0].norm_sq() + state.amplitudes[1].norm_sq()
264 } else {
265 state.amplitudes[0].norm_sq() + state.amplitudes[2].norm_sq()
266 };
267 let brightness = prob.min(1.0);
268 result.push(('*', brightness, brightness, 0.5));
269 } else if i == mid - 1 || i == mid {
270 result.push(('~', 0.3, 0.3, 0.8)); } else {
272 result.push((' ', 0.0, 0.0, 0.0));
273 }
274 }
275 }
276 result
277 }
278}
279
280#[derive(Clone, Debug)]
282pub struct GHZState {
283 pub n_qubits: usize,
284 pub amplitudes: Vec<Complex>,
285}
286
287impl GHZState {
288 pub fn new(n_qubits: usize) -> Self {
289 let size = 1 << n_qubits;
290 let mut amplitudes = vec![Complex::zero(); size];
291 let s = 1.0 / 2.0_f64.sqrt();
292 amplitudes[0] = Complex::new(s, 0.0); amplitudes[size - 1] = Complex::new(s, 0.0); Self { n_qubits, amplitudes }
295 }
296
297 pub fn norm_sq(&self) -> f64 {
298 self.amplitudes.iter().map(|c| c.norm_sq()).sum()
299 }
300
301 pub fn measure(&self, rng_val: f64) -> Vec<u8> {
303 let n = self.amplitudes.len();
304 let mut cumulative = 0.0;
305 let mut outcome = 0;
306 for i in 0..n {
307 cumulative += self.amplitudes[i].norm_sq();
308 if rng_val < cumulative {
309 outcome = i;
310 break;
311 }
312 }
313 (0..self.n_qubits)
315 .map(|bit| ((outcome >> (self.n_qubits - 1 - bit)) & 1) as u8)
316 .collect()
317 }
318}
319
320#[cfg(test)]
321mod tests {
322 use super::*;
323
324 #[test]
325 fn test_bell_states_normalized() {
326 for i in 0..4 {
327 let state = bell_state(i);
328 let norm = state.norm_sq();
329 assert!((norm - 1.0).abs() < 1e-10, "Bell state {} norm: {}", i, norm);
330 }
331 }
332
333 #[test]
334 fn test_bell_state_maximally_entangled() {
335 for i in 0..4 {
336 let state = bell_state(i);
337 let c = concurrence(&state);
338 assert!((c - 1.0).abs() < 1e-10, "Bell state {} concurrence: {}", i, c);
339 }
340 }
341
342 #[test]
343 fn test_product_state_not_entangled() {
344 let a = QubitState::zero();
345 let b = QubitState::zero();
346 let state = TwoQubitState::product(&a, &b);
347 let c = concurrence(&state);
348 assert!(c < 1e-10, "Product state concurrence: {}", c);
349 }
350
351 #[test]
352 fn test_measurement_correlation() {
353 let state = bell_state(0); let (outcome, remaining) = measure_qubit(&state, 0, 0.1); if outcome == 0 {
357 assert!(remaining.alpha.norm_sq() > 0.9);
359 } else {
360 assert!(remaining.beta.norm_sq() > 0.9);
362 }
363 }
364
365 #[test]
366 fn test_partial_trace_bell_gives_mixed() {
367 let state = bell_state(0);
368 let rho = partial_trace(&state, 1);
369 assert!(rho.is_mixed(), "Partial trace of Bell state should be mixed");
370 let purity = rho.purity();
371 assert!((purity - 0.5).abs() < 1e-10, "Purity: {}", purity);
372 }
373
374 #[test]
375 fn test_partial_trace_product_gives_pure() {
376 let a = QubitState::new(
377 Complex::new(1.0 / 2.0_f64.sqrt(), 0.0),
378 Complex::new(1.0 / 2.0_f64.sqrt(), 0.0),
379 );
380 let b = QubitState::zero();
381 let state = TwoQubitState::product(&a, &b);
382 let rho = partial_trace(&state, 1);
383 let purity = rho.purity();
384 assert!((purity - 1.0).abs() < 1e-10, "Product state purity: {}", purity);
385 }
386
387 #[test]
388 fn test_chsh_violation() {
389 let state = bell_state(0);
391 let s = chsh_correlation(&state, 0.0, PI / 2.0, PI / 4.0, -PI / 4.0);
392 assert!(s > 2.0, "CHSH S = {} should violate Bell inequality (> 2)", s);
393 assert!((s - 2.0 * 2.0_f64.sqrt()).abs() < 0.3, "S = {} should be ~2.828", s);
394 }
395
396 #[test]
397 fn test_chsh_classical_bound() {
398 let state = TwoQubitState::product(&QubitState::zero(), &QubitState::zero());
400 let s = chsh_correlation(&state, 0.0, PI / 2.0, PI / 4.0, -PI / 4.0);
401 assert!(s <= 2.1, "Product state S = {} should be <= 2", s);
402 }
403
404 #[test]
405 fn test_ghz_state() {
406 let ghz = GHZState::new(3);
407 assert_eq!(ghz.amplitudes.len(), 8);
408 let norm = ghz.norm_sq();
409 assert!((norm - 1.0).abs() < 1e-10);
410
411 let result_0 = ghz.measure(0.1);
413 assert!(result_0.iter().all(|&b| b == 0) || result_0.iter().all(|&b| b == 1));
414 let result_1 = ghz.measure(0.9);
415 assert!(result_1.iter().all(|&b| b == 0) || result_1.iter().all(|&b| b == 1));
416 }
417
418 #[test]
419 fn test_renderer() {
420 let state = bell_state(0);
421 let renderer = EntanglementRenderer::new(20);
422 let result = renderer.render(&state, None);
423 assert_eq!(result.len(), 20);
424 let measured = renderer.render(&state, Some((0, 0)));
425 assert_eq!(measured.len(), 20);
426 }
427}