1mod explicit;
24mod implicit;
25mod pde;
26mod types;
27mod utils;
28
29pub use explicit::{EulerSolver, Rk4Solver, Rk45Solver};
31pub use implicit::{Bdf2Solver, ImplicitEulerSolver};
32pub use pde::{
33 AdvectionEquation1D, BoundaryCondition, Grid1D, Grid2D, HeatEquation1D, PdeConfig, Poisson1D,
34 WaveEquation1D,
35};
36pub use types::{OdeConfig, OdeMethod, OdeSolution, OdeSystem, StepResult};
37pub use utils::{numerical_jacobian, solve_tridiagonal};
38
39#[cfg(test)]
40mod tests {
41 use super::utils::apply_bc_1d;
42 use super::*;
43 use std::f64::consts::PI;
44
45 struct ExponentialDecay;
49
50 impl OdeSystem for ExponentialDecay {
51 fn rhs(&self, _t: f64, y: &[f64], dydt: &mut [f64]) -> crate::error::SolverResult<()> {
52 dydt[0] = -y[0];
53 Ok(())
54 }
55 fn dim(&self) -> usize {
56 1
57 }
58 }
59
60 struct HarmonicOscillator;
63
64 impl OdeSystem for HarmonicOscillator {
65 fn rhs(&self, _t: f64, y: &[f64], dydt: &mut [f64]) -> crate::error::SolverResult<()> {
66 dydt[0] = y[1]; dydt[1] = -y[0]; Ok(())
69 }
70 fn dim(&self) -> usize {
71 2
72 }
73 }
74
75 struct StiffSystem;
77
78 impl OdeSystem for StiffSystem {
79 fn rhs(&self, t: f64, y: &[f64], dydt: &mut [f64]) -> crate::error::SolverResult<()> {
80 dydt[0] = -1000.0 * (y[0] - t.sin()) + t.cos();
81 Ok(())
82 }
83 fn dim(&self) -> usize {
84 1
85 }
86 }
87
88 struct VanDerPol {
90 mu: f64,
91 }
92
93 impl OdeSystem for VanDerPol {
94 fn rhs(&self, _t: f64, y: &[f64], dydt: &mut [f64]) -> crate::error::SolverResult<()> {
95 dydt[0] = y[1];
96 dydt[1] = self.mu * (1.0 - y[0] * y[0]) * y[1] - y[0];
97 Ok(())
98 }
99 fn dim(&self) -> usize {
100 2
101 }
102 }
103
104 #[test]
107 fn euler_exponential_decay() {
108 let sys = ExponentialDecay;
109 let config = OdeConfig {
110 t_start: 0.0,
111 t_end: 1.0,
112 dt: 0.001,
113 method: OdeMethod::Euler,
114 ..OdeConfig::default()
115 };
116 let sol = EulerSolver::solve(&sys, &[1.0], &config);
117 assert!(sol.is_ok());
118 let sol = sol.ok().filter(|s| !s.states.is_empty());
119 assert!(sol.is_some());
120 let sol = sol.as_ref().and_then(|s| s.states.last());
121 assert!(sol.is_some());
122 let y_final = sol.map(|s| s[0]).unwrap_or(0.0);
123 let expected = (-1.0_f64).exp();
124 assert!(
126 (y_final - expected).abs() < 0.01,
127 "Euler: y(1) = {y_final}, expected {expected}"
128 );
129 }
130
131 #[test]
132 fn rk4_harmonic_oscillator() {
133 let sys = HarmonicOscillator;
134 let config = OdeConfig {
135 t_start: 0.0,
136 t_end: 2.0 * PI,
137 dt: 0.01,
138 method: OdeMethod::Rk4,
139 ..OdeConfig::default()
140 };
141 let sol = Rk4Solver::solve(&sys, &[1.0, 0.0], &config);
142 assert!(sol.is_ok());
143 let sol = sol.ok().filter(|s| !s.states.is_empty());
144 assert!(sol.is_some());
145 let last = sol.as_ref().and_then(|s| s.states.last());
146 assert!(last.is_some());
147 let y = last.map(|s| s[0]).unwrap_or(0.0);
148 let v = last.map(|s| s[1]).unwrap_or(0.0);
149 assert!((y - 1.0).abs() < 1e-6, "RK4: y(2pi) = {y}, expected 1.0");
151 assert!(v.abs() < 1e-6, "RK4: v(2pi) = {v}, expected 0.0");
152 }
153
154 #[test]
155 fn rk45_adaptive_step() {
156 let sys = HarmonicOscillator;
157 let config = OdeConfig {
158 t_start: 0.0,
159 t_end: 2.0 * PI,
160 dt: 0.1,
161 rtol: 1e-8,
162 atol: 1e-10,
163 max_steps: 10_000,
164 method: OdeMethod::Rk45,
165 };
166 let sol = Rk45Solver::solve(&sys, &[1.0, 0.0], &config);
167 assert!(sol.is_ok());
168 let sol_data = sol.ok().filter(|s| !s.states.is_empty());
169 assert!(sol_data.is_some());
170 let sd = sol_data.as_ref();
171 let last = sd.and_then(|s| s.states.last());
172 let y = last.map(|s| s[0]).unwrap_or(0.0);
173 let v = last.map(|s| s[1]).unwrap_or(0.0);
174 assert!((y - 1.0).abs() < 1e-6, "RK45: y(2pi) = {y}, expected 1.0");
175 assert!(v.abs() < 1e-6, "RK45: v(2pi) = {v}, expected 0.0");
176 let num = sd.map(|s| s.num_steps).unwrap_or(0);
178 assert!(
179 num < 1000,
180 "RK45 should take fewer than 1000 steps, took {num}"
181 );
182 }
183
184 #[test]
185 fn implicit_euler_stiff() {
186 let sys = StiffSystem;
187 let config = OdeConfig {
188 t_start: 0.0,
189 t_end: 0.1,
190 dt: 0.01,
191 method: OdeMethod::ImplicitEuler,
192 ..OdeConfig::default()
193 };
194 let sol = ImplicitEulerSolver::solve(&sys, &[0.0], &config);
196 assert!(sol.is_ok());
197 let sol_data = sol.ok().filter(|s| !s.states.is_empty());
198 assert!(sol_data.is_some());
199 let last = sol_data.as_ref().and_then(|s| s.states.last());
200 let y = last.map(|s| s[0]).unwrap_or(f64::NAN);
201 let expected = 0.1_f64.sin();
202 assert!(
204 (y - expected).abs() < 0.05,
205 "ImplicitEuler: y(0.1) = {y}, expected {expected}"
206 );
207 }
208
209 #[test]
210 fn bdf2_van_der_pol() {
211 let sys = VanDerPol { mu: 1.0 };
212 let config = OdeConfig {
213 t_start: 0.0,
214 t_end: 1.0,
215 dt: 0.01,
216 method: OdeMethod::Bdf2,
217 ..OdeConfig::default()
218 };
219 let sol = Bdf2Solver::solve(&sys, &[2.0, 0.0], &config);
220 assert!(sol.is_ok());
221 let sol_data = sol.ok().filter(|s| !s.states.is_empty());
222 assert!(sol_data.is_some());
223 let sd = sol_data.as_ref();
225 let all_finite = sd
226 .map(|s| s.states.iter().all(|st| st.iter().all(|v| v.is_finite())))
227 .unwrap_or(false);
228 assert!(all_finite, "BDF2: all states should be finite");
229 }
230
231 #[test]
232 fn ode_convergence_order() {
233 let sys = ExponentialDecay;
235 let t_end = 1.0;
236 let exact = (-1.0_f64).exp();
237
238 let mut errors = Vec::new();
239 for &dt in &[0.1, 0.05, 0.025] {
240 let config = OdeConfig {
241 t_start: 0.0,
242 t_end,
243 dt,
244 method: OdeMethod::Rk4,
245 ..OdeConfig::default()
246 };
247 let sol = Rk4Solver::solve(&sys, &[1.0], &config);
248 let sol_data = sol.ok().filter(|s| !s.states.is_empty());
249 let y = sol_data
250 .as_ref()
251 .and_then(|s| s.states.last())
252 .map(|s| s[0])
253 .unwrap_or(0.0);
254 errors.push((y - exact).abs());
255 }
256
257 if errors[0] > 1e-15 && errors[1] > 1e-15 {
259 let ratio = errors[0] / errors[1];
260 assert!(
261 ratio > 10.0,
262 "RK4 convergence ratio should be ~16, got {ratio}"
263 );
264 }
265 }
266
267 #[test]
270 fn heat_explicit_gaussian() {
271 let alpha = 0.01;
273 let nx = 101;
274 let grid = Grid1D::new(-5.0, 5.0, nx);
275 let heat = HeatEquation1D { alpha };
276
277 let dt = 0.4 * heat.stability_limit(grid.dx); let config = PdeConfig {
279 grid: grid.clone(),
280 dt,
281 num_steps: 10,
282 bc_left: BoundaryCondition::Dirichlet(0.0),
283 bc_right: BoundaryCondition::Dirichlet(0.0),
284 };
285
286 let u0: Vec<f64> = (0..nx)
287 .map(|i| {
288 let x = grid.point(i);
289 (-x * x).exp()
290 })
291 .collect();
292
293 let result = heat.solve_explicit(&u0, &config);
294 assert!(result.is_ok());
295 let data = result.ok().filter(|d| !d.is_empty());
296 assert!(data.is_some());
297 let last = data.as_ref().and_then(|d| d.last());
298 assert!(last.is_some());
299
300 let u_final = last.as_ref().map(|u| u.as_slice()).unwrap_or(&[]);
302 if u_final.len() == nx {
303 let mid = nx / 2;
304 assert!(u_final[mid] < u0[mid], "Heat diffusion should reduce peak");
306 assert!(
308 u_final.iter().all(|&v| v >= -1e-10),
309 "Heat solution should remain non-negative"
310 );
311 }
312 }
313
314 #[test]
315 fn heat_crank_nicolson_stability() {
316 let alpha = 1.0;
318 let nx = 51;
319 let grid = Grid1D::new(0.0, 1.0, nx);
320 let heat = HeatEquation1D { alpha };
321
322 let dt = 10.0 * heat.stability_limit(grid.dx);
324 let config = PdeConfig {
325 grid: grid.clone(),
326 dt,
327 num_steps: 20,
328 bc_left: BoundaryCondition::Dirichlet(0.0),
329 bc_right: BoundaryCondition::Dirichlet(0.0),
330 };
331
332 let u0: Vec<f64> = (0..nx)
333 .map(|i| {
334 let x = grid.point(i);
335 (PI * x).sin()
336 })
337 .collect();
338
339 let result = heat.solve_implicit(&u0, &config);
340 assert!(result.is_ok());
341 let data = result.ok().filter(|d| !d.is_empty());
342 assert!(data.is_some());
343 let last = data.as_ref().and_then(|d| d.last());
344
345 let u_final = last.as_ref().map(|u| u.as_slice()).unwrap_or(&[]);
347 let max_val = u_final.iter().copied().fold(0.0_f64, f64::max);
348 assert!(
349 max_val < 2.0,
350 "Crank-Nicolson should be stable; max = {max_val}"
351 );
352 }
353
354 #[test]
355 fn wave_energy_conservation() {
356 let c = 1.0;
357 let nx = 101;
358 let grid = Grid1D::new(0.0, 1.0, nx);
359 let dx = grid.dx;
360 let wave = WaveEquation1D { c };
361
362 let dt = 0.5 * dx / c; let config = PdeConfig {
364 grid: grid.clone(),
365 dt,
366 num_steps: 50,
367 bc_left: BoundaryCondition::Dirichlet(0.0),
368 bc_right: BoundaryCondition::Dirichlet(0.0),
369 };
370
371 let u0: Vec<f64> = (0..nx)
373 .map(|i| {
374 let x = grid.point(i);
375 (PI * x).sin()
376 })
377 .collect();
378 let v0 = vec![0.0; nx];
379
380 let result = wave.solve(&u0, &v0, &config);
381 assert!(result.is_ok());
382 let data = result.ok().filter(|d| d.len() > 2);
383 assert!(data.is_some());
384
385 let states = data.as_deref().unwrap_or(&[]);
387 if states.len() >= 3 {
388 let energy = |u: &[f64], u_prev: &[f64]| -> f64 {
389 let mut ke = 0.0;
390 let mut pe = 0.0;
391 for i in 1..nx - 1 {
392 let v = (u[i] - u_prev[i]) / dt;
393 ke += 0.5 * v * v * dx;
394 let ux = (u[i + 1] - u[i - 1]) / (2.0 * dx);
395 pe += 0.5 * c * c * ux * ux * dx;
396 }
397 ke + pe
398 };
399
400 let e_initial = energy(&states[1], &states[0]);
401 let e_final = energy(&states[states.len() - 1], &states[states.len() - 2]);
402
403 if e_initial > 1e-10 {
405 let rel_change = (e_final - e_initial).abs() / e_initial;
406 assert!(
407 rel_change < 0.1,
408 "Wave energy changed by {:.2}%",
409 rel_change * 100.0
410 );
411 }
412 }
413 }
414
415 #[test]
416 fn poisson_analytical() {
417 let nx = 101;
420 let grid = Grid1D::new(0.0, 1.0, nx);
421 let poisson = Poisson1D;
422
423 let f_rhs: Vec<f64> = vec![2.0; nx];
424 let config = PdeConfig {
425 grid: grid.clone(),
426 dt: 0.0, num_steps: 0,
428 bc_left: BoundaryCondition::Dirichlet(0.0),
429 bc_right: BoundaryCondition::Dirichlet(0.0),
430 };
431
432 let result = poisson.solve(&f_rhs, &config);
433 assert!(result.is_ok());
434 let u = result.ok().unwrap_or_default();
435
436 let mut max_err = 0.0_f64;
437 for (i, u_val) in u.iter().enumerate().take(nx) {
438 let x = grid.point(i);
439 let exact = x * (1.0 - x);
440 max_err = max_err.max((u_val - exact).abs());
441 }
442
443 assert!(
445 max_err < 1e-3,
446 "Poisson max error = {max_err}, expected < 1e-3"
447 );
448 }
449
450 #[test]
451 fn advection_upwind() {
452 let a = 1.0;
453 let nx = 101;
454 let grid = Grid1D::new(0.0, 2.0, nx);
455 let adv = AdvectionEquation1D { a };
456
457 let dt = 0.5 * grid.dx / a.abs(); let config = PdeConfig {
459 grid: grid.clone(),
460 dt,
461 num_steps: 10,
462 bc_left: BoundaryCondition::Dirichlet(0.0),
463 bc_right: BoundaryCondition::Dirichlet(0.0),
464 };
465
466 let u0: Vec<f64> = (0..nx)
468 .map(|i| {
469 let x = grid.point(i);
470 if (0.5..=1.0).contains(&x) { 1.0 } else { 0.0 }
471 })
472 .collect();
473
474 let result = adv.solve_upwind(&u0, &config);
475 assert!(result.is_ok());
476 let data = result.ok().filter(|d| !d.is_empty());
477 assert!(data.is_some());
478
479 let last = data
481 .as_ref()
482 .and_then(|d| d.last())
483 .cloned()
484 .unwrap_or_default();
485 let max_val = last.iter().copied().fold(0.0_f64, f64::max);
487 assert!(
488 max_val <= 1.0 + 1e-10,
489 "Upwind should be monotone, max = {max_val}"
490 );
491 }
492
493 #[test]
494 fn lax_wendroff_accuracy() {
495 let a = 1.0;
496 let nx = 201;
497 let grid = Grid1D::new(0.0, 4.0, nx);
498 let adv = AdvectionEquation1D { a };
499
500 let dt = 0.5 * grid.dx / a.abs();
501 let num_steps = 20;
502 let config = PdeConfig {
503 grid: grid.clone(),
504 dt,
505 num_steps,
506 bc_left: BoundaryCondition::Dirichlet(0.0),
507 bc_right: BoundaryCondition::Dirichlet(0.0),
508 };
509
510 let u0: Vec<f64> = (0..nx)
512 .map(|i| {
513 let x = grid.point(i);
514 (-(x - 1.0).powi(2) / 0.1).exp()
515 })
516 .collect();
517
518 let result_lw = adv.solve_lax_wendroff(&u0, &config);
519 let result_up = adv.solve_upwind(&u0, &config);
520 assert!(result_lw.is_ok());
521 assert!(result_up.is_ok());
522
523 let lw_last = result_lw
524 .ok()
525 .and_then(|d| d.last().cloned())
526 .unwrap_or_default();
527 let up_last = result_up
528 .ok()
529 .and_then(|d| d.last().cloned())
530 .unwrap_or_default();
531
532 let lw_max = lw_last.iter().copied().fold(0.0_f64, f64::max);
534 let up_max = up_last.iter().copied().fold(0.0_f64, f64::max);
535
536 assert!(
537 lw_max >= up_max - 1e-10,
538 "Lax-Wendroff should have less diffusion: LW max = {lw_max}, upwind max = {up_max}"
539 );
540 }
541
542 #[test]
543 fn boundary_condition_enforcement() {
544 let nx = 11;
545 let mut u = vec![1.0_f64; nx];
546
547 apply_bc_1d(
548 &mut u,
549 &BoundaryCondition::Dirichlet(0.0),
550 &BoundaryCondition::Dirichlet(2.0),
551 nx,
552 );
553 assert!((u[0] - 0.0_f64).abs() < 1e-15);
554 assert!((u[nx - 1] - 2.0_f64).abs() < 1e-15);
555
556 let mut u2 = vec![0.0_f64; nx];
558 u2[1] = 3.0;
559 u2[nx - 2] = 5.0;
560 apply_bc_1d(
561 &mut u2,
562 &BoundaryCondition::Periodic,
563 &BoundaryCondition::Periodic,
564 nx,
565 );
566 assert!((u2[0] - 5.0_f64).abs() < 1e-15);
567 assert!((u2[nx - 1] - 3.0_f64).abs() < 1e-15);
568 }
569
570 #[test]
571 fn stability_limit_calculation() {
572 let heat = HeatEquation1D { alpha: 0.5 };
573 let dx = 0.1;
574 let limit = heat.stability_limit(dx);
575 assert!((limit - 0.01).abs() < 1e-15, "stability limit = {limit}");
577 }
578
579 #[test]
580 fn grid_construction() {
581 let g = Grid1D::new(0.0, 1.0, 11);
582 assert_eq!(g.nx, 11);
583 assert!((g.dx - 0.1).abs() < 1e-15);
584 assert!((g.point(0) - 0.0).abs() < 1e-15);
585 assert!((g.point(5) - 0.5).abs() < 1e-15);
586 assert!((g.point(10) - 1.0).abs() < 1e-15);
587
588 let g2 = Grid2D::new(0.0, 1.0, 11, -1.0, 1.0, 21);
589 assert_eq!(g2.nx, 11);
590 assert_eq!(g2.ny, 21);
591 assert!((g2.dx - 0.1).abs() < 1e-15);
592 assert!((g2.dy - 0.1).abs() < 1e-15);
593 }
594
595 #[test]
596 fn numerical_jacobian_accuracy() {
597 let sys = ExponentialDecay;
599 let jac = numerical_jacobian(&sys, 0.0, &[1.0], 1e-8);
600 assert!(jac.is_ok());
601 let j = jac.ok().unwrap_or_default();
602 if !j.is_empty() && !j[0].is_empty() {
603 assert!(
604 (j[0][0] - (-1.0)).abs() < 1e-5,
605 "Jacobian J[0][0] = {}, expected -1.0",
606 j[0][0]
607 );
608 }
609
610 let sys2 = HarmonicOscillator;
612 let jac2 = numerical_jacobian(&sys2, 0.0, &[1.0, 0.0], 1e-8);
613 assert!(jac2.is_ok());
614 let j2 = jac2.ok().unwrap_or_default();
615 if j2.len() >= 2 && j2[0].len() >= 2 {
616 assert!((j2[0][0]).abs() < 1e-5, "J[0][0] should be ~0");
617 assert!((j2[0][1] - 1.0).abs() < 1e-5, "J[0][1] should be ~1");
618 assert!((j2[1][0] - (-1.0)).abs() < 1e-5, "J[1][0] should be ~-1");
619 assert!((j2[1][1]).abs() < 1e-5, "J[1][1] should be ~0");
620 }
621 }
622}