use crate::algebra::solvers::{EquationSolver, SolverError, SolverResult};
use crate::core::commutativity::Commutativity;
use crate::core::{Expression, Symbol};
use crate::educational::step_by_step::{Step, StepByStepExplanation};
use crate::simplify::Simplify;
#[derive(Debug, Clone)]
pub struct MatrixEquationSolver {
pub show_steps: bool,
}
impl MatrixEquationSolver {
pub fn new() -> Self {
Self { show_steps: true }
}
pub fn new_fast() -> Self {
Self { show_steps: false }
}
fn detect_left_division(
&self,
equation: &Expression,
variable: &Symbol,
) -> Option<(Expression, Expression)> {
let simplified = equation.simplify();
match &simplified {
Expression::Add(terms) if terms.len() == 2 => {
match (&terms[0], &terms[1]) {
(Expression::Mul(factors), b) if factors.len() == 2 => {
if let [a, Expression::Symbol(x)] = &factors[..] {
if x == variable && !a.contains_variable(variable) {
let neg_b =
Expression::mul(vec![Expression::integer(-1), b.clone()]);
return Some((a.clone(), neg_b.simplify()));
}
}
None
}
_ => None,
}
}
Expression::Mul(factors) if factors.len() == 2 => {
if let [a, Expression::Symbol(x)] = &factors[..] {
if x == variable && !a.contains_variable(variable) {
return Some((a.clone(), Expression::integer(0)));
}
}
None
}
_ => None,
}
}
fn detect_right_division(
&self,
equation: &Expression,
variable: &Symbol,
) -> Option<(Expression, Expression)> {
let simplified = equation.simplify();
match &simplified {
Expression::Add(terms) if terms.len() == 2 => {
match (&terms[0], &terms[1]) {
(Expression::Mul(factors), b) if factors.len() == 2 => {
if let [Expression::Symbol(x), a] = &factors[..] {
if x == variable && !a.contains_variable(variable) {
let neg_b =
Expression::mul(vec![Expression::integer(-1), b.clone()]);
return Some((a.clone(), neg_b.simplify()));
}
}
None
}
_ => None,
}
}
Expression::Mul(factors) if factors.len() == 2 => {
if let [Expression::Symbol(x), a] = &factors[..] {
if x == variable && !a.contains_variable(variable) {
return Some((a.clone(), Expression::integer(0)));
}
}
None
}
_ => None,
}
}
pub fn solve_left_division(
&self,
a: &Expression,
b: &Expression,
) -> Result<Expression, SolverError> {
if self.is_zero_matrix(a) {
return Err(SolverError::InvalidEquation(
"Cannot invert zero matrix".to_owned(),
));
}
let a_inv = Expression::pow(a.clone(), Expression::integer(-1));
let solution = Expression::mul(vec![a_inv, b.clone()]);
Ok(solution.simplify())
}
pub fn solve_right_division(
&self,
a: &Expression,
b: &Expression,
) -> Result<Expression, SolverError> {
if self.is_zero_matrix(a) {
return Err(SolverError::InvalidEquation(
"Cannot invert zero matrix".to_owned(),
));
}
let a_inv = Expression::pow(a.clone(), Expression::integer(-1));
let solution = Expression::mul(vec![b.clone(), a_inv]);
Ok(solution.simplify())
}
fn is_zero_matrix(&self, expr: &Expression) -> bool {
match expr {
Expression::Number(n) if n.is_zero() => true,
Expression::Matrix(m) => {
let (rows, cols) = m.dimensions();
for i in 0..rows {
for j in 0..cols {
let elem = m.get_element(i, j);
if !elem.is_zero() {
return false;
}
}
}
true
}
_ => false,
}
}
fn variable_appears_multiple_times(&self, expr: &Expression, variable: &Symbol) -> bool {
let count = expr.count_variable_occurrences(variable);
count > 1
}
}
impl Default for MatrixEquationSolver {
fn default() -> Self {
Self::new()
}
}
impl EquationSolver for MatrixEquationSolver {
fn solve(&self, equation: &Expression, variable: &Symbol) -> SolverResult {
if self.variable_appears_multiple_times(equation, variable) {
return SolverResult::NoSolution;
}
if let Some((a, b)) = self.detect_left_division(equation, variable) {
match self.solve_left_division(&a, &b) {
Ok(solution) => return SolverResult::Single(solution),
Err(_) => return SolverResult::NoSolution,
}
}
if let Some((a, b)) = self.detect_right_division(equation, variable) {
match self.solve_right_division(&a, &b) {
Ok(solution) => return SolverResult::Single(solution),
Err(_) => return SolverResult::NoSolution,
}
}
SolverResult::NoSolution
}
fn solve_with_explanation(
&self,
equation: &Expression,
variable: &Symbol,
) -> (SolverResult, StepByStepExplanation) {
let mut steps = vec![Step::new(
"Given Equation",
format!("Solve {} = 0 for {}", equation, variable.name),
)];
if equation.commutativity() == Commutativity::Commutative {
steps.push(Step::new(
"Analysis",
"All symbols are commutative - use standard linear solver instead",
));
return (SolverResult::NoSolution, StepByStepExplanation::new(steps));
}
steps.push(Step::new(
"Analysis",
"Detected noncommutative symbols (matrix/operator/quaternion)",
));
if let Some((a, b)) = self.detect_left_division(equation, variable) {
steps.push(Step::new(
"Pattern",
format!(
"Identified left division: {} * {} = {}",
a, variable.name, b
),
));
steps.push(Step::new(
"Solution Method",
format!(
"{} = {}^(-1) * {} (inverse applied on LEFT)",
variable.name, a, b
),
));
match self.solve_left_division(&a, &b) {
Ok(solution) => {
steps.push(Step::new(
"Result",
format!("{} = {}", variable.name, solution),
));
return (
SolverResult::Single(solution),
StepByStepExplanation::new(steps),
);
}
Err(e) => {
steps.push(Step::new("Error", format!("{:?}", e)));
return (SolverResult::NoSolution, StepByStepExplanation::new(steps));
}
}
}
if let Some((a, b)) = self.detect_right_division(equation, variable) {
steps.push(Step::new(
"Pattern",
format!(
"Identified right division: {} * {} = {}",
variable.name, a, b
),
));
steps.push(Step::new(
"Solution Method",
format!(
"{} = {} * {}^(-1) (inverse applied on RIGHT)",
variable.name, b, a
),
));
match self.solve_right_division(&a, &b) {
Ok(solution) => {
steps.push(Step::new(
"Result",
format!("{} = {}", variable.name, solution),
));
return (
SolverResult::Single(solution),
StepByStepExplanation::new(steps),
);
}
Err(e) => {
steps.push(Step::new("Error", format!("{:?}", e)));
return (SolverResult::NoSolution, StepByStepExplanation::new(steps));
}
}
}
steps.push(Step::new(
"Result",
"Could not identify left or right division pattern",
));
(SolverResult::NoSolution, StepByStepExplanation::new(steps))
}
fn can_solve(&self, equation: &Expression) -> bool {
equation.commutativity() != Commutativity::Commutative
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::symbol;
#[test]
fn test_left_division_detection() {
let solver = MatrixEquationSolver::new();
let a = symbol!(A; matrix);
let x = symbol!(X; matrix);
let b = symbol!(B; matrix);
let equation = Expression::add(vec![
Expression::mul(vec![
Expression::symbol(a.clone()),
Expression::symbol(x.clone()),
]),
Expression::mul(vec![Expression::integer(-1), Expression::symbol(b.clone())]),
]);
let result = solver.detect_left_division(&equation, &x);
assert!(result.is_some());
}
#[test]
fn test_right_division_detection() {
let solver = MatrixEquationSolver::new();
let a = symbol!(A; matrix);
let x = symbol!(X; matrix);
let b = symbol!(B; matrix);
let equation = Expression::add(vec![
Expression::mul(vec![
Expression::symbol(x.clone()),
Expression::symbol(a.clone()),
]),
Expression::mul(vec![Expression::integer(-1), Expression::symbol(b.clone())]),
]);
let result = solver.detect_right_division(&equation, &x);
assert!(result.is_some());
}
}