use crate::ast::{BinaryOp, Expression, Function, UnaryOp, Variable};
use crate::integration::{integrate, IntegrationError};
#[derive(Debug, Clone, PartialEq)]
#[non_exhaustive]
pub enum ODEError {
NotInExpectedForm(String),
CannotSolve(String),
IntegrationFailed(IntegrationError),
InitialConditionError(String),
NotSeparable,
NotLinear,
CharacteristicEquationError(String),
NonConstantCoefficients(String),
BoundaryValueError(String),
ResonanceDetected(String),
}
impl From<IntegrationError> for ODEError {
fn from(e: IntegrationError) -> Self {
ODEError::IntegrationFailed(e)
}
}
impl std::fmt::Display for ODEError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ODEError::NotInExpectedForm(msg) => write!(f, "ODE not in expected form: {}", msg),
ODEError::CannotSolve(msg) => write!(f, "Cannot solve ODE: {}", msg),
ODEError::IntegrationFailed(e) => write!(f, "Integration failed: {}", e),
ODEError::InitialConditionError(msg) => {
write!(f, "Initial condition error: {}", msg)
}
ODEError::NotSeparable => write!(f, "ODE is not separable"),
ODEError::NotLinear => write!(f, "ODE is not first-order linear"),
ODEError::CharacteristicEquationError(msg) => {
write!(f, "Characteristic equation error: {}", msg)
}
ODEError::NonConstantCoefficients(msg) => {
write!(f, "Non-constant coefficients: {}", msg)
}
ODEError::BoundaryValueError(msg) => write!(f, "Boundary value error: {}", msg),
ODEError::ResonanceDetected(msg) => write!(f, "Resonance detected: {}", msg),
}
}
}
impl std::error::Error for ODEError {}
#[derive(Debug, Clone)]
pub struct FirstOrderODE {
pub dependent: String,
pub independent: String,
pub rhs: Expression,
}
impl FirstOrderODE {
pub fn new(dependent: &str, independent: &str, rhs: Expression) -> Self {
Self {
dependent: dependent.to_string(),
independent: independent.to_string(),
rhs,
}
}
pub fn is_separable(&self) -> bool {
try_separate(&self.rhs, &self.independent, &self.dependent).is_some()
}
pub fn is_linear(&self) -> bool {
extract_linear_coefficients(&self.rhs, &self.independent, &self.dependent).is_some()
}
}
#[derive(Debug, Clone)]
pub struct ODESolution {
pub general_solution: Expression,
pub method: String,
pub steps: Vec<String>,
}
#[must_use = "solving returns a result that should be used"]
pub fn solve_separable(ode: &FirstOrderODE) -> Result<ODESolution, ODEError> {
let mut steps = Vec::new();
steps.push(format!(
"Given ODE: d{}/d{} = {}",
ode.dependent, ode.independent, ode.rhs
));
let (g_x, h_y) =
try_separate(&ode.rhs, &ode.independent, &ode.dependent).ok_or(ODEError::NotSeparable)?;
steps.push(format!(
"Separating: d{}/d{} = ({}) * ({})",
ode.dependent, ode.independent, g_x, h_y
));
steps.push(format!(
"Rearranging: (1/({})) d{} = ({}) d{}",
h_y, ode.dependent, g_x, ode.independent
));
let one_over_h_y = Expression::Binary(
BinaryOp::Div,
Box::new(Expression::Integer(1)),
Box::new(h_y.clone()),
);
let left_integral = integrate(&one_over_h_y, &ode.dependent)?;
let right_integral = integrate(&g_x, &ode.independent)?;
steps.push(format!(
"Integrating left side: ∫(1/({})) d{} = {}",
h_y, ode.dependent, left_integral
));
steps.push(format!(
"Integrating right side: ∫({}) d{} = {} + C",
g_x, ode.independent, right_integral
));
let c = Expression::Variable(Variable::new("C"));
let rhs_with_c = Expression::Binary(BinaryOp::Add, Box::new(right_integral), Box::new(c));
let solution = try_solve_implicit_for_y(&left_integral, &rhs_with_c, &ode.dependent)
.unwrap_or_else(|| {
Expression::Binary(
BinaryOp::Sub,
Box::new(left_integral.clone()),
Box::new(rhs_with_c.clone()),
)
});
steps.push(format!(
"General solution: {} = {}",
ode.dependent, solution
));
Ok(ODESolution {
general_solution: solution,
method: "Separation of variables".to_string(),
steps,
})
}
#[must_use = "solving returns a result that should be used"]
pub fn solve_linear(ode: &FirstOrderODE) -> Result<ODESolution, ODEError> {
let mut steps = Vec::new();
steps.push(format!(
"Given ODE: d{}/d{} = {}",
ode.dependent, ode.independent, ode.rhs
));
let (p_x, q_x) = extract_linear_coefficients(&ode.rhs, &ode.independent, &ode.dependent)
.ok_or(ODEError::NotLinear)?;
steps.push(format!(
"Standard form: d{}/d{} + ({}) * {} = {}",
ode.dependent, ode.independent, p_x, ode.dependent, q_x
));
let p_integral = integrate(&p_x, &ode.independent)?;
let mu = Expression::Function(Function::Exp, vec![p_integral.clone()]);
steps.push(format!(
"Integrating factor: μ({}) = e^(∫{} d{}) = e^({})",
ode.independent, p_x, ode.independent, p_integral
));
let mu_times_q = Expression::Binary(BinaryOp::Mul, Box::new(mu.clone()), Box::new(q_x.clone()));
let mu_q_integral = integrate(&mu_times_q.simplify(), &ode.independent)?;
steps.push(format!(
"Integrating: ∫μ({}) * ({}) d{} = {}",
ode.independent, q_x, ode.independent, mu_q_integral
));
let c = Expression::Variable(Variable::new("C"));
let integral_plus_c = Expression::Binary(BinaryOp::Add, Box::new(mu_q_integral), Box::new(c));
let solution = Expression::Binary(
BinaryOp::Div,
Box::new(integral_plus_c),
Box::new(mu.clone()),
)
.simplify();
steps.push(format!(
"General solution: {} = (∫μQ d{} + C) / μ = {}",
ode.dependent, ode.independent, solution
));
Ok(ODESolution {
general_solution: solution,
method: "Integrating factor".to_string(),
steps,
})
}
pub fn solve_ivp(
ode: &FirstOrderODE,
x0: &Expression,
y0: &Expression,
) -> Result<ODESolution, ODEError> {
let general = if ode.is_separable() {
solve_separable(ode)?
} else if ode.is_linear() {
solve_linear(ode)?
} else {
return Err(ODEError::CannotSolve(
"ODE is neither separable nor linear".to_string(),
));
};
let mut steps = general.steps.clone();
steps.push(format!(
"Applying initial condition: {}({}) = {}",
ode.dependent, x0, y0
));
let substituted = substitute_var(&general.general_solution, &ode.independent, x0);
let equation = Expression::Binary(BinaryOp::Sub, Box::new(substituted), Box::new(y0.clone()));
if let Some(c_value) = solve_for_constant(&equation.simplify(), "C") {
steps.push(format!("Solving for C: C = {}", c_value));
let particular = substitute_var(&general.general_solution, "C", &c_value).simplify();
steps.push(format!(
"Particular solution: {} = {}",
ode.dependent, particular
));
Ok(ODESolution {
general_solution: particular,
method: format!("{} with initial condition", general.method),
steps,
})
} else {
Err(ODEError::InitialConditionError(
"Could not solve for constant C".to_string(),
))
}
}
fn try_separate(expr: &Expression, x_var: &str, y_var: &str) -> Option<(Expression, Expression)> {
if let Expression::Binary(BinaryOp::Mul, left, right) = expr {
let left_has_x = left.contains_variable(x_var);
let left_has_y = left.contains_variable(y_var);
let right_has_x = right.contains_variable(x_var);
let right_has_y = right.contains_variable(y_var);
if left_has_x && !left_has_y && right_has_y && !right_has_x {
return Some((left.as_ref().clone(), right.as_ref().clone()));
}
if left_has_y && !left_has_x && right_has_x && !right_has_y {
return Some((right.as_ref().clone(), left.as_ref().clone()));
}
if (left_has_x || right_has_x) && !left_has_y && !right_has_y {
return Some((expr.clone(), Expression::Integer(1)));
}
if (left_has_y || right_has_y) && !left_has_x && !right_has_x {
return Some((Expression::Integer(1), expr.clone()));
}
}
let has_x = expr.contains_variable(x_var);
let has_y = expr.contains_variable(y_var);
if has_x && !has_y {
return Some((expr.clone(), Expression::Integer(1)));
}
if has_y && !has_x {
return Some((Expression::Integer(1), expr.clone()));
}
if !has_x && !has_y {
return Some((expr.clone(), Expression::Integer(1)));
}
if let Expression::Binary(BinaryOp::Div, num, denom) = expr {
let num_has_x = num.contains_variable(x_var);
let num_has_y = num.contains_variable(y_var);
let denom_has_x = denom.contains_variable(x_var);
let denom_has_y = denom.contains_variable(y_var);
if num_has_x && !num_has_y && denom_has_y && !denom_has_x {
let h_y = Expression::Binary(
BinaryOp::Div,
Box::new(Expression::Integer(1)),
denom.clone(),
);
return Some((num.as_ref().clone(), h_y));
}
}
None
}
fn extract_linear_coefficients(
rhs: &Expression,
_x_var: &str,
y_var: &str,
) -> Option<(Expression, Expression)> {
let mut y_coefficient = Expression::Integer(0);
let mut constant_terms = Expression::Integer(0);
fn collect_terms(
expr: &Expression,
y_var: &str,
y_coeff: &mut Expression,
const_terms: &mut Expression,
) -> bool {
match expr {
Expression::Variable(v) if v.name == y_var => {
*y_coeff = Expression::Binary(
BinaryOp::Add,
Box::new(y_coeff.clone()),
Box::new(Expression::Integer(1)),
);
true
}
Expression::Binary(BinaryOp::Add, left, right) => {
collect_terms(left, y_var, y_coeff, const_terms)
&& collect_terms(right, y_var, y_coeff, const_terms)
}
Expression::Binary(BinaryOp::Sub, left, right) => {
let mut neg_y_coeff = Expression::Integer(0);
let mut neg_const = Expression::Integer(0);
if !collect_terms(left, y_var, y_coeff, const_terms) {
return false;
}
if !collect_terms(right, y_var, &mut neg_y_coeff, &mut neg_const) {
return false;
}
*y_coeff = Expression::Binary(
BinaryOp::Sub,
Box::new(y_coeff.clone()),
Box::new(neg_y_coeff),
);
*const_terms = Expression::Binary(
BinaryOp::Sub,
Box::new(const_terms.clone()),
Box::new(neg_const),
);
true
}
Expression::Binary(BinaryOp::Mul, left, right) => {
let left_has_y = left.contains_variable(y_var);
let right_has_y = right.contains_variable(y_var);
if left_has_y && right_has_y {
return false;
}
if !left_has_y && !right_has_y {
*const_terms = Expression::Binary(
BinaryOp::Add,
Box::new(const_terms.clone()),
Box::new(expr.clone()),
);
return true;
}
if left_has_y {
if matches!(left.as_ref(), Expression::Variable(v) if v.name == y_var) {
*y_coeff = Expression::Binary(
BinaryOp::Add,
Box::new(y_coeff.clone()),
right.clone(),
);
return true;
}
} else {
if matches!(right.as_ref(), Expression::Variable(v) if v.name == y_var) {
*y_coeff = Expression::Binary(
BinaryOp::Add,
Box::new(y_coeff.clone()),
left.clone(),
);
return true;
}
}
false
}
Expression::Unary(UnaryOp::Neg, inner) => {
let mut neg_y_coeff = Expression::Integer(0);
let mut neg_const = Expression::Integer(0);
if !collect_terms(inner, y_var, &mut neg_y_coeff, &mut neg_const) {
return false;
}
*y_coeff = Expression::Binary(
BinaryOp::Sub,
Box::new(y_coeff.clone()),
Box::new(neg_y_coeff),
);
*const_terms = Expression::Binary(
BinaryOp::Sub,
Box::new(const_terms.clone()),
Box::new(neg_const),
);
true
}
_ if !expr.contains_variable(y_var) => {
*const_terms = Expression::Binary(
BinaryOp::Add,
Box::new(const_terms.clone()),
Box::new(expr.clone()),
);
true
}
_ => false,
}
}
if !collect_terms(rhs, y_var, &mut y_coefficient, &mut constant_terms) {
return None;
}
let y_coeff = y_coefficient.simplify();
let q_x = constant_terms.simplify();
let p_x = Expression::Unary(UnaryOp::Neg, Box::new(y_coeff)).simplify();
if p_x.contains_variable(y_var) {
return None;
}
Some((p_x, q_x))
}
fn substitute_var(expr: &Expression, var: &str, replacement: &Expression) -> Expression {
match expr {
Expression::Variable(v) if v.name == var => replacement.clone(),
Expression::Variable(_) => expr.clone(),
Expression::Integer(_) | Expression::Float(_) | Expression::Rational(_) => expr.clone(),
Expression::Constant(_) => expr.clone(),
Expression::Complex(_) => expr.clone(),
Expression::Binary(op, left, right) => Expression::Binary(
*op,
Box::new(substitute_var(left, var, replacement)),
Box::new(substitute_var(right, var, replacement)),
),
Expression::Unary(op, operand) => {
Expression::Unary(*op, Box::new(substitute_var(operand, var, replacement)))
}
Expression::Function(func, args) => Expression::Function(
func.clone(),
args.iter()
.map(|arg| substitute_var(arg, var, replacement))
.collect(),
),
Expression::Power(base, exp) => Expression::Power(
Box::new(substitute_var(base, var, replacement)),
Box::new(substitute_var(exp, var, replacement)),
),
}
}
fn try_solve_implicit_for_y(
left: &Expression,
right: &Expression,
y_var: &str,
) -> Option<Expression> {
if matches!(left, Expression::Variable(v) if v.name == y_var) {
return Some(right.clone());
}
if let Expression::Function(Function::Ln, args) = left {
if args.len() == 1 {
if matches!(&args[0], Expression::Variable(v) if v.name == y_var) {
return Some(Expression::Function(Function::Exp, vec![right.clone()]));
}
if let Expression::Function(Function::Abs, inner_args) = &args[0] {
if inner_args.len() == 1 {
if matches!(&inner_args[0], Expression::Variable(v) if v.name == y_var) {
return Some(Expression::Function(Function::Exp, vec![right.clone()]));
}
}
}
}
}
if let Expression::Power(base, exp) = left {
if matches!(base.as_ref(), Expression::Variable(v) if v.name == y_var) {
if !exp.contains_variable(y_var) {
let one_over_n = Expression::Binary(
BinaryOp::Div,
Box::new(Expression::Integer(1)),
exp.clone(),
);
return Some(Expression::Power(
Box::new(right.clone()),
Box::new(one_over_n),
));
}
}
}
if let Expression::Binary(BinaryOp::Div, num, denom) = left {
if matches!(num.as_ref(), Expression::Integer(1)) {
if matches!(denom.as_ref(), Expression::Variable(v) if v.name == y_var) {
return Some(Expression::Binary(
BinaryOp::Div,
Box::new(Expression::Integer(1)),
Box::new(right.clone()),
));
}
}
}
None
}
fn solve_for_constant(equation: &Expression, const_name: &str) -> Option<Expression> {
match equation {
Expression::Binary(BinaryOp::Sub, left, right) => {
if matches!(left.as_ref(), Expression::Variable(v) if v.name == const_name) {
return Some(right.as_ref().clone());
}
if matches!(right.as_ref(), Expression::Variable(v) if v.name == const_name) {
return Some(left.as_ref().clone());
}
}
Expression::Binary(BinaryOp::Add, left, right) => {
if matches!(left.as_ref(), Expression::Variable(v) if v.name == const_name) {
return Some(Expression::Unary(UnaryOp::Neg, right.clone()));
}
if matches!(right.as_ref(), Expression::Variable(v) if v.name == const_name) {
return Some(Expression::Unary(UnaryOp::Neg, left.clone()));
}
}
_ => {}
}
if let Some(c_value) = try_numerical_solve_for_c(equation, const_name) {
return Some(c_value);
}
None
}
fn try_numerical_solve_for_c(equation: &Expression, const_name: &str) -> Option<Expression> {
if !equation.contains_variable(const_name) {
return None;
}
let mut c_coefficient = Expression::Integer(0);
let mut constant_part = Expression::Integer(0);
fn extract_c_terms(
expr: &Expression,
c_name: &str,
c_coeff: &mut Expression,
const_part: &mut Expression,
) -> bool {
match expr {
Expression::Variable(v) if v.name == c_name => {
*c_coeff = Expression::Binary(
BinaryOp::Add,
Box::new(c_coeff.clone()),
Box::new(Expression::Integer(1)),
);
true
}
Expression::Binary(BinaryOp::Add, left, right) => {
extract_c_terms(left, c_name, c_coeff, const_part)
&& extract_c_terms(right, c_name, c_coeff, const_part)
}
Expression::Binary(BinaryOp::Sub, left, right) => {
let mut neg_c = Expression::Integer(0);
let mut neg_const = Expression::Integer(0);
if !extract_c_terms(left, c_name, c_coeff, const_part) {
return false;
}
if !extract_c_terms(right, c_name, &mut neg_c, &mut neg_const) {
return false;
}
*c_coeff =
Expression::Binary(BinaryOp::Sub, Box::new(c_coeff.clone()), Box::new(neg_c));
*const_part = Expression::Binary(
BinaryOp::Sub,
Box::new(const_part.clone()),
Box::new(neg_const),
);
true
}
Expression::Binary(BinaryOp::Mul, left, right) => {
let left_has_c = left.contains_variable(c_name);
let right_has_c = right.contains_variable(c_name);
if left_has_c && right_has_c {
return false; }
if !left_has_c && !right_has_c {
*const_part = Expression::Binary(
BinaryOp::Add,
Box::new(const_part.clone()),
Box::new(expr.clone()),
);
return true;
}
if left_has_c {
if matches!(left.as_ref(), Expression::Variable(v) if v.name == c_name) {
*c_coeff = Expression::Binary(
BinaryOp::Add,
Box::new(c_coeff.clone()),
right.clone(),
);
return true;
}
} else if matches!(right.as_ref(), Expression::Variable(v) if v.name == c_name) {
*c_coeff =
Expression::Binary(BinaryOp::Add, Box::new(c_coeff.clone()), left.clone());
return true;
}
false
}
Expression::Unary(UnaryOp::Neg, inner) => {
let mut neg_c = Expression::Integer(0);
let mut neg_const = Expression::Integer(0);
if !extract_c_terms(inner, c_name, &mut neg_c, &mut neg_const) {
return false;
}
*c_coeff =
Expression::Binary(BinaryOp::Sub, Box::new(c_coeff.clone()), Box::new(neg_c));
*const_part = Expression::Binary(
BinaryOp::Sub,
Box::new(const_part.clone()),
Box::new(neg_const),
);
true
}
_ if !expr.contains_variable(c_name) => {
*const_part = Expression::Binary(
BinaryOp::Add,
Box::new(const_part.clone()),
Box::new(expr.clone()),
);
true
}
_ => false,
}
}
if !extract_c_terms(equation, const_name, &mut c_coefficient, &mut constant_part) {
return None;
}
let c_coeff = c_coefficient.simplify();
let b = constant_part.simplify();
let neg_b = Expression::Unary(UnaryOp::Neg, Box::new(b));
let c_value = Expression::Binary(BinaryOp::Div, Box::new(neg_b), Box::new(c_coeff)).simplify();
Some(c_value)
}
#[derive(Debug, Clone, PartialEq)]
pub enum RootType {
TwoDistinctReal,
RepeatedReal,
ComplexConjugate,
}
#[derive(Debug, Clone)]
pub struct CharacteristicRoots {
pub r1: f64,
pub r2: f64,
pub root_type: RootType,
}
#[derive(Debug, Clone)]
pub struct SecondOrderODE {
pub dependent: String,
pub independent: String,
pub a: f64,
pub b: f64,
pub c: f64,
pub forcing: Expression,
}
impl SecondOrderODE {
pub fn new(
dependent: &str,
independent: &str,
a: f64,
b: f64,
c: f64,
forcing: Expression,
) -> Self {
SecondOrderODE {
dependent: dependent.to_string(),
independent: independent.to_string(),
a,
b,
c,
forcing,
}
}
pub fn homogeneous(dependent: &str, independent: &str, a: f64, b: f64, c: f64) -> Self {
Self::new(dependent, independent, a, b, c, Expression::Integer(0))
}
pub fn is_homogeneous(&self) -> bool {
matches!(&self.forcing, Expression::Integer(0))
|| matches!(&self.forcing, Expression::Float(x) if x.abs() < 1e-15)
}
}
#[derive(Debug, Clone)]
pub struct SecondOrderSolution {
pub homogeneous_solution: Expression,
pub particular_solution: Option<Expression>,
pub general_solution: Expression,
pub method: String,
pub roots: CharacteristicRoots,
pub steps: Vec<String>,
}
#[must_use = "solving returns a result that should be used"]
pub fn solve_characteristic_equation(
a: f64,
b: f64,
c: f64,
) -> Result<CharacteristicRoots, ODEError> {
if a.abs() < 1e-15 {
return Err(ODEError::CharacteristicEquationError(
"Coefficient 'a' cannot be zero for second-order ODE".to_string(),
));
}
let discriminant = b * b - 4.0 * a * c;
const EPSILON: f64 = 1e-10;
if discriminant > EPSILON {
let sqrt_disc = discriminant.sqrt();
let r1 = (-b + sqrt_disc) / (2.0 * a);
let r2 = (-b - sqrt_disc) / (2.0 * a);
Ok(CharacteristicRoots {
r1,
r2,
root_type: RootType::TwoDistinctReal,
})
} else if discriminant < -EPSILON {
let alpha = -b / (2.0 * a);
let beta = (-discriminant).sqrt() / (2.0 * a);
Ok(CharacteristicRoots {
r1: alpha,
r2: beta,
root_type: RootType::ComplexConjugate,
})
} else {
let r = -b / (2.0 * a);
Ok(CharacteristicRoots {
r1: r,
r2: r,
root_type: RootType::RepeatedReal,
})
}
}
fn build_solution_distinct_real(r1: f64, r2: f64, x_var: &str) -> Expression {
let x = Expression::Variable(Variable::new(x_var));
let c1 = Expression::Variable(Variable::new("C1"));
let c2 = Expression::Variable(Variable::new("C2"));
let exp1_arg = Expression::Binary(
BinaryOp::Mul,
Box::new(Expression::Float(r1)),
Box::new(x.clone()),
);
let exp1 = Expression::Function(Function::Exp, vec![exp1_arg]);
let term1 = Expression::Binary(BinaryOp::Mul, Box::new(c1), Box::new(exp1));
let exp2_arg = Expression::Binary(BinaryOp::Mul, Box::new(Expression::Float(r2)), Box::new(x));
let exp2 = Expression::Function(Function::Exp, vec![exp2_arg]);
let term2 = Expression::Binary(BinaryOp::Mul, Box::new(c2), Box::new(exp2));
Expression::Binary(BinaryOp::Add, Box::new(term1), Box::new(term2))
}
fn build_solution_repeated(r: f64, x_var: &str) -> Expression {
let x = Expression::Variable(Variable::new(x_var));
let c1 = Expression::Variable(Variable::new("C1"));
let c2 = Expression::Variable(Variable::new("C2"));
let c2_x = Expression::Binary(BinaryOp::Mul, Box::new(c2), Box::new(x.clone()));
let linear = Expression::Binary(BinaryOp::Add, Box::new(c1), Box::new(c2_x));
let exp_arg = Expression::Binary(BinaryOp::Mul, Box::new(Expression::Float(r)), Box::new(x));
let exp_term = Expression::Function(Function::Exp, vec![exp_arg]);
Expression::Binary(BinaryOp::Mul, Box::new(linear), Box::new(exp_term))
}
fn build_solution_complex(alpha: f64, beta: f64, x_var: &str) -> Expression {
let x = Expression::Variable(Variable::new(x_var));
let c1 = Expression::Variable(Variable::new("C1"));
let c2 = Expression::Variable(Variable::new("C2"));
let beta_x = Expression::Binary(
BinaryOp::Mul,
Box::new(Expression::Float(beta)),
Box::new(x.clone()),
);
let cos_term = Expression::Function(Function::Cos, vec![beta_x.clone()]);
let term1 = Expression::Binary(BinaryOp::Mul, Box::new(c1), Box::new(cos_term));
let sin_term = Expression::Function(Function::Sin, vec![beta_x]);
let term2 = Expression::Binary(BinaryOp::Mul, Box::new(c2), Box::new(sin_term));
let oscillatory = Expression::Binary(BinaryOp::Add, Box::new(term1), Box::new(term2));
if alpha.abs() < 1e-10 {
oscillatory
} else {
let exp_arg = Expression::Binary(
BinaryOp::Mul,
Box::new(Expression::Float(alpha)),
Box::new(x),
);
let exp_term = Expression::Function(Function::Exp, vec![exp_arg]);
Expression::Binary(BinaryOp::Mul, Box::new(exp_term), Box::new(oscillatory))
}
}
#[must_use = "solving returns a result that should be used"]
pub fn solve_second_order_homogeneous(
ode: &SecondOrderODE,
) -> Result<SecondOrderSolution, ODEError> {
let mut steps = Vec::new();
steps.push(format!(
"Given ODE: {}·{}'' + {}·{}' + {}·{} = 0",
ode.a, ode.dependent, ode.b, ode.dependent, ode.c, ode.dependent
));
steps.push(format!(
"Characteristic equation: {}·r² + {}·r + {} = 0",
ode.a, ode.b, ode.c
));
let roots = solve_characteristic_equation(ode.a, ode.b, ode.c)?;
let (method, solution) = match roots.root_type {
RootType::TwoDistinctReal => {
steps.push(format!(
"Discriminant Δ = {}² - 4·{}·{} = {} > 0",
ode.b,
ode.a,
ode.c,
ode.b * ode.b - 4.0 * ode.a * ode.c
));
steps.push(format!(
"Two distinct real roots: r₁ = {:.4}, r₂ = {:.4}",
roots.r1, roots.r2
));
steps.push(format!(
"General solution: y = C1·e^({:.4}·{}) + C2·e^({:.4}·{})",
roots.r1, ode.independent, roots.r2, ode.independent
));
(
"Characteristic equation - distinct real roots".to_string(),
build_solution_distinct_real(roots.r1, roots.r2, &ode.independent),
)
}
RootType::RepeatedReal => {
steps.push(format!(
"Discriminant Δ = {}² - 4·{}·{} = 0",
ode.b, ode.a, ode.c
));
steps.push(format!("Repeated root: r = {:.4}", roots.r1));
steps.push(format!(
"General solution: y = (C1 + C2·{})·e^({:.4}·{})",
ode.independent, roots.r1, ode.independent
));
(
"Characteristic equation - repeated root".to_string(),
build_solution_repeated(roots.r1, &ode.independent),
)
}
RootType::ComplexConjugate => {
steps.push(format!(
"Discriminant Δ = {}² - 4·{}·{} = {} < 0",
ode.b,
ode.a,
ode.c,
ode.b * ode.b - 4.0 * ode.a * ode.c
));
steps.push(format!(
"Complex conjugate roots: r = {:.4} ± {:.4}i",
roots.r1, roots.r2
));
if roots.r1.abs() < 1e-10 {
steps.push(format!(
"General solution: y = C1·cos({:.4}·{}) + C2·sin({:.4}·{})",
roots.r2, ode.independent, roots.r2, ode.independent
));
} else {
steps.push(format!(
"General solution: y = e^({:.4}·{})·(C1·cos({:.4}·{}) + C2·sin({:.4}·{}))",
roots.r1, ode.independent, roots.r2, ode.independent, roots.r2, ode.independent
));
}
(
"Characteristic equation - complex conjugate roots".to_string(),
build_solution_complex(roots.r1, roots.r2, &ode.independent),
)
}
};
Ok(SecondOrderSolution {
homogeneous_solution: solution.clone(),
particular_solution: None,
general_solution: solution,
method,
roots,
steps,
})
}
pub fn solve_second_order_ivp(
ode: &SecondOrderODE,
x0: f64,
y0: f64,
yp0: f64,
) -> Result<Expression, ODEError> {
if !ode.is_homogeneous() {
return Err(ODEError::CannotSolve(
"IVP solver currently only supports homogeneous equations".to_string(),
));
}
let solution = solve_second_order_homogeneous(ode)?;
let (c1, c2) = match solution.roots.root_type {
RootType::TwoDistinctReal => {
let r1 = solution.roots.r1;
let r2 = solution.roots.r2;
let e1 = (r1 * x0).exp();
let e2 = (r2 * x0).exp();
let det = e1 * r2 * e2 - e2 * r1 * e1;
if det.abs() < 1e-15 {
return Err(ODEError::InitialConditionError(
"Cannot determine constants - singular system".to_string(),
));
}
let c1 = (y0 * r2 * e2 - yp0 * e2) / det;
let c2 = (yp0 * e1 - y0 * r1 * e1) / det;
(c1, c2)
}
RootType::RepeatedReal => {
let r = solution.roots.r1;
let e = (r * x0).exp();
let y0_over_e = y0 / e;
let c2 = (yp0 / e) - r * y0_over_e;
let c1 = y0_over_e - c2 * x0;
(c1, c2)
}
RootType::ComplexConjugate => {
let alpha = solution.roots.r1;
let beta = solution.roots.r2;
let e = (alpha * x0).exp();
let cos_bx0 = (beta * x0).cos();
let sin_bx0 = (beta * x0).sin();
let a11 = e * cos_bx0;
let a12 = e * sin_bx0;
let a21 = e * (alpha * cos_bx0 - beta * sin_bx0);
let a22 = e * (alpha * sin_bx0 + beta * cos_bx0);
let det = a11 * a22 - a12 * a21;
if det.abs() < 1e-15 {
return Err(ODEError::InitialConditionError(
"Cannot determine constants - singular system".to_string(),
));
}
let c1 = (y0 * a22 - yp0 * a12) / det;
let c2 = (yp0 * a11 - y0 * a21) / det;
(c1, c2)
}
};
let general = solution.general_solution;
let with_c1 = substitute_var(&general, "C1", &Expression::Float(c1));
let with_c2 = substitute_var(&with_c1, "C2", &Expression::Float(c2));
Ok(with_c2.simplify())
}
#[cfg(test)]
mod tests {
use super::*;
fn var(name: &str) -> Expression {
Expression::Variable(Variable::new(name))
}
fn int(n: i64) -> Expression {
Expression::Integer(n)
}
fn mul(left: Expression, right: Expression) -> Expression {
Expression::Binary(BinaryOp::Mul, Box::new(left), Box::new(right))
}
fn add(left: Expression, right: Expression) -> Expression {
Expression::Binary(BinaryOp::Add, Box::new(left), Box::new(right))
}
fn div(left: Expression, right: Expression) -> Expression {
Expression::Binary(BinaryOp::Div, Box::new(left), Box::new(right))
}
fn neg(expr: Expression) -> Expression {
Expression::Unary(UnaryOp::Neg, Box::new(expr))
}
#[test]
fn test_try_separate_simple_product() {
let expr = mul(var("x"), var("y"));
let result = try_separate(&expr, "x", "y");
assert!(result.is_some());
let (g_x, h_y) = result.unwrap();
assert!(matches!(g_x, Expression::Variable(v) if v.name == "x"));
assert!(matches!(h_y, Expression::Variable(v) if v.name == "y"));
}
#[test]
fn test_try_separate_only_x() {
let x = var("x");
let expr = Expression::Power(Box::new(x), Box::new(int(2)));
let result = try_separate(&expr, "x", "y");
assert!(result.is_some());
let (g_x, h_y) = result.unwrap();
assert!(matches!(g_x, Expression::Power(_, _)));
assert!(matches!(h_y, Expression::Integer(1)));
}
#[test]
fn test_try_separate_only_y() {
let y = var("y");
let expr = Expression::Power(Box::new(y), Box::new(int(2)));
let result = try_separate(&expr, "x", "y");
assert!(result.is_some());
let (g_x, h_y) = result.unwrap();
assert!(matches!(g_x, Expression::Integer(1)));
assert!(matches!(h_y, Expression::Power(_, _)));
}
#[test]
fn test_try_separate_constant() {
let expr = int(5);
let result = try_separate(&expr, "x", "y");
assert!(result.is_some());
let (g_x, h_y) = result.unwrap();
assert!(matches!(g_x, Expression::Integer(5)));
assert!(matches!(h_y, Expression::Integer(1)));
}
#[test]
fn test_is_separable() {
let ode = FirstOrderODE::new("y", "x", mul(var("x"), var("y")));
assert!(ode.is_separable());
let ode2 = FirstOrderODE::new("y", "x", add(var("x"), var("y")));
assert!(!ode2.is_separable());
}
#[test]
fn test_is_linear() {
let ode = FirstOrderODE::new("y", "x", add(neg(var("y")), var("x")));
assert!(ode.is_linear());
let y = var("y");
let ode2 = FirstOrderODE::new("y", "x", Expression::Power(Box::new(y), Box::new(int(2))));
assert!(!ode2.is_linear());
}
#[test]
fn test_extract_linear_coefficients() {
let rhs = add(mul(int(-2), var("y")), mul(int(3), var("x")));
let result = extract_linear_coefficients(&rhs, "x", "y");
assert!(result.is_some());
}
#[test]
fn test_solve_separable_simple() {
let ode = FirstOrderODE::new("y", "x", var("y"));
let result = solve_separable(&ode);
assert!(result.is_ok());
let solution = result.unwrap();
assert_eq!(solution.method, "Separation of variables");
assert!(!solution.steps.is_empty());
}
#[test]
fn test_solve_separable_xy() {
let ode = FirstOrderODE::new("y", "x", mul(var("x"), var("y")));
let result = solve_separable(&ode);
assert!(result.is_ok());
}
#[test]
fn test_solve_linear_simple() {
let ode = FirstOrderODE::new("y", "x", neg(var("y")));
let result = solve_linear(&ode);
assert!(result.is_ok());
let solution = result.unwrap();
assert_eq!(solution.method, "Integrating factor");
}
#[test]
fn test_solve_ivp() {
let ode = FirstOrderODE::new("y", "x", var("y"));
let result = solve_ivp(&ode, &int(0), &int(1));
assert!(result.is_ok());
}
#[test]
fn test_substitute_var() {
let expr = add(var("x"), var("y"));
let result = substitute_var(&expr, "x", &int(5));
assert!(matches!(
result,
Expression::Binary(BinaryOp::Add, left, _) if matches!(left.as_ref(), Expression::Integer(5))
));
}
#[test]
fn test_try_solve_implicit_ln_y() {
let left = Expression::Function(Function::Ln, vec![var("y")]);
let right = add(var("x"), var("C"));
let result = try_solve_implicit_for_y(&left, &right, "y");
assert!(result.is_some());
assert!(matches!(
result.unwrap(),
Expression::Function(Function::Exp, _)
));
}
#[test]
fn test_characteristic_equation_distinct_real() {
let roots = solve_characteristic_equation(1.0, 0.0, -1.0).unwrap();
assert_eq!(roots.root_type, RootType::TwoDistinctReal);
assert!((roots.r1 - 1.0).abs() < 1e-10);
assert!((roots.r2 - (-1.0)).abs() < 1e-10);
}
#[test]
fn test_characteristic_equation_complex() {
let roots = solve_characteristic_equation(1.0, 0.0, 1.0).unwrap();
assert_eq!(roots.root_type, RootType::ComplexConjugate);
assert!(roots.r1.abs() < 1e-10); assert!((roots.r2 - 1.0).abs() < 1e-10); }
#[test]
fn test_characteristic_equation_repeated() {
let roots = solve_characteristic_equation(1.0, -2.0, 1.0).unwrap();
assert_eq!(roots.root_type, RootType::RepeatedReal);
assert!((roots.r1 - 1.0).abs() < 1e-10);
assert!((roots.r2 - 1.0).abs() < 1e-10);
}
#[test]
fn test_second_order_homogeneous_distinct_real() {
let ode = SecondOrderODE::homogeneous("y", "x", 1.0, 0.0, -1.0);
let solution = solve_second_order_homogeneous(&ode).unwrap();
assert_eq!(
solution.method,
"Characteristic equation - distinct real roots"
);
assert_eq!(solution.roots.root_type, RootType::TwoDistinctReal);
assert!(!solution.steps.is_empty());
}
#[test]
fn test_second_order_homogeneous_complex() {
let ode = SecondOrderODE::homogeneous("y", "x", 1.0, 0.0, 1.0);
let solution = solve_second_order_homogeneous(&ode).unwrap();
assert_eq!(
solution.method,
"Characteristic equation - complex conjugate roots"
);
assert_eq!(solution.roots.root_type, RootType::ComplexConjugate);
}
#[test]
fn test_second_order_homogeneous_repeated() {
let ode = SecondOrderODE::homogeneous("y", "x", 1.0, -2.0, 1.0);
let solution = solve_second_order_homogeneous(&ode).unwrap();
assert_eq!(solution.method, "Characteristic equation - repeated root");
assert_eq!(solution.roots.root_type, RootType::RepeatedReal);
}
#[test]
fn test_second_order_ivp_complex() {
let ode = SecondOrderODE::homogeneous("y", "x", 1.0, 0.0, 1.0);
let solution = solve_second_order_ivp(&ode, 0.0, 1.0, 0.0).unwrap();
let mut vars = std::collections::HashMap::new();
vars.insert("x".to_string(), 0.0);
let result = solution.evaluate(&vars).unwrap();
assert!((result - 1.0).abs() < 1e-10);
vars.insert("x".to_string(), std::f64::consts::FRAC_PI_2);
let result = solution.evaluate(&vars).unwrap();
assert!(result.abs() < 1e-6);
}
#[test]
fn test_second_order_ivp_distinct_real() {
let ode = SecondOrderODE::homogeneous("y", "x", 1.0, 0.0, -1.0);
let solution = solve_second_order_ivp(&ode, 0.0, 1.0, 0.0).unwrap();
let mut vars = std::collections::HashMap::new();
vars.insert("x".to_string(), 0.0);
let result = solution.evaluate(&vars).unwrap();
assert!((result - 1.0).abs() < 1e-10);
vars.insert("x".to_string(), 1.0);
let result = solution.evaluate(&vars).unwrap();
let expected = 1.0_f64.cosh();
assert!((result - expected).abs() < 1e-6);
}
#[test]
fn test_second_order_ode_is_homogeneous() {
let ode1 = SecondOrderODE::homogeneous("y", "x", 1.0, 2.0, 3.0);
assert!(ode1.is_homogeneous());
let ode2 = SecondOrderODE::new("y", "x", 1.0, 2.0, 3.0, var("x"));
assert!(!ode2.is_homogeneous());
}
}