mathhook_core/algebra/solvers/
quadratic.rs1use crate::algebra::solvers::{EquationSolver, SolverResult};
5use crate::core::constants::EPSILON;
6use crate::core::{Expression, Number, Symbol};
7use crate::educational::step_by_step::{Step, StepByStepExplanation};
8use crate::formatter::latex::LaTeXFormatter;
10use crate::simplify::Simplify;
11use num_bigint::BigInt;
12use num_rational::BigRational;
13
14#[derive(Debug, Clone)]
16pub struct QuadraticSolver;
17
18impl Default for QuadraticSolver {
19 fn default() -> Self {
20 Self::new()
21 }
22}
23
24impl QuadraticSolver {
25 pub fn new() -> Self {
26 Self
27 }
28}
29
30impl EquationSolver for QuadraticSolver {
31 #[inline(always)]
32 fn solve(&self, equation: &Expression, variable: &Symbol) -> SolverResult {
33 let simplified_equation = equation.simplify();
35
36 let (a, b, c) = self.extract_quadratic_coefficients(&simplified_equation, variable);
38
39 let a_simplified = a.simplify();
41 let b_simplified = b.simplify();
42 let c_simplified = c.simplify();
43
44 if a_simplified.is_zero() {
45 if b_simplified.is_zero() {
47 if c_simplified.is_zero() {
48 return SolverResult::InfiniteSolutions; } else {
50 return SolverResult::NoSolution; }
52 } else {
53 return self.solve_linear(&b_simplified, &c_simplified);
55 }
56 }
57
58 self.solve_quadratic_formula(&a_simplified, &b_simplified, &c_simplified)
60 }
61
62 fn solve_with_explanation(
63 &self,
64 equation: &Expression,
65 variable: &Symbol,
66 ) -> (SolverResult, StepByStepExplanation) {
67 let mut steps = Vec::new();
68
69 let simplified_equation = equation.simplify();
70 let equation_latex = simplified_equation
71 .to_latex(None)
72 .unwrap_or_else(|_| "equation".to_owned());
73
74 steps.push(Step::new(
75 "Given Equation",
76 format!("Solve: {} = 0", equation_latex),
77 ));
78
79 let (a, b, c) = self.extract_quadratic_coefficients(&simplified_equation, variable);
80 let a_simplified = a.simplify();
81 let b_simplified = b.simplify();
82 let c_simplified = c.simplify();
83
84 let a_latex = a_simplified
85 .to_latex(None)
86 .unwrap_or_else(|_| "a".to_owned());
87 let b_latex = b_simplified
88 .to_latex(None)
89 .unwrap_or_else(|_| "b".to_owned());
90 let c_latex = c_simplified
91 .to_latex(None)
92 .unwrap_or_else(|_| "c".to_owned());
93
94 steps.push(Step::new(
95 "Extract Coefficients",
96 format!(
97 "Identified coefficients: a = {}, b = {}, c = {}",
98 a_latex, b_latex, c_latex
99 ),
100 ));
101
102 if a_simplified.is_zero() {
103 steps.push(Step::new(
104 "Special Case",
105 "Coefficient a = 0, this is actually a linear equation",
106 ));
107
108 if b_simplified.is_zero() {
109 steps.push(Step::new(
110 "Degenerate Case",
111 if c_simplified.is_zero() {
112 "0 = 0 is always true (infinite solutions)"
113 } else {
114 "Non-zero constant = 0 has no solution"
115 },
116 ));
117 } else {
118 steps.push(Step::new(
119 "Linear Solution",
120 format!("Solving linear equation: {}x + {} = 0", b_latex, c_latex),
121 ));
122 }
123
124 let result = self.solve(equation, variable);
125 return (result, StepByStepExplanation::new(steps));
126 }
127
128 steps.push(Step::new(
129 "Quadratic Formula",
130 "Applying quadratic formula: x = (-b ± √(b² - 4ac)) / (2a)",
131 ));
132
133 let discriminant = match (&a_simplified, &b_simplified, &c_simplified) {
134 (
135 Expression::Number(Number::Integer(a_val)),
136 Expression::Number(Number::Integer(b_val)),
137 Expression::Number(Number::Integer(c_val)),
138 ) => b_val * b_val - 4 * a_val * c_val,
139 _ => 0,
140 };
141
142 steps.push(Step::new(
143 "Compute Discriminant",
144 format!("Discriminant Δ = b² - 4ac = {}", discriminant),
145 ));
146
147 if discriminant > 0 {
148 steps.push(Step::new(
149 "Discriminant Analysis",
150 "Δ > 0: Equation has two distinct real solutions",
151 ));
152 } else if discriminant == 0 {
153 steps.push(Step::new(
154 "Discriminant Analysis",
155 "Δ = 0: Equation has one repeated real solution",
156 ));
157 } else {
158 steps.push(Step::new(
159 "Discriminant Analysis",
160 "Δ < 0: Equation has two complex conjugate solutions",
161 ));
162 }
163
164 let result = self.solve_quadratic_formula(&a_simplified, &b_simplified, &c_simplified);
165
166 match &result {
167 SolverResult::Single(sol) => {
168 let sol_latex = sol.to_latex(None).unwrap_or_else(|_| "solution".to_owned());
169 steps.push(Step::new("Solution", format!("x = {}", sol_latex)));
170 }
171 SolverResult::Multiple(sols) => {
172 let sols_latex: Vec<String> = sols
173 .iter()
174 .map(|s| s.to_latex(None).unwrap_or_else(|_| "solution".to_owned()))
175 .collect();
176 steps.push(Step::new(
177 "Solutions",
178 format!("x₁ = {}, x₂ = {}", sols_latex[0], sols_latex[1]),
179 ));
180 }
181 _ => {
182 steps.push(Step::new("Result", format!("{:?}", result)));
183 }
184 }
185
186 (result, StepByStepExplanation::new(steps))
187 }
188
189 fn can_solve(&self, equation: &Expression) -> bool {
190 self.is_quadratic_equation(equation)
192 }
193}
194
195impl QuadraticSolver {
196 fn extract_quadratic_coefficients(
198 &self,
199 equation: &Expression,
200 variable: &Symbol,
201 ) -> (Expression, Expression, Expression) {
202 let flattened_terms = equation.flatten_add_terms();
204
205 let mut a_coeff = Expression::integer(0);
206 let mut b_coeff = Expression::integer(0);
207 let mut c_coeff = Expression::integer(0);
208
209 for term in flattened_terms.iter() {
210 match term {
211 Expression::Pow(base, exp) if **base == Expression::symbol(variable.clone()) => {
213 if let Expression::Number(Number::Integer(2)) = **exp {
214 a_coeff = Expression::add(vec![a_coeff, Expression::integer(1)]);
215 }
216 }
217 Expression::Mul(factors) => {
219 let mut has_x_squared = false;
220 let mut has_x_linear = false;
221 let mut coeff = Expression::integer(1);
222
223 for factor in factors.iter() {
224 if let Expression::Pow(base, exp) = factor {
225 if **base == Expression::symbol(variable.clone()) {
226 if let Expression::Number(Number::Integer(2)) = **exp {
227 has_x_squared = true;
228 } else if let Expression::Number(Number::Integer(1)) = **exp {
229 has_x_linear = true;
231 }
232 }
233 } else if *factor == Expression::symbol(variable.clone()) {
234 has_x_linear = true;
236 } else {
237 coeff = Expression::mul(vec![coeff, factor.clone()]);
238 }
239 }
240
241 if has_x_squared {
242 a_coeff = Expression::add(vec![a_coeff, coeff]);
243 } else if has_x_linear {
244 b_coeff = Expression::add(vec![b_coeff, coeff]);
245 } else {
246 c_coeff = Expression::add(vec![c_coeff, term.clone()]);
248 }
249 }
250 _ if *term == Expression::symbol(variable.clone()) => {
252 b_coeff = Expression::add(vec![b_coeff, Expression::integer(1)]);
253 }
254 _ => {
256 c_coeff = Expression::add(vec![c_coeff, term.clone()]);
257 }
258 }
259 }
260
261 (a_coeff, b_coeff, c_coeff)
262 }
263
264 fn solve_linear(&self, b: &Expression, c: &Expression) -> SolverResult {
266 match (b, c) {
267 (
268 Expression::Number(Number::Integer(b_val)),
269 Expression::Number(Number::Integer(c_val)),
270 ) => {
271 if *b_val != 0 {
272 let result = -c_val / b_val;
273 if c_val % b_val == 0 {
274 SolverResult::Single(Expression::integer(result))
275 } else {
276 SolverResult::Single(Expression::Number(Number::rational(
277 BigRational::new(BigInt::from(-c_val), BigInt::from(*b_val)),
278 )))
279 }
280 } else {
281 SolverResult::NoSolution
282 }
283 }
284 _ => {
285 let neg_c = Expression::mul(vec![Expression::integer(-1), c.clone()]);
287 let result = Expression::div(neg_c, b.clone());
288 SolverResult::Single(result)
289 }
290 }
291 }
292
293 fn solve_quadratic_formula(
295 &self,
296 a: &Expression,
297 b: &Expression,
298 c: &Expression,
299 ) -> SolverResult {
300 match (a, b, c) {
301 (
302 Expression::Number(Number::Integer(a_val)),
303 Expression::Number(Number::Integer(b_val)),
304 Expression::Number(Number::Integer(c_val)),
305 ) => {
306 let discriminant = b_val * b_val - 4 * a_val * c_val;
308
309 if discriminant > 0 {
310 let sqrt_discriminant = (discriminant as f64).sqrt();
312 let solution1 = (-b_val as f64 + sqrt_discriminant) / (2.0 * *a_val as f64);
313 let solution2 = (-b_val as f64 - sqrt_discriminant) / (2.0 * *a_val as f64);
314
315 let sol1 = if solution1.fract().abs() < EPSILON {
317 Expression::integer(solution1 as i64)
318 } else {
319 Expression::Number(Number::float(solution1))
320 };
321 let sol2 = if solution2.fract().abs() < EPSILON {
322 Expression::integer(solution2 as i64)
323 } else {
324 Expression::Number(Number::float(solution2))
325 };
326
327 SolverResult::Multiple(vec![sol1, sol2])
328 } else if discriminant == 0 {
329 let solution = -b_val as f64 / (2.0 * *a_val as f64);
331 let sol = if solution.fract().abs() < EPSILON {
332 Expression::integer(solution as i64)
333 } else {
334 Expression::Number(Number::float(solution))
335 };
336 SolverResult::Single(sol)
337 } else {
338 let sqrt_abs_discriminant = ((-discriminant) as f64).sqrt();
340 let real_part = -b_val as f64 / (2.0 * *a_val as f64);
341 let imag_part = sqrt_abs_discriminant / (2.0 * *a_val as f64);
342
343 let solution1 = Expression::complex(
345 Expression::Number(Number::float(real_part)),
346 Expression::Number(Number::float(imag_part)),
347 );
348 let solution2 = Expression::complex(
349 Expression::Number(Number::float(real_part)),
350 Expression::Number(Number::float(-imag_part)),
351 );
352
353 SolverResult::Multiple(vec![solution1, solution2])
354 }
355 }
356 _ => {
357 let b_squared = Expression::pow(b.clone(), Expression::integer(2));
360 let four_a_c = Expression::mul(vec![Expression::integer(4), a.clone(), c.clone()]);
361 let discriminant = Expression::add(vec![
362 b_squared,
363 Expression::mul(vec![Expression::integer(-1), four_a_c]),
364 ]);
365
366 let discriminant_simplified = discriminant.simplify();
368
369 let two_a = Expression::mul(vec![Expression::integer(2), a.clone()]);
371
372 let sqrt_discriminant = Expression::function("sqrt", vec![discriminant_simplified]);
374
375 let neg_b = Expression::mul(vec![Expression::integer(-1), b.clone()]);
377 let solution1 = Expression::div(
378 Expression::add(vec![neg_b.clone(), sqrt_discriminant.clone()]),
379 two_a.clone(),
380 );
381
382 let solution2 = Expression::div(
383 Expression::add(vec![
384 neg_b,
385 Expression::mul(vec![Expression::integer(-1), sqrt_discriminant]),
386 ]),
387 two_a,
388 );
389
390 SolverResult::Multiple(vec![solution1, solution2])
391 }
392 }
393 }
394
395 fn is_quadratic_equation(&self, _equation: &Expression) -> bool {
397 true
399 }
400}