use std::sync::Arc;
use super::{CoreMatrixOps, Matrix};
use crate::core::Expression;
use crate::core::Number;
use crate::simplify::Simplify;
pub trait MatrixOperations {
fn matrix_add(&self, other: &Expression) -> Expression;
fn matrix_subtract(&self, other: &Expression) -> Expression;
fn matrix_multiply(&self, other: &Expression) -> Expression;
fn matrix_scalar_multiply(&self, scalar: &Expression) -> Expression;
fn matrix_determinant(&self) -> Expression;
fn matrix_transpose(&self) -> Expression;
fn matrix_inverse(&self) -> Expression;
fn matrix_trace(&self) -> Expression;
fn matrix_power(&self, exponent: &Expression) -> Expression;
fn is_identity_matrix(&self) -> bool;
fn is_zero_matrix(&self) -> bool;
fn is_diagonal(&self) -> bool;
fn simplify_matrix(&self) -> Expression;
}
impl Expression {
pub fn matrix_dimensions(&self) -> Option<(usize, usize)> {
match self {
Expression::Matrix(matrix) => Some(matrix.dimensions()),
_ => None,
}
}
pub fn is_matrix(&self) -> bool {
matches!(self, Expression::Matrix(_))
}
}
impl MatrixOperations for Expression {
fn matrix_add(&self, other: &Expression) -> Expression {
match (self, other) {
(Expression::Matrix(a), Expression::Matrix(b)) => match a.add(b) {
Ok(result_matrix) => Expression::Matrix(Arc::new(result_matrix)),
Err(_) => Expression::function("undefined", vec![]),
},
_ => Expression::function("undefined", vec![]),
}
}
fn matrix_subtract(&self, other: &Expression) -> Expression {
match (self, other) {
(Expression::Matrix(a), Expression::Matrix(b)) => {
let neg_b = b.scalar_multiply(&Expression::integer(-1));
match a.add(&neg_b) {
Ok(result_matrix) => Expression::Matrix(Arc::new(result_matrix)),
Err(_) => Expression::function("undefined", vec![]),
}
}
_ => Expression::function("undefined", vec![]),
}
}
fn matrix_multiply(&self, other: &Expression) -> Expression {
match (self, other) {
(Expression::Matrix(a), Expression::Matrix(b)) => match a.multiply(b) {
Ok(result_matrix) => Expression::Matrix(Arc::new(result_matrix)),
Err(_) => Expression::function("undefined", vec![]),
},
_ => Expression::function("undefined", vec![]),
}
}
fn matrix_scalar_multiply(&self, scalar: &Expression) -> Expression {
match self {
Expression::Matrix(matrix) => {
let result_matrix = matrix.scalar_multiply(scalar);
let result = Expression::Matrix(Arc::new(result_matrix));
result.simplify()
}
_ => Expression::function("undefined", vec![]),
}
}
fn matrix_determinant(&self) -> Expression {
match self {
Expression::Matrix(matrix) => matrix
.determinant()
.unwrap_or_else(|_| Expression::function("undefined", vec![])),
_ => Expression::function("undefined", vec![]),
}
}
fn matrix_transpose(&self) -> Expression {
match self {
Expression::Matrix(matrix) => {
let transposed = matrix.transpose();
Expression::Matrix(Arc::new(transposed))
}
_ => Expression::function("undefined", vec![]),
}
}
fn matrix_inverse(&self) -> Expression {
match self {
Expression::Matrix(matrix) => {
let inverse = matrix.inverse();
Expression::Matrix(Arc::new(inverse))
}
_ => Expression::function("undefined", vec![]),
}
}
fn matrix_trace(&self) -> Expression {
match self {
Expression::Matrix(matrix) => matrix.trace(),
_ => Expression::function("undefined", vec![]),
}
}
fn matrix_power(&self, exponent: &Expression) -> Expression {
if !self.is_matrix() {
return Expression::function("undefined", vec![]);
}
if let Expression::Number(Number::Integer(n)) = exponent {
if *n < 0 {
let inv = self.matrix_inverse();
return inv.matrix_power(&Expression::integer(-n));
}
if *n == 0 {
if let Some((rows, cols)) = self.matrix_dimensions() {
if rows == cols {
return Expression::identity_matrix(rows);
}
}
return Expression::function("undefined", vec![]);
}
if *n == 1 {
return self.clone();
}
let mut result = self.clone();
for _ in 1..*n {
result = result.matrix_multiply(self);
}
result
} else {
Expression::function("undefined", vec![])
}
}
fn is_identity_matrix(&self) -> bool {
match self {
Expression::Matrix(matrix) => matches!(matrix.as_ref(), Matrix::Identity(_)),
_ => false,
}
}
fn is_zero_matrix(&self) -> bool {
match self {
Expression::Matrix(matrix) => matches!(matrix.as_ref(), Matrix::Zero(_)),
_ => false,
}
}
fn is_diagonal(&self) -> bool {
match self {
Expression::Matrix(matrix) => {
matches!(
matrix.as_ref(),
Matrix::Diagonal(_) | Matrix::Identity(_) | Matrix::Scalar(_)
)
}
_ => false,
}
}
fn simplify_matrix(&self) -> Expression {
match self {
Expression::Matrix(matrix) => {
let optimized = matrix.as_ref().clone().optimize();
Expression::Matrix(Arc::new(optimized))
}
_ => self.clone(),
}
}
}