use super::{ODEError, ODEResult};
use crate::calculus::derivatives::Derivative;
use crate::calculus::integrals::Integration;
use crate::core::{Expression, Symbol};
use crate::simplify::Simplify;
use crate::symbol;
pub struct ExactODESolver;
impl ExactODESolver {
pub fn is_exact(&self, m: &Expression, n: &Expression, x: &Symbol, y: &Symbol) -> bool {
let dm_dy = m.derivative(y.clone()).simplify();
let dn_dx = n.derivative(x.clone()).simplify();
dm_dy == dn_dx
}
pub fn solve(&self, m: &Expression, n: &Expression, x: &Symbol, y: &Symbol) -> ODEResult {
if !self.is_exact(m, n, x, y) {
return Err(ODEError::NotLinearForm {
reason: "ODE is not exact: ∂M/∂y ≠ ∂N/∂x".to_owned(),
});
}
let f_from_m = m.integrate(x.clone(), 0);
let df_dy = f_from_m.derivative(y.clone()).simplify();
let g_prime = Expression::add(vec![
n.clone(),
Expression::mul(vec![Expression::integer(-1), df_dy]),
])
.simplify();
let g = g_prime.integrate(y.clone(), 0);
let potential = Expression::add(vec![f_from_m, g]).simplify();
let c = Expression::symbol(symbol!(C));
let solution = Expression::add(vec![potential, c]).simplify();
Ok(solution)
}
pub fn find_integrating_factor(
&self,
m: &Expression,
n: &Expression,
x: &Symbol,
y: &Symbol,
) -> Option<Expression> {
let dm_dy = m.derivative(y.clone()).simplify();
let dn_dx = n.derivative(x.clone()).simplify();
let numerator = Expression::add(vec![
dm_dy.clone(),
Expression::mul(vec![Expression::integer(-1), dn_dx.clone()]),
])
.simplify();
let quotient = Expression::mul(vec![
numerator,
Expression::pow(n.clone(), Expression::integer(-1)),
])
.simplify();
if !quotient.contains_variable(y) {
let integral = quotient.integrate(x.clone(), 0);
return Some(Expression::function("exp", vec![integral]));
}
let numerator_y = Expression::add(vec![
dn_dx,
Expression::mul(vec![Expression::integer(-1), dm_dy]),
])
.simplify();
let quotient_y = Expression::mul(vec![
numerator_y,
Expression::pow(m.clone(), Expression::integer(-1)),
])
.simplify();
if !quotient_y.contains_variable(x) {
let integral = quotient_y.integrate(y.clone(), 0);
return Some(Expression::function("exp", vec![integral]));
}
None
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{expr, symbol};
#[test]
fn test_exact_ode_simple() {
let x = symbol!(x);
let y = symbol!(y);
let m = expr!((2 * x) * y);
let n = expr!(x ^ 2);
let solver = ExactODESolver;
assert!(solver.is_exact(&m, &n, &x, &y));
let solution = solver.solve(&m, &n, &x, &y);
assert!(solution.is_ok());
}
#[test]
fn test_exact_ode_polynomial() {
let x = symbol!(x);
let y = symbol!(y);
let m = expr!((3 * (x ^ 2)) + y);
let n = expr!(x);
let solver = ExactODESolver;
assert!(solver.is_exact(&m, &n, &x, &y));
let solution = solver.solve(&m, &n, &x, &y);
assert!(solution.is_ok());
}
#[test]
fn test_not_exact() {
let x = symbol!(x);
let y = symbol!(y);
let m = expr!(y);
let n = expr!(2 * x);
let solver = ExactODESolver;
assert!(!solver.is_exact(&m, &n, &x, &y));
}
#[test]
fn test_integrating_factor_x_only() {
let x = symbol!(x);
let y = symbol!(y);
let m = expr!(y);
let n = expr!(2 * x);
let solver = ExactODESolver;
let mu = solver.find_integrating_factor(&m, &n, &x, &y);
assert!(mu.is_some());
}
#[test]
fn test_exact_solve_returns_implicit_solution() {
let x = symbol!(x);
let y = symbol!(y);
let m = expr!((2 * x) * y);
let n = expr!(x ^ 2);
let solver = ExactODESolver;
let solution = solver.solve(&m, &n, &x, &y).unwrap();
let sol_str = solution.to_string();
assert!(sol_str.contains("x") || sol_str.contains("y"));
assert!(sol_str.contains("C"));
}
}