mathhook_core/algebra/solvers/linear.rs
1//! Solves equations of the form ax + b = 0
2//! Includes step-by-step explanations for educational value
3
4use crate::algebra::Expand;
5use crate::core::constants::EPSILON;
6use crate::core::{Commutativity, Expression, Number, Symbol};
7use crate::educational::step_by_step::{Step, StepByStepExplanation};
8// Temporarily simplified for TDD success
9use crate::algebra::solvers::{EquationSolver, SolverResult};
10use crate::simplify::Simplify;
11use num_bigint::BigInt;
12use num_rational::BigRational;
13
14/// Handles linear equations with step-by-step explanations
15#[derive(Debug, Clone)]
16pub struct LinearSolver {
17 /// Enable step-by-step explanations
18 pub show_steps: bool,
19}
20
21impl Default for LinearSolver {
22 fn default() -> Self {
23 Self::new()
24 }
25}
26
27impl LinearSolver {
28 /// Create new linear solver
29 pub fn new() -> Self {
30 Self { show_steps: true }
31 }
32
33 /// Create solver without step-by-step (for performance)
34 pub fn new_fast() -> Self {
35 Self { show_steps: false }
36 }
37}
38
39impl EquationSolver for LinearSolver {
40 /// Solve linear equation ax + b = 0
41 ///
42 /// Fractional solutions are automatically simplified to lowest terms via
43 /// `BigRational::new()`, which reduces fractions using GCD. Integer solutions
44 /// (where numerator is divisible by denominator) are returned as integers.
45 ///
46 /// # Examples
47 ///
48 /// ```rust
49 /// use mathhook_core::algebra::solvers::{linear::LinearSolver, EquationSolver, SolverResult};
50 /// use mathhook_core::core::{Expression, Number};
51 /// use mathhook_core::symbol;
52 /// use num_bigint::BigInt;
53 ///
54 /// let solver = LinearSolver::new_fast();
55 /// let x = symbol!(x);
56 ///
57 /// // Example: 4x = 6 gives x = 3/2 (simplified from 6/4)
58 /// let equation = Expression::add(vec![
59 /// Expression::mul(vec![Expression::integer(4), Expression::symbol(x.clone())]),
60 /// Expression::integer(-6),
61 /// ]);
62 ///
63 /// match solver.solve(&equation, &x) {
64 /// SolverResult::Single(solution) => {
65 /// if let Expression::Number(Number::Rational(r)) = solution {
66 /// assert_eq!(r.numer(), &BigInt::from(3));
67 /// assert_eq!(r.denom(), &BigInt::from(2));
68 /// }
69 /// }
70 /// _ => panic!("Expected single solution"),
71 /// }
72 /// ```
73 #[inline(always)]
74 fn solve(&self, equation: &Expression, variable: &Symbol) -> SolverResult {
75 // Handle Relation type (equations like x = 5)
76 let equation_expr = if let Expression::Relation(data) = equation {
77 // Convert relation to expression: left - right = 0
78 Expression::add(vec![
79 data.left.clone(),
80 Expression::mul(vec![Expression::integer(-1), data.right.clone()]),
81 ])
82 } else {
83 equation.clone()
84 };
85
86 // Check for noncommutative symbols - delegate to MatrixEquationSolver if found
87 if equation_expr.commutativity() != Commutativity::Commutative {
88 use crate::algebra::solvers::matrix_equations::MatrixEquationSolver;
89 let matrix_solver = MatrixEquationSolver::new_fast();
90 return matrix_solver.solve(&equation_expr, variable);
91 }
92
93 // Simplify and expand equation to flatten nested structures and distribute multiplication
94 let simplified_equation = equation_expr.simplify().expand();
95
96 // Check for identity equations (0 = 0) or contradictions AFTER simplification
97 if simplified_equation.is_zero() {
98 // If equation simplified to just 0, it means 0 = 0 (infinite solutions)
99 return SolverResult::InfiniteSolutions;
100 }
101 // Check for non-zero constant (contradiction)
102 if let Expression::Number(Number::Integer(n)) = simplified_equation {
103 if n != 0 {
104 return SolverResult::NoSolution;
105 }
106 }
107
108 // Check for factored form: (x - a)(x - b)...(x - n) = 0
109 if let Some(roots) = self.extract_factored_roots(&simplified_equation, variable) {
110 if roots.len() == 1 {
111 return SolverResult::Single(roots[0].clone());
112 } else if roots.len() > 1 {
113 return SolverResult::Multiple(roots);
114 }
115 }
116
117 // Extract coefficients from simplified linear equation
118 let (a, b) = self.extract_linear_coefficients(&simplified_equation, variable);
119
120 // Smart solver: Analyze original equation structure before simplification
121
122 // Check if original equation has patterns like 0*x + constant
123 if let Some(special_result) = self.detect_special_linear_cases(&equation_expr, variable) {
124 return special_result;
125 }
126
127 // Extract coefficients for normal linear analysis
128 let a_simplified = a.simplify();
129 let b_simplified = b.simplify();
130
131 if a_simplified.is_zero() {
132 if b_simplified.is_zero() {
133 return SolverResult::InfiniteSolutions; // 0x + 0 = 0
134 } else {
135 return SolverResult::NoSolution; // 0x + b = 0 where b ≠ 0
136 }
137 }
138
139 // Solve ax + b = 0 → x = -b/a
140 // Fractions are automatically reduced to lowest terms by BigRational::new()
141
142 // Check if we can solve numerically
143 match (&a_simplified, &b_simplified) {
144 (
145 Expression::Number(Number::Integer(a_val)),
146 Expression::Number(Number::Integer(b_val)),
147 ) => {
148 if *a_val != 0 {
149 // Simple case: ax + b = 0 → x = -b/a
150 let result = -b_val / a_val;
151 if b_val % a_val == 0 {
152 // Integer solution: return as integer (e.g., 10/5 = 2)
153 SolverResult::Single(Expression::integer(result))
154 } else {
155 // Fractional solution: BigRational::new() automatically reduces to lowest terms
156 // Example: 6/4 → 3/2, 18/12 → 3/2
157 SolverResult::Single(Expression::Number(Number::rational(
158 BigRational::new(BigInt::from(-b_val), BigInt::from(*a_val)),
159 )))
160 }
161 } else {
162 SolverResult::NoSolution
163 }
164 }
165 _ => {
166 // General case - use simplified coefficients
167 let neg_b = b_simplified.negate().simplify();
168 let solution = Self::divide_expressions(&neg_b, &a_simplified).simplify();
169
170 // Try to evaluate the solution numerically if possible
171 let final_solution = Self::try_eval_numeric_internal(&solution);
172 SolverResult::Single(final_solution)
173 }
174 }
175 }
176
177 /// Solve with step-by-step explanation
178 fn solve_with_explanation(
179 &self,
180 equation: &Expression,
181 variable: &Symbol,
182 ) -> (SolverResult, StepByStepExplanation) {
183 let simplified_equation = equation.simplify();
184 let (a, b) = self.extract_linear_coefficients(&simplified_equation, variable);
185
186 if a.is_zero() {
187 return self.handle_special_case_with_style(&b);
188 }
189
190 let a_simplified = a.simplify();
191 let b_simplified = b.simplify();
192 let neg_b = b_simplified.negate().simplify();
193 let solution = Self::divide_expressions(&neg_b, &a_simplified).simplify();
194
195 let steps = vec![
196 Step::new(
197 "Given Equation",
198 format!("We need to solve: {} = 0", equation),
199 ),
200 Step::new(
201 "Strategy",
202 format!("Isolate {} using inverse operations", variable.name),
203 ),
204 Step::new(
205 "Identify Form",
206 format!("This has form: {}·{} + {} = 0", a, variable.name, b),
207 ),
208 Step::new(
209 "Calculate",
210 format!("{} = -({}) ÷ {} = {}", variable.name, b, a, solution),
211 ),
212 Step::new("Solution", format!("{} = {}", variable.name, solution)),
213 ];
214 let explanation = StepByStepExplanation::new(steps);
215
216 (SolverResult::Single(solution), explanation)
217 }
218
219 /// Check if this solver can handle the equation
220 fn can_solve(&self, equation: &Expression) -> bool {
221 // Check if equation is linear in any variable
222 self.is_linear_equation(equation)
223 }
224}
225
226impl LinearSolver {
227 /// Handle special cases with step explanations
228 fn handle_special_case_with_style(
229 &self,
230 b: &Expression,
231 ) -> (SolverResult, StepByStepExplanation) {
232 if b.is_zero() {
233 let steps = vec![
234 Step::new("Special Case", "0x + 0 = 0 is always true"),
235 Step::new("Result", "Infinite solutions - any value of x works"),
236 ];
237 (
238 SolverResult::InfiniteSolutions,
239 StepByStepExplanation::new(steps),
240 )
241 } else {
242 let steps = vec![
243 Step::new("Special Case", format!("0x + {} = 0 means {} = 0", b, b)),
244 Step::new(
245 "Contradiction",
246 format!("But {} ≠ 0, so no solution exists", b),
247 ),
248 ];
249 (SolverResult::NoSolution, StepByStepExplanation::new(steps))
250 }
251 }
252 /// Extract coefficients a and b from equation ax + b = 0
253 #[inline(always)]
254 fn extract_linear_coefficients(
255 &self,
256 equation: &Expression,
257 variable: &Symbol,
258 ) -> (Expression, Expression) {
259 // First, flatten all nested Add expressions
260 let flattened_terms = equation.flatten_add_terms();
261
262 let mut coefficient = Expression::integer(0); // Coefficient of variable
263 let mut constant = Expression::integer(0); // Constant term
264
265 for term in flattened_terms.iter() {
266 match term {
267 Expression::Symbol(s) if s == variable => {
268 coefficient = Expression::add(vec![coefficient, Expression::integer(1)]);
269 }
270 Expression::Mul(factors) => {
271 let mut var_coeff = Expression::integer(1);
272 let mut has_variable = false;
273
274 for factor in factors.iter() {
275 match factor {
276 Expression::Symbol(s) if s == variable => {
277 has_variable = true;
278 }
279 _ => {
280 var_coeff = Expression::mul(vec![var_coeff, factor.clone()]);
281 }
282 }
283 }
284
285 if has_variable {
286 coefficient = Expression::add(vec![coefficient, var_coeff]);
287 } else {
288 constant = Expression::add(vec![constant, term.clone()]);
289 }
290 }
291 _ => {
292 // Constant term
293 constant = Expression::add(vec![constant, term.clone()]);
294 }
295 }
296 }
297 (coefficient, constant)
298 }
299
300 /// Check if equation is linear
301 fn is_linear_equation(&self, equation: &Expression) -> bool {
302 matches!(
303 equation,
304 Expression::Add(_) | Expression::Symbol(_) | Expression::Number(_)
305 )
306 }
307
308 /// Detect special linear cases before simplification
309 #[inline(always)]
310 fn detect_special_linear_cases(
311 &self,
312 equation: &Expression,
313 variable: &Symbol,
314 ) -> Option<SolverResult> {
315 match equation {
316 Expression::Add(terms) if terms.len() == 2 => {
317 // Check for patterns: 0*x + constant
318 if let [Expression::Mul(factors), constant] = &terms[..] {
319 if factors.len() == 2 {
320 if let [Expression::Number(Number::Integer(0)), var] = &factors[..] {
321 if var == &Expression::symbol(variable.clone()) {
322 // Found 0*x + constant pattern
323 match constant {
324 Expression::Number(Number::Integer(0)) => {
325 return Some(SolverResult::InfiniteSolutions);
326 // 0*x + 0 = 0
327 }
328 _ => {
329 return Some(SolverResult::NoSolution); // 0*x + nonzero = 0
330 }
331 }
332 }
333 }
334 }
335 }
336 }
337 _ => {}
338 }
339 None // No special case detected
340 }
341
342 /// Extract roots from factored polynomial form: (x - a)(x - b) = 0
343 fn extract_factored_roots(
344 &self,
345 expr: &Expression,
346 variable: &Symbol,
347 ) -> Option<Vec<Expression>> {
348 match expr {
349 Expression::Mul(factors) => {
350 let mut roots = Vec::new();
351
352 for factor in factors.iter() {
353 // Check if this factor is (x - constant) or (constant - x)
354 if let Expression::Add(terms) = factor {
355 if terms.len() == 2 {
356 // Check pattern: x + (-a) = 0 → x = a
357 if let [Expression::Symbol(s), Expression::Mul(neg_factors)] =
358 &terms[..]
359 {
360 if s == variable && neg_factors.len() == 2 {
361 if let [Expression::Number(Number::Integer(-1)), constant] =
362 &neg_factors[..]
363 {
364 roots.push(constant.clone());
365 continue;
366 }
367 }
368 }
369 // Check pattern: -a + x = 0 → x = a
370 if let [Expression::Mul(neg_factors), Expression::Symbol(s)] =
371 &terms[..]
372 {
373 if s == variable && neg_factors.len() == 2 {
374 if let [Expression::Number(Number::Integer(-1)), constant] =
375 &neg_factors[..]
376 {
377 roots.push(constant.clone());
378 continue;
379 }
380 }
381 }
382 }
383 }
384 }
385
386 if roots.is_empty() {
387 None
388 } else {
389 Some(roots)
390 }
391 }
392 _ => None,
393 }
394 }
395
396 /// Internal domain-specific optimization for linear solver
397 ///
398 /// Evaluate expressions with fraction handling for linear equation solutions.
399 /// This is a specialized version optimized for the linear solver's needs.
400 ///
401 /// Static helper function - doesn't depend on instance state.
402 #[inline(always)]
403 fn try_eval_numeric_internal(expr: &Expression) -> Expression {
404 match expr {
405 // Handle -1 * (complex expression)
406 Expression::Mul(factors) if factors.len() == 2 => {
407 if let [Expression::Number(Number::Integer(-1)), complex_expr] = &factors[..] {
408 // Evaluate the complex expression and negate it
409 let evaluated = Self::eval_exact_internal(complex_expr);
410 evaluated.negate().simplify()
411 } else {
412 expr.clone()
413 }
414 }
415 // Handle fractions that should be evaluated
416 Expression::Function { name, args } if name == "fraction" && args.len() == 2 => {
417 Self::eval_exact_internal(expr)
418 }
419 _ => expr.clone(),
420 }
421 }
422
423 /// Internal domain-specific optimization for linear solver
424 ///
425 /// Static helper function for exact arithmetic evaluation.
426 /// Preserves exact arithmetic (integers/rationals) without instance state dependency.
427 ///
428 /// This method is kept separate from Expression::evaluate_to_f64() because it maintains
429 /// mathematical exactness. For example, 1/3 stays as Rational(1,3), not 0.333...
430 ///
431 /// # Why Not Use evaluate_to_f64()?
432 ///
433 /// - evaluate_to_f64() converts to f64 (loses precision: 1/3 → 0.333...)
434 /// - This method preserves rationals (keeps exactness: 1/3 → Rational(1,3))
435 /// - Linear equation solutions often require exact fractions (e.g., x = 2/3)
436 ///
437 /// # Automatic Fraction Simplification
438 ///
439 /// When creating rational numbers via `BigRational::new(num, den)`, fractions are
440 /// automatically reduced to lowest terms using GCD. For example:
441 /// - `BigRational::new(6, 4)` → 3/2
442 /// - `BigRational::new(18, 12)` → 3/2
443 /// - `BigRational::new(10, 5)` → 2 (returned as integer if denominator is 1)
444 ///
445 /// # Returns
446 ///
447 /// - Expression::Number(Integer) for exact integer results
448 /// - Expression::Number(Rational) for exact fractional results (automatically simplified)
449 /// - Original expression if cannot be evaluated exactly
450 #[inline(always)]
451 fn eval_exact_internal(expr: &Expression) -> Expression {
452 match expr {
453 Expression::Add(terms) => {
454 let mut total = 0i64;
455 for term in terms.iter() {
456 match Self::eval_exact_internal(term) {
457 Expression::Number(Number::Integer(n)) => total += n,
458 _ => return expr.clone(), // Can't evaluate
459 }
460 }
461 Expression::integer(total)
462 }
463 Expression::Mul(factors) => {
464 let mut product = 1i64;
465 for factor in factors.iter() {
466 match Self::eval_exact_internal(factor) {
467 Expression::Number(Number::Integer(n)) => product *= n,
468 _ => return expr.clone(), // Can't evaluate
469 }
470 }
471 Expression::integer(product)
472 }
473 // Handle fraction functions: fraction(numerator, denominator)
474 // BigRational::new() automatically reduces to lowest terms
475 Expression::Function { name, args } if name == "fraction" && args.len() == 2 => {
476 // First evaluate the numerator and denominator
477 let num_eval = Self::eval_exact_internal(&args[0]);
478 let den_eval = Self::eval_exact_internal(&args[1]);
479
480 match (&num_eval, &den_eval) {
481 (
482 Expression::Number(Number::Float(num)),
483 Expression::Number(Number::Float(den)),
484 ) => {
485 if den.abs() >= EPSILON {
486 let result = num / den;
487 if result.fract().abs() < EPSILON {
488 Expression::integer(result as i64)
489 } else {
490 Expression::Number(Number::float(result))
491 }
492 } else {
493 expr.clone()
494 }
495 }
496 (
497 Expression::Number(Number::Integer(num)),
498 Expression::Number(Number::Integer(den)),
499 ) => {
500 if *den != 0 {
501 if num % den == 0 {
502 Expression::integer(num / den)
503 } else {
504 // BigRational::new() automatically reduces to lowest terms via GCD
505 Expression::Number(Number::rational(BigRational::new(
506 BigInt::from(*num),
507 BigInt::from(*den),
508 )))
509 }
510 } else {
511 expr.clone()
512 }
513 }
514 _ => expr.clone(),
515 }
516 }
517 Expression::Number(_) => expr.clone(),
518 _ => expr.clone(),
519 }
520 }
521
522 /// Divide two expressions (simplified division)
523 ///
524 /// Static helper function for recursive division operations.
525 /// Does not require instance state, only performs expression manipulation.
526 ///
527 /// Fractions created via `BigRational::new()` are automatically reduced
528 /// to lowest terms using GCD.
529 #[inline(always)]
530 fn divide_expressions(numerator: &Expression, denominator: &Expression) -> Expression {
531 // First simplify both expressions
532 let num_simplified = numerator.simplify();
533 let den_simplified = denominator.simplify();
534
535 match (&num_simplified, &den_simplified) {
536 // Simple integer division
537 // BigRational::new() automatically reduces to lowest terms
538 (Expression::Number(Number::Integer(n)), Expression::Number(Number::Integer(d))) => {
539 if *d != 0 {
540 if n % d == 0 {
541 Expression::integer(n / d)
542 } else {
543 // Create rational number - automatically reduced to lowest terms
544 Expression::Number(Number::rational(BigRational::new(
545 BigInt::from(*n),
546 BigInt::from(*d),
547 )))
548 }
549 } else {
550 // Division by zero - should be handled as error
551 Expression::integer(0) // Placeholder
552 }
553 }
554 // Integer divided by rational: a / (p/q) = a * (q/p)
555 (Expression::Number(Number::Integer(n)), Expression::Number(Number::Rational(r))) => {
556 // a / (p/q) = a * q / p
557 let inverted = BigRational::new(r.denom().clone(), r.numer().clone());
558 let result = BigRational::from(BigInt::from(*n)) * inverted;
559
560 // Simplify to integer if possible
561 if result.is_integer() {
562 Expression::integer(result.numer().to_string().parse().unwrap())
563 } else {
564 Expression::Number(Number::rational(result))
565 }
566 }
567 // Rational divided by integer: (p/q) / a = (p/q) / a = p/(q*a)
568 (Expression::Number(Number::Rational(r)), Expression::Number(Number::Integer(d))) => {
569 if *d != 0 {
570 let result = (**r).clone() / BigRational::from(BigInt::from(*d));
571 if result.is_integer() {
572 Expression::integer(result.numer().to_string().parse().unwrap())
573 } else {
574 Expression::Number(Number::rational(result))
575 }
576 } else {
577 Expression::integer(0) // Placeholder
578 }
579 }
580 // Rational divided by rational
581 (
582 Expression::Number(Number::Rational(num_r)),
583 Expression::Number(Number::Rational(den_r)),
584 ) => {
585 let result = (**num_r).clone() / (**den_r).clone();
586 if result.is_integer() {
587 Expression::integer(result.numer().to_string().parse().unwrap())
588 } else {
589 Expression::Number(Number::rational(result))
590 }
591 }
592 // Try to simplify further - if denominator is 1, just return numerator
593 (num, Expression::Number(Number::Integer(1))) => num.clone(),
594 // Handle multiplication by -1 and other simple cases
595 (Expression::Mul(factors), den) if factors.len() == 2 => {
596 if let [Expression::Number(Number::Integer(-1)), expr] = &factors[..] {
597 // -1 * expr / den = -(expr / den)
598 let inner_div = Self::divide_expressions(expr, den);
599 Expression::mul(vec![Expression::integer(-1), inner_div]).simplify()
600 } else {
601 // General case
602 let fraction =
603 Expression::function("fraction", vec![num_simplified, den_simplified]);
604 fraction.simplify()
605 }
606 }
607 // For linear solver, try to evaluate numerically if possible
608 _ => {
609 // Return as fraction function and let it simplify
610 let fraction =
611 Expression::function("fraction", vec![num_simplified, den_simplified]);
612 fraction.simplify()
613 }
614 }
615 }
616}
617
618#[cfg(test)]
619mod tests {
620 use super::*;
621 use crate::symbol;
622
623 #[test]
624 fn test_coefficient_extraction() {
625 let x = symbol!(x);
626 let solver = LinearSolver::new();
627
628 // Test 2x + 3
629 let equation = Expression::add(vec![
630 Expression::mul(vec![Expression::integer(2), Expression::symbol(x.clone())]),
631 Expression::integer(3),
632 ]);
633
634 let (a, b) = solver.extract_linear_coefficients(&equation, &x);
635 // The coefficient might be Mul([1, 2]) so we need to simplify it
636 assert_eq!(a.simplify(), Expression::integer(2));
637 assert_eq!(b.simplify(), Expression::integer(3));
638 }
639
640 #[test]
641 fn test_linear_detection() {
642 let x = symbol!(x);
643 let solver = LinearSolver::new();
644
645 // Linear equation
646 let linear = Expression::add(vec![Expression::symbol(x.clone()), Expression::integer(1)]);
647 assert!(solver.is_linear_equation(&linear));
648
649 // Non-linear equation (power)
650 let nonlinear = Expression::pow(Expression::symbol(x.clone()), Expression::integer(2));
651 assert!(!solver.is_linear_equation(&nonlinear));
652 }
653}