pub mod sensitivity;
use crate::kernel::{Domain, ExprData, ExprId, ExprPool};
use crate::simplify::engine::simplify;
use std::fmt;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum OdeError {
VariableCountMismatch,
NotFirstOrder,
DiffError(String),
}
impl fmt::Display for OdeError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
OdeError::VariableCountMismatch => write!(f, "variable and RHS count mismatch"),
OdeError::NotFirstOrder => write!(f, "ODE is not first-order"),
OdeError::DiffError(msg) => write!(f, "differentiation error: {msg}"),
}
}
}
impl std::error::Error for OdeError {}
impl crate::errors::AlkahestError for OdeError {
fn code(&self) -> &'static str {
match self {
OdeError::VariableCountMismatch => "E-ODE-001",
OdeError::NotFirstOrder => "E-ODE-002",
OdeError::DiffError(_) => "E-ODE-003",
}
}
fn remediation(&self) -> Option<&'static str> {
match self {
OdeError::VariableCountMismatch => Some(
"the number of state variables must equal the number of right-hand-side expressions",
),
OdeError::NotFirstOrder => Some(
"use lower_to_first_order() to reduce higher-order ODEs to first-order form",
),
OdeError::DiffError(_) => Some(
"check that all functions in the ODE are differentiable; unknown functions block lowering",
),
}
}
}
#[derive(Clone, Debug)]
pub struct ODE {
pub state_vars: Vec<ExprId>,
pub derivatives: Vec<ExprId>,
pub rhs: Vec<ExprId>,
pub time_var: ExprId,
pub initial_conditions: Vec<(ExprId, ExprId)>,
}
impl ODE {
pub fn new(
state_vars: Vec<ExprId>,
rhs: Vec<ExprId>,
time_var: ExprId,
pool: &ExprPool,
) -> Result<Self, OdeError> {
if state_vars.len() != rhs.len() {
return Err(OdeError::VariableCountMismatch);
}
let derivatives: Vec<ExprId> = state_vars
.iter()
.map(|&v| {
let name = pool.with(v, |d| match d {
ExprData::Symbol { name, .. } => format!("d{name}/dt"),
_ => "d?/dt".to_string(),
});
pool.symbol(&name, Domain::Real)
})
.collect();
Ok(ODE {
state_vars,
derivatives,
rhs,
time_var,
initial_conditions: vec![],
})
}
pub fn with_ic(mut self, var: ExprId, value: ExprId) -> Self {
self.initial_conditions.push((var, value));
self
}
pub fn order(&self) -> usize {
self.state_vars.len()
}
pub fn is_autonomous(&self, pool: &ExprPool) -> bool {
self.rhs
.iter()
.all(|&rhs| !contains(rhs, self.time_var, pool))
}
pub fn simplify_rhs(&self, pool: &ExprPool) -> ODE {
let rhs: Vec<ExprId> = self.rhs.iter().map(|&r| simplify(r, pool).value).collect();
ODE {
state_vars: self.state_vars.clone(),
derivatives: self.derivatives.clone(),
rhs,
time_var: self.time_var,
initial_conditions: self.initial_conditions.clone(),
}
}
pub fn display(&self, pool: &ExprPool) -> String {
let mut lines: Vec<String> = self
.derivatives
.iter()
.zip(self.rhs.iter())
.map(|(&d, &r)| format!(" {} = {}", pool.display(d), pool.display(r)))
.collect();
for (v, val) in &self.initial_conditions {
lines.push(format!(
" {}(0) = {}",
pool.display(*v),
pool.display(*val)
));
}
lines.join("\n")
}
}
pub struct ScalarODE {
pub var: ExprId,
pub aux_vars: Vec<ExprId>,
pub rhs: ExprId,
pub time_var: ExprId,
pub order: usize,
}
pub fn lower_to_first_order(scalar_ode: &ScalarODE, pool: &ExprPool) -> Result<ODE, OdeError> {
let n = scalar_ode.order;
if n == 0 {
return Err(OdeError::NotFirstOrder);
}
if n == 1 {
return ODE::new(
vec![scalar_ode.var],
vec![scalar_ode.rhs],
scalar_ode.time_var,
pool,
);
}
let var_name = pool.with(scalar_ode.var, |d| match d {
ExprData::Symbol { name, .. } => name.clone(),
_ => "x".to_string(),
});
let aux: Vec<ExprId> = (0..n)
.map(|i| {
let suffix = if i == 0 {
var_name.clone()
} else {
format!("{var_name}_{i}")
};
pool.symbol(&suffix, Domain::Real)
})
.collect();
let mut rhs_vec: Vec<ExprId> = (0..n - 1).map(|i| aux[i + 1]).collect();
rhs_vec.push(scalar_ode.rhs);
ODE::new(aux, rhs_vec, scalar_ode.time_var, pool)
}
fn contains(expr: ExprId, needle: ExprId, pool: &ExprPool) -> bool {
if expr == needle {
return true;
}
let children = pool.with(expr, |data| match data {
ExprData::Add(args) | ExprData::Mul(args) | ExprData::Func { args, .. } => args.clone(),
ExprData::Pow { base, exp } => vec![*base, *exp],
_ => vec![],
});
children.into_iter().any(|c| contains(c, needle, pool))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::kernel::ExprPool;
fn p() -> ExprPool {
ExprPool::new()
}
#[test]
fn ode_new_simple() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let t = pool.symbol("t", Domain::Real);
let ode = ODE::new(vec![x], vec![x], t, &pool).unwrap();
assert_eq!(ode.order(), 1);
assert!(ode.is_autonomous(&pool));
}
#[test]
fn ode_is_not_autonomous_with_t() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let t = pool.symbol("t", Domain::Real);
let rhs = pool.mul(vec![t, x]);
let ode = ODE::new(vec![x], vec![rhs], t, &pool).unwrap();
assert!(!ode.is_autonomous(&pool));
}
#[test]
fn ode_mismatch_error() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let y = pool.symbol("y", Domain::Real);
let t = pool.symbol("t", Domain::Real);
let result = ODE::new(vec![x, y], vec![x], t, &pool);
assert!(result.is_err());
}
#[test]
fn lower_second_order() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let t = pool.symbol("t", Domain::Real);
let rhs = pool.mul(vec![pool.integer(-1_i32), x]);
let scalar = ScalarODE {
var: x,
aux_vars: vec![],
rhs,
time_var: t,
order: 2,
};
let sys = lower_to_first_order(&scalar, &pool).unwrap();
assert_eq!(sys.order(), 2);
let first_rhs_name = pool.with(sys.rhs[0], |d| match d {
ExprData::Symbol { name, .. } => name.clone(),
_ => "?".to_string(),
});
assert_eq!(first_rhs_name, "x_1");
}
#[test]
fn ode_display() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let t = pool.symbol("t", Domain::Real);
let ode = ODE::new(vec![x], vec![x], t, &pool).unwrap();
let s = ode.display(&pool);
assert!(s.contains("dx/dt") || s.contains("d"), "got: {s}");
}
#[test]
fn ode_with_ic() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let t = pool.symbol("t", Domain::Real);
let zero = pool.integer(0_i32);
let one = pool.integer(1_i32);
let ode = ODE::new(vec![x], vec![x], t, &pool)
.unwrap()
.with_ic(x, one);
assert_eq!(ode.initial_conditions.len(), 1);
assert_eq!(ode.initial_conditions[0], (x, one));
let _ = zero; }
#[test]
fn ode_simplify_rhs() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let t = pool.symbol("t", Domain::Real);
let zero = pool.integer(0_i32);
let rhs = pool.add(vec![x, zero]);
let ode = ODE::new(vec![x], vec![rhs], t, &pool).unwrap();
let simplified = ode.simplify_rhs(&pool);
assert_eq!(simplified.rhs[0], x);
}
}