use crate::builtins::common::linalg;
use crate::BuiltinResult;
use runmat_builtins::Tensor;
pub fn matrix_add(a: &Tensor, b: &Tensor) -> Result<Tensor, String> {
if a.rows() != b.rows() || a.cols() != b.cols() {
return Err(format!(
"Matrix dimensions must agree: {}x{} + {}x{}",
a.rows, a.cols, b.rows, b.cols
));
}
let data: Vec<f64> = a
.data
.iter()
.zip(b.data.iter())
.map(|(x, y)| x + y)
.collect();
Tensor::new_2d(data, a.rows(), a.cols())
}
pub fn matrix_sub(a: &Tensor, b: &Tensor) -> Result<Tensor, String> {
if a.rows() != b.rows() || a.cols() != b.cols() {
return Err(format!(
"Matrix dimensions must agree: {}x{} - {}x{}",
a.rows, a.cols, b.rows, b.cols
));
}
let data: Vec<f64> = a
.data
.iter()
.zip(b.data.iter())
.map(|(x, y)| x - y)
.collect();
Tensor::new_2d(data, a.rows(), a.cols())
}
pub fn matrix_mul(a: &Tensor, b: &Tensor) -> Result<Tensor, String> {
linalg::matmul_real(a, b)
}
pub async fn value_matmul(
a: &runmat_builtins::Value,
b: &runmat_builtins::Value,
) -> BuiltinResult<runmat_builtins::Value> {
crate::builtins::math::linalg::ops::mtimes::mtimes_eval(a, b).await
}
fn complex_matrix_mul(
a: &runmat_builtins::ComplexTensor,
b: &runmat_builtins::ComplexTensor,
) -> Result<runmat_builtins::ComplexTensor, String> {
linalg::matmul_complex(a, b)
}
pub fn matrix_scalar_mul(a: &Tensor, scalar: f64) -> Tensor {
linalg::scalar_mul_real(a, scalar)
}
pub fn matrix_power(a: &Tensor, n: i32) -> Result<Tensor, String> {
if a.rows() != a.cols() {
return Err(format!(
"Matrix must be square for matrix power: {}x{}",
a.rows(),
a.cols()
));
}
if n < 0 {
return Err("Negative matrix powers not supported yet".to_string());
}
if n == 0 {
return Ok(matrix_eye(a.rows));
}
if n == 1 {
return Ok(a.clone());
}
let mut result = matrix_eye(a.rows());
let mut base = a.clone();
let mut exp = n as u32;
while exp > 0 {
if exp % 2 == 1 {
result = matrix_mul(&result, &base)?;
}
base = matrix_mul(&base, &base)?;
exp /= 2;
}
Ok(result)
}
pub fn complex_matrix_power(
a: &runmat_builtins::ComplexTensor,
n: i32,
) -> Result<runmat_builtins::ComplexTensor, String> {
if a.rows != a.cols {
return Err(format!(
"Matrix must be square for matrix power: {}x{}",
a.rows, a.cols
));
}
if n < 0 {
return Err("Negative matrix powers not supported yet".to_string());
}
if n == 0 {
return Ok(complex_matrix_eye(a.rows));
}
if n == 1 {
return Ok(a.clone());
}
let mut result = complex_matrix_eye(a.rows);
let mut base = a.clone();
let mut exp = n as u32;
while exp > 0 {
if exp % 2 == 1 {
result = complex_matrix_mul(&result, &base)?;
}
base = complex_matrix_mul(&base, &base)?;
exp /= 2;
}
Ok(result)
}
fn complex_matrix_eye(n: usize) -> runmat_builtins::ComplexTensor {
let mut data: Vec<(f64, f64)> = vec![(0.0, 0.0); n * n];
for i in 0..n {
data[i * n + i] = (1.0, 0.0);
}
runmat_builtins::ComplexTensor::new_2d(data, n, n).unwrap()
}
pub fn matrix_eye(n: usize) -> Tensor {
let mut data = vec![0.0; n * n];
for i in 0..n {
data[i * n + i] = 1.0;
}
Tensor::new_2d(data, n, n).unwrap() }