#[cfg(test)]
mod tests {
use crate::core::{Expression, Number};
use crate::engine::{ComputeEngine, compute::BasicComputeEngine};
use std::collections::HashMap;
#[test]
fn test_matrix_add_basic() {
let engine = BasicComputeEngine::new();
let matrix_a = Expression::matrix(vec![
vec![Expression::number(Number::from(1)), Expression::number(Number::from(2))],
vec![Expression::number(Number::from(3)), Expression::number(Number::from(4))]
]).unwrap();
let matrix_b = Expression::matrix(vec![
vec![Expression::number(Number::from(5)), Expression::number(Number::from(6))],
vec![Expression::number(Number::from(7)), Expression::number(Number::from(8))]
]).unwrap();
let result = engine.matrix_add(&matrix_a, &matrix_b).unwrap();
if let Expression::Matrix(elements) = result {
assert_eq!(elements.len(), 2);
assert_eq!(elements[0].len(), 2);
let vars = HashMap::new();
let elem_00 = engine.evaluate(&elements[0][0], &vars).unwrap();
let elem_01 = engine.evaluate(&elements[0][1], &vars).unwrap();
assert_eq!(elem_00, Number::from(6));
assert_eq!(elem_01, Number::from(8));
let elem_10 = engine.evaluate(&elements[1][0], &vars).unwrap();
let elem_11 = engine.evaluate(&elements[1][1], &vars).unwrap();
assert_eq!(elem_10, Number::from(10));
assert_eq!(elem_11, Number::from(12));
} else {
panic!("结果不是矩阵类型");
}
}
#[test]
fn test_matrix_multiply_basic() {
let engine = BasicComputeEngine::new();
let matrix_a = Expression::matrix(vec![
vec![Expression::number(Number::from(1)), Expression::number(Number::from(2))],
vec![Expression::number(Number::from(3)), Expression::number(Number::from(4))]
]).unwrap();
let matrix_b = Expression::matrix(vec![
vec![Expression::number(Number::from(5)), Expression::number(Number::from(6))],
vec![Expression::number(Number::from(7)), Expression::number(Number::from(8))]
]).unwrap();
let result = engine.matrix_multiply(&matrix_a, &matrix_b).unwrap();
if let Expression::Matrix(elements) = result {
assert_eq!(elements.len(), 2);
assert_eq!(elements[0].len(), 2);
let vars = HashMap::new();
let elem_00 = engine.evaluate(&elements[0][0], &vars).unwrap();
let elem_01 = engine.evaluate(&elements[0][1], &vars).unwrap();
assert_eq!(elem_00, Number::from(19));
assert_eq!(elem_01, Number::from(22));
let elem_10 = engine.evaluate(&elements[1][0], &vars).unwrap();
let elem_11 = engine.evaluate(&elements[1][1], &vars).unwrap();
assert_eq!(elem_10, Number::from(43));
assert_eq!(elem_11, Number::from(50));
} else {
panic!("结果不是矩阵类型");
}
}
#[test]
fn test_matrix_determinant_2x2() {
let engine = BasicComputeEngine::new();
let matrix = Expression::matrix(vec![
vec![Expression::number(Number::from(1)), Expression::number(Number::from(2))],
vec![Expression::number(Number::from(3)), Expression::number(Number::from(4))]
]).unwrap();
let result = engine.matrix_determinant(&matrix).unwrap();
let vars = HashMap::new();
let det_value = engine.evaluate(&result, &vars).unwrap();
assert_eq!(det_value, Number::from(-2));
}
#[test]
fn test_matrix_determinant_3x3() {
let engine = BasicComputeEngine::new();
let matrix = Expression::matrix(vec![
vec![Expression::number(Number::from(1)), Expression::number(Number::from(2)), Expression::number(Number::from(3))],
vec![Expression::number(Number::from(4)), Expression::number(Number::from(5)), Expression::number(Number::from(6))],
vec![Expression::number(Number::from(7)), Expression::number(Number::from(8)), Expression::number(Number::from(9))]
]).unwrap();
let result = engine.matrix_determinant(&matrix).unwrap();
let vars = HashMap::new();
let det_value = engine.evaluate(&result, &vars).unwrap();
assert_eq!(det_value, Number::from(0));
}
#[test]
fn test_matrix_inverse_2x2() {
let engine = BasicComputeEngine::new();
let matrix = Expression::matrix(vec![
vec![Expression::number(Number::from(1)), Expression::number(Number::from(2))],
vec![Expression::number(Number::from(3)), Expression::number(Number::from(4))]
]).unwrap();
let result = engine.matrix_inverse(&matrix).unwrap();
if let Expression::Matrix(elements) = result {
assert_eq!(elements.len(), 2);
assert_eq!(elements[0].len(), 2);
let vars = HashMap::new();
let _elem_00 = engine.evaluate(&elements[0][0], &vars);
let _elem_01 = engine.evaluate(&elements[0][1], &vars);
let _elem_10 = engine.evaluate(&elements[1][0], &vars);
let _elem_11 = engine.evaluate(&elements[1][1], &vars);
assert!(_elem_00.is_ok());
assert!(_elem_01.is_ok());
assert!(_elem_10.is_ok());
assert!(_elem_11.is_ok());
} else {
panic!("结果不是矩阵类型");
}
}
#[test]
fn test_vector_dot_product() {
let engine = BasicComputeEngine::new();
let vector_a = Expression::vector(vec![
Expression::number(Number::from(1)),
Expression::number(Number::from(2)),
Expression::number(Number::from(3))
]).unwrap();
let vector_b = Expression::vector(vec![
Expression::number(Number::from(4)),
Expression::number(Number::from(5)),
Expression::number(Number::from(6))
]).unwrap();
let result = engine.vector_dot(&vector_a, &vector_b).unwrap();
let vars = HashMap::new();
let dot_value = engine.evaluate(&result, &vars).unwrap();
assert_eq!(dot_value, Number::from(32));
}
#[test]
fn test_vector_cross_product() {
let engine = BasicComputeEngine::new();
let vector_i = Expression::vector(vec![
Expression::number(Number::from(1)),
Expression::number(Number::from(0)),
Expression::number(Number::from(0))
]).unwrap();
let vector_j = Expression::vector(vec![
Expression::number(Number::from(0)),
Expression::number(Number::from(1)),
Expression::number(Number::from(0))
]).unwrap();
let result = engine.vector_cross(&vector_i, &vector_j).unwrap();
if let Expression::Vector(elements) = result {
assert_eq!(elements.len(), 3);
let vars = HashMap::new();
let x = engine.evaluate(&elements[0], &vars).unwrap();
let y = engine.evaluate(&elements[1], &vars).unwrap();
let z = engine.evaluate(&elements[2], &vars).unwrap();
assert_eq!(x, Number::from(0));
assert_eq!(y, Number::from(0));
assert_eq!(z, Number::from(1));
} else {
panic!("结果不是向量类型");
}
}
#[test]
fn test_vector_norm() {
let engine = BasicComputeEngine::new();
let vector = Expression::vector(vec![
Expression::number(Number::from(3)),
Expression::number(Number::from(4)),
Expression::number(Number::from(0))
]).unwrap();
let result = engine.vector_norm(&vector).unwrap();
assert!(matches!(result, Expression::UnaryOp { .. }));
}
#[test]
fn test_matrix_transpose() {
let engine = BasicComputeEngine::new();
let matrix = Expression::matrix(vec![
vec![Expression::number(Number::from(1)), Expression::number(Number::from(2)), Expression::number(Number::from(3))],
vec![Expression::number(Number::from(4)), Expression::number(Number::from(5)), Expression::number(Number::from(6))]
]).unwrap();
let transpose_expr = Expression::unary_op(crate::core::UnaryOperator::Transpose, matrix);
let vars = HashMap::new();
let result = engine.evaluate(&transpose_expr, &vars);
assert!(result.is_err());
let matrix_1x1 = Expression::matrix(vec![
vec![Expression::number(Number::from(42))]
]).unwrap();
let transpose_1x1 = Expression::unary_op(crate::core::UnaryOperator::Transpose, matrix_1x1);
let result_1x1 = engine.evaluate(&transpose_1x1, &vars).unwrap();
assert_eq!(result_1x1, Number::from(42));
let matrix_2x1 = Expression::matrix(vec![
vec![Expression::number(Number::from(1))],
vec![Expression::number(Number::from(2))]
]).unwrap();
let transpose_2x1 = Expression::unary_op(crate::core::UnaryOperator::Transpose, matrix_2x1);
let simplified = engine.simplify(&transpose_2x1).unwrap();
if let Expression::Matrix(elements) = simplified {
assert_eq!(elements.len(), 1); assert_eq!(elements[0].len(), 2);
let elem_00 = engine.evaluate(&elements[0][0], &vars).unwrap();
let elem_01 = engine.evaluate(&elements[0][1], &vars).unwrap();
assert_eq!(elem_00, Number::from(1));
assert_eq!(elem_01, Number::from(2));
} else {
assert!(matches!(simplified, Expression::UnaryOp { .. }));
}
}
#[test]
fn test_matrix_trace() {
let engine = BasicComputeEngine::new();
let matrix = Expression::matrix(vec![
vec![Expression::number(Number::from(1)), Expression::number(Number::from(2))],
vec![Expression::number(Number::from(3)), Expression::number(Number::from(4))]
]).unwrap();
let trace_expr = Expression::unary_op(crate::core::UnaryOperator::Trace, matrix);
let vars = HashMap::new();
let trace_value = engine.evaluate(&trace_expr, &vars).unwrap();
assert_eq!(trace_value, Number::from(5));
}
#[test]
fn test_matrix_dimension_mismatch() {
let engine = BasicComputeEngine::new();
let matrix_a = Expression::matrix(vec![
vec![Expression::number(Number::from(1)), Expression::number(Number::from(2))]
]).unwrap();
let matrix_b = Expression::matrix(vec![
vec![Expression::number(Number::from(3))],
vec![Expression::number(Number::from(4))],
vec![Expression::number(Number::from(5))]
]).unwrap();
let add_result = engine.matrix_add(&matrix_a, &matrix_b);
assert!(add_result.is_err());
let mul_result = engine.matrix_multiply(&matrix_a, &matrix_b);
assert!(mul_result.is_err());
}
#[test]
fn test_vector_dimension_mismatch() {
let engine = BasicComputeEngine::new();
let vector_a = Expression::vector(vec![
Expression::number(Number::from(1)),
Expression::number(Number::from(2))
]).unwrap();
let vector_b = Expression::vector(vec![
Expression::number(Number::from(3)),
Expression::number(Number::from(4)),
Expression::number(Number::from(5))
]).unwrap();
let dot_result = engine.vector_dot(&vector_a, &vector_b);
assert!(dot_result.is_err());
let cross_result = engine.vector_cross(&vector_a, &vector_b);
assert!(cross_result.is_err());
}
#[test]
fn test_non_square_matrix_operations() {
let engine = BasicComputeEngine::new();
let matrix = Expression::matrix(vec![
vec![Expression::number(Number::from(1)), Expression::number(Number::from(2)), Expression::number(Number::from(3))],
vec![Expression::number(Number::from(4)), Expression::number(Number::from(5)), Expression::number(Number::from(6))]
]).unwrap();
let det_result = engine.matrix_determinant(&matrix);
assert!(det_result.is_err());
let inv_result = engine.matrix_inverse(&matrix);
assert!(inv_result.is_err());
let trace_expr = Expression::unary_op(crate::core::UnaryOperator::Trace, matrix);
let vars = HashMap::new();
let trace_result = engine.evaluate(&trace_expr, &vars);
assert!(trace_result.is_err());
}
}