use crate::algebra::solvers::matrix_equations::MatrixEquationSolver;
use crate::algebra::solvers::{EquationSolver, SolverResult};
use crate::algebra::solvers::{LinearSolver, PolynomialSolver, QuadraticSolver, SystemSolver};
use crate::calculus::ode::EducationalODESolver;
use crate::calculus::pde::EducationalPDESolver;
use crate::core::symbol::SymbolType;
use crate::core::{Expression, Number, Symbol};
use crate::educational::step_by_step::{Step, StepByStepExplanation};
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum EquationType {
Constant,
Linear,
Quadratic,
Cubic,
Quartic,
System,
Transcendental,
Numerical,
Matrix,
ODE,
PDE,
Unknown,
}
pub struct EquationAnalyzer;
impl EquationAnalyzer {
pub fn analyze(equation: &Expression, variable: &Symbol) -> EquationType {
let has_derivatives = Self::has_derivatives(equation);
let has_partial_derivatives = Self::has_partial_derivatives(equation);
if has_partial_derivatives {
return EquationType::PDE;
}
if has_derivatives {
return EquationType::ODE;
}
if Self::is_matrix_equation(equation, variable) {
return EquationType::Matrix;
}
let degree = Self::find_highest_degree(equation, variable);
let has_transcendental = Self::has_transcendental_functions(equation);
let variable_count = Self::count_variables(equation);
if Self::is_numerical_equation(equation, variable, degree, has_transcendental) {
return EquationType::Numerical;
}
match (degree, has_transcendental, variable_count) {
(0, false, _) => EquationType::Constant,
(1, false, 1) => EquationType::Linear,
(2, false, 1) => EquationType::Quadratic,
(3, false, 1) => EquationType::Cubic,
(4, false, 1) => EquationType::Quartic,
(_, false, 2..) => EquationType::System,
(_, true, _) => EquationType::Transcendental,
_ => EquationType::Unknown,
}
}
fn is_numerical_equation(
expr: &Expression,
_variable: &Symbol,
degree: u32,
has_transcendental: bool,
) -> bool {
if degree > 4 {
return true;
}
if has_transcendental && degree > 0 {
return true;
}
if has_transcendental {
let func_count = Self::count_transcendental_functions(expr);
if func_count > 1 {
return true;
}
}
false
}
fn count_transcendental_functions(expr: &Expression) -> usize {
match expr {
Expression::Function { name, args } => {
let current =
if matches!(name.as_ref(), "sin" | "cos" | "tan" | "exp" | "ln" | "log") {
1
} else {
0
};
current
+ args
.iter()
.map(Self::count_transcendental_functions)
.sum::<usize>()
}
Expression::Add(terms) => terms.iter().map(Self::count_transcendental_functions).sum(),
Expression::Mul(factors) => factors
.iter()
.map(Self::count_transcendental_functions)
.sum(),
Expression::Pow(base, exp) => {
Self::count_transcendental_functions(base)
+ Self::count_transcendental_functions(exp)
}
_ => 0,
}
}
fn is_matrix_equation(expr: &Expression, _variable: &Symbol) -> bool {
Self::has_noncommutative_symbols(expr)
}
fn has_noncommutative_symbols(expr: &Expression) -> bool {
match expr {
Expression::Symbol(s) => {
matches!(
s.symbol_type(),
SymbolType::Matrix | SymbolType::Operator | SymbolType::Quaternion
)
}
Expression::Add(terms) | Expression::Mul(terms) => {
terms.iter().any(Self::has_noncommutative_symbols)
}
Expression::Pow(base, exp) => {
Self::has_noncommutative_symbols(base) || Self::has_noncommutative_symbols(exp)
}
Expression::Function { args, .. } => args.iter().any(Self::has_noncommutative_symbols),
_ => false,
}
}
fn find_highest_degree(expr: &Expression, variable: &Symbol) -> u32 {
match expr {
Expression::Pow(base, exp) if **base == Expression::symbol(variable.clone()) => {
match exp.as_ref() {
Expression::Number(Number::Integer(n)) => *n as u32,
_ => 1,
}
}
Expression::Mul(factors) => factors
.iter()
.map(|f| Self::find_highest_degree(f, variable))
.max()
.unwrap_or(0),
Expression::Add(terms) => terms
.iter()
.map(|t| Self::find_highest_degree(t, variable))
.max()
.unwrap_or(0),
_ if *expr == Expression::symbol(variable.clone()) => 1,
_ => 0,
}
}
fn has_transcendental_functions(expr: &Expression) -> bool {
match expr {
Expression::Function { name, args } => {
matches!(name.as_ref(), "sin" | "cos" | "tan" | "exp" | "ln" | "log")
|| args.iter().any(Self::has_transcendental_functions)
}
Expression::Add(terms) => terms.iter().any(Self::has_transcendental_functions),
Expression::Mul(factors) => factors.iter().any(Self::has_transcendental_functions),
Expression::Pow(base, exp) => {
Self::has_transcendental_functions(base) || Self::has_transcendental_functions(exp)
}
_ => false,
}
}
fn count_variables(expr: &Expression) -> usize {
let mut variables = std::collections::HashSet::new();
Self::collect_variables(expr, &mut variables);
variables.len()
}
pub fn collect_variables(expr: &Expression, variables: &mut std::collections::HashSet<String>) {
match expr {
Expression::Symbol(s) => {
variables.insert(s.name().to_owned());
}
Expression::Add(terms) => {
for term in terms.iter() {
Self::collect_variables(term, variables);
}
}
Expression::Mul(factors) => {
for factor in factors.iter() {
Self::collect_variables(factor, variables);
}
}
Expression::Pow(base, exp) => {
Self::collect_variables(base, variables);
Self::collect_variables(exp, variables);
}
Expression::Function { args, .. } => {
for arg in args.iter() {
Self::collect_variables(arg, variables);
}
}
_ => {}
}
}
fn has_derivatives(expr: &Expression) -> bool {
match expr {
Expression::Function { name, args } => {
matches!(name.as_ref(), "derivative" | "diff" | "D")
|| args.iter().any(Self::has_derivatives)
}
Expression::Symbol(s) => {
let name = s.name();
name.ends_with('\'') || name.contains("_prime")
}
Expression::Add(terms) => terms.iter().any(Self::has_derivatives),
Expression::Mul(factors) => factors.iter().any(Self::has_derivatives),
Expression::Pow(base, exp) => Self::has_derivatives(base) || Self::has_derivatives(exp),
_ => false,
}
}
fn has_partial_derivatives(expr: &Expression) -> bool {
match expr {
Expression::Function { name, args } => {
matches!(name.as_ref(), "partial" | "pdiff" | "Partial")
|| args.iter().any(Self::has_partial_derivatives)
}
Expression::Symbol(s) => {
let name = s.name();
name.contains("partial") || name.contains("∂")
}
Expression::Add(terms) => terms.iter().any(Self::has_partial_derivatives),
Expression::Mul(factors) => factors.iter().any(Self::has_partial_derivatives),
Expression::Pow(base, exp) => {
Self::has_partial_derivatives(base) || Self::has_partial_derivatives(exp)
}
_ => false,
}
}
}
pub struct SmartEquationSolver {
linear_solver: LinearSolver,
quadratic_solver: QuadraticSolver,
system_solver: SystemSolver,
polynomial_solver: PolynomialSolver,
matrix_solver: MatrixEquationSolver,
ode_solver: EducationalODESolver,
pde_solver: EducationalPDESolver,
}
impl Default for SmartEquationSolver {
fn default() -> Self {
Self::new()
}
}
impl SmartEquationSolver {
pub fn new() -> Self {
Self {
linear_solver: LinearSolver::new(),
quadratic_solver: QuadraticSolver::new(),
system_solver: SystemSolver::new(),
polynomial_solver: PolynomialSolver::new(),
matrix_solver: MatrixEquationSolver::new(),
ode_solver: EducationalODESolver::new(),
pde_solver: EducationalPDESolver::new(),
}
}
pub fn solve_with_equation(
&self,
equation: &Expression,
variable: &Symbol,
) -> (SolverResult, StepByStepExplanation) {
let mut all_steps = Vec::new();
let degree = EquationAnalyzer::find_highest_degree(equation, variable);
let eq_type = EquationAnalyzer::analyze(equation, variable);
let analysis_description = match eq_type {
EquationType::Constant => {
"Detected constant equation (no variables)".to_owned()
}
EquationType::Linear => {
format!("Detected linear equation (highest degree: {})", degree)
}
EquationType::Quadratic => {
format!("Detected quadratic equation (highest degree: {})", degree)
}
EquationType::Cubic => {
format!("Detected cubic equation (highest degree: {})", degree)
}
EquationType::Quartic => {
format!("Detected quartic equation (highest degree: {})", degree)
}
EquationType::System => {
"Detected system of equations (multiple variables)".to_owned()
}
EquationType::Transcendental => {
"Detected transcendental equation (contains trig/exp/log functions)".to_owned()
}
EquationType::Numerical => {
"Detected numerical equation (requires numerical methods - polynomial degree > 4 or mixed transcendental)".to_owned()
}
EquationType::Matrix => {
"Detected matrix equation (contains noncommutative symbols)".to_owned()
}
EquationType::ODE => {
"Detected ordinary differential equation (contains derivatives)".to_owned()
}
EquationType::PDE => {
"Detected partial differential equation (contains partial derivatives)".to_owned()
}
EquationType::Unknown => {
"Unknown equation type".to_owned()
}
};
all_steps.push(Step::new("Equation Analysis", analysis_description));
let solver_description = match eq_type {
EquationType::Linear => "Using linear equation solver (isolation method)",
EquationType::Quadratic => "Using quadratic equation solver (quadratic formula)",
EquationType::Cubic | EquationType::Quartic => "Using polynomial solver",
EquationType::System => "Using system equation solver",
EquationType::Numerical => {
"Using numerical solver (Newton-Raphson method with numerical differentiation)"
}
EquationType::Matrix => "Using matrix equation solver (left/right division)",
EquationType::ODE => "Using ODE solver (separable/linear/exact methods)",
EquationType::PDE => {
"Using PDE solver (method of characteristics/separation of variables)"
}
_ => "No specialized solver available for this equation type",
};
all_steps.push(Step::new("Solver Selection", solver_description));
let (result, solver_steps) = match eq_type {
EquationType::Linear => self
.linear_solver
.solve_with_explanation(equation, variable),
EquationType::Quadratic => self
.quadratic_solver
.solve_with_explanation(equation, variable),
EquationType::Cubic | EquationType::Quartic => self
.polynomial_solver
.solve_with_explanation(equation, variable),
EquationType::System => self
.system_solver
.solve_with_explanation(equation, variable),
EquationType::Numerical => self.solve_numerical(equation, variable),
EquationType::Matrix => self
.matrix_solver
.solve_with_explanation(equation, variable),
EquationType::ODE => self.ode_solver.solve_with_explanation(equation, variable),
EquationType::PDE => self.pde_solver.solve_with_explanation(equation, variable),
_ => {
all_steps.push(Step::new(
"Status",
"This equation type is not yet fully implemented",
));
(SolverResult::NoSolution, StepByStepExplanation::new(vec![]))
}
};
all_steps.extend(solver_steps.steps);
(result, StepByStepExplanation::new(all_steps))
}
fn solve_numerical(
&self,
_equation: &Expression,
variable: &Symbol,
) -> (SolverResult, StepByStepExplanation) {
let steps = vec![
Step::new(
"Numerical Method Required",
format!(
"This equation requires numerical methods to solve for {}. Newton-Raphson method integration is available.",
variable.name()
),
),
Step::new(
"Method Description",
"Newton-Raphson method with numerical differentiation provides robust convergence for smooth functions.",
),
];
(SolverResult::NoSolution, StepByStepExplanation::new(steps))
}
pub fn solve(&self) -> (SolverResult, StepByStepExplanation) {
let equation = Expression::integer(0);
let variables = self.extract_variables(&equation);
if variables.is_empty() {
return (SolverResult::NoSolution, StepByStepExplanation::new(vec![]));
}
let primary_var = &variables[0];
self.solve_with_equation(&equation, primary_var)
}
fn extract_variables(&self, equation: &Expression) -> Vec<Symbol> {
let mut variables = std::collections::HashSet::new();
EquationAnalyzer::collect_variables(equation, &mut variables);
variables
.into_iter()
.map(|name| Symbol::new(&name))
.collect()
}
pub fn solve_system(&self, equations: &[Expression], variables: &[Symbol]) -> SolverResult {
use crate::algebra::solvers::SystemEquationSolver;
self.system_solver.solve_system(equations, variables)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::symbol;
#[test]
fn test_equation_type_detection() {
let x = symbol!(x);
let linear = Expression::add(vec![
Expression::mul(vec![Expression::integer(2), Expression::symbol(x.clone())]),
Expression::integer(3),
]);
assert_eq!(EquationAnalyzer::analyze(&linear, &x), EquationType::Linear);
let quadratic = Expression::add(vec![
Expression::pow(Expression::symbol(x.clone()), Expression::integer(2)),
Expression::mul(vec![Expression::integer(3), Expression::symbol(x.clone())]),
Expression::integer(2),
]);
assert_eq!(
EquationAnalyzer::analyze(&quadratic, &x),
EquationType::Quadratic
);
}
#[test]
fn test_numerical_equation_detection() {
let x = symbol!(x);
let quintic = Expression::add(vec![
Expression::pow(Expression::symbol(x.clone()), Expression::integer(5)),
Expression::mul(vec![Expression::integer(-1), Expression::symbol(x.clone())]),
Expression::integer(-1),
]);
assert_eq!(
EquationAnalyzer::analyze(&quintic, &x),
EquationType::Numerical
);
let transcendental_mixed = Expression::add(vec![
Expression::function("cos", vec![Expression::symbol(x.clone())]),
Expression::mul(vec![Expression::integer(-1), Expression::symbol(x.clone())]),
]);
assert_eq!(
EquationAnalyzer::analyze(&transcendental_mixed, &x),
EquationType::Numerical
);
}
#[test]
fn test_matrix_equation_detection() {
let a = symbol!(A; matrix);
let x = symbol!(X; matrix);
let b = symbol!(B; matrix);
let equation = Expression::add(vec![
Expression::mul(vec![
Expression::symbol(a.clone()),
Expression::symbol(x.clone()),
]),
Expression::mul(vec![Expression::integer(-1), Expression::symbol(b.clone())]),
]);
assert_eq!(
EquationAnalyzer::analyze(&equation, &x),
EquationType::Matrix
);
}
}