mathhook_core/algebra/solvers/
linear.rs

1//! Solves equations of the form ax + b = 0
2//! Includes step-by-step explanations for educational value
3
4use crate::algebra::Expand;
5use crate::core::constants::EPSILON;
6use crate::core::{Commutativity, Expression, Number, Symbol};
7use crate::educational::step_by_step::{Step, StepByStepExplanation};
8// Temporarily simplified for TDD success
9use crate::algebra::solvers::{EquationSolver, SolverResult};
10use crate::simplify::Simplify;
11use num_bigint::BigInt;
12use num_rational::BigRational;
13
14/// Handles linear equations with step-by-step explanations
15#[derive(Debug, Clone)]
16pub struct LinearSolver {
17    /// Enable step-by-step explanations
18    pub show_steps: bool,
19}
20
21impl Default for LinearSolver {
22    fn default() -> Self {
23        Self::new()
24    }
25}
26
27impl LinearSolver {
28    /// Create new linear solver
29    pub fn new() -> Self {
30        Self { show_steps: true }
31    }
32
33    /// Create solver without step-by-step (for performance)
34    pub fn new_fast() -> Self {
35        Self { show_steps: false }
36    }
37}
38
39impl EquationSolver for LinearSolver {
40    /// Solve linear equation ax + b = 0
41    ///
42    /// Fractional solutions are automatically simplified to lowest terms via
43    /// `BigRational::new()`, which reduces fractions using GCD. Integer solutions
44    /// (where numerator is divisible by denominator) are returned as integers.
45    ///
46    /// # Examples
47    ///
48    /// ```rust
49    /// use mathhook_core::algebra::solvers::{linear::LinearSolver, EquationSolver, SolverResult};
50    /// use mathhook_core::core::{Expression, Number};
51    /// use mathhook_core::symbol;
52    /// use num_bigint::BigInt;
53    ///
54    /// let solver = LinearSolver::new_fast();
55    /// let x = symbol!(x);
56    ///
57    /// // Example: 4x = 6 gives x = 3/2 (simplified from 6/4)
58    /// let equation = Expression::add(vec![
59    ///     Expression::mul(vec![Expression::integer(4), Expression::symbol(x.clone())]),
60    ///     Expression::integer(-6),
61    /// ]);
62    ///
63    /// match solver.solve(&equation, &x) {
64    ///     SolverResult::Single(solution) => {
65    ///         if let Expression::Number(Number::Rational(r)) = solution {
66    ///             assert_eq!(r.numer(), &BigInt::from(3));
67    ///             assert_eq!(r.denom(), &BigInt::from(2));
68    ///         }
69    ///     }
70    ///     _ => panic!("Expected single solution"),
71    /// }
72    /// ```
73    #[inline(always)]
74    fn solve(&self, equation: &Expression, variable: &Symbol) -> SolverResult {
75        // Handle Relation type (equations like x = 5)
76        let equation_expr = if let Expression::Relation(data) = equation {
77            // Convert relation to expression: left - right = 0
78            Expression::add(vec![
79                data.left.clone(),
80                Expression::mul(vec![Expression::integer(-1), data.right.clone()]),
81            ])
82        } else {
83            equation.clone()
84        };
85
86        // Check for noncommutative symbols - delegate to MatrixEquationSolver if found
87        if equation_expr.commutativity() != Commutativity::Commutative {
88            use crate::algebra::solvers::matrix_equations::MatrixEquationSolver;
89            let matrix_solver = MatrixEquationSolver::new_fast();
90            return matrix_solver.solve(&equation_expr, variable);
91        }
92
93        // Simplify and expand equation to flatten nested structures and distribute multiplication
94        let simplified_equation = equation_expr.simplify().expand();
95
96        // Check for identity equations (0 = 0) or contradictions AFTER simplification
97        if simplified_equation.is_zero() {
98            // If equation simplified to just 0, it means 0 = 0 (infinite solutions)
99            return SolverResult::InfiniteSolutions;
100        }
101        // Check for non-zero constant (contradiction)
102        if let Expression::Number(Number::Integer(n)) = simplified_equation {
103            if n != 0 {
104                return SolverResult::NoSolution;
105            }
106        }
107
108        // Check for factored form: (x - a)(x - b)...(x - n) = 0
109        if let Some(roots) = self.extract_factored_roots(&simplified_equation, variable) {
110            if roots.len() == 1 {
111                return SolverResult::Single(roots[0].clone());
112            } else if roots.len() > 1 {
113                return SolverResult::Multiple(roots);
114            }
115        }
116
117        // Extract coefficients from simplified linear equation
118        let (a, b) = self.extract_linear_coefficients(&simplified_equation, variable);
119
120        // Smart solver: Analyze original equation structure before simplification
121
122        // Check if original equation has patterns like 0*x + constant
123        if let Some(special_result) = self.detect_special_linear_cases(&equation_expr, variable) {
124            return special_result;
125        }
126
127        // Extract coefficients for normal linear analysis
128        let a_simplified = a.simplify();
129        let b_simplified = b.simplify();
130
131        if a_simplified.is_zero() {
132            if b_simplified.is_zero() {
133                return SolverResult::InfiniteSolutions; // 0x + 0 = 0
134            } else {
135                return SolverResult::NoSolution; // 0x + b = 0 where b ≠ 0
136            }
137        }
138
139        // Solve ax + b = 0 → x = -b/a
140        // Fractions are automatically reduced to lowest terms by BigRational::new()
141
142        // Check if we can solve numerically
143        match (&a_simplified, &b_simplified) {
144            (
145                Expression::Number(Number::Integer(a_val)),
146                Expression::Number(Number::Integer(b_val)),
147            ) => {
148                if *a_val != 0 {
149                    // Simple case: ax + b = 0 → x = -b/a
150                    let result = -b_val / a_val;
151                    if b_val % a_val == 0 {
152                        // Integer solution: return as integer (e.g., 10/5 = 2)
153                        SolverResult::Single(Expression::integer(result))
154                    } else {
155                        // Fractional solution: BigRational::new() automatically reduces to lowest terms
156                        // Example: 6/4 → 3/2, 18/12 → 3/2
157                        SolverResult::Single(Expression::Number(Number::rational(
158                            BigRational::new(BigInt::from(-b_val), BigInt::from(*a_val)),
159                        )))
160                    }
161                } else {
162                    SolverResult::NoSolution
163                }
164            }
165            _ => {
166                // General case - use simplified coefficients
167                let neg_b = b_simplified.negate().simplify();
168                let solution = Self::divide_expressions(&neg_b, &a_simplified).simplify();
169
170                // Try to evaluate the solution numerically if possible
171                let final_solution = Self::try_eval_numeric_internal(&solution);
172                SolverResult::Single(final_solution)
173            }
174        }
175    }
176
177    /// Solve with step-by-step explanation
178    fn solve_with_explanation(
179        &self,
180        equation: &Expression,
181        variable: &Symbol,
182    ) -> (SolverResult, StepByStepExplanation) {
183        let simplified_equation = equation.simplify();
184        let (a, b) = self.extract_linear_coefficients(&simplified_equation, variable);
185
186        if a.is_zero() {
187            return self.handle_special_case_with_style(&b);
188        }
189
190        let a_simplified = a.simplify();
191        let b_simplified = b.simplify();
192        let neg_b = b_simplified.negate().simplify();
193        let solution = Self::divide_expressions(&neg_b, &a_simplified).simplify();
194
195        let steps = vec![
196            Step::new(
197                "Given Equation",
198                format!("We need to solve: {} = 0", equation),
199            ),
200            Step::new(
201                "Strategy",
202                format!("Isolate {} using inverse operations", variable.name),
203            ),
204            Step::new(
205                "Identify Form",
206                format!("This has form: {}·{} + {} = 0", a, variable.name, b),
207            ),
208            Step::new(
209                "Calculate",
210                format!("{} = -({}) ÷ {} = {}", variable.name, b, a, solution),
211            ),
212            Step::new("Solution", format!("{} = {}", variable.name, solution)),
213        ];
214        let explanation = StepByStepExplanation::new(steps);
215
216        (SolverResult::Single(solution), explanation)
217    }
218
219    /// Check if this solver can handle the equation
220    fn can_solve(&self, equation: &Expression) -> bool {
221        // Check if equation is linear in any variable
222        self.is_linear_equation(equation)
223    }
224}
225
226impl LinearSolver {
227    /// Handle special cases with step explanations
228    fn handle_special_case_with_style(
229        &self,
230        b: &Expression,
231    ) -> (SolverResult, StepByStepExplanation) {
232        if b.is_zero() {
233            let steps = vec![
234                Step::new("Special Case", "0x + 0 = 0 is always true"),
235                Step::new("Result", "Infinite solutions - any value of x works"),
236            ];
237            (
238                SolverResult::InfiniteSolutions,
239                StepByStepExplanation::new(steps),
240            )
241        } else {
242            let steps = vec![
243                Step::new("Special Case", format!("0x + {} = 0 means {} = 0", b, b)),
244                Step::new(
245                    "Contradiction",
246                    format!("But {} ≠ 0, so no solution exists", b),
247                ),
248            ];
249            (SolverResult::NoSolution, StepByStepExplanation::new(steps))
250        }
251    }
252    /// Extract coefficients a and b from equation ax + b = 0
253    #[inline(always)]
254    fn extract_linear_coefficients(
255        &self,
256        equation: &Expression,
257        variable: &Symbol,
258    ) -> (Expression, Expression) {
259        // First, flatten all nested Add expressions
260        let flattened_terms = equation.flatten_add_terms();
261
262        let mut coefficient = Expression::integer(0); // Coefficient of variable
263        let mut constant = Expression::integer(0); // Constant term
264
265        for term in flattened_terms.iter() {
266            match term {
267                Expression::Symbol(s) if s == variable => {
268                    coefficient = Expression::add(vec![coefficient, Expression::integer(1)]);
269                }
270                Expression::Mul(factors) => {
271                    let mut var_coeff = Expression::integer(1);
272                    let mut has_variable = false;
273
274                    for factor in factors.iter() {
275                        match factor {
276                            Expression::Symbol(s) if s == variable => {
277                                has_variable = true;
278                            }
279                            _ => {
280                                var_coeff = Expression::mul(vec![var_coeff, factor.clone()]);
281                            }
282                        }
283                    }
284
285                    if has_variable {
286                        coefficient = Expression::add(vec![coefficient, var_coeff]);
287                    } else {
288                        constant = Expression::add(vec![constant, term.clone()]);
289                    }
290                }
291                _ => {
292                    // Constant term
293                    constant = Expression::add(vec![constant, term.clone()]);
294                }
295            }
296        }
297        (coefficient, constant)
298    }
299
300    /// Check if equation is linear
301    fn is_linear_equation(&self, equation: &Expression) -> bool {
302        matches!(
303            equation,
304            Expression::Add(_) | Expression::Symbol(_) | Expression::Number(_)
305        )
306    }
307
308    /// Detect special linear cases before simplification
309    #[inline(always)]
310    fn detect_special_linear_cases(
311        &self,
312        equation: &Expression,
313        variable: &Symbol,
314    ) -> Option<SolverResult> {
315        match equation {
316            Expression::Add(terms) if terms.len() == 2 => {
317                // Check for patterns: 0*x + constant
318                if let [Expression::Mul(factors), constant] = &terms[..] {
319                    if factors.len() == 2 {
320                        if let [Expression::Number(Number::Integer(0)), var] = &factors[..] {
321                            if var == &Expression::symbol(variable.clone()) {
322                                // Found 0*x + constant pattern
323                                match constant {
324                                    Expression::Number(Number::Integer(0)) => {
325                                        return Some(SolverResult::InfiniteSolutions);
326                                        // 0*x + 0 = 0
327                                    }
328                                    _ => {
329                                        return Some(SolverResult::NoSolution); // 0*x + nonzero = 0
330                                    }
331                                }
332                            }
333                        }
334                    }
335                }
336            }
337            _ => {}
338        }
339        None // No special case detected
340    }
341
342    /// Extract roots from factored polynomial form: (x - a)(x - b) = 0
343    fn extract_factored_roots(
344        &self,
345        expr: &Expression,
346        variable: &Symbol,
347    ) -> Option<Vec<Expression>> {
348        match expr {
349            Expression::Mul(factors) => {
350                let mut roots = Vec::new();
351
352                for factor in factors.iter() {
353                    // Check if this factor is (x - constant) or (constant - x)
354                    if let Expression::Add(terms) = factor {
355                        if terms.len() == 2 {
356                            // Check pattern: x + (-a) = 0 → x = a
357                            if let [Expression::Symbol(s), Expression::Mul(neg_factors)] =
358                                &terms[..]
359                            {
360                                if s == variable && neg_factors.len() == 2 {
361                                    if let [Expression::Number(Number::Integer(-1)), constant] =
362                                        &neg_factors[..]
363                                    {
364                                        roots.push(constant.clone());
365                                        continue;
366                                    }
367                                }
368                            }
369                            // Check pattern: -a + x = 0 → x = a
370                            if let [Expression::Mul(neg_factors), Expression::Symbol(s)] =
371                                &terms[..]
372                            {
373                                if s == variable && neg_factors.len() == 2 {
374                                    if let [Expression::Number(Number::Integer(-1)), constant] =
375                                        &neg_factors[..]
376                                    {
377                                        roots.push(constant.clone());
378                                        continue;
379                                    }
380                                }
381                            }
382                        }
383                    }
384                }
385
386                if roots.is_empty() {
387                    None
388                } else {
389                    Some(roots)
390                }
391            }
392            _ => None,
393        }
394    }
395
396    /// Internal domain-specific optimization for linear solver
397    ///
398    /// Evaluate expressions with fraction handling for linear equation solutions.
399    /// This is a specialized version optimized for the linear solver's needs.
400    ///
401    /// Static helper function - doesn't depend on instance state.
402    #[inline(always)]
403    fn try_eval_numeric_internal(expr: &Expression) -> Expression {
404        match expr {
405            // Handle -1 * (complex expression)
406            Expression::Mul(factors) if factors.len() == 2 => {
407                if let [Expression::Number(Number::Integer(-1)), complex_expr] = &factors[..] {
408                    // Evaluate the complex expression and negate it
409                    let evaluated = Self::eval_exact_internal(complex_expr);
410                    evaluated.negate().simplify()
411                } else {
412                    expr.clone()
413                }
414            }
415            // Handle fractions that should be evaluated
416            Expression::Function { name, args } if name == "fraction" && args.len() == 2 => {
417                Self::eval_exact_internal(expr)
418            }
419            _ => expr.clone(),
420        }
421    }
422
423    /// Internal domain-specific optimization for linear solver
424    ///
425    /// Static helper function for exact arithmetic evaluation.
426    /// Preserves exact arithmetic (integers/rationals) without instance state dependency.
427    ///
428    /// This method is kept separate from Expression::evaluate_to_f64() because it maintains
429    /// mathematical exactness. For example, 1/3 stays as Rational(1,3), not 0.333...
430    ///
431    /// # Why Not Use evaluate_to_f64()?
432    ///
433    /// - evaluate_to_f64() converts to f64 (loses precision: 1/3 → 0.333...)
434    /// - This method preserves rationals (keeps exactness: 1/3 → Rational(1,3))
435    /// - Linear equation solutions often require exact fractions (e.g., x = 2/3)
436    ///
437    /// # Automatic Fraction Simplification
438    ///
439    /// When creating rational numbers via `BigRational::new(num, den)`, fractions are
440    /// automatically reduced to lowest terms using GCD. For example:
441    /// - `BigRational::new(6, 4)` → 3/2
442    /// - `BigRational::new(18, 12)` → 3/2
443    /// - `BigRational::new(10, 5)` → 2 (returned as integer if denominator is 1)
444    ///
445    /// # Returns
446    ///
447    /// - Expression::Number(Integer) for exact integer results
448    /// - Expression::Number(Rational) for exact fractional results (automatically simplified)
449    /// - Original expression if cannot be evaluated exactly
450    #[inline(always)]
451    fn eval_exact_internal(expr: &Expression) -> Expression {
452        match expr {
453            Expression::Add(terms) => {
454                let mut total = 0i64;
455                for term in terms.iter() {
456                    match Self::eval_exact_internal(term) {
457                        Expression::Number(Number::Integer(n)) => total += n,
458                        _ => return expr.clone(), // Can't evaluate
459                    }
460                }
461                Expression::integer(total)
462            }
463            Expression::Mul(factors) => {
464                let mut product = 1i64;
465                for factor in factors.iter() {
466                    match Self::eval_exact_internal(factor) {
467                        Expression::Number(Number::Integer(n)) => product *= n,
468                        _ => return expr.clone(), // Can't evaluate
469                    }
470                }
471                Expression::integer(product)
472            }
473            // Handle fraction functions: fraction(numerator, denominator)
474            // BigRational::new() automatically reduces to lowest terms
475            Expression::Function { name, args } if name == "fraction" && args.len() == 2 => {
476                // First evaluate the numerator and denominator
477                let num_eval = Self::eval_exact_internal(&args[0]);
478                let den_eval = Self::eval_exact_internal(&args[1]);
479
480                match (&num_eval, &den_eval) {
481                    (
482                        Expression::Number(Number::Float(num)),
483                        Expression::Number(Number::Float(den)),
484                    ) => {
485                        if den.abs() >= EPSILON {
486                            let result = num / den;
487                            if result.fract().abs() < EPSILON {
488                                Expression::integer(result as i64)
489                            } else {
490                                Expression::Number(Number::float(result))
491                            }
492                        } else {
493                            expr.clone()
494                        }
495                    }
496                    (
497                        Expression::Number(Number::Integer(num)),
498                        Expression::Number(Number::Integer(den)),
499                    ) => {
500                        if *den != 0 {
501                            if num % den == 0 {
502                                Expression::integer(num / den)
503                            } else {
504                                // BigRational::new() automatically reduces to lowest terms via GCD
505                                Expression::Number(Number::rational(BigRational::new(
506                                    BigInt::from(*num),
507                                    BigInt::from(*den),
508                                )))
509                            }
510                        } else {
511                            expr.clone()
512                        }
513                    }
514                    _ => expr.clone(),
515                }
516            }
517            Expression::Number(_) => expr.clone(),
518            _ => expr.clone(),
519        }
520    }
521
522    /// Divide two expressions (simplified division)
523    ///
524    /// Static helper function for recursive division operations.
525    /// Does not require instance state, only performs expression manipulation.
526    ///
527    /// Fractions created via `BigRational::new()` are automatically reduced
528    /// to lowest terms using GCD.
529    #[inline(always)]
530    fn divide_expressions(numerator: &Expression, denominator: &Expression) -> Expression {
531        // First simplify both expressions
532        let num_simplified = numerator.simplify();
533        let den_simplified = denominator.simplify();
534
535        match (&num_simplified, &den_simplified) {
536            // Simple integer division
537            // BigRational::new() automatically reduces to lowest terms
538            (Expression::Number(Number::Integer(n)), Expression::Number(Number::Integer(d))) => {
539                if *d != 0 {
540                    if n % d == 0 {
541                        Expression::integer(n / d)
542                    } else {
543                        // Create rational number - automatically reduced to lowest terms
544                        Expression::Number(Number::rational(BigRational::new(
545                            BigInt::from(*n),
546                            BigInt::from(*d),
547                        )))
548                    }
549                } else {
550                    // Division by zero - should be handled as error
551                    Expression::integer(0) // Placeholder
552                }
553            }
554            // Integer divided by rational: a / (p/q) = a * (q/p)
555            (Expression::Number(Number::Integer(n)), Expression::Number(Number::Rational(r))) => {
556                // a / (p/q) = a * q / p
557                let inverted = BigRational::new(r.denom().clone(), r.numer().clone());
558                let result = BigRational::from(BigInt::from(*n)) * inverted;
559
560                // Simplify to integer if possible
561                if result.is_integer() {
562                    Expression::integer(result.numer().to_string().parse().unwrap())
563                } else {
564                    Expression::Number(Number::rational(result))
565                }
566            }
567            // Rational divided by integer: (p/q) / a = (p/q) / a = p/(q*a)
568            (Expression::Number(Number::Rational(r)), Expression::Number(Number::Integer(d))) => {
569                if *d != 0 {
570                    let result = (**r).clone() / BigRational::from(BigInt::from(*d));
571                    if result.is_integer() {
572                        Expression::integer(result.numer().to_string().parse().unwrap())
573                    } else {
574                        Expression::Number(Number::rational(result))
575                    }
576                } else {
577                    Expression::integer(0) // Placeholder
578                }
579            }
580            // Rational divided by rational
581            (
582                Expression::Number(Number::Rational(num_r)),
583                Expression::Number(Number::Rational(den_r)),
584            ) => {
585                let result = (**num_r).clone() / (**den_r).clone();
586                if result.is_integer() {
587                    Expression::integer(result.numer().to_string().parse().unwrap())
588                } else {
589                    Expression::Number(Number::rational(result))
590                }
591            }
592            // Try to simplify further - if denominator is 1, just return numerator
593            (num, Expression::Number(Number::Integer(1))) => num.clone(),
594            // Handle multiplication by -1 and other simple cases
595            (Expression::Mul(factors), den) if factors.len() == 2 => {
596                if let [Expression::Number(Number::Integer(-1)), expr] = &factors[..] {
597                    // -1 * expr / den = -(expr / den)
598                    let inner_div = Self::divide_expressions(expr, den);
599                    Expression::mul(vec![Expression::integer(-1), inner_div]).simplify()
600                } else {
601                    // General case
602                    let fraction =
603                        Expression::function("fraction", vec![num_simplified, den_simplified]);
604                    fraction.simplify()
605                }
606            }
607            // For linear solver, try to evaluate numerically if possible
608            _ => {
609                // Return as fraction function and let it simplify
610                let fraction =
611                    Expression::function("fraction", vec![num_simplified, den_simplified]);
612                fraction.simplify()
613            }
614        }
615    }
616}
617
618#[cfg(test)]
619mod tests {
620    use super::*;
621    use crate::symbol;
622
623    #[test]
624    fn test_coefficient_extraction() {
625        let x = symbol!(x);
626        let solver = LinearSolver::new();
627
628        // Test 2x + 3
629        let equation = Expression::add(vec![
630            Expression::mul(vec![Expression::integer(2), Expression::symbol(x.clone())]),
631            Expression::integer(3),
632        ]);
633
634        let (a, b) = solver.extract_linear_coefficients(&equation, &x);
635        // The coefficient might be Mul([1, 2]) so we need to simplify it
636        assert_eq!(a.simplify(), Expression::integer(2));
637        assert_eq!(b.simplify(), Expression::integer(3));
638    }
639
640    #[test]
641    fn test_linear_detection() {
642        let x = symbol!(x);
643        let solver = LinearSolver::new();
644
645        // Linear equation
646        let linear = Expression::add(vec![Expression::symbol(x.clone()), Expression::integer(1)]);
647        assert!(solver.is_linear_equation(&linear));
648
649        // Non-linear equation (power)
650        let nonlinear = Expression::pow(Expression::symbol(x.clone()), Expression::integer(2));
651        assert!(!solver.is_linear_equation(&nonlinear));
652    }
653}