use std::collections::HashMap;
use std::collections::HashSet;
use std::sync::Arc;
use crate::symbolic::calculus::differentiate;
use crate::symbolic::calculus::integrate;
use crate::symbolic::calculus::substitute;
use crate::symbolic::calculus::substitute_expr;
use crate::symbolic::core::Expr;
use crate::symbolic::polynomial::contains_var;
use crate::symbolic::simplify::is_zero;
use crate::symbolic::simplify::pattern_match;
use crate::symbolic::simplify_dag::simplify;
use crate::symbolic::solve::solve;
use crate::symbolic::solve::solve_linear_system;
use crate::symbolic::transforms;
pub struct ParsedODE {
pub order: u32,
pub coeffs: HashMap<u32, Expr>,
pub remaining_expr: Expr,
}
pub(crate) fn parse_ode(
equation: &Expr,
func: &str,
var: &str,
) -> ParsedODE {
pub(crate) fn collect_terms(
expr: &Expr,
func: &str,
var: &str,
coeffs: &mut HashMap<u32, Expr>,
remaining: &mut Expr,
) {
if let Expr::Dag(node) = expr {
collect_terms(
&node.to_expr().expect("Collect Terms"),
func,
var,
coeffs,
remaining,
);
return;
}
if let Expr::Add(a, b) = expr {
collect_terms(a, func, var, coeffs, remaining);
collect_terms(b, func, var, coeffs, remaining);
} else if let Expr::Sub(a, b) = expr {
collect_terms(a, func, var, coeffs, remaining);
let neg_b = simplify(&Expr::new_neg(b.clone()));
collect_terms(&neg_b, func, var, coeffs, remaining);
} else {
let (order, coeff) = get_term_order_and_coeff(expr, func, var);
if order > 100 {
*remaining = simplify(&Expr::new_add(remaining.clone(), expr.clone()));
} else {
let entry = coeffs.entry(order).or_insert_with(|| Expr::Constant(0.0));
*entry = simplify(&Expr::new_add(entry.clone(), coeff));
}
}
}
let mut coeffs = HashMap::new();
let mut remaining_expr = Expr::Constant(0.0);
collect_terms(
&simplify(&equation.clone()),
func,
var,
&mut coeffs,
&mut remaining_expr,
);
let max_order = coeffs.keys().max().copied().unwrap_or(0);
ParsedODE {
order: max_order,
coeffs,
remaining_expr,
}
}
pub(crate) fn get_term_order_and_coeff(
expr: &Expr,
func: &str,
var: &str,
) -> (u32, Expr) {
match expr {
| Expr::Dag(node) => {
get_term_order_and_coeff(
&node.to_expr().expect(
"Get Order and \
Coeff",
),
func,
var,
)
},
| Expr::Derivative(inner, d_var) if d_var == var => {
let (order, coeff) = get_term_order_and_coeff(inner, func, var);
(order + 1, coeff)
},
| Expr::Mul(a, b) => {
let (order_a, _) = get_term_order_and_coeff(a, func, var);
let (order_b, _) = get_term_order_and_coeff(b, func, var);
if order_a > 100 && order_b <= 100 {
let (order, coeff) = get_term_order_and_coeff(b, func, var);
(order, simplify(&Expr::new_mul(coeff, a.as_ref().clone())))
} else if order_b > 100 && order_a <= 100 {
let (order, coeff) = get_term_order_and_coeff(a, func, var);
(order, simplify(&Expr::new_mul(coeff, b.as_ref().clone())))
} else {
(999, expr.clone())
}
},
| Expr::Variable(v) if v == func => (0, Expr::Constant(1.0)),
| _ => (999, expr.clone()),
}
}
pub(crate) fn find_constants(
expr: &Expr,
constants: &mut Vec<String>,
) {
if let Expr::Variable(s) = expr {
if s.starts_with('C')
&& s.chars().skip(1).all(|c| c.is_ascii_digit())
&& !constants.contains(s)
{
constants.push(s.clone());
}
}
match expr {
| Expr::Dag(node) => {
find_constants(&node.to_expr().expect("Found Constants"), constants);
},
| Expr::Add(a, b)
| Expr::Sub(a, b)
| Expr::Mul(a, b)
| Expr::Div(a, b)
| Expr::Power(a, b) => {
find_constants(a, constants);
find_constants(b, constants);
},
| Expr::Sin(a)
| Expr::Cos(a)
| Expr::Tan(a)
| Expr::Exp(a)
| Expr::Log(a)
| Expr::Neg(a) => {
find_constants(a, constants);
},
| _ => {},
}
}
pub(crate) fn find_derivatives(
expr: &Expr,
var: &str,
derivatives: &mut HashMap<String, u32>,
) {
if let Expr::Derivative(inner, d_var) = expr {
if d_var == var {
let mut current = &**inner;
let mut order = 1;
while let Expr::Derivative(next_inner, next_d_var) = current {
if next_d_var == var {
order += 1;
current = next_inner;
} else {
break;
}
}
if let Expr::Variable(func_name) = current {
let entry = derivatives.entry(func_name.clone()).or_insert(0);
*entry = std::cmp::max(*entry, order);
}
}
}
match expr {
| Expr::Dag(node) => {
find_derivatives(
&node.to_expr().expect("Found Derivatives"),
var,
derivatives,
);
},
| Expr::Add(a, b)
| Expr::Sub(a, b)
| Expr::Mul(a, b)
| Expr::Div(a, b)
| Expr::Power(a, b) => {
find_derivatives(a, var, derivatives);
find_derivatives(b, var, derivatives);
},
| Expr::Sin(a)
| Expr::Cos(a)
| Expr::Tan(a)
| Expr::Exp(a)
| Expr::Log(a)
| Expr::Neg(a) => {
find_derivatives(a, var, derivatives);
},
| _ => {},
}
}
#[must_use]
pub fn solve_ode(
ode: &Expr,
func: &str,
var: &str,
initial_conditions: Option<&[(Expr, u32, Expr)]>,
) -> Expr {
let general_solution_eq = solve_ode_system(std::slice::from_ref(ode), &[func], var)
.and_then(|mut solutions| solutions.pop())
.map_or_else(
|| Expr::Solve(Arc::new(ode.clone()), func.to_string()),
|sol| Expr::Eq(Arc::new(Expr::Variable(func.to_string())), Arc::new(sol)),
);
if let Some(conditions) = initial_conditions {
if let Expr::Eq(_, general_solution) = &general_solution_eq {
return apply_initial_conditions(general_solution, var, conditions);
}
}
general_solution_eq
}
#[must_use]
pub fn solve_ode_system(
equations: &[Expr],
funcs: &[&str],
var: &str,
) -> Option<Vec<Expr>> {
let (first_order_eqs, all_vars, original_funcs_map) =
reduce_to_first_order_system(equations, funcs, var).ok()?;
let first_order_funcs: Vec<&str> = all_vars.iter().map(std::string::String::as_str).collect();
let solutions_map =
solve_first_order_system_sequentially(&first_order_eqs, &first_order_funcs, var)?;
let mut final_solutions = Vec::new();
for &original_func in funcs {
let sol_var = original_funcs_map.get(original_func)?;
let solution = solutions_map.get(sol_var)?;
final_solutions.push(solution.clone());
}
Some(final_solutions)
}
pub(crate) fn try_all_solvers(
equation: &Expr,
func: &str,
var: &str,
) -> Option<Expr> {
let eq = if let Expr::Eq(l, r) = equation {
simplify(&Expr::new_sub(l.clone(), r.clone()))
} else {
equation.clone()
};
if let Some(sol) = solve_first_order_linear_ode(&eq, func, var) {
return Some(sol);
}
if let Some(sol) = solve_separable_ode(&eq, func, var) {
return Some(sol);
}
solve_bernoulli_ode(&eq, func, var)
.or_else(|| solve_cauchy_euler_ode(&eq, func, var))
.or_else(|| solve_exact_ode(&eq, func, var))
}
pub(crate) fn apply_initial_conditions(
general_solution: &Expr,
var: &str,
conditions: &[(Expr, u32, Expr)],
) -> Expr {
let mut constants = Vec::new();
find_constants(general_solution, &mut constants);
constants.sort();
if constants.is_empty() || conditions.is_empty() {
return general_solution.clone();
}
let mut eq_system = Vec::new();
for (x0, order, y_val) in conditions {
let mut sol_deriv = general_solution.clone();
for _ in 0..*order {
sol_deriv = differentiate(&sol_deriv, var);
}
let substituted_sol = substitute(&sol_deriv, var, x0);
let equation = simplify(&Expr::Eq(
Arc::new(substituted_sol),
Arc::new(y_val.clone()),
));
eq_system.push(equation);
}
if eq_system.len() < constants.len() {
return Expr::Variable(
"Not enough initial \
conditions"
.to_string(),
);
}
if let Ok(const_solutions) = solve_linear_system(&Expr::System(eq_system), &constants) {
let mut final_solution = general_solution.clone();
for (i, c_var) in constants.iter().enumerate() {
if i < const_solutions.len() {
final_solution = substitute(&final_solution, c_var, &const_solutions[i]);
}
}
return simplify(&final_solution);
}
Expr::Variable("Could not solve for constants".to_string())
}
type ReducedFirstOrderSystem = (Vec<Expr>, Vec<String>, HashMap<String, String>);
pub(crate) fn reduce_to_first_order_system(
equations: &[Expr],
funcs: &[&str],
var: &str,
) -> Result<ReducedFirstOrderSystem, String> {
let mut new_eqs = Vec::new();
let mut new_vars_map: HashMap<(String, u32), String> = HashMap::new();
let mut all_new_vars = funcs
.iter()
.map(|s| (*s).to_string())
.collect::<HashSet<_>>();
let mut original_funcs_map = HashMap::new();
for &func in funcs {
new_vars_map.insert((func.to_string(), 0), func.to_string());
original_funcs_map.insert(func.to_string(), func.to_string());
}
let mut temp_eqs = equations.to_vec();
let mut i = 0;
while i < temp_eqs.len() {
let eq = &temp_eqs[i];
let mut derivatives = HashMap::new();
find_derivatives(eq, var, &mut derivatives);
let mut eq_with_substitutions = eq.clone();
for (func, &order) in &derivatives {
if order > 1 {
for k in 1..order {
let key = (func.clone(), k);
if !new_vars_map.contains_key(&key) {
let new_var_name = format!("{func}_d{k}");
all_new_vars.insert(new_var_name.clone());
new_vars_map.insert(key.clone(), new_var_name.clone());
let prev_var_name = match new_vars_map.get(&(func.clone(), k - 1)) {
| Some(name) => name,
| None => {
return Err(
"Logic error: previous derivative not found in map".to_string()
);
},
};
let new_eq = Expr::Eq(
Arc::new(Expr::Derivative(
Arc::new(Expr::Variable(prev_var_name.clone())),
var.to_string(),
)),
Arc::new(Expr::Variable(new_var_name.clone())),
);
temp_eqs.push(new_eq);
}
}
let highest_deriv = (0..order).fold(Expr::Variable(func.clone()), |e, _| {
Expr::Derivative(Arc::new(e), var.to_string())
});
let replacement_var_name = match new_vars_map.get(&(func.clone(), order - 1)) {
| Some(name) => name,
| None => {
return Err("Logic error: highest derivative not found in map".to_string());
},
};
let replacement_expr = Expr::Derivative(
Arc::new(Expr::Variable(replacement_var_name.clone())),
var.to_string(),
);
eq_with_substitutions = substitute(
&eq_with_substitutions,
&highest_deriv.to_string(),
&replacement_expr,
);
}
}
new_eqs.push(eq_with_substitutions);
i += 1;
}
Ok((
new_eqs,
all_new_vars.into_iter().collect(),
original_funcs_map,
))
}
pub(crate) fn solve_first_order_system_sequentially(
equations: &[Expr],
funcs: &[&str],
var: &str,
) -> Option<HashMap<String, Expr>> {
let mut remaining_eqs: Vec<Expr> = equations.to_vec();
let mut solutions: HashMap<String, Expr> = HashMap::new();
let mut progress = true;
while progress && !remaining_eqs.is_empty() {
progress = false;
let mut solved_eq_indices = Vec::new();
for (i, eq) in remaining_eqs.iter().enumerate() {
let mut current_eq = eq.clone();
for (solved_func, solution_expr) in &solutions {
current_eq = substitute(¤t_eq, solved_func, solution_expr);
}
let mut remaining_funcs = Vec::new();
for &f in funcs {
if !solutions.contains_key(f) {
let mut found = false;
current_eq.pre_order_walk(&mut |e| {
if let Expr::Variable(v) = e {
if v == f {
found = true;
}
}
if let Expr::Derivative(inner, _) = e {
if let Expr::Variable(v) = &**inner {
if v == f {
found = true;
}
}
}
});
if found {
remaining_funcs.push(f);
}
}
}
if remaining_funcs.len() == 1 {
let func_to_solve = remaining_funcs[0];
let solution_eq = try_all_solvers(¤t_eq, func_to_solve, var)?;
let solution_eq = if let Expr::Dag(node) = solution_eq {
node.to_expr().expect("Unwrap solution DAG")
} else {
solution_eq
};
if let Expr::Eq(_, solution_expr) = solution_eq {
solutions.insert(func_to_solve.to_string(), solution_expr.as_ref().clone());
solved_eq_indices.push(i);
progress = true;
break;
}
}
}
for &i in solved_eq_indices.iter().rev() {
remaining_eqs.remove(i);
}
}
if solutions.len() == funcs.len() {
Some(solutions)
} else {
None
}
}
pub(crate) fn separate_factors(
expr: &Expr,
func: &str,
var: &str,
) -> Option<(Expr, Expr)> {
if !contains_var(expr, func) {
return Some((expr.clone(), Expr::Constant(1.0)));
}
if !contains_var(expr, var) {
return Some((Expr::Constant(1.0), expr.clone()));
}
match expr {
| Expr::Mul(a, b) => {
let (ga, ha) = separate_factors(a, func, var)?;
let (gb, hb) = separate_factors(b, func, var)?;
Some((
simplify(&Expr::new_mul(ga, gb)),
simplify(&Expr::new_mul(ha, hb)),
))
},
| Expr::Div(a, b) => {
let (ga, ha) = separate_factors(a, func, var)?;
let (gb, hb) = separate_factors(b, func, var)?;
Some((
simplify(&Expr::new_div(ga, gb)),
simplify(&Expr::new_div(ha, hb)),
))
},
| Expr::Neg(a) => {
let (g, h) = separate_factors(a, func, var)?;
Some((simplify(&Expr::new_neg(g)), h))
},
| Expr::Dag(node) => {
separate_factors(&node.to_expr().expect("Separate Factors"), func, var)
},
| _ => None,
}
}
#[must_use]
pub fn solve_separable_ode(
equation: &Expr,
func: &str,
var: &str,
) -> Option<Expr> {
if let Expr::Dag(node) = equation {
return solve_separable_ode(&node.to_expr().expect("Unwrap DAG"), func, var);
}
let y_prime = Expr::Derivative(Arc::new(Expr::Variable(func.to_string())), var.to_string());
let (f_y, g_x) = match equation {
| Expr::Sub(lhs, rhs) if **lhs == y_prime => (Expr::Constant(1.0), rhs.as_ref().clone()),
| Expr::Sub(lhs, rhs) => {
if let Expr::Mul(a, b) = lhs.as_ref() {
if **b == y_prime {
(a.as_ref().clone(), rhs.as_ref().clone())
} else if **a == y_prime {
(b.as_ref().clone(), rhs.as_ref().clone())
} else {
return None;
}
} else {
return None;
}
},
| Expr::Add(lhs, rhs) if **lhs == y_prime => {
if let Expr::Neg(inner) = rhs.as_ref() {
(Expr::Constant(1.0), inner.as_ref().clone())
} else if let Expr::Mul(a, b) = rhs.as_ref() {
if let Expr::Constant(c) = a.as_ref() {
if (*c + 1.0).abs() < f64::EPSILON {
(Expr::Constant(1.0), b.as_ref().clone())
} else {
return None;
}
} else if let Expr::Constant(c) = b.as_ref() {
if (*c + 1.0).abs() < f64::EPSILON {
(Expr::Constant(1.0), a.as_ref().clone())
} else {
return None;
}
} else {
return None;
}
} else {
return None;
}
},
| _ => {
return None;
},
};
if let Some((g_x_sep, h_y_sep)) = separate_factors(&g_x, func, var) {
let new_f_y = simplify(&Expr::new_div(f_y, h_y_sep));
let new_g_x = g_x_sep;
if !new_g_x.to_string().contains(func) && !new_f_y.to_string().contains(var) {
let int_f_y = integrate(&new_f_y, func, None, None);
let int_g_x = integrate(&new_g_x, var, None, None);
let c = Expr::Variable("C".to_string());
return Some(simplify(&Expr::Eq(
Arc::new(int_f_y),
Arc::new(Expr::new_add(int_g_x, c)),
)));
}
}
None
}
#[must_use]
pub fn solve_first_order_linear_ode(
equation: &Expr,
func: &str,
var: &str,
) -> Option<Expr> {
if let Expr::Dag(node) = equation {
return solve_first_order_linear_ode(&node.to_expr().expect("Unwrap DAG"), func, var);
}
let parsed = parse_ode(equation, func, var);
if parsed.order != 1 {
return None;
}
let p_x = parsed.coeffs.get(&0).cloned()?;
let r_x = parsed.remaining_expr;
let q_x = simplify(&Expr::new_neg(r_x));
let y_expr = Expr::Variable(func.to_string());
let mu = Expr::new_exp(integrate(&p_x, var, None, None));
let rhs = integrate(&simplify(&Expr::new_mul(q_x, mu.clone())), var, None, None);
let c = Expr::Variable("C1".to_string());
let solution = simplify(&Expr::new_div(Expr::new_add(rhs, c), mu));
Some(Expr::Eq(Arc::new(y_expr), Arc::new(solution)))
}
#[must_use]
pub fn solve_bernoulli_ode(
equation: &Expr,
func: &str,
var: &str,
) -> Option<Expr> {
if let Expr::Dag(node) = equation {
return solve_bernoulli_ode(&node.to_expr().expect("Unwrap DAG"), func, var);
}
let y = Expr::Variable(func.to_string());
let y_prime = Expr::Derivative(Arc::new(y.clone()), var.to_string());
let pattern = Expr::new_add(
y_prime,
Expr::new_sub(
Expr::new_mul(Expr::Pattern("P".to_string()), y.clone()),
Expr::new_mul(
Expr::Pattern("Q".to_string()),
Expr::new_pow(y.clone(), Expr::Pattern("n".to_string())),
),
),
);
if let Some(m) = pattern_match(equation, &pattern) {
let p_x = m.get("P")?;
let q_x = m.get("Q")?;
let n = m.get("n")?.to_f64()?;
if (n - 1.0).abs() < f64::EPSILON || n.abs() < f64::EPSILON {
return None;
}
let one_minus_n = 1.0 - n;
let p_v = simplify(&Expr::new_mul(Expr::Constant(one_minus_n), p_x.clone()));
let q_v = simplify(&Expr::new_mul(Expr::Constant(one_minus_n), q_x.clone()));
let v_prime = Expr::Derivative(Arc::new(Expr::Variable("v".to_string())), var.to_string());
let linear_ode_v = Expr::new_add(
v_prime,
Expr::new_sub(Expr::new_mul(p_v, Expr::Variable("v".to_string())), q_v),
);
let v_solution_eq = solve_first_order_linear_ode(&linear_ode_v, "v", var)?;
let v_solution = if let Expr::Eq(_, sol) = v_solution_eq {
sol
} else {
return None;
};
let y_solution = simplify(&Expr::new_pow(
v_solution.as_ref().clone(),
Expr::Constant(1.0 / one_minus_n),
));
return Some(Expr::Eq(Arc::new(y), Arc::new(y_solution)));
}
None
}
#[must_use]
pub fn solve_riccati_ode(
equation: &Expr,
func: &str,
var: &str,
y1: &Expr,
) -> Option<Expr> {
fn collect_add_terms(
expr: &Expr,
terms: &mut Vec<Expr>,
) {
match expr {
| Expr::Add(a, b) => {
collect_add_terms(a, terms);
collect_add_terms(b, terms);
},
| Expr::AddList(list) => {
for item in list {
collect_add_terms(item, terms);
}
},
| Expr::Dag(node) => {
collect_add_terms(
&node.to_expr().expect(
"Collect add \
terms",
),
terms,
);
},
| _ => {
terms.push(expr.clone());
},
}
}
if let Expr::Dag(node) = equation {
return solve_riccati_ode(&node.to_expr().expect("Unwrap DAG"), func, var, y1);
}
let eq = if let Expr::Eq(l, r) = equation {
simplify(&Expr::new_sub(l.clone(), r.clone()))
} else {
equation.clone()
};
let y_prime = Expr::Derivative(Arc::new(Expr::Variable(func.to_string())), var.to_string());
let eq_y_prime_0 = substitute_expr(&eq, &y_prime, &Expr::Constant(0.0));
let eq_y_prime_1 = substitute_expr(&eq, &y_prime, &Expr::Constant(1.0));
let a_x = simplify(&Expr::new_sub(eq_y_prime_1, eq_y_prime_0.clone()));
if contains_var(&a_x, func) {
return None;
}
let rhs_poly = simplify(&Expr::new_neg(Expr::new_div(eq_y_prime_0, a_x)));
let rhs_poly = if let Expr::Dag(node) = &rhs_poly {
node.to_expr().expect(
"Unwrap rhs_poly \
DAG",
)
} else {
rhs_poly
};
let mut coeffs = std::collections::HashMap::new();
let mut terms = Vec::new();
collect_add_terms(&rhs_poly, &mut terms);
for term in &terms {
let term_str = term.to_string();
if !term_str.contains('y') {
let entry = coeffs.entry(0).or_insert(Expr::Constant(0.0));
*entry = simplify(&Expr::new_add(entry.clone(), term.clone()));
} else if let Expr::Power(base, exp) = term {
if let (Expr::Variable(v), Some(n)) = (base.as_ref(), exp.to_f64()) {
if v == "y" {
let entry = coeffs
.entry((n as i64).try_into().unwrap_or(0))
.or_insert(Expr::Constant(0.0));
*entry = simplify(&Expr::new_add(entry.clone(), Expr::Constant(1.0)));
}
}
} else if let Expr::Variable(v) = term {
if v == "y" {
let entry = coeffs.entry(1).or_insert(Expr::Constant(0.0));
*entry = simplify(&Expr::new_add(entry.clone(), Expr::Constant(1.0)));
}
} else if let Expr::Mul(a, b) = term {
let (coeff, deg) = if let Expr::Power(base, exp) = b.as_ref() {
if let (Expr::Variable(v), Some(n)) = (base.as_ref(), exp.to_f64()) {
if v == "y" {
(a.as_ref().clone(), (n as i64).try_into().unwrap_or(0))
} else {
continue;
}
} else {
continue;
}
} else if let Expr::Variable(v) = b.as_ref() {
if v == "y" {
(a.as_ref().clone(), 1)
} else {
continue;
}
} else if let Expr::Power(base, exp) = a.as_ref() {
if let (Expr::Variable(v), Some(n)) = (base.as_ref(), exp.to_f64()) {
if v == "y" {
(b.as_ref().clone(), (n as i64).try_into().unwrap_or(0))
} else {
continue;
}
} else {
continue;
}
} else if let Expr::Variable(v) = a.as_ref() {
if v == "y" {
(b.as_ref().clone(), 1)
} else {
continue;
}
} else {
continue;
};
let entry = coeffs.entry(deg).or_insert(Expr::Constant(0.0));
*entry = simplify(&Expr::new_add(entry.clone(), coeff));
}
}
let max_deg = coeffs.keys().max().copied().unwrap_or(0);
if max_deg != 2 {
return None;
}
let _p = coeffs.get(&0).cloned().unwrap_or(Expr::Constant(0.0));
let q = coeffs.get(&1).cloned().unwrap_or(Expr::Constant(0.0));
let r = coeffs.get(&2).cloned().unwrap_or(Expr::Constant(0.0));
let v_var = "v";
let v = Expr::Variable(v_var.to_string());
let v_prime = Expr::Derivative(Arc::new(v.clone()), var.to_string());
let p_v = simplify(&Expr::new_add(
q,
Expr::new_mul(Expr::Constant(2.0), Expr::new_mul(r.clone(), y1.clone())),
));
let q_v = simplify(&Expr::new_neg(r));
let linear_ode = Expr::new_sub(Expr::new_add(v_prime, Expr::new_mul(p_v, v)), q_v);
let v_sol_eq = solve_first_order_linear_ode(&linear_ode, v_var, var)?;
let v_sol = if let Expr::Eq(_, sol) = v_sol_eq {
sol
} else {
return None;
};
let y_sol = simplify(&Expr::new_add(
y1.clone(),
Expr::new_div(Expr::Constant(1.0), v_sol),
));
Some(Expr::Eq(
Arc::new(Expr::Variable(func.to_string())),
Arc::new(y_sol),
))
}
#[must_use]
pub fn solve_cauchy_euler_ode(
equation: &Expr,
func: &str,
var: &str,
) -> Option<Expr> {
if let Expr::Dag(node) = equation {
return solve_cauchy_euler_ode(&node.to_expr().expect("Unwrap DAG"), func, var);
}
let parsed = parse_ode(equation, func, var);
if parsed.order != 2 || !is_zero(&parsed.remaining_expr) {
return None;
}
let c2 = parsed.coeffs.get(&2)?;
let c1 = parsed.coeffs.get(&1)?;
let c0 = parsed.coeffs.get(&0)?;
let x = Expr::Variable(var.to_string());
let x_sq = Expr::new_pow(x.clone(), Expr::Constant(2.0));
let a = simplify(&Expr::new_div(c2.clone(), x_sq));
let b = simplify(&Expr::new_div(c1.clone(), x.clone()));
let c = c0.clone();
if a.to_f64().is_none() || b.to_f64().is_none() || c.to_f64().is_none() {
return None;
}
let m = Expr::Variable("m".to_string());
let b_minus_a = simplify(&Expr::new_sub(b, a.clone()));
let aux_eq = Expr::new_add(
Expr::new_mul(a, Expr::new_pow(m.clone(), Expr::Constant(2.0))),
Expr::new_add(Expr::new_mul(b_minus_a, m), c),
);
let roots = solve(&aux_eq, "m");
if roots.len() != 2 {
return None;
}
let m1 = &roots[0];
let m2 = &roots[1];
let const1 = Expr::Variable("C1".to_string());
let const2 = Expr::Variable("C2".to_string());
let solution = if m1 == m2 {
simplify(&Expr::new_mul(
Expr::new_pow(x.clone(), m1.clone()),
Expr::new_add(const1, Expr::new_mul(const2, Expr::new_log(x))),
))
} else {
simplify(&Expr::new_add(
Expr::new_mul(const1, Expr::new_pow(x.clone(), m1.clone())),
Expr::new_mul(const2, Expr::new_pow(x, m2.clone())),
))
};
Some(Expr::Eq(
Arc::new(Expr::Variable(func.to_string())),
Arc::new(solution),
))
}
#[must_use]
pub fn solve_by_reduction_of_order(
equation: &Expr,
func: &str,
var: &str,
y1: &Expr,
) -> Option<Expr> {
if let Expr::Dag(node) = equation {
return solve_by_reduction_of_order(&node.to_expr().expect("Unwrap DAG"), func, var, y1);
}
let parsed = parse_ode(equation, func, var);
if parsed.order != 2 || !is_zero(&parsed.remaining_expr) {
return None;
}
let coeff2 = parsed.coeffs.get(&2)?;
let p_x = simplify(&Expr::new_div(
parsed.coeffs.get(&1)?.clone(),
coeff2.clone(),
));
let integral_p = integrate(&p_x, var, None, None);
let exp_term = Expr::new_exp(Expr::new_neg(integral_p));
let y1_sq = Expr::new_pow(y1.clone(), Expr::Constant(2.0));
let integrand = simplify(&Expr::new_div(exp_term, y1_sq));
let integral_v = integrate(&integrand, var, None, None);
let y2 = simplify(&Expr::new_mul(y1.clone(), integral_v));
let c1 = Expr::Variable("C1".to_string());
let c2 = Expr::Variable("C2".to_string());
let general_solution = simplify(&Expr::new_add(
Expr::new_mul(c1, y1.clone()),
Expr::new_mul(c2, y2),
));
Some(Expr::Eq(
Arc::new(Expr::Variable(func.to_string())),
Arc::new(general_solution),
))
}
#[must_use]
pub fn solve_exact_ode(
equation: &Expr,
func: &str,
var: &str,
) -> Option<Expr> {
if let Expr::Dag(node) = equation {
return solve_exact_ode(&node.to_expr().expect("Unwrap DAG"), func, var);
}
let y = Expr::Variable(func.to_string());
let y_prime = Expr::Derivative(Arc::new(y), var.to_string());
let pattern = Expr::new_add(
Expr::Pattern("M".to_string()),
Expr::new_mul(Expr::Pattern("N".to_string()), y_prime),
);
if let Some(m) = pattern_match(equation, &pattern) {
let m_xy = m.get("M")?;
let n_xy = m.get("N")?;
let dm_dy = differentiate(m_xy, func);
let dn_dx = differentiate(n_xy, var);
if simplify(&dm_dy) != simplify(&dn_dx) {
return None;
}
let int_m_dx = integrate(m_xy, var, None, None);
let d_int_m_dy = differentiate(&int_m_dx, func);
let g_prime_y = simplify(&Expr::new_sub(n_xy.clone(), d_int_m_dy));
let g_y = integrate(&g_prime_y, func, None, None);
let f_xy = simplify(&Expr::new_add(int_m_dx, g_y));
return Some(Expr::Eq(
Arc::new(f_xy),
Arc::new(Expr::Variable("C".to_string())),
));
}
None
}
#[must_use]
pub fn solve_ode_by_series(
equation: &Expr,
func: &str,
var: &str,
x0: &Expr,
order: u32,
initial_conditions: &[(u32, Expr)],
) -> Option<Expr> {
let mut y_n_at_x0: HashMap<u32, Expr> = initial_conditions.iter().cloned().collect();
let parsed = parse_ode(equation, func, var);
let highest_order = parsed.order;
let coeff_helper = parsed.coeffs.clone();
let coeff_highest = coeff_helper.get(&highest_order)?;
let mut other_terms = Expr::Constant(0.0);
for (o, c) in parsed.coeffs {
if o < highest_order {
let deriv = (0..o).fold(Expr::Variable(func.to_string()), |e, _| {
Expr::Derivative(Arc::new(e), var.to_string())
});
other_terms = simplify(&Expr::new_add(other_terms, Expr::new_mul(c, deriv)));
}
}
other_terms = simplify(&Expr::new_add(other_terms, parsed.remaining_expr));
let highest_deriv_expr = simplify(&Expr::new_neg(Expr::new_div(
other_terms,
coeff_highest.clone(),
)));
for n in highest_order..=order {
if !y_n_at_x0.contains_key(&n) {
let mut current_expr = highest_deriv_expr.clone();
for i in 0..n {
let deriv_i_expr = (0..i).fold(Expr::Variable(func.to_string()), |e, _| {
Expr::Derivative(Arc::new(e), var.to_string())
});
if let Some(val) = y_n_at_x0.get(&i) {
current_expr = substitute(¤t_expr, &deriv_i_expr.to_string(), val);
}
}
let val_at_x0 = substitute(¤t_expr, var, x0);
y_n_at_x0.insert(n, simplify(&val_at_x0));
}
}
let mut series_sum = Expr::Constant(0.0);
for n in 0..=order {
if let Some(y_n_val) = y_n_at_x0.get(&n) {
let n_factorial = f64::from((1..=n).product::<u32>());
let coeff_term = simplify(&Expr::new_div(y_n_val.clone(), Expr::Constant(n_factorial)));
let power_term = Expr::new_pow(
Expr::new_sub(Expr::Variable(var.to_string()), x0.clone()),
Expr::Constant(f64::from(n)),
);
series_sum = simplify(&Expr::new_add(
series_sum,
Expr::new_mul(coeff_term, power_term),
));
}
}
Some(series_sum)
}
#[must_use]
pub fn solve_ode_by_fourier(
equation: &Expr,
func: &str,
var: &str,
) -> Option<Expr> {
let omega_var = "w";
let parsed = parse_ode(equation, func, var);
let g_w = transforms::fourier_transform(&parsed.remaining_expr, var, omega_var);
let mut algebraic_lhs = Expr::Constant(0.0);
let y_w = Expr::Variable("Y".to_string());
for (order, coeff) in parsed.coeffs {
coeff.to_f64()?;
let mut deriv_transform = y_w.clone();
for _ in 0..order {
deriv_transform = transforms::fourier_differentiation(&deriv_transform, omega_var);
}
let term = simplify(&Expr::new_mul(coeff, deriv_transform));
algebraic_lhs = simplify(&Expr::new_add(algebraic_lhs, term));
}
let algebraic_eq = simplify(&Expr::new_sub(algebraic_lhs, simplify(&Expr::new_neg(g_w))));
let y_w_solutions = solve(&algebraic_eq, "Y");
if y_w_solutions.is_empty() {
return None;
}
let y_w_solution = y_w_solutions[0].clone();
let solution = transforms::inverse_fourier_transform(&y_w_solution, omega_var, var);
Some(Expr::Eq(
Arc::new(Expr::Variable(func.to_string())),
Arc::new(solution),
))
}