mathhook_core/algebra/solvers/
quadratic.rs

1//! Solves equations of the form ax² + bx + c = 0
2//! Includes step-by-step explanations for educational value
3
4use crate::algebra::solvers::{EquationSolver, SolverResult};
5use crate::core::constants::EPSILON;
6use crate::core::{Expression, Number, Symbol};
7use crate::educational::step_by_step::{Step, StepByStepExplanation};
8// Unused educational imports removed
9use crate::formatter::latex::LaTeXFormatter;
10use crate::simplify::Simplify;
11use num_bigint::BigInt;
12use num_rational::BigRational;
13
14/// Quadratic equation solver
15#[derive(Debug, Clone)]
16pub struct QuadraticSolver;
17
18impl Default for QuadraticSolver {
19    fn default() -> Self {
20        Self::new()
21    }
22}
23
24impl QuadraticSolver {
25    pub fn new() -> Self {
26        Self
27    }
28}
29
30impl EquationSolver for QuadraticSolver {
31    #[inline(always)]
32    fn solve(&self, equation: &Expression, variable: &Symbol) -> SolverResult {
33        // Simplify equation first to flatten nested structures
34        let simplified_equation = equation.simplify();
35
36        // Extract coefficients from quadratic equation ax² + bx + c = 0
37        let (a, b, c) = self.extract_quadratic_coefficients(&simplified_equation, variable);
38
39        // Handle special cases
40        let a_simplified = a.simplify();
41        let b_simplified = b.simplify();
42        let c_simplified = c.simplify();
43
44        if a_simplified.is_zero() {
45            // Degenerate case - actually linear: bx + c = 0
46            if b_simplified.is_zero() {
47                if c_simplified.is_zero() {
48                    return SolverResult::InfiniteSolutions; // 0 = 0
49                } else {
50                    return SolverResult::NoSolution; // c = 0 where c ≠ 0
51                }
52            } else {
53                // Linear equation: bx + c = 0 → x = -c/b
54                return self.solve_linear(&b_simplified, &c_simplified);
55            }
56        }
57
58        // Solve using quadratic formula: x = (-b ± √(b² - 4ac)) / 2a
59        self.solve_quadratic_formula(&a_simplified, &b_simplified, &c_simplified)
60    }
61
62    fn solve_with_explanation(
63        &self,
64        equation: &Expression,
65        variable: &Symbol,
66    ) -> (SolverResult, StepByStepExplanation) {
67        let mut steps = Vec::new();
68
69        let simplified_equation = equation.simplify();
70        let equation_latex = simplified_equation
71            .to_latex(None)
72            .unwrap_or_else(|_| "equation".to_owned());
73
74        steps.push(Step::new(
75            "Given Equation",
76            format!("Solve: {} = 0", equation_latex),
77        ));
78
79        let (a, b, c) = self.extract_quadratic_coefficients(&simplified_equation, variable);
80        let a_simplified = a.simplify();
81        let b_simplified = b.simplify();
82        let c_simplified = c.simplify();
83
84        let a_latex = a_simplified
85            .to_latex(None)
86            .unwrap_or_else(|_| "a".to_owned());
87        let b_latex = b_simplified
88            .to_latex(None)
89            .unwrap_or_else(|_| "b".to_owned());
90        let c_latex = c_simplified
91            .to_latex(None)
92            .unwrap_or_else(|_| "c".to_owned());
93
94        steps.push(Step::new(
95            "Extract Coefficients",
96            format!(
97                "Identified coefficients: a = {}, b = {}, c = {}",
98                a_latex, b_latex, c_latex
99            ),
100        ));
101
102        if a_simplified.is_zero() {
103            steps.push(Step::new(
104                "Special Case",
105                "Coefficient a = 0, this is actually a linear equation",
106            ));
107
108            if b_simplified.is_zero() {
109                steps.push(Step::new(
110                    "Degenerate Case",
111                    if c_simplified.is_zero() {
112                        "0 = 0 is always true (infinite solutions)"
113                    } else {
114                        "Non-zero constant = 0 has no solution"
115                    },
116                ));
117            } else {
118                steps.push(Step::new(
119                    "Linear Solution",
120                    format!("Solving linear equation: {}x + {} = 0", b_latex, c_latex),
121                ));
122            }
123
124            let result = self.solve(equation, variable);
125            return (result, StepByStepExplanation::new(steps));
126        }
127
128        steps.push(Step::new(
129            "Quadratic Formula",
130            "Applying quadratic formula: x = (-b ± √(b² - 4ac)) / (2a)",
131        ));
132
133        let discriminant = match (&a_simplified, &b_simplified, &c_simplified) {
134            (
135                Expression::Number(Number::Integer(a_val)),
136                Expression::Number(Number::Integer(b_val)),
137                Expression::Number(Number::Integer(c_val)),
138            ) => b_val * b_val - 4 * a_val * c_val,
139            _ => 0,
140        };
141
142        steps.push(Step::new(
143            "Compute Discriminant",
144            format!("Discriminant Δ = b² - 4ac = {}", discriminant),
145        ));
146
147        if discriminant > 0 {
148            steps.push(Step::new(
149                "Discriminant Analysis",
150                "Δ > 0: Equation has two distinct real solutions",
151            ));
152        } else if discriminant == 0 {
153            steps.push(Step::new(
154                "Discriminant Analysis",
155                "Δ = 0: Equation has one repeated real solution",
156            ));
157        } else {
158            steps.push(Step::new(
159                "Discriminant Analysis",
160                "Δ < 0: Equation has two complex conjugate solutions",
161            ));
162        }
163
164        let result = self.solve_quadratic_formula(&a_simplified, &b_simplified, &c_simplified);
165
166        match &result {
167            SolverResult::Single(sol) => {
168                let sol_latex = sol.to_latex(None).unwrap_or_else(|_| "solution".to_owned());
169                steps.push(Step::new("Solution", format!("x = {}", sol_latex)));
170            }
171            SolverResult::Multiple(sols) => {
172                let sols_latex: Vec<String> = sols
173                    .iter()
174                    .map(|s| s.to_latex(None).unwrap_or_else(|_| "solution".to_owned()))
175                    .collect();
176                steps.push(Step::new(
177                    "Solutions",
178                    format!("x₁ = {}, x₂ = {}", sols_latex[0], sols_latex[1]),
179                ));
180            }
181            _ => {
182                steps.push(Step::new("Result", format!("{:?}", result)));
183            }
184        }
185
186        (result, StepByStepExplanation::new(steps))
187    }
188
189    fn can_solve(&self, equation: &Expression) -> bool {
190        // Check if equation has degree 2 in the variable
191        self.is_quadratic_equation(equation)
192    }
193}
194
195impl QuadraticSolver {
196    /// Extract coefficients a, b, c from ax² + bx + c = 0
197    fn extract_quadratic_coefficients(
198        &self,
199        equation: &Expression,
200        variable: &Symbol,
201    ) -> (Expression, Expression, Expression) {
202        // First, flatten all nested Add expressions
203        let flattened_terms = equation.flatten_add_terms();
204
205        let mut a_coeff = Expression::integer(0);
206        let mut b_coeff = Expression::integer(0);
207        let mut c_coeff = Expression::integer(0);
208
209        for term in flattened_terms.iter() {
210            match term {
211                // x² term
212                Expression::Pow(base, exp) if **base == Expression::symbol(variable.clone()) => {
213                    if let Expression::Number(Number::Integer(2)) = **exp {
214                        a_coeff = Expression::add(vec![a_coeff, Expression::integer(1)]);
215                    }
216                }
217                // ax² term
218                Expression::Mul(factors) => {
219                    let mut has_x_squared = false;
220                    let mut has_x_linear = false;
221                    let mut coeff = Expression::integer(1);
222
223                    for factor in factors.iter() {
224                        if let Expression::Pow(base, exp) = factor {
225                            if **base == Expression::symbol(variable.clone()) {
226                                if let Expression::Number(Number::Integer(2)) = **exp {
227                                    has_x_squared = true;
228                                } else if let Expression::Number(Number::Integer(1)) = **exp {
229                                    // x^1 = x (linear term)
230                                    has_x_linear = true;
231                                }
232                            }
233                        } else if *factor == Expression::symbol(variable.clone()) {
234                            // Linear term: coefficient * x
235                            has_x_linear = true;
236                        } else {
237                            coeff = Expression::mul(vec![coeff, factor.clone()]);
238                        }
239                    }
240
241                    if has_x_squared {
242                        a_coeff = Expression::add(vec![a_coeff, coeff]);
243                    } else if has_x_linear {
244                        b_coeff = Expression::add(vec![b_coeff, coeff]);
245                    } else {
246                        // No variable in this multiplication - it's a constant
247                        c_coeff = Expression::add(vec![c_coeff, term.clone()]);
248                    }
249                }
250                // x term (linear)
251                _ if *term == Expression::symbol(variable.clone()) => {
252                    b_coeff = Expression::add(vec![b_coeff, Expression::integer(1)]);
253                }
254                // Constant term
255                _ => {
256                    c_coeff = Expression::add(vec![c_coeff, term.clone()]);
257                }
258            }
259        }
260
261        (a_coeff, b_coeff, c_coeff)
262    }
263
264    /// Solve linear equation bx + c = 0 (degenerate quadratic case)
265    fn solve_linear(&self, b: &Expression, c: &Expression) -> SolverResult {
266        match (b, c) {
267            (
268                Expression::Number(Number::Integer(b_val)),
269                Expression::Number(Number::Integer(c_val)),
270            ) => {
271                if *b_val != 0 {
272                    let result = -c_val / b_val;
273                    if c_val % b_val == 0 {
274                        SolverResult::Single(Expression::integer(result))
275                    } else {
276                        SolverResult::Single(Expression::Number(Number::rational(
277                            BigRational::new(BigInt::from(-c_val), BigInt::from(*b_val)),
278                        )))
279                    }
280                } else {
281                    SolverResult::NoSolution
282                }
283            }
284            _ => {
285                // Symbolic case: x = -c/b
286                let neg_c = Expression::mul(vec![Expression::integer(-1), c.clone()]);
287                let result = Expression::div(neg_c, b.clone());
288                SolverResult::Single(result)
289            }
290        }
291    }
292
293    /// Solve using quadratic formula
294    fn solve_quadratic_formula(
295        &self,
296        a: &Expression,
297        b: &Expression,
298        c: &Expression,
299    ) -> SolverResult {
300        match (a, b, c) {
301            (
302                Expression::Number(Number::Integer(a_val)),
303                Expression::Number(Number::Integer(b_val)),
304                Expression::Number(Number::Integer(c_val)),
305            ) => {
306                // Calculate discriminant: Δ = b² - 4ac
307                let discriminant = b_val * b_val - 4 * a_val * c_val;
308
309                if discriminant > 0 {
310                    // Two real solutions
311                    let sqrt_discriminant = (discriminant as f64).sqrt();
312                    let solution1 = (-b_val as f64 + sqrt_discriminant) / (2.0 * *a_val as f64);
313                    let solution2 = (-b_val as f64 - sqrt_discriminant) / (2.0 * *a_val as f64);
314
315                    // Try to return integers if possible
316                    let sol1 = if solution1.fract().abs() < EPSILON {
317                        Expression::integer(solution1 as i64)
318                    } else {
319                        Expression::Number(Number::float(solution1))
320                    };
321                    let sol2 = if solution2.fract().abs() < EPSILON {
322                        Expression::integer(solution2 as i64)
323                    } else {
324                        Expression::Number(Number::float(solution2))
325                    };
326
327                    SolverResult::Multiple(vec![sol1, sol2])
328                } else if discriminant == 0 {
329                    // One solution (repeated root)
330                    let solution = -b_val as f64 / (2.0 * *a_val as f64);
331                    let sol = if solution.fract().abs() < EPSILON {
332                        Expression::integer(solution as i64)
333                    } else {
334                        Expression::Number(Number::float(solution))
335                    };
336                    SolverResult::Single(sol)
337                } else {
338                    // Complex solutions: x = (-b ± i√|Δ|) / 2a
339                    let sqrt_abs_discriminant = ((-discriminant) as f64).sqrt();
340                    let real_part = -b_val as f64 / (2.0 * *a_val as f64);
341                    let imag_part = sqrt_abs_discriminant / (2.0 * *a_val as f64);
342
343                    // Use Expression::complex for proper complex number representation
344                    let solution1 = Expression::complex(
345                        Expression::Number(Number::float(real_part)),
346                        Expression::Number(Number::float(imag_part)),
347                    );
348                    let solution2 = Expression::complex(
349                        Expression::Number(Number::float(real_part)),
350                        Expression::Number(Number::float(-imag_part)),
351                    );
352
353                    SolverResult::Multiple(vec![solution1, solution2])
354                }
355            }
356            _ => {
357                // Symbolic case: use quadratic formula symbolically
358                // Discriminant: b² - 4ac
359                let b_squared = Expression::pow(b.clone(), Expression::integer(2));
360                let four_a_c = Expression::mul(vec![Expression::integer(4), a.clone(), c.clone()]);
361                let discriminant = Expression::add(vec![
362                    b_squared,
363                    Expression::mul(vec![Expression::integer(-1), four_a_c]),
364                ]);
365
366                // Check if discriminant simplifies to a number
367                let discriminant_simplified = discriminant.simplify();
368
369                // Two times a for denominator
370                let two_a = Expression::mul(vec![Expression::integer(2), a.clone()]);
371
372                // Square root of discriminant
373                let sqrt_discriminant = Expression::function("sqrt", vec![discriminant_simplified]);
374
375                // Solutions: (-b ± √discriminant) / (2a)
376                let neg_b = Expression::mul(vec![Expression::integer(-1), b.clone()]);
377                let solution1 = Expression::div(
378                    Expression::add(vec![neg_b.clone(), sqrt_discriminant.clone()]),
379                    two_a.clone(),
380                );
381
382                let solution2 = Expression::div(
383                    Expression::add(vec![
384                        neg_b,
385                        Expression::mul(vec![Expression::integer(-1), sqrt_discriminant]),
386                    ]),
387                    two_a,
388                );
389
390                SolverResult::Multiple(vec![solution1, solution2])
391            }
392        }
393    }
394
395    /// Check if equation is quadratic
396    fn is_quadratic_equation(&self, _equation: &Expression) -> bool {
397        // Simplified check for now
398        true
399    }
400}