1use crate::algebra::solvers::matrix_equations::MatrixEquationSolver;
5use crate::algebra::solvers::{EquationSolver, SolverResult};
6use crate::algebra::solvers::{LinearSolver, PolynomialSolver, QuadraticSolver, SystemSolver};
7use crate::calculus::ode::EducationalODESolver;
8use crate::calculus::pde::EducationalPDESolver;
9use crate::core::symbol::SymbolType;
10use crate::core::{Expression, Number, Symbol};
11use crate::educational::step_by_step::{Step, StepByStepExplanation};
12
13#[derive(Debug, Clone, Copy, PartialEq)]
15pub enum EquationType {
16 Constant,
17 Linear,
18 Quadratic,
19 Cubic,
20 Quartic,
21 System,
22 Transcendental,
23 Numerical,
24 Matrix,
25 ODE,
26 PDE,
27 Unknown,
28}
29
30pub struct EquationAnalyzer;
32
33impl EquationAnalyzer {
34 pub fn analyze(equation: &Expression, variable: &Symbol) -> EquationType {
36 let has_derivatives = Self::has_derivatives(equation);
37 let has_partial_derivatives = Self::has_partial_derivatives(equation);
38
39 if has_partial_derivatives {
40 return EquationType::PDE;
41 }
42
43 if has_derivatives {
44 return EquationType::ODE;
45 }
46
47 if Self::is_matrix_equation(equation, variable) {
48 return EquationType::Matrix;
49 }
50
51 let degree = Self::find_highest_degree(equation, variable);
52 let has_transcendental = Self::has_transcendental_functions(equation);
53 let variable_count = Self::count_variables(equation);
54
55 if Self::is_numerical_equation(equation, variable, degree, has_transcendental) {
56 return EquationType::Numerical;
57 }
58
59 match (degree, has_transcendental, variable_count) {
60 (0, false, _) => EquationType::Constant,
61 (1, false, 1) => EquationType::Linear,
62 (2, false, 1) => EquationType::Quadratic,
63 (3, false, 1) => EquationType::Cubic,
64 (4, false, 1) => EquationType::Quartic,
65 (_, false, 2..) => EquationType::System,
66 (_, true, _) => EquationType::Transcendental,
67 _ => EquationType::Unknown,
68 }
69 }
70
71 fn is_numerical_equation(
72 expr: &Expression,
73 _variable: &Symbol,
74 degree: u32,
75 has_transcendental: bool,
76 ) -> bool {
77 if degree > 4 {
78 return true;
79 }
80
81 if has_transcendental && degree > 0 {
82 return true;
83 }
84
85 if has_transcendental {
86 let func_count = Self::count_transcendental_functions(expr);
87 if func_count > 1 {
88 return true;
89 }
90 }
91
92 false
93 }
94
95 fn count_transcendental_functions(expr: &Expression) -> usize {
96 match expr {
97 Expression::Function { name, args } => {
98 let current =
99 if matches!(name.as_ref(), "sin" | "cos" | "tan" | "exp" | "ln" | "log") {
100 1
101 } else {
102 0
103 };
104 current
105 + args
106 .iter()
107 .map(Self::count_transcendental_functions)
108 .sum::<usize>()
109 }
110 Expression::Add(terms) => terms.iter().map(Self::count_transcendental_functions).sum(),
111 Expression::Mul(factors) => factors
112 .iter()
113 .map(Self::count_transcendental_functions)
114 .sum(),
115 Expression::Pow(base, exp) => {
116 Self::count_transcendental_functions(base)
117 + Self::count_transcendental_functions(exp)
118 }
119 _ => 0,
120 }
121 }
122
123 fn is_matrix_equation(expr: &Expression, _variable: &Symbol) -> bool {
124 Self::has_noncommutative_symbols(expr)
125 }
126
127 fn has_noncommutative_symbols(expr: &Expression) -> bool {
128 match expr {
129 Expression::Symbol(s) => {
130 matches!(
131 s.symbol_type(),
132 SymbolType::Matrix | SymbolType::Operator | SymbolType::Quaternion
133 )
134 }
135 Expression::Add(terms) | Expression::Mul(terms) => {
136 terms.iter().any(Self::has_noncommutative_symbols)
137 }
138 Expression::Pow(base, exp) => {
139 Self::has_noncommutative_symbols(base) || Self::has_noncommutative_symbols(exp)
140 }
141 Expression::Function { args, .. } => args.iter().any(Self::has_noncommutative_symbols),
142 _ => false,
143 }
144 }
145
146 fn find_highest_degree(expr: &Expression, variable: &Symbol) -> u32 {
147 match expr {
148 Expression::Pow(base, exp) if **base == Expression::symbol(variable.clone()) => {
149 match exp.as_ref() {
150 Expression::Number(Number::Integer(n)) => *n as u32,
151 _ => 1,
152 }
153 }
154 Expression::Mul(factors) => factors
155 .iter()
156 .map(|f| Self::find_highest_degree(f, variable))
157 .max()
158 .unwrap_or(0),
159 Expression::Add(terms) => terms
160 .iter()
161 .map(|t| Self::find_highest_degree(t, variable))
162 .max()
163 .unwrap_or(0),
164 _ if *expr == Expression::symbol(variable.clone()) => 1,
165 _ => 0,
166 }
167 }
168
169 fn has_transcendental_functions(expr: &Expression) -> bool {
170 match expr {
171 Expression::Function { name, args } => {
172 matches!(name.as_ref(), "sin" | "cos" | "tan" | "exp" | "ln" | "log")
173 || args.iter().any(Self::has_transcendental_functions)
174 }
175 Expression::Add(terms) => terms.iter().any(Self::has_transcendental_functions),
176 Expression::Mul(factors) => factors.iter().any(Self::has_transcendental_functions),
177 Expression::Pow(base, exp) => {
178 Self::has_transcendental_functions(base) || Self::has_transcendental_functions(exp)
179 }
180 _ => false,
181 }
182 }
183
184 fn count_variables(expr: &Expression) -> usize {
185 let mut variables = std::collections::HashSet::new();
186 Self::collect_variables(expr, &mut variables);
187 variables.len()
188 }
189
190 pub fn collect_variables(expr: &Expression, variables: &mut std::collections::HashSet<String>) {
191 match expr {
192 Expression::Symbol(s) => {
193 variables.insert(s.name().to_owned());
194 }
195 Expression::Add(terms) => {
196 for term in terms.iter() {
197 Self::collect_variables(term, variables);
198 }
199 }
200 Expression::Mul(factors) => {
201 for factor in factors.iter() {
202 Self::collect_variables(factor, variables);
203 }
204 }
205 Expression::Pow(base, exp) => {
206 Self::collect_variables(base, variables);
207 Self::collect_variables(exp, variables);
208 }
209 Expression::Function { args, .. } => {
210 for arg in args.iter() {
211 Self::collect_variables(arg, variables);
212 }
213 }
214 _ => {}
215 }
216 }
217
218 fn has_derivatives(expr: &Expression) -> bool {
219 match expr {
220 Expression::Function { name, args } => {
221 matches!(name.as_ref(), "derivative" | "diff" | "D")
222 || args.iter().any(Self::has_derivatives)
223 }
224 Expression::Symbol(s) => {
225 let name = s.name();
226 name.ends_with('\'') || name.contains("_prime")
227 }
228 Expression::Add(terms) => terms.iter().any(Self::has_derivatives),
229 Expression::Mul(factors) => factors.iter().any(Self::has_derivatives),
230 Expression::Pow(base, exp) => Self::has_derivatives(base) || Self::has_derivatives(exp),
231 _ => false,
232 }
233 }
234
235 fn has_partial_derivatives(expr: &Expression) -> bool {
236 match expr {
237 Expression::Function { name, args } => {
238 matches!(name.as_ref(), "partial" | "pdiff" | "Partial")
239 || args.iter().any(Self::has_partial_derivatives)
240 }
241 Expression::Symbol(s) => {
242 let name = s.name();
243 name.contains("partial") || name.contains("∂")
244 }
245 Expression::Add(terms) => terms.iter().any(Self::has_partial_derivatives),
246 Expression::Mul(factors) => factors.iter().any(Self::has_partial_derivatives),
247 Expression::Pow(base, exp) => {
248 Self::has_partial_derivatives(base) || Self::has_partial_derivatives(exp)
249 }
250 _ => false,
251 }
252 }
253}
254
255pub struct SmartEquationSolver {
257 linear_solver: LinearSolver,
258 quadratic_solver: QuadraticSolver,
259 system_solver: SystemSolver,
260 polynomial_solver: PolynomialSolver,
261 matrix_solver: MatrixEquationSolver,
262 ode_solver: EducationalODESolver,
263 pde_solver: EducationalPDESolver,
264}
265
266impl Default for SmartEquationSolver {
267 fn default() -> Self {
268 Self::new()
269 }
270}
271
272impl SmartEquationSolver {
273 pub fn new() -> Self {
274 Self {
275 linear_solver: LinearSolver::new(),
276 quadratic_solver: QuadraticSolver::new(),
277 system_solver: SystemSolver::new(),
278 polynomial_solver: PolynomialSolver::new(),
279 matrix_solver: MatrixEquationSolver::new(),
280 ode_solver: EducationalODESolver::new(),
281 pde_solver: EducationalPDESolver::new(),
282 }
283 }
284
285 pub fn solve_with_equation(
305 &self,
306 equation: &Expression,
307 variable: &Symbol,
308 ) -> (SolverResult, StepByStepExplanation) {
309 let mut all_steps = Vec::new();
310
311 let degree = EquationAnalyzer::find_highest_degree(equation, variable);
312 let eq_type = EquationAnalyzer::analyze(equation, variable);
313
314 let analysis_description = match eq_type {
315 EquationType::Constant => {
316 "Detected constant equation (no variables)".to_owned()
317 }
318 EquationType::Linear => {
319 format!("Detected linear equation (highest degree: {})", degree)
320 }
321 EquationType::Quadratic => {
322 format!("Detected quadratic equation (highest degree: {})", degree)
323 }
324 EquationType::Cubic => {
325 format!("Detected cubic equation (highest degree: {})", degree)
326 }
327 EquationType::Quartic => {
328 format!("Detected quartic equation (highest degree: {})", degree)
329 }
330 EquationType::System => {
331 "Detected system of equations (multiple variables)".to_owned()
332 }
333 EquationType::Transcendental => {
334 "Detected transcendental equation (contains trig/exp/log functions)".to_owned()
335 }
336 EquationType::Numerical => {
337 "Detected numerical equation (requires numerical methods - polynomial degree > 4 or mixed transcendental)".to_owned()
338 }
339 EquationType::Matrix => {
340 "Detected matrix equation (contains noncommutative symbols)".to_owned()
341 }
342 EquationType::ODE => {
343 "Detected ordinary differential equation (contains derivatives)".to_owned()
344 }
345 EquationType::PDE => {
346 "Detected partial differential equation (contains partial derivatives)".to_owned()
347 }
348 EquationType::Unknown => {
349 "Unknown equation type".to_owned()
350 }
351 };
352
353 all_steps.push(Step::new("Equation Analysis", analysis_description));
354
355 let solver_description = match eq_type {
356 EquationType::Linear => "Using linear equation solver (isolation method)",
357 EquationType::Quadratic => "Using quadratic equation solver (quadratic formula)",
358 EquationType::Cubic | EquationType::Quartic => "Using polynomial solver",
359 EquationType::System => "Using system equation solver",
360 EquationType::Numerical => {
361 "Using numerical solver (Newton-Raphson method with numerical differentiation)"
362 }
363 EquationType::Matrix => "Using matrix equation solver (left/right division)",
364 EquationType::ODE => "Using ODE solver (separable/linear/exact methods)",
365 EquationType::PDE => {
366 "Using PDE solver (method of characteristics/separation of variables)"
367 }
368 _ => "No specialized solver available for this equation type",
369 };
370
371 all_steps.push(Step::new("Solver Selection", solver_description));
372
373 let (result, solver_steps) = match eq_type {
374 EquationType::Linear => self
375 .linear_solver
376 .solve_with_explanation(equation, variable),
377 EquationType::Quadratic => self
378 .quadratic_solver
379 .solve_with_explanation(equation, variable),
380 EquationType::Cubic | EquationType::Quartic => self
381 .polynomial_solver
382 .solve_with_explanation(equation, variable),
383 EquationType::System => self
384 .system_solver
385 .solve_with_explanation(equation, variable),
386 EquationType::Numerical => self.solve_numerical(equation, variable),
387 EquationType::Matrix => self
388 .matrix_solver
389 .solve_with_explanation(equation, variable),
390 EquationType::ODE => self.ode_solver.solve_with_explanation(equation, variable),
391 EquationType::PDE => self.pde_solver.solve_with_explanation(equation, variable),
392 _ => {
393 all_steps.push(Step::new(
394 "Status",
395 "This equation type is not yet fully implemented",
396 ));
397 (SolverResult::NoSolution, StepByStepExplanation::new(vec![]))
398 }
399 };
400
401 all_steps.extend(solver_steps.steps);
402
403 (result, StepByStepExplanation::new(all_steps))
404 }
405
406 fn solve_numerical(
407 &self,
408 _equation: &Expression,
409 variable: &Symbol,
410 ) -> (SolverResult, StepByStepExplanation) {
411 let steps = vec![
412 Step::new(
413 "Numerical Method Required",
414 format!(
415 "This equation requires numerical methods to solve for {}. Newton-Raphson method integration is available.",
416 variable.name()
417 ),
418 ),
419 Step::new(
420 "Method Description",
421 "Newton-Raphson method with numerical differentiation provides robust convergence for smooth functions.",
422 ),
423 ];
424
425 (SolverResult::NoSolution, StepByStepExplanation::new(steps))
426 }
427
428 pub fn solve(&self) -> (SolverResult, StepByStepExplanation) {
430 let equation = Expression::integer(0);
431 let variables = self.extract_variables(&equation);
432 if variables.is_empty() {
433 return (SolverResult::NoSolution, StepByStepExplanation::new(vec![]));
434 }
435
436 let primary_var = &variables[0];
437 self.solve_with_equation(&equation, primary_var)
438 }
439
440 fn extract_variables(&self, equation: &Expression) -> Vec<Symbol> {
441 let mut variables = std::collections::HashSet::new();
442 EquationAnalyzer::collect_variables(equation, &mut variables);
443
444 variables
445 .into_iter()
446 .map(|name| Symbol::new(&name))
447 .collect()
448 }
449
450 pub fn solve_system(&self, equations: &[Expression], variables: &[Symbol]) -> SolverResult {
489 use crate::algebra::solvers::SystemEquationSolver;
490 self.system_solver.solve_system(equations, variables)
491 }
492}
493
494#[cfg(test)]
495mod tests {
496 use super::*;
497 use crate::symbol;
498
499 #[test]
500 fn test_equation_type_detection() {
501 let x = symbol!(x);
502
503 let linear = Expression::add(vec![
504 Expression::mul(vec![Expression::integer(2), Expression::symbol(x.clone())]),
505 Expression::integer(3),
506 ]);
507 assert_eq!(EquationAnalyzer::analyze(&linear, &x), EquationType::Linear);
508
509 let quadratic = Expression::add(vec![
510 Expression::pow(Expression::symbol(x.clone()), Expression::integer(2)),
511 Expression::mul(vec![Expression::integer(3), Expression::symbol(x.clone())]),
512 Expression::integer(2),
513 ]);
514 assert_eq!(
515 EquationAnalyzer::analyze(&quadratic, &x),
516 EquationType::Quadratic
517 );
518 }
519
520 #[test]
521 fn test_numerical_equation_detection() {
522 let x = symbol!(x);
523
524 let quintic = Expression::add(vec![
525 Expression::pow(Expression::symbol(x.clone()), Expression::integer(5)),
526 Expression::mul(vec![Expression::integer(-1), Expression::symbol(x.clone())]),
527 Expression::integer(-1),
528 ]);
529 assert_eq!(
530 EquationAnalyzer::analyze(&quintic, &x),
531 EquationType::Numerical
532 );
533
534 let transcendental_mixed = Expression::add(vec![
535 Expression::function("cos", vec![Expression::symbol(x.clone())]),
536 Expression::mul(vec![Expression::integer(-1), Expression::symbol(x.clone())]),
537 ]);
538 assert_eq!(
539 EquationAnalyzer::analyze(&transcendental_mixed, &x),
540 EquationType::Numerical
541 );
542 }
543
544 #[test]
545 fn test_matrix_equation_detection() {
546 let a = symbol!(A; matrix);
547 let x = symbol!(X; matrix);
548 let b = symbol!(B; matrix);
549
550 let equation = Expression::add(vec![
551 Expression::mul(vec![
552 Expression::symbol(a.clone()),
553 Expression::symbol(x.clone()),
554 ]),
555 Expression::mul(vec![Expression::integer(-1), Expression::symbol(b.clone())]),
556 ]);
557
558 assert_eq!(
559 EquationAnalyzer::analyze(&equation, &x),
560 EquationType::Matrix
561 );
562 }
563}