Skip to main content

runmat_runtime/
matrix.rs

1//! Matrix operations for MATLAB-compatible arithmetic
2//!
3//! Implements element-wise and matrix operations following MATLAB semantics.
4
5use crate::builtins::common::linalg;
6use crate::BuiltinResult;
7use runmat_builtins::{Tensor, Value};
8use runmat_macros::runtime_builtin;
9
10/// Matrix addition: C = A + B
11pub fn matrix_add(a: &Tensor, b: &Tensor) -> Result<Tensor, String> {
12    if a.rows() != b.rows() || a.cols() != b.cols() {
13        return Err(format!(
14            "Matrix dimensions must agree: {}x{} + {}x{}",
15            a.rows, a.cols, b.rows, b.cols
16        ));
17    }
18
19    let data: Vec<f64> = a
20        .data
21        .iter()
22        .zip(b.data.iter())
23        .map(|(x, y)| x + y)
24        .collect();
25
26    Tensor::new_2d(data, a.rows(), a.cols())
27}
28
29/// Matrix subtraction: C = A - B
30pub fn matrix_sub(a: &Tensor, b: &Tensor) -> Result<Tensor, String> {
31    if a.rows() != b.rows() || a.cols() != b.cols() {
32        return Err(format!(
33            "Matrix dimensions must agree: {}x{} - {}x{}",
34            a.rows, a.cols, b.rows, b.cols
35        ));
36    }
37
38    let data: Vec<f64> = a
39        .data
40        .iter()
41        .zip(b.data.iter())
42        .map(|(x, y)| x - y)
43        .collect();
44
45    Tensor::new_2d(data, a.rows(), a.cols())
46}
47
48/// Matrix multiplication: C = A * B
49pub fn matrix_mul(a: &Tensor, b: &Tensor) -> Result<Tensor, String> {
50    linalg::matmul_real(a, b)
51}
52
53/// GPU-aware matmul entry: if both inputs are GpuTensor handles, call provider; otherwise fall back to CPU.
54pub async fn value_matmul(
55    a: &runmat_builtins::Value,
56    b: &runmat_builtins::Value,
57) -> BuiltinResult<runmat_builtins::Value> {
58    crate::builtins::math::linalg::ops::mtimes::mtimes_eval(a, b).await
59}
60
61fn complex_matrix_mul(
62    a: &runmat_builtins::ComplexTensor,
63    b: &runmat_builtins::ComplexTensor,
64) -> Result<runmat_builtins::ComplexTensor, String> {
65    linalg::matmul_complex(a, b)
66}
67
68/// Scalar multiplication: C = A * s
69pub fn matrix_scalar_mul(a: &Tensor, scalar: f64) -> Tensor {
70    linalg::scalar_mul_real(a, scalar)
71}
72
73/// Matrix power: C = A^n (for positive integer n)
74/// This computes A * A * ... * A (n times) via repeated multiplication
75pub fn matrix_power(a: &Tensor, n: i32) -> Result<Tensor, String> {
76    if a.rows() != a.cols() {
77        return Err(format!(
78            "Matrix must be square for matrix power: {}x{}",
79            a.rows(),
80            a.cols()
81        ));
82    }
83
84    if n < 0 {
85        return Err("Negative matrix powers not supported yet".to_string());
86    }
87
88    if n == 0 {
89        // A^0 = I (identity matrix)
90        return Ok(matrix_eye(a.rows));
91    }
92
93    if n == 1 {
94        // A^1 = A
95        return Ok(a.clone());
96    }
97
98    // Compute A^n via repeated multiplication
99    // Use binary exponentiation for efficiency
100    let mut result = matrix_eye(a.rows());
101    let mut base = a.clone();
102    let mut exp = n as u32;
103
104    while exp > 0 {
105        if exp % 2 == 1 {
106            result = matrix_mul(&result, &base)?;
107        }
108        base = matrix_mul(&base, &base)?;
109        exp /= 2;
110    }
111
112    Ok(result)
113}
114
115/// Complex matrix power: C = A^n (for positive integer n)
116/// Uses binary exponentiation with complex matrix multiply
117pub fn complex_matrix_power(
118    a: &runmat_builtins::ComplexTensor,
119    n: i32,
120) -> Result<runmat_builtins::ComplexTensor, String> {
121    if a.rows != a.cols {
122        return Err(format!(
123            "Matrix must be square for matrix power: {}x{}",
124            a.rows, a.cols
125        ));
126    }
127    if n < 0 {
128        return Err("Negative matrix powers not supported yet".to_string());
129    }
130    if n == 0 {
131        return Ok(complex_matrix_eye(a.rows));
132    }
133    if n == 1 {
134        return Ok(a.clone());
135    }
136    let mut result = complex_matrix_eye(a.rows);
137    let mut base = a.clone();
138    let mut exp = n as u32;
139    while exp > 0 {
140        if exp % 2 == 1 {
141            result = complex_matrix_mul(&result, &base)?;
142        }
143        base = complex_matrix_mul(&base, &base)?;
144        exp /= 2;
145    }
146    Ok(result)
147}
148
149fn complex_matrix_eye(n: usize) -> runmat_builtins::ComplexTensor {
150    let mut data: Vec<(f64, f64)> = vec![(0.0, 0.0); n * n];
151    for i in 0..n {
152        data[i * n + i] = (1.0, 0.0);
153    }
154    runmat_builtins::ComplexTensor::new_2d(data, n, n).unwrap()
155}
156
157/// Create identity matrix
158pub fn matrix_eye(n: usize) -> Tensor {
159    let mut data = vec![0.0; n * n];
160    for i in 0..n {
161        data[i * n + i] = 1.0;
162    }
163    Tensor::new_2d(data, n, n).unwrap() // Always valid
164}
165
166// Simple built-in function for testing matrix operations
167#[runtime_builtin(name = "matrix_zeros", builtin_path = "crate::matrix")]
168async fn matrix_zeros_builtin(rows: i32, cols: i32) -> crate::BuiltinResult<Tensor> {
169    if rows < 0 || cols < 0 {
170        return Err(("Matrix dimensions must be non-negative".to_string()).into());
171    }
172    Ok(Tensor::zeros(vec![rows as usize, cols as usize]))
173}
174
175#[runtime_builtin(name = "matrix_ones", builtin_path = "crate::matrix")]
176async fn matrix_ones_builtin(rows: i32, cols: i32) -> crate::BuiltinResult<Tensor> {
177    if rows < 0 || cols < 0 {
178        return Err(("Matrix dimensions must be non-negative".to_string()).into());
179    }
180    Ok(Tensor::ones(vec![rows as usize, cols as usize]))
181}
182
183#[runtime_builtin(name = "matrix_eye", builtin_path = "crate::matrix")]
184async fn matrix_eye_builtin(n: i32) -> crate::BuiltinResult<Tensor> {
185    if n < 0 {
186        return Err(("Matrix size must be non-negative".to_string()).into());
187    }
188    Ok(matrix_eye(n as usize))
189}
190
191#[runtime_builtin(name = "matrix_transpose", builtin_path = "crate::matrix")]
192async fn matrix_transpose_builtin(a: Tensor) -> crate::BuiltinResult<Tensor> {
193    let args = [Value::Tensor(a)];
194    let result = crate::call_builtin_async("transpose", &args).await?;
195    match result {
196        Value::Tensor(tensor) => Ok(tensor),
197        other => Err((format!("matrix_transpose: expected tensor, got {other:?}")).into()),
198    }
199}