use super::{ODEError, ODEResult};
use crate::calculus::ode::first_order::linear::LinearFirstOrderSolver;
use crate::core::{Expression, Symbol};
use crate::simplify::Simplify;
pub struct BernoulliODESolver;
impl BernoulliODESolver {
pub fn solve(
&self,
p: &Expression,
q: &Expression,
n: &Expression,
dependent: &Symbol,
independent: &Symbol,
) -> ODEResult {
let n_simplified = n.simplify();
if n_simplified == Expression::integer(0) || n_simplified == Expression::integer(1) {
return Err(ODEError::NotLinearForm {
reason: "Bernoulli equation requires n ≠ 0, 1 (this is linear)".to_owned(),
});
}
let one_minus_n = Expression::add(vec![
Expression::integer(1),
Expression::mul(vec![Expression::integer(-1), n.clone()]),
])
.simplify();
let p_prime = Expression::mul(vec![one_minus_n.clone(), p.clone()]).simplify();
let q_prime = Expression::mul(vec![one_minus_n.clone(), q.clone()]).simplify();
let linear_solver = LinearFirstOrderSolver;
let v_solution = linear_solver.solve(&p_prime, &q_prime, dependent, independent, None)?;
let exponent = Expression::pow(one_minus_n, Expression::integer(-1)).simplify();
let y_solution = Expression::pow(v_solution, exponent).simplify();
Ok(y_solution)
}
pub fn detect_form(
&self,
_equation: &Expression,
_dependent: &Symbol,
_independent: &Symbol,
) -> Option<(Expression, Expression, Expression)> {
None
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{expr, symbol};
#[test]
fn test_bernoulli_n_equals_2() {
let x = symbol!(x);
let y = symbol!(y);
let solver = BernoulliODESolver;
let solution = solver.solve(&expr!(1), &expr!(x), &expr!(2), &y, &x);
assert!(solution.is_ok());
let sol = solution.unwrap();
let sol_str = sol.to_string();
assert!(sol_str.contains("exp") || sol_str.contains("x"));
}
#[test]
fn test_bernoulli_n_equals_3() {
let x = symbol!(x);
let y = symbol!(y);
let solver = BernoulliODESolver;
let solution = solver.solve(&expr!(2), &expr!(1), &expr!(3), &y, &x);
assert!(solution.is_ok());
}
#[test]
fn test_bernoulli_rejects_n_equals_0() {
let x = symbol!(x);
let y = symbol!(y);
let solver = BernoulliODESolver;
let result = solver.solve(&expr!(1), &expr!(x), &expr!(0), &y, &x);
assert!(result.is_err());
if let Err(ODEError::NotLinearForm { reason }) = result {
assert!(reason.contains("linear"));
}
}
#[test]
fn test_bernoulli_rejects_n_equals_1() {
let x = symbol!(x);
let y = symbol!(y);
let solver = BernoulliODESolver;
let result = solver.solve(&expr!(1), &expr!(x), &expr!(1), &y, &x);
assert!(result.is_err());
if let Err(ODEError::NotLinearForm { reason }) = result {
assert!(reason.contains("linear"));
}
}
#[test]
fn test_bernoulli_negative_n() {
let x = symbol!(x);
let y = symbol!(y);
let solver = BernoulliODESolver;
let solution = solver.solve(&expr!(1), &expr!(1), &expr!(-1), &y, &x);
assert!(solution.is_ok());
}
#[test]
fn test_bernoulli_fractional_n() {
let x = symbol!(x);
let y = symbol!(y);
let solver = BernoulliODESolver;
let n = Expression::mul(vec![
Expression::integer(1),
Expression::pow(Expression::integer(2), Expression::integer(-1)),
]);
let solution = solver.solve(&expr!(1), &expr!(1), &n, &y, &x);
assert!(solution.is_ok());
}
#[test]
fn test_bernoulli_solution_structure() {
let x = symbol!(x);
let y = symbol!(y);
let solver = BernoulliODESolver;
let solution = solver
.solve(&expr!(1), &expr!(1), &expr!(2), &y, &x)
.unwrap();
let sol_str = solution.to_string();
assert!(sol_str.contains("C"));
}
}