use crate::physics::traits::PhysicalState;
use crate::solver::scenario::Scenario;
use serde::{Deserialize, Serialize};
use std::{collections::HashMap, string::ToString};
#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)]
pub enum SolverType {
TimeEvolution { total_time: f64, time_steps: usize },
Iterative {
tolerance: f64,
max_iterations: usize,
},
Analytical { evaluation_time: Option<f64> },
SpatialDiscretization {
grid_points: usize,
time_steps: Option<usize>,
},
Custom(String, HashMap<String, f64>),
}
impl SolverType {
pub fn name(&self) -> &str {
match self {
SolverType::TimeEvolution { .. } => "TimeEvolution",
SolverType::Iterative { .. } => "Iterative",
SolverType::Analytical { .. } => "Analytical",
SolverType::SpatialDiscretization { .. } => "SpatialDiscretization",
SolverType::Custom(name, _) => name,
}
}
pub fn validate(&self) -> Result<(), String> {
match self {
SolverType::TimeEvolution {
total_time,
time_steps,
} => {
if *total_time <= 0.0 {
return Err("Total time must be positive".to_string());
}
if *time_steps == 0 {
return Err("TimeSteps must be greater than 0".to_string());
}
Ok(())
}
SolverType::Iterative {
tolerance,
max_iterations,
} => {
if *tolerance <= 0.0 {
return Err("Tolerance must be positive".to_string());
}
if *max_iterations == 0 {
return Err("Maximum iterations must be positive".to_string());
}
Ok(())
}
SolverType::Analytical { evaluation_time: _ } => Ok(()),
SolverType::SpatialDiscretization {
grid_points,
time_steps,
} => {
if *grid_points == 0 {
return Err("Grid Points cannot be null (no grid)".to_string());
}
if let Some(steps) = time_steps
&& *steps == 0
{
return Err("Steps must be greater than 0".to_string());
}
Ok(())
}
SolverType::Custom(_, parameters) => {
for (key, value) in parameters {
if !value.is_finite() {
return Err(format!("Parameter {} is not finite", key));
}
}
Ok(())
}
}
}
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct SolverConfiguration {
pub solver_type: SolverType,
#[serde(default)]
pub step: Option<usize>,
}
impl SolverConfiguration {
pub fn new(solver_type: SolverType) -> Self {
Self {
solver_type,
step: None,
}
}
pub fn with_step(mut self, n: usize) -> Self {
self.step = Some(n);
self
}
pub fn time_evolution(total_time: f64, time_steps: usize) -> Self {
Self::new(SolverType::TimeEvolution {
total_time,
time_steps,
})
}
pub fn iterative(tolerance: f64, max_iterations: usize) -> Self {
Self::new(SolverType::Iterative {
tolerance,
max_iterations,
})
}
pub fn analytical(evaluation_time: f64) -> Self {
Self::new(SolverType::Analytical {
evaluation_time: Some(evaluation_time),
})
}
pub fn spatial_discretization(grid_points: usize, time_steps: usize) -> Self {
Self::new(SolverType::SpatialDiscretization {
grid_points,
time_steps: Some(time_steps),
})
}
pub fn validate(&self) -> Result<(), String> {
self.solver_type.validate()
}
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct SimulationResult {
pub time_points: Vec<f64>,
pub state_trajectory: Vec<PhysicalState>,
pub final_state: PhysicalState,
pub metadata: HashMap<String, String>,
}
impl SimulationResult {
pub fn new(
time_points: Vec<f64>,
state_trajectory: Vec<PhysicalState>,
final_state: PhysicalState,
) -> Self {
Self {
time_points,
state_trajectory,
final_state,
metadata: HashMap::new(),
}
}
pub fn add_metadata(&mut self, key: impl Into<String>, value: impl Into<String>) {
self.metadata.insert(key.into(), value.into());
}
pub fn len(&self) -> usize {
self.time_points.len()
}
pub fn is_empty(&self) -> bool {
self.time_points.is_empty()
}
}
pub trait Solver {
fn solve(
&self,
scenario: &Scenario,
config: &SolverConfiguration,
) -> Result<SimulationResult, String>;
fn name(&self) -> &str;
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json;
#[test]
fn test_solver_type_name() {
let time_ev = SolverType::TimeEvolution {
total_time: 10.0,
time_steps: 100,
};
assert_eq!(time_ev.name(), "TimeEvolution");
let iterative = SolverType::Iterative {
tolerance: 1e-6,
max_iterations: 100,
};
assert_eq!(iterative.name(), "Iterative");
}
#[test]
fn test_solver_type_validate_time_evolution() {
let valid = SolverType::TimeEvolution {
total_time: 10.0,
time_steps: 1000,
};
assert!(valid.validate().is_ok());
let invalid_time = SolverType::TimeEvolution {
total_time: -1.0,
time_steps: 1000,
};
assert!(invalid_time.validate().is_err());
let invalid_steps = SolverType::TimeEvolution {
total_time: 10.0,
time_steps: 0,
};
assert!(invalid_steps.validate().is_err());
}
#[test]
fn test_solver_type_validate_iterative() {
let valid = SolverType::Iterative {
tolerance: 1e-6,
max_iterations: 100,
};
assert!(valid.validate().is_ok());
let invalid_tol = SolverType::Iterative {
tolerance: -1e-6,
max_iterations: 100,
};
assert!(invalid_tol.validate().is_err());
}
#[test]
fn test_solver_configuration_factory_methods() {
let time_ev = SolverConfiguration::time_evolution(10.0, 1000);
assert!(matches!(
time_ev.solver_type,
SolverType::TimeEvolution { .. }
));
let iterative = SolverConfiguration::iterative(1e-6, 100);
assert!(matches!(
iterative.solver_type,
SolverType::Iterative { .. }
));
let analytical = SolverConfiguration::analytical(5.0);
assert!(matches!(
analytical.solver_type,
SolverType::Analytical { .. }
));
}
#[test]
fn test_solver_configuration_validate() {
let valid = SolverConfiguration::time_evolution(10.0, 1000);
assert!(valid.validate().is_ok());
let invalid = SolverConfiguration::time_evolution(-1.0, 1000);
assert!(invalid.validate().is_err());
}
#[test]
fn test_step_defaults_to_none_on_all_factories() {
assert!(
SolverConfiguration::time_evolution(10.0, 1000)
.step
.is_none()
);
assert!(SolverConfiguration::iterative(1e-6, 100).step.is_none());
assert!(SolverConfiguration::analytical(5.0).step.is_none());
assert!(
SolverConfiguration::spatial_discretization(100, 1000)
.step
.is_none()
);
}
#[test]
fn test_with_step_sets_value() {
let config = SolverConfiguration::time_evolution(600.0, 10000).with_step(50);
assert_eq!(config.step, Some(50));
}
#[test]
fn test_with_step_preserves_solver_type() {
let config = SolverConfiguration::time_evolution(600.0, 10000).with_step(50);
assert!(matches!(
config.solver_type,
SolverType::TimeEvolution { .. }
));
assert!(config.validate().is_ok());
}
#[test]
fn test_with_step_zero_accepted() {
let config = SolverConfiguration::time_evolution(10.0, 1000).with_step(0);
assert_eq!(config.step, Some(0));
}
#[test]
fn test_step_serde_default_when_field_absent() {
let json = r#"{"solver_type":{"TimeEvolution":{"total_time":10.0,"time_steps":1000}}}"#;
let config: SolverConfiguration =
serde_json::from_str(json).expect("deserialisation must succeed when step is absent");
assert!(config.step.is_none(), "step must default to None");
}
#[test]
fn test_step_serde_null() {
let json = r#"{"solver_type":{"TimeEvolution":{"total_time":10.0,"time_steps":1000}},"step":null}"#;
let config: SolverConfiguration =
serde_json::from_str(json).expect("deserialisation must succeed for step: null");
assert!(config.step.is_none(), "step: null must yield None");
}
#[test]
fn test_step_serde_with_value() {
let json =
r#"{"solver_type":{"TimeEvolution":{"total_time":10.0,"time_steps":1000}},"step":10}"#;
let config: SolverConfiguration =
serde_json::from_str(json).expect("deserialisation must succeed for step: 10");
assert_eq!(config.step, Some(10));
}
#[test]
fn test_step_serde_round_trip() {
let original = SolverConfiguration::time_evolution(600.0, 10000).with_step(25);
let serialised = serde_json::to_string(&original).expect("serialisation must succeed");
let restored: SolverConfiguration =
serde_json::from_str(&serialised).expect("deserialisation must succeed");
assert_eq!(restored.step, Some(25));
assert!(matches!(
restored.solver_type,
SolverType::TimeEvolution { .. }
));
}
}