Skip to main content

proof_engine/quantum/
harmonic.rs

1use std::f64::consts::PI;
2use super::schrodinger::Complex;
3
4/// Energy of the nth level of a quantum harmonic oscillator.
5pub fn qho_energy(n: u32, omega: f64, hbar: f64) -> f64 {
6    hbar * omega * (n as f64 + 0.5)
7}
8
9/// Physicist's Hermite polynomial H_n(x) via recurrence.
10/// H_0(x) = 1, H_1(x) = 2x, H_{n+1}(x) = 2x H_n(x) - 2n H_{n-1}(x)
11pub fn hermite_polynomial(n: u32, x: f64) -> f64 {
12    if n == 0 {
13        return 1.0;
14    }
15    if n == 1 {
16        return 2.0 * x;
17    }
18    let mut h_prev = 1.0;
19    let mut h_curr = 2.0 * x;
20    for k in 1..n {
21        let h_next = 2.0 * x * h_curr - 2.0 * k as f64 * h_prev;
22        h_prev = h_curr;
23        h_curr = h_next;
24    }
25    h_curr
26}
27
28/// QHO wave function psi_n(x) = N_n * H_n(xi) * exp(-xi^2/2)
29/// where xi = sqrt(m*omega/hbar) * x
30pub fn qho_wavefunction(n: u32, x: f64, omega: f64, mass: f64, hbar: f64) -> f64 {
31    let alpha = (mass * omega / hbar).sqrt();
32    let xi = alpha * x;
33
34    // Normalization: N_n = (alpha / (sqrt(pi) * 2^n * n!))^{1/2}
35    let n_fact: f64 = (1..=n as u64).map(|k| k as f64).product::<f64>().max(1.0);
36    let two_n: f64 = 2.0_f64.powi(n as i32);
37    let norm = (alpha / (PI.sqrt() * two_n * n_fact)).sqrt();
38
39    norm * hermite_polynomial(n, xi) * (-xi * xi / 2.0).exp()
40}
41
42/// Creation (raising) operator: a+ |n> = sqrt(n+1) |n+1>
43/// Applied numerically: a+ psi(x) = (1/sqrt(2)) * (xi - d/dxi) psi(x) in scaled coords
44pub fn qho_ladder_up(psi: &[Complex], x_grid: &[f64], omega: f64, mass: f64, hbar: f64) -> Vec<Complex> {
45    let n = psi.len();
46    let alpha = (mass * omega / hbar).sqrt();
47    let dx = if n > 1 { x_grid[1] - x_grid[0] } else { 1.0 };
48    let factor = 1.0 / (2.0_f64).sqrt();
49
50    let mut result = vec![Complex::zero(); n];
51    for i in 0..n {
52        let xi = alpha * x_grid[i];
53        // Numerical derivative
54        let dpsi = if i == 0 {
55            (psi[1] - psi[0]) / dx
56        } else if i == n - 1 {
57            (psi[n - 1] - psi[n - 2]) / dx
58        } else {
59            (psi[i + 1] - psi[i - 1]) / (2.0 * dx)
60        };
61        // a+ = (1/sqrt(2)) * (xi * psi - (1/alpha) * dpsi/dx ... actually in position rep:
62        // a+ = sqrt(m*omega/(2*hbar)) * x - i*p/(sqrt(2*m*omega*hbar))
63        // = sqrt(m*omega/(2*hbar)) * x - (1/sqrt(2*m*omega*hbar)) * (-i*hbar) d/dx
64        // = (alpha/sqrt(2)) * x - (1/(alpha*sqrt(2))) * d/dx
65        let coeff_x = alpha / (2.0_f64).sqrt();
66        let coeff_d = 1.0 / (alpha * (2.0_f64).sqrt());
67        result[i] = psi[i] * coeff_x * x_grid[i] - dpsi * coeff_d;
68    }
69    result
70}
71
72/// Annihilation (lowering) operator: a |n> = sqrt(n) |n-1>
73/// a = (alpha/sqrt(2)) * x + (1/(alpha*sqrt(2))) * d/dx
74pub fn qho_ladder_down(psi: &[Complex], x_grid: &[f64], omega: f64, mass: f64, hbar: f64) -> Vec<Complex> {
75    let n = psi.len();
76    let alpha = (mass * omega / hbar).sqrt();
77    let dx = if n > 1 { x_grid[1] - x_grid[0] } else { 1.0 };
78
79    let mut result = vec![Complex::zero(); n];
80    for i in 0..n {
81        let dpsi = if i == 0 {
82            (psi[1] - psi[0]) / dx
83        } else if i == n - 1 {
84            (psi[n - 1] - psi[n - 2]) / dx
85        } else {
86            (psi[i + 1] - psi[i - 1]) / (2.0 * dx)
87        };
88        let coeff_x = alpha / (2.0_f64).sqrt();
89        let coeff_d = 1.0 / (alpha * (2.0_f64).sqrt());
90        result[i] = psi[i] * coeff_x * x_grid[i] + dpsi * coeff_d;
91    }
92    result
93}
94
95/// Time evolution of coherent state parameter: alpha(t) = alpha_0 * exp(-i*omega*t)
96pub fn coherent_state_evolution(alpha: Complex, omega: f64, t: f64) -> Complex {
97    let phase = Complex::from_polar(1.0, -omega * t);
98    alpha * phase
99}
100
101/// Probability of finding n photons in a coherent state |alpha>: P(n) = |alpha|^{2n} e^{-|alpha|^2} / n!
102pub fn number_state_probability(coherent_alpha: Complex, n: u32) -> f64 {
103    let alpha_sq = coherent_alpha.norm_sq();
104    let n_fact: f64 = (1..=n as u64).map(|k| k as f64).product::<f64>().max(1.0);
105    alpha_sq.powi(n as i32) * (-alpha_sq).exp() / n_fact
106}
107
108/// Render QHO energy levels and wavefunctions.
109pub struct QHORenderer {
110    pub width: usize,
111    pub height: usize,
112    pub n_levels: usize,
113    pub omega: f64,
114    pub mass: f64,
115    pub hbar: f64,
116}
117
118impl QHORenderer {
119    pub fn new(width: usize, height: usize, n_levels: usize) -> Self {
120        Self {
121            width,
122            height,
123            n_levels,
124            omega: 1.0,
125            mass: 1.0,
126            hbar: 1.0,
127        }
128    }
129
130    /// Render energy levels as horizontal lines with wavefunctions overlaid.
131    /// Returns a grid of (char, r, g, b).
132    pub fn render(&self) -> Vec<Vec<(char, f64, f64, f64)>> {
133        let x_min = -5.0;
134        let x_max = 5.0;
135        let e_max = qho_energy(self.n_levels as u32, self.omega, self.hbar);
136
137        let mut grid = vec![vec![(' ', 0.0, 0.0, 0.0); self.width]; self.height];
138
139        for level in 0..self.n_levels {
140            let e = qho_energy(level as u32, self.omega, self.hbar);
141            let y_frac = e / e_max;
142            let row = self.height - 1 - ((y_frac * (self.height - 1) as f64) as usize).min(self.height - 1);
143
144            // Color based on level
145            let hue = level as f64 / self.n_levels as f64;
146            let (r, g, b) = super::wavefunction::PhaseColorMap::hsv_to_rgb(hue, 0.8, 1.0);
147
148            for col in 0..self.width {
149                let x = x_min + (col as f64 / self.width as f64) * (x_max - x_min);
150                let psi = qho_wavefunction(level as u32, x, self.omega, self.mass, self.hbar);
151                let offset = (psi * 3.0) as i32;
152                let draw_row = (row as i32 - offset) as usize;
153                if draw_row < self.height {
154                    let brightness = psi.abs().min(1.0);
155                    if brightness > 0.05 {
156                        grid[draw_row][col] = ('*', r * brightness, g * brightness, b * brightness);
157                    }
158                }
159                // Draw energy level line
160                if grid[row][col].0 == ' ' {
161                    grid[row][col] = ('-', r * 0.3, g * 0.3, b * 0.3);
162                }
163            }
164        }
165
166        // Draw potential (parabola)
167        for col in 0..self.width {
168            let x = x_min + (col as f64 / self.width as f64) * (x_max - x_min);
169            let v = 0.5 * self.mass * self.omega * self.omega * x * x;
170            let y_frac = v / e_max;
171            if y_frac <= 1.0 {
172                let row = self.height - 1 - ((y_frac * (self.height - 1) as f64) as usize).min(self.height - 1);
173                if grid[row][col].0 == ' ' || grid[row][col].0 == '-' {
174                    grid[row][col] = ('|', 0.3, 0.3, 0.3);
175                }
176            }
177        }
178
179        grid
180    }
181}
182
183#[cfg(test)]
184mod tests {
185    use super::*;
186
187    #[test]
188    fn test_qho_energy_levels() {
189        assert!((qho_energy(0, 1.0, 1.0) - 0.5).abs() < 1e-10);
190        assert!((qho_energy(1, 1.0, 1.0) - 1.5).abs() < 1e-10);
191        assert!((qho_energy(2, 1.0, 1.0) - 2.5).abs() < 1e-10);
192        assert!((qho_energy(5, 2.0, 1.0) - 11.0).abs() < 1e-10);
193    }
194
195    #[test]
196    fn test_hermite_polynomials() {
197        assert!((hermite_polynomial(0, 1.0) - 1.0).abs() < 1e-10);
198        assert!((hermite_polynomial(1, 1.0) - 2.0).abs() < 1e-10);
199        // H_2(x) = 4x^2 - 2
200        assert!((hermite_polynomial(2, 1.0) - 2.0).abs() < 1e-10);
201        assert!((hermite_polynomial(2, 0.0) - (-2.0)).abs() < 1e-10);
202        // H_3(x) = 8x^3 - 12x
203        assert!((hermite_polynomial(3, 1.0) - (-4.0)).abs() < 1e-10);
204    }
205
206    #[test]
207    fn test_wavefunction_normalization() {
208        let dx = 0.01;
209        let n_points = 2000;
210        let x_min = -10.0;
211        for n in 0..5 {
212            let integral: f64 = (0..n_points)
213                .map(|i| {
214                    let x = x_min + i as f64 * dx;
215                    let psi = qho_wavefunction(n, x, 1.0, 1.0, 1.0);
216                    psi * psi * dx
217                })
218                .sum();
219            assert!(
220                (integral - 1.0).abs() < 0.02,
221                "n={}: integral={}",
222                n,
223                integral
224            );
225        }
226    }
227
228    #[test]
229    fn test_orthogonality() {
230        let dx = 0.01;
231        let n_points = 2000;
232        let x_min = -10.0;
233
234        // <0|1> should be 0
235        let integral: f64 = (0..n_points)
236            .map(|i| {
237                let x = x_min + i as f64 * dx;
238                let psi0 = qho_wavefunction(0, x, 1.0, 1.0, 1.0);
239                let psi1 = qho_wavefunction(1, x, 1.0, 1.0, 1.0);
240                psi0 * psi1 * dx
241            })
242            .sum();
243        assert!(integral.abs() < 0.02, "<0|1> = {}", integral);
244
245        // <0|2> should be 0
246        let integral: f64 = (0..n_points)
247            .map(|i| {
248                let x = x_min + i as f64 * dx;
249                let psi0 = qho_wavefunction(0, x, 1.0, 1.0, 1.0);
250                let psi2 = qho_wavefunction(2, x, 1.0, 1.0, 1.0);
251                psi0 * psi2 * dx
252            })
253            .sum();
254        assert!(integral.abs() < 0.02, "<0|2> = {}", integral);
255    }
256
257    #[test]
258    fn test_ladder_operators() {
259        // a|0> should be ~0
260        let n_pts = 512;
261        let dx = 0.05;
262        let x_grid: Vec<f64> = (0..n_pts).map(|i| -12.8 + i as f64 * dx).collect();
263        let psi0: Vec<Complex> = x_grid
264            .iter()
265            .map(|&x| Complex::new(qho_wavefunction(0, x, 1.0, 1.0, 1.0), 0.0))
266            .collect();
267        let a_psi0 = qho_ladder_down(&psi0, &x_grid, 1.0, 1.0, 1.0);
268        let norm: f64 = a_psi0.iter().map(|c| c.norm_sq()).sum::<f64>() * dx;
269        assert!(norm < 0.1, "a|0> norm = {}", norm);
270
271        // a+|0> should be proportional to |1>
272        let a_up_psi0 = qho_ladder_up(&psi0, &x_grid, 1.0, 1.0, 1.0);
273        let norm_up: f64 = a_up_psi0.iter().map(|c| c.norm_sq()).sum::<f64>() * dx;
274        // Should be ~1 (since a+|0> = |1>)
275        assert!((norm_up - 1.0).abs() < 0.3, "a+|0> norm = {}", norm_up);
276    }
277
278    #[test]
279    fn test_ground_state_uncertainty() {
280        // For ground state: dx * dp = hbar/2
281        let n_pts = 1024;
282        let dx_grid = 0.02;
283        let x_grid: Vec<f64> = (0..n_pts).map(|i| -10.0 + i as f64 * dx_grid).collect();
284        let psi: Vec<Complex> = x_grid
285            .iter()
286            .map(|&x| Complex::new(qho_wavefunction(0, x, 1.0, 1.0, 1.0), 0.0))
287            .collect();
288        let wf = super::super::schrodinger::WaveFunction1D::new(psi, dx_grid, -10.0);
289
290        let dx_unc = super::super::schrodinger::uncertainty_x(&wf);
291        let dp_unc = super::super::schrodinger::uncertainty_p(&wf, 1.0);
292        let product = dx_unc * dp_unc;
293        // Should be hbar/2 = 0.5
294        assert!(
295            (product - 0.5).abs() < 0.15,
296            "dx*dp = {} (expected 0.5)",
297            product
298        );
299    }
300
301    #[test]
302    fn test_coherent_state_evolution() {
303        let alpha = Complex::new(1.0, 0.0);
304        let evolved = coherent_state_evolution(alpha, 1.0, PI);
305        // After half period, alpha -> -alpha
306        assert!((evolved.re - (-1.0)).abs() < 1e-10);
307        assert!(evolved.im.abs() < 1e-10);
308    }
309
310    #[test]
311    fn test_number_state_probability() {
312        let alpha = Complex::new(2.0, 0.0);
313        let total: f64 = (0..20).map(|n| number_state_probability(alpha, n)).sum();
314        assert!((total - 1.0).abs() < 0.01, "Total prob: {}", total);
315
316        // Mean should be |alpha|^2 = 4
317        let mean: f64 = (0..20)
318            .map(|n| n as f64 * number_state_probability(alpha, n))
319            .sum();
320        assert!((mean - 4.0).abs() < 0.1, "Mean: {}", mean);
321    }
322
323    #[test]
324    fn test_renderer() {
325        let renderer = QHORenderer::new(40, 20, 4);
326        let grid = renderer.render();
327        assert_eq!(grid.len(), 20);
328        assert_eq!(grid[0].len(), 40);
329    }
330}