mathhook_core/algebra/solvers/
matrix_equations.rs

1//! Matrix equation solver for noncommutative algebra
2//!
3//! Handles equations involving matrices, operators, and quaternions where
4//! multiplication order matters. Distinguishes between left and right division.
5//!
6//! # Mathematical Background: Why Order Matters in Matrix Equations
7//!
8//! In commutative algebra (scalars), multiplication order doesn't matter:
9//! - `a * b = b * a`
10//! - `a * x = b` can be solved as `x = b / a = b * (1/a) = (1/a) * b`
11//!
12//! But in noncommutative algebra (matrices, operators, quaternions), order is critical:
13//! - `A * B ≠ B * A` (in general)
14//! - Division must distinguish LEFT from RIGHT
15//!
16//! ## Left Division: A*X = B
17//!
18//! To solve `A*X = B` for X, we multiply both sides by A^(-1) on the LEFT:
19//!
20//! ```text
21//! A*X = B
22//! A^(-1) * (A*X) = A^(-1) * B    // Multiply left by A^(-1)
23//! (A^(-1) * A) * X = A^(-1) * B  // Associativity
24//! I * X = A^(-1) * B             // A^(-1)*A = I
25//! X = A^(-1) * B                 // Solution
26//! ```
27//!
28//! ## Right Division: X*A = B
29//!
30//! To solve `X*A = B` for X, we multiply both sides by A^(-1) on the RIGHT:
31//!
32//! ```text
33//! X*A = B
34//! (X*A) * A^(-1) = B * A^(-1)    // Multiply right by A^(-1)
35//! X * (A*A^(-1)) = B * A^(-1)    // Associativity
36//! X * I = B * A^(-1)             // A*A^(-1) = I
37//! X = B * A^(-1)                 // Solution
38//! ```
39//!
40//! ## Why We Can't Swap Order
41//!
42//! In general, `A^(-1) * B ≠ B * A^(-1)`, so:
43//! - Solution to `A*X = B` is `X = A^(-1)*B` (NOT `B*A^(-1)`)
44//! - Solution to `X*A = B` is `X = B*A^(-1)` (NOT `A^(-1)*B`)
45//!
46//! ## Real-World Examples
47//!
48//! **Linear Algebra**: Solving `A*x = b` for vector x
49//! - `A` is coefficient matrix
50//! - `x` is unknown vector
51//! - `b` is result vector
52//! - Solution: `x = A^(-1)*b` (left multiplication)
53//!
54//! **Quantum Mechanics**: Eigenvalue equations `H*ψ = E*ψ`
55//! - `H` is Hamiltonian operator
56//! - `ψ` is wavefunction (eigenstate)
57//! - `E` is energy (eigenvalue, commutative)
58//!
59//! **Quaternions**: 3D rotations `q*v*conj(q)`
60//! - `q` is rotation quaternion
61//! - `v` is vector (as quaternion)
62//! - Order matters: `q*v ≠ v*q`
63
64use crate::algebra::solvers::{EquationSolver, SolverError, SolverResult};
65use crate::core::commutativity::Commutativity;
66use crate::core::{Expression, Symbol};
67use crate::educational::step_by_step::{Step, StepByStepExplanation};
68use crate::simplify::Simplify;
69
70/// Matrix equation solver specialized for noncommutative types
71///
72/// Handles equations of the form:
73/// - Left multiplication: A*X = B (solution: X = A^(-1)*B)
74/// - Right multiplication: X*A = B (solution: X = B*A^(-1))
75///
76/// # Examples
77///
78/// ```rust,ignore
79/// use mathhook_core::{symbol, expr};
80/// use mathhook_core::algebra::solvers::matrix_equations::MatrixEquationSolver;
81/// use mathhook_core::algebra::solvers::EquationSolver;
82///
83/// let solver = MatrixEquationSolver::new();
84/// let A = symbol!(A; matrix);
85/// let B = symbol!(B; matrix);
86/// let X = symbol!(X; matrix);
87///
88/// // Solve A*X = B for X
89/// let equation = expr!((A*X) - B);
90/// let result = solver.solve(&equation, &X);
91/// ```
92#[derive(Debug, Clone)]
93pub struct MatrixEquationSolver {
94    pub show_steps: bool,
95}
96
97impl MatrixEquationSolver {
98    /// Create a new matrix equation solver
99    ///
100    /// # Examples
101    ///
102    /// ```rust,ignore
103    /// use mathhook_core::algebra::solvers::matrix_equations::MatrixEquationSolver;
104    ///
105    /// let solver = MatrixEquationSolver::new();
106    /// ```
107    pub fn new() -> Self {
108        Self { show_steps: true }
109    }
110
111    /// Create solver without step-by-step explanations (for performance)
112    ///
113    /// # Examples
114    ///
115    /// ```rust,ignore
116    /// use mathhook_core::algebra::solvers::matrix_equations::MatrixEquationSolver;
117    ///
118    /// let solver = MatrixEquationSolver::new_fast();
119    /// ```
120    pub fn new_fast() -> Self {
121        Self { show_steps: false }
122    }
123
124    /// Detect if equation is left division (A*X = B)
125    ///
126    /// Returns Some((A, B)) if equation is A*X = B, None otherwise
127    fn detect_left_division(
128        &self,
129        equation: &Expression,
130        variable: &Symbol,
131    ) -> Option<(Expression, Expression)> {
132        let simplified = equation.simplify();
133
134        match &simplified {
135            // Pattern: A*X - B = 0
136            Expression::Add(terms) if terms.len() == 2 => {
137                // Look for pattern: Mul(A, X) and -B
138                match (&terms[0], &terms[1]) {
139                    (Expression::Mul(factors), b) if factors.len() == 2 => {
140                        if let [a, Expression::Symbol(x)] = &factors[..] {
141                            if x == variable && !a.contains_variable(variable) {
142                                // Found A*X - B pattern
143                                let neg_b =
144                                    Expression::mul(vec![Expression::integer(-1), b.clone()]);
145                                return Some((a.clone(), neg_b.simplify()));
146                            }
147                        }
148                        None
149                    }
150                    _ => None,
151                }
152            }
153            // Pattern: A*X = 0 (already simplified, b=0 implicit)
154            Expression::Mul(factors) if factors.len() == 2 => {
155                if let [a, Expression::Symbol(x)] = &factors[..] {
156                    if x == variable && !a.contains_variable(variable) {
157                        return Some((a.clone(), Expression::integer(0)));
158                    }
159                }
160                None
161            }
162            _ => None,
163        }
164    }
165
166    /// Detect if equation is right division (X*A = B)
167    ///
168    /// Returns Some((A, B)) if equation is X*A = B, None otherwise
169    fn detect_right_division(
170        &self,
171        equation: &Expression,
172        variable: &Symbol,
173    ) -> Option<(Expression, Expression)> {
174        let simplified = equation.simplify();
175
176        match &simplified {
177            // Pattern: X*A - B = 0
178            Expression::Add(terms) if terms.len() == 2 => {
179                // Look for pattern: Mul(X, A) and -B
180                match (&terms[0], &terms[1]) {
181                    (Expression::Mul(factors), b) if factors.len() == 2 => {
182                        if let [Expression::Symbol(x), a] = &factors[..] {
183                            if x == variable && !a.contains_variable(variable) {
184                                // Found X*A - B pattern
185                                let neg_b =
186                                    Expression::mul(vec![Expression::integer(-1), b.clone()]);
187                                return Some((a.clone(), neg_b.simplify()));
188                            }
189                        }
190                        None
191                    }
192                    _ => None,
193                }
194            }
195            // Pattern: X*A = 0 (already simplified, b=0 implicit)
196            Expression::Mul(factors) if factors.len() == 2 => {
197                if let [Expression::Symbol(x), a] = &factors[..] {
198                    if x == variable && !a.contains_variable(variable) {
199                        return Some((a.clone(), Expression::integer(0)));
200                    }
201                }
202                None
203            }
204            _ => None,
205        }
206    }
207
208    /// Solve left division: A*X = B → X = A^(-1)*B
209    ///
210    /// # Arguments
211    ///
212    /// * `A` - The left coefficient matrix/operator
213    /// * `B` - The right-hand side
214    ///
215    /// # Examples
216    ///
217    /// ```rust,ignore
218    /// use mathhook_core::{symbol, expr};
219    /// use mathhook_core::algebra::solvers::matrix_equations::MatrixEquationSolver;
220    ///
221    /// let solver = MatrixEquationSolver::new();
222    /// let A = symbol!(A; matrix);
223    /// let B = symbol!(B; matrix);
224    ///
225    /// let solution = solver.solve_left_division(&A, &B);
226    /// // solution should be A^(-1)*B
227    /// ```
228    pub fn solve_left_division(
229        &self,
230        a: &Expression,
231        b: &Expression,
232    ) -> Result<Expression, SolverError> {
233        // Check if A is potentially singular (for matrices)
234        if self.is_zero_matrix(a) {
235            return Err(SolverError::InvalidEquation(
236                "Cannot invert zero matrix".to_owned(),
237            ));
238        }
239
240        // X = A^(-1) * B (left multiplication)
241        let a_inv = Expression::pow(a.clone(), Expression::integer(-1));
242        let solution = Expression::mul(vec![a_inv, b.clone()]);
243
244        Ok(solution.simplify())
245    }
246
247    /// Solve right division: X*A = B → X = B*A^(-1)
248    ///
249    /// # Arguments
250    ///
251    /// * `A` - The right coefficient matrix/operator
252    /// * `B` - The right-hand side
253    ///
254    /// # Examples
255    ///
256    /// ```rust,ignore
257    /// use mathhook_core::{symbol, expr};
258    /// use mathhook_core::algebra::solvers::matrix_equations::MatrixEquationSolver;
259    ///
260    /// let solver = MatrixEquationSolver::new();
261    /// let A = symbol!(A; matrix);
262    /// let B = symbol!(B; matrix);
263    ///
264    /// let solution = solver.solve_right_division(&A, &B);
265    /// // solution should be B*A^(-1)
266    /// ```
267    pub fn solve_right_division(
268        &self,
269        a: &Expression,
270        b: &Expression,
271    ) -> Result<Expression, SolverError> {
272        // Check if A is potentially singular (for matrices)
273        if self.is_zero_matrix(a) {
274            return Err(SolverError::InvalidEquation(
275                "Cannot invert zero matrix".to_owned(),
276            ));
277        }
278
279        // X = B * A^(-1) (right multiplication)
280        let a_inv = Expression::pow(a.clone(), Expression::integer(-1));
281        let solution = Expression::mul(vec![b.clone(), a_inv]);
282
283        Ok(solution.simplify())
284    }
285
286    /// Check if expression represents a zero matrix
287    fn is_zero_matrix(&self, expr: &Expression) -> bool {
288        match expr {
289            Expression::Number(n) if n.is_zero() => true,
290            Expression::Matrix(m) => {
291                let (rows, cols) = m.dimensions();
292                for i in 0..rows {
293                    for j in 0..cols {
294                        let elem = m.get_element(i, j);
295                        if !elem.is_zero() {
296                            return false;
297                        }
298                    }
299                }
300                true
301            }
302            _ => false,
303        }
304    }
305
306    /// Detect if variable appears in multiple positions (error case)
307    fn variable_appears_multiple_times(&self, expr: &Expression, variable: &Symbol) -> bool {
308        let count = expr.count_variable_occurrences(variable);
309        count > 1
310    }
311}
312
313impl Default for MatrixEquationSolver {
314    fn default() -> Self {
315        Self::new()
316    }
317}
318
319impl EquationSolver for MatrixEquationSolver {
320    fn solve(&self, equation: &Expression, variable: &Symbol) -> SolverResult {
321        // Check if variable appears multiple times (error case for noncommutative)
322        if self.variable_appears_multiple_times(equation, variable) {
323            return SolverResult::NoSolution;
324        }
325
326        // Try left division first
327        if let Some((a, b)) = self.detect_left_division(equation, variable) {
328            match self.solve_left_division(&a, &b) {
329                Ok(solution) => return SolverResult::Single(solution),
330                Err(_) => return SolverResult::NoSolution,
331            }
332        }
333
334        // Try right division
335        if let Some((a, b)) = self.detect_right_division(equation, variable) {
336            match self.solve_right_division(&a, &b) {
337                Ok(solution) => return SolverResult::Single(solution),
338                Err(_) => return SolverResult::NoSolution,
339            }
340        }
341
342        SolverResult::NoSolution
343    }
344
345    fn solve_with_explanation(
346        &self,
347        equation: &Expression,
348        variable: &Symbol,
349    ) -> (SolverResult, StepByStepExplanation) {
350        let mut steps = vec![Step::new(
351            "Given Equation",
352            format!("Solve {} = 0 for {}", equation, variable.name),
353        )];
354
355        // Check commutativity
356        if equation.commutativity() == Commutativity::Commutative {
357            steps.push(Step::new(
358                "Analysis",
359                "All symbols are commutative - use standard linear solver instead",
360            ));
361            return (SolverResult::NoSolution, StepByStepExplanation::new(steps));
362        }
363
364        steps.push(Step::new(
365            "Analysis",
366            "Detected noncommutative symbols (matrix/operator/quaternion)",
367        ));
368
369        // Try left division
370        if let Some((a, b)) = self.detect_left_division(equation, variable) {
371            steps.push(Step::new(
372                "Pattern",
373                format!(
374                    "Identified left division: {} * {} = {}",
375                    a, variable.name, b
376                ),
377            ));
378            steps.push(Step::new(
379                "Solution Method",
380                format!(
381                    "{} = {}^(-1) * {} (inverse applied on LEFT)",
382                    variable.name, a, b
383                ),
384            ));
385
386            match self.solve_left_division(&a, &b) {
387                Ok(solution) => {
388                    steps.push(Step::new(
389                        "Result",
390                        format!("{} = {}", variable.name, solution),
391                    ));
392                    return (
393                        SolverResult::Single(solution),
394                        StepByStepExplanation::new(steps),
395                    );
396                }
397                Err(e) => {
398                    steps.push(Step::new("Error", format!("{:?}", e)));
399                    return (SolverResult::NoSolution, StepByStepExplanation::new(steps));
400                }
401            }
402        }
403
404        // Try right division
405        if let Some((a, b)) = self.detect_right_division(equation, variable) {
406            steps.push(Step::new(
407                "Pattern",
408                format!(
409                    "Identified right division: {} * {} = {}",
410                    variable.name, a, b
411                ),
412            ));
413            steps.push(Step::new(
414                "Solution Method",
415                format!(
416                    "{} = {} * {}^(-1) (inverse applied on RIGHT)",
417                    variable.name, b, a
418                ),
419            ));
420
421            match self.solve_right_division(&a, &b) {
422                Ok(solution) => {
423                    steps.push(Step::new(
424                        "Result",
425                        format!("{} = {}", variable.name, solution),
426                    ));
427                    return (
428                        SolverResult::Single(solution),
429                        StepByStepExplanation::new(steps),
430                    );
431                }
432                Err(e) => {
433                    steps.push(Step::new("Error", format!("{:?}", e)));
434                    return (SolverResult::NoSolution, StepByStepExplanation::new(steps));
435                }
436            }
437        }
438
439        steps.push(Step::new(
440            "Result",
441            "Could not identify left or right division pattern",
442        ));
443        (SolverResult::NoSolution, StepByStepExplanation::new(steps))
444    }
445
446    fn can_solve(&self, equation: &Expression) -> bool {
447        equation.commutativity() != Commutativity::Commutative
448    }
449}
450
451#[cfg(test)]
452mod tests {
453    use super::*;
454    use crate::symbol;
455
456    #[test]
457    fn test_left_division_detection() {
458        let solver = MatrixEquationSolver::new();
459        let a = symbol!(A; matrix);
460        let x = symbol!(X; matrix);
461        let b = symbol!(B; matrix);
462
463        // A*X - B = 0
464        let equation = Expression::add(vec![
465            Expression::mul(vec![
466                Expression::symbol(a.clone()),
467                Expression::symbol(x.clone()),
468            ]),
469            Expression::mul(vec![Expression::integer(-1), Expression::symbol(b.clone())]),
470        ]);
471
472        let result = solver.detect_left_division(&equation, &x);
473        assert!(result.is_some());
474    }
475
476    #[test]
477    fn test_right_division_detection() {
478        let solver = MatrixEquationSolver::new();
479        let a = symbol!(A; matrix);
480        let x = symbol!(X; matrix);
481        let b = symbol!(B; matrix);
482
483        // X*A - B = 0
484        let equation = Expression::add(vec![
485            Expression::mul(vec![
486                Expression::symbol(x.clone()),
487                Expression::symbol(a.clone()),
488            ]),
489            Expression::mul(vec![Expression::integer(-1), Expression::symbol(b.clone())]),
490        ]);
491
492        let result = solver.detect_right_division(&equation, &x);
493        assert!(result.is_some());
494    }
495}