mathhook_core/algebra/solvers/
linear.rs

1//! Solves equations of the form ax + b = 0
2//! Includes step-by-step explanations for educational value
3
4use crate::algebra::Expand;
5use crate::core::constants::EPSILON;
6use crate::core::{Commutativity, Expression, Number, Symbol};
7use crate::educational::step_by_step::{Step, StepByStepExplanation};
8// Temporarily simplified for TDD success
9use crate::algebra::solvers::{EquationSolver, SolverResult};
10use crate::simplify::Simplify;
11use num_bigint::BigInt;
12use num_rational::BigRational;
13
14/// Handles linear equations with step-by-step explanations
15#[derive(Debug, Clone)]
16pub struct LinearSolver {
17    /// Enable step-by-step explanations
18    pub show_steps: bool,
19}
20
21impl Default for LinearSolver {
22    fn default() -> Self {
23        Self::new()
24    }
25}
26
27impl LinearSolver {
28    /// Create new linear solver
29    pub fn new() -> Self {
30        Self { show_steps: true }
31    }
32
33    /// Create solver without step-by-step (for performance)
34    pub fn new_fast() -> Self {
35        Self { show_steps: false }
36    }
37}
38
39impl EquationSolver for LinearSolver {
40    /// Solve linear equation ax + b = 0
41    ///
42    /// Fractional solutions are automatically simplified to lowest terms via
43    /// `BigRational::new()`, which reduces fractions using GCD. Integer solutions
44    /// (where numerator is divisible by denominator) are returned as integers.
45    ///
46    /// # Examples
47    ///
48    /// ```rust
49    /// use mathhook_core::algebra::solvers::{linear::LinearSolver, EquationSolver, SolverResult};
50    /// use mathhook_core::core::{Expression, Number};
51    /// use mathhook_core::symbol;
52    /// use num_bigint::BigInt;
53    ///
54    /// let solver = LinearSolver::new_fast();
55    /// let x = symbol!(x);
56    ///
57    /// // Example: 4x = 6 gives x = 3/2 (simplified from 6/4)
58    /// let equation = Expression::add(vec![
59    ///     Expression::mul(vec![Expression::integer(4), Expression::symbol(x.clone())]),
60    ///     Expression::integer(-6),
61    /// ]);
62    ///
63    /// match solver.solve(&equation, &x) {
64    ///     SolverResult::Single(solution) => {
65    ///         if let Expression::Number(Number::Rational(r)) = solution {
66    ///             assert_eq!(r.numer(), &BigInt::from(3));
67    ///             assert_eq!(r.denom(), &BigInt::from(2));
68    ///         }
69    ///     }
70    ///     _ => panic!("Expected single solution"),
71    /// }
72    /// ```
73    #[inline(always)]
74    fn solve(&self, equation: &Expression, variable: &Symbol) -> SolverResult {
75        // Handle Relation type (equations like x = 5)
76        let equation_expr = if let Expression::Relation(data) = equation {
77            // Convert relation to expression: left - right = 0
78            Expression::add(vec![
79                data.left.clone(),
80                Expression::mul(vec![Expression::integer(-1), data.right.clone()]),
81            ])
82        } else {
83            equation.clone()
84        };
85
86        // Check for noncommutative symbols - delegate to MatrixEquationSolver if found
87        if equation_expr.commutativity() != Commutativity::Commutative {
88            use crate::algebra::solvers::matrix_equations::MatrixEquationSolver;
89            let matrix_solver = MatrixEquationSolver::new_fast();
90            return matrix_solver.solve(&equation_expr, variable);
91        }
92
93        // Simplify and expand equation to flatten nested structures and distribute multiplication
94        let simplified_equation = equation_expr.simplify().expand();
95
96        // Check for identity equations (0 = 0) or contradictions AFTER simplification
97        if simplified_equation.is_zero() {
98            // If equation simplified to just 0, it means 0 = 0 (infinite solutions)
99            return SolverResult::InfiniteSolutions;
100        }
101        // Check for non-zero constant (contradiction)
102        if let Expression::Number(Number::Integer(n)) = simplified_equation {
103            if n != 0 {
104                return SolverResult::NoSolution;
105            }
106        }
107
108        // Check for factored form: (x - a)(x - b)...(x - n) = 0
109        if let Some(roots) = self.extract_factored_roots(&simplified_equation, variable) {
110            if roots.len() == 1 {
111                return SolverResult::Single(roots[0].clone());
112            } else if roots.len() > 1 {
113                return SolverResult::Multiple(roots);
114            }
115        }
116
117        // Extract coefficients from simplified linear equation
118        let (a, b) = self.extract_linear_coefficients(&simplified_equation, variable);
119
120        // Smart solver: Analyze original equation structure before simplification
121
122        // Check if original equation has patterns like 0*x + constant
123        if let Some(special_result) = self.detect_special_linear_cases(&equation_expr, variable) {
124            return special_result;
125        }
126
127        // Extract coefficients for normal linear analysis
128        let a_simplified = a.simplify();
129        let b_simplified = b.simplify();
130
131        if a_simplified.is_zero() {
132            if b_simplified.is_zero() {
133                return SolverResult::InfiniteSolutions; // 0x + 0 = 0
134            } else {
135                return SolverResult::NoSolution; // 0x + b = 0 where b ≠ 0
136            }
137        }
138
139        // Solve ax + b = 0 → x = -b/a
140        // Fractions are automatically reduced to lowest terms by BigRational::new()
141
142        // Check if we can solve numerically
143        match (&a_simplified, &b_simplified) {
144            (
145                Expression::Number(Number::Integer(a_val)),
146                Expression::Number(Number::Integer(b_val)),
147            ) => {
148                if *a_val != 0 {
149                    // Simple case: ax + b = 0 → x = -b/a
150                    let result = -b_val / a_val;
151                    if b_val % a_val == 0 {
152                        // Integer solution: return as integer (e.g., 10/5 = 2)
153                        SolverResult::Single(Expression::integer(result))
154                    } else {
155                        // Fractional solution: BigRational::new() automatically reduces to lowest terms
156                        // Example: 6/4 → 3/2, 18/12 → 3/2
157                        SolverResult::Single(Expression::Number(Number::rational(
158                            BigRational::new(BigInt::from(-b_val), BigInt::from(*a_val)),
159                        )))
160                    }
161                } else {
162                    SolverResult::NoSolution
163                }
164            }
165            _ => {
166                // General case - use simplified coefficients
167                let neg_b = b_simplified.negate().simplify();
168                let solution = Self::divide_expressions(&neg_b, &a_simplified).simplify();
169
170                // Try to evaluate the solution numerically if possible
171                let final_solution = Self::try_eval_numeric_internal(&solution);
172                SolverResult::Single(final_solution)
173            }
174        }
175    }
176
177    /// Solve with step-by-step explanation
178    fn solve_with_explanation(
179        &self,
180        equation: &Expression,
181        variable: &Symbol,
182    ) -> (SolverResult, StepByStepExplanation) {
183        let simplified_equation = equation.simplify();
184        let (a, b) = self.extract_linear_coefficients(&simplified_equation, variable);
185
186        if a.is_zero() {
187            return self.handle_special_case_with_style(&b);
188        }
189
190        let a_simplified = a.simplify();
191        let b_simplified = b.simplify();
192        let neg_b = b_simplified.negate().simplify();
193        let solution = Self::divide_expressions(&neg_b, &a_simplified).simplify();
194
195        let steps = vec![
196            Step::new(
197                "Given Equation",
198                format!("We need to solve: {} = 0", equation),
199            ),
200            Step::new(
201                "Strategy",
202                format!("Isolate {} using inverse operations", variable.name),
203            ),
204            Step::new(
205                "Identify Form",
206                format!("This has form: {}·{} + {} = 0", a, variable.name, b),
207            ),
208            Step::new(
209                "Calculate",
210                format!("{} = -({}) ÷ {} = {}", variable.name, b, a, solution),
211            ),
212            Step::new("Solution", format!("{} = {}", variable.name, solution)),
213        ];
214        let explanation = StepByStepExplanation::new(steps);
215
216        (SolverResult::Single(solution), explanation)
217    }
218
219    /// Check if this solver can handle the equation
220    fn can_solve(&self, equation: &Expression) -> bool {
221        // Check if equation is linear in any variable
222        self.is_linear_equation(equation)
223    }
224}
225
226impl LinearSolver {
227    /// Handle special cases with step explanations
228    fn handle_special_case_with_style(
229        &self,
230        b: &Expression,
231    ) -> (SolverResult, StepByStepExplanation) {
232        if b.is_zero() {
233            let steps = vec![
234                Step::new("Special Case", "0x + 0 = 0 is always true"),
235                Step::new("Result", "Infinite solutions - any value of x works"),
236            ];
237            (
238                SolverResult::InfiniteSolutions,
239                StepByStepExplanation::new(steps),
240            )
241        } else {
242            let steps = vec![
243                Step::new("Special Case", format!("0x + {} = 0 means {} = 0", b, b)),
244                Step::new(
245                    "Contradiction",
246                    format!("But {} ≠ 0, so no solution exists", b),
247                ),
248            ];
249            (SolverResult::NoSolution, StepByStepExplanation::new(steps))
250        }
251    }
252    /// Extract coefficients a and b from equation ax + b = 0
253    #[inline(always)]
254    fn extract_linear_coefficients(
255        &self,
256        equation: &Expression,
257        variable: &Symbol,
258    ) -> (Expression, Expression) {
259        // First, flatten all nested Add expressions
260        let flattened_terms = equation.flatten_add_terms();
261
262        let mut coefficient = Expression::integer(0); // Coefficient of variable
263        let mut constant = Expression::integer(0); // Constant term
264
265        for term in flattened_terms.iter() {
266            match term {
267                Expression::Symbol(s) if s == variable => {
268                    coefficient = Expression::add(vec![coefficient, Expression::integer(1)]);
269                }
270                Expression::Mul(factors) => {
271                    let mut var_coeff = Expression::integer(1);
272                    let mut has_variable = false;
273
274                    for factor in factors.iter() {
275                        match factor {
276                            Expression::Symbol(s) if s == variable => {
277                                has_variable = true;
278                            }
279                            _ => {
280                                var_coeff = Expression::mul(vec![var_coeff, factor.clone()]);
281                            }
282                        }
283                    }
284
285                    if has_variable {
286                        coefficient = Expression::add(vec![coefficient, var_coeff]);
287                    } else {
288                        constant = Expression::add(vec![constant, term.clone()]);
289                    }
290                }
291                _ => {
292                    // Constant term
293                    constant = Expression::add(vec![constant, term.clone()]);
294                }
295            }
296        }
297        (coefficient, constant)
298    }
299
300    /// Check if equation is linear
301    fn is_linear_equation(&self, equation: &Expression) -> bool {
302        matches!(
303            equation,
304            Expression::Add(_) | Expression::Symbol(_) | Expression::Number(_)
305        )
306    }
307
308    /// Detect special linear cases before simplification
309    #[inline(always)]
310    fn detect_special_linear_cases(
311        &self,
312        equation: &Expression,
313        variable: &Symbol,
314    ) -> Option<SolverResult> {
315        match equation {
316            Expression::Add(terms) if terms.len() == 2 => {
317                // Check for patterns: 0*x + constant
318                if let [Expression::Mul(factors), constant] = &terms[..] {
319                    if factors.len() == 2 {
320                        if let [Expression::Number(Number::Integer(0)), var] = &factors[..] {
321                            if var == &Expression::symbol(variable.clone()) {
322                                // Found 0*x + constant pattern
323                                match constant {
324                                    Expression::Number(Number::Integer(0)) => {
325                                        return Some(SolverResult::InfiniteSolutions);
326                                        // 0*x + 0 = 0
327                                    }
328                                    _ => {
329                                        return Some(SolverResult::NoSolution); // 0*x + nonzero = 0
330                                    }
331                                }
332                            }
333                        }
334                    }
335                }
336            }
337            _ => {}
338        }
339        None // No special case detected
340    }
341
342    /// Extract roots from factored polynomial form: (x - a)(x - b) = 0
343    fn extract_factored_roots(
344        &self,
345        expr: &Expression,
346        variable: &Symbol,
347    ) -> Option<Vec<Expression>> {
348        match expr {
349            Expression::Mul(factors) => {
350                let mut roots = Vec::new();
351
352                for factor in factors.iter() {
353                    // Check if this factor is (x - constant) or (constant - x)
354                    if let Expression::Add(terms) = factor {
355                        if terms.len() == 2 {
356                            // Check pattern: x + (-a) = 0 → x = a
357                            if let [Expression::Symbol(s), Expression::Mul(neg_factors)] =
358                                &terms[..]
359                            {
360                                if s == variable && neg_factors.len() == 2 {
361                                    if let [Expression::Number(Number::Integer(-1)), constant] =
362                                        &neg_factors[..]
363                                    {
364                                        roots.push(constant.clone());
365                                        continue;
366                                    }
367                                }
368                            }
369                            // Check pattern: -a + x = 0 → x = a
370                            if let [Expression::Mul(neg_factors), Expression::Symbol(s)] =
371                                &terms[..]
372                            {
373                                if s == variable && neg_factors.len() == 2 {
374                                    if let [Expression::Number(Number::Integer(-1)), constant] =
375                                        &neg_factors[..]
376                                    {
377                                        roots.push(constant.clone());
378                                        continue;
379                                    }
380                                }
381                            }
382                        }
383                    }
384                }
385
386                if roots.is_empty() {
387                    None
388                } else {
389                    Some(roots)
390                }
391            }
392            _ => None,
393        }
394    }
395
396    /// Internal domain-specific optimization for linear solver
397    ///
398    /// Evaluate expressions with fraction handling for linear equation solutions.
399    /// This is a specialized version optimized for the linear solver's needs.
400    ///
401    /// Static helper function - doesn't depend on instance state.
402    #[inline(always)]
403    fn try_eval_numeric_internal(expr: &Expression) -> Expression {
404        match expr {
405            // Handle -1 * (complex expression)
406            Expression::Mul(factors) if factors.len() == 2 => {
407                if let [Expression::Number(Number::Integer(-1)), complex_expr] = &factors[..] {
408                    // Evaluate the complex expression and negate it
409                    let evaluated = Self::eval_exact_internal(complex_expr);
410                    evaluated.negate().simplify()
411                } else {
412                    expr.clone()
413                }
414            }
415            // Handle fractions that should be evaluated
416            Expression::Function { name, args }
417                if name.as_ref() == "fraction" && args.len() == 2 =>
418            {
419                Self::eval_exact_internal(expr)
420            }
421            _ => expr.clone(),
422        }
423    }
424
425    /// Internal domain-specific optimization for linear solver
426    ///
427    /// Static helper function for exact arithmetic evaluation.
428    /// Preserves exact arithmetic (integers/rationals) without instance state dependency.
429    ///
430    /// This method is kept separate from Expression::evaluate_to_f64() because it maintains
431    /// mathematical exactness. For example, 1/3 stays as Rational(1,3), not 0.333...
432    ///
433    /// # Why Not Use evaluate_to_f64()?
434    ///
435    /// - evaluate_to_f64() converts to f64 (loses precision: 1/3 → 0.333...)
436    /// - This method preserves rationals (keeps exactness: 1/3 → Rational(1,3))
437    /// - Linear equation solutions often require exact fractions (e.g., x = 2/3)
438    ///
439    /// # Automatic Fraction Simplification
440    ///
441    /// When creating rational numbers via `BigRational::new(num, den)`, fractions are
442    /// automatically reduced to lowest terms using GCD. For example:
443    /// - `BigRational::new(6, 4)` → 3/2
444    /// - `BigRational::new(18, 12)` → 3/2
445    /// - `BigRational::new(10, 5)` → 2 (returned as integer if denominator is 1)
446    ///
447    /// # Returns
448    ///
449    /// - Expression::Number(Integer) for exact integer results
450    /// - Expression::Number(Rational) for exact fractional results (automatically simplified)
451    /// - Original expression if cannot be evaluated exactly
452    #[inline(always)]
453    fn eval_exact_internal(expr: &Expression) -> Expression {
454        match expr {
455            Expression::Add(terms) => {
456                let mut total = 0i64;
457                for term in terms.iter() {
458                    match Self::eval_exact_internal(term) {
459                        Expression::Number(Number::Integer(n)) => total += n,
460                        _ => return expr.clone(), // Can't evaluate
461                    }
462                }
463                Expression::integer(total)
464            }
465            Expression::Mul(factors) => {
466                let mut product = 1i64;
467                for factor in factors.iter() {
468                    match Self::eval_exact_internal(factor) {
469                        Expression::Number(Number::Integer(n)) => product *= n,
470                        _ => return expr.clone(), // Can't evaluate
471                    }
472                }
473                Expression::integer(product)
474            }
475            // Handle fraction functions: fraction(numerator, denominator)
476            // BigRational::new() automatically reduces to lowest terms
477            Expression::Function { name, args }
478                if name.as_ref() == "fraction" && args.len() == 2 =>
479            {
480                // First evaluate the numerator and denominator
481                let num_eval = Self::eval_exact_internal(&args[0]);
482                let den_eval = Self::eval_exact_internal(&args[1]);
483
484                match (&num_eval, &den_eval) {
485                    (
486                        Expression::Number(Number::Float(num)),
487                        Expression::Number(Number::Float(den)),
488                    ) => {
489                        if den.abs() >= EPSILON {
490                            let result = num / den;
491                            if result.fract().abs() < EPSILON {
492                                Expression::integer(result as i64)
493                            } else {
494                                Expression::Number(Number::float(result))
495                            }
496                        } else {
497                            expr.clone()
498                        }
499                    }
500                    (
501                        Expression::Number(Number::Integer(num)),
502                        Expression::Number(Number::Integer(den)),
503                    ) => {
504                        if *den != 0 {
505                            if num % den == 0 {
506                                Expression::integer(num / den)
507                            } else {
508                                // BigRational::new() automatically reduces to lowest terms via GCD
509                                Expression::Number(Number::rational(BigRational::new(
510                                    BigInt::from(*num),
511                                    BigInt::from(*den),
512                                )))
513                            }
514                        } else {
515                            expr.clone()
516                        }
517                    }
518                    _ => expr.clone(),
519                }
520            }
521            Expression::Number(_) => expr.clone(),
522            _ => expr.clone(),
523        }
524    }
525
526    /// Divide two expressions (simplified division)
527    ///
528    /// Static helper function for recursive division operations.
529    /// Does not require instance state, only performs expression manipulation.
530    ///
531    /// Fractions created via `BigRational::new()` are automatically reduced
532    /// to lowest terms using GCD.
533    #[inline(always)]
534    fn divide_expressions(numerator: &Expression, denominator: &Expression) -> Expression {
535        // First simplify both expressions
536        let num_simplified = numerator.simplify();
537        let den_simplified = denominator.simplify();
538
539        match (&num_simplified, &den_simplified) {
540            // Simple integer division
541            // BigRational::new() automatically reduces to lowest terms
542            (Expression::Number(Number::Integer(n)), Expression::Number(Number::Integer(d))) => {
543                if *d != 0 {
544                    if n % d == 0 {
545                        Expression::integer(n / d)
546                    } else {
547                        // Create rational number - automatically reduced to lowest terms
548                        Expression::Number(Number::rational(BigRational::new(
549                            BigInt::from(*n),
550                            BigInt::from(*d),
551                        )))
552                    }
553                } else {
554                    // Division by zero - should be handled as error
555                    Expression::integer(0) // Placeholder
556                }
557            }
558            // Integer divided by rational: a / (p/q) = a * (q/p)
559            (Expression::Number(Number::Integer(n)), Expression::Number(Number::Rational(r))) => {
560                // a / (p/q) = a * q / p
561                let inverted = BigRational::new(r.denom().clone(), r.numer().clone());
562                let result = BigRational::from(BigInt::from(*n)) * inverted;
563
564                // Simplify to integer if possible
565                if result.is_integer() {
566                    Expression::integer(result.numer().to_string().parse().unwrap())
567                } else {
568                    Expression::Number(Number::rational(result))
569                }
570            }
571            // Rational divided by integer: (p/q) / a = (p/q) / a = p/(q*a)
572            (Expression::Number(Number::Rational(r)), Expression::Number(Number::Integer(d))) => {
573                if *d != 0 {
574                    let result = (**r).clone() / BigRational::from(BigInt::from(*d));
575                    if result.is_integer() {
576                        Expression::integer(result.numer().to_string().parse().unwrap())
577                    } else {
578                        Expression::Number(Number::rational(result))
579                    }
580                } else {
581                    Expression::integer(0) // Placeholder
582                }
583            }
584            // Rational divided by rational
585            (
586                Expression::Number(Number::Rational(num_r)),
587                Expression::Number(Number::Rational(den_r)),
588            ) => {
589                let result = (**num_r).clone() / (**den_r).clone();
590                if result.is_integer() {
591                    Expression::integer(result.numer().to_string().parse().unwrap())
592                } else {
593                    Expression::Number(Number::rational(result))
594                }
595            }
596            // Try to simplify further - if denominator is 1, just return numerator
597            (num, Expression::Number(Number::Integer(1))) => num.clone(),
598            // Handle multiplication by -1 and other simple cases
599            (Expression::Mul(factors), den) if factors.len() == 2 => {
600                if let [Expression::Number(Number::Integer(-1)), expr] = &factors[..] {
601                    // -1 * expr / den = -(expr / den)
602                    let inner_div = Self::divide_expressions(expr, den);
603                    Expression::mul(vec![Expression::integer(-1), inner_div]).simplify()
604                } else {
605                    // General case
606                    let fraction =
607                        Expression::function("fraction", vec![num_simplified, den_simplified]);
608                    fraction.simplify()
609                }
610            }
611            // For linear solver, try to evaluate numerically if possible
612            _ => {
613                // Return as fraction function and let it simplify
614                let fraction =
615                    Expression::function("fraction", vec![num_simplified, den_simplified]);
616                fraction.simplify()
617            }
618        }
619    }
620}
621
622#[cfg(test)]
623mod tests {
624    use super::*;
625    use crate::symbol;
626
627    #[test]
628    fn test_coefficient_extraction() {
629        let x = symbol!(x);
630        let solver = LinearSolver::new();
631
632        // Test 2x + 3
633        let equation = Expression::add(vec![
634            Expression::mul(vec![Expression::integer(2), Expression::symbol(x.clone())]),
635            Expression::integer(3),
636        ]);
637
638        let (a, b) = solver.extract_linear_coefficients(&equation, &x);
639        // The coefficient might be Mul([1, 2]) so we need to simplify it
640        assert_eq!(a.simplify(), Expression::integer(2));
641        assert_eq!(b.simplify(), Expression::integer(3));
642    }
643
644    #[test]
645    fn test_linear_detection() {
646        let x = symbol!(x);
647        let solver = LinearSolver::new();
648
649        // Linear equation
650        let linear = Expression::add(vec![Expression::symbol(x.clone()), Expression::integer(1)]);
651        assert!(solver.is_linear_equation(&linear));
652
653        // Non-linear equation (power)
654        let nonlinear = Expression::pow(Expression::symbol(x.clone()), Expression::integer(2));
655        assert!(!solver.is_linear_equation(&nonlinear));
656    }
657}