Skip to main content

numerics_ode/
euler.rs

1//! Euler method (first-order explicit).
2//!
3//! Solves the initial value problem
4//! ```text
5//! dy/dx = f(x, y),  y(x₀) = y₀
6//! ```
7//! using the forward Euler scheme:
8//! ```text
9//! y_{n+1} = y_n + h · f(x_n, y_n)
10//! ```
11//!
12//! Global truncation error: O(h).
13
14/// Solve an IVP using the forward Euler method.
15///
16/// # Arguments
17///
18/// * `f`    — Right-hand side `dy/dx = f(x, y)`.
19/// * `x0`   — Initial independent variable.
20/// * `y0`   — Initial value.
21/// * `x_end`— Terminal value of the independent variable.
22/// * `n`    — Number of steps (≥ 1).
23///
24/// # Returns
25///
26/// `(xs, ys)` — vectors of `x` and `y` values at each step.
27///
28/// # Panics
29///
30/// Panics if `n` is zero.
31pub fn solve(f: &dyn Fn(f64, f64) -> f64, x0: f64, y0: f64, x_end: f64, n: usize) -> (Vec<f64>, Vec<f64>) {
32    assert!(n >= 1, "number of steps must be ≥ 1");
33    let h = (x_end - x0) / n as f64;
34    let mut xs = Vec::with_capacity(n + 1);
35    let mut ys = Vec::with_capacity(n + 1);
36    xs.push(x0);
37    ys.push(y0);
38    let mut x = x0;
39    let mut y = y0;
40    for _ in 0..n {
41        y = y + h * f(x, y);
42        x += h;
43        xs.push(x);
44        ys.push(y);
45    }
46    (xs, ys)
47}
48
49#[cfg(test)]
50mod tests {
51    use super::*;
52
53    /// dy/dx = 0  →  y = y0 (constant)
54    #[test]
55    fn constant_rhs() {
56        let f = |_x: f64, _y: f64| 0.0_f64;
57        let (_xs, ys) = solve(&f, 0.0, 5.0, 1.0, 10);
58        assert_eq!(ys.len(), 11);
59        for &y in &ys {
60            assert!((y - 5.0).abs() < 1e-12, "y = {y}, expected 5.0");
61        }
62    }
63
64    /// dy/dx = 1  →  y = x + y0
65    #[test]
66    fn linear_rhs() {
67        let f = |_x: f64, _y: f64| 1.0;
68        let (_, ys) = solve(&f, 0.0, 0.0, 2.0, 200);
69        assert!((ys[200] - 2.0).abs() < 1e-4);
70    }
71
72    /// dy/dx = y  →  y = e^x
73    #[test]
74    fn exponential_growth() {
75        let f = |_x: f64, y: f64| y;
76        let (_, ys) = solve(&f, 0.0, 1.0, 1.0, 10_000);
77        let exact = 1.0_f64.exp();
78        assert!((ys[10_000] - exact).abs() < 2e-4, "got {}, exact {}", ys[10_000], exact);
79    }
80
81    /// Euler is first order: error ∝ h
82    #[test]
83    fn convergence_order_euler() {
84        let f = |_x: f64, y: f64| y;
85        let exact = 1.0_f64.exp();
86        let e1 = (solve(&f, 0.0, 1.0, 1.0, 100).1[100] - exact).abs();
87        let e2 = (solve(&f, 0.0, 1.0, 1.0, 200).1[200] - exact).abs();
88        let ratio = e1 / e2;
89        // ratio should be ~2 for first-order method
90        assert!(ratio > 1.6 && ratio < 2.6, "convergence ratio = {ratio}, expected ~2");
91    }
92
93    /// dy/dx = -y  →  y = e^{-x}
94    #[test]
95    fn exponential_decay() {
96        let f = |_x: f64, y: f64| -y;
97        let (_, ys) = solve(&f, 0.0, 1.0, 1.0, 10_000);
98        let exact = (-1.0_f64).exp();
99        assert!((ys[10_000] - exact).abs() < 1e-4);
100    }
101
102    /// dy/dx = 2x  →  y = x² + y0
103    #[test]
104    fn quadratic_rhs() {
105        let f = |x: f64, _y: f64| 2.0 * x;
106        let (_, ys) = solve(&f, 0.0, 0.0, 3.0, 10_000);
107        let exact = 9.0;
108        assert!((ys[10_000] - exact).abs() < 1e-2);
109    }
110
111    #[test]
112    #[should_panic(expected = "number of steps must be ≥ 1")]
113    fn panics_on_zero_steps() {
114        let f = |_x, _y| 0.0;
115        solve(&f, 0.0, 1.0, 1.0, 0);
116    }
117
118    #[test]
119    fn single_step() {
120        let f = |_x, y| y;
121        let (xs, ys) = solve(&f, 0.0, 1.0, 1.0, 1);
122        assert_eq!(xs.len(), 2);
123        assert_eq!(ys.len(), 2);
124        // Euler single step: y = 1 + 1*1 = 2
125        assert!((ys[1] - 2.0).abs() < 1e-12);
126    }
127
128    #[test]
129    fn negative_direction() {
130        // integrate backwards from x=1 to x=0
131        let f = |_x, y| y;
132        let (_, ys) = solve(&f, 1.0, 1.0_f64.exp(), 0.0, 10_000);
133        // exact: y(0) = 1
134        assert!((ys[10_000] - 1.0).abs() < 1e-3);
135    }
136
137    /// dy/dx = cos(x) → y = sin(x) + C
138    #[test]
139    fn trigonometric_rhs() {
140        let f = |x: f64, _y: f64| x.cos();
141        let (_, ys) = solve(&f, 0.0, 0.0, std::f64::consts::PI / 2.0, 10_000);
142        let exact = 1.0;
143        assert!((ys[10_000] - exact).abs() < 1e-3);
144    }
145}