use crate::context::calculator::ContextCalculator;
#[non_exhaustive]
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub enum StepControl {
Fixed {
dt: f64,
},
Adaptive {
dt_init: f64,
dt_min: f64,
dt_max: f64,
rtol: f64,
atol: f64,
},
}
impl StepControl {
pub fn dt_initial(&self) -> f64 {
match self {
Self::Fixed { dt } => *dt,
Self::Adaptive { dt_init, .. } => *dt_init,
}
}
pub fn is_fixed(&self) -> bool {
matches!(self, Self::Fixed { .. })
}
pub fn is_adaptive(&self) -> bool {
matches!(self, Self::Adaptive { .. })
}
}
#[non_exhaustive]
#[derive(Debug, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub enum IntegratorKind {
Euler,
RK4,
}
impl IntegratorKind {
pub fn is_explicit(&self) -> bool {
matches!(self, Self::Euler | Self::RK4)
}
}
#[non_exhaustive]
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct TimeConfiguration {
pub t_end: f64,
pub step_control: StepControl,
pub save_every: Option<usize>,
}
impl TimeConfiguration {
pub fn new(t_end: f64, step_control: StepControl) -> Self {
Self {
t_end,
step_control,
save_every: None,
}
}
pub fn saving_every(mut self, n: usize) -> Self {
self.save_every = Some(n);
self
}
pub fn n_steps_estimate(&self) -> usize {
match &self.step_control {
StepControl::Fixed { dt } => {
if *dt > 0.0 {
(self.t_end / dt).ceil() as usize
} else {
0
}
}
StepControl::Adaptive { .. } => 0,
}
}
}
#[non_exhaustive]
pub struct SolverConfiguration {
pub time: TimeConfiguration,
pub integrator: IntegratorKind,
pub calculators: Vec<Box<dyn ContextCalculator>>,
}
impl SolverConfiguration {
pub fn new(time: TimeConfiguration, integrator: IntegratorKind) -> Self {
Self {
time,
integrator,
calculators: Vec::new(),
}
}
pub fn with_calculator(mut self, calc: Box<dyn ContextCalculator>) -> Self {
self.calculators.push(calc);
self
}
pub fn with_calculators(mut self, calcs: Vec<Box<dyn ContextCalculator>>) -> Self {
self.calculators.extend(calcs);
self
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::context::compute::ComputeContext;
use crate::context::error::OxiflowError;
use crate::context::value::ContextValue;
use crate::context::variable::ContextVariable;
use crate::model::traits::RequiresContext;
#[test]
fn fixed_dt_initial() {
let sc = StepControl::Fixed { dt: 0.05 };
assert_eq!(sc.dt_initial(), 0.05);
}
#[test]
fn adaptive_dt_initial() {
let sc = StepControl::Adaptive {
dt_init: 0.01,
dt_min: 1e-6,
dt_max: 1.0,
rtol: 1e-4,
atol: 1e-6,
};
assert_eq!(sc.dt_initial(), 0.01);
}
#[test]
fn is_fixed_and_is_adaptive() {
assert!(StepControl::Fixed { dt: 0.01 }.is_fixed());
assert!(!StepControl::Fixed { dt: 0.01 }.is_adaptive());
let adaptive = StepControl::Adaptive {
dt_init: 0.01,
dt_min: 1e-6,
dt_max: 1.0,
rtol: 1e-4,
atol: 1e-6,
};
assert!(adaptive.is_adaptive());
assert!(!adaptive.is_fixed());
}
#[test]
fn euler_and_rk4_are_explicit() {
assert!(IntegratorKind::Euler.is_explicit());
assert!(IntegratorKind::RK4.is_explicit());
}
#[test]
fn integrator_equality() {
assert_eq!(IntegratorKind::Euler, IntegratorKind::Euler);
assert_ne!(IntegratorKind::Euler, IntegratorKind::RK4);
}
#[test]
fn n_steps_estimate_fixed() {
let tc = TimeConfiguration::new(10.0, StepControl::Fixed { dt: 0.01 });
assert_eq!(tc.n_steps_estimate(), 1000);
}
#[test]
fn n_steps_estimate_adaptive_is_zero() {
let tc = TimeConfiguration::new(
10.0,
StepControl::Adaptive {
dt_init: 0.01,
dt_min: 1e-6,
dt_max: 1.0,
rtol: 1e-4,
atol: 1e-6,
},
);
assert_eq!(tc.n_steps_estimate(), 0);
}
#[test]
fn saving_every_builder() {
let tc = TimeConfiguration::new(100.0, StepControl::Fixed { dt: 0.1 }).saving_every(10);
assert_eq!(tc.save_every, Some(10));
}
#[test]
fn default_save_every_is_none() {
let tc = TimeConfiguration::new(1.0, StepControl::Fixed { dt: 0.1 });
assert_eq!(tc.save_every, None);
}
#[test]
fn new_config_has_no_calculators() {
let cfg = SolverConfiguration::new(
TimeConfiguration::new(1.0, StepControl::Fixed { dt: 0.1 }),
IntegratorKind::Euler,
);
assert!(cfg.calculators.is_empty());
}
#[test]
fn with_calculator_adds_to_chain() {
#[derive(Debug)]
struct DummyCalc;
impl RequiresContext for DummyCalc {
fn required_variables(&self) -> Vec<ContextVariable> {
vec![]
}
}
impl crate::context::calculator::ContextCalculator for DummyCalc {
fn provides(&self) -> ContextVariable {
ContextVariable::Time
}
fn compute(
&self,
_: &ContextValue,
ctx: &ComputeContext,
) -> Result<ContextValue, OxiflowError> {
Ok(ContextValue::Scalar(ctx.time()))
}
}
let cfg = SolverConfiguration::new(
TimeConfiguration::new(1.0, StepControl::Fixed { dt: 0.1 }),
IntegratorKind::RK4,
)
.with_calculator(Box::new(DummyCalc));
assert_eq!(cfg.calculators.len(), 1);
assert_eq!(cfg.integrator, IntegratorKind::RK4);
}
#[test]
fn time_configuration_accessible() {
let cfg = SolverConfiguration::new(
TimeConfiguration::new(600.0, StepControl::Fixed { dt: 0.5 }),
IntegratorKind::Euler,
);
assert_eq!(cfg.time.t_end, 600.0);
assert_eq!(cfg.time.step_control.dt_initial(), 0.5);
}
}