1use crate::builtins::common::linalg;
6use runmat_builtins::Tensor;
7use runmat_macros::runtime_builtin;
8
9pub fn matrix_add(a: &Tensor, b: &Tensor) -> Result<Tensor, String> {
11 if a.rows() != b.rows() || a.cols() != b.cols() {
12 return Err(format!(
13 "Matrix dimensions must agree: {}x{} + {}x{}",
14 a.rows, a.cols, b.rows, b.cols
15 ));
16 }
17
18 let data: Vec<f64> = a
19 .data
20 .iter()
21 .zip(b.data.iter())
22 .map(|(x, y)| x + y)
23 .collect();
24
25 Tensor::new_2d(data, a.rows(), a.cols())
26}
27
28pub fn matrix_sub(a: &Tensor, b: &Tensor) -> Result<Tensor, String> {
30 if a.rows() != b.rows() || a.cols() != b.cols() {
31 return Err(format!(
32 "Matrix dimensions must agree: {}x{} - {}x{}",
33 a.rows, a.cols, b.rows, b.cols
34 ));
35 }
36
37 let data: Vec<f64> = a
38 .data
39 .iter()
40 .zip(b.data.iter())
41 .map(|(x, y)| x - y)
42 .collect();
43
44 Tensor::new_2d(data, a.rows(), a.cols())
45}
46
47pub fn matrix_mul(a: &Tensor, b: &Tensor) -> Result<Tensor, String> {
49 linalg::matmul_real(a, b)
50}
51
52pub fn value_matmul(
54 a: &runmat_builtins::Value,
55 b: &runmat_builtins::Value,
56) -> Result<runmat_builtins::Value, String> {
57 crate::builtins::math::linalg::ops::mtimes::mtimes_eval(a, b)
58}
59
60fn complex_matrix_mul(
61 a: &runmat_builtins::ComplexTensor,
62 b: &runmat_builtins::ComplexTensor,
63) -> Result<runmat_builtins::ComplexTensor, String> {
64 linalg::matmul_complex(a, b)
65}
66
67pub fn matrix_scalar_mul(a: &Tensor, scalar: f64) -> Tensor {
69 linalg::scalar_mul_real(a, scalar)
70}
71
72pub fn matrix_transpose(a: &Tensor) -> Tensor {
74 let mut data = vec![0.0; a.rows() * a.cols()];
75 for i in 0..a.rows() {
76 for j in 0..a.cols() {
77 data[j * a.rows() + i] = a.data[i + j * a.rows()];
79 }
80 }
81 Tensor::new_2d(data, a.cols(), a.rows()).unwrap() }
83
84pub fn matrix_power(a: &Tensor, n: i32) -> Result<Tensor, String> {
87 if a.rows() != a.cols() {
88 return Err(format!(
89 "Matrix must be square for matrix power: {}x{}",
90 a.rows(),
91 a.cols()
92 ));
93 }
94
95 if n < 0 {
96 return Err("Negative matrix powers not supported yet".to_string());
97 }
98
99 if n == 0 {
100 return Ok(matrix_eye(a.rows));
102 }
103
104 if n == 1 {
105 return Ok(a.clone());
107 }
108
109 let mut result = matrix_eye(a.rows());
112 let mut base = a.clone();
113 let mut exp = n as u32;
114
115 while exp > 0 {
116 if exp % 2 == 1 {
117 result = matrix_mul(&result, &base)?;
118 }
119 base = matrix_mul(&base, &base)?;
120 exp /= 2;
121 }
122
123 Ok(result)
124}
125
126pub fn complex_matrix_power(
129 a: &runmat_builtins::ComplexTensor,
130 n: i32,
131) -> Result<runmat_builtins::ComplexTensor, String> {
132 if a.rows != a.cols {
133 return Err(format!(
134 "Matrix must be square for matrix power: {}x{}",
135 a.rows, a.cols
136 ));
137 }
138 if n < 0 {
139 return Err("Negative matrix powers not supported yet".to_string());
140 }
141 if n == 0 {
142 return Ok(complex_matrix_eye(a.rows));
143 }
144 if n == 1 {
145 return Ok(a.clone());
146 }
147 let mut result = complex_matrix_eye(a.rows);
148 let mut base = a.clone();
149 let mut exp = n as u32;
150 while exp > 0 {
151 if exp % 2 == 1 {
152 result = complex_matrix_mul(&result, &base)?;
153 }
154 base = complex_matrix_mul(&base, &base)?;
155 exp /= 2;
156 }
157 Ok(result)
158}
159
160fn complex_matrix_eye(n: usize) -> runmat_builtins::ComplexTensor {
161 let mut data: Vec<(f64, f64)> = vec![(0.0, 0.0); n * n];
162 for i in 0..n {
163 data[i * n + i] = (1.0, 0.0);
164 }
165 runmat_builtins::ComplexTensor::new_2d(data, n, n).unwrap()
166}
167
168pub fn matrix_eye(n: usize) -> Tensor {
170 let mut data = vec![0.0; n * n];
171 for i in 0..n {
172 data[i * n + i] = 1.0;
173 }
174 Tensor::new_2d(data, n, n).unwrap() }
176
177#[runtime_builtin(name = "matrix_zeros")]
179fn matrix_zeros_builtin(rows: i32, cols: i32) -> Result<Tensor, String> {
180 if rows < 0 || cols < 0 {
181 return Err("Matrix dimensions must be non-negative".to_string());
182 }
183 Ok(Tensor::zeros(vec![rows as usize, cols as usize]))
184}
185
186#[runtime_builtin(name = "matrix_ones")]
187fn matrix_ones_builtin(rows: i32, cols: i32) -> Result<Tensor, String> {
188 if rows < 0 || cols < 0 {
189 return Err("Matrix dimensions must be non-negative".to_string());
190 }
191 Ok(Tensor::ones(vec![rows as usize, cols as usize]))
192}
193
194#[runtime_builtin(name = "matrix_eye")]
195fn matrix_eye_builtin(n: i32) -> Result<Tensor, String> {
196 if n < 0 {
197 return Err("Matrix size must be non-negative".to_string());
198 }
199 Ok(matrix_eye(n as usize))
200}
201
202#[runtime_builtin(name = "matrix_transpose")]
203fn matrix_transpose_builtin(a: Tensor) -> Result<Tensor, String> {
204 Ok(matrix_transpose(&a))
205}