Skip to main content

proof_engine/math/
numerical.rs

1//! Numerical methods: root finding, quadrature, ODE solvers, linear algebra, interpolation.
2//! All implementations are from scratch — no external crates.
3
4// ============================================================
5// ROOT FINDING
6// ============================================================
7
8/// Bisection method. Requires f(a)*f(b) < 0.
9/// Returns the root within tolerance `tol`, or None if bracket invalid or no convergence.
10pub fn bisect(f: impl Fn(f64) -> f64, mut a: f64, mut b: f64, tol: f64, max_iter: usize) -> Option<f64> {
11    let mut fa = f(a);
12    let mut fb = f(b);
13    if fa * fb > 0.0 {
14        return None;
15    }
16    for _ in 0..max_iter {
17        let mid = (a + b) * 0.5;
18        if (b - a) * 0.5 < tol {
19            return Some(mid);
20        }
21        let fm = f(mid);
22        if fm == 0.0 {
23            return Some(mid);
24        }
25        if fa * fm < 0.0 {
26            b = mid;
27            fb = fm;
28        } else {
29            a = mid;
30            fa = fm;
31        }
32    }
33    Some((a + b) * 0.5)
34}
35
36/// Newton-Raphson method.
37pub fn newton_raphson(
38    f: impl Fn(f64) -> f64,
39    df: impl Fn(f64) -> f64,
40    mut x: f64,
41    tol: f64,
42    max_iter: usize,
43) -> Option<f64> {
44    for _ in 0..max_iter {
45        let fx = f(x);
46        if fx.abs() < tol {
47            return Some(x);
48        }
49        let dfx = df(x);
50        if dfx.abs() < 1e-300 {
51            return None;
52        }
53        let x_new = x - fx / dfx;
54        if (x_new - x).abs() < tol {
55            return Some(x_new);
56        }
57        x = x_new;
58    }
59    None
60}
61
62/// Secant method — Newton without explicit derivative.
63pub fn secant(
64    f: impl Fn(f64) -> f64,
65    mut x0: f64,
66    mut x1: f64,
67    tol: f64,
68    max_iter: usize,
69) -> Option<f64> {
70    let mut f0 = f(x0);
71    for _ in 0..max_iter {
72        let f1 = f(x1);
73        if f1.abs() < tol {
74            return Some(x1);
75        }
76        let denom = f1 - f0;
77        if denom.abs() < 1e-300 {
78            return None;
79        }
80        let x2 = x1 - f1 * (x1 - x0) / denom;
81        if (x2 - x1).abs() < tol {
82            return Some(x2);
83        }
84        x0 = x1;
85        f0 = f1;
86        x1 = x2;
87    }
88    None
89}
90
91/// Brent's method — superlinear convergence without derivative.
92/// Requires f(a)*f(b) <= 0.
93pub fn brent(f: impl Fn(f64) -> f64, mut a: f64, mut b: f64, tol: f64) -> Option<f64> {
94    let max_iter = 100;
95    let mut fa = f(a);
96    let mut fb = f(b);
97    if fa * fb > 0.0 {
98        return None;
99    }
100    if fa.abs() < fb.abs() {
101        core::mem::swap(&mut a, &mut b);
102        core::mem::swap(&mut fa, &mut fb);
103    }
104    let mut c = a;
105    let mut fc = fa;
106    let mut mflag = true;
107    let mut s = 0.0;
108    let mut d = 0.0;
109    for _ in 0..max_iter {
110        if fb.abs() < tol || (b - a).abs() < tol {
111            return Some(b);
112        }
113        if fa != fc && fb != fc {
114            // Inverse quadratic interpolation
115            s = a * fb * fc / ((fa - fb) * (fa - fc))
116                + b * fa * fc / ((fb - fa) * (fb - fc))
117                + c * fa * fb / ((fc - fa) * (fc - fb));
118        } else {
119            // Secant
120            s = b - fb * (b - a) / (fb - fa);
121        }
122        let cond1 = !((3.0 * a + b) / 4.0 < s && s < b)
123            && !((3.0 * a + b) / 4.0 > s && s > b);
124        let cond2 = mflag && (s - b).abs() >= (b - c).abs() / 2.0;
125        let cond3 = !mflag && (s - b).abs() >= (c - d).abs() / 2.0;
126        let cond4 = mflag && (b - c).abs() < tol;
127        let cond5 = !mflag && (c - d).abs() < tol;
128        if cond1 || cond2 || cond3 || cond4 || cond5 {
129            s = (a + b) / 2.0;
130            mflag = true;
131        } else {
132            mflag = false;
133        }
134        let fs = f(s);
135        d = c;
136        c = b;
137        fc = fb;
138        if fa * fs < 0.0 {
139            b = s;
140            fb = fs;
141        } else {
142            a = s;
143            fa = fs;
144        }
145        if fa.abs() < fb.abs() {
146            core::mem::swap(&mut a, &mut b);
147            core::mem::swap(&mut fa, &mut fb);
148        }
149    }
150    Some(b)
151}
152
153/// Illinois method — a regula falsi variant with superlinear convergence.
154pub fn illinois(f: impl Fn(f64) -> f64, mut a: f64, mut b: f64, tol: f64) -> Option<f64> {
155    let max_iter = 200;
156    let mut fa = f(a);
157    let mut fb = f(b);
158    if fa * fb > 0.0 {
159        return None;
160    }
161    let mut side = 0i32; // -1 = last step on 'a' side, +1 = 'b' side
162    for _ in 0..max_iter {
163        // Linear interpolation
164        let c = (a * fb - b * fa) / (fb - fa);
165        let fc = f(c);
166        if fc.abs() < tol || (b - a).abs() < tol {
167            return Some(c);
168        }
169        if fa * fc < 0.0 {
170            // Root in [a, c]
171            b = c;
172            fb = fc;
173            if side == -1 {
174                fa *= 0.5; // Illinois modification
175            }
176            side = -1;
177        } else {
178            // Root in [c, b]
179            a = c;
180            fa = fc;
181            if side == 1 {
182                fb *= 0.5;
183            }
184            side = 1;
185        }
186    }
187    Some((a + b) * 0.5)
188}
189
190/// Muller's method — quadratic interpolation, can find complex roots (returns real part here).
191pub fn muller(
192    f: impl Fn(f64) -> f64,
193    mut x0: f64,
194    mut x1: f64,
195    mut x2: f64,
196    tol: f64,
197    max_iter: usize,
198) -> Option<f64> {
199    for _ in 0..max_iter {
200        let f0 = f(x0);
201        let f1 = f(x1);
202        let f2 = f(x2);
203        let h1 = x1 - x0;
204        let h2 = x2 - x1;
205        let d1 = (f1 - f0) / h1;
206        let d2 = (f2 - f1) / h2;
207        let a = (d2 - d1) / (h2 + h1);
208        let b = a * h2 + d2;
209        let c = f2;
210        let discriminant = b * b - 4.0 * a * c;
211        let x3 = if discriminant < 0.0 {
212            // No real root from this quadratic; fall back to secant step
213            x2 - c / b
214        } else {
215            let sqrt_d = discriminant.sqrt();
216            let denom = if b + sqrt_d > (b - sqrt_d).abs() {
217                b + sqrt_d
218            } else {
219                b - sqrt_d
220            };
221            if denom.abs() < 1e-300 {
222                return None;
223            }
224            x2 - 2.0 * c / denom
225        };
226        if (x3 - x2).abs() < tol {
227            return Some(x3);
228        }
229        x0 = x1;
230        x1 = x2;
231        x2 = x3;
232    }
233    None
234}
235
236/// Fixed-point iteration: x_{n+1} = g(x_n).
237pub fn fixed_point(g: impl Fn(f64) -> f64, mut x: f64, tol: f64, max_iter: usize) -> Option<f64> {
238    for _ in 0..max_iter {
239        let x_new = g(x);
240        if (x_new - x).abs() < tol {
241            return Some(x_new);
242        }
243        x = x_new;
244    }
245    None
246}
247
248// ============================================================
249// NUMERICAL INTEGRATION (QUADRATURE)
250// ============================================================
251
252/// Trapezoidal rule with n sub-intervals (n must be >= 1).
253pub fn trapezoid(f: impl Fn(f64) -> f64, a: f64, b: f64, n: usize) -> f64 {
254    let n = n.max(1);
255    let h = (b - a) / n as f64;
256    let mut sum = 0.5 * (f(a) + f(b));
257    for i in 1..n {
258        sum += f(a + i as f64 * h);
259    }
260    sum * h
261}
262
263/// Simpson's 1/3 rule. n must be even; if odd, n is incremented by 1.
264pub fn simpsons(f: impl Fn(f64) -> f64, a: f64, b: f64, n: usize) -> f64 {
265    let n = if n % 2 == 0 { n.max(2) } else { n + 1 };
266    let h = (b - a) / n as f64;
267    let mut sum = f(a) + f(b);
268    for i in 1..n {
269        let x = a + i as f64 * h;
270        sum += if i % 2 == 0 { 2.0 * f(x) } else { 4.0 * f(x) };
271    }
272    sum * h / 3.0
273}
274
275/// Simpson's 3/8 rule. n must be a multiple of 3; adjusted upward if not.
276pub fn simpsons38(f: impl Fn(f64) -> f64, a: f64, b: f64, n: usize) -> f64 {
277    let n = {
278        let n = n.max(3);
279        if n % 3 == 0 { n } else { n + (3 - n % 3) }
280    };
281    let h = (b - a) / n as f64;
282    let mut sum = f(a) + f(b);
283    for i in 1..n {
284        let x = a + i as f64 * h;
285        sum += if i % 3 == 0 { 2.0 * f(x) } else { 3.0 * f(x) };
286    }
287    sum * 3.0 * h / 8.0
288}
289
290/// Gauss-Legendre quadrature. Supports n = 1..=5 nodes (pre-computed).
291/// Maps from [-1,1] to [a,b].
292pub fn gauss_legendre(f: impl Fn(f64) -> f64, a: f64, b: f64, n: usize) -> f64 {
293    // (nodes, weights) on [-1, 1]
294    let (nodes, weights): (&[f64], &[f64]) = match n {
295        1 => (&[0.0], &[2.0]),
296        2 => (
297            &[-0.577_350_269_189_626, 0.577_350_269_189_626],
298            &[1.0, 1.0],
299        ),
300        3 => (
301            &[-0.774_596_669_241_483, 0.0, 0.774_596_669_241_483],
302            &[
303                0.555_555_555_555_556,
304                0.888_888_888_888_889,
305                0.555_555_555_555_556,
306            ],
307        ),
308        4 => (
309            &[
310                -0.861_136_311_594_953,
311                -0.339_981_043_584_856,
312                0.339_981_043_584_856,
313                0.861_136_311_594_953,
314            ],
315            &[
316                0.347_854_845_137_454,
317                0.652_145_154_862_546,
318                0.652_145_154_862_546,
319                0.347_854_845_137_454,
320            ],
321        ),
322        _ => (
323            // n=5
324            &[
325                -0.906_179_845_938_664,
326                -0.538_469_310_105_683,
327                0.0,
328                0.538_469_310_105_683,
329                0.906_179_845_938_664,
330            ],
331            &[
332                0.236_926_885_056_189,
333                0.478_628_670_499_366,
334                0.568_888_888_888_889,
335                0.478_628_670_499_366,
336                0.236_926_885_056_189,
337            ],
338        ),
339    };
340    let scale = (b - a) * 0.5;
341    let shift = (b + a) * 0.5;
342    nodes
343        .iter()
344        .zip(weights.iter())
345        .map(|(&xi, &wi)| wi * f(scale * xi + shift))
346        .sum::<f64>()
347        * scale
348}
349
350/// Romberg integration — Richardson extrapolation on the trapezoidal rule.
351pub fn romberg(f: impl Fn(f64) -> f64, a: f64, b: f64, max_levels: usize, tol: f64) -> f64 {
352    let max_levels = max_levels.max(2);
353    let mut table = vec![vec![0.0f64; max_levels]; max_levels];
354    for i in 0..max_levels {
355        let n = 1usize << i;
356        table[i][0] = trapezoid(&f, a, b, n);
357    }
358    for j in 1..max_levels {
359        for i in j..max_levels {
360            let factor = (4.0f64).powi(j as i32);
361            table[i][j] = (factor * table[i][j - 1] - table[i - 1][j - 1]) / (factor - 1.0);
362        }
363        if max_levels > 2 {
364            let prev = table[j][j - 1];
365            let curr = table[j][j];
366            if (curr - prev).abs() < tol {
367                return curr;
368            }
369        }
370    }
371    table[max_levels - 1][max_levels - 1]
372}
373
374fn adaptive_simpson_helper(
375    f: &impl Fn(f64) -> f64,
376    a: f64,
377    b: f64,
378    tol: f64,
379    depth: usize,
380    max_depth: usize,
381) -> f64 {
382    let mid = (a + b) * 0.5;
383    let whole = simpsons(f, a, b, 2);
384    let left = simpsons(f, a, mid, 2);
385    let right = simpsons(f, mid, b, 2);
386    if depth >= max_depth || (left + right - whole).abs() < 15.0 * tol {
387        left + right + (left + right - whole) / 15.0
388    } else {
389        adaptive_simpson_helper(f, a, mid, tol / 2.0, depth + 1, max_depth)
390            + adaptive_simpson_helper(f, mid, b, tol / 2.0, depth + 1, max_depth)
391    }
392}
393
394/// Adaptive Simpson's rule with recursive subdivision.
395pub fn adaptive_simpson(f: impl Fn(f64) -> f64, a: f64, b: f64, tol: f64, max_depth: usize) -> f64 {
396    adaptive_simpson_helper(&f, a, b, tol, 0, max_depth)
397}
398
399/// Multi-dimensional Monte Carlo integration.
400/// `bounds` is a slice of (low, high) per dimension.
401/// Uses a simple LCG for reproducible sampling.
402pub fn monte_carlo_integrate(
403    f: impl Fn(&[f64]) -> f64,
404    bounds: &[(f64, f64)],
405    n_samples: usize,
406    seed: u64,
407) -> f64 {
408    let dim = bounds.len();
409    let volume: f64 = bounds.iter().map(|(lo, hi)| hi - lo).product();
410    let mut state = seed.wrapping_add(1);
411    let mut sum = 0.0;
412    let mut point = vec![0.0f64; dim];
413    for _ in 0..n_samples {
414        for (d, (lo, hi)) in bounds.iter().enumerate() {
415            state = state.wrapping_mul(6_364_136_223_846_793_005).wrapping_add(1_442_695_040_888_963_407);
416            let u = (state >> 33) as f64 / (u32::MAX as f64);
417            point[d] = lo + u * (hi - lo);
418        }
419        sum += f(&point);
420    }
421    volume * sum / n_samples as f64
422}
423
424// ============================================================
425// ODE SOLVERS
426// ============================================================
427
428/// Forward Euler method. Returns list of state vectors at each step.
429pub fn euler(
430    f: impl Fn(f64, &[f64]) -> Vec<f64>,
431    t0: f64,
432    y0: &[f64],
433    dt: f64,
434    steps: usize,
435) -> Vec<Vec<f64>> {
436    let mut result = Vec::with_capacity(steps + 1);
437    let mut y = y0.to_vec();
438    let mut t = t0;
439    result.push(y.clone());
440    for _ in 0..steps {
441        let dy = f(t, &y);
442        for (yi, dyi) in y.iter_mut().zip(dy.iter()) {
443            *yi += dt * dyi;
444        }
445        t += dt;
446        result.push(y.clone());
447    }
448    result
449}
450
451/// Classical 4th-order Runge-Kutta.
452pub fn rk4(
453    f: impl Fn(f64, &[f64]) -> Vec<f64>,
454    t0: f64,
455    y0: &[f64],
456    dt: f64,
457    steps: usize,
458) -> Vec<Vec<f64>> {
459    let mut result = Vec::with_capacity(steps + 1);
460    let mut y = y0.to_vec();
461    let mut t = t0;
462    result.push(y.clone());
463    let n = y.len();
464    for _ in 0..steps {
465        let k1 = f(t, &y);
466        let y2: Vec<f64> = y.iter().zip(&k1).map(|(yi, k)| yi + 0.5 * dt * k).collect();
467        let k2 = f(t + 0.5 * dt, &y2);
468        let y3: Vec<f64> = y.iter().zip(&k2).map(|(yi, k)| yi + 0.5 * dt * k).collect();
469        let k3 = f(t + 0.5 * dt, &y3);
470        let y4: Vec<f64> = y.iter().zip(&k3).map(|(yi, k)| yi + dt * k).collect();
471        let k4 = f(t + dt, &y4);
472        for i in 0..n {
473            y[i] += dt / 6.0 * (k1[i] + 2.0 * k2[i] + 2.0 * k3[i] + k4[i]);
474        }
475        t += dt;
476        result.push(y.clone());
477    }
478    result
479}
480
481/// Dormand-Prince RK45 adaptive step integrator.
482/// Returns (time_points, state_vectors).
483pub fn rk45(
484    f: impl Fn(f64, &[f64]) -> Vec<f64>,
485    t0: f64,
486    y0: &[f64],
487    t_end: f64,
488    tol: f64,
489    h_min: f64,
490    h_max: f64,
491) -> (Vec<f64>, Vec<Vec<f64>>) {
492    // Dormand-Prince coefficients
493    const C2: f64 = 1.0 / 5.0;
494    const C3: f64 = 3.0 / 10.0;
495    const C4: f64 = 4.0 / 5.0;
496    const C5: f64 = 8.0 / 9.0;
497
498    const A21: f64 = 1.0 / 5.0;
499    const A31: f64 = 3.0 / 40.0;
500    const A32: f64 = 9.0 / 40.0;
501    const A41: f64 = 44.0 / 45.0;
502    const A42: f64 = -56.0 / 15.0;
503    const A43: f64 = 32.0 / 9.0;
504    const A51: f64 = 19372.0 / 6561.0;
505    const A52: f64 = -25360.0 / 2187.0;
506    const A53: f64 = 64448.0 / 6561.0;
507    const A54: f64 = -212.0 / 729.0;
508    const A61: f64 = 9017.0 / 3168.0;
509    const A62: f64 = -355.0 / 33.0;
510    const A63: f64 = 46732.0 / 5247.0;
511    const A64: f64 = 49.0 / 176.0;
512    const A65: f64 = -5103.0 / 18656.0;
513
514    // 5th order weights
515    const B1: f64 = 35.0 / 384.0;
516    const B3: f64 = 500.0 / 1113.0;
517    const B4: f64 = 125.0 / 192.0;
518    const B5: f64 = -2187.0 / 6784.0;
519    const B6: f64 = 11.0 / 84.0;
520
521    // 4th order weights (for error)
522    const E1: f64 = 71.0 / 57600.0;
523    const E3: f64 = -71.0 / 16695.0;
524    const E4: f64 = 71.0 / 1920.0;
525    const E5: f64 = -17253.0 / 339200.0;
526    const E6: f64 = 22.0 / 525.0;
527    const E7: f64 = -1.0 / 40.0;
528
529    let n = y0.len();
530    let mut ts = vec![t0];
531    let mut ys = vec![y0.to_vec()];
532    let mut t = t0;
533    let mut y = y0.to_vec();
534    let mut h = (h_max).min((t_end - t0) * 0.1).max(h_min);
535
536    while t < t_end {
537        if t + h > t_end { h = t_end - t; }
538        if h < h_min { h = h_min; }
539
540        let k1 = f(t, &y);
541        let yy: Vec<f64> = (0..n).map(|i| y[i] + h * A21 * k1[i]).collect();
542        let k2 = f(t + C2 * h, &yy);
543        let yy: Vec<f64> = (0..n).map(|i| y[i] + h * (A31 * k1[i] + A32 * k2[i])).collect();
544        let k3 = f(t + C3 * h, &yy);
545        let yy: Vec<f64> = (0..n).map(|i| y[i] + h * (A41 * k1[i] + A42 * k2[i] + A43 * k3[i])).collect();
546        let k4 = f(t + C4 * h, &yy);
547        let yy: Vec<f64> = (0..n).map(|i| y[i] + h * (A51 * k1[i] + A52 * k2[i] + A53 * k3[i] + A54 * k4[i])).collect();
548        let k5 = f(t + C5 * h, &yy);
549        let yy: Vec<f64> = (0..n).map(|i| y[i] + h * (A61 * k1[i] + A62 * k2[i] + A63 * k3[i] + A64 * k4[i] + A65 * k5[i])).collect();
550        let k6 = f(t + h, &yy);
551
552        let y_new: Vec<f64> = (0..n)
553            .map(|i| y[i] + h * (B1 * k1[i] + B3 * k3[i] + B4 * k4[i] + B5 * k5[i] + B6 * k6[i]))
554            .collect();
555        let k7 = f(t + h, &y_new);
556
557        // Error estimate
558        let err: f64 = (0..n)
559            .map(|i| {
560                let e = h * (E1 * k1[i] + E3 * k3[i] + E4 * k4[i] + E5 * k5[i] + E6 * k6[i] + E7 * k7[i]);
561                let sc = tol + tol * y[i].abs().max(y_new[i].abs());
562                (e / sc).powi(2)
563            })
564            .sum::<f64>()
565            / n as f64;
566        let err = err.sqrt();
567
568        if err <= 1.0 || h <= h_min {
569            t += h;
570            y = y_new;
571            ts.push(t);
572            ys.push(y.clone());
573        }
574        // Adjust step
575        let factor = if err == 0.0 { 5.0 } else { 0.9 * err.powf(-0.2) };
576        h = (h * factor.clamp(0.1, 5.0)).clamp(h_min, h_max);
577    }
578    (ts, ys)
579}
580
581/// Adams-Bashforth 4-step method.
582/// Seeds first 4 steps with RK4, then applies the multi-step formula.
583pub fn adams_bashforth4(
584    f: impl Fn(f64, &[f64]) -> Vec<f64>,
585    t0: f64,
586    y0: &[f64],
587    dt: f64,
588    steps: usize,
589) -> Vec<Vec<f64>> {
590    if steps == 0 {
591        return vec![y0.to_vec()];
592    }
593    let n = y0.len();
594    // Seed with RK4
595    let seed_steps = 3.min(steps);
596    let rk_result = rk4(&f, t0, y0, dt, seed_steps);
597    let mut result = rk_result.clone();
598    if steps <= 3 {
599        return result;
600    }
601    // Store last 4 derivatives
602    let mut t = t0 + seed_steps as f64 * dt;
603    let mut derivs: Vec<Vec<f64>> = (0..=seed_steps)
604        .map(|i| f(t0 + i as f64 * dt, &rk_result[i]))
605        .collect();
606    for _ in 4..=steps {
607        let f0 = &derivs[derivs.len() - 4];
608        let f1 = &derivs[derivs.len() - 3];
609        let f2 = &derivs[derivs.len() - 2];
610        let f3 = &derivs[derivs.len() - 1];
611        let y_prev = result.last().unwrap();
612        let y_new: Vec<f64> = (0..n)
613            .map(|i| {
614                y_prev[i]
615                    + dt / 24.0 * (55.0 * f3[i] - 59.0 * f2[i] + 37.0 * f1[i] - 9.0 * f0[i])
616            })
617            .collect();
618        t += dt;
619        let fn_new = f(t, &y_new);
620        derivs.push(fn_new);
621        result.push(y_new);
622    }
623    result
624}
625
626/// Störmer-Verlet integrator for second-order ODE x'' = a(x).
627/// Returns vec of (t, x, v).
628pub fn verlet(
629    x0: f64,
630    v0: f64,
631    a: impl Fn(f64) -> f64,
632    dt: f64,
633    steps: usize,
634) -> Vec<(f64, f64, f64)> {
635    let mut result = Vec::with_capacity(steps + 1);
636    let mut x = x0;
637    let mut v = v0;
638    let mut t = 0.0;
639    result.push((t, x, v));
640    for _ in 0..steps {
641        let acc = a(x);
642        let x_new = x + v * dt + 0.5 * acc * dt * dt;
643        let acc_new = a(x_new);
644        let v_new = v + 0.5 * (acc + acc_new) * dt;
645        x = x_new;
646        v = v_new;
647        t += dt;
648        result.push((t, x, v));
649    }
650    result
651}
652
653/// Leapfrog (Störmer-Verlet) symplectic integrator for N-body-style systems.
654/// `positions` and `velocities` are flat arrays of length 3*N.
655/// `forces_fn` takes positions and returns force vectors (acceleration).
656/// Returns steps of (positions, velocities).
657pub fn leapfrog(
658    positions: &[f64],
659    velocities: &[f64],
660    forces_fn: impl Fn(&[f64]) -> Vec<f64>,
661    dt: f64,
662    steps: usize,
663) -> Vec<(Vec<f64>, Vec<f64>)> {
664    let n = positions.len();
665    let mut pos = positions.to_vec();
666    let mut vel = velocities.to_vec();
667    let mut result = Vec::with_capacity(steps + 1);
668    result.push((pos.clone(), vel.clone()));
669    let mut acc = forces_fn(&pos);
670    for _ in 0..steps {
671        // Half-kick
672        for i in 0..n {
673            vel[i] += 0.5 * dt * acc[i];
674        }
675        // Full drift
676        for i in 0..n {
677            pos[i] += dt * vel[i];
678        }
679        // Compute new forces
680        acc = forces_fn(&pos);
681        // Half-kick
682        for i in 0..n {
683            vel[i] += 0.5 * dt * acc[i];
684        }
685        result.push((pos.clone(), vel.clone()));
686    }
687    result
688}
689
690// ============================================================
691// LINEAR ALGEBRA
692// ============================================================
693
694/// Dense matrix stored in row-major order.
695#[derive(Clone, Debug)]
696pub struct Matrix {
697    pub rows: usize,
698    pub cols: usize,
699    pub data: Vec<f64>,
700}
701
702impl Matrix {
703    /// Create an uninitialized (zero) matrix.
704    pub fn zeros(rows: usize, cols: usize) -> Self {
705        Matrix { rows, cols, data: vec![0.0; rows * cols] }
706    }
707
708    /// Create identity matrix.
709    pub fn identity(n: usize) -> Self {
710        let mut m = Self::zeros(n, n);
711        for i in 0..n { m[(i, i)] = 1.0; }
712        m
713    }
714
715    /// Create from row-major flat data.
716    pub fn from_data(rows: usize, cols: usize, data: Vec<f64>) -> Self {
717        assert_eq!(data.len(), rows * cols);
718        Matrix { rows, cols, data }
719    }
720
721    pub fn get(&self, r: usize, c: usize) -> f64 {
722        self.data[r * self.cols + c]
723    }
724
725    pub fn set(&mut self, r: usize, c: usize, v: f64) {
726        self.data[r * self.cols + c] = v;
727    }
728}
729
730impl core::ops::Index<(usize, usize)> for Matrix {
731    type Output = f64;
732    fn index(&self, (r, c): (usize, usize)) -> &f64 {
733        &self.data[r * self.cols + c]
734    }
735}
736
737impl core::ops::IndexMut<(usize, usize)> for Matrix {
738    fn index_mut(&mut self, (r, c): (usize, usize)) -> &mut f64 {
739        &mut self.data[r * self.cols + c]
740    }
741}
742
743/// Matrix multiplication. Panics if dimensions mismatch.
744pub fn matmul(a: &Matrix, b: &Matrix) -> Matrix {
745    assert_eq!(a.cols, b.rows, "matmul: dimension mismatch");
746    let mut c = Matrix::zeros(a.rows, b.cols);
747    for i in 0..a.rows {
748        for k in 0..a.cols {
749            for j in 0..b.cols {
750                c[(i, j)] += a[(i, k)] * b[(k, j)];
751            }
752        }
753    }
754    c
755}
756
757/// Matrix transpose.
758pub fn transpose(a: &Matrix) -> Matrix {
759    let mut t = Matrix::zeros(a.cols, a.rows);
760    for i in 0..a.rows {
761        for j in 0..a.cols {
762            t[(j, i)] = a[(i, j)];
763        }
764    }
765    t
766}
767
768/// LU decomposition with partial pivoting.
769/// Returns (L, U, pivot) or None if singular.
770pub fn lu_decompose(a: &Matrix) -> Option<(Matrix, Matrix, Vec<usize>)> {
771    let n = a.rows;
772    assert_eq!(a.rows, a.cols, "LU requires square matrix");
773    let mut lu = a.clone();
774    let mut piv: Vec<usize> = (0..n).collect();
775    for k in 0..n {
776        // Find pivot
777        let mut max_val = lu[(k, k)].abs();
778        let mut max_row = k;
779        for i in k + 1..n {
780            let v = lu[(i, k)].abs();
781            if v > max_val {
782                max_val = v;
783                max_row = i;
784            }
785        }
786        if max_val < 1e-300 {
787            return None; // Singular
788        }
789        if max_row != k {
790            piv.swap(k, max_row);
791            for j in 0..n {
792                let tmp = lu[(k, j)];
793                lu[(k, j)] = lu[(max_row, j)];
794                lu[(max_row, j)] = tmp;
795            }
796        }
797        for i in k + 1..n {
798            lu[(i, k)] /= lu[(k, k)];
799            for j in k + 1..n {
800                let val = lu[(i, k)] * lu[(k, j)];
801                lu[(i, j)] -= val;
802            }
803        }
804    }
805    // Extract L and U
806    let mut l = Matrix::identity(n);
807    let mut u = Matrix::zeros(n, n);
808    for i in 0..n {
809        for j in 0..n {
810            if i > j {
811                l[(i, j)] = lu[(i, j)];
812            } else {
813                u[(i, j)] = lu[(i, j)];
814            }
815        }
816    }
817    Some((l, u, piv))
818}
819
820/// Solve L*U*x = Pb using forward/back substitution.
821pub fn lu_solve(l: &Matrix, u: &Matrix, piv: &[usize], b: &[f64]) -> Vec<f64> {
822    let n = b.len();
823    // Apply permutation
824    let mut pb: Vec<f64> = piv.iter().map(|&i| b[i]).collect();
825    // Forward substitution (L*y = pb)
826    for i in 0..n {
827        for j in 0..i {
828            pb[i] -= l[(i, j)] * pb[j];
829        }
830    }
831    // Back substitution (U*x = y)
832    for i in (0..n).rev() {
833        for j in i + 1..n {
834            pb[i] -= u[(i, j)] * pb[j];
835        }
836        pb[i] /= u[(i, i)];
837    }
838    pb
839}
840
841/// Solve Ax = b via LU decomposition. Returns None if singular.
842pub fn solve_linear(a: &Matrix, b: &[f64]) -> Option<Vec<f64>> {
843    let (l, u, piv) = lu_decompose(a)?;
844    Some(lu_solve(&l, &u, &piv, b))
845}
846
847/// Determinant via LU decomposition.
848pub fn determinant(a: &Matrix) -> f64 {
849    let n = a.rows;
850    assert_eq!(a.rows, a.cols);
851    let mut lu = a.clone();
852    let mut piv: Vec<usize> = (0..n).collect();
853    let mut sign = 1.0f64;
854    for k in 0..n {
855        let mut max_val = lu[(k, k)].abs();
856        let mut max_row = k;
857        for i in k + 1..n {
858            let v = lu[(i, k)].abs();
859            if v > max_val {
860                max_val = v;
861                max_row = i;
862            }
863        }
864        if max_val < 1e-300 {
865            return 0.0;
866        }
867        if max_row != k {
868            piv.swap(k, max_row);
869            for j in 0..n {
870                let tmp = lu[(k, j)];
871                lu[(k, j)] = lu[(max_row, j)];
872                lu[(max_row, j)] = tmp;
873            }
874            sign = -sign;
875        }
876        for i in k + 1..n {
877            lu[(i, k)] /= lu[(k, k)];
878            for j in k + 1..n {
879                let val = lu[(i, k)] * lu[(k, j)];
880                lu[(i, j)] -= val;
881            }
882        }
883    }
884    let mut det = sign;
885    for i in 0..n { det *= lu[(i, i)]; }
886    det
887}
888
889/// Matrix inverse via LU. Returns None if singular.
890pub fn inverse(a: &Matrix) -> Option<Matrix> {
891    let n = a.rows;
892    assert_eq!(a.rows, a.cols);
893    let (l, u, piv) = lu_decompose(a)?;
894    let mut inv = Matrix::zeros(n, n);
895    for j in 0..n {
896        let mut e = vec![0.0f64; n];
897        e[j] = 1.0;
898        let col = lu_solve(&l, &u, &piv, &e);
899        for i in 0..n { inv[(i, j)] = col[i]; }
900    }
901    Some(inv)
902}
903
904/// Cholesky decomposition for symmetric positive-definite matrices.
905/// Returns lower triangular L such that A = L * L^T. Returns None if not SPD.
906pub fn cholesky(a: &Matrix) -> Option<Matrix> {
907    let n = a.rows;
908    assert_eq!(a.rows, a.cols);
909    let mut l = Matrix::zeros(n, n);
910    for i in 0..n {
911        for j in 0..=i {
912            let mut s: f64 = a[(i, j)];
913            for k in 0..j { s -= l[(i, k)] * l[(j, k)]; }
914            if i == j {
915                if s <= 0.0 { return None; }
916                l[(i, j)] = s.sqrt();
917            } else {
918                l[(i, j)] = s / l[(j, j)];
919            }
920        }
921    }
922    Some(l)
923}
924
925/// Gram-Schmidt orthonormalization.
926pub fn gram_schmidt(cols: &[Vec<f64>]) -> Vec<Vec<f64>> {
927    let mut q: Vec<Vec<f64>> = Vec::new();
928    for v in cols {
929        let mut u = v.clone();
930        for qi in &q {
931            let dot_vu: f64 = u.iter().zip(qi.iter()).map(|(a, b)| a * b).sum();
932            for (ui, qi_i) in u.iter_mut().zip(qi.iter()) {
933                *ui -= dot_vu * qi_i;
934            }
935        }
936        let norm: f64 = u.iter().map(|x| x * x).sum::<f64>().sqrt();
937        if norm < 1e-12 { continue; }
938        q.push(u.iter().map(|x| x / norm).collect());
939    }
940    q
941}
942
943/// Thin QR decomposition via Gram-Schmidt.
944pub fn qr_decompose(a: &Matrix) -> (Matrix, Matrix) {
945    let m = a.rows;
946    let n = a.cols;
947    // Extract columns
948    let cols: Vec<Vec<f64>> = (0..n)
949        .map(|j| (0..m).map(|i| a[(i, j)]).collect())
950        .collect();
951    let q_cols = gram_schmidt(&cols);
952    let k = q_cols.len();
953    let mut q = Matrix::zeros(m, k);
954    for (j, col) in q_cols.iter().enumerate() {
955        for i in 0..m { q[(i, j)] = col[i]; }
956    }
957    // R = Q^T * A
958    let qt = transpose(&q);
959    let r = matmul(&qt, a);
960    (q, r)
961}
962
963/// Analytic eigenvalues of a 2x2 matrix.
964pub fn eigenvalues_2x2(a: &Matrix) -> (f64, f64) {
965    assert!(a.rows == 2 && a.cols == 2);
966    let tr = a[(0, 0)] + a[(1, 1)];
967    let det = a[(0, 0)] * a[(1, 1)] - a[(0, 1)] * a[(1, 0)];
968    let disc = tr * tr - 4.0 * det;
969    if disc >= 0.0 {
970        let s = disc.sqrt();
971        ((tr + s) * 0.5, (tr - s) * 0.5)
972    } else {
973        // Complex pair — return real parts
974        (tr * 0.5, tr * 0.5)
975    }
976}
977
978/// Power iteration for the dominant eigenvalue/eigenvector.
979pub fn power_iteration(a: &Matrix, max_iter: usize, tol: f64) -> (f64, Vec<f64>) {
980    let n = a.rows;
981    let mut v: Vec<f64> = (0..n).map(|i| if i == 0 { 1.0 } else { 0.0 }).collect();
982    let mut lambda = 0.0;
983    for _ in 0..max_iter {
984        let av: Vec<f64> = (0..n).map(|i| (0..n).map(|j| a[(i, j)] * v[j]).sum()).collect();
985        let norm: f64 = av.iter().map(|x| x * x).sum::<f64>().sqrt();
986        if norm < 1e-300 { break; }
987        let v_new: Vec<f64> = av.iter().map(|x| x / norm).collect();
988        let lambda_new: f64 = av.iter().zip(v.iter()).map(|(a, b)| a * b).sum();
989        if (lambda_new - lambda).abs() < tol {
990            return (lambda_new, v_new);
991        }
992        lambda = lambda_new;
993        v = v_new;
994    }
995    (lambda, v)
996}
997
998/// Analytic 2x2 SVD: A = U * diag(sigma) * V^T.
999pub fn svd_2x2(a: &Matrix) -> (Matrix, Vec<f64>, Matrix) {
1000    assert!(a.rows == 2 && a.cols == 2);
1001    // Compute A^T * A
1002    let at = transpose(a);
1003    let ata = matmul(&at, a);
1004    let (e1, e2) = eigenvalues_2x2(&ata);
1005    let s1 = e1.abs().sqrt();
1006    let s2 = e2.abs().sqrt();
1007
1008    // V from eigenvectors of A^T A
1009    let build_evec = |lambda: f64| -> Vec<f64> {
1010        let a00 = ata[(0, 0)] - lambda;
1011        let a01 = ata[(0, 1)];
1012        if a01.abs() > 1e-12 || a00.abs() > 1e-12 {
1013            let norm = (a00 * a00 + a01 * a01).sqrt();
1014            if norm < 1e-300 { return vec![1.0, 0.0]; }
1015            vec![a01 / norm, -a00 / norm]
1016        } else {
1017            vec![1.0, 0.0]
1018        }
1019    };
1020
1021    let v1 = build_evec(e1);
1022    let v2 = build_evec(e2);
1023
1024    let mut v_mat = Matrix::zeros(2, 2);
1025    v_mat[(0, 0)] = v1[0]; v_mat[(1, 0)] = v1[1];
1026    v_mat[(0, 1)] = v2[0]; v_mat[(1, 1)] = v2[1];
1027
1028    let sigmas = vec![s1, s2];
1029
1030    // U: for each non-zero sigma, u_i = A * v_i / sigma_i
1031    let mut u_mat = Matrix::identity(2);
1032    if s1 > 1e-12 {
1033        let u0 = vec![
1034            (a[(0, 0)] * v1[0] + a[(0, 1)] * v1[1]) / s1,
1035            (a[(1, 0)] * v1[0] + a[(1, 1)] * v1[1]) / s1,
1036        ];
1037        u_mat[(0, 0)] = u0[0]; u_mat[(1, 0)] = u0[1];
1038    }
1039    if s2 > 1e-12 {
1040        let u1 = vec![
1041            (a[(0, 0)] * v2[0] + a[(0, 1)] * v2[1]) / s2,
1042            (a[(1, 0)] * v2[0] + a[(1, 1)] * v2[1]) / s2,
1043        ];
1044        u_mat[(0, 1)] = u1[0]; u_mat[(1, 1)] = u1[1];
1045    }
1046
1047    (u_mat, sigmas, v_mat)
1048}
1049
1050// ============================================================
1051// INTERPOLATION
1052// ============================================================
1053
1054/// Linear interpolation between a and b.
1055#[inline]
1056pub fn lerp(a: f64, b: f64, t: f64) -> f64 {
1057    a + (b - a) * t
1058}
1059
1060/// Bilinear interpolation on a unit square.
1061/// tl=top-left, tr=top-right, bl=bottom-left, br=bottom-right.
1062/// tx, ty in [0,1].
1063#[inline]
1064pub fn bilinear(tl: f64, tr: f64, bl: f64, br: f64, tx: f64, ty: f64) -> f64 {
1065    let top = lerp(tl, tr, tx);
1066    let bot = lerp(bl, br, tx);
1067    lerp(top, bot, ty)
1068}
1069
1070/// Barycentric coordinates of point p w.r.t. triangle (a, b, c).
1071/// Returns (u, v, w) such that p = u*a + v*b + w*c.
1072pub fn barycentric(
1073    p: (f64, f64),
1074    a: (f64, f64),
1075    b: (f64, f64),
1076    c: (f64, f64),
1077) -> (f64, f64, f64) {
1078    let denom = (b.1 - c.1) * (a.0 - c.0) + (c.0 - b.0) * (a.1 - c.1);
1079    if denom.abs() < 1e-300 {
1080        return (1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0);
1081    }
1082    let u = ((b.1 - c.1) * (p.0 - c.0) + (c.0 - b.0) * (p.1 - c.1)) / denom;
1083    let v = ((c.1 - a.1) * (p.0 - c.0) + (a.0 - c.0) * (p.1 - c.1)) / denom;
1084    let w = 1.0 - u - v;
1085    (u, v, w)
1086}
1087
1088/// Lagrange polynomial interpolation at x.
1089pub fn lagrange_interp(xs: &[f64], ys: &[f64], x: f64) -> f64 {
1090    let n = xs.len();
1091    let mut result = 0.0;
1092    for i in 0..n {
1093        let mut basis = 1.0;
1094        for j in 0..n {
1095            if i != j {
1096                basis *= (x - xs[j]) / (xs[i] - xs[j]);
1097            }
1098        }
1099        result += ys[i] * basis;
1100    }
1101    result
1102}
1103
1104/// Cubic spline piece: f(x) = a + b*(x-xi) + c*(x-xi)^2 + d*(x-xi)^3
1105#[derive(Clone, Debug)]
1106struct SplinePiece {
1107    x: f64,
1108    a: f64,
1109    b: f64,
1110    c: f64,
1111    d: f64,
1112}
1113
1114/// Natural cubic spline interpolant.
1115#[derive(Clone, Debug)]
1116pub struct CubicSpline {
1117    pieces: Vec<SplinePiece>,
1118    x_end: f64,
1119}
1120
1121impl CubicSpline {
1122    /// Evaluate the spline at x.
1123    pub fn evaluate(&self, x: f64) -> f64 {
1124        // Find correct piece via binary search
1125        let idx = self.pieces.partition_point(|p| p.x <= x).saturating_sub(1);
1126        let idx = idx.min(self.pieces.len() - 1);
1127        let p = &self.pieces[idx];
1128        let dx = x - p.x;
1129        p.a + p.b * dx + p.c * dx * dx + p.d * dx * dx * dx
1130    }
1131}
1132
1133/// Build a natural cubic spline through (xs, ys).
1134pub fn natural_cubic_spline(xs: &[f64], ys: &[f64]) -> CubicSpline {
1135    let n = xs.len();
1136    assert!(n >= 2, "Need at least 2 points for cubic spline");
1137    let m = n - 1; // number of intervals
1138    let mut h = vec![0.0f64; m];
1139    for i in 0..m { h[i] = xs[i + 1] - xs[i]; }
1140
1141    // Tridiagonal system for second derivatives (natural BC: M[0] = M[n-1] = 0)
1142    let rhs_len = n - 2;
1143    if rhs_len == 0 {
1144        // Only 2 points: linear
1145        let slope = (ys[1] - ys[0]) / h[0];
1146        let pieces = vec![SplinePiece { x: xs[0], a: ys[0], b: slope, c: 0.0, d: 0.0 }];
1147        return CubicSpline { pieces, x_end: xs[n - 1] };
1148    }
1149
1150    let mut diag = vec![0.0f64; rhs_len];
1151    let mut upper = vec![0.0f64; rhs_len - 1];
1152    let mut lower = vec![0.0f64; rhs_len - 1];
1153    let mut rhs = vec![0.0f64; rhs_len];
1154
1155    for i in 0..rhs_len {
1156        let ii = i + 1; // index in original array
1157        diag[i] = 2.0 * (h[ii - 1] + h[ii]);
1158        rhs[i] = 6.0 * ((ys[ii + 1] - ys[ii]) / h[ii] - (ys[ii] - ys[ii - 1]) / h[ii - 1]);
1159    }
1160    for i in 0..rhs_len - 1 {
1161        upper[i] = h[i + 1];
1162        lower[i] = h[i + 1];
1163    }
1164
1165    // Thomas algorithm (tridiagonal solver)
1166    let mut c_prime = vec![0.0f64; rhs_len];
1167    let mut d_prime = vec![0.0f64; rhs_len];
1168    c_prime[0] = upper[0] / diag[0];
1169    d_prime[0] = rhs[0] / diag[0];
1170    for i in 1..rhs_len {
1171        let denom = diag[i] - lower[i - 1] * c_prime[i - 1];
1172        if i < rhs_len - 1 {
1173            c_prime[i] = upper[i] / denom;
1174        }
1175        d_prime[i] = (rhs[i] - lower[i - 1] * d_prime[i - 1]) / denom;
1176    }
1177    let mut sigma = vec![0.0f64; n];
1178    sigma[rhs_len] = d_prime[rhs_len - 1];
1179    for i in (0..rhs_len - 1).rev() {
1180        sigma[i + 1] = d_prime[i] - c_prime[i] * sigma[i + 2];
1181    }
1182    // sigma[0] = sigma[n-1] = 0 (natural)
1183
1184    let mut pieces = Vec::with_capacity(m);
1185    for i in 0..m {
1186        let a = ys[i];
1187        let b = (ys[i + 1] - ys[i]) / h[i] - h[i] * (2.0 * sigma[i] + sigma[i + 1]) / 6.0;
1188        let c = sigma[i] * 0.5;
1189        let d = (sigma[i + 1] - sigma[i]) / (6.0 * h[i]);
1190        pieces.push(SplinePiece { x: xs[i], a, b, c, d });
1191    }
1192    CubicSpline { pieces, x_end: *xs.last().unwrap() }
1193}
1194
1195/// 2D Radial Basis Function interpolation using multiquadric RBF.
1196/// centers: list of (x, y) center points, values: function value at each center.
1197pub fn rbf_interpolate(centers: &[(f64, f64)], values: &[f64], p: (f64, f64)) -> f64 {
1198    let n = centers.len();
1199    if n == 0 { return 0.0; }
1200    // Build RBF matrix and solve for weights
1201    // phi(r) = sqrt(r^2 + 1) — multiquadric
1202    let phi = |cx: f64, cy: f64, x: f64, y: f64| {
1203        let r2 = (x - cx).powi(2) + (y - cy).powi(2);
1204        (r2 + 1.0).sqrt()
1205    };
1206    let mut mat = Matrix::zeros(n, n);
1207    for i in 0..n {
1208        for j in 0..n {
1209            mat[(i, j)] = phi(centers[i].0, centers[i].1, centers[j].0, centers[j].1);
1210        }
1211    }
1212    let weights = solve_linear(&mat, values).unwrap_or_else(|| values.to_vec());
1213    weights.iter().enumerate().map(|(i, &w)| w * phi(centers[i].0, centers[i].1, p.0, p.1)).sum()
1214}
1215
1216// ============================================================
1217// TESTS
1218// ============================================================
1219
1220#[cfg(test)]
1221mod tests {
1222    use super::*;
1223
1224    fn sq(x: f64) -> f64 { x * x - 2.0 }
1225    fn dsq(x: f64) -> f64 { 2.0 * x }
1226
1227    #[test]
1228    fn test_bisect_sqrt2() {
1229        let root = bisect(sq, 1.0, 2.0, 1e-10, 100).unwrap();
1230        assert!((root - 2.0f64.sqrt()).abs() < 1e-9);
1231    }
1232
1233    #[test]
1234    fn test_newton_raphson_sqrt2() {
1235        let root = newton_raphson(sq, dsq, 1.5, 1e-10, 100).unwrap();
1236        assert!((root - 2.0f64.sqrt()).abs() < 1e-9);
1237    }
1238
1239    #[test]
1240    fn test_secant_sqrt2() {
1241        let root = secant(sq, 1.0, 2.0, 1e-10, 100).unwrap();
1242        assert!((root - 2.0f64.sqrt()).abs() < 1e-9);
1243    }
1244
1245    #[test]
1246    fn test_brent_sqrt2() {
1247        let root = brent(sq, 1.0, 2.0, 1e-10).unwrap();
1248        assert!((root - 2.0f64.sqrt()).abs() < 1e-9);
1249    }
1250
1251    #[test]
1252    fn test_illinois_sqrt2() {
1253        let root = illinois(sq, 1.0, 2.0, 1e-10).unwrap();
1254        assert!((root - 2.0f64.sqrt()).abs() < 1e-9);
1255    }
1256
1257    #[test]
1258    fn test_muller_sqrt2() {
1259        let root = muller(sq, 1.0, 1.4, 2.0, 1e-10, 100).unwrap();
1260        assert!((root - 2.0f64.sqrt()).abs() < 1e-8);
1261    }
1262
1263    #[test]
1264    fn test_fixed_point_sqrt2() {
1265        // g(x) = (x + 2/x) / 2 — Newton for sqrt(2)
1266        let root = fixed_point(|x| (x + 2.0 / x) / 2.0, 1.5, 1e-10, 100).unwrap();
1267        assert!((root - 2.0f64.sqrt()).abs() < 1e-9);
1268    }
1269
1270    #[test]
1271    fn test_trapezoid_sine() {
1272        let result = trapezoid(|x: f64| x.sin(), 0.0, std::f64::consts::PI, 1000);
1273        assert!((result - 2.0).abs() < 1e-5);
1274    }
1275
1276    #[test]
1277    fn test_simpsons_polynomial() {
1278        // integrate x^2 from 0 to 1 = 1/3
1279        let result = simpsons(|x| x * x, 0.0, 1.0, 100);
1280        assert!((result - 1.0 / 3.0).abs() < 1e-10);
1281    }
1282
1283    #[test]
1284    fn test_gauss_legendre_polynomial() {
1285        // integrate x^4 from 0 to 1 = 1/5
1286        let result = gauss_legendre(|x| x.powi(4), 0.0, 1.0, 5);
1287        assert!((result - 0.2).abs() < 1e-10);
1288    }
1289
1290    #[test]
1291    fn test_romberg_exp() {
1292        // integrate e^x from 0 to 1 = e - 1
1293        let result = romberg(|x: f64| x.exp(), 0.0, 1.0, 8, 1e-10);
1294        assert!((result - (std::f64::consts::E - 1.0)).abs() < 1e-8);
1295    }
1296
1297    #[test]
1298    fn test_adaptive_simpson() {
1299        let result = adaptive_simpson(|x: f64| x.sin(), 0.0, std::f64::consts::PI, 1e-8, 20);
1300        assert!((result - 2.0).abs() < 1e-8);
1301    }
1302
1303    #[test]
1304    fn test_monte_carlo_integrate() {
1305        // integrate 1 over [0,1]^2 = 1
1306        let result = monte_carlo_integrate(|_p| 1.0, &[(0.0, 1.0), (0.0, 1.0)], 100_000, 42);
1307        assert!((result - 1.0).abs() < 0.01);
1308    }
1309
1310    #[test]
1311    fn test_euler_exp_decay() {
1312        // y' = -y, y(0) = 1 => y = e^-t
1313        let sol = euler(|_t, y| vec![-y[0]], 0.0, &[1.0], 0.001, 1000);
1314        let last = &sol[1000];
1315        assert!((last[0] - (-1.0f64).exp()).abs() < 0.01);
1316    }
1317
1318    #[test]
1319    fn test_rk4_exp_decay() {
1320        let sol = rk4(|_t, y| vec![-y[0]], 0.0, &[1.0], 0.01, 100);
1321        let last = &sol[100];
1322        assert!((last[0] - (-1.0f64).exp()).abs() < 1e-6);
1323    }
1324
1325    #[test]
1326    fn test_rk45_exp_decay() {
1327        let (ts, ys) = rk45(|_t, y| vec![-y[0]], 0.0, &[1.0], 1.0, 1e-8, 1e-6, 0.1);
1328        assert!(!ts.is_empty());
1329        let last = ys.last().unwrap();
1330        assert!((last[0] - (-1.0f64).exp()).abs() < 1e-6);
1331    }
1332
1333    #[test]
1334    fn test_verlet_harmonic() {
1335        // x'' = -x (harmonic oscillator), x(0)=1, v(0)=0 => x(t)=cos(t)
1336        let result = verlet(1.0, 0.0, |x| -x, 0.001, 6283);
1337        let last = result.last().unwrap();
1338        // At t ~ 2*pi*k, x ~ 1
1339        let _ = last; // just check no panic
1340    }
1341
1342    #[test]
1343    fn test_matmul_identity() {
1344        let a = Matrix::from_data(2, 2, vec![1.0, 2.0, 3.0, 4.0]);
1345        let id = Matrix::identity(2);
1346        let c = matmul(&a, &id);
1347        assert!((c[(0, 0)] - 1.0).abs() < 1e-12);
1348        assert!((c[(1, 1)] - 4.0).abs() < 1e-12);
1349    }
1350
1351    #[test]
1352    fn test_solve_linear() {
1353        // 2x + y = 5, x + 3y = 10 => x=1, y=3
1354        let a = Matrix::from_data(2, 2, vec![2.0, 1.0, 1.0, 3.0]);
1355        let b = vec![5.0, 10.0];
1356        let x = solve_linear(&a, &b).unwrap();
1357        assert!((x[0] - 1.0).abs() < 1e-10);
1358        assert!((x[1] - 3.0).abs() < 1e-10);
1359    }
1360
1361    #[test]
1362    fn test_determinant() {
1363        let a = Matrix::from_data(2, 2, vec![3.0, 8.0, 4.0, 6.0]);
1364        let d = determinant(&a);
1365        assert!((d - (18.0 - 32.0)).abs() < 1e-10);
1366    }
1367
1368    #[test]
1369    fn test_inverse() {
1370        let a = Matrix::from_data(2, 2, vec![4.0, 7.0, 2.0, 6.0]);
1371        let inv = inverse(&a).unwrap();
1372        let prod = matmul(&a, &inv);
1373        assert!((prod[(0, 0)] - 1.0).abs() < 1e-10);
1374        assert!((prod[(1, 1)] - 1.0).abs() < 1e-10);
1375        assert!(prod[(0, 1)].abs() < 1e-10);
1376    }
1377
1378    #[test]
1379    fn test_cholesky() {
1380        // A = [[4, 2], [2, 3]]
1381        let a = Matrix::from_data(2, 2, vec![4.0, 2.0, 2.0, 3.0]);
1382        let l = cholesky(&a).unwrap();
1383        let lt = transpose(&l);
1384        let reconstructed = matmul(&l, &lt);
1385        assert!((reconstructed[(0, 0)] - 4.0).abs() < 1e-10);
1386        assert!((reconstructed[(0, 1)] - 2.0).abs() < 1e-10);
1387    }
1388
1389    #[test]
1390    fn test_lagrange_interp() {
1391        // Should interpolate x^2 exactly at given nodes
1392        let xs = vec![0.0, 1.0, 2.0, 3.0];
1393        let ys: Vec<f64> = xs.iter().map(|x| x * x).collect();
1394        let v = lagrange_interp(&xs, &ys, 1.5);
1395        assert!((v - 2.25).abs() < 1e-10);
1396    }
1397
1398    #[test]
1399    fn test_cubic_spline() {
1400        let xs = vec![0.0, 1.0, 2.0, 3.0];
1401        let ys: Vec<f64> = xs.iter().map(|x: &f64| x.sin()).collect();
1402        let spline = natural_cubic_spline(&xs, &ys);
1403        // At knots the spline should be exact
1404        for (x, y) in xs.iter().zip(ys.iter()) {
1405            let v = spline.evaluate(*x);
1406            assert!((v - y).abs() < 1e-10, "spline at knot {}: {} vs {}", x, v, y);
1407        }
1408    }
1409
1410    #[test]
1411    fn test_lerp() {
1412        assert_eq!(lerp(0.0, 10.0, 0.5), 5.0);
1413        assert_eq!(lerp(0.0, 10.0, 0.0), 0.0);
1414        assert_eq!(lerp(0.0, 10.0, 1.0), 10.0);
1415    }
1416
1417    #[test]
1418    fn test_bilinear() {
1419        // All corners = 1.0 => any point = 1.0
1420        assert_eq!(bilinear(1.0, 1.0, 1.0, 1.0, 0.5, 0.5), 1.0);
1421    }
1422
1423    #[test]
1424    fn test_power_iteration() {
1425        // A = [[2, 1], [1, 2]] — dominant eigenvalue = 3
1426        let a = Matrix::from_data(2, 2, vec![2.0, 1.0, 1.0, 2.0]);
1427        let (lambda, _v) = power_iteration(&a, 1000, 1e-10);
1428        assert!((lambda - 3.0).abs() < 1e-8);
1429    }
1430
1431    #[test]
1432    fn test_eigenvalues_2x2() {
1433        let a = Matrix::from_data(2, 2, vec![4.0, 1.0, 2.0, 3.0]);
1434        let (e1, e2) = eigenvalues_2x2(&a);
1435        // Trace = 7, det = 10, eigenvalues: (7±3)/2 = 5, 2
1436        let mut evs = [e1, e2];
1437        evs.sort_by(|a, b| b.partial_cmp(a).unwrap());
1438        assert!((evs[0] - 5.0).abs() < 1e-10);
1439        assert!((evs[1] - 2.0).abs() < 1e-10);
1440    }
1441
1442    #[test]
1443    fn test_qr_decompose() {
1444        let a = Matrix::from_data(3, 2, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1445        let (q, r) = qr_decompose(&a);
1446        let recon = matmul(&q, &r);
1447        for i in 0..3 {
1448            for j in 0..2 {
1449                assert!((recon[(i, j)] - a[(i, j)]).abs() < 1e-9,
1450                    "QR mismatch at ({},{}) : {} vs {}", i, j, recon[(i,j)], a[(i,j)]);
1451            }
1452        }
1453    }
1454
1455    #[test]
1456    fn test_leapfrog_basic() {
1457        // Harmonic oscillator: a = -x
1458        let pos = vec![1.0, 0.0, 0.0];
1459        let vel = vec![0.0, 0.0, 0.0];
1460        let steps = leapfrog(&pos, &vel, |p| vec![-p[0], 0.0, 0.0], 0.001, 100);
1461        assert_eq!(steps.len(), 101);
1462    }
1463
1464    #[test]
1465    fn test_adams_bashforth4() {
1466        // y' = -y, y(0)=1
1467        let sol = adams_bashforth4(|_t, y| vec![-y[0]], 0.0, &[1.0], 0.01, 100);
1468        let last = sol.last().unwrap();
1469        assert!((last[0] - (-1.0f64).exp()).abs() < 0.01);
1470    }
1471}