use std::sync::Arc;
use crate::core::Expression;
use crate::error::MathError;
use crate::matrices::CoreMatrixOps;
pub fn try_matrix_add(a: &Expression, b: &Expression) -> Option<Result<Expression, MathError>> {
match (a, b) {
(Expression::Matrix(ma), Expression::Matrix(mb)) => Some(
ma.add(mb)
.map(|result| Expression::Matrix(Arc::new(result))),
),
_ => None,
}
}
pub fn try_matrix_multiply(
a: &Expression,
b: &Expression,
) -> Option<Result<Expression, MathError>> {
match (a, b) {
(Expression::Matrix(ma), Expression::Matrix(mb)) => Some(
ma.multiply(mb)
.map(|result| Expression::Matrix(Arc::new(result))),
),
_ => None,
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::expr;
use crate::simplify::Simplify;
#[test]
fn test_matrix_add_fast_path_compatible() {
let a = Expression::matrix(vec![vec![expr!(1), expr!(2)], vec![expr!(3), expr!(4)]]);
let b = Expression::matrix(vec![vec![expr!(5), expr!(6)], vec![expr!(7), expr!(8)]]);
let result = try_matrix_add(&a, &b);
assert!(result.is_some());
match result {
Some(Ok(Expression::Matrix(m))) => {
let (rows, cols) = m.dimensions();
assert_eq!(rows, 2);
assert_eq!(cols, 2);
}
Some(Err(_)) => panic!("Expected Ok result for compatible matrices"),
None => panic!("Expected Some for matrix operands"),
_ => panic!("Expected matrix result"),
}
}
#[test]
fn test_matrix_add_fast_path_incompatible() {
let a = Expression::matrix(vec![vec![expr!(1), expr!(2)], vec![expr!(3), expr!(4)]]);
let b = Expression::matrix(vec![vec![expr!(5), expr!(6), expr!(7)]]);
let result = try_matrix_add(&a, &b);
assert!(result.is_some());
match result {
Some(Err(MathError::DomainError {
operation, reason, ..
})) => {
assert_eq!(operation, "matrix_addition");
assert!(reason.contains("2x2"));
assert!(reason.contains("1x3"));
}
_ => panic!("Expected DomainError for incompatible dimensions"),
}
}
#[test]
fn test_matrix_add_fast_path_non_matrix() {
let a = Expression::matrix(vec![vec![expr!(1), expr!(2)]]);
let b = expr!(42);
let result = try_matrix_add(&a, &b);
assert!(result.is_none());
}
#[test]
fn test_matrix_multiply_fast_path_compatible() {
let a = Expression::matrix(vec![vec![expr!(1), expr!(2)], vec![expr!(3), expr!(4)]]);
let b = Expression::matrix(vec![vec![expr!(5), expr!(6)], vec![expr!(7), expr!(8)]]);
let result = try_matrix_multiply(&a, &b);
assert!(result.is_some());
match result {
Some(Ok(Expression::Matrix(m))) => {
let (rows, cols) = m.dimensions();
assert_eq!(rows, 2);
assert_eq!(cols, 2);
}
Some(Err(_)) => panic!("Expected Ok result for compatible matrices"),
None => panic!("Expected Some for matrix operands"),
_ => panic!("Expected matrix result"),
}
}
#[test]
fn test_matrix_multiply_fast_path_incompatible() {
let a = Expression::matrix(vec![vec![expr!(1), expr!(2)]]);
let b = Expression::matrix(vec![vec![expr!(5)], vec![expr!(6)], vec![expr!(7)]]);
let result = try_matrix_multiply(&a, &b);
assert!(result.is_some());
match result {
Some(Err(MathError::DomainError {
operation, reason, ..
})) => {
assert_eq!(operation, "matrix_multiplication");
assert!(reason.contains("1x2"));
assert!(reason.contains("3x1"));
assert!(reason.contains("2 != 3"));
}
_ => panic!("Expected DomainError for incompatible dimensions"),
}
}
#[test]
fn test_matrix_multiply_fast_path_non_matrix() {
let a = Expression::matrix(vec![vec![expr!(1), expr!(2)]]);
let b = expr!(x);
let result = try_matrix_multiply(&a, &b);
assert!(result.is_none());
}
#[test]
fn test_mixed_symbolic_matrix() {
let a = Expression::matrix(vec![vec![expr!(x), expr!(y)]]);
let b = Expression::matrix(vec![vec![expr!(2), expr!(3)]]);
let result = try_matrix_add(&a, &b);
assert!(result.is_some());
assert!(result.unwrap().is_ok());
}
#[test]
fn test_benchmark_operations_no_stack_overflow() {
fn create_test_matrix(size: usize) -> Expression {
let mut rows = Vec::new();
for i in 0..size {
let mut row = Vec::new();
for j in 0..size {
row.push(Expression::integer((i * size + j + 1) as i64));
}
rows.push(row);
}
Expression::matrix(rows)
}
for size in [2, 3, 4, 8, 16] {
let matrix_a = create_test_matrix(size);
let matrix_b = create_test_matrix(size);
let _add_result = Expression::add(vec![matrix_a.clone(), matrix_b.clone()]).simplify();
let _mul_result = Expression::mul(vec![matrix_a.clone(), matrix_b.clone()]).simplify();
}
}
}