runmat_runtime/
matrix.rs

1//! Matrix operations for MATLAB-compatible arithmetic
2//!
3//! Implements element-wise and matrix operations following MATLAB semantics.
4
5use crate::builtins::common::linalg;
6use runmat_builtins::Tensor;
7use runmat_macros::runtime_builtin;
8
9/// Matrix addition: C = A + B
10pub fn matrix_add(a: &Tensor, b: &Tensor) -> Result<Tensor, String> {
11    if a.rows() != b.rows() || a.cols() != b.cols() {
12        return Err(format!(
13            "Matrix dimensions must agree: {}x{} + {}x{}",
14            a.rows, a.cols, b.rows, b.cols
15        ));
16    }
17
18    let data: Vec<f64> = a
19        .data
20        .iter()
21        .zip(b.data.iter())
22        .map(|(x, y)| x + y)
23        .collect();
24
25    Tensor::new_2d(data, a.rows(), a.cols())
26}
27
28/// Matrix subtraction: C = A - B
29pub fn matrix_sub(a: &Tensor, b: &Tensor) -> Result<Tensor, String> {
30    if a.rows() != b.rows() || a.cols() != b.cols() {
31        return Err(format!(
32            "Matrix dimensions must agree: {}x{} - {}x{}",
33            a.rows, a.cols, b.rows, b.cols
34        ));
35    }
36
37    let data: Vec<f64> = a
38        .data
39        .iter()
40        .zip(b.data.iter())
41        .map(|(x, y)| x - y)
42        .collect();
43
44    Tensor::new_2d(data, a.rows(), a.cols())
45}
46
47/// Matrix multiplication: C = A * B
48pub fn matrix_mul(a: &Tensor, b: &Tensor) -> Result<Tensor, String> {
49    linalg::matmul_real(a, b)
50}
51
52/// GPU-aware matmul entry: if both inputs are GpuTensor handles, call provider; otherwise fall back to CPU.
53pub fn value_matmul(
54    a: &runmat_builtins::Value,
55    b: &runmat_builtins::Value,
56) -> Result<runmat_builtins::Value, String> {
57    crate::builtins::math::linalg::ops::mtimes::mtimes_eval(a, b)
58}
59
60fn complex_matrix_mul(
61    a: &runmat_builtins::ComplexTensor,
62    b: &runmat_builtins::ComplexTensor,
63) -> Result<runmat_builtins::ComplexTensor, String> {
64    linalg::matmul_complex(a, b)
65}
66
67/// Scalar multiplication: C = A * s
68pub fn matrix_scalar_mul(a: &Tensor, scalar: f64) -> Tensor {
69    linalg::scalar_mul_real(a, scalar)
70}
71
72/// Matrix transpose: C = A'
73pub fn matrix_transpose(a: &Tensor) -> Tensor {
74    let mut data = vec![0.0; a.rows() * a.cols()];
75    for i in 0..a.rows() {
76        for j in 0..a.cols() {
77            // dst(j,i) = src(i,j)
78            data[j * a.rows() + i] = a.data[i + j * a.rows()];
79        }
80    }
81    Tensor::new_2d(data, a.cols(), a.rows()).unwrap() // Always valid
82}
83
84/// Matrix power: C = A^n (for positive integer n)
85/// This computes A * A * ... * A (n times) via repeated multiplication
86pub fn matrix_power(a: &Tensor, n: i32) -> Result<Tensor, String> {
87    if a.rows() != a.cols() {
88        return Err(format!(
89            "Matrix must be square for matrix power: {}x{}",
90            a.rows(),
91            a.cols()
92        ));
93    }
94
95    if n < 0 {
96        return Err("Negative matrix powers not supported yet".to_string());
97    }
98
99    if n == 0 {
100        // A^0 = I (identity matrix)
101        return Ok(matrix_eye(a.rows));
102    }
103
104    if n == 1 {
105        // A^1 = A
106        return Ok(a.clone());
107    }
108
109    // Compute A^n via repeated multiplication
110    // Use binary exponentiation for efficiency
111    let mut result = matrix_eye(a.rows());
112    let mut base = a.clone();
113    let mut exp = n as u32;
114
115    while exp > 0 {
116        if exp % 2 == 1 {
117            result = matrix_mul(&result, &base)?;
118        }
119        base = matrix_mul(&base, &base)?;
120        exp /= 2;
121    }
122
123    Ok(result)
124}
125
126/// Complex matrix power: C = A^n (for positive integer n)
127/// Uses binary exponentiation with complex matrix multiply
128pub fn complex_matrix_power(
129    a: &runmat_builtins::ComplexTensor,
130    n: i32,
131) -> Result<runmat_builtins::ComplexTensor, String> {
132    if a.rows != a.cols {
133        return Err(format!(
134            "Matrix must be square for matrix power: {}x{}",
135            a.rows, a.cols
136        ));
137    }
138    if n < 0 {
139        return Err("Negative matrix powers not supported yet".to_string());
140    }
141    if n == 0 {
142        return Ok(complex_matrix_eye(a.rows));
143    }
144    if n == 1 {
145        return Ok(a.clone());
146    }
147    let mut result = complex_matrix_eye(a.rows);
148    let mut base = a.clone();
149    let mut exp = n as u32;
150    while exp > 0 {
151        if exp % 2 == 1 {
152            result = complex_matrix_mul(&result, &base)?;
153        }
154        base = complex_matrix_mul(&base, &base)?;
155        exp /= 2;
156    }
157    Ok(result)
158}
159
160fn complex_matrix_eye(n: usize) -> runmat_builtins::ComplexTensor {
161    let mut data: Vec<(f64, f64)> = vec![(0.0, 0.0); n * n];
162    for i in 0..n {
163        data[i * n + i] = (1.0, 0.0);
164    }
165    runmat_builtins::ComplexTensor::new_2d(data, n, n).unwrap()
166}
167
168/// Create identity matrix
169pub fn matrix_eye(n: usize) -> Tensor {
170    let mut data = vec![0.0; n * n];
171    for i in 0..n {
172        data[i * n + i] = 1.0;
173    }
174    Tensor::new_2d(data, n, n).unwrap() // Always valid
175}
176
177// Simple built-in function for testing matrix operations
178#[runtime_builtin(name = "matrix_zeros")]
179fn matrix_zeros_builtin(rows: i32, cols: i32) -> Result<Tensor, String> {
180    if rows < 0 || cols < 0 {
181        return Err("Matrix dimensions must be non-negative".to_string());
182    }
183    Ok(Tensor::zeros(vec![rows as usize, cols as usize]))
184}
185
186#[runtime_builtin(name = "matrix_ones")]
187fn matrix_ones_builtin(rows: i32, cols: i32) -> Result<Tensor, String> {
188    if rows < 0 || cols < 0 {
189        return Err("Matrix dimensions must be non-negative".to_string());
190    }
191    Ok(Tensor::ones(vec![rows as usize, cols as usize]))
192}
193
194#[runtime_builtin(name = "matrix_eye")]
195fn matrix_eye_builtin(n: i32) -> Result<Tensor, String> {
196    if n < 0 {
197        return Err("Matrix size must be non-negative".to_string());
198    }
199    Ok(matrix_eye(n as usize))
200}
201
202#[runtime_builtin(name = "matrix_transpose")]
203fn matrix_transpose_builtin(a: Tensor) -> Result<Tensor, String> {
204    Ok(matrix_transpose(&a))
205}