1pub mod sensitivity;
10
11use crate::kernel::{Domain, ExprData, ExprId, ExprPool};
12use crate::simplify::engine::simplify;
13use std::fmt;
14
15#[derive(Debug, Clone, PartialEq, Eq)]
20pub enum OdeError {
21 VariableCountMismatch,
22 NotFirstOrder,
23 DiffError(String),
24}
25
26impl fmt::Display for OdeError {
27 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
28 match self {
29 OdeError::VariableCountMismatch => write!(f, "variable and RHS count mismatch"),
30 OdeError::NotFirstOrder => write!(f, "ODE is not first-order"),
31 OdeError::DiffError(msg) => write!(f, "differentiation error: {msg}"),
32 }
33 }
34}
35
36impl std::error::Error for OdeError {}
37
38impl crate::errors::AlkahestError for OdeError {
39 fn code(&self) -> &'static str {
40 match self {
41 OdeError::VariableCountMismatch => "E-ODE-001",
42 OdeError::NotFirstOrder => "E-ODE-002",
43 OdeError::DiffError(_) => "E-ODE-003",
44 }
45 }
46
47 fn remediation(&self) -> Option<&'static str> {
48 match self {
49 OdeError::VariableCountMismatch => Some(
50 "the number of state variables must equal the number of right-hand-side expressions",
51 ),
52 OdeError::NotFirstOrder => Some(
53 "use lower_to_first_order() to reduce higher-order ODEs to first-order form",
54 ),
55 OdeError::DiffError(_) => Some(
56 "check that all functions in the ODE are differentiable; unknown functions block lowering",
57 ),
58 }
59 }
60}
61
62#[derive(Clone, Debug)]
73pub struct ODE {
74 pub state_vars: Vec<ExprId>,
76 pub derivatives: Vec<ExprId>,
78 pub rhs: Vec<ExprId>,
80 pub time_var: ExprId,
82 pub initial_conditions: Vec<(ExprId, ExprId)>,
84}
85
86impl ODE {
87 pub fn new(
96 state_vars: Vec<ExprId>,
97 rhs: Vec<ExprId>,
98 time_var: ExprId,
99 pool: &ExprPool,
100 ) -> Result<Self, OdeError> {
101 if state_vars.len() != rhs.len() {
102 return Err(OdeError::VariableCountMismatch);
103 }
104 let derivatives: Vec<ExprId> = state_vars
105 .iter()
106 .map(|&v| {
107 let name = pool.with(v, |d| match d {
108 ExprData::Symbol { name, .. } => format!("d{name}/dt"),
109 _ => "d?/dt".to_string(),
110 });
111 pool.symbol(&name, Domain::Real)
112 })
113 .collect();
114 Ok(ODE {
115 state_vars,
116 derivatives,
117 rhs,
118 time_var,
119 initial_conditions: vec![],
120 })
121 }
122
123 pub fn with_ic(mut self, var: ExprId, value: ExprId) -> Self {
125 self.initial_conditions.push((var, value));
126 self
127 }
128
129 pub fn order(&self) -> usize {
131 self.state_vars.len()
132 }
133
134 pub fn is_autonomous(&self, pool: &ExprPool) -> bool {
136 self.rhs
137 .iter()
138 .all(|&rhs| !contains(rhs, self.time_var, pool))
139 }
140
141 pub fn simplify_rhs(&self, pool: &ExprPool) -> ODE {
143 let rhs: Vec<ExprId> = self.rhs.iter().map(|&r| simplify(r, pool).value).collect();
144 ODE {
145 state_vars: self.state_vars.clone(),
146 derivatives: self.derivatives.clone(),
147 rhs,
148 time_var: self.time_var,
149 initial_conditions: self.initial_conditions.clone(),
150 }
151 }
152
153 pub fn display(&self, pool: &ExprPool) -> String {
155 let mut lines: Vec<String> = self
156 .derivatives
157 .iter()
158 .zip(self.rhs.iter())
159 .map(|(&d, &r)| format!(" {} = {}", pool.display(d), pool.display(r)))
160 .collect();
161 for (v, val) in &self.initial_conditions {
162 lines.push(format!(
163 " {}(0) = {}",
164 pool.display(*v),
165 pool.display(*val)
166 ));
167 }
168 lines.join("\n")
169 }
170}
171
172pub struct ScalarODE {
178 pub var: ExprId,
180 pub aux_vars: Vec<ExprId>,
182 pub rhs: ExprId,
184 pub time_var: ExprId,
186 pub order: usize,
188}
189
190pub fn lower_to_first_order(scalar_ode: &ScalarODE, pool: &ExprPool) -> Result<ODE, OdeError> {
203 let n = scalar_ode.order;
204 if n == 0 {
205 return Err(OdeError::NotFirstOrder);
206 }
207 if n == 1 {
208 return ODE::new(
210 vec![scalar_ode.var],
211 vec![scalar_ode.rhs],
212 scalar_ode.time_var,
213 pool,
214 );
215 }
216
217 let var_name = pool.with(scalar_ode.var, |d| match d {
219 ExprData::Symbol { name, .. } => name.clone(),
220 _ => "x".to_string(),
221 });
222 let aux: Vec<ExprId> = (0..n)
223 .map(|i| {
224 let suffix = if i == 0 {
225 var_name.clone()
226 } else {
227 format!("{var_name}_{i}")
228 };
229 pool.symbol(&suffix, Domain::Real)
230 })
231 .collect();
232
233 let mut rhs_vec: Vec<ExprId> = (0..n - 1).map(|i| aux[i + 1]).collect();
235 rhs_vec.push(scalar_ode.rhs);
236
237 ODE::new(aux, rhs_vec, scalar_ode.time_var, pool)
238}
239
240fn contains(expr: ExprId, needle: ExprId, pool: &ExprPool) -> bool {
245 if expr == needle {
246 return true;
247 }
248 let children = pool.with(expr, |data| match data {
249 ExprData::Add(args) | ExprData::Mul(args) | ExprData::Func { args, .. } => args.clone(),
250 ExprData::Pow { base, exp } => vec![*base, *exp],
251 _ => vec![],
252 });
253 children.into_iter().any(|c| contains(c, needle, pool))
254}
255
256#[cfg(test)]
261mod tests {
262 use super::*;
263 use crate::kernel::ExprPool;
264
265 fn p() -> ExprPool {
266 ExprPool::new()
267 }
268
269 #[test]
270 fn ode_new_simple() {
271 let pool = p();
272 let x = pool.symbol("x", Domain::Real);
273 let t = pool.symbol("t", Domain::Real);
274 let ode = ODE::new(vec![x], vec![x], t, &pool).unwrap();
276 assert_eq!(ode.order(), 1);
277 assert!(ode.is_autonomous(&pool));
278 }
279
280 #[test]
281 fn ode_is_not_autonomous_with_t() {
282 let pool = p();
283 let x = pool.symbol("x", Domain::Real);
284 let t = pool.symbol("t", Domain::Real);
285 let rhs = pool.mul(vec![t, x]);
287 let ode = ODE::new(vec![x], vec![rhs], t, &pool).unwrap();
288 assert!(!ode.is_autonomous(&pool));
289 }
290
291 #[test]
292 fn ode_mismatch_error() {
293 let pool = p();
294 let x = pool.symbol("x", Domain::Real);
295 let y = pool.symbol("y", Domain::Real);
296 let t = pool.symbol("t", Domain::Real);
297 let result = ODE::new(vec![x, y], vec![x], t, &pool);
299 assert!(result.is_err());
300 }
301
302 #[test]
303 fn lower_second_order() {
304 let pool = p();
306 let x = pool.symbol("x", Domain::Real);
307 let t = pool.symbol("t", Domain::Real);
308 let rhs = pool.mul(vec![pool.integer(-1_i32), x]);
309 let scalar = ScalarODE {
310 var: x,
311 aux_vars: vec![],
312 rhs,
313 time_var: t,
314 order: 2,
315 };
316 let sys = lower_to_first_order(&scalar, &pool).unwrap();
317 assert_eq!(sys.order(), 2);
318 let first_rhs_name = pool.with(sys.rhs[0], |d| match d {
320 ExprData::Symbol { name, .. } => name.clone(),
321 _ => "?".to_string(),
322 });
323 assert_eq!(first_rhs_name, "x_1");
324 }
325
326 #[test]
327 fn ode_display() {
328 let pool = p();
329 let x = pool.symbol("x", Domain::Real);
330 let t = pool.symbol("t", Domain::Real);
331 let ode = ODE::new(vec![x], vec![x], t, &pool).unwrap();
332 let s = ode.display(&pool);
333 assert!(s.contains("dx/dt") || s.contains("d"), "got: {s}");
334 }
335
336 #[test]
337 fn ode_with_ic() {
338 let pool = p();
339 let x = pool.symbol("x", Domain::Real);
340 let t = pool.symbol("t", Domain::Real);
341 let zero = pool.integer(0_i32);
342 let one = pool.integer(1_i32);
343 let ode = ODE::new(vec![x], vec![x], t, &pool)
344 .unwrap()
345 .with_ic(x, one);
346 assert_eq!(ode.initial_conditions.len(), 1);
347 assert_eq!(ode.initial_conditions[0], (x, one));
348 let _ = zero; }
350
351 #[test]
352 fn ode_simplify_rhs() {
353 let pool = p();
354 let x = pool.symbol("x", Domain::Real);
355 let t = pool.symbol("t", Domain::Real);
356 let zero = pool.integer(0_i32);
357 let rhs = pool.add(vec![x, zero]);
359 let ode = ODE::new(vec![x], vec![rhs], t, &pool).unwrap();
360 let simplified = ode.simplify_rhs(&pool);
361 assert_eq!(simplified.rhs[0], x);
362 }
363}