mathhook_core/algebra/
solvers.rs

1//! Equation solvers module with modern Rust structure
2//!
3//! Comprehensive equation solving with step-by-step explanations.
4//! Follows modern Rust 2021+ conventions and test-driven development approach.
5
6use crate::core::{Expression, Number, Symbol};
7use crate::educational::step_by_step::{Step, StepByStepExplanation};
8use serde::{Deserialize, Serialize};
9
10// Individual solver modules
11pub mod linear;
12pub mod matrix_equations;
13pub mod polynomial;
14pub mod quadratic;
15pub mod systems;
16
17// Re-exports for easy access
18pub use linear::LinearSolver;
19pub use matrix_equations::MatrixEquationSolver;
20pub use polynomial::PolynomialSolver;
21pub use quadratic::QuadraticSolver;
22pub use systems::SystemSolver;
23
24/// Unified result type for equation solvers
25#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
26pub enum SolverResult {
27    /// Single solution found
28    Single(Expression),
29    /// Multiple solutions found
30    Multiple(Vec<Expression>),
31    /// No solution exists
32    NoSolution,
33    /// Infinite solutions exist
34    InfiniteSolutions,
35    /// Parametric solutions (for systems)
36    Parametric(Vec<Expression>),
37    /// Partial solutions found (some but not all roots)
38    /// Used when a solver can find some roots but not all expected roots.
39    /// For example, a cubic equation may have one real root found via rational root theorem,
40    /// but the remaining complex roots cannot be computed without implementing the full cubic formula.
41    Partial(Vec<Expression>),
42}
43
44/// Unified error handling for equation solvers
45#[derive(Debug, Clone, PartialEq)]
46pub enum SolverError {
47    /// Malformed equation
48    InvalidEquation(String),
49    /// Unsupported equation type
50    UnsupportedType(String),
51    /// Numerical instability
52    NumericalInstability(String),
53    /// Too complex to solve
54    ComplexityLimit(String),
55}
56
57/// Common interface for equation solvers
58pub trait EquationSolver {
59    /// Solve equation for given variable
60    fn solve(&self, equation: &Expression, variable: &Symbol) -> SolverResult;
61
62    /// Solve with step-by-step explanation
63    fn solve_with_explanation(
64        &self,
65        equation: &Expression,
66        variable: &Symbol,
67    ) -> (SolverResult, StepByStepExplanation);
68
69    /// Check if solver can handle this equation type
70    fn can_solve(&self, equation: &Expression) -> bool;
71}
72
73/// Trait for solving systems of equations
74pub trait SystemEquationSolver {
75    /// Solve system of equations
76    fn solve_system(&self, equations: &[Expression], variables: &[Symbol]) -> SolverResult;
77
78    /// Solve system with step-by-step explanation
79    fn solve_system_with_explanation(
80        &self,
81        equations: &[Expression],
82        variables: &[Symbol],
83    ) -> (SolverResult, StepByStepExplanation);
84}
85
86// Solver result utility methods
87
88impl SolverResult {
89    /// Check if result represents a valid solution
90    pub fn is_valid_solution(&self) -> bool {
91        match self {
92            SolverResult::NoSolution => true,
93            SolverResult::InfiniteSolutions => true,
94            SolverResult::Single(expr) => expr.is_valid_expression(),
95            SolverResult::Multiple(exprs) => exprs.iter().all(|e| e.is_valid_expression()),
96            SolverResult::Parametric(exprs) => exprs.iter().all(|e| e.is_valid_expression()),
97            SolverResult::Partial(exprs) => exprs.iter().all(|e| e.is_valid_expression()),
98        }
99    }
100
101    /// Get number of solutions
102    pub fn solution_count(&self) -> Option<usize> {
103        match self {
104            SolverResult::Single(_) => Some(1),
105            SolverResult::Multiple(exprs) => Some(exprs.len()),
106            SolverResult::Parametric(exprs) => Some(exprs.len()),
107            SolverResult::Partial(exprs) => Some(exprs.len()),
108            SolverResult::NoSolution => Some(0),
109            SolverResult::InfiniteSolutions => None,
110        }
111    }
112}
113
114// Step-by-step integration for equation solvers
115
116/// Extension trait for Expression to add solver step-by-step support
117pub trait SolverStepByStep {
118    /// Solve with complete step-by-step explanation
119    fn solve_with_steps(&self, variable: &Symbol) -> (SolverResult, StepByStepExplanation);
120
121    /// Generate step-by-step explanation for solving process
122    fn explain_solving_steps(&self, variable: &Symbol) -> StepByStepExplanation;
123}
124
125impl SolverStepByStep for Expression {
126    fn solve_with_steps(&self, _variable: &Symbol) -> (SolverResult, StepByStepExplanation) {
127        // Individual solver implementations override this method
128        let explanation = StepByStepExplanation::new(vec![
129            Step::new("Analysis", format!("Analyzing equation: {}", self)),
130            Step::new("Method", "Determining appropriate solving method"),
131            Step::new(
132                "Status",
133                "This equation type requires a specialized solver implementation",
134            ),
135        ]);
136
137        (SolverResult::NoSolution, explanation)
138    }
139
140    fn explain_solving_steps(&self, variable: &Symbol) -> StepByStepExplanation {
141        StepByStepExplanation::new(vec![
142            Step::new("Equation", format!("Given: {} = 0", self)),
143            Step::new("Variable", format!("Solve for: {}", variable.name)),
144            Step::new("Method", "Applying appropriate solving algorithm"),
145        ])
146    }
147}
148
149// Utility functions for expression validation
150
151impl Expression {
152    /// Check if expression is a valid mathematical expression
153    pub fn is_valid_expression(&self) -> bool {
154        // Basic validation - can be expanded
155        match self {
156            Expression::Number(_) | Expression::Symbol(_) => true,
157            Expression::Add(terms) | Expression::Mul(terms) => {
158                !terms.is_empty() && terms.iter().all(|t| t.is_valid_expression())
159            }
160            Expression::Pow(base, exp) => base.is_valid_expression() && exp.is_valid_expression(),
161            Expression::Function { args, .. } => args.iter().all(|a| a.is_valid_expression()),
162            // New expression types - basic validity checks
163            Expression::Complex(complex_data) => {
164                complex_data.real.is_valid_expression() && complex_data.imag.is_valid_expression()
165            }
166            Expression::Matrix(matrix) => {
167                // Validate matrix dimensions and all elements
168                let (rows, cols) = matrix.dimensions();
169                if rows == 0 || cols == 0 || rows > 1000 || cols > 1000 {
170                    return false;
171                }
172
173                // Validate each element recursively
174                for i in 0..rows {
175                    for j in 0..cols {
176                        if !matrix.get_element(i, j).is_valid_expression() {
177                            return false;
178                        }
179                    }
180                }
181                true
182            }
183            Expression::Constant(_) => true,
184            Expression::Relation(relation_data) => {
185                relation_data.left.is_valid_expression()
186                    && relation_data.right.is_valid_expression()
187            }
188            Expression::Piecewise(piecewise_data) => {
189                piecewise_data
190                    .pieces
191                    .iter()
192                    .all(|(cond, val)| cond.is_valid_expression() && val.is_valid_expression())
193                    && piecewise_data
194                        .default
195                        .as_ref()
196                        .is_none_or(|d| d.is_valid_expression())
197            }
198            Expression::Set(elements) => elements.iter().all(|e| e.is_valid_expression()),
199            Expression::Interval(interval_data) => {
200                interval_data.start.is_valid_expression() && interval_data.end.is_valid_expression()
201            }
202            // Calculus types - unified validation
203            Expression::Calculus(calculus_data) => {
204                use crate::core::expression::CalculusData;
205                match calculus_data.as_ref() {
206                    CalculusData::Derivative { expression, .. } => expression.is_valid_expression(),
207                    CalculusData::Integral { integrand, .. } => integrand.is_valid_expression(),
208                    CalculusData::Limit { expression, .. } => expression.is_valid_expression(),
209                    CalculusData::Sum {
210                        expression,
211                        start,
212                        end,
213                        ..
214                    } => {
215                        expression.is_valid_expression()
216                            && start.is_valid_expression()
217                            && end.is_valid_expression()
218                    }
219                    CalculusData::Product {
220                        expression,
221                        start,
222                        end,
223                        ..
224                    } => {
225                        expression.is_valid_expression()
226                            && start.is_valid_expression()
227                            && end.is_valid_expression()
228                    }
229                }
230            }
231            Expression::MethodCall(method_data) => {
232                method_data.object.is_valid_expression()
233                    && method_data.args.iter().all(|a| a.is_valid_expression())
234            }
235        }
236    }
237
238    /// Convert to LaTeX representation for solvers (avoid conflict)
239    pub fn solver_to_latex(&self) -> String {
240        match self {
241            Expression::Number(n) => format!("{}", n),
242            Expression::Symbol(s) => s.name.to_string(),
243            Expression::Add(terms) => {
244                let term_strs: Vec<String> = terms.iter().map(|t| format!("{}", t)).collect();
245                term_strs.join(" + ")
246            }
247            Expression::Mul(factors) => {
248                let factor_strs: Vec<String> = factors.iter().map(|f| format!("{}", f)).collect();
249                factor_strs.join(" \\cdot ")
250            }
251            Expression::Pow(base, exp) => {
252                format!("{}^{{{}}}", base, exp)
253            }
254            Expression::Function { name, args } => {
255                let arg_strs: Vec<String> = args.iter().map(|a| format!("{}", a)).collect();
256                format!("\\{}({})", name, arg_strs.join(", "))
257            }
258            // New expression types - implement later
259            _ => "\\text{unknown}".to_owned(),
260        }
261    }
262
263    pub fn flatten_add_terms(&self) -> Vec<Expression> {
264        match self {
265            Expression::Add(terms) => terms
266                .iter()
267                .flat_map(|term| term.flatten_add_terms())
268                .collect(),
269            _ => vec![self.clone()],
270        }
271    }
272
273    /// Negate an expression
274    pub fn negate(&self) -> Expression {
275        // Distribute negation for canonical form: -(a + b) = -a + -b
276        match self {
277            Expression::Add(terms) => {
278                let negated_terms: Vec<Expression> = terms.iter().map(|t| t.negate()).collect();
279                Expression::add(negated_terms)
280            }
281            Expression::Number(Number::Integer(n)) => Expression::integer(-n),
282            Expression::Number(Number::Rational(r)) => {
283                Expression::Number(Number::rational(-(**r).clone()))
284            }
285            Expression::Mul(factors) if factors.len() == 2 => {
286                if let [Expression::Number(Number::Integer(-1)), expr] = &factors[..] {
287                    // -(-expr) = expr (double negation)
288                    expr.clone()
289                } else if let [Expression::Number(Number::Integer(n)), rest] = &factors[..] {
290                    // -(n * rest) = (-n) * rest
291                    Expression::mul(vec![Expression::integer(-n), rest.clone()])
292                } else {
293                    Expression::mul(vec![Expression::integer(-1), self.clone()])
294                }
295            }
296            _ => Expression::mul(vec![Expression::integer(-1), self.clone()]),
297        }
298    }
299}