use super::expression::{simplify, SymbolicExpression, Variable};
use crate::common::IntegrateFloat;
use crate::error::{IntegrateError, IntegrateResult};
use scirs2_core::ndarray::{Array2, ArrayView1};
use std::collections::HashMap;
#[allow(dead_code)]
fn var<F: IntegrateFloat>(name: &str) -> SymbolicExpression<F> {
SymbolicExpression::var(name)
}
#[allow(dead_code)]
fn indexed_var<F: IntegrateFloat>(name: &str, index: usize) -> SymbolicExpression<F> {
SymbolicExpression::indexedvar(name, index)
}
#[allow(dead_code)]
fn constant<F: IntegrateFloat>(value: F) -> SymbolicExpression<F> {
SymbolicExpression::constant(value)
}
pub struct SymbolicJacobian<F: IntegrateFloat> {
pub elements: Array2<SymbolicExpression<F>>,
pub state_vars: Vec<Variable>,
pub time_var: Option<Variable>,
}
impl<F: IntegrateFloat> SymbolicJacobian<F> {
pub fn new(
elements: Array2<SymbolicExpression<F>>,
state_vars: Vec<Variable>,
time_var: Option<Variable>,
) -> Self {
SymbolicJacobian {
elements,
state_vars,
time_var,
}
}
pub fn evaluate(&self, t: F, y: ArrayView1<F>) -> IntegrateResult<Array2<F>> {
let n = self.state_vars.len();
if y.len() != n {
return Err(IntegrateError::DimensionMismatch(format!(
"Expected {} states, got {}",
n,
y.len()
)));
}
let mut values = HashMap::new();
for (i, var) in self.state_vars.iter().enumerate() {
values.insert(var.clone(), y[i]);
}
if let Some(ref t_var) = self.time_var {
values.insert(t_var.clone(), t);
}
let (rows, cols) = self.elements.dim();
let mut result = Array2::zeros((rows, cols));
for i in 0..rows {
for j in 0..cols {
result[[i, j]] = self.elements[[i, j]].evaluate(&values)?;
}
}
Ok(result)
}
pub fn simplify(&mut self) {
let (rows, cols) = self.elements.dim();
for i in 0..rows {
for j in 0..cols {
self.elements[[i, j]] = simplify(&self.elements[[i, j]]);
}
}
}
}
#[allow(dead_code)]
pub fn generate_jacobian<F: IntegrateFloat>(
expressions: &[SymbolicExpression<F>],
state_vars: &[Variable],
time_var: Option<Variable>,
) -> IntegrateResult<SymbolicJacobian<F>> {
let n = expressions.len();
let m = state_vars.len();
if n == 0 || m == 0 {
return Err(IntegrateError::ValueError(
"Empty expressions or state variables".to_string(),
));
}
let mut jacobian = Array2::from_elem((n, m), SymbolicExpression::Constant(F::zero()));
for (i, expr) in expressions.iter().enumerate() {
for (j, var) in state_vars.iter().enumerate() {
jacobian[[i, j]] = expr.differentiate(var);
}
}
Ok(SymbolicJacobian::new(
jacobian,
state_vars.to_vec(),
time_var,
))
}
pub struct SymbolicODEBuilder<F: IntegrateFloat> {
expressions: Vec<SymbolicExpression<F>>,
state_vars: Vec<Variable>,
time_var: Option<Variable>,
}
impl<F: IntegrateFloat> SymbolicODEBuilder<F> {
pub fn new() -> Self {
SymbolicODEBuilder {
expressions: Vec::new(),
state_vars: Vec::new(),
time_var: None,
}
}
pub fn with_state_vars(mut self, n: usize) -> Self {
self.state_vars = (0..n).map(|i| Variable::indexed("y", i)).collect();
self
}
pub fn with_named_vars(mut self, names: Vec<String>) -> Self {
self.state_vars = names.into_iter().map(Variable::new).collect();
self
}
pub fn with_time(mut self) -> Self {
self.time_var = Some(Variable::new("t"));
self
}
pub fn add_equation(&mut self, expr: SymbolicExpression<F>) -> &mut Self {
self.expressions.push(expr);
self
}
pub fn build_jacobian(&self) -> IntegrateResult<SymbolicJacobian<F>> {
generate_jacobian(&self.expressions, &self.state_vars, self.time_var.clone())
}
}
impl<F: IntegrateFloat> Default for SymbolicODEBuilder<F> {
fn default() -> Self {
Self::new()
}
}
#[allow(dead_code)]
pub fn example_van_der_pol<F: IntegrateFloat>(mu: F) -> IntegrateResult<SymbolicJacobian<F>> {
use SymbolicExpression::*;
let y0 = Var(Variable::indexed("y", 0));
let y1 = Var(Variable::indexed("y", 1));
let expr1 = y1.clone();
let expr2 = Sub(
Box::new(Mul(
Box::new(Mul(
Box::new(Constant(mu)),
Box::new(Sub(
Box::new(Constant(F::one())),
Box::new(Pow(
Box::new(y0.clone()),
Box::new(Constant(
F::from(2.0).expect("Failed to convert constant to float"),
)),
)),
)),
)),
Box::new(y1),
)),
Box::new(y0),
);
SymbolicODEBuilder::new()
.with_state_vars(2)
.add_equation(expr1)
.add_equation(expr2)
.build_jacobian()
}
#[allow(dead_code)]
pub fn example_stiff_chemical<F: IntegrateFloat>() -> IntegrateResult<SymbolicJacobian<F>> {
let y1 = SymbolicExpression::indexedvar("y", 0);
let y2 = SymbolicExpression::indexedvar("y", 1);
let y3 = SymbolicExpression::indexedvar("y", 2);
let k1 =
SymbolicExpression::constant(F::from(0.04).expect("Failed to convert constant to float"));
let k2 =
SymbolicExpression::constant(F::from(1e4).expect("Failed to convert constant to float"));
let k3 =
SymbolicExpression::constant(F::from(3e7).expect("Failed to convert constant to float"));
let expr1 = -k1.clone() * y1.clone() + k2.clone() * y2.clone() * y3.clone();
let expr2 = k1 * y1 - k2 * y2.clone() * y3 - k3.clone() * y2.clone() * y2.clone();
let expr3 = k3 * y2.clone() * y2;
SymbolicODEBuilder::new()
.with_state_vars(3)
.add_equation(expr1)
.add_equation(expr2)
.add_equation(expr3)
.build_jacobian()
}
#[allow(dead_code)]
pub fn example_seasonal_predator_prey<F: IntegrateFloat>() -> IntegrateResult<SymbolicJacobian<F>> {
let x = indexed_var("y", 0);
let y = indexed_var("y", 1);
let t = var("t");
let a = constant(F::from(1.5).expect("Failed to convert constant to float"));
let b = constant(F::from(0.1).expect("Failed to convert constant to float"));
let c = constant(F::from(0.5).expect("Failed to convert constant to float"));
let d = constant(F::from(0.75).expect("Failed to convert constant to float"));
let e = constant(F::from(0.25).expect("Failed to convert constant to float"));
let two_pi = constant(F::from(std::f64::consts::TAU).expect("Failed to convert to float"));
let seasonal = constant(F::one()) + b * SymbolicExpression::Sin(Box::new(two_pi * t));
let expr1 = a * x.clone() * seasonal - c * x.clone() * y.clone();
let expr2 = -d * y.clone() + e * x * y;
let mut builder = SymbolicODEBuilder::new().with_state_vars(2).with_time();
builder.add_equation(expr1);
builder.add_equation(expr2);
builder.build_jacobian()
}
#[cfg(feature = "autodiff")]
#[allow(dead_code)]
pub fn create_autodiff_jacobian<F, Func>(
symbolic_jacobian: &SymbolicJacobian<F>,
) -> impl Fn(F, ArrayView1<F>) -> IntegrateResult<Array2<F>>
where
F: IntegrateFloat,
Func: Fn(F, ArrayView1<F>) -> IntegrateResult<ArrayView1<F>>,
{
let jac = symbolic_jacobian.clone();
move |t: F, y: ArrayView1<F>| jac.evaluate(t, y)
}
impl<F: IntegrateFloat> Clone for SymbolicJacobian<F> {
fn clone(&self) -> Self {
SymbolicJacobian {
elements: self.elements.clone(),
state_vars: self.state_vars.clone(),
time_var: self.time_var.clone(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
generate_jacobian,
SymbolicExpression::{Neg, Var},
Variable,
};
use scirs2_core::ndarray::ArrayView1;
use std::collections::HashMap;
#[test]
fn test_simple_jacobian() {
let y = Var(Variable::new("y"));
let expr = Neg(Box::new(y));
let jacobian =
generate_jacobian(&[expr], &[Variable::new("y")], None).expect("Operation failed");
let mut values = HashMap::new();
values.insert(Variable::new("y"), 1.0);
let j = jacobian
.evaluate(0.0, ArrayView1::from(&[1.0]))
.expect("Operation failed");
assert_eq!(j.dim(), (1, 1));
assert!((j[[0, 0]] + 1.0_f64).abs() < 1e-10);
}
}