mathhook_core/algebra/solvers/linear.rs
1//! Solves equations of the form ax + b = 0
2//! Includes step-by-step explanations for educational value
3
4use crate::algebra::Expand;
5use crate::core::constants::EPSILON;
6use crate::core::{Commutativity, Expression, Number, Symbol};
7use crate::educational::step_by_step::{Step, StepByStepExplanation};
8// Temporarily simplified for TDD success
9use crate::algebra::solvers::{EquationSolver, SolverResult};
10use crate::simplify::Simplify;
11use num_bigint::BigInt;
12use num_rational::BigRational;
13
14/// Handles linear equations with step-by-step explanations
15#[derive(Debug, Clone)]
16pub struct LinearSolver {
17 /// Enable step-by-step explanations
18 pub show_steps: bool,
19}
20
21impl Default for LinearSolver {
22 fn default() -> Self {
23 Self::new()
24 }
25}
26
27impl LinearSolver {
28 /// Create new linear solver
29 pub fn new() -> Self {
30 Self { show_steps: true }
31 }
32
33 /// Create solver without step-by-step (for performance)
34 pub fn new_fast() -> Self {
35 Self { show_steps: false }
36 }
37}
38
39impl EquationSolver for LinearSolver {
40 /// Solve linear equation ax + b = 0
41 ///
42 /// Fractional solutions are automatically simplified to lowest terms via
43 /// `BigRational::new()`, which reduces fractions using GCD. Integer solutions
44 /// (where numerator is divisible by denominator) are returned as integers.
45 ///
46 /// # Examples
47 ///
48 /// ```rust
49 /// use mathhook_core::algebra::solvers::{linear::LinearSolver, EquationSolver, SolverResult};
50 /// use mathhook_core::core::{Expression, Number};
51 /// use mathhook_core::symbol;
52 /// use num_bigint::BigInt;
53 ///
54 /// let solver = LinearSolver::new_fast();
55 /// let x = symbol!(x);
56 ///
57 /// // Example: 4x = 6 gives x = 3/2 (simplified from 6/4)
58 /// let equation = Expression::add(vec![
59 /// Expression::mul(vec![Expression::integer(4), Expression::symbol(x.clone())]),
60 /// Expression::integer(-6),
61 /// ]);
62 ///
63 /// match solver.solve(&equation, &x) {
64 /// SolverResult::Single(solution) => {
65 /// if let Expression::Number(Number::Rational(r)) = solution {
66 /// assert_eq!(r.numer(), &BigInt::from(3));
67 /// assert_eq!(r.denom(), &BigInt::from(2));
68 /// }
69 /// }
70 /// _ => panic!("Expected single solution"),
71 /// }
72 /// ```
73 #[inline(always)]
74 fn solve(&self, equation: &Expression, variable: &Symbol) -> SolverResult {
75 // Handle Relation type (equations like x = 5)
76 let equation_expr = if let Expression::Relation(data) = equation {
77 // Convert relation to expression: left - right = 0
78 Expression::add(vec![
79 data.left.clone(),
80 Expression::mul(vec![Expression::integer(-1), data.right.clone()]),
81 ])
82 } else {
83 equation.clone()
84 };
85
86 // Check for noncommutative symbols - delegate to MatrixEquationSolver if found
87 if equation_expr.commutativity() != Commutativity::Commutative {
88 use crate::algebra::solvers::matrix_equations::MatrixEquationSolver;
89 let matrix_solver = MatrixEquationSolver::new_fast();
90 return matrix_solver.solve(&equation_expr, variable);
91 }
92
93 // Simplify and expand equation to flatten nested structures and distribute multiplication
94 let simplified_equation = equation_expr.simplify().expand();
95
96 // Check for identity equations (0 = 0) or contradictions AFTER simplification
97 if simplified_equation.is_zero() {
98 // If equation simplified to just 0, it means 0 = 0 (infinite solutions)
99 return SolverResult::InfiniteSolutions;
100 }
101 // Check for non-zero constant (contradiction)
102 if let Expression::Number(Number::Integer(n)) = simplified_equation {
103 if n != 0 {
104 return SolverResult::NoSolution;
105 }
106 }
107
108 // Check for factored form: (x - a)(x - b)...(x - n) = 0
109 if let Some(roots) = self.extract_factored_roots(&simplified_equation, variable) {
110 if roots.len() == 1 {
111 return SolverResult::Single(roots[0].clone());
112 } else if roots.len() > 1 {
113 return SolverResult::Multiple(roots);
114 }
115 }
116
117 // Extract coefficients from simplified linear equation
118 let (a, b) = self.extract_linear_coefficients(&simplified_equation, variable);
119
120 // Smart solver: Analyze original equation structure before simplification
121
122 // Check if original equation has patterns like 0*x + constant
123 if let Some(special_result) = self.detect_special_linear_cases(&equation_expr, variable) {
124 return special_result;
125 }
126
127 // Extract coefficients for normal linear analysis
128 let a_simplified = a.simplify();
129 let b_simplified = b.simplify();
130
131 if a_simplified.is_zero() {
132 if b_simplified.is_zero() {
133 return SolverResult::InfiniteSolutions; // 0x + 0 = 0
134 } else {
135 return SolverResult::NoSolution; // 0x + b = 0 where b ≠ 0
136 }
137 }
138
139 // Solve ax + b = 0 → x = -b/a
140 // Fractions are automatically reduced to lowest terms by BigRational::new()
141
142 // Check if we can solve numerically
143 match (&a_simplified, &b_simplified) {
144 (
145 Expression::Number(Number::Integer(a_val)),
146 Expression::Number(Number::Integer(b_val)),
147 ) => {
148 if *a_val != 0 {
149 // Simple case: ax + b = 0 → x = -b/a
150 let result = -b_val / a_val;
151 if b_val % a_val == 0 {
152 // Integer solution: return as integer (e.g., 10/5 = 2)
153 SolverResult::Single(Expression::integer(result))
154 } else {
155 // Fractional solution: BigRational::new() automatically reduces to lowest terms
156 // Example: 6/4 → 3/2, 18/12 → 3/2
157 SolverResult::Single(Expression::Number(Number::rational(
158 BigRational::new(BigInt::from(-b_val), BigInt::from(*a_val)),
159 )))
160 }
161 } else {
162 SolverResult::NoSolution
163 }
164 }
165 _ => {
166 // General case - use simplified coefficients
167 let neg_b = b_simplified.negate().simplify();
168 let solution = Self::divide_expressions(&neg_b, &a_simplified).simplify();
169
170 // Try to evaluate the solution numerically if possible
171 let final_solution = Self::try_eval_numeric_internal(&solution);
172 SolverResult::Single(final_solution)
173 }
174 }
175 }
176
177 /// Solve with step-by-step explanation
178 fn solve_with_explanation(
179 &self,
180 equation: &Expression,
181 variable: &Symbol,
182 ) -> (SolverResult, StepByStepExplanation) {
183 let simplified_equation = equation.simplify();
184 let (a, b) = self.extract_linear_coefficients(&simplified_equation, variable);
185
186 if a.is_zero() {
187 return self.handle_special_case_with_style(&b);
188 }
189
190 let a_simplified = a.simplify();
191 let b_simplified = b.simplify();
192 let neg_b = b_simplified.negate().simplify();
193 let solution = Self::divide_expressions(&neg_b, &a_simplified).simplify();
194
195 let steps = vec![
196 Step::new(
197 "Given Equation",
198 format!("We need to solve: {} = 0", equation),
199 ),
200 Step::new(
201 "Strategy",
202 format!("Isolate {} using inverse operations", variable.name),
203 ),
204 Step::new(
205 "Identify Form",
206 format!("This has form: {}·{} + {} = 0", a, variable.name, b),
207 ),
208 Step::new(
209 "Calculate",
210 format!("{} = -({}) ÷ {} = {}", variable.name, b, a, solution),
211 ),
212 Step::new("Solution", format!("{} = {}", variable.name, solution)),
213 ];
214 let explanation = StepByStepExplanation::new(steps);
215
216 (SolverResult::Single(solution), explanation)
217 }
218
219 /// Check if this solver can handle the equation
220 fn can_solve(&self, equation: &Expression) -> bool {
221 // Check if equation is linear in any variable
222 self.is_linear_equation(equation)
223 }
224}
225
226impl LinearSolver {
227 /// Handle special cases with step explanations
228 fn handle_special_case_with_style(
229 &self,
230 b: &Expression,
231 ) -> (SolverResult, StepByStepExplanation) {
232 if b.is_zero() {
233 let steps = vec![
234 Step::new("Special Case", "0x + 0 = 0 is always true"),
235 Step::new("Result", "Infinite solutions - any value of x works"),
236 ];
237 (
238 SolverResult::InfiniteSolutions,
239 StepByStepExplanation::new(steps),
240 )
241 } else {
242 let steps = vec![
243 Step::new("Special Case", format!("0x + {} = 0 means {} = 0", b, b)),
244 Step::new(
245 "Contradiction",
246 format!("But {} ≠ 0, so no solution exists", b),
247 ),
248 ];
249 (SolverResult::NoSolution, StepByStepExplanation::new(steps))
250 }
251 }
252 /// Extract coefficients a and b from equation ax + b = 0
253 #[inline(always)]
254 fn extract_linear_coefficients(
255 &self,
256 equation: &Expression,
257 variable: &Symbol,
258 ) -> (Expression, Expression) {
259 // First, flatten all nested Add expressions
260 let flattened_terms = equation.flatten_add_terms();
261
262 let mut coefficient = Expression::integer(0); // Coefficient of variable
263 let mut constant = Expression::integer(0); // Constant term
264
265 for term in flattened_terms.iter() {
266 match term {
267 Expression::Symbol(s) if s == variable => {
268 coefficient = Expression::add(vec![coefficient, Expression::integer(1)]);
269 }
270 Expression::Mul(factors) => {
271 let mut var_coeff = Expression::integer(1);
272 let mut has_variable = false;
273
274 for factor in factors.iter() {
275 match factor {
276 Expression::Symbol(s) if s == variable => {
277 has_variable = true;
278 }
279 _ => {
280 var_coeff = Expression::mul(vec![var_coeff, factor.clone()]);
281 }
282 }
283 }
284
285 if has_variable {
286 coefficient = Expression::add(vec![coefficient, var_coeff]);
287 } else {
288 constant = Expression::add(vec![constant, term.clone()]);
289 }
290 }
291 _ => {
292 // Constant term
293 constant = Expression::add(vec![constant, term.clone()]);
294 }
295 }
296 }
297 (coefficient, constant)
298 }
299
300 /// Check if equation is linear
301 fn is_linear_equation(&self, equation: &Expression) -> bool {
302 matches!(
303 equation,
304 Expression::Add(_) | Expression::Symbol(_) | Expression::Number(_)
305 )
306 }
307
308 /// Detect special linear cases before simplification
309 #[inline(always)]
310 fn detect_special_linear_cases(
311 &self,
312 equation: &Expression,
313 variable: &Symbol,
314 ) -> Option<SolverResult> {
315 match equation {
316 Expression::Add(terms) if terms.len() == 2 => {
317 // Check for patterns: 0*x + constant
318 if let [Expression::Mul(factors), constant] = &terms[..] {
319 if factors.len() == 2 {
320 if let [Expression::Number(Number::Integer(0)), var] = &factors[..] {
321 if var == &Expression::symbol(variable.clone()) {
322 // Found 0*x + constant pattern
323 match constant {
324 Expression::Number(Number::Integer(0)) => {
325 return Some(SolverResult::InfiniteSolutions);
326 // 0*x + 0 = 0
327 }
328 _ => {
329 return Some(SolverResult::NoSolution); // 0*x + nonzero = 0
330 }
331 }
332 }
333 }
334 }
335 }
336 }
337 _ => {}
338 }
339 None // No special case detected
340 }
341
342 /// Extract roots from factored polynomial form: (x - a)(x - b) = 0
343 fn extract_factored_roots(
344 &self,
345 expr: &Expression,
346 variable: &Symbol,
347 ) -> Option<Vec<Expression>> {
348 match expr {
349 Expression::Mul(factors) => {
350 let mut roots = Vec::new();
351
352 for factor in factors.iter() {
353 // Check if this factor is (x - constant) or (constant - x)
354 if let Expression::Add(terms) = factor {
355 if terms.len() == 2 {
356 // Check pattern: x + (-a) = 0 → x = a
357 if let [Expression::Symbol(s), Expression::Mul(neg_factors)] =
358 &terms[..]
359 {
360 if s == variable && neg_factors.len() == 2 {
361 if let [Expression::Number(Number::Integer(-1)), constant] =
362 &neg_factors[..]
363 {
364 roots.push(constant.clone());
365 continue;
366 }
367 }
368 }
369 // Check pattern: -a + x = 0 → x = a
370 if let [Expression::Mul(neg_factors), Expression::Symbol(s)] =
371 &terms[..]
372 {
373 if s == variable && neg_factors.len() == 2 {
374 if let [Expression::Number(Number::Integer(-1)), constant] =
375 &neg_factors[..]
376 {
377 roots.push(constant.clone());
378 continue;
379 }
380 }
381 }
382 }
383 }
384 }
385
386 if roots.is_empty() {
387 None
388 } else {
389 Some(roots)
390 }
391 }
392 _ => None,
393 }
394 }
395
396 /// Internal domain-specific optimization for linear solver
397 ///
398 /// Evaluate expressions with fraction handling for linear equation solutions.
399 /// This is a specialized version optimized for the linear solver's needs.
400 ///
401 /// Static helper function - doesn't depend on instance state.
402 #[inline(always)]
403 fn try_eval_numeric_internal(expr: &Expression) -> Expression {
404 match expr {
405 // Handle -1 * (complex expression)
406 Expression::Mul(factors) if factors.len() == 2 => {
407 if let [Expression::Number(Number::Integer(-1)), complex_expr] = &factors[..] {
408 // Evaluate the complex expression and negate it
409 let evaluated = Self::eval_exact_internal(complex_expr);
410 evaluated.negate().simplify()
411 } else {
412 expr.clone()
413 }
414 }
415 // Handle fractions that should be evaluated
416 Expression::Function { name, args }
417 if name.as_ref() == "fraction" && args.len() == 2 =>
418 {
419 Self::eval_exact_internal(expr)
420 }
421 _ => expr.clone(),
422 }
423 }
424
425 /// Internal domain-specific optimization for linear solver
426 ///
427 /// Static helper function for exact arithmetic evaluation.
428 /// Preserves exact arithmetic (integers/rationals) without instance state dependency.
429 ///
430 /// This method is kept separate from Expression::evaluate_to_f64() because it maintains
431 /// mathematical exactness. For example, 1/3 stays as Rational(1,3), not 0.333...
432 ///
433 /// # Why Not Use evaluate_to_f64()?
434 ///
435 /// - evaluate_to_f64() converts to f64 (loses precision: 1/3 → 0.333...)
436 /// - This method preserves rationals (keeps exactness: 1/3 → Rational(1,3))
437 /// - Linear equation solutions often require exact fractions (e.g., x = 2/3)
438 ///
439 /// # Automatic Fraction Simplification
440 ///
441 /// When creating rational numbers via `BigRational::new(num, den)`, fractions are
442 /// automatically reduced to lowest terms using GCD. For example:
443 /// - `BigRational::new(6, 4)` → 3/2
444 /// - `BigRational::new(18, 12)` → 3/2
445 /// - `BigRational::new(10, 5)` → 2 (returned as integer if denominator is 1)
446 ///
447 /// # Returns
448 ///
449 /// - Expression::Number(Integer) for exact integer results
450 /// - Expression::Number(Rational) for exact fractional results (automatically simplified)
451 /// - Original expression if cannot be evaluated exactly
452 #[inline(always)]
453 fn eval_exact_internal(expr: &Expression) -> Expression {
454 match expr {
455 Expression::Add(terms) => {
456 let mut total = 0i64;
457 for term in terms.iter() {
458 match Self::eval_exact_internal(term) {
459 Expression::Number(Number::Integer(n)) => total += n,
460 _ => return expr.clone(), // Can't evaluate
461 }
462 }
463 Expression::integer(total)
464 }
465 Expression::Mul(factors) => {
466 let mut product = 1i64;
467 for factor in factors.iter() {
468 match Self::eval_exact_internal(factor) {
469 Expression::Number(Number::Integer(n)) => product *= n,
470 _ => return expr.clone(), // Can't evaluate
471 }
472 }
473 Expression::integer(product)
474 }
475 // Handle fraction functions: fraction(numerator, denominator)
476 // BigRational::new() automatically reduces to lowest terms
477 Expression::Function { name, args }
478 if name.as_ref() == "fraction" && args.len() == 2 =>
479 {
480 // First evaluate the numerator and denominator
481 let num_eval = Self::eval_exact_internal(&args[0]);
482 let den_eval = Self::eval_exact_internal(&args[1]);
483
484 match (&num_eval, &den_eval) {
485 (
486 Expression::Number(Number::Float(num)),
487 Expression::Number(Number::Float(den)),
488 ) => {
489 if den.abs() >= EPSILON {
490 let result = num / den;
491 if result.fract().abs() < EPSILON {
492 Expression::integer(result as i64)
493 } else {
494 Expression::Number(Number::float(result))
495 }
496 } else {
497 expr.clone()
498 }
499 }
500 (
501 Expression::Number(Number::Integer(num)),
502 Expression::Number(Number::Integer(den)),
503 ) => {
504 if *den != 0 {
505 if num % den == 0 {
506 Expression::integer(num / den)
507 } else {
508 // BigRational::new() automatically reduces to lowest terms via GCD
509 Expression::Number(Number::rational(BigRational::new(
510 BigInt::from(*num),
511 BigInt::from(*den),
512 )))
513 }
514 } else {
515 expr.clone()
516 }
517 }
518 _ => expr.clone(),
519 }
520 }
521 Expression::Number(_) => expr.clone(),
522 _ => expr.clone(),
523 }
524 }
525
526 /// Divide two expressions (simplified division)
527 ///
528 /// Static helper function for recursive division operations.
529 /// Does not require instance state, only performs expression manipulation.
530 ///
531 /// Fractions created via `BigRational::new()` are automatically reduced
532 /// to lowest terms using GCD.
533 #[inline(always)]
534 fn divide_expressions(numerator: &Expression, denominator: &Expression) -> Expression {
535 // First simplify both expressions
536 let num_simplified = numerator.simplify();
537 let den_simplified = denominator.simplify();
538
539 match (&num_simplified, &den_simplified) {
540 // Simple integer division
541 // BigRational::new() automatically reduces to lowest terms
542 (Expression::Number(Number::Integer(n)), Expression::Number(Number::Integer(d))) => {
543 if *d != 0 {
544 if n % d == 0 {
545 Expression::integer(n / d)
546 } else {
547 // Create rational number - automatically reduced to lowest terms
548 Expression::Number(Number::rational(BigRational::new(
549 BigInt::from(*n),
550 BigInt::from(*d),
551 )))
552 }
553 } else {
554 // Division by zero - should be handled as error
555 Expression::integer(0) // Placeholder
556 }
557 }
558 // Integer divided by rational: a / (p/q) = a * (q/p)
559 (Expression::Number(Number::Integer(n)), Expression::Number(Number::Rational(r))) => {
560 // a / (p/q) = a * q / p
561 let inverted = BigRational::new(r.denom().clone(), r.numer().clone());
562 let result = BigRational::from(BigInt::from(*n)) * inverted;
563
564 // Simplify to integer if possible
565 if result.is_integer() {
566 Expression::integer(result.numer().to_string().parse().unwrap())
567 } else {
568 Expression::Number(Number::rational(result))
569 }
570 }
571 // Rational divided by integer: (p/q) / a = (p/q) / a = p/(q*a)
572 (Expression::Number(Number::Rational(r)), Expression::Number(Number::Integer(d))) => {
573 if *d != 0 {
574 let result = (**r).clone() / BigRational::from(BigInt::from(*d));
575 if result.is_integer() {
576 Expression::integer(result.numer().to_string().parse().unwrap())
577 } else {
578 Expression::Number(Number::rational(result))
579 }
580 } else {
581 Expression::integer(0) // Placeholder
582 }
583 }
584 // Rational divided by rational
585 (
586 Expression::Number(Number::Rational(num_r)),
587 Expression::Number(Number::Rational(den_r)),
588 ) => {
589 let result = (**num_r).clone() / (**den_r).clone();
590 if result.is_integer() {
591 Expression::integer(result.numer().to_string().parse().unwrap())
592 } else {
593 Expression::Number(Number::rational(result))
594 }
595 }
596 // Try to simplify further - if denominator is 1, just return numerator
597 (num, Expression::Number(Number::Integer(1))) => num.clone(),
598 // Handle multiplication by -1 and other simple cases
599 (Expression::Mul(factors), den) if factors.len() == 2 => {
600 if let [Expression::Number(Number::Integer(-1)), expr] = &factors[..] {
601 // -1 * expr / den = -(expr / den)
602 let inner_div = Self::divide_expressions(expr, den);
603 Expression::mul(vec![Expression::integer(-1), inner_div]).simplify()
604 } else {
605 // General case
606 let fraction =
607 Expression::function("fraction", vec![num_simplified, den_simplified]);
608 fraction.simplify()
609 }
610 }
611 // For linear solver, try to evaluate numerically if possible
612 _ => {
613 // Return as fraction function and let it simplify
614 let fraction =
615 Expression::function("fraction", vec![num_simplified, den_simplified]);
616 fraction.simplify()
617 }
618 }
619 }
620}
621
622#[cfg(test)]
623mod tests {
624 use super::*;
625 use crate::symbol;
626
627 #[test]
628 fn test_coefficient_extraction() {
629 let x = symbol!(x);
630 let solver = LinearSolver::new();
631
632 // Test 2x + 3
633 let equation = Expression::add(vec![
634 Expression::mul(vec![Expression::integer(2), Expression::symbol(x.clone())]),
635 Expression::integer(3),
636 ]);
637
638 let (a, b) = solver.extract_linear_coefficients(&equation, &x);
639 // The coefficient might be Mul([1, 2]) so we need to simplify it
640 assert_eq!(a.simplify(), Expression::integer(2));
641 assert_eq!(b.simplify(), Expression::integer(3));
642 }
643
644 #[test]
645 fn test_linear_detection() {
646 let x = symbol!(x);
647 let solver = LinearSolver::new();
648
649 // Linear equation
650 let linear = Expression::add(vec![Expression::symbol(x.clone()), Expression::integer(1)]);
651 assert!(solver.is_linear_equation(&linear));
652
653 // Non-linear equation (power)
654 let nonlinear = Expression::pow(Expression::symbol(x.clone()), Expression::integer(2));
655 assert!(!solver.is_linear_equation(&nonlinear));
656 }
657}