use crate::solver::{
Scenario, SimulationResult, Solver, SolverConfiguration, SolverType, validate_state,
};
#[derive(Debug, Clone, Copy, Default)]
pub struct RK4Solver;
impl RK4Solver {
pub fn new() -> Self {
Self
}
}
impl Solver for RK4Solver {
fn solve(
&self,
scenario: &Scenario,
config: &SolverConfiguration,
) -> Result<SimulationResult, String> {
config.validate()?;
scenario.validate()?;
let (total_time, time_steps) = match &config.solver_type {
SolverType::TimeEvolution {
total_time,
time_steps,
} => (*total_time, *time_steps),
other => {
return Err(format!(
"EulerSolver only supports TimeEvolution configuration, got {}",
other.name()
));
}
};
let dt = total_time / (time_steps as f64);
let mut state = match scenario.conditions.initial_condition() {
Some(initial_state) => initial_state.clone(),
None => return Err("No initial condition found in domain boundaries".to_string()),
};
let mut time_points = Vec::with_capacity(time_steps + 1);
let mut state_trajectory = Vec::with_capacity(time_steps + 1);
time_points.push(0.0);
state_trajectory.push(state.clone());
for step in 0..time_steps {
let t = (step as f64) * dt;
let ctx = crate::physics::ComputeContext::new(t, dt);
let k1 = scenario.model.compute_physics(&state, &ctx);
let state_k2 = state.clone() + k1.clone() * (dt / 2.0);
let k2 = scenario.model.compute_physics(&state_k2, &ctx);
let state_k3 = state.clone() + k2.clone() * (dt / 2.0);
let k3 = scenario.model.compute_physics(&state_k3, &ctx);
let state_k4 = state.clone() + k3.clone() * dt;
let k4 = scenario.model.compute_physics(&state_k4, &ctx);
let weighted_slope = k1 + k2 * 2.0 + k3 * 2.0 + k4;
state = state.clone() + weighted_slope * (dt / 6.0);
state_trajectory.push(state.clone());
time_points.push((step as f64 + 1.0) * dt);
validate_state(&state, step + 1)?;
}
let final_state = state;
let mut result = SimulationResult::new(time_points, state_trajectory, final_state);
result.add_metadata("solver", "Runge-Kutta 4");
result.add_metadata("time steps", time_steps.to_string());
result.add_metadata("dt", dt.to_string());
result.add_metadata("total time", total_time.to_string());
result.add_metadata("function evaluations", (4 * time_steps).to_string());
Ok(result)
}
fn name(&self) -> &'static str {
"Runge Kutta (RK4)"
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::physics::{PhysicalData, PhysicalModel, PhysicalQuantity, PhysicalState};
use crate::solver::boundary::DomainBoundaries;
use serde::{Deserialize, Serialize};
#[derive(Deserialize, Serialize)]
struct ExponentialDecay {
points: usize,
decay_rate: f64, }
#[typetag::serde]
impl PhysicalModel for ExponentialDecay {
fn points(&self) -> usize {
self.points
}
fn compute_physics(
&self,
state: &PhysicalState,
_ctx: &crate::physics::ComputeContext,
) -> PhysicalState {
let mut result = state.clone();
if let Some(conc) = result.get_mut(PhysicalQuantity::Concentration) {
conc.apply(|y| -self.decay_rate * y);
}
result
}
fn setup_initial_state(&self) -> PhysicalState {
PhysicalState::new(
PhysicalQuantity::Concentration,
PhysicalData::uniform_vector(self.points, 1.0),
)
}
fn name(&self) -> &'static str {
"Exponential Decay"
}
}
#[derive(Deserialize, Serialize)]
struct ConstantGrowth {
points: usize,
growth_rate: f64,
}
#[typetag::serde]
impl PhysicalModel for ConstantGrowth {
fn points(&self) -> usize {
self.points
}
fn compute_physics(
&self,
_state: &PhysicalState,
_ctx: &crate::physics::ComputeContext,
) -> PhysicalState {
PhysicalState::new(
PhysicalQuantity::Concentration,
PhysicalData::uniform_vector(self.points, self.growth_rate),
)
}
fn setup_initial_state(&self) -> PhysicalState {
PhysicalState::new(
PhysicalQuantity::Concentration,
PhysicalData::uniform_vector(self.points, 0.0),
)
}
fn name(&self) -> &'static str {
"Constant Growth"
}
}
#[derive(Deserialize, Serialize)]
struct HarmonicOscillator {
points: usize,
omega: f64, }
#[typetag::serde]
impl PhysicalModel for HarmonicOscillator {
fn points(&self) -> usize {
self.points
}
fn compute_physics(
&self,
state: &PhysicalState,
_ctx: &crate::physics::ComputeContext,
) -> PhysicalState {
let mut result = PhysicalState::empty();
let y1 = state.get(PhysicalQuantity::Concentration).unwrap();
let y2 = state.get(PhysicalQuantity::Velocity).unwrap();
result.set(PhysicalQuantity::Concentration, y2.clone());
let mut dy2 = y1.clone();
dy2.apply(|y| -self.omega * self.omega * y);
result.set(PhysicalQuantity::Velocity, dy2);
result
}
fn setup_initial_state(&self) -> PhysicalState {
let mut state = PhysicalState::empty();
state.set(
PhysicalQuantity::Concentration,
PhysicalData::uniform_vector(self.points, 1.0),
);
state.set(
PhysicalQuantity::Velocity,
PhysicalData::uniform_vector(self.points, 0.0),
);
state
}
fn name(&self) -> &str {
"Harmonic Oscillator"
}
}
#[test]
fn test_rk4_solver_creation() {
let solver = RK4Solver::new();
assert_eq!(solver.name(), "Runge Kutta (RK4)");
}
#[test]
fn test_rk4_solver_default() {
let solver = RK4Solver::default();
assert_eq!(solver.name(), "Runge Kutta (RK4)");
}
#[test]
fn test_rk4_accepts_time_evolution() {
let solver = RK4Solver::new();
let config = SolverConfiguration::time_evolution(10.0, 100);
let model = Box::new(ConstantGrowth {
points: 10,
growth_rate: 1.0,
});
let initial = model.setup_initial_state();
let boundaries = DomainBoundaries::temporal(initial);
let scenario = Scenario::new(model, boundaries);
let result = solver.solve(&scenario, &config);
assert!(result.is_ok());
}
#[test]
fn test_rk4_rejects_iterative() {
let solver = RK4Solver::new();
let config = SolverConfiguration::iterative(1e-6, 100);
let model = Box::new(ConstantGrowth {
points: 10,
growth_rate: 1.0,
});
let initial = model.setup_initial_state();
let boundaries = DomainBoundaries::temporal(initial);
let scenario = Scenario::new(model, boundaries);
let result = solver.solve(&scenario, &config);
assert!(result.is_err());
assert!(result.unwrap_err().contains("only supports TimeEvolution"));
}
#[test]
fn test_rk4_rejects_analytical() {
let solver = RK4Solver::new();
let config = SolverConfiguration::analytical(5.0);
let model = Box::new(ConstantGrowth {
points: 10,
growth_rate: 1.0,
});
let initial = model.setup_initial_state();
let boundaries = DomainBoundaries::temporal(initial);
let scenario = Scenario::new(model, boundaries);
let result = solver.solve(&scenario, &config);
assert!(result.is_err());
}
#[test]
fn test_rk4_constant_growth() {
let solver = RK4Solver::new();
let growth_rate = 2.0;
let model = Box::new(ConstantGrowth {
points: 5,
growth_rate,
});
let initial = model.setup_initial_state();
let boundaries = DomainBoundaries::temporal(initial);
let scenario = Scenario::new(model, boundaries);
let total_time = 10.0;
let config = SolverConfiguration::time_evolution(total_time, 100);
let result = solver.solve(&scenario, &config).unwrap();
let final_concentration = result
.final_state
.get(PhysicalQuantity::Concentration)
.unwrap();
let expected_concentration = growth_rate * total_time;
let calculated_concentration = final_concentration.as_vector()[0];
assert!((calculated_concentration - expected_concentration).abs() < 1e-10);
}
#[test]
fn test_rk4_exponential_decay() {
let solver = RK4Solver::new();
let decay_rate = 0.1;
let model = Box::new(ExponentialDecay {
points: 5,
decay_rate,
});
let initial = model.setup_initial_state();
let boundaries = DomainBoundaries::temporal(initial);
let scenario = Scenario::new(model, boundaries);
let total_time = 10.0;
let time_steps = 100;
let config = SolverConfiguration::time_evolution(total_time, time_steps);
let result = solver.solve(&scenario, &config).unwrap();
let expected_concentraion = (-decay_rate * total_time).exp();
let final_concentration = result
.final_state
.get(PhysicalQuantity::Concentration)
.unwrap();
let actual_concentration = final_concentration.as_vector()[0];
let error = (actual_concentration - expected_concentraion).abs();
assert!(error < 1e-4, "Error {} is too large for RK4", error);
}
#[test]
fn test_rk4_convergence() {
let solver = RK4Solver::new();
let decay_rate = 0.1;
let total_time = 5.0;
let model = Box::new(ExponentialDecay {
points: 3,
decay_rate,
});
let exact_solution = (-decay_rate * total_time).exp();
let vsteps = vec![50, 100, 200, 400];
let mut verrors: Vec<f64> = Vec::new();
for &steps in &vsteps {
let initial = model.setup_initial_state();
let boundaries = DomainBoundaries::temporal(initial);
let scenario = Scenario::new(
Box::new(ExponentialDecay {
points: 3,
decay_rate,
}),
boundaries,
);
let config = SolverConfiguration::time_evolution(total_time, steps);
let result = solver.solve(&scenario, &config).unwrap();
let final_concentration = result
.final_state
.get(PhysicalQuantity::Concentration)
.unwrap();
verrors.push((final_concentration.as_vector()[0] - exact_solution).abs());
}
for i in 0..verrors.len() - 1 {
let ratio = verrors[i] / verrors[i + 1];
assert!(
ratio > 12.0 && ratio < 20.0,
"Convergence ratio {} is not a fourth-order a step {}",
ratio,
i
);
}
}
#[test]
fn test_rk4_solver_harmonic_oscillator() {
let solver = RK4Solver::new();
let omega = 1.0;
let model = Box::new(HarmonicOscillator { points: 3, omega });
let initial = model.setup_initial_state();
let boundaries = DomainBoundaries::temporal(initial);
let scenario = Scenario::new(model, boundaries);
let period = 2.0 * std::f64::consts::PI;
let config = SolverConfiguration::time_evolution(period, 100);
let result = solver.solve(&scenario, &config).unwrap();
let final_position = result
.final_state
.get(PhysicalQuantity::Concentration)
.unwrap();
let expected = 1.0;
let actual_position = final_position.as_vector()[0];
assert!((actual_position - expected).abs() < 0.01);
}
#[test]
fn test_rk4_trajectory_length() {
let solver = RK4Solver::new();
let model = Box::new(ConstantGrowth {
points: 5,
growth_rate: 1.0,
});
let initial = model.setup_initial_state();
let boundaries = DomainBoundaries::temporal(initial);
let scenario = Scenario::new(model, boundaries);
let time_steps = 100;
let config = SolverConfiguration::time_evolution(10.0, time_steps);
let result = solver.solve(&scenario, &config).unwrap();
assert_eq!(result.time_points.len(), time_steps + 1);
assert_eq!(result.state_trajectory.len(), time_steps + 1);
}
#[test]
fn test_rk4_time_points() {
let solver = RK4Solver::new();
let model = Box::new(ConstantGrowth {
points: 5,
growth_rate: 1.0,
});
let initial = model.setup_initial_state();
let boundaries = DomainBoundaries::temporal(initial);
let scenario = Scenario::new(model, boundaries);
let total_time = 20.0;
let time_steps = 100;
let dt = total_time / (time_steps as f64);
let config = SolverConfiguration::time_evolution(total_time, time_steps);
let result = solver.solve(&scenario, &config).unwrap();
assert!((result.time_points[0] - 0.0).abs() < 1e-10);
assert!((result.time_points.last().unwrap() - total_time).abs() < 1e-10);
for i in 1..result.time_points.len() {
let actual_dt = result.time_points[i] - result.time_points[i - 1];
assert!((actual_dt - dt).abs() < 1e-10);
}
}
#[test]
fn test_rk4_metadata() {
let solver = RK4Solver::new();
let model = Box::new(ConstantGrowth {
points: 5,
growth_rate: 1.0,
});
let initial = model.setup_initial_state();
let boundaries = DomainBoundaries::temporal(initial);
let scenario = Scenario::new(model, boundaries);
let total_time = 100.0;
let time_steps = 500;
let config = SolverConfiguration::time_evolution(total_time, time_steps);
let result = solver.solve(&scenario, &config).unwrap();
assert_eq!(
result.metadata.get("solver"),
Some(&"Runge-Kutta 4".to_string())
);
assert_eq!(result.metadata.get("time steps"), Some(&"500".to_string()));
assert_eq!(result.metadata.get("total time"), Some(&"100".to_string()));
assert_eq!(
result.metadata.get("function evaluations"),
Some(&"2000".to_string())
);
let dt_str = result.metadata.get("dt").unwrap();
let dt: f64 = dt_str.parse().unwrap();
assert!((dt - 0.2).abs() < 1e-10);
}
#[test]
fn test_rk4_detects_nan() {
#[derive(Deserialize, Serialize)]
struct NaNModel {
points: usize,
}
#[typetag::serde]
impl PhysicalModel for NaNModel {
fn points(&self) -> usize {
self.points
}
fn compute_physics(
&self,
_state: &PhysicalState,
_ctx: &crate::physics::ComputeContext,
) -> PhysicalState {
PhysicalState::new(
PhysicalQuantity::Concentration,
PhysicalData::uniform_vector(self.points, f64::NAN),
)
}
fn setup_initial_state(&self) -> PhysicalState {
PhysicalState::new(
PhysicalQuantity::Concentration,
PhysicalData::uniform_vector(self.points, 1.0),
)
}
fn name(&self) -> &str {
"NaN Model"
}
}
let solver = RK4Solver::new();
let model = Box::new(NaNModel { points: 5 });
let initial = model.setup_initial_state();
let boundaries = DomainBoundaries::temporal(initial);
let scenario = Scenario::new(model, boundaries);
let config = SolverConfiguration::time_evolution(10.0, 10);
let result = solver.solve(&scenario, &config);
assert!(result.is_err());
let error = result.unwrap_err();
assert!(error.contains("NaN"));
}
#[test]
fn test_rk4_detects_inf() {
#[derive(Deserialize, Serialize)]
struct InfModel {
points: usize,
}
#[typetag::serde]
impl PhysicalModel for InfModel {
fn points(&self) -> usize {
self.points
}
fn compute_physics(
&self,
_state: &PhysicalState,
_ctx: &crate::physics::ComputeContext,
) -> PhysicalState {
PhysicalState::new(
PhysicalQuantity::Concentration,
PhysicalData::uniform_vector(self.points, f64::INFINITY),
)
}
fn setup_initial_state(&self) -> PhysicalState {
PhysicalState::new(
PhysicalQuantity::Concentration,
PhysicalData::uniform_vector(self.points, 1.0),
)
}
fn name(&self) -> &str {
"Inf Model"
}
}
let solver = RK4Solver::new();
let model = Box::new(InfModel { points: 5 });
let initial = model.setup_initial_state();
let boundaries = DomainBoundaries::temporal(initial);
let scenario = Scenario::new(model, boundaries);
let config = SolverConfiguration::time_evolution(10.0, 10);
let result = solver.solve(&scenario, &config);
assert!(result.is_err());
let error = result.unwrap_err();
assert!(error.contains("Infinity"));
}
#[test]
fn test_rk4_single_step() {
let solver = RK4Solver::new();
let model = Box::new(ConstantGrowth {
points: 3,
growth_rate: 5.0,
});
let initial = model.setup_initial_state();
let boundaries = DomainBoundaries::temporal(initial);
let scenario = Scenario::new(model, boundaries);
let config = SolverConfiguration::time_evolution(1.0, 1);
let result = solver.solve(&scenario, &config).unwrap();
assert_eq!(result.time_points.len(), 2);
assert_eq!(result.state_trajectory.len(), 2);
let final_conc = result
.final_state
.get(PhysicalQuantity::Concentration)
.unwrap();
assert!((final_conc.as_vector()[0] - 5.0).abs() < 1e-10);
}
#[test]
fn test_rk4_multi_quantity() {
#[derive(Deserialize, Serialize)]
struct MultiQuantityModel {
points: usize,
}
#[typetag::serde]
impl PhysicalModel for MultiQuantityModel {
fn points(&self) -> usize {
self.points
}
fn compute_physics(
&self,
_state: &PhysicalState,
_ctx: &crate::physics::ComputeContext,
) -> PhysicalState {
let mut result = PhysicalState::empty();
result.set(
PhysicalQuantity::Concentration,
PhysicalData::uniform_vector(self.points, 1.0),
);
result.set(
PhysicalQuantity::Temperature,
PhysicalData::uniform_vector(self.points, 0.1),
);
result
}
fn setup_initial_state(&self) -> PhysicalState {
let mut state = PhysicalState::empty();
state.set(
PhysicalQuantity::Concentration,
PhysicalData::uniform_vector(self.points, 0.0),
);
state.set(
PhysicalQuantity::Temperature,
PhysicalData::uniform_vector(self.points, 298.0),
);
state
}
fn name(&self) -> &str {
"Multi-Quantity Model"
}
}
let solver = RK4Solver::new();
let model = Box::new(MultiQuantityModel { points: 5 });
let initial = model.setup_initial_state();
let boundaries = DomainBoundaries::temporal(initial);
let scenario = Scenario::new(model, boundaries);
let config = SolverConfiguration::time_evolution(10.0, 100);
let result = solver.solve(&scenario, &config).unwrap();
let final_conc = result
.final_state
.get(PhysicalQuantity::Concentration)
.unwrap();
assert!((final_conc.as_vector()[0] - 10.0).abs() < 1e-10);
let final_temp = result
.final_state
.get(PhysicalQuantity::Temperature)
.unwrap();
assert!((final_temp.as_vector()[0] - 299.0).abs() < 1e-10);
}
#[test]
fn test_rk4_vs_analytical_accuracy() {
let solver = RK4Solver::new();
let k = 0.3;
let model = Box::new(ExponentialDecay {
points: 1,
decay_rate: k,
});
let initial = model.setup_initial_state();
let boundaries = DomainBoundaries::temporal(initial);
let scenario = Scenario::new(model, boundaries);
let test_times = vec![1.0, 5.0, 10.0, 20.0];
for &t in &test_times {
let config = SolverConfiguration::time_evolution(t, 100);
let result = solver.solve(&scenario, &config).unwrap();
let analytical = (-k * t).exp();
let numerical = result
.final_state
.get(PhysicalQuantity::Concentration)
.unwrap()
.as_vector()[0];
let relative_error = ((numerical - analytical) / analytical).abs();
assert!(
relative_error < 0.001,
"At t={}: relative error {} too large",
t,
relative_error
);
}
}
}