use crate::calculus::integrals::Integration;
use crate::calculus::ode::first_order::{ODEError, ODEResult};
use crate::core::{Expression, Symbol};
use crate::simplify::Simplify;
use std::collections::HashMap;
pub struct SeparableODESolver;
impl SeparableODESolver {
pub fn new() -> Self {
Self
}
pub fn solve(
&self,
rhs: &Expression,
dependent: &Symbol,
independent: &Symbol,
initial_condition: Option<(Expression, Expression)>,
) -> ODEResult {
let (g_x, h_y) = self.separate(rhs, dependent, independent)?;
let integrand_y = Expression::pow(h_y, Expression::integer(-1));
let integral_y = integrand_y.integrate(dependent.clone(), 0);
let integral_x = g_x.integrate(independent.clone(), 0);
let c1 = Symbol::new("C1");
let general_solution = Expression::add(vec![
integral_y,
Expression::mul(vec![Expression::integer(-1), integral_x]),
Expression::symbol(c1),
])
.simplify();
if let Some((x0, y0)) = initial_condition {
self.apply_initial_condition(&general_solution, dependent, independent, x0, y0)
} else {
Ok(general_solution)
}
}
pub fn is_separable(&self, rhs: &Expression, dependent: &Symbol, independent: &Symbol) -> bool {
self.separate(rhs, dependent, independent).is_ok()
}
fn separate(
&self,
rhs: &Expression,
dependent: &Symbol,
independent: &Symbol,
) -> Result<(Expression, Expression), ODEError> {
if !rhs.contains_variable(dependent) {
return Ok((rhs.clone(), Expression::integer(1)));
}
if !rhs.contains_variable(independent) {
return Ok((Expression::integer(1), rhs.clone()));
}
if let Expression::Mul(factors) = rhs {
let mut x_factors = Vec::new();
let mut y_factors = Vec::new();
for factor in factors.iter() {
if factor.contains_variable(dependent) && factor.contains_variable(independent) {
return Err(ODEError::UnknownType {
equation: rhs.clone(),
reason: "Cannot separate variables - factor contains both x and y"
.to_owned(),
});
} else if factor.contains_variable(independent) {
x_factors.push(factor.clone());
} else if factor.contains_variable(dependent) {
y_factors.push(factor.clone());
} else {
x_factors.push(factor.clone());
}
}
let g_x = if x_factors.is_empty() {
Expression::integer(1)
} else {
Expression::mul(x_factors)
};
let h_y = if y_factors.is_empty() {
Expression::integer(1)
} else {
Expression::mul(y_factors)
};
return Ok((g_x, h_y));
}
Err(ODEError::UnknownType {
equation: rhs.clone(),
reason: "Cannot factor into g(x)*h(y)".to_owned(),
})
}
fn apply_initial_condition(
&self,
general_solution: &Expression,
dependent: &Symbol,
independent: &Symbol,
x0: Expression,
y0: Expression,
) -> ODEResult {
let mut subs = HashMap::new();
subs.insert(independent.name().to_owned(), x0);
subs.insert(dependent.name().to_owned(), y0);
let substituted = general_solution.substitute(&subs);
let simplified = substituted.simplify();
let c1_value = simplified;
let mut c_subs = HashMap::new();
c_subs.insert("C1".to_owned(), c1_value);
let particular_solution = general_solution.substitute(&c_subs).simplify();
Ok(particular_solution)
}
}
impl Default for SeparableODESolver {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{expr, symbol};
#[test]
fn test_is_separable_simple_cases() {
let x = symbol!(x);
let y = symbol!(y);
let solver = SeparableODESolver::new();
assert!(solver.is_separable(&expr!(x), &y, &x));
assert!(solver.is_separable(&expr!(y), &y, &x));
assert!(solver.is_separable(&expr!(x * y), &y, &x));
assert!(!solver.is_separable(&expr!(x + y), &y, &x));
}
#[test]
fn test_separate_simple_linear() {
let x = symbol!(x);
let y = symbol!(y);
let solver = SeparableODESolver::new();
let rhs = expr!(x);
let result = solver.separate(&rhs, &y, &x);
assert!(result.is_ok());
let (g_x, h_y) = result.unwrap();
assert_eq!(g_x, expr!(x));
assert_eq!(h_y, Expression::integer(1));
}
#[test]
fn test_separate_product() {
let x = symbol!(x);
let y = symbol!(y);
let solver = SeparableODESolver::new();
let rhs = expr!(x * y);
let result = solver.separate(&rhs, &y, &x);
assert!(result.is_ok());
let (g_x, h_y) = result.unwrap();
assert_eq!(g_x, expr!(x));
assert_eq!(h_y, expr!(y));
}
#[test]
fn test_separate_non_separable() {
let x = symbol!(x);
let y = symbol!(y);
let solver = SeparableODESolver::new();
let rhs = expr!(x + y);
let result = solver.separate(&rhs, &y, &x);
assert!(result.is_err());
}
#[test]
fn test_solve_simple_linear() {
let x = symbol!(x);
let y = symbol!(y);
let rhs = expr!(x);
let solver = SeparableODESolver::new();
let solution = solver.solve(&rhs, &y, &x, None);
assert!(
solution.is_ok(),
"Failed to solve dy/dx = x: {:?}",
solution.err()
);
}
#[test]
fn test_solve_exponential() {
let x = symbol!(x);
let y = symbol!(y);
let rhs = expr!(y);
let solver = SeparableODESolver::new();
let solution = solver.solve(&rhs, &y, &x, None);
assert!(
solution.is_ok(),
"Failed to solve dy/dx = y: {:?}",
solution.err()
);
}
#[test]
fn test_solve_product() {
let x = symbol!(x);
let y = symbol!(y);
let rhs = expr!(x * y);
let solver = SeparableODESolver::new();
let solution = solver.solve(&rhs, &y, &x, None);
assert!(
solution.is_ok(),
"Failed to solve dy/dx = x*y: {:?}",
solution.err()
);
}
#[test]
fn test_solve_with_initial_condition() {
let x = symbol!(x);
let y = symbol!(y);
let rhs = expr!(x);
let ic = Some((expr!(0), expr!(1)));
let solver = SeparableODESolver::new();
let solution = solver.solve(&rhs, &y, &x, ic);
assert!(
solution.is_ok(),
"Failed to solve with IC: {:?}",
solution.err()
);
}
#[test]
fn test_non_separable_fails() {
let x = symbol!(x);
let y = symbol!(y);
let rhs = expr!(x + y);
let solver = SeparableODESolver::new();
let solution = solver.solve(&rhs, &y, &x, None);
assert!(solution.is_err(), "Should not solve non-separable ODE");
}
}