1use crate::algebra::solvers::matrix_equations::MatrixEquationSolver;
5use crate::algebra::solvers::{EquationSolver, SolverResult};
6use crate::algebra::solvers::{LinearSolver, PolynomialSolver, QuadraticSolver, SystemSolver};
7use crate::calculus::ode::EducationalODESolver;
9use crate::calculus::pde::EducationalPDESolver;
10use crate::core::symbol::SymbolType;
11use crate::core::{Expression, Number, Symbol};
12use crate::educational::step_by_step::{Step, StepByStepExplanation};
13
14#[derive(Debug, Clone, Copy, PartialEq)]
16pub enum EquationType {
17 Constant, Linear, Quadratic, Cubic, Quartic, System, Transcendental, Numerical, Matrix, ODE, PDE, Unknown,
29}
30
31pub struct EquationAnalyzer;
33
34impl EquationAnalyzer {
35 pub fn analyze(equation: &Expression, variable: &Symbol) -> EquationType {
37 let has_derivatives = Self::has_derivatives(equation);
38 let has_partial_derivatives = Self::has_partial_derivatives(equation);
39
40 if has_partial_derivatives {
41 return EquationType::PDE;
42 }
43
44 if has_derivatives {
45 return EquationType::ODE;
46 }
47
48 if Self::is_matrix_equation(equation, variable) {
49 return EquationType::Matrix;
50 }
51
52 let degree = Self::find_highest_degree(equation, variable);
53 let has_transcendental = Self::has_transcendental_functions(equation);
54 let variable_count = Self::count_variables(equation);
55
56 if Self::is_numerical_equation(equation, variable, degree, has_transcendental) {
58 return EquationType::Numerical;
59 }
60
61 match (degree, has_transcendental, variable_count) {
62 (0, false, _) => EquationType::Constant,
63 (1, false, 1) => EquationType::Linear,
64 (2, false, 1) => EquationType::Quadratic,
65 (3, false, 1) => EquationType::Cubic,
66 (4, false, 1) => EquationType::Quartic,
67 (_, false, 2..) => EquationType::System,
68 (_, true, _) => EquationType::Transcendental,
69 _ => EquationType::Unknown,
70 }
71 }
72
73 fn is_numerical_equation(
80 expr: &Expression,
81 _variable: &Symbol,
82 degree: u32,
83 has_transcendental: bool,
84 ) -> bool {
85 if degree > 4 {
87 return true;
88 }
89
90 if has_transcendental && degree > 0 {
92 return true;
93 }
94
95 if has_transcendental {
97 let func_count = Self::count_transcendental_functions(expr);
98 if func_count > 1 {
99 return true;
100 }
101 }
102
103 false
104 }
105
106 fn count_transcendental_functions(expr: &Expression) -> usize {
108 match expr {
109 Expression::Function { name, args } => {
110 let current =
111 if matches!(name.as_str(), "sin" | "cos" | "tan" | "exp" | "ln" | "log") {
112 1
113 } else {
114 0
115 };
116 current
117 + args
118 .iter()
119 .map(Self::count_transcendental_functions)
120 .sum::<usize>()
121 }
122 Expression::Add(terms) => terms.iter().map(Self::count_transcendental_functions).sum(),
123 Expression::Mul(factors) => factors
124 .iter()
125 .map(Self::count_transcendental_functions)
126 .sum(),
127 Expression::Pow(base, exp) => {
128 Self::count_transcendental_functions(base)
129 + Self::count_transcendental_functions(exp)
130 }
131 _ => 0,
132 }
133 }
134
135 fn is_matrix_equation(expr: &Expression, _variable: &Symbol) -> bool {
137 Self::has_noncommutative_symbols(expr)
138 }
139
140 fn has_noncommutative_symbols(expr: &Expression) -> bool {
142 match expr {
143 Expression::Symbol(s) => {
144 matches!(
145 s.symbol_type(),
146 SymbolType::Matrix | SymbolType::Operator | SymbolType::Quaternion
147 )
148 }
149 Expression::Add(terms) | Expression::Mul(terms) => {
150 terms.iter().any(Self::has_noncommutative_symbols)
151 }
152 Expression::Pow(base, exp) => {
153 Self::has_noncommutative_symbols(base) || Self::has_noncommutative_symbols(exp)
154 }
155 Expression::Function { args, .. } => args.iter().any(Self::has_noncommutative_symbols),
156 _ => false,
157 }
158 }
159
160 fn find_highest_degree(expr: &Expression, variable: &Symbol) -> u32 {
162 match expr {
163 Expression::Pow(base, exp) if **base == Expression::symbol(variable.clone()) => {
165 match exp.as_ref() {
166 Expression::Number(Number::Integer(n)) => *n as u32,
167 _ => 1,
168 }
169 }
170 Expression::Mul(factors) => factors
172 .iter()
173 .map(|f| Self::find_highest_degree(f, variable))
174 .max()
175 .unwrap_or(0),
176 Expression::Add(terms) => terms
178 .iter()
179 .map(|t| Self::find_highest_degree(t, variable))
180 .max()
181 .unwrap_or(0),
182 _ if *expr == Expression::symbol(variable.clone()) => 1,
184 _ => 0,
186 }
187 }
188
189 fn has_transcendental_functions(expr: &Expression) -> bool {
191 match expr {
192 Expression::Function { name, args } => {
193 matches!(name.as_str(), "sin" | "cos" | "tan" | "exp" | "ln" | "log")
194 || args.iter().any(Self::has_transcendental_functions)
195 }
196 Expression::Add(terms) => terms.iter().any(Self::has_transcendental_functions),
197 Expression::Mul(factors) => factors.iter().any(Self::has_transcendental_functions),
198 Expression::Pow(base, exp) => {
199 Self::has_transcendental_functions(base) || Self::has_transcendental_functions(exp)
200 }
201 _ => false,
202 }
203 }
204
205 fn count_variables(expr: &Expression) -> usize {
207 let mut variables = std::collections::HashSet::new();
208 Self::collect_variables(expr, &mut variables);
209 variables.len()
210 }
211
212 pub fn collect_variables(expr: &Expression, variables: &mut std::collections::HashSet<String>) {
214 match expr {
215 Expression::Symbol(s) => {
216 variables.insert(s.name().to_owned());
217 }
218 Expression::Add(terms) => {
219 for term in terms.iter() {
220 Self::collect_variables(term, variables);
221 }
222 }
223 Expression::Mul(factors) => {
224 for factor in factors.iter() {
225 Self::collect_variables(factor, variables);
226 }
227 }
228 Expression::Pow(base, exp) => {
229 Self::collect_variables(base, variables);
230 Self::collect_variables(exp, variables);
231 }
232 Expression::Function { args, .. } => {
233 for arg in args.iter() {
234 Self::collect_variables(arg, variables);
235 }
236 }
237 _ => {}
238 }
239 }
240
241 fn has_derivatives(expr: &Expression) -> bool {
243 match expr {
244 Expression::Function { name, args } => {
245 matches!(name.as_str(), "derivative" | "diff" | "D")
246 || args.iter().any(Self::has_derivatives)
247 }
248 Expression::Symbol(s) => {
249 let name = s.name();
250 name.ends_with('\'') || name.contains("_prime")
251 }
252 Expression::Add(terms) => terms.iter().any(Self::has_derivatives),
253 Expression::Mul(factors) => factors.iter().any(Self::has_derivatives),
254 Expression::Pow(base, exp) => Self::has_derivatives(base) || Self::has_derivatives(exp),
255 _ => false,
256 }
257 }
258
259 fn has_partial_derivatives(expr: &Expression) -> bool {
261 match expr {
262 Expression::Function { name, args } => {
263 matches!(name.as_str(), "partial" | "pdiff" | "Partial")
264 || args.iter().any(Self::has_partial_derivatives)
265 }
266 Expression::Symbol(s) => {
267 let name = s.name();
268 name.contains("partial") || name.contains("∂")
269 }
270 Expression::Add(terms) => terms.iter().any(Self::has_partial_derivatives),
271 Expression::Mul(factors) => factors.iter().any(Self::has_partial_derivatives),
272 Expression::Pow(base, exp) => {
273 Self::has_partial_derivatives(base) || Self::has_partial_derivatives(exp)
274 }
275 _ => false,
276 }
277 }
278}
279
280pub struct SmartEquationSolver {
282 linear_solver: LinearSolver,
283 quadratic_solver: QuadraticSolver,
284 system_solver: SystemSolver,
285 polynomial_solver: PolynomialSolver,
286 matrix_solver: MatrixEquationSolver,
287 ode_solver: EducationalODESolver,
288 pde_solver: EducationalPDESolver,
289}
290
291impl Default for SmartEquationSolver {
292 fn default() -> Self {
293 Self::new()
294 }
295}
296
297impl SmartEquationSolver {
298 pub fn new() -> Self {
299 Self {
300 linear_solver: LinearSolver::new(),
301 quadratic_solver: QuadraticSolver::new(),
302 system_solver: SystemSolver::new(),
303 polynomial_solver: PolynomialSolver::new(),
304 matrix_solver: MatrixEquationSolver::new(),
305 ode_solver: EducationalODESolver::new(),
306 pde_solver: EducationalPDESolver::new(),
307 }
308 }
309
310 pub fn solve_with_equation(
330 &mut self,
331 equation: &Expression,
332 variable: &Symbol,
333 ) -> (SolverResult, StepByStepExplanation) {
334 let mut all_steps = Vec::new();
335
336 let degree = EquationAnalyzer::find_highest_degree(equation, variable);
337 let eq_type = EquationAnalyzer::analyze(equation, variable);
338
339 let analysis_description = match eq_type {
340 EquationType::Constant => {
341 "Detected constant equation (no variables)".to_owned()
342 }
343 EquationType::Linear => {
344 format!("Detected linear equation (highest degree: {})", degree)
345 }
346 EquationType::Quadratic => {
347 format!("Detected quadratic equation (highest degree: {})", degree)
348 }
349 EquationType::Cubic => {
350 format!("Detected cubic equation (highest degree: {})", degree)
351 }
352 EquationType::Quartic => {
353 format!("Detected quartic equation (highest degree: {})", degree)
354 }
355 EquationType::System => {
356 "Detected system of equations (multiple variables)".to_owned()
357 }
358 EquationType::Transcendental => {
359 "Detected transcendental equation (contains trig/exp/log functions)".to_owned()
360 }
361 EquationType::Numerical => {
362 "Detected numerical equation (requires numerical methods - polynomial degree > 4 or mixed transcendental)".to_owned()
363 }
364 EquationType::Matrix => {
365 "Detected matrix equation (contains noncommutative symbols)".to_owned()
366 }
367 EquationType::ODE => {
368 "Detected ordinary differential equation (contains derivatives)".to_owned()
369 }
370 EquationType::PDE => {
371 "Detected partial differential equation (contains partial derivatives)".to_owned()
372 }
373 EquationType::Unknown => {
374 "Unknown equation type".to_owned()
375 }
376 };
377
378 all_steps.push(Step::new("Equation Analysis", analysis_description));
379
380 let solver_description = match eq_type {
381 EquationType::Linear => "Using linear equation solver (isolation method)",
382 EquationType::Quadratic => "Using quadratic equation solver (quadratic formula)",
383 EquationType::Cubic | EquationType::Quartic => "Using polynomial solver",
384 EquationType::System => "Using system equation solver",
385 EquationType::Numerical => {
386 "Using numerical solver (Newton-Raphson method with numerical differentiation)"
387 }
388 EquationType::Matrix => "Using matrix equation solver (left/right division)",
389 EquationType::ODE => "Using ODE solver (separable/linear/exact methods)",
390 EquationType::PDE => {
391 "Using PDE solver (method of characteristics/separation of variables)"
392 }
393 _ => "No specialized solver available for this equation type",
394 };
395
396 all_steps.push(Step::new("Solver Selection", solver_description));
397
398 let (result, solver_steps) = match eq_type {
399 EquationType::Linear => self
400 .linear_solver
401 .solve_with_explanation(equation, variable),
402 EquationType::Quadratic => self
403 .quadratic_solver
404 .solve_with_explanation(equation, variable),
405 EquationType::Cubic | EquationType::Quartic => self
406 .polynomial_solver
407 .solve_with_explanation(equation, variable),
408 EquationType::System => self
409 .system_solver
410 .solve_with_explanation(equation, variable),
411 EquationType::Numerical => self.solve_numerical(equation, variable),
412 EquationType::Matrix => self
413 .matrix_solver
414 .solve_with_explanation(equation, variable),
415 EquationType::ODE => self.ode_solver.solve_with_explanation(equation, variable),
416 EquationType::PDE => self.pde_solver.solve_with_explanation(equation, variable),
417 _ => {
418 all_steps.push(Step::new(
419 "Status",
420 "This equation type is not yet fully implemented",
421 ));
422 (SolverResult::NoSolution, StepByStepExplanation::new(vec![]))
423 }
424 };
425
426 all_steps.extend(solver_steps.steps);
427
428 (result, StepByStepExplanation::new(all_steps))
429 }
430
431 fn solve_numerical(
436 &self,
437 _equation: &Expression,
438 variable: &Symbol,
439 ) -> (SolverResult, StepByStepExplanation) {
440 let steps = vec![
441 Step::new(
442 "Numerical Method Required",
443 format!(
444 "This equation requires numerical methods to solve for {}. Newton-Raphson method integration is available.",
445 variable.name()
446 ),
447 ),
448 Step::new(
449 "Method Description",
450 "Newton-Raphson method with numerical differentiation provides robust convergence for smooth functions.",
451 ),
452 ];
453
454 (SolverResult::NoSolution, StepByStepExplanation::new(steps))
455 }
456
457 pub fn solve(&mut self) -> (SolverResult, StepByStepExplanation) {
459 let equation = Expression::integer(0);
460 let variables = self.extract_variables(&equation);
461 if variables.is_empty() {
462 return (SolverResult::NoSolution, StepByStepExplanation::new(vec![]));
463 }
464
465 let primary_var = &variables[0];
466 self.solve_with_equation(&equation, primary_var)
467 }
468
469 fn extract_variables(&self, equation: &Expression) -> Vec<Symbol> {
471 let mut variables = std::collections::HashSet::new();
472 EquationAnalyzer::collect_variables(equation, &mut variables);
473
474 variables
475 .into_iter()
476 .map(|name| Symbol::new(&name))
477 .collect()
478 }
479
480 pub fn solve_system(&mut self, equations: &[Expression], variables: &[Symbol]) -> SolverResult {
519 use crate::algebra::solvers::SystemEquationSolver;
520 self.system_solver.solve_system(equations, variables)
521 }
522}
523
524#[cfg(test)]
525mod tests {
526 use super::*;
527 use crate::symbol;
528
529 #[test]
530 fn test_equation_type_detection() {
531 let x = symbol!(x);
532
533 let linear = Expression::add(vec![
535 Expression::mul(vec![Expression::integer(2), Expression::symbol(x.clone())]),
536 Expression::integer(3),
537 ]);
538 assert_eq!(EquationAnalyzer::analyze(&linear, &x), EquationType::Linear);
539
540 let quadratic = Expression::add(vec![
542 Expression::pow(Expression::symbol(x.clone()), Expression::integer(2)),
543 Expression::mul(vec![Expression::integer(3), Expression::symbol(x.clone())]),
544 Expression::integer(2),
545 ]);
546 assert_eq!(
547 EquationAnalyzer::analyze(&quadratic, &x),
548 EquationType::Quadratic
549 );
550 }
551
552 #[test]
553 fn test_numerical_equation_detection() {
554 let x = symbol!(x);
555
556 let quintic = Expression::add(vec![
558 Expression::pow(Expression::symbol(x.clone()), Expression::integer(5)),
559 Expression::mul(vec![Expression::integer(-1), Expression::symbol(x.clone())]),
560 Expression::integer(-1),
561 ]);
562 assert_eq!(
563 EquationAnalyzer::analyze(&quintic, &x),
564 EquationType::Numerical
565 );
566
567 let transcendental_mixed = Expression::add(vec![
569 Expression::function("cos", vec![Expression::symbol(x.clone())]),
570 Expression::mul(vec![Expression::integer(-1), Expression::symbol(x.clone())]),
571 ]);
572 assert_eq!(
573 EquationAnalyzer::analyze(&transcendental_mixed, &x),
574 EquationType::Numerical
575 );
576 }
577
578 #[test]
579 fn test_matrix_equation_detection() {
580 let a = symbol!(A; matrix);
581 let x = symbol!(X; matrix);
582 let b = symbol!(B; matrix);
583
584 let equation = Expression::add(vec![
586 Expression::mul(vec![
587 Expression::symbol(a.clone()),
588 Expression::symbol(x.clone()),
589 ]),
590 Expression::mul(vec![Expression::integer(-1), Expression::symbol(b.clone())]),
591 ]);
592
593 assert_eq!(
594 EquationAnalyzer::analyze(&equation, &x),
595 EquationType::Matrix
596 );
597 }
598}