mathhook_core/algebra/solvers/matrix_equations.rs
1//! Matrix equation solver for noncommutative algebra
2//!
3//! Handles equations involving matrices, operators, and quaternions where
4//! multiplication order matters. Distinguishes between left and right division.
5//!
6//! # Mathematical Background: Why Order Matters in Matrix Equations
7//!
8//! In commutative algebra (scalars), multiplication order doesn't matter:
9//! - `a * b = b * a`
10//! - `a * x = b` can be solved as `x = b / a = b * (1/a) = (1/a) * b`
11//!
12//! But in noncommutative algebra (matrices, operators, quaternions), order is critical:
13//! - `A * B ≠ B * A` (in general)
14//! - Division must distinguish LEFT from RIGHT
15//!
16//! ## Left Division: A*X = B
17//!
18//! To solve `A*X = B` for X, we multiply both sides by A^(-1) on the LEFT:
19//!
20//! ```text
21//! A*X = B
22//! A^(-1) * (A*X) = A^(-1) * B // Multiply left by A^(-1)
23//! (A^(-1) * A) * X = A^(-1) * B // Associativity
24//! I * X = A^(-1) * B // A^(-1)*A = I
25//! X = A^(-1) * B // Solution
26//! ```
27//!
28//! ## Right Division: X*A = B
29//!
30//! To solve `X*A = B` for X, we multiply both sides by A^(-1) on the RIGHT:
31//!
32//! ```text
33//! X*A = B
34//! (X*A) * A^(-1) = B * A^(-1) // Multiply right by A^(-1)
35//! X * (A*A^(-1)) = B * A^(-1) // Associativity
36//! X * I = B * A^(-1) // A*A^(-1) = I
37//! X = B * A^(-1) // Solution
38//! ```
39//!
40//! ## Why We Can't Swap Order
41//!
42//! In general, `A^(-1) * B ≠ B * A^(-1)`, so:
43//! - Solution to `A*X = B` is `X = A^(-1)*B` (NOT `B*A^(-1)`)
44//! - Solution to `X*A = B` is `X = B*A^(-1)` (NOT `A^(-1)*B`)
45//!
46//! ## Real-World Examples
47//!
48//! **Linear Algebra**: Solving `A*x = b` for vector x
49//! - `A` is coefficient matrix
50//! - `x` is unknown vector
51//! - `b` is result vector
52//! - Solution: `x = A^(-1)*b` (left multiplication)
53//!
54//! **Quantum Mechanics**: Eigenvalue equations `H*ψ = E*ψ`
55//! - `H` is Hamiltonian operator
56//! - `ψ` is wavefunction (eigenstate)
57//! - `E` is energy (eigenvalue, commutative)
58//!
59//! **Quaternions**: 3D rotations `q*v*conj(q)`
60//! - `q` is rotation quaternion
61//! - `v` is vector (as quaternion)
62//! - Order matters: `q*v ≠ v*q`
63
64use crate::algebra::solvers::{EquationSolver, SolverError, SolverResult};
65use crate::core::commutativity::Commutativity;
66use crate::core::{Expression, Symbol};
67use crate::educational::step_by_step::{Step, StepByStepExplanation};
68use crate::simplify::Simplify;
69
70/// Matrix equation solver specialized for noncommutative types
71///
72/// Handles equations of the form:
73/// - Left multiplication: A*X = B (solution: X = A^(-1)*B)
74/// - Right multiplication: X*A = B (solution: X = B*A^(-1))
75///
76/// # Examples
77///
78/// ```rust,ignore
79/// use mathhook_core::{symbol, expr};
80/// use mathhook_core::algebra::solvers::matrix_equations::MatrixEquationSolver;
81/// use mathhook_core::algebra::solvers::EquationSolver;
82///
83/// let solver = MatrixEquationSolver::new();
84/// let A = symbol!(A; matrix);
85/// let B = symbol!(B; matrix);
86/// let X = symbol!(X; matrix);
87///
88/// // Solve A*X = B for X
89/// let equation = expr!((A*X) - B);
90/// let result = solver.solve(&equation, &X);
91/// ```
92#[derive(Debug, Clone)]
93pub struct MatrixEquationSolver {
94 pub show_steps: bool,
95}
96
97impl MatrixEquationSolver {
98 /// Create a new matrix equation solver
99 ///
100 /// # Examples
101 ///
102 /// ```rust,ignore
103 /// use mathhook_core::algebra::solvers::matrix_equations::MatrixEquationSolver;
104 ///
105 /// let solver = MatrixEquationSolver::new();
106 /// ```
107 pub fn new() -> Self {
108 Self { show_steps: true }
109 }
110
111 /// Create solver without step-by-step explanations (for performance)
112 ///
113 /// # Examples
114 ///
115 /// ```rust,ignore
116 /// use mathhook_core::algebra::solvers::matrix_equations::MatrixEquationSolver;
117 ///
118 /// let solver = MatrixEquationSolver::new_fast();
119 /// ```
120 pub fn new_fast() -> Self {
121 Self { show_steps: false }
122 }
123
124 /// Detect if equation is left division (A*X = B)
125 ///
126 /// Returns Some((A, B)) if equation is A*X = B, None otherwise
127 fn detect_left_division(
128 &self,
129 equation: &Expression,
130 variable: &Symbol,
131 ) -> Option<(Expression, Expression)> {
132 let simplified = equation.simplify();
133
134 match &simplified {
135 // Pattern: A*X - B = 0
136 Expression::Add(terms) if terms.len() == 2 => {
137 // Look for pattern: Mul(A, X) and -B
138 match (&terms[0], &terms[1]) {
139 (Expression::Mul(factors), b) if factors.len() == 2 => {
140 if let [a, Expression::Symbol(x)] = &factors[..] {
141 if x == variable && !a.contains_variable(variable) {
142 // Found A*X - B pattern
143 let neg_b =
144 Expression::mul(vec![Expression::integer(-1), b.clone()]);
145 return Some((a.clone(), neg_b.simplify()));
146 }
147 }
148 None
149 }
150 _ => None,
151 }
152 }
153 // Pattern: A*X = 0 (already simplified, b=0 implicit)
154 Expression::Mul(factors) if factors.len() == 2 => {
155 if let [a, Expression::Symbol(x)] = &factors[..] {
156 if x == variable && !a.contains_variable(variable) {
157 return Some((a.clone(), Expression::integer(0)));
158 }
159 }
160 None
161 }
162 _ => None,
163 }
164 }
165
166 /// Detect if equation is right division (X*A = B)
167 ///
168 /// Returns Some((A, B)) if equation is X*A = B, None otherwise
169 fn detect_right_division(
170 &self,
171 equation: &Expression,
172 variable: &Symbol,
173 ) -> Option<(Expression, Expression)> {
174 let simplified = equation.simplify();
175
176 match &simplified {
177 // Pattern: X*A - B = 0
178 Expression::Add(terms) if terms.len() == 2 => {
179 // Look for pattern: Mul(X, A) and -B
180 match (&terms[0], &terms[1]) {
181 (Expression::Mul(factors), b) if factors.len() == 2 => {
182 if let [Expression::Symbol(x), a] = &factors[..] {
183 if x == variable && !a.contains_variable(variable) {
184 // Found X*A - B pattern
185 let neg_b =
186 Expression::mul(vec![Expression::integer(-1), b.clone()]);
187 return Some((a.clone(), neg_b.simplify()));
188 }
189 }
190 None
191 }
192 _ => None,
193 }
194 }
195 // Pattern: X*A = 0 (already simplified, b=0 implicit)
196 Expression::Mul(factors) if factors.len() == 2 => {
197 if let [Expression::Symbol(x), a] = &factors[..] {
198 if x == variable && !a.contains_variable(variable) {
199 return Some((a.clone(), Expression::integer(0)));
200 }
201 }
202 None
203 }
204 _ => None,
205 }
206 }
207
208 /// Solve left division: A*X = B → X = A^(-1)*B
209 ///
210 /// # Arguments
211 ///
212 /// * `A` - The left coefficient matrix/operator
213 /// * `B` - The right-hand side
214 ///
215 /// # Examples
216 ///
217 /// ```rust,ignore
218 /// use mathhook_core::{symbol, expr};
219 /// use mathhook_core::algebra::solvers::matrix_equations::MatrixEquationSolver;
220 ///
221 /// let solver = MatrixEquationSolver::new();
222 /// let A = symbol!(A; matrix);
223 /// let B = symbol!(B; matrix);
224 ///
225 /// let solution = solver.solve_left_division(&A, &B);
226 /// // solution should be A^(-1)*B
227 /// ```
228 pub fn solve_left_division(
229 &self,
230 a: &Expression,
231 b: &Expression,
232 ) -> Result<Expression, SolverError> {
233 // Check if A is potentially singular (for matrices)
234 if self.is_zero_matrix(a) {
235 return Err(SolverError::InvalidEquation(
236 "Cannot invert zero matrix".to_owned(),
237 ));
238 }
239
240 // X = A^(-1) * B (left multiplication)
241 let a_inv = Expression::pow(a.clone(), Expression::integer(-1));
242 let solution = Expression::mul(vec![a_inv, b.clone()]);
243
244 Ok(solution.simplify())
245 }
246
247 /// Solve right division: X*A = B → X = B*A^(-1)
248 ///
249 /// # Arguments
250 ///
251 /// * `A` - The right coefficient matrix/operator
252 /// * `B` - The right-hand side
253 ///
254 /// # Examples
255 ///
256 /// ```rust,ignore
257 /// use mathhook_core::{symbol, expr};
258 /// use mathhook_core::algebra::solvers::matrix_equations::MatrixEquationSolver;
259 ///
260 /// let solver = MatrixEquationSolver::new();
261 /// let A = symbol!(A; matrix);
262 /// let B = symbol!(B; matrix);
263 ///
264 /// let solution = solver.solve_right_division(&A, &B);
265 /// // solution should be B*A^(-1)
266 /// ```
267 pub fn solve_right_division(
268 &self,
269 a: &Expression,
270 b: &Expression,
271 ) -> Result<Expression, SolverError> {
272 // Check if A is potentially singular (for matrices)
273 if self.is_zero_matrix(a) {
274 return Err(SolverError::InvalidEquation(
275 "Cannot invert zero matrix".to_owned(),
276 ));
277 }
278
279 // X = B * A^(-1) (right multiplication)
280 let a_inv = Expression::pow(a.clone(), Expression::integer(-1));
281 let solution = Expression::mul(vec![b.clone(), a_inv]);
282
283 Ok(solution.simplify())
284 }
285
286 /// Check if expression represents a zero matrix
287 fn is_zero_matrix(&self, expr: &Expression) -> bool {
288 match expr {
289 Expression::Number(n) if n.is_zero() => true,
290 Expression::Matrix(m) => {
291 let (rows, cols) = m.dimensions();
292 for i in 0..rows {
293 for j in 0..cols {
294 let elem = m.get_element(i, j);
295 if !elem.is_zero() {
296 return false;
297 }
298 }
299 }
300 true
301 }
302 _ => false,
303 }
304 }
305
306 /// Detect if variable appears in multiple positions (error case)
307 fn variable_appears_multiple_times(&self, expr: &Expression, variable: &Symbol) -> bool {
308 let count = expr.count_variable_occurrences(variable);
309 count > 1
310 }
311}
312
313impl Default for MatrixEquationSolver {
314 fn default() -> Self {
315 Self::new()
316 }
317}
318
319impl EquationSolver for MatrixEquationSolver {
320 fn solve(&self, equation: &Expression, variable: &Symbol) -> SolverResult {
321 // Check if variable appears multiple times (error case for noncommutative)
322 if self.variable_appears_multiple_times(equation, variable) {
323 return SolverResult::NoSolution;
324 }
325
326 // Try left division first
327 if let Some((a, b)) = self.detect_left_division(equation, variable) {
328 match self.solve_left_division(&a, &b) {
329 Ok(solution) => return SolverResult::Single(solution),
330 Err(_) => return SolverResult::NoSolution,
331 }
332 }
333
334 // Try right division
335 if let Some((a, b)) = self.detect_right_division(equation, variable) {
336 match self.solve_right_division(&a, &b) {
337 Ok(solution) => return SolverResult::Single(solution),
338 Err(_) => return SolverResult::NoSolution,
339 }
340 }
341
342 SolverResult::NoSolution
343 }
344
345 fn solve_with_explanation(
346 &self,
347 equation: &Expression,
348 variable: &Symbol,
349 ) -> (SolverResult, StepByStepExplanation) {
350 let mut steps = vec![Step::new(
351 "Given Equation",
352 format!("Solve {} = 0 for {}", equation, variable.name),
353 )];
354
355 // Check commutativity
356 if equation.commutativity() == Commutativity::Commutative {
357 steps.push(Step::new(
358 "Analysis",
359 "All symbols are commutative - use standard linear solver instead",
360 ));
361 return (SolverResult::NoSolution, StepByStepExplanation::new(steps));
362 }
363
364 steps.push(Step::new(
365 "Analysis",
366 "Detected noncommutative symbols (matrix/operator/quaternion)",
367 ));
368
369 // Try left division
370 if let Some((a, b)) = self.detect_left_division(equation, variable) {
371 steps.push(Step::new(
372 "Pattern",
373 format!(
374 "Identified left division: {} * {} = {}",
375 a, variable.name, b
376 ),
377 ));
378 steps.push(Step::new(
379 "Solution Method",
380 format!(
381 "{} = {}^(-1) * {} (inverse applied on LEFT)",
382 variable.name, a, b
383 ),
384 ));
385
386 match self.solve_left_division(&a, &b) {
387 Ok(solution) => {
388 steps.push(Step::new(
389 "Result",
390 format!("{} = {}", variable.name, solution),
391 ));
392 return (
393 SolverResult::Single(solution),
394 StepByStepExplanation::new(steps),
395 );
396 }
397 Err(e) => {
398 steps.push(Step::new("Error", format!("{:?}", e)));
399 return (SolverResult::NoSolution, StepByStepExplanation::new(steps));
400 }
401 }
402 }
403
404 // Try right division
405 if let Some((a, b)) = self.detect_right_division(equation, variable) {
406 steps.push(Step::new(
407 "Pattern",
408 format!(
409 "Identified right division: {} * {} = {}",
410 variable.name, a, b
411 ),
412 ));
413 steps.push(Step::new(
414 "Solution Method",
415 format!(
416 "{} = {} * {}^(-1) (inverse applied on RIGHT)",
417 variable.name, b, a
418 ),
419 ));
420
421 match self.solve_right_division(&a, &b) {
422 Ok(solution) => {
423 steps.push(Step::new(
424 "Result",
425 format!("{} = {}", variable.name, solution),
426 ));
427 return (
428 SolverResult::Single(solution),
429 StepByStepExplanation::new(steps),
430 );
431 }
432 Err(e) => {
433 steps.push(Step::new("Error", format!("{:?}", e)));
434 return (SolverResult::NoSolution, StepByStepExplanation::new(steps));
435 }
436 }
437 }
438
439 steps.push(Step::new(
440 "Result",
441 "Could not identify left or right division pattern",
442 ));
443 (SolverResult::NoSolution, StepByStepExplanation::new(steps))
444 }
445
446 fn can_solve(&self, equation: &Expression) -> bool {
447 equation.commutativity() != Commutativity::Commutative
448 }
449}
450
451#[cfg(test)]
452mod tests {
453 use super::*;
454 use crate::symbol;
455
456 #[test]
457 fn test_left_division_detection() {
458 let solver = MatrixEquationSolver::new();
459 let a = symbol!(A; matrix);
460 let x = symbol!(X; matrix);
461 let b = symbol!(B; matrix);
462
463 // A*X - B = 0
464 let equation = Expression::add(vec![
465 Expression::mul(vec![
466 Expression::symbol(a.clone()),
467 Expression::symbol(x.clone()),
468 ]),
469 Expression::mul(vec![Expression::integer(-1), Expression::symbol(b.clone())]),
470 ]);
471
472 let result = solver.detect_left_division(&equation, &x);
473 assert!(result.is_some());
474 }
475
476 #[test]
477 fn test_right_division_detection() {
478 let solver = MatrixEquationSolver::new();
479 let a = symbol!(A; matrix);
480 let x = symbol!(X; matrix);
481 let b = symbol!(B; matrix);
482
483 // X*A - B = 0
484 let equation = Expression::add(vec![
485 Expression::mul(vec![
486 Expression::symbol(x.clone()),
487 Expression::symbol(a.clone()),
488 ]),
489 Expression::mul(vec![Expression::integer(-1), Expression::symbol(b.clone())]),
490 ]);
491
492 let result = solver.detect_right_division(&equation, &x);
493 assert!(result.is_some());
494 }
495}