Skip to main content

proof_engine/quantum/
schrodinger.rs

1use std::f64::consts::PI;
2
3/// Complex number with full arithmetic.
4#[derive(Clone, Copy, Debug, PartialEq)]
5pub struct Complex {
6    pub re: f64,
7    pub im: f64,
8}
9
10impl Complex {
11    pub fn new(re: f64, im: f64) -> Self {
12        Self { re, im }
13    }
14
15    pub fn zero() -> Self {
16        Self { re: 0.0, im: 0.0 }
17    }
18
19    pub fn one() -> Self {
20        Self { re: 1.0, im: 0.0 }
21    }
22
23    pub fn i() -> Self {
24        Self { re: 0.0, im: 1.0 }
25    }
26
27    pub fn norm_sq(&self) -> f64 {
28        self.re * self.re + self.im * self.im
29    }
30
31    pub fn norm(&self) -> f64 {
32        self.norm_sq().sqrt()
33    }
34
35    pub fn conj(&self) -> Self {
36        Self { re: self.re, im: -self.im }
37    }
38
39    pub fn exp(self) -> Self {
40        let e = self.re.exp();
41        Self {
42            re: e * self.im.cos(),
43            im: e * self.im.sin(),
44        }
45    }
46
47    pub fn from_polar(r: f64, theta: f64) -> Self {
48        Self {
49            re: r * theta.cos(),
50            im: r * theta.sin(),
51        }
52    }
53
54    pub fn arg(&self) -> f64 {
55        self.im.atan2(self.re)
56    }
57
58    pub fn scale(self, s: f64) -> Self {
59        Self { re: self.re * s, im: self.im * s }
60    }
61}
62
63impl std::ops::Add for Complex {
64    type Output = Self;
65    fn add(self, rhs: Self) -> Self {
66        Self { re: self.re + rhs.re, im: self.im + rhs.im }
67    }
68}
69
70impl std::ops::AddAssign for Complex {
71    fn add_assign(&mut self, rhs: Self) {
72        self.re += rhs.re;
73        self.im += rhs.im;
74    }
75}
76
77impl std::ops::Sub for Complex {
78    type Output = Self;
79    fn sub(self, rhs: Self) -> Self {
80        Self { re: self.re - rhs.re, im: self.im - rhs.im }
81    }
82}
83
84impl std::ops::SubAssign for Complex {
85    fn sub_assign(&mut self, rhs: Self) {
86        self.re -= rhs.re;
87        self.im -= rhs.im;
88    }
89}
90
91impl std::ops::Mul for Complex {
92    type Output = Self;
93    fn mul(self, rhs: Self) -> Self {
94        Self {
95            re: self.re * rhs.re - self.im * rhs.im,
96            im: self.re * rhs.im + self.im * rhs.re,
97        }
98    }
99}
100
101impl std::ops::MulAssign for Complex {
102    fn mul_assign(&mut self, rhs: Self) {
103        let re = self.re * rhs.re - self.im * rhs.im;
104        let im = self.re * rhs.im + self.im * rhs.re;
105        self.re = re;
106        self.im = im;
107    }
108}
109
110impl std::ops::Div for Complex {
111    type Output = Self;
112    fn div(self, rhs: Self) -> Self {
113        let d = rhs.norm_sq();
114        Self {
115            re: (self.re * rhs.re + self.im * rhs.im) / d,
116            im: (self.im * rhs.re - self.re * rhs.im) / d,
117        }
118    }
119}
120
121impl std::ops::Neg for Complex {
122    type Output = Self;
123    fn neg(self) -> Self {
124        Self { re: -self.re, im: -self.im }
125    }
126}
127
128impl std::ops::Mul<f64> for Complex {
129    type Output = Self;
130    fn mul(self, rhs: f64) -> Self {
131        Self { re: self.re * rhs, im: self.im * rhs }
132    }
133}
134
135impl std::ops::Mul<Complex> for f64 {
136    type Output = Complex;
137    fn mul(self, rhs: Complex) -> Complex {
138        Complex { re: self * rhs.re, im: self * rhs.im }
139    }
140}
141
142impl std::ops::Div<f64> for Complex {
143    type Output = Self;
144    fn div(self, rhs: f64) -> Self {
145        Self { re: self.re / rhs, im: self.im / rhs }
146    }
147}
148
149impl Default for Complex {
150    fn default() -> Self {
151        Self::zero()
152    }
153}
154
155/// 1D wave function on a uniform grid.
156#[derive(Clone, Debug)]
157pub struct WaveFunction1D {
158    pub psi: Vec<Complex>,
159    pub dx: f64,
160    pub x_min: f64,
161}
162
163impl WaveFunction1D {
164    pub fn new(psi: Vec<Complex>, dx: f64, x_min: f64) -> Self {
165        Self { psi, dx, x_min }
166    }
167
168    pub fn n(&self) -> usize {
169        self.psi.len()
170    }
171
172    pub fn x_at(&self, i: usize) -> f64 {
173        self.x_min + i as f64 * self.dx
174    }
175
176    pub fn x_max(&self) -> f64 {
177        self.x_min + (self.n() - 1) as f64 * self.dx
178    }
179
180    pub fn probability_density(&self) -> Vec<f64> {
181        self.psi.iter().map(|c| c.norm_sq()).collect()
182    }
183
184    pub fn norm_squared(&self) -> f64 {
185        self.psi.iter().map(|c| c.norm_sq()).sum::<f64>() * self.dx
186    }
187
188    pub fn normalize(&mut self) {
189        let n = self.norm_squared().sqrt();
190        if n > 1e-30 {
191            for c in &mut self.psi {
192                *c = *c / n;
193            }
194        }
195    }
196}
197
198/// Solve tridiagonal system Ax = d where A has diagonals (a, b, c).
199/// a[0] and c[n-1] are unused. Modifies d in-place and returns solution.
200fn solve_tridiagonal(a: &[Complex], b: &[Complex], c: &[Complex], d: &mut [Complex]) -> Vec<Complex> {
201    let n = d.len();
202    let mut cp = vec![Complex::zero(); n];
203    let mut dp = vec![Complex::zero(); n];
204
205    cp[0] = c[0] / b[0];
206    dp[0] = d[0] / b[0];
207
208    for i in 1..n {
209        let m = b[i] - a[i] * cp[i - 1];
210        cp[i] = if i < n - 1 { c[i] / m } else { Complex::zero() };
211        dp[i] = (d[i] - a[i] * dp[i - 1]) / m;
212    }
213
214    let mut x = vec![Complex::zero(); n];
215    x[n - 1] = dp[n - 1];
216    for i in (0..n - 1).rev() {
217        x[i] = dp[i] - cp[i] * x[i + 1];
218    }
219    x
220}
221
222/// 1D Schrodinger equation solver using Crank-Nicolson and split-operator methods.
223#[derive(Clone)]
224pub struct SchrodingerSolver1D {
225    pub psi: WaveFunction1D,
226    pub potential: Vec<f64>,
227    pub mass: f64,
228    pub hbar: f64,
229    pub dt: f64,
230}
231
232impl SchrodingerSolver1D {
233    pub fn new(psi: WaveFunction1D, potential: Vec<f64>, mass: f64, hbar: f64, dt: f64) -> Self {
234        Self { psi, potential, mass, hbar, dt }
235    }
236
237    /// Crank-Nicolson time step (implicit, unitary).
238    pub fn step(&mut self) {
239        let n = self.psi.n();
240        let dx = self.psi.dx;
241        let r = Complex::new(0.0, self.hbar * self.dt / (4.0 * self.mass * dx * dx));
242
243        let mut a_lower = vec![Complex::zero(); n];
244        let mut a_diag = vec![Complex::zero(); n];
245        let mut a_upper = vec![Complex::zero(); n];
246        let mut rhs = vec![Complex::zero(); n];
247
248        for j in 0..n {
249            let v_term = Complex::new(0.0, self.dt * self.potential[j] / (2.0 * self.hbar));
250
251            // LHS: (1 + iHdt/2)
252            a_diag[j] = Complex::one() + r * 2.0 + v_term;
253            if j > 0 {
254                a_lower[j] = -r;
255            }
256            if j < n - 1 {
257                a_upper[j] = -r;
258            }
259
260            // RHS: (1 - iHdt/2) * psi
261            let psi_j = self.psi.psi[j];
262            let psi_left = if j > 0 { self.psi.psi[j - 1] } else { Complex::zero() };
263            let psi_right = if j < n - 1 { self.psi.psi[j + 1] } else { Complex::zero() };
264
265            rhs[j] = (Complex::one() - r * 2.0 - v_term) * psi_j
266                + r * psi_left
267                + r * psi_right;
268        }
269
270        self.psi.psi = solve_tridiagonal(&a_lower, &a_diag, &a_upper, &mut rhs);
271    }
272
273    /// Split-operator FFT method: kinetic in k-space, potential in x-space.
274    pub fn step_split_operator(&mut self) {
275        let n = self.psi.n();
276        let dx = self.psi.dx;
277        let dt = self.dt;
278        let hbar = self.hbar;
279        let mass = self.mass;
280
281        // Half potential step in x-space
282        for j in 0..n {
283            let phase = -self.potential[j] * dt / (2.0 * hbar);
284            let exp_v = Complex::from_polar(1.0, phase);
285            self.psi.psi[j] = self.psi.psi[j] * exp_v;
286        }
287
288        // Full kinetic step in k-space
289        let mut psi_k = dft(&self.psi.psi);
290        let dk = 2.0 * PI / (n as f64 * dx);
291        for j in 0..n {
292            let k = if j <= n / 2 {
293                j as f64 * dk
294            } else {
295                (j as f64 - n as f64) * dk
296            };
297            let phase = -hbar * k * k * dt / (2.0 * mass);
298            let exp_t = Complex::from_polar(1.0, phase);
299            psi_k[j] = psi_k[j] * exp_t;
300        }
301        self.psi.psi = idft(&psi_k);
302
303        // Half potential step in x-space
304        for j in 0..n {
305            let phase = -self.potential[j] * dt / (2.0 * hbar);
306            let exp_v = Complex::from_polar(1.0, phase);
307            self.psi.psi[j] = self.psi.psi[j] * exp_v;
308        }
309    }
310}
311
312/// Discrete Fourier Transform.
313pub fn dft(input: &[Complex]) -> Vec<Complex> {
314    let n = input.len();
315    let mut output = vec![Complex::zero(); n];
316    for k in 0..n {
317        let mut sum = Complex::zero();
318        for j in 0..n {
319            let angle = -2.0 * PI * (k as f64) * (j as f64) / (n as f64);
320            sum += input[j] * Complex::from_polar(1.0, angle);
321        }
322        output[k] = sum;
323    }
324    output
325}
326
327/// Inverse Discrete Fourier Transform.
328pub fn idft(input: &[Complex]) -> Vec<Complex> {
329    let n = input.len();
330    let mut output = vec![Complex::zero(); n];
331    for j in 0..n {
332        let mut sum = Complex::zero();
333        for k in 0..n {
334            let angle = 2.0 * PI * (k as f64) * (j as f64) / (n as f64);
335            sum += input[k] * Complex::from_polar(1.0, angle);
336        }
337        output[j] = sum / n as f64;
338    }
339    output
340}
341
342/// Compute energy eigenvalues using the shooting method for the time-independent
343/// Schrodinger equation with the given potential on a uniform grid.
344pub fn energy_eigenvalues(potential: &[f64], n_states: usize, dx: f64, mass: f64, hbar: f64) -> Vec<f64> {
345    let n = potential.len();
346    let v_min = potential.iter().cloned().fold(f64::INFINITY, f64::min);
347    let v_max = potential.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
348
349    let mut eigenvalues = Vec::new();
350    let de = (v_max - v_min + 10.0) / 10000.0;
351    let mut e = v_min + de;
352    let mut prev_end = shoot(potential, e, dx, mass, hbar);
353
354    while eigenvalues.len() < n_states && e < v_max + 50.0 {
355        e += de;
356        let cur_end = shoot(potential, e, dx, mass, hbar);
357        if prev_end * cur_end < 0.0 {
358            // Sign change: refine with bisection
359            let mut lo = e - de;
360            let mut hi = e;
361            for _ in 0..60 {
362                let mid = (lo + hi) / 2.0;
363                let mid_val = shoot(potential, mid, dx, mass, hbar);
364                if mid_val * shoot(potential, lo, dx, mass, hbar) < 0.0 {
365                    hi = mid;
366                } else {
367                    lo = mid;
368                }
369            }
370            eigenvalues.push((lo + hi) / 2.0);
371        }
372        prev_end = cur_end;
373    }
374    eigenvalues
375}
376
377/// Shooting method: integrate Schrodinger equation from left to right and return
378/// the value of psi at the right boundary. A zero crossing indicates an eigenvalue.
379fn shoot(potential: &[f64], energy: f64, dx: f64, mass: f64, hbar: f64) -> f64 {
380    let n = potential.len();
381    let coeff = 2.0 * mass / (hbar * hbar);
382    let mut psi_prev = 0.0_f64;
383    let mut psi_curr = 1e-10_f64;
384
385    for i in 1..n - 1 {
386        let k_sq = coeff * (potential[i] - energy);
387        let psi_next = 2.0 * psi_curr - psi_prev + dx * dx * k_sq * psi_curr;
388        psi_prev = psi_curr;
389        psi_curr = psi_next;
390    }
391    psi_curr
392}
393
394/// Compute energy eigenstates using the shooting method.
395pub fn energy_eigenstates(
396    potential: &[f64],
397    n_states: usize,
398    dx: f64,
399    x_min: f64,
400    mass: f64,
401    hbar: f64,
402) -> Vec<WaveFunction1D> {
403    let eigenvalues = energy_eigenvalues(potential, n_states, dx, mass, hbar);
404    let n = potential.len();
405    let coeff = 2.0 * mass / (hbar * hbar);
406
407    eigenvalues
408        .iter()
409        .map(|&energy| {
410            let mut psi = vec![0.0_f64; n];
411            psi[1] = 1e-10;
412            for i in 1..n - 1 {
413                let k_sq = coeff * (potential[i] - energy);
414                psi[i + 1] = 2.0 * psi[i] - psi[i - 1] + dx * dx * k_sq * psi[i];
415            }
416            let psi_c: Vec<Complex> = psi.iter().map(|&v| Complex::new(v, 0.0)).collect();
417            let mut wf = WaveFunction1D::new(psi_c, dx, x_min);
418            wf.normalize();
419            wf
420        })
421        .collect()
422}
423
424/// 2D Schrodinger solver using ADI (alternating direction implicit) method.
425#[derive(Clone)]
426pub struct SchrodingerSolver2D {
427    pub psi: Vec<Vec<Complex>>,
428    pub potential: Vec<Vec<f64>>,
429    pub nx: usize,
430    pub ny: usize,
431    pub dx: f64,
432    pub dy: f64,
433    pub dt: f64,
434    pub mass: f64,
435    pub hbar: f64,
436}
437
438impl SchrodingerSolver2D {
439    pub fn new(
440        psi: Vec<Vec<Complex>>,
441        potential: Vec<Vec<f64>>,
442        nx: usize,
443        ny: usize,
444        dx: f64,
445        dy: f64,
446        dt: f64,
447        mass: f64,
448        hbar: f64,
449    ) -> Self {
450        Self { psi, potential, nx, ny, dx, dy, dt, mass, hbar }
451    }
452
453    /// ADI time step: half step implicit in x, half step implicit in y.
454    pub fn step_2d(&mut self) {
455        let nx = self.nx;
456        let ny = self.ny;
457        let rx = Complex::new(0.0, self.hbar * self.dt / (4.0 * self.mass * self.dx * self.dx));
458        let ry = Complex::new(0.0, self.hbar * self.dt / (4.0 * self.mass * self.dy * self.dy));
459
460        // Half step: implicit in x, explicit in y
461        let mut psi_half = vec![vec![Complex::zero(); ny]; nx];
462        for j in 0..ny {
463            let mut a = vec![Complex::zero(); nx];
464            let mut b = vec![Complex::zero(); nx];
465            let mut c = vec![Complex::zero(); nx];
466            let mut d = vec![Complex::zero(); nx];
467
468            for i in 0..nx {
469                let v_term = Complex::new(0.0, self.dt * self.potential[i][j] / (4.0 * self.hbar));
470                b[i] = Complex::one() + rx * 2.0 + v_term;
471                if i > 0 { a[i] = -rx; }
472                if i < nx - 1 { c[i] = -rx; }
473
474                let psi_ij = self.psi[i][j];
475                let psi_up = if j > 0 { self.psi[i][j - 1] } else { Complex::zero() };
476                let psi_down = if j < ny - 1 { self.psi[i][j + 1] } else { Complex::zero() };
477
478                d[i] = (Complex::one() - ry * 2.0 - v_term) * psi_ij
479                    + ry * psi_up
480                    + ry * psi_down;
481            }
482            let sol = solve_tridiagonal(&a, &b, &c, &mut d);
483            for i in 0..nx {
484                psi_half[i][j] = sol[i];
485            }
486        }
487
488        // Half step: implicit in y, explicit in x
489        for i in 0..nx {
490            let mut a = vec![Complex::zero(); ny];
491            let mut b = vec![Complex::zero(); ny];
492            let mut c = vec![Complex::zero(); ny];
493            let mut d = vec![Complex::zero(); ny];
494
495            for j in 0..ny {
496                let v_term = Complex::new(0.0, self.dt * self.potential[i][j] / (4.0 * self.hbar));
497                b[j] = Complex::one() + ry * 2.0 + v_term;
498                if j > 0 { a[j] = -ry; }
499                if j < ny - 1 { c[j] = -ry; }
500
501                let psi_ij = psi_half[i][j];
502                let psi_left = if i > 0 { psi_half[i - 1][j] } else { Complex::zero() };
503                let psi_right = if i < nx - 1 { psi_half[i + 1][j] } else { Complex::zero() };
504
505                d[j] = (Complex::one() - rx * 2.0 - v_term) * psi_ij
506                    + rx * psi_left
507                    + rx * psi_right;
508            }
509            let sol = solve_tridiagonal(&a, &b, &c, &mut d);
510            for j in 0..ny {
511                self.psi[i][j] = sol[j];
512            }
513        }
514    }
515}
516
517/// Probability density |psi|^2 for 1D wave function.
518pub fn probability_density_1d(psi: &[Complex]) -> Vec<f64> {
519    psi.iter().map(|c| c.norm_sq()).collect()
520}
521
522/// Probability density |psi|^2 for 2D wave function.
523pub fn probability_density_2d(psi: &[Vec<Complex>]) -> Vec<Vec<f64>> {
524    psi.iter()
525        .map(|row| row.iter().map(|c| c.norm_sq()).collect())
526        .collect()
527}
528
529/// Normalize a 1D wave function so that integral |psi|^2 dx = 1.
530pub fn normalize(psi: &mut [Complex], dx: f64) {
531    let norm_sq: f64 = psi.iter().map(|c| c.norm_sq()).sum::<f64>() * dx;
532    let norm = norm_sq.sqrt();
533    if norm > 1e-30 {
534        for c in psi.iter_mut() {
535            *c = *c / norm;
536        }
537    }
538}
539
540/// Expectation value of position: <x> = integral psi* x psi dx.
541pub fn expectation_x(psi: &WaveFunction1D) -> f64 {
542    let mut sum = 0.0;
543    for i in 0..psi.n() {
544        let x = psi.x_at(i);
545        sum += psi.psi[i].norm_sq() * x;
546    }
547    sum * psi.dx
548}
549
550/// Expectation value of momentum: <p> = -i hbar integral psi* dpsi/dx dx.
551pub fn expectation_p(psi: &WaveFunction1D, hbar: f64) -> f64 {
552    let n = psi.n();
553    let dx = psi.dx;
554    let mut sum = Complex::zero();
555    for i in 1..n - 1 {
556        let dpsi = (psi.psi[i + 1] - psi.psi[i - 1]) / (2.0 * dx);
557        sum += psi.psi[i].conj() * dpsi;
558    }
559    let result = Complex::new(0.0, -hbar) * sum * dx;
560    result.re
561}
562
563/// Uncertainty in position: sqrt(<x^2> - <x>^2).
564pub fn uncertainty_x(psi: &WaveFunction1D) -> f64 {
565    let ex = expectation_x(psi);
566    let mut ex2 = 0.0;
567    for i in 0..psi.n() {
568        let x = psi.x_at(i);
569        ex2 += psi.psi[i].norm_sq() * x * x;
570    }
571    ex2 *= psi.dx;
572    (ex2 - ex * ex).max(0.0).sqrt()
573}
574
575/// Uncertainty in momentum: sqrt(<p^2> - <p>^2).
576pub fn uncertainty_p(psi: &WaveFunction1D, hbar: f64) -> f64 {
577    let ep = expectation_p(psi, hbar);
578    let n = psi.n();
579    let dx = psi.dx;
580
581    // <p^2> = -hbar^2 integral psi* d^2psi/dx^2 dx
582    let mut sum = 0.0;
583    for i in 1..n - 1 {
584        let d2psi = (psi.psi[i + 1] - psi.psi[i] * 2.0 + psi.psi[i - 1]) / (dx * dx);
585        let integrand = psi.psi[i].conj() * d2psi;
586        sum += integrand.re;
587    }
588    let ep2 = -hbar * hbar * sum * dx;
589    (ep2 - ep * ep).max(0.0).sqrt()
590}
591
592#[cfg(test)]
593mod tests {
594    use super::*;
595
596    #[test]
597    fn test_complex_arithmetic() {
598        let a = Complex::new(1.0, 2.0);
599        let b = Complex::new(3.0, 4.0);
600        let sum = a + b;
601        assert!((sum.re - 4.0).abs() < 1e-10);
602        assert!((sum.im - 6.0).abs() < 1e-10);
603
604        let prod = a * b;
605        assert!((prod.re - (-5.0)).abs() < 1e-10);
606        assert!((prod.im - 10.0).abs() < 1e-10);
607
608        let div = a / b;
609        let expected_re = 11.0 / 25.0;
610        let expected_im = 2.0 / 25.0;
611        assert!((div.re - expected_re).abs() < 1e-10);
612        assert!((div.im - expected_im).abs() < 1e-10);
613    }
614
615    #[test]
616    fn test_complex_exp() {
617        let z = Complex::new(0.0, PI);
618        let result = z.exp();
619        assert!((result.re - (-1.0)).abs() < 1e-10);
620        assert!(result.im.abs() < 1e-10);
621    }
622
623    #[test]
624    fn test_complex_from_polar() {
625        let c = Complex::from_polar(2.0, PI / 4.0);
626        assert!((c.re - 2.0_f64.sqrt()).abs() < 1e-10);
627        assert!((c.im - 2.0_f64.sqrt()).abs() < 1e-10);
628    }
629
630    #[test]
631    fn test_normalization() {
632        let n = 200;
633        let dx = 0.1;
634        let x_min = -10.0;
635        let sigma = 1.0;
636        let psi: Vec<Complex> = (0..n)
637            .map(|i| {
638                let x = x_min + i as f64 * dx;
639                Complex::new((-x * x / (2.0 * sigma * sigma)).exp(), 0.0)
640            })
641            .collect();
642        let mut wf = WaveFunction1D::new(psi, dx, x_min);
643        wf.normalize();
644        let norm = wf.norm_squared();
645        assert!((norm - 1.0).abs() < 0.01);
646    }
647
648    #[test]
649    fn test_crank_nicolson_preserves_norm() {
650        let n = 128;
651        let dx = 0.1;
652        let x_min = -6.4;
653        let sigma = 1.0;
654        let psi: Vec<Complex> = (0..n)
655            .map(|i| {
656                let x = x_min + i as f64 * dx;
657                Complex::new((-x * x / (2.0 * sigma * sigma)).exp(), 0.0)
658            })
659            .collect();
660        let mut wf = WaveFunction1D::new(psi, dx, x_min);
661        wf.normalize();
662        let potential = vec![0.0; n];
663        let mut solver = SchrodingerSolver1D::new(wf, potential, 1.0, 1.0, 0.001);
664        let norm_before = solver.psi.norm_squared();
665        for _ in 0..50 {
666            solver.step();
667        }
668        let norm_after = solver.psi.norm_squared();
669        assert!((norm_after - norm_before).abs() < 0.05);
670    }
671
672    #[test]
673    fn test_split_operator_preserves_norm() {
674        let n = 64;
675        let dx = 0.2;
676        let x_min = -6.4;
677        let sigma = 1.0;
678        let psi: Vec<Complex> = (0..n)
679            .map(|i| {
680                let x = x_min + i as f64 * dx;
681                Complex::new((-x * x / (2.0 * sigma * sigma)).exp(), 0.0)
682            })
683            .collect();
684        let mut wf = WaveFunction1D::new(psi, dx, x_min);
685        wf.normalize();
686        let potential = vec![0.0; n];
687        let mut solver = SchrodingerSolver1D::new(wf, potential, 1.0, 1.0, 0.001);
688        let norm_before = solver.psi.norm_squared();
689        for _ in 0..20 {
690            solver.step_split_operator();
691        }
692        let norm_after = solver.psi.norm_squared();
693        assert!((norm_after - norm_before).abs() < 0.05);
694    }
695
696    #[test]
697    fn test_uncertainty_principle() {
698        let n = 512;
699        let dx = 0.05;
700        let x_min = -12.8;
701        let sigma = 1.0;
702        let hbar = 1.0;
703        let psi: Vec<Complex> = (0..n)
704            .map(|i| {
705                let x = x_min + i as f64 * dx;
706                Complex::new((-x * x / (4.0 * sigma * sigma)).exp(), 0.0)
707            })
708            .collect();
709        let mut wf = WaveFunction1D::new(psi, dx, x_min);
710        wf.normalize();
711        let dx_unc = uncertainty_x(&wf);
712        let dp_unc = uncertainty_p(&wf, hbar);
713        let product = dx_unc * dp_unc;
714        // Heisenberg: dx * dp >= hbar/2
715        assert!(product >= hbar / 2.0 - 0.1, "Uncertainty product {} < hbar/2", product);
716    }
717
718    #[test]
719    fn test_dft_idft_roundtrip() {
720        let input = vec![
721            Complex::new(1.0, 0.0),
722            Complex::new(0.0, 1.0),
723            Complex::new(-1.0, 0.0),
724            Complex::new(0.0, -1.0),
725        ];
726        let transformed = dft(&input);
727        let recovered = idft(&transformed);
728        for (a, b) in input.iter().zip(recovered.iter()) {
729            assert!((a.re - b.re).abs() < 1e-10);
730            assert!((a.im - b.im).abs() < 1e-10);
731        }
732    }
733
734    #[test]
735    fn test_tridiagonal_solver() {
736        // Simple 3x3 system
737        let a = [Complex::zero(), Complex::new(-1.0, 0.0), Complex::new(-1.0, 0.0)];
738        let b = [Complex::new(2.0, 0.0), Complex::new(2.0, 0.0), Complex::new(2.0, 0.0)];
739        let c = [Complex::new(-1.0, 0.0), Complex::new(-1.0, 0.0), Complex::zero()];
740        let mut d = [Complex::new(1.0, 0.0), Complex::new(0.0, 0.0), Complex::new(1.0, 0.0)];
741        let x = solve_tridiagonal(&a, &b, &c, &mut d);
742        // Verify Ax = d
743        let r0 = b[0] * x[0] + c[0] * x[1];
744        assert!((r0.re - 1.0).abs() < 1e-10);
745    }
746
747    #[test]
748    fn test_infinite_well_eigenvalues() {
749        let n = 200;
750        let l = 1.0;
751        let dx = l / (n as f64 - 1.0);
752        let hbar = 1.0;
753        let mass = 0.5; // so hbar^2/(2m) = 1
754        let potential: Vec<f64> = (0..n)
755            .map(|i| {
756                let x = i as f64 * dx;
757                if x < 0.01 || x > l - 0.01 { 1e6 } else { 0.0 }
758            })
759            .collect();
760        let evals = energy_eigenvalues(&potential, 3, dx, mass, hbar);
761        // E_n = n^2 pi^2 hbar^2 / (2 m L^2) = n^2 pi^2
762        if evals.len() >= 2 {
763            let ratio = evals[1] / evals[0];
764            // Should be close to 4 (2^2/1^2)
765            assert!((ratio - 4.0).abs() < 1.0, "Ratio: {}", ratio);
766        }
767    }
768
769    #[test]
770    fn test_2d_solver_runs() {
771        let nx = 16;
772        let ny = 16;
773        let psi = vec![vec![Complex::zero(); ny]; nx];
774        let potential = vec![vec![0.0; ny]; nx];
775        let mut solver = SchrodingerSolver2D::new(psi, potential, nx, ny, 0.1, 0.1, 0.001, 1.0, 1.0);
776        // Place a peak in the center
777        solver.psi[nx / 2][ny / 2] = Complex::new(1.0, 0.0);
778        solver.step_2d();
779        // Just verify it doesn't panic and spreads
780        let center_prob = solver.psi[nx / 2][ny / 2].norm_sq();
781        assert!(center_prob < 1.1); // shouldn't blow up
782    }
783}