mathhook_core/algebra/
equation_analyzer.rs

1//! Analyzes LaTeX equations and routes to appropriate solvers
2//! This is the "brain" that decides which solver to use
3
4use crate::algebra::solvers::matrix_equations::MatrixEquationSolver;
5use crate::algebra::solvers::{EquationSolver, SolverResult};
6use crate::algebra::solvers::{LinearSolver, PolynomialSolver, QuadraticSolver, SystemSolver};
7// Unused imports removed
8use crate::calculus::ode::EducationalODESolver;
9use crate::calculus::pde::EducationalPDESolver;
10use crate::core::symbol::SymbolType;
11use crate::core::{Expression, Number, Symbol};
12use crate::educational::step_by_step::{Step, StepByStepExplanation};
13
14/// Types of equations our system can handle
15#[derive(Debug, Clone, Copy, PartialEq)]
16pub enum EquationType {
17    Constant,       // "5 = 0"
18    Linear,         // "2x + 3 = 0"
19    Quadratic,      // "x² + 3x + 2 = 0"
20    Cubic,          // "x³ + 2x² + x + 1 = 0"
21    Quartic,        // "x⁴ + x³ + x² + x + 1 = 0"
22    System,         // "2x + 3y = 5, x - y = 1"
23    Transcendental, // "sin(x) = 0", "e^x = 5"
24    Numerical,      // Equations requiring numerical methods
25    Matrix,         // "A*X = B" (matrix equation)
26    ODE,            // "y' + 2y = x", "y'' + 3y' + 2y = 0"
27    PDE,            // "∂u/∂t = k∂²u/∂x²" (heat equation)
28    Unknown,
29}
30
31/// Smart equation analyzer that determines solver routing
32pub struct EquationAnalyzer;
33
34impl EquationAnalyzer {
35    /// Analyze equation and determine type for solver dispatch
36    pub fn analyze(equation: &Expression, variable: &Symbol) -> EquationType {
37        let has_derivatives = Self::has_derivatives(equation);
38        let has_partial_derivatives = Self::has_partial_derivatives(equation);
39
40        if has_partial_derivatives {
41            return EquationType::PDE;
42        }
43
44        if has_derivatives {
45            return EquationType::ODE;
46        }
47
48        if Self::is_matrix_equation(equation, variable) {
49            return EquationType::Matrix;
50        }
51
52        let degree = Self::find_highest_degree(equation, variable);
53        let has_transcendental = Self::has_transcendental_functions(equation);
54        let variable_count = Self::count_variables(equation);
55
56        // Check if numerical methods are needed
57        if Self::is_numerical_equation(equation, variable, degree, has_transcendental) {
58            return EquationType::Numerical;
59        }
60
61        match (degree, has_transcendental, variable_count) {
62            (0, false, _) => EquationType::Constant,
63            (1, false, 1) => EquationType::Linear,
64            (2, false, 1) => EquationType::Quadratic,
65            (3, false, 1) => EquationType::Cubic,
66            (4, false, 1) => EquationType::Quartic,
67            (_, false, 2..) => EquationType::System,
68            (_, true, _) => EquationType::Transcendental,
69            _ => EquationType::Unknown,
70        }
71    }
72
73    /// Determine if equation requires numerical methods
74    ///
75    /// Numerical methods are needed when:
76    /// - Polynomial degree > 4 (no general algebraic formula)
77    /// - Transcendental functions mixed with polynomials (x = sin(x))
78    /// - Complex transcendental equations (e^x = x^2)
79    fn is_numerical_equation(
80        expr: &Expression,
81        _variable: &Symbol,
82        degree: u32,
83        has_transcendental: bool,
84    ) -> bool {
85        // Polynomial degree > 4 requires numerical methods
86        if degree > 4 {
87            return true;
88        }
89
90        // Mixed transcendental and polynomial requires numerical methods
91        if has_transcendental && degree > 0 {
92            return true;
93        }
94
95        // Complex transcendental equations
96        if has_transcendental {
97            let func_count = Self::count_transcendental_functions(expr);
98            if func_count > 1 {
99                return true;
100            }
101        }
102
103        false
104    }
105
106    /// Count transcendental functions in expression
107    fn count_transcendental_functions(expr: &Expression) -> usize {
108        match expr {
109            Expression::Function { name, args } => {
110                let current =
111                    if matches!(name.as_str(), "sin" | "cos" | "tan" | "exp" | "ln" | "log") {
112                        1
113                    } else {
114                        0
115                    };
116                current
117                    + args
118                        .iter()
119                        .map(Self::count_transcendental_functions)
120                        .sum::<usize>()
121            }
122            Expression::Add(terms) => terms.iter().map(Self::count_transcendental_functions).sum(),
123            Expression::Mul(factors) => factors
124                .iter()
125                .map(Self::count_transcendental_functions)
126                .sum(),
127            Expression::Pow(base, exp) => {
128                Self::count_transcendental_functions(base)
129                    + Self::count_transcendental_functions(exp)
130            }
131            _ => 0,
132        }
133    }
134
135    /// Check if equation is a matrix equation (contains noncommutative symbols)
136    fn is_matrix_equation(expr: &Expression, _variable: &Symbol) -> bool {
137        Self::has_noncommutative_symbols(expr)
138    }
139
140    /// Check if expression contains noncommutative symbols (matrix, operator, quaternion)
141    fn has_noncommutative_symbols(expr: &Expression) -> bool {
142        match expr {
143            Expression::Symbol(s) => {
144                matches!(
145                    s.symbol_type(),
146                    SymbolType::Matrix | SymbolType::Operator | SymbolType::Quaternion
147                )
148            }
149            Expression::Add(terms) | Expression::Mul(terms) => {
150                terms.iter().any(Self::has_noncommutative_symbols)
151            }
152            Expression::Pow(base, exp) => {
153                Self::has_noncommutative_symbols(base) || Self::has_noncommutative_symbols(exp)
154            }
155            Expression::Function { args, .. } => args.iter().any(Self::has_noncommutative_symbols),
156            _ => false,
157        }
158    }
159
160    /// Find the highest degree of variable in expression
161    fn find_highest_degree(expr: &Expression, variable: &Symbol) -> u32 {
162        match expr {
163            // Direct power: x^2, x^3, etc.
164            Expression::Pow(base, exp) if **base == Expression::symbol(variable.clone()) => {
165                match exp.as_ref() {
166                    Expression::Number(Number::Integer(n)) => *n as u32,
167                    _ => 1,
168                }
169            }
170            // Multiplication: 3x^2, coefficient * x^power
171            Expression::Mul(factors) => factors
172                .iter()
173                .map(|f| Self::find_highest_degree(f, variable))
174                .max()
175                .unwrap_or(0),
176            // Addition: x^2 + 3x + 2
177            Expression::Add(terms) => terms
178                .iter()
179                .map(|t| Self::find_highest_degree(t, variable))
180                .max()
181                .unwrap_or(0),
182            // Simple variable: x (degree 1)
183            _ if *expr == Expression::symbol(variable.clone()) => 1,
184            // Constant or other variable
185            _ => 0,
186        }
187    }
188
189    /// Check for transcendental functions
190    fn has_transcendental_functions(expr: &Expression) -> bool {
191        match expr {
192            Expression::Function { name, args } => {
193                matches!(name.as_str(), "sin" | "cos" | "tan" | "exp" | "ln" | "log")
194                    || args.iter().any(Self::has_transcendental_functions)
195            }
196            Expression::Add(terms) => terms.iter().any(Self::has_transcendental_functions),
197            Expression::Mul(factors) => factors.iter().any(Self::has_transcendental_functions),
198            Expression::Pow(base, exp) => {
199                Self::has_transcendental_functions(base) || Self::has_transcendental_functions(exp)
200            }
201            _ => false,
202        }
203    }
204
205    /// Count unique variables in expression
206    fn count_variables(expr: &Expression) -> usize {
207        let mut variables = std::collections::HashSet::new();
208        Self::collect_variables(expr, &mut variables);
209        variables.len()
210    }
211
212    /// Recursively collect all variables
213    pub fn collect_variables(expr: &Expression, variables: &mut std::collections::HashSet<String>) {
214        match expr {
215            Expression::Symbol(s) => {
216                variables.insert(s.name().to_owned());
217            }
218            Expression::Add(terms) => {
219                for term in terms.iter() {
220                    Self::collect_variables(term, variables);
221                }
222            }
223            Expression::Mul(factors) => {
224                for factor in factors.iter() {
225                    Self::collect_variables(factor, variables);
226                }
227            }
228            Expression::Pow(base, exp) => {
229                Self::collect_variables(base, variables);
230                Self::collect_variables(exp, variables);
231            }
232            Expression::Function { args, .. } => {
233                for arg in args.iter() {
234                    Self::collect_variables(arg, variables);
235                }
236            }
237            _ => {}
238        }
239    }
240
241    /// Check if expression contains ordinary derivatives (y', dy/dx, etc.)
242    fn has_derivatives(expr: &Expression) -> bool {
243        match expr {
244            Expression::Function { name, args } => {
245                matches!(name.as_str(), "derivative" | "diff" | "D")
246                    || args.iter().any(Self::has_derivatives)
247            }
248            Expression::Symbol(s) => {
249                let name = s.name();
250                name.ends_with('\'') || name.contains("_prime")
251            }
252            Expression::Add(terms) => terms.iter().any(Self::has_derivatives),
253            Expression::Mul(factors) => factors.iter().any(Self::has_derivatives),
254            Expression::Pow(base, exp) => Self::has_derivatives(base) || Self::has_derivatives(exp),
255            _ => false,
256        }
257    }
258
259    /// Check if expression contains partial derivatives (∂u/∂x, ∂²u/∂x², etc.)
260    fn has_partial_derivatives(expr: &Expression) -> bool {
261        match expr {
262            Expression::Function { name, args } => {
263                matches!(name.as_str(), "partial" | "pdiff" | "Partial")
264                    || args.iter().any(Self::has_partial_derivatives)
265            }
266            Expression::Symbol(s) => {
267                let name = s.name();
268                name.contains("partial") || name.contains("∂")
269            }
270            Expression::Add(terms) => terms.iter().any(Self::has_partial_derivatives),
271            Expression::Mul(factors) => factors.iter().any(Self::has_partial_derivatives),
272            Expression::Pow(base, exp) => {
273                Self::has_partial_derivatives(base) || Self::has_partial_derivatives(exp)
274            }
275            _ => false,
276        }
277    }
278}
279
280/// Master equation solver with smart dispatch
281pub struct SmartEquationSolver {
282    linear_solver: LinearSolver,
283    quadratic_solver: QuadraticSolver,
284    system_solver: SystemSolver,
285    polynomial_solver: PolynomialSolver,
286    matrix_solver: MatrixEquationSolver,
287    ode_solver: EducationalODESolver,
288    pde_solver: EducationalPDESolver,
289}
290
291impl Default for SmartEquationSolver {
292    fn default() -> Self {
293        Self::new()
294    }
295}
296
297impl SmartEquationSolver {
298    pub fn new() -> Self {
299        Self {
300            linear_solver: LinearSolver::new(),
301            quadratic_solver: QuadraticSolver::new(),
302            system_solver: SystemSolver::new(),
303            polynomial_solver: PolynomialSolver::new(),
304            matrix_solver: MatrixEquationSolver::new(),
305            ode_solver: EducationalODESolver::new(),
306            pde_solver: EducationalPDESolver::new(),
307        }
308    }
309
310    /// Solve equation with educational explanation, including equation analysis
311    ///
312    /// This is the primary entry point for solving equations with full educational
313    /// integration. It automatically:
314    /// 1. Analyzes the equation type
315    /// 2. Explains the equation structure
316    /// 3. Selects the appropriate solver
317    /// 4. Provides step-by-step solution with explanations
318    ///
319    /// # Arguments
320    ///
321    /// * `equation` - The equation expression to solve
322    /// * `variable` - The variable to solve for
323    ///
324    /// # Returns
325    ///
326    /// A tuple containing:
327    /// - The solver result (solutions or error)
328    /// - Complete step-by-step explanation starting with equation analysis
329    pub fn solve_with_equation(
330        &mut self,
331        equation: &Expression,
332        variable: &Symbol,
333    ) -> (SolverResult, StepByStepExplanation) {
334        let mut all_steps = Vec::new();
335
336        let degree = EquationAnalyzer::find_highest_degree(equation, variable);
337        let eq_type = EquationAnalyzer::analyze(equation, variable);
338
339        let analysis_description = match eq_type {
340            EquationType::Constant => {
341                "Detected constant equation (no variables)".to_owned()
342            }
343            EquationType::Linear => {
344                format!("Detected linear equation (highest degree: {})", degree)
345            }
346            EquationType::Quadratic => {
347                format!("Detected quadratic equation (highest degree: {})", degree)
348            }
349            EquationType::Cubic => {
350                format!("Detected cubic equation (highest degree: {})", degree)
351            }
352            EquationType::Quartic => {
353                format!("Detected quartic equation (highest degree: {})", degree)
354            }
355            EquationType::System => {
356                "Detected system of equations (multiple variables)".to_owned()
357            }
358            EquationType::Transcendental => {
359                "Detected transcendental equation (contains trig/exp/log functions)".to_owned()
360            }
361            EquationType::Numerical => {
362                "Detected numerical equation (requires numerical methods - polynomial degree > 4 or mixed transcendental)".to_owned()
363            }
364            EquationType::Matrix => {
365                "Detected matrix equation (contains noncommutative symbols)".to_owned()
366            }
367            EquationType::ODE => {
368                "Detected ordinary differential equation (contains derivatives)".to_owned()
369            }
370            EquationType::PDE => {
371                "Detected partial differential equation (contains partial derivatives)".to_owned()
372            }
373            EquationType::Unknown => {
374                "Unknown equation type".to_owned()
375            }
376        };
377
378        all_steps.push(Step::new("Equation Analysis", analysis_description));
379
380        let solver_description = match eq_type {
381            EquationType::Linear => "Using linear equation solver (isolation method)",
382            EquationType::Quadratic => "Using quadratic equation solver (quadratic formula)",
383            EquationType::Cubic | EquationType::Quartic => "Using polynomial solver",
384            EquationType::System => "Using system equation solver",
385            EquationType::Numerical => {
386                "Using numerical solver (Newton-Raphson method with numerical differentiation)"
387            }
388            EquationType::Matrix => "Using matrix equation solver (left/right division)",
389            EquationType::ODE => "Using ODE solver (separable/linear/exact methods)",
390            EquationType::PDE => {
391                "Using PDE solver (method of characteristics/separation of variables)"
392            }
393            _ => "No specialized solver available for this equation type",
394        };
395
396        all_steps.push(Step::new("Solver Selection", solver_description));
397
398        let (result, solver_steps) = match eq_type {
399            EquationType::Linear => self
400                .linear_solver
401                .solve_with_explanation(equation, variable),
402            EquationType::Quadratic => self
403                .quadratic_solver
404                .solve_with_explanation(equation, variable),
405            EquationType::Cubic | EquationType::Quartic => self
406                .polynomial_solver
407                .solve_with_explanation(equation, variable),
408            EquationType::System => self
409                .system_solver
410                .solve_with_explanation(equation, variable),
411            EquationType::Numerical => self.solve_numerical(equation, variable),
412            EquationType::Matrix => self
413                .matrix_solver
414                .solve_with_explanation(equation, variable),
415            EquationType::ODE => self.ode_solver.solve_with_explanation(equation, variable),
416            EquationType::PDE => self.pde_solver.solve_with_explanation(equation, variable),
417            _ => {
418                all_steps.push(Step::new(
419                    "Status",
420                    "This equation type is not yet fully implemented",
421                ));
422                (SolverResult::NoSolution, StepByStepExplanation::new(vec![]))
423            }
424        };
425
426        all_steps.extend(solver_steps.steps);
427
428        (result, StepByStepExplanation::new(all_steps))
429    }
430
431    /// Solve numerical equations using Newton-Raphson method
432    ///
433    /// Provides integration point for numerical solving. Currently provides
434    /// educational explanation about numerical methods requirement.
435    fn solve_numerical(
436        &self,
437        _equation: &Expression,
438        variable: &Symbol,
439    ) -> (SolverResult, StepByStepExplanation) {
440        let steps = vec![
441            Step::new(
442                "Numerical Method Required",
443                format!(
444                    "This equation requires numerical methods to solve for {}. Newton-Raphson method integration is available.",
445                    variable.name()
446                ),
447            ),
448            Step::new(
449                "Method Description",
450                "Newton-Raphson method with numerical differentiation provides robust convergence for smooth functions.",
451            ),
452        ];
453
454        (SolverResult::NoSolution, StepByStepExplanation::new(steps))
455    }
456
457    /// Legacy solve method (deprecated, use solve_with_equation instead)
458    pub fn solve(&mut self) -> (SolverResult, StepByStepExplanation) {
459        let equation = Expression::integer(0);
460        let variables = self.extract_variables(&equation);
461        if variables.is_empty() {
462            return (SolverResult::NoSolution, StepByStepExplanation::new(vec![]));
463        }
464
465        let primary_var = &variables[0];
466        self.solve_with_equation(&equation, primary_var)
467    }
468
469    /// Extract variables from equation
470    fn extract_variables(&self, equation: &Expression) -> Vec<Symbol> {
471        let mut variables = std::collections::HashSet::new();
472        EquationAnalyzer::collect_variables(equation, &mut variables);
473
474        variables
475            .into_iter()
476            .map(|name| Symbol::new(&name))
477            .collect()
478    }
479
480    /// Solve system of equations using the integrated system solver
481    ///
482    /// This method exposes the system solving capability through SmartEquationSolver,
483    /// allowing for solving both linear and polynomial systems (via Gröbner basis).
484    ///
485    /// # Arguments
486    ///
487    /// * `equations` - Array of equations to solve
488    /// * `variables` - Array of variables to solve for
489    ///
490    /// # Returns
491    ///
492    /// SolverResult containing solutions, no solution, or partial solutions
493    ///
494    /// # Examples
495    ///
496    /// ```rust
497    /// use mathhook_core::algebra::equation_analyzer::SmartEquationSolver;
498    /// use mathhook_core::{symbol, Expression};
499    ///
500    /// let mut solver = SmartEquationSolver::new();
501    /// let x = symbol!(x);
502    /// let y = symbol!(y);
503    ///
504    /// // Linear system: 2x + y = 5, x - y = 1
505    /// let eq1 = Expression::add(vec![
506    ///     Expression::mul(vec![Expression::integer(2), Expression::symbol(x.clone())]),
507    ///     Expression::symbol(y.clone()),
508    ///     Expression::integer(-5),
509    /// ]);
510    /// let eq2 = Expression::add(vec![
511    ///     Expression::symbol(x.clone()),
512    ///     Expression::mul(vec![Expression::integer(-1), Expression::symbol(y.clone())]),
513    ///     Expression::integer(-1),
514    /// ]);
515    ///
516    /// let result = solver.solve_system(&[eq1, eq2], &[x, y]);
517    /// ```
518    pub fn solve_system(&mut self, equations: &[Expression], variables: &[Symbol]) -> SolverResult {
519        use crate::algebra::solvers::SystemEquationSolver;
520        self.system_solver.solve_system(equations, variables)
521    }
522}
523
524#[cfg(test)]
525mod tests {
526    use super::*;
527    use crate::symbol;
528
529    #[test]
530    fn test_equation_type_detection() {
531        let x = symbol!(x);
532
533        // Linear: 2x + 3
534        let linear = Expression::add(vec![
535            Expression::mul(vec![Expression::integer(2), Expression::symbol(x.clone())]),
536            Expression::integer(3),
537        ]);
538        assert_eq!(EquationAnalyzer::analyze(&linear, &x), EquationType::Linear);
539
540        // Quadratic: x^2 + 3x + 2
541        let quadratic = Expression::add(vec![
542            Expression::pow(Expression::symbol(x.clone()), Expression::integer(2)),
543            Expression::mul(vec![Expression::integer(3), Expression::symbol(x.clone())]),
544            Expression::integer(2),
545        ]);
546        assert_eq!(
547            EquationAnalyzer::analyze(&quadratic, &x),
548            EquationType::Quadratic
549        );
550    }
551
552    #[test]
553    fn test_numerical_equation_detection() {
554        let x = symbol!(x);
555
556        // High-degree polynomial: x^5 - x - 1 (numerical)
557        let quintic = Expression::add(vec![
558            Expression::pow(Expression::symbol(x.clone()), Expression::integer(5)),
559            Expression::mul(vec![Expression::integer(-1), Expression::symbol(x.clone())]),
560            Expression::integer(-1),
561        ]);
562        assert_eq!(
563            EquationAnalyzer::analyze(&quintic, &x),
564            EquationType::Numerical
565        );
566
567        // Mixed transcendental: cos(x) - x (numerical)
568        let transcendental_mixed = Expression::add(vec![
569            Expression::function("cos", vec![Expression::symbol(x.clone())]),
570            Expression::mul(vec![Expression::integer(-1), Expression::symbol(x.clone())]),
571        ]);
572        assert_eq!(
573            EquationAnalyzer::analyze(&transcendental_mixed, &x),
574            EquationType::Numerical
575        );
576    }
577
578    #[test]
579    fn test_matrix_equation_detection() {
580        let a = symbol!(A; matrix);
581        let x = symbol!(X; matrix);
582        let b = symbol!(B; matrix);
583
584        // A*X - B = 0
585        let equation = Expression::add(vec![
586            Expression::mul(vec![
587                Expression::symbol(a.clone()),
588                Expression::symbol(x.clone()),
589            ]),
590            Expression::mul(vec![Expression::integer(-1), Expression::symbol(b.clone())]),
591        ]);
592
593        assert_eq!(
594            EquationAnalyzer::analyze(&equation, &x),
595            EquationType::Matrix
596        );
597    }
598}