use crate::physics::PhysicalState;
use crate::solver;
use crate::solver::{Scenario, SimulationResult, Solver, SolverConfiguration, SolverType};
#[derive(Debug, Clone, Copy, Default)]
pub struct EulerSolver;
impl EulerSolver {
pub fn new() -> Self {
Self
}
}
impl Solver for EulerSolver {
fn solve(
&self,
scenario: &Scenario,
config: &SolverConfiguration,
) -> Result<SimulationResult, String> {
config.validate()?;
scenario.validate()?;
let (total_time, time_steps) = match &config.solver_type {
SolverType::TimeEvolution {
total_time,
time_steps,
} => (*total_time, *time_steps),
other => {
return Err(format!(
"EulerSolver only supports TimeEvolution configuration, got {}",
other.name()
));
}
};
let dt = total_time / (time_steps as f64);
let mut state = match scenario.conditions.initial_condition() {
Some(initial_state) => initial_state.clone(),
None => return Err("No initial condition found in domain boundaries".to_string()),
};
let mut time_points = Vec::with_capacity(time_steps + 1);
let mut state_trajectory = Vec::with_capacity(time_steps + 1);
time_points.push(0.0);
state_trajectory.push(state.clone());
for step in 0..time_steps {
let t = dt * step as f64;
state.set_metadata("time".to_string(), t);
let physics: PhysicalState = scenario.model.compute_physics(&state);
state = state.clone() + physics * dt;
state_trajectory.push(state.clone());
time_points.push((step as f64 + 1.0) * dt);
solver::validate_state(&state, step + 1)?;
}
let final_state: PhysicalState = state;
let mut result = SimulationResult::new(time_points, state_trajectory, final_state);
result.add_metadata("solver", "Forward Euler");
result.add_metadata("time steps", time_steps.to_string());
result.add_metadata("dt", dt.to_string());
result.add_metadata("total time", total_time.to_string());
Ok(result)
}
fn name(&self) -> &'static str {
"Forward Euler"
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::physics::{PhysicalData, PhysicalModel, PhysicalQuantity, PhysicalState};
use crate::solver::boundary::DomainBoundaries;
use serde::{Deserialize, Serialize};
#[derive(Deserialize, Serialize)]
struct ExponentialDecay {
points: usize,
decay_rate: f64, }
#[typetag::serde]
impl PhysicalModel for ExponentialDecay {
fn points(&self) -> usize {
self.points
}
fn compute_physics(&self, state: &PhysicalState) -> PhysicalState {
let mut result = state.clone();
if let Some(conc) = result.get_mut(PhysicalQuantity::Concentration) {
conc.apply(|y| -self.decay_rate * y);
}
result
}
fn setup_initial_state(&self) -> PhysicalState {
PhysicalState::new(
PhysicalQuantity::Concentration,
PhysicalData::uniform_vector(self.points, 1.0),
)
}
fn name(&self) -> &'static str {
"Exponential Decay"
}
}
#[derive(Deserialize, Serialize)]
struct ConstantGrowth {
points: usize,
growth_rate: f64,
}
#[typetag::serde]
impl PhysicalModel for ConstantGrowth {
fn points(&self) -> usize {
self.points
}
fn compute_physics(&self, _state: &PhysicalState) -> PhysicalState {
PhysicalState::new(
PhysicalQuantity::Concentration,
PhysicalData::uniform_vector(self.points, self.growth_rate),
)
}
fn setup_initial_state(&self) -> PhysicalState {
PhysicalState::new(
PhysicalQuantity::Concentration,
PhysicalData::uniform_vector(self.points, 0.0),
)
}
fn name(&self) -> &'static str {
"Constant Growth"
}
}
#[test]
fn test_euler_solver_creation() {
let solver = EulerSolver::new();
assert_eq!(solver.name(), "Forward Euler");
}
#[test]
fn test_euler_solver_default() {
let solver = EulerSolver::default();
assert_eq!(solver.name(), "Forward Euler");
}
#[test]
fn test_euler_accepts_time_evolution() {
let solver = EulerSolver::new();
let config = SolverConfiguration::time_evolution(10.0, 100);
let model = Box::new(ConstantGrowth {
points: 10,
growth_rate: 1.0,
});
let initial = model.setup_initial_state();
let boundaries = DomainBoundaries::temporal(initial);
let scenario = Scenario::new(model, boundaries);
let result = solver.solve(&scenario, &config);
assert!(result.is_ok());
}
#[test]
fn test_euler_solver_iterative_failed() {
let solver = EulerSolver::new();
let config = SolverConfiguration::iterative(1e-6, 100);
let model = Box::new(ConstantGrowth {
points: 10,
growth_rate: 1.0,
});
let initial = model.setup_initial_state();
let boundaries = DomainBoundaries::temporal(initial);
let scenario = Scenario::new(model, boundaries);
let result = solver.solve(&scenario, &config);
assert!(result.is_err());
assert!(result.unwrap_err().contains("only supports TimeEvolution"));
}
#[test]
fn test_euler_solver_analytical_failed() {
let solver = EulerSolver::new();
let config = SolverConfiguration::analytical(0.5);
let model = Box::new(ConstantGrowth {
points: 10,
growth_rate: 1.0,
});
let initial = model.setup_initial_state();
let boundaries = DomainBoundaries::temporal(initial);
let scenario = Scenario::new(model, boundaries);
let result = solver.solve(&scenario, &config);
assert!(result.is_err());
assert!(result.unwrap_err().contains("only supports TimeEvolution"));
}
#[test]
fn test_euler_constant_growth() {
let solver = EulerSolver::new();
let growth_rate = 2.0;
let model = Box::new(ConstantGrowth {
points: 5,
growth_rate,
});
let initial = model.setup_initial_state();
let boundaries = DomainBoundaries::temporal(initial);
let scenario = Scenario::new(model, boundaries);
let total_time = 10.0;
let config = SolverConfiguration::time_evolution(total_time, 100);
let result = solver.solve(&scenario, &config).unwrap();
println!("{:?}", (10.0 - result.time_points.last().unwrap()));
assert!((result.time_points.last().unwrap() - total_time).abs() < 1e-10);
let final_state = result
.final_state
.get(PhysicalQuantity::Concentration)
.unwrap();
let expected_concentration = growth_rate * total_time;
let actual_concentration = final_state.as_vector()[0];
assert!((actual_concentration - expected_concentration).abs() < 1e-10);
}
#[test]
fn test_euler_exponential_decay() {
let solver = EulerSolver::new();
let decay_rate = 0.1;
let model = Box::new(ExponentialDecay {
points: 5,
decay_rate,
});
let initial = model.setup_initial_state();
let boundaries = DomainBoundaries::temporal(initial);
let scenario = Scenario::new(model, boundaries);
let total_time = 10.0;
let time_step = 1000;
let config = SolverConfiguration::time_evolution(total_time, time_step);
let result = solver.solve(&scenario, &config).unwrap();
let expected = (-decay_rate * total_time).exp();
let final_concentration = result
.final_state
.get(PhysicalQuantity::Concentration)
.unwrap();
let actual_concentration = final_concentration.as_vector()[0];
let error = (actual_concentration - expected).abs();
assert!(error < 0.01, "Error {} too large for dt=0.01", error);
}
#[test]
fn test_euler_convergence() {
let solver = EulerSolver::new();
let decay_rate = 0.5;
let total_time = 5.0;
let model = Box::new(ExponentialDecay {
points: 3,
decay_rate,
});
let exact = (-decay_rate * total_time).exp();
let vsteps: Vec<usize> = vec![100, 200, 400, 800];
let mut verrors: Vec<f64> = Vec::new();
for &steps in &vsteps {
let initial = model.setup_initial_state();
let boundaries = DomainBoundaries::temporal(initial);
let scenario = Scenario::new(
Box::new(ExponentialDecay {
points: 3,
decay_rate,
}),
boundaries,
);
let config = SolverConfiguration::time_evolution(total_time, steps);
let result = solver.solve(&scenario, &config).unwrap();
let final_state = result
.final_state
.get(PhysicalQuantity::Concentration)
.unwrap();
let actual = final_state.as_vector()[0];
let error = (actual - exact).abs();
verrors.push(error);
}
for i in 0..verrors.len() - 1 {
let ratio = verrors[i] / verrors[i + 1];
assert!(
ratio > 1.8 && ratio < 2.2,
"Convergence ration {} not a first order at step {}",
ratio,
i
);
}
}
#[test]
fn test_euler_trajectory_length() {
let solver = EulerSolver::new();
let model = Box::new(ConstantGrowth {
points: 5,
growth_rate: 1.0,
});
let initial = model.setup_initial_state();
let boundaries = DomainBoundaries::temporal(initial);
let scenario = Scenario::new(model, boundaries);
let time_steps = 100;
let config = SolverConfiguration::time_evolution(10.0, time_steps);
let result = solver.solve(&scenario, &config).unwrap();
assert_eq!(result.time_points.len(), time_steps + 1);
assert_eq!(result.state_trajectory.len(), time_steps + 1);
}
#[test]
fn test_euler_time_points() {
let solver = EulerSolver::new();
let model = Box::new(ConstantGrowth {
points: 5,
growth_rate: 1.0,
});
let initial = model.setup_initial_state();
let boundaries = DomainBoundaries::temporal(initial);
let scenario = Scenario::new(model, boundaries);
let total_time = 20.0;
let time_steps = 100;
let dt = total_time / (time_steps as f64);
let config = SolverConfiguration::time_evolution(total_time, time_steps);
let result = solver.solve(&scenario, &config).unwrap();
assert!((result.time_points[0] - 0.0).abs() <= 1e-10);
let final_time = *result.time_points.last().unwrap();
assert!(
(final_time - total_time).abs() <= 1e-14,
"Final time {} should be very close to {}. Difference {:e}",
final_time,
total_time,
(final_time - total_time).abs()
);
for i in 1..result.state_trajectory.len() {
let spacing = result.time_points[i] - result.time_points[i - 1];
assert!(
(spacing - dt).abs() <= 1e-12,
"Time step {} differs from mathematical dt {} by more than 1e-12",
spacing,
dt
);
}
}
#[test]
fn test_euler_time_precision() {
let solver = EulerSolver::new();
let model = Box::new(ConstantGrowth {
points: 3,
growth_rate: 1.0,
});
let initial = model.setup_initial_state();
let boundaries = DomainBoundaries::temporal(initial);
let scenario = Scenario::new(model, boundaries);
let total_time = 10.0;
let time_steps = 100;
let config = SolverConfiguration::time_evolution(total_time, time_steps);
let result = solver.solve(&scenario, &config).unwrap();
let final_time = *result.time_points.last().unwrap();
assert!(
(final_time - total_time).abs() < 1e-14,
"Direct calculation maintains precision: {} ≈ {} (error: {:e})",
final_time,
total_time,
(final_time - total_time).abs()
);
}
#[test]
fn test_euler_metadata() {
let solver = EulerSolver::new();
let model = Box::new(ConstantGrowth {
points: 5,
growth_rate: 1.0,
});
let initial = model.setup_initial_state();
let boundaries = DomainBoundaries::temporal(initial);
let scenario = Scenario::new(model, boundaries);
let total_time = 100.0;
let time_steps = 500;
let config = SolverConfiguration::time_evolution(total_time, time_steps);
let result = solver.solve(&scenario, &config).unwrap();
assert_eq!(
result.metadata.get("solver"),
Some(&"Forward Euler".to_string())
);
assert_eq!(result.metadata.get("time steps"), Some(&"500".to_string()));
assert_eq!(result.metadata.get("total time"), Some(&"100".to_string()));
let dt_str = result.metadata.get("dt").unwrap();
let dt: f64 = dt_str.parse().unwrap();
assert!((dt - 0.2).abs() < 1e-10);
}
#[test]
fn test_euler_detects_nan() {
#[derive(Deserialize, Serialize)]
struct NaNModel {
points: usize,
}
#[typetag::serde]
impl PhysicalModel for NaNModel {
fn points(&self) -> usize {
self.points
}
fn compute_physics(&self, _state: &PhysicalState) -> PhysicalState {
PhysicalState::new(
PhysicalQuantity::Concentration,
PhysicalData::uniform_vector(self.points, f64::NAN),
)
}
fn setup_initial_state(&self) -> PhysicalState {
PhysicalState::new(
PhysicalQuantity::Concentration,
PhysicalData::uniform_vector(self.points, 1.0),
)
}
fn name(&self) -> &str {
"NaN Model"
}
}
let solver = EulerSolver::new();
let model = Box::new(NaNModel { points: 5 });
let initial = model.setup_initial_state();
let boundaries = DomainBoundaries::temporal(initial);
let scenario = Scenario::new(model, boundaries);
let config = SolverConfiguration::time_evolution(10.0, 10);
let result = solver.solve(&scenario, &config);
assert!(result.is_err());
let error = result.unwrap_err();
assert!(error.contains("NaN"));
}
#[test]
fn test_euler_detects_inf() {
#[derive(Deserialize, Serialize)]
struct InfModel {
points: usize,
}
#[typetag::serde]
impl PhysicalModel for InfModel {
fn points(&self) -> usize {
self.points
}
fn compute_physics(&self, _state: &PhysicalState) -> PhysicalState {
PhysicalState::new(
PhysicalQuantity::Concentration,
PhysicalData::uniform_vector(self.points, f64::INFINITY),
)
}
fn setup_initial_state(&self) -> PhysicalState {
PhysicalState::new(
PhysicalQuantity::Concentration,
PhysicalData::uniform_vector(self.points, 1.0),
)
}
fn name(&self) -> &str {
"Inf Model"
}
}
let solver = EulerSolver::new();
let model = Box::new(InfModel { points: 5 });
let initial = model.setup_initial_state();
let boundaries = DomainBoundaries::temporal(initial);
let scenario = Scenario::new(model, boundaries);
let config = SolverConfiguration::time_evolution(10.0, 10);
let result = solver.solve(&scenario, &config);
assert!(result.is_err());
let error = result.unwrap_err();
assert!(error.contains("Infinity"));
}
#[test]
fn test_euler_single_step() {
let solver = EulerSolver::new();
let model = Box::new(ConstantGrowth {
points: 3,
growth_rate: 5.0,
});
let initial = model.setup_initial_state();
let boundaries = DomainBoundaries::temporal(initial);
let scenario = Scenario::new(model, boundaries);
let config = SolverConfiguration::time_evolution(1.0, 1);
let result = solver.solve(&scenario, &config).unwrap();
assert_eq!(result.time_points.len(), 2);
assert_eq!(result.state_trajectory.len(), 2);
let final_state = result
.final_state
.get(PhysicalQuantity::Concentration)
.unwrap();
assert!((final_state.as_vector()[0] - 5.0).abs() < 1e-10);
}
#[test]
fn test_euler_multiple_quantity() {
#[derive(Deserialize, Serialize)]
struct MultiQuantityModel {
points: usize,
}
#[typetag::serde]
impl PhysicalModel for MultiQuantityModel {
fn points(&self) -> usize {
self.points
}
fn compute_physics(&self, _state: &PhysicalState) -> PhysicalState {
let mut result = PhysicalState::empty();
result.set(
PhysicalQuantity::Concentration,
PhysicalData::uniform_vector(self.points, 1.0),
);
result.set(
PhysicalQuantity::Temperature,
PhysicalData::uniform_vector(self.points, 0.1),
);
result
}
fn setup_initial_state(&self) -> PhysicalState {
let mut state = PhysicalState::empty();
state.set(
PhysicalQuantity::Concentration,
PhysicalData::uniform_vector(self.points, 0.0),
);
state.set(
PhysicalQuantity::Temperature,
PhysicalData::uniform_vector(self.points, -273.15),
);
state
}
fn name(&self) -> &str {
"Multiple Physical Quantity Model"
}
}
let solver = EulerSolver::new();
let model = Box::new(MultiQuantityModel { points: 5 });
let initial = model.setup_initial_state();
let boundaries = DomainBoundaries::temporal(initial);
let scenario = Scenario::new(model, boundaries);
let config = SolverConfiguration::time_evolution(10.0, 100);
let result = solver.solve(&scenario, &config).unwrap();
let final_concentration = result
.final_state
.get(PhysicalQuantity::Concentration)
.unwrap();
assert!((final_concentration.as_vector()[0] - 10.0).abs() < 1e-10);
let final_temperature = result
.final_state
.get(PhysicalQuantity::Temperature)
.unwrap();
assert!((final_temperature.as_vector()[0] + 272.15).abs() < 1e-10); }
}