Skip to main content

proof_engine/quantum/
spin.rs

1use std::f64::consts::PI;
2use super::schrodinger::Complex;
3use glam::Vec3;
4
5/// Spin-1/2 state: |psi> = up|+> + down|->.
6#[derive(Clone, Debug)]
7pub struct SpinState {
8    pub up: Complex,
9    pub down: Complex,
10}
11
12impl SpinState {
13    pub fn new(up: Complex, down: Complex) -> Self {
14        Self { up, down }
15    }
16
17    pub fn spin_up() -> Self {
18        Self { up: Complex::one(), down: Complex::zero() }
19    }
20
21    pub fn spin_down() -> Self {
22        Self { up: Complex::zero(), down: Complex::one() }
23    }
24
25    pub fn norm_sq(&self) -> f64 {
26        self.up.norm_sq() + self.down.norm_sq()
27    }
28
29    pub fn normalize(&mut self) {
30        let n = self.norm_sq().sqrt();
31        if n > 1e-30 {
32            self.up = self.up / n;
33            self.down = self.down / n;
34        }
35    }
36}
37
38/// Get Bloch sphere angles (theta, phi) from a spin state.
39/// |psi> = cos(theta/2)|+> + e^{i*phi}*sin(theta/2)|->
40pub fn bloch_angles(state: &SpinState) -> (f64, f64) {
41    let mut s = state.clone();
42    s.normalize();
43
44    let r_up = s.up.norm();
45    let r_down = s.down.norm();
46
47    let theta = 2.0 * r_down.atan2(r_up);
48
49    // phi is the relative phase between down and up
50    let phi = if r_down > 1e-12 && r_up > 1e-12 {
51        let phase_up = s.up.arg();
52        let phase_down = s.down.arg();
53        phase_down - phase_up
54    } else {
55        0.0
56    };
57
58    (theta, phi)
59}
60
61/// Create spin state from Bloch sphere angles.
62pub fn from_bloch(theta: f64, phi: f64) -> SpinState {
63    SpinState {
64        up: Complex::new((theta / 2.0).cos(), 0.0),
65        down: Complex::from_polar((theta / 2.0).sin(), phi),
66    }
67}
68
69/// Pauli X gate: |+> <-> |->
70pub fn pauli_x(state: &SpinState) -> SpinState {
71    SpinState {
72        up: state.down,
73        down: state.up,
74    }
75}
76
77/// Pauli Y gate: |+> -> i|-> , |-> -> -i|+>
78pub fn pauli_y(state: &SpinState) -> SpinState {
79    SpinState {
80        up: Complex::new(0.0, -1.0) * state.down,
81        down: Complex::new(0.0, 1.0) * state.up,
82    }
83}
84
85/// Pauli Z gate: |+> -> |+>, |-> -> -|->
86pub fn pauli_z(state: &SpinState) -> SpinState {
87    SpinState {
88        up: state.up,
89        down: -state.down,
90    }
91}
92
93/// Rotate spin state about an axis by an angle.
94/// Uses the SU(2) rotation: exp(-i * angle/2 * n.sigma)
95pub fn rotate_spin(state: &SpinState, axis: Vec3, angle: f64) -> SpinState {
96    let n = axis.normalize();
97    let half = angle / 2.0;
98    let c = half.cos();
99    let s = half.sin();
100    let nx = n.x as f64;
101    let ny = n.y as f64;
102    let nz = n.z as f64;
103
104    // R = cos(a/2)*I - i*sin(a/2)*(nx*sx + ny*sy + nz*sz)
105    // = [[cos - i*nz*sin, (-ny - i*nx)*sin],
106    //    [(ny - i*nx)*sin, cos + i*nz*sin]]
107    let r00 = Complex::new(c, -nz * s);
108    let r01 = Complex::new(-ny * s, -nx * s);
109    let r10 = Complex::new(ny * s, -nx * s);
110    let r11 = Complex::new(c, nz * s);
111
112    SpinState {
113        up: r00 * state.up + r01 * state.down,
114        down: r10 * state.up + r11 * state.down,
115    }
116}
117
118/// Expectation value of spin along an axis: <S.n> = (hbar/2)*<psi|sigma.n|psi>.
119/// Returns the value in units of hbar/2 (i.e., between -1 and 1).
120pub fn spin_expectation(state: &SpinState, axis: Vec3) -> f64 {
121    let n = axis.normalize();
122    let nx = n.x as f64;
123    let ny = n.y as f64;
124    let nz = n.z as f64;
125
126    // sigma.n = [[nz, nx-i*ny],[nx+i*ny, -nz]]
127    let s_up = Complex::new(nz, 0.0) * state.up + Complex::new(nx, -ny) * state.down;
128    let s_down = Complex::new(nx, ny) * state.up + Complex::new(-nz, 0.0) * state.down;
129
130    let exp = state.up.conj() * s_up + state.down.conj() * s_down;
131    exp.re
132}
133
134/// Larmor precession: time evolution in a magnetic field.
135/// H = -gamma * B.sigma, evolve by exp(-iHt/hbar).
136/// For simplicity, gamma*hbar/2 = 1, so omega = |B|.
137pub fn larmor_precession(state: &SpinState, b_field: Vec3, dt: f64) -> SpinState {
138    let b_mag = b_field.length() as f64;
139    if b_mag < 1e-15 {
140        return state.clone();
141    }
142    let axis = b_field / b_field.length();
143    // Precession angle = omega * dt = |B| * dt
144    rotate_spin(state, axis, b_mag * dt)
145}
146
147/// Stern-Gerlach measurement along an axis.
148/// Returns (collapsed state, outcome +1 or -1).
149pub fn stern_gerlach(state: &SpinState, measurement_axis: Vec3, rng_val: f64) -> (SpinState, i8) {
150    let n = measurement_axis.normalize();
151
152    // Eigenstates of sigma.n:
153    // |+n> = cos(theta/2)|+> + e^{i*phi}*sin(theta/2)|->
154    // |-n> = -e^{-i*phi}*sin(theta/2)|+> + cos(theta/2)|->
155    let theta = (n.z as f64).acos();
156    let phi = (n.y as f64).atan2(n.x as f64);
157
158    let plus_n = SpinState {
159        up: Complex::new((theta / 2.0).cos(), 0.0),
160        down: Complex::from_polar((theta / 2.0).sin(), phi),
161    };
162
163    // Probability of +1 outcome
164    let overlap = state.up.conj() * plus_n.up + state.down.conj() * plus_n.down;
165    let prob_plus = overlap.norm_sq();
166
167    if rng_val < prob_plus {
168        (plus_n, 1)
169    } else {
170        let minus_n = SpinState {
171            up: Complex::from_polar(-(theta / 2.0).sin(), -phi),
172            down: Complex::new((theta / 2.0).cos(), 0.0),
173        };
174        (minus_n, -1)
175    }
176}
177
178/// Render Bloch sphere as wireframe with state vector.
179pub struct BlochSphereRenderer {
180    pub size: usize,
181}
182
183impl BlochSphereRenderer {
184    pub fn new(size: usize) -> Self {
185        Self { size }
186    }
187
188    /// Render the Bloch sphere as ASCII art.
189    pub fn render(&self, state: &SpinState) -> Vec<Vec<char>> {
190        let (theta, phi) = bloch_angles(state);
191        let sx = theta.sin() * phi.cos();
192        let sy = theta.sin() * phi.sin();
193        let sz = theta.cos();
194
195        let s = self.size;
196        let mut grid = vec![vec![' '; s]; s];
197        let cx = s / 2;
198        let cy = s / 2;
199        let r = (s / 2 - 1) as f64;
200
201        // Draw circle outline
202        for i in 0..s {
203            for j in 0..s {
204                let dx = (j as f64 - cx as f64) / r;
205                let dy = (i as f64 - cy as f64) / r;
206                let dist = (dx * dx + dy * dy).sqrt();
207                if (dist - 1.0).abs() < 0.15 {
208                    grid[i][j] = '.';
209                }
210            }
211        }
212
213        // Draw equator
214        for j in 0..s {
215            let dx = (j as f64 - cx as f64) / r;
216            if dx.abs() <= 1.0 {
217                let row = cy;
218                if grid[row][j] == ' ' {
219                    grid[row][j] = '-';
220                }
221            }
222        }
223
224        // Draw vertical axis
225        for i in 0..s {
226            let dy = (i as f64 - cy as f64) / r;
227            if dy.abs() <= 1.0 {
228                if grid[i][cx] == ' ' {
229                    grid[i][cx] = '|';
230                }
231            }
232        }
233
234        // Place poles
235        grid[0][cx] = 'N'; // |+z>
236        grid[s - 1][cx] = 'S'; // |-z>
237
238        // Place state vector endpoint (project to xz plane for 2D view)
239        let state_col = cx as f64 + sx * r;
240        let state_row = cy as f64 - sz * r;
241        let sc = (state_col.round() as usize).min(s - 1);
242        let sr = (state_row.round() as usize).min(s - 1);
243        grid[sr][sc] = '*';
244
245        grid
246    }
247}
248
249/// Heisenberg spin chain with nearest-neighbor coupling.
250#[derive(Clone, Debug)]
251pub struct SpinChain {
252    pub spins: Vec<SpinState>,
253    pub coupling: f64,
254}
255
256impl SpinChain {
257    pub fn new(n: usize, coupling: f64) -> Self {
258        let spins = (0..n).map(|_| SpinState::spin_up()).collect();
259        Self { spins, coupling }
260    }
261
262    /// Simple time evolution step using Trotter decomposition.
263    /// H = J * sum_i sigma_i . sigma_{i+1}
264    pub fn step(&mut self, dt: f64) {
265        let n = self.spins.len();
266        if n < 2 {
267            return;
268        }
269
270        for i in 0..n - 1 {
271            // Approximate two-spin interaction
272            // The effective field on spin i from spin i+1 and vice versa
273            let exp_i = [
274                spin_expectation(&self.spins[i], Vec3::X),
275                spin_expectation(&self.spins[i], Vec3::Y),
276                spin_expectation(&self.spins[i], Vec3::Z),
277            ];
278            let exp_j = [
279                spin_expectation(&self.spins[i + 1], Vec3::X),
280                spin_expectation(&self.spins[i + 1], Vec3::Y),
281                spin_expectation(&self.spins[i + 1], Vec3::Z),
282            ];
283
284            // Mean-field: each spin sees the other as an effective B field
285            let b_on_i = Vec3::new(
286                exp_j[0] as f32,
287                exp_j[1] as f32,
288                exp_j[2] as f32,
289            ) * self.coupling as f32;
290            let b_on_j = Vec3::new(
291                exp_i[0] as f32,
292                exp_i[1] as f32,
293                exp_i[2] as f32,
294            ) * self.coupling as f32;
295
296            self.spins[i] = larmor_precession(&self.spins[i], b_on_i, dt);
297            self.spins[i + 1] = larmor_precession(&self.spins[i + 1], b_on_j, dt);
298        }
299    }
300
301    /// Total magnetization along z: sum of <Sz_i>.
302    pub fn magnetization_z(&self) -> f64 {
303        self.spins
304            .iter()
305            .map(|s| spin_expectation(s, Vec3::Z))
306            .sum()
307    }
308}
309
310#[cfg(test)]
311mod tests {
312    use super::*;
313
314    #[test]
315    fn test_bloch_roundtrip() {
316        let theta = 1.2;
317        let phi = 0.7;
318        let state = from_bloch(theta, phi);
319        let (t2, p2) = bloch_angles(&state);
320        assert!((t2 - theta).abs() < 1e-10, "theta: {} vs {}", t2, theta);
321        assert!((p2 - phi).abs() < 1e-10, "phi: {} vs {}", p2, phi);
322    }
323
324    #[test]
325    fn test_bloch_poles() {
326        let up = SpinState::spin_up();
327        let (theta, _) = bloch_angles(&up);
328        assert!(theta.abs() < 1e-10, "Up: theta = {}", theta);
329
330        let down = SpinState::spin_down();
331        let (theta, _) = bloch_angles(&down);
332        assert!((theta - PI).abs() < 1e-10, "Down: theta = {}", theta);
333    }
334
335    #[test]
336    fn test_pauli_x_algebra() {
337        // X^2 = I
338        let state = SpinState::new(Complex::new(0.6, 0.0), Complex::new(0.8, 0.0));
339        let xx = pauli_x(&pauli_x(&state));
340        assert!((xx.up.re - state.up.re).abs() < 1e-10);
341        assert!((xx.down.re - state.down.re).abs() < 1e-10);
342    }
343
344    #[test]
345    fn test_pauli_y_algebra() {
346        let state = SpinState::spin_up();
347        let yy = pauli_y(&pauli_y(&state));
348        // Y^2 = I
349        assert!((yy.up.re - state.up.re).abs() < 1e-10);
350        assert!((yy.down.re - state.down.re).abs() < 1e-10);
351    }
352
353    #[test]
354    fn test_pauli_z_algebra() {
355        let state = SpinState::new(Complex::new(0.6, 0.0), Complex::new(0.8, 0.0));
356        let zz = pauli_z(&pauli_z(&state));
357        assert!((zz.up.re - state.up.re).abs() < 1e-10);
358        assert!((zz.down.re - state.down.re).abs() < 1e-10);
359    }
360
361    #[test]
362    fn test_pauli_anticommutation() {
363        // XY = iZ, YX = -iZ => XY + YX = 0
364        let state = SpinState::new(Complex::new(0.6, 0.0), Complex::new(0.0, 0.8));
365        let xy = pauli_y(&pauli_x(&state));
366        let yx = pauli_x(&pauli_y(&state));
367        // xy + yx should be zero
368        let sum_up = xy.up + yx.up;
369        let sum_down = xy.down + yx.down;
370        assert!(sum_up.norm() < 1e-10, "XY+YX up = {:?}", sum_up);
371        assert!(sum_down.norm() < 1e-10, "XY+YX down = {:?}", sum_down);
372    }
373
374    #[test]
375    fn test_spin_expectation_z() {
376        let up = SpinState::spin_up();
377        let ez = spin_expectation(&up, Vec3::Z);
378        assert!((ez - 1.0).abs() < 1e-10);
379
380        let down = SpinState::spin_down();
381        let ez = spin_expectation(&down, Vec3::Z);
382        assert!((ez - (-1.0)).abs() < 1e-10);
383    }
384
385    #[test]
386    fn test_spin_expectation_x() {
387        // |+x> = (|+> + |->)/sqrt(2)
388        let s = 1.0 / 2.0_f64.sqrt();
389        let plus_x = SpinState::new(Complex::new(s, 0.0), Complex::new(s, 0.0));
390        let ex = spin_expectation(&plus_x, Vec3::X);
391        assert!((ex - 1.0).abs() < 1e-10, "<Sx> = {}", ex);
392    }
393
394    #[test]
395    fn test_rotation_360_identity() {
396        let state = SpinState::new(Complex::new(0.6, 0.1), Complex::new(0.3, 0.7));
397        let mut s = state.clone();
398        s.normalize();
399        // 4pi rotation = identity for spin-1/2 (2pi gives -1)
400        let rotated = rotate_spin(&s, Vec3::Z, 4.0 * PI);
401        assert!((rotated.up.re - s.up.re).abs() < 1e-8);
402        assert!((rotated.down.re - s.down.re).abs() < 1e-8);
403    }
404
405    #[test]
406    fn test_rotation_pi_z_flips_x() {
407        // Rotating |+x> by pi about z should give |-x>
408        let s = 1.0 / 2.0_f64.sqrt();
409        let plus_x = SpinState::new(Complex::new(s, 0.0), Complex::new(s, 0.0));
410        let rotated = rotate_spin(&plus_x, Vec3::Z, PI);
411        let ex = spin_expectation(&rotated, Vec3::X);
412        assert!((ex - (-1.0)).abs() < 1e-8, "<Sx> after pi-z rotation: {}", ex);
413    }
414
415    #[test]
416    fn test_stern_gerlach_statistics() {
417        // |+z> measured along z should always give +1
418        let up = SpinState::spin_up();
419        let (_, outcome) = stern_gerlach(&up, Vec3::Z, 0.3);
420        assert_eq!(outcome, 1);
421        let (_, outcome) = stern_gerlach(&up, Vec3::Z, 0.9);
422        assert_eq!(outcome, 1);
423    }
424
425    #[test]
426    fn test_stern_gerlach_x_on_z_up() {
427        // |+z> measured along x should give 50/50
428        let up = SpinState::spin_up();
429        let (_, o1) = stern_gerlach(&up, Vec3::X, 0.2);
430        let (_, o2) = stern_gerlach(&up, Vec3::X, 0.8);
431        assert_eq!(o1, 1);
432        assert_eq!(o2, -1);
433    }
434
435    #[test]
436    fn test_larmor_precession() {
437        // Spin up in x-field should precess to spin down and back
438        let state = SpinState::spin_up();
439        let b = Vec3::X * 1.0;
440        let mut s = state;
441        for _ in 0..1000 {
442            s = larmor_precession(&s, b, 0.01);
443        }
444        // After time T = 2*pi/|B| = 2*pi, should return to original
445        // We evolved for t=10, which is about 1.59 periods
446        // Just check it's still normalized
447        assert!((s.norm_sq() - 1.0).abs() < 1e-6);
448    }
449
450    #[test]
451    fn test_bloch_sphere_renderer() {
452        let state = SpinState::spin_up();
453        let renderer = BlochSphereRenderer::new(15);
454        let grid = renderer.render(&state);
455        assert_eq!(grid.len(), 15);
456        assert_eq!(grid[0].len(), 15);
457    }
458
459    #[test]
460    fn test_spin_chain() {
461        let mut chain = SpinChain::new(4, 1.0);
462        // Flip one spin
463        chain.spins[0] = SpinState::spin_down();
464        let m_before = chain.magnetization_z();
465        chain.step(0.01);
466        let m_after = chain.magnetization_z();
467        // Magnetization should be roughly conserved in Heisenberg model
468        assert!((m_before - m_after).abs() < 0.5);
469    }
470}