mathhook_core/algebra/
equation_analyzer.rs

1//! Analyzes LaTeX equations and routes to appropriate solvers
2//! This is the "brain" that decides which solver to use
3
4use crate::algebra::solvers::matrix_equations::MatrixEquationSolver;
5use crate::algebra::solvers::{EquationSolver, SolverResult};
6use crate::algebra::solvers::{LinearSolver, PolynomialSolver, QuadraticSolver, SystemSolver};
7use crate::calculus::ode::EducationalODESolver;
8use crate::calculus::pde::EducationalPDESolver;
9use crate::core::symbol::SymbolType;
10use crate::core::{Expression, Number, Symbol};
11use crate::educational::step_by_step::{Step, StepByStepExplanation};
12
13/// Types of equations our system can handle
14#[derive(Debug, Clone, Copy, PartialEq)]
15pub enum EquationType {
16    Constant,
17    Linear,
18    Quadratic,
19    Cubic,
20    Quartic,
21    System,
22    Transcendental,
23    Numerical,
24    Matrix,
25    ODE,
26    PDE,
27    Unknown,
28}
29
30/// Smart equation analyzer that determines solver routing
31pub struct EquationAnalyzer;
32
33impl EquationAnalyzer {
34    /// Analyze equation and determine type for solver dispatch
35    pub fn analyze(equation: &Expression, variable: &Symbol) -> EquationType {
36        let has_derivatives = Self::has_derivatives(equation);
37        let has_partial_derivatives = Self::has_partial_derivatives(equation);
38
39        if has_partial_derivatives {
40            return EquationType::PDE;
41        }
42
43        if has_derivatives {
44            return EquationType::ODE;
45        }
46
47        if Self::is_matrix_equation(equation, variable) {
48            return EquationType::Matrix;
49        }
50
51        let degree = Self::find_highest_degree(equation, variable);
52        let has_transcendental = Self::has_transcendental_functions(equation);
53        let variable_count = Self::count_variables(equation);
54
55        if Self::is_numerical_equation(equation, variable, degree, has_transcendental) {
56            return EquationType::Numerical;
57        }
58
59        match (degree, has_transcendental, variable_count) {
60            (0, false, _) => EquationType::Constant,
61            (1, false, 1) => EquationType::Linear,
62            (2, false, 1) => EquationType::Quadratic,
63            (3, false, 1) => EquationType::Cubic,
64            (4, false, 1) => EquationType::Quartic,
65            (_, false, 2..) => EquationType::System,
66            (_, true, _) => EquationType::Transcendental,
67            _ => EquationType::Unknown,
68        }
69    }
70
71    fn is_numerical_equation(
72        expr: &Expression,
73        _variable: &Symbol,
74        degree: u32,
75        has_transcendental: bool,
76    ) -> bool {
77        if degree > 4 {
78            return true;
79        }
80
81        if has_transcendental && degree > 0 {
82            return true;
83        }
84
85        if has_transcendental {
86            let func_count = Self::count_transcendental_functions(expr);
87            if func_count > 1 {
88                return true;
89            }
90        }
91
92        false
93    }
94
95    fn count_transcendental_functions(expr: &Expression) -> usize {
96        match expr {
97            Expression::Function { name, args } => {
98                let current =
99                    if matches!(name.as_ref(), "sin" | "cos" | "tan" | "exp" | "ln" | "log") {
100                        1
101                    } else {
102                        0
103                    };
104                current
105                    + args
106                        .iter()
107                        .map(Self::count_transcendental_functions)
108                        .sum::<usize>()
109            }
110            Expression::Add(terms) => terms.iter().map(Self::count_transcendental_functions).sum(),
111            Expression::Mul(factors) => factors
112                .iter()
113                .map(Self::count_transcendental_functions)
114                .sum(),
115            Expression::Pow(base, exp) => {
116                Self::count_transcendental_functions(base)
117                    + Self::count_transcendental_functions(exp)
118            }
119            _ => 0,
120        }
121    }
122
123    fn is_matrix_equation(expr: &Expression, _variable: &Symbol) -> bool {
124        Self::has_noncommutative_symbols(expr)
125    }
126
127    fn has_noncommutative_symbols(expr: &Expression) -> bool {
128        match expr {
129            Expression::Symbol(s) => {
130                matches!(
131                    s.symbol_type(),
132                    SymbolType::Matrix | SymbolType::Operator | SymbolType::Quaternion
133                )
134            }
135            Expression::Add(terms) | Expression::Mul(terms) => {
136                terms.iter().any(Self::has_noncommutative_symbols)
137            }
138            Expression::Pow(base, exp) => {
139                Self::has_noncommutative_symbols(base) || Self::has_noncommutative_symbols(exp)
140            }
141            Expression::Function { args, .. } => args.iter().any(Self::has_noncommutative_symbols),
142            _ => false,
143        }
144    }
145
146    fn find_highest_degree(expr: &Expression, variable: &Symbol) -> u32 {
147        match expr {
148            Expression::Pow(base, exp) if **base == Expression::symbol(variable.clone()) => {
149                match exp.as_ref() {
150                    Expression::Number(Number::Integer(n)) => *n as u32,
151                    _ => 1,
152                }
153            }
154            Expression::Mul(factors) => factors
155                .iter()
156                .map(|f| Self::find_highest_degree(f, variable))
157                .max()
158                .unwrap_or(0),
159            Expression::Add(terms) => terms
160                .iter()
161                .map(|t| Self::find_highest_degree(t, variable))
162                .max()
163                .unwrap_or(0),
164            _ if *expr == Expression::symbol(variable.clone()) => 1,
165            _ => 0,
166        }
167    }
168
169    fn has_transcendental_functions(expr: &Expression) -> bool {
170        match expr {
171            Expression::Function { name, args } => {
172                matches!(name.as_ref(), "sin" | "cos" | "tan" | "exp" | "ln" | "log")
173                    || args.iter().any(Self::has_transcendental_functions)
174            }
175            Expression::Add(terms) => terms.iter().any(Self::has_transcendental_functions),
176            Expression::Mul(factors) => factors.iter().any(Self::has_transcendental_functions),
177            Expression::Pow(base, exp) => {
178                Self::has_transcendental_functions(base) || Self::has_transcendental_functions(exp)
179            }
180            _ => false,
181        }
182    }
183
184    fn count_variables(expr: &Expression) -> usize {
185        let mut variables = std::collections::HashSet::new();
186        Self::collect_variables(expr, &mut variables);
187        variables.len()
188    }
189
190    pub fn collect_variables(expr: &Expression, variables: &mut std::collections::HashSet<String>) {
191        match expr {
192            Expression::Symbol(s) => {
193                variables.insert(s.name().to_owned());
194            }
195            Expression::Add(terms) => {
196                for term in terms.iter() {
197                    Self::collect_variables(term, variables);
198                }
199            }
200            Expression::Mul(factors) => {
201                for factor in factors.iter() {
202                    Self::collect_variables(factor, variables);
203                }
204            }
205            Expression::Pow(base, exp) => {
206                Self::collect_variables(base, variables);
207                Self::collect_variables(exp, variables);
208            }
209            Expression::Function { args, .. } => {
210                for arg in args.iter() {
211                    Self::collect_variables(arg, variables);
212                }
213            }
214            _ => {}
215        }
216    }
217
218    fn has_derivatives(expr: &Expression) -> bool {
219        match expr {
220            Expression::Function { name, args } => {
221                matches!(name.as_ref(), "derivative" | "diff" | "D")
222                    || args.iter().any(Self::has_derivatives)
223            }
224            Expression::Symbol(s) => {
225                let name = s.name();
226                name.ends_with('\'') || name.contains("_prime")
227            }
228            Expression::Add(terms) => terms.iter().any(Self::has_derivatives),
229            Expression::Mul(factors) => factors.iter().any(Self::has_derivatives),
230            Expression::Pow(base, exp) => Self::has_derivatives(base) || Self::has_derivatives(exp),
231            _ => false,
232        }
233    }
234
235    fn has_partial_derivatives(expr: &Expression) -> bool {
236        match expr {
237            Expression::Function { name, args } => {
238                matches!(name.as_ref(), "partial" | "pdiff" | "Partial")
239                    || args.iter().any(Self::has_partial_derivatives)
240            }
241            Expression::Symbol(s) => {
242                let name = s.name();
243                name.contains("partial") || name.contains("∂")
244            }
245            Expression::Add(terms) => terms.iter().any(Self::has_partial_derivatives),
246            Expression::Mul(factors) => factors.iter().any(Self::has_partial_derivatives),
247            Expression::Pow(base, exp) => {
248                Self::has_partial_derivatives(base) || Self::has_partial_derivatives(exp)
249            }
250            _ => false,
251        }
252    }
253}
254
255/// Master equation solver with smart dispatch
256pub struct SmartEquationSolver {
257    linear_solver: LinearSolver,
258    quadratic_solver: QuadraticSolver,
259    system_solver: SystemSolver,
260    polynomial_solver: PolynomialSolver,
261    matrix_solver: MatrixEquationSolver,
262    ode_solver: EducationalODESolver,
263    pde_solver: EducationalPDESolver,
264}
265
266impl Default for SmartEquationSolver {
267    fn default() -> Self {
268        Self::new()
269    }
270}
271
272impl SmartEquationSolver {
273    pub fn new() -> Self {
274        Self {
275            linear_solver: LinearSolver::new(),
276            quadratic_solver: QuadraticSolver::new(),
277            system_solver: SystemSolver::new(),
278            polynomial_solver: PolynomialSolver::new(),
279            matrix_solver: MatrixEquationSolver::new(),
280            ode_solver: EducationalODESolver::new(),
281            pde_solver: EducationalPDESolver::new(),
282        }
283    }
284
285    /// Solve equation with educational explanation, including equation analysis
286    ///
287    /// This is the primary entry point for solving equations with full educational
288    /// integration. It automatically:
289    /// 1. Analyzes the equation type
290    /// 2. Explains the equation structure
291    /// 3. Selects the appropriate solver
292    /// 4. Provides step-by-step solution with explanations
293    ///
294    /// # Arguments
295    ///
296    /// * `equation` - The equation expression to solve
297    /// * `variable` - The variable to solve for
298    ///
299    /// # Returns
300    ///
301    /// A tuple containing:
302    /// - The solver result (solutions or error)
303    /// - Complete step-by-step explanation starting with equation analysis
304    pub fn solve_with_equation(
305        &self,
306        equation: &Expression,
307        variable: &Symbol,
308    ) -> (SolverResult, StepByStepExplanation) {
309        let mut all_steps = Vec::new();
310
311        let degree = EquationAnalyzer::find_highest_degree(equation, variable);
312        let eq_type = EquationAnalyzer::analyze(equation, variable);
313
314        let analysis_description = match eq_type {
315            EquationType::Constant => {
316                "Detected constant equation (no variables)".to_owned()
317            }
318            EquationType::Linear => {
319                format!("Detected linear equation (highest degree: {})", degree)
320            }
321            EquationType::Quadratic => {
322                format!("Detected quadratic equation (highest degree: {})", degree)
323            }
324            EquationType::Cubic => {
325                format!("Detected cubic equation (highest degree: {})", degree)
326            }
327            EquationType::Quartic => {
328                format!("Detected quartic equation (highest degree: {})", degree)
329            }
330            EquationType::System => {
331                "Detected system of equations (multiple variables)".to_owned()
332            }
333            EquationType::Transcendental => {
334                "Detected transcendental equation (contains trig/exp/log functions)".to_owned()
335            }
336            EquationType::Numerical => {
337                "Detected numerical equation (requires numerical methods - polynomial degree > 4 or mixed transcendental)".to_owned()
338            }
339            EquationType::Matrix => {
340                "Detected matrix equation (contains noncommutative symbols)".to_owned()
341            }
342            EquationType::ODE => {
343                "Detected ordinary differential equation (contains derivatives)".to_owned()
344            }
345            EquationType::PDE => {
346                "Detected partial differential equation (contains partial derivatives)".to_owned()
347            }
348            EquationType::Unknown => {
349                "Unknown equation type".to_owned()
350            }
351        };
352
353        all_steps.push(Step::new("Equation Analysis", analysis_description));
354
355        let solver_description = match eq_type {
356            EquationType::Linear => "Using linear equation solver (isolation method)",
357            EquationType::Quadratic => "Using quadratic equation solver (quadratic formula)",
358            EquationType::Cubic | EquationType::Quartic => "Using polynomial solver",
359            EquationType::System => "Using system equation solver",
360            EquationType::Numerical => {
361                "Using numerical solver (Newton-Raphson method with numerical differentiation)"
362            }
363            EquationType::Matrix => "Using matrix equation solver (left/right division)",
364            EquationType::ODE => "Using ODE solver (separable/linear/exact methods)",
365            EquationType::PDE => {
366                "Using PDE solver (method of characteristics/separation of variables)"
367            }
368            _ => "No specialized solver available for this equation type",
369        };
370
371        all_steps.push(Step::new("Solver Selection", solver_description));
372
373        let (result, solver_steps) = match eq_type {
374            EquationType::Linear => self
375                .linear_solver
376                .solve_with_explanation(equation, variable),
377            EquationType::Quadratic => self
378                .quadratic_solver
379                .solve_with_explanation(equation, variable),
380            EquationType::Cubic | EquationType::Quartic => self
381                .polynomial_solver
382                .solve_with_explanation(equation, variable),
383            EquationType::System => self
384                .system_solver
385                .solve_with_explanation(equation, variable),
386            EquationType::Numerical => self.solve_numerical(equation, variable),
387            EquationType::Matrix => self
388                .matrix_solver
389                .solve_with_explanation(equation, variable),
390            EquationType::ODE => self.ode_solver.solve_with_explanation(equation, variable),
391            EquationType::PDE => self.pde_solver.solve_with_explanation(equation, variable),
392            _ => {
393                all_steps.push(Step::new(
394                    "Status",
395                    "This equation type is not yet fully implemented",
396                ));
397                (SolverResult::NoSolution, StepByStepExplanation::new(vec![]))
398            }
399        };
400
401        all_steps.extend(solver_steps.steps);
402
403        (result, StepByStepExplanation::new(all_steps))
404    }
405
406    fn solve_numerical(
407        &self,
408        _equation: &Expression,
409        variable: &Symbol,
410    ) -> (SolverResult, StepByStepExplanation) {
411        let steps = vec![
412            Step::new(
413                "Numerical Method Required",
414                format!(
415                    "This equation requires numerical methods to solve for {}. Newton-Raphson method integration is available.",
416                    variable.name()
417                ),
418            ),
419            Step::new(
420                "Method Description",
421                "Newton-Raphson method with numerical differentiation provides robust convergence for smooth functions.",
422            ),
423        ];
424
425        (SolverResult::NoSolution, StepByStepExplanation::new(steps))
426    }
427
428    /// Legacy solve method (deprecated, use solve_with_equation instead)
429    pub fn solve(&self) -> (SolverResult, StepByStepExplanation) {
430        let equation = Expression::integer(0);
431        let variables = self.extract_variables(&equation);
432        if variables.is_empty() {
433            return (SolverResult::NoSolution, StepByStepExplanation::new(vec![]));
434        }
435
436        let primary_var = &variables[0];
437        self.solve_with_equation(&equation, primary_var)
438    }
439
440    fn extract_variables(&self, equation: &Expression) -> Vec<Symbol> {
441        let mut variables = std::collections::HashSet::new();
442        EquationAnalyzer::collect_variables(equation, &mut variables);
443
444        variables
445            .into_iter()
446            .map(|name| Symbol::new(&name))
447            .collect()
448    }
449
450    /// Solve system of equations using the integrated system solver
451    ///
452    /// This method exposes the system solving capability through SmartEquationSolver,
453    /// allowing for solving both linear and polynomial systems (via Grobner basis).
454    ///
455    /// # Arguments
456    ///
457    /// * `equations` - Array of equations to solve
458    /// * `variables` - Array of variables to solve for
459    ///
460    /// # Returns
461    ///
462    /// SolverResult containing solutions, no solution, or partial solutions
463    ///
464    /// # Examples
465    ///
466    /// ```rust
467    /// use mathhook_core::algebra::equation_analyzer::SmartEquationSolver;
468    /// use mathhook_core::{symbol, Expression};
469    ///
470    /// let solver = SmartEquationSolver::new();
471    /// let x = symbol!(x);
472    /// let y = symbol!(y);
473    ///
474    /// // Linear system: 2x + y = 5, x - y = 1
475    /// let eq1 = Expression::add(vec![
476    ///     Expression::mul(vec![Expression::integer(2), Expression::symbol(x.clone())]),
477    ///     Expression::symbol(y.clone()),
478    ///     Expression::integer(-5),
479    /// ]);
480    /// let eq2 = Expression::add(vec![
481    ///     Expression::symbol(x.clone()),
482    ///     Expression::mul(vec![Expression::integer(-1), Expression::symbol(y.clone())]),
483    ///     Expression::integer(-1),
484    /// ]);
485    ///
486    /// let result = solver.solve_system(&[eq1, eq2], &[x, y]);
487    /// ```
488    pub fn solve_system(&self, equations: &[Expression], variables: &[Symbol]) -> SolverResult {
489        use crate::algebra::solvers::SystemEquationSolver;
490        self.system_solver.solve_system(equations, variables)
491    }
492}
493
494#[cfg(test)]
495mod tests {
496    use super::*;
497    use crate::symbol;
498
499    #[test]
500    fn test_equation_type_detection() {
501        let x = symbol!(x);
502
503        let linear = Expression::add(vec![
504            Expression::mul(vec![Expression::integer(2), Expression::symbol(x.clone())]),
505            Expression::integer(3),
506        ]);
507        assert_eq!(EquationAnalyzer::analyze(&linear, &x), EquationType::Linear);
508
509        let quadratic = Expression::add(vec![
510            Expression::pow(Expression::symbol(x.clone()), Expression::integer(2)),
511            Expression::mul(vec![Expression::integer(3), Expression::symbol(x.clone())]),
512            Expression::integer(2),
513        ]);
514        assert_eq!(
515            EquationAnalyzer::analyze(&quadratic, &x),
516            EquationType::Quadratic
517        );
518    }
519
520    #[test]
521    fn test_numerical_equation_detection() {
522        let x = symbol!(x);
523
524        let quintic = Expression::add(vec![
525            Expression::pow(Expression::symbol(x.clone()), Expression::integer(5)),
526            Expression::mul(vec![Expression::integer(-1), Expression::symbol(x.clone())]),
527            Expression::integer(-1),
528        ]);
529        assert_eq!(
530            EquationAnalyzer::analyze(&quintic, &x),
531            EquationType::Numerical
532        );
533
534        let transcendental_mixed = Expression::add(vec![
535            Expression::function("cos", vec![Expression::symbol(x.clone())]),
536            Expression::mul(vec![Expression::integer(-1), Expression::symbol(x.clone())]),
537        ]);
538        assert_eq!(
539            EquationAnalyzer::analyze(&transcendental_mixed, &x),
540            EquationType::Numerical
541        );
542    }
543
544    #[test]
545    fn test_matrix_equation_detection() {
546        let a = symbol!(A; matrix);
547        let x = symbol!(X; matrix);
548        let b = symbol!(B; matrix);
549
550        let equation = Expression::add(vec![
551            Expression::mul(vec![
552                Expression::symbol(a.clone()),
553                Expression::symbol(x.clone()),
554            ]),
555            Expression::mul(vec![Expression::integer(-1), Expression::symbol(b.clone())]),
556        ]);
557
558        assert_eq!(
559            EquationAnalyzer::analyze(&equation, &x),
560            EquationType::Matrix
561        );
562    }
563}