1use runmat_builtins::Tensor;
6use runmat_macros::runtime_builtin;
7
8pub fn matrix_add(a: &Tensor, b: &Tensor) -> Result<Tensor, String> {
10 if a.rows() != b.rows() || a.cols() != b.cols() {
11 return Err(format!(
12 "Matrix dimensions must agree: {}x{} + {}x{}",
13 a.rows, a.cols, b.rows, b.cols
14 ));
15 }
16
17 let data: Vec<f64> = a
18 .data
19 .iter()
20 .zip(b.data.iter())
21 .map(|(x, y)| x + y)
22 .collect();
23
24 Tensor::new_2d(data, a.rows(), a.cols())
25}
26
27pub fn matrix_sub(a: &Tensor, b: &Tensor) -> Result<Tensor, String> {
29 if a.rows() != b.rows() || a.cols() != b.cols() {
30 return Err(format!(
31 "Matrix dimensions must agree: {}x{} - {}x{}",
32 a.rows, a.cols, b.rows, b.cols
33 ));
34 }
35
36 let data: Vec<f64> = a
37 .data
38 .iter()
39 .zip(b.data.iter())
40 .map(|(x, y)| x - y)
41 .collect();
42
43 Tensor::new_2d(data, a.rows(), a.cols())
44}
45
46pub fn matrix_mul(a: &Tensor, b: &Tensor) -> Result<Tensor, String> {
48 if a.cols() != b.rows() {
51 return Err(format!(
52 "Inner matrix dimensions must agree: {}x{} * {}x{}",
53 a.rows, a.cols, b.rows, b.cols
54 ));
55 }
56
57 let rows = a.rows();
58 let cols = b.cols();
59 let mut data = vec![0.0; rows * cols];
60
61 for i in 0..rows {
62 for j in 0..cols {
63 let mut sum = 0.0;
64 for k in 0..a.cols() {
65 sum += a.data[i + k * rows] * b.data[k + j * b.rows()];
67 }
68 data[i + j * rows] = sum;
70 }
71 }
72
73 Tensor::new_2d(data, rows, cols)
74}
75
76pub fn value_matmul(
78 a: &runmat_builtins::Value,
79 b: &runmat_builtins::Value,
80) -> Result<runmat_builtins::Value, String> {
81 use runmat_builtins::Value;
82 if let (Value::GpuTensor(ha), Value::GpuTensor(hb)) = (a, b) {
84 if let Some(p) = runmat_accelerate_api::provider() {
85 match p.matmul(ha, hb) {
86 Ok(hc) => {
87 let ht = p.download(&hc).map_err(|e| e.to_string())?;
88 return Ok(Value::Tensor(
89 runmat_builtins::Tensor::new(ht.data, ht.shape)
90 .map_err(|e| e.to_string())?,
91 ));
92 }
93 Err(_) => {
94 let ta = p.download(ha).map_err(|e| e.to_string())?;
96 let tb = p.download(hb).map_err(|e| e.to_string())?;
97 let ca = runmat_builtins::Tensor::new(ta.data, ta.shape)
98 .map_err(|e| e.to_string())?;
99 let cb = runmat_builtins::Tensor::new(tb.data, tb.shape)
100 .map_err(|e| e.to_string())?;
101 return Ok(Value::Tensor(matrix_mul(&ca, &cb)?));
102 }
103 }
104 }
105 }
106 if matches!(a, Value::GpuTensor(_)) || matches!(b, Value::GpuTensor(_)) {
108 let to_host = |v: &Value| -> Result<Value, String> {
109 match v {
110 Value::GpuTensor(h) => {
111 if let Some(p) = runmat_accelerate_api::provider() {
112 let ht = p.download(h).map_err(|e| e.to_string())?;
113 Ok(Value::Tensor(
114 runmat_builtins::Tensor::new(ht.data, ht.shape)
115 .map_err(|e| e.to_string())?,
116 ))
117 } else {
118 let total: usize = h.shape.iter().product();
119 Ok(Value::Tensor(
120 runmat_builtins::Tensor::new(vec![0.0; total], h.shape.clone())
121 .map_err(|e| e.to_string())?,
122 ))
123 }
124 }
125 other => Ok(other.clone()),
126 }
127 };
128 let ah = to_host(a)?;
129 let bh = to_host(b)?;
130 return value_matmul(&ah, &bh);
131 }
132 match (a, b) {
134 (Value::Complex(ar, ai), Value::Complex(br, bi)) => {
136 Ok(Value::Complex(ar * br - ai * bi, ar * bi + ai * br))
137 }
138 (Value::Complex(ar, ai), Value::Num(s)) => Ok(Value::Complex(ar * s, ai * s)),
139 (Value::Num(s), Value::Complex(br, bi)) => Ok(Value::Complex(s * br, s * bi)),
140 (Value::Tensor(t), Value::Complex(cr, ci)) => {
141 Ok(Value::ComplexTensor(matrix_scalar_mul_complex(t, *cr, *ci)))
143 }
144 (Value::Complex(cr, ci), Value::Tensor(t)) => {
145 Ok(Value::ComplexTensor(matrix_scalar_mul_complex(t, *cr, *ci)))
147 }
148 (Value::ComplexTensor(ct), Value::Num(s)) => Ok(Value::ComplexTensor(
149 matrix_scalar_mul_complex_tensor(ct, *s, 0.0),
150 )),
151 (Value::Num(s), Value::ComplexTensor(ct)) => Ok(Value::ComplexTensor(
152 matrix_scalar_mul_complex_tensor(ct, *s, 0.0),
153 )),
154 (Value::ComplexTensor(ct), Value::Complex(cr, ci)) => Ok(Value::ComplexTensor(
155 matrix_scalar_mul_complex_tensor(ct, *cr, *ci),
156 )),
157 (Value::Complex(cr, ci), Value::ComplexTensor(ct)) => Ok(Value::ComplexTensor(
158 matrix_scalar_mul_complex_tensor(ct, *cr, *ci),
159 )),
160 (Value::Tensor(ta), Value::Tensor(tb)) => Ok(Value::Tensor(matrix_mul(ta, tb)?)),
161 (Value::ComplexTensor(ta), Value::ComplexTensor(tb)) => {
162 Ok(Value::ComplexTensor(complex_matrix_mul(ta, tb)?))
163 }
164 (Value::ComplexTensor(ta), Value::Tensor(tb)) => {
165 Ok(Value::ComplexTensor(complex_real_matrix_mul(ta, tb)?))
166 }
167 (Value::Tensor(ta), Value::ComplexTensor(tb)) => {
168 Ok(Value::ComplexTensor(real_complex_matrix_mul(ta, tb)?))
169 }
170 (Value::Tensor(ta), Value::Num(s)) => Ok(Value::Tensor(matrix_scalar_mul(ta, *s))),
171 (Value::Num(s), Value::Tensor(tb)) => Ok(Value::Tensor(matrix_scalar_mul(tb, *s))),
172 (Value::Tensor(ta), Value::Int(i)) => Ok(Value::Tensor(matrix_scalar_mul(ta, i.to_f64()))),
173 (Value::Int(i), Value::Tensor(tb)) => Ok(Value::Tensor(matrix_scalar_mul(tb, i.to_f64()))),
174 (Value::Num(x), Value::Num(y)) => Ok(Value::Num(x * y)),
175 (Value::Int(x), Value::Num(y)) => Ok(Value::Num(x.to_f64() * y)),
176 (Value::Num(x), Value::Int(y)) => Ok(Value::Num(x * y.to_f64())),
177 (Value::Int(x), Value::Int(y)) => Ok(Value::Num(x.to_f64() * y.to_f64())),
178 _ => Err("matmul: unsupported operand types".to_string()),
179 }
180}
181
182fn complex_matrix_mul(
183 a: &runmat_builtins::ComplexTensor,
184 b: &runmat_builtins::ComplexTensor,
185) -> Result<runmat_builtins::ComplexTensor, String> {
186 if a.cols != b.rows {
187 return Err(format!(
188 "Inner matrix dimensions must agree: {}x{} * {}x{}",
189 a.rows, a.cols, b.rows, b.cols
190 ));
191 }
192 let rows = a.rows;
193 let cols = b.cols;
194 let kdim = a.cols;
195 let mut data: Vec<(f64, f64)> = vec![(0.0, 0.0); rows * cols];
196 for j in 0..cols {
197 for i in 0..rows {
198 let mut acc_re = 0.0;
199 let mut acc_im = 0.0;
200 for k in 0..kdim {
201 let (ar, ai) = a.data[i + k * rows];
202 let (br, bi) = b.data[k + j * b.rows];
203 acc_re += ar * br - ai * bi;
204 acc_im += ar * bi + ai * br;
205 }
206 data[i + j * rows] = (acc_re, acc_im);
207 }
208 }
209 runmat_builtins::ComplexTensor::new_2d(data, rows, cols)
210}
211
212fn complex_real_matrix_mul(
213 a: &runmat_builtins::ComplexTensor,
214 b: &runmat_builtins::Tensor,
215) -> Result<runmat_builtins::ComplexTensor, String> {
216 if a.cols != b.rows() {
217 return Err(format!(
218 "Inner matrix dimensions must agree: {}x{} * {}x{}",
219 a.rows,
220 a.cols,
221 b.rows(),
222 b.cols()
223 ));
224 }
225 let rows = a.rows;
226 let cols = b.cols();
227 let kdim = a.cols;
228 let mut data: Vec<(f64, f64)> = vec![(0.0, 0.0); rows * cols];
229 for j in 0..cols {
230 for i in 0..rows {
231 let mut acc_re = 0.0;
232 let mut acc_im = 0.0;
233 for k in 0..kdim {
234 let (ar, ai) = a.data[i + k * rows];
235 let br = b.data[k + j * b.rows()];
236 acc_re += ar * br;
237 acc_im += ai * br;
238 }
239 data[i + j * rows] = (acc_re, acc_im);
240 }
241 }
242 runmat_builtins::ComplexTensor::new_2d(data, rows, cols)
243}
244
245fn real_complex_matrix_mul(
246 a: &runmat_builtins::Tensor,
247 b: &runmat_builtins::ComplexTensor,
248) -> Result<runmat_builtins::ComplexTensor, String> {
249 if a.cols() != b.rows {
250 return Err(format!(
251 "Inner matrix dimensions must agree: {}x{} * {}x{}",
252 a.rows(),
253 a.cols(),
254 b.rows,
255 b.cols
256 ));
257 }
258 let rows = a.rows();
259 let cols = b.cols;
260 let kdim = a.cols();
261 let mut data: Vec<(f64, f64)> = vec![(0.0, 0.0); rows * cols];
262 for j in 0..cols {
263 for i in 0..rows {
264 let mut acc_re = 0.0;
265 let mut acc_im = 0.0;
266 for k in 0..kdim {
267 let ar = a.data[i + k * rows];
268 let (br, bi) = b.data[k + j * b.rows];
269 acc_re += ar * br;
270 acc_im += ar * bi;
271 }
272 data[i + j * rows] = (acc_re, acc_im);
273 }
274 }
275 runmat_builtins::ComplexTensor::new_2d(data, rows, cols)
276}
277
278fn matrix_scalar_mul_complex(a: &Tensor, cr: f64, ci: f64) -> runmat_builtins::ComplexTensor {
279 let data: Vec<(f64, f64)> = a.data.iter().map(|&x| (x * cr, x * ci)).collect();
280 runmat_builtins::ComplexTensor::new_2d(data, a.rows(), a.cols()).unwrap()
281}
282
283fn matrix_scalar_mul_complex_tensor(
284 a: &runmat_builtins::ComplexTensor,
285 cr: f64,
286 ci: f64,
287) -> runmat_builtins::ComplexTensor {
288 let data: Vec<(f64, f64)> = a
289 .data
290 .iter()
291 .map(|&(ar, ai)| (ar * cr - ai * ci, ar * ci + ai * cr))
292 .collect();
293 runmat_builtins::ComplexTensor::new_2d(data, a.rows, a.cols).unwrap()
294}
295
296#[runtime_builtin(name = "mtimes")]
297fn mtimes_builtin(
298 a: runmat_builtins::Value,
299 b: runmat_builtins::Value,
300) -> Result<runmat_builtins::Value, String> {
301 use runmat_builtins::Value;
302 match (&a, &b) {
303 (Value::GpuTensor(_), Value::GpuTensor(_)) => value_matmul(&a, &b),
304 (Value::Tensor(ta), Value::Tensor(tb)) => Ok(Value::Tensor(matrix_mul(ta, tb)?)),
305 (Value::Tensor(ta), Value::Num(s)) => Ok(Value::Tensor(matrix_scalar_mul(ta, *s))),
306 (Value::Num(s), Value::Tensor(tb)) => Ok(Value::Tensor(matrix_scalar_mul(tb, *s))),
307 (Value::Num(x), Value::Num(y)) => Ok(Value::Num(x * y)),
308 _ => Err("mtimes: unsupported operand types".to_string()),
309 }
310}
311
312pub fn matrix_scalar_mul(a: &Tensor, scalar: f64) -> Tensor {
314 let data: Vec<f64> = a.data.iter().map(|x| x * scalar).collect();
315 Tensor::new_2d(data, a.rows(), a.cols()).unwrap() }
317
318pub fn matrix_transpose(a: &Tensor) -> Tensor {
320 let mut data = vec![0.0; a.rows() * a.cols()];
321 for i in 0..a.rows() {
322 for j in 0..a.cols() {
323 data[j * a.rows() + i] = a.data[i + j * a.rows()];
325 }
326 }
327 Tensor::new_2d(data, a.cols(), a.rows()).unwrap() }
329
330pub fn matrix_power(a: &Tensor, n: i32) -> Result<Tensor, String> {
333 if a.rows() != a.cols() {
334 return Err(format!(
335 "Matrix must be square for matrix power: {}x{}",
336 a.rows(),
337 a.cols()
338 ));
339 }
340
341 if n < 0 {
342 return Err("Negative matrix powers not supported yet".to_string());
343 }
344
345 if n == 0 {
346 return Ok(matrix_eye(a.rows));
348 }
349
350 if n == 1 {
351 return Ok(a.clone());
353 }
354
355 let mut result = matrix_eye(a.rows());
358 let mut base = a.clone();
359 let mut exp = n as u32;
360
361 while exp > 0 {
362 if exp % 2 == 1 {
363 result = matrix_mul(&result, &base)?;
364 }
365 base = matrix_mul(&base, &base)?;
366 exp /= 2;
367 }
368
369 Ok(result)
370}
371
372pub fn complex_matrix_power(
375 a: &runmat_builtins::ComplexTensor,
376 n: i32,
377) -> Result<runmat_builtins::ComplexTensor, String> {
378 if a.rows != a.cols {
379 return Err(format!(
380 "Matrix must be square for matrix power: {}x{}",
381 a.rows, a.cols
382 ));
383 }
384 if n < 0 {
385 return Err("Negative matrix powers not supported yet".to_string());
386 }
387 if n == 0 {
388 return Ok(complex_matrix_eye(a.rows));
389 }
390 if n == 1 {
391 return Ok(a.clone());
392 }
393 let mut result = complex_matrix_eye(a.rows);
394 let mut base = a.clone();
395 let mut exp = n as u32;
396 while exp > 0 {
397 if exp % 2 == 1 {
398 result = complex_matrix_mul(&result, &base)?;
399 }
400 base = complex_matrix_mul(&base, &base)?;
401 exp /= 2;
402 }
403 Ok(result)
404}
405
406fn complex_matrix_eye(n: usize) -> runmat_builtins::ComplexTensor {
407 let mut data: Vec<(f64, f64)> = vec![(0.0, 0.0); n * n];
408 for i in 0..n {
409 data[i * n + i] = (1.0, 0.0);
410 }
411 runmat_builtins::ComplexTensor::new_2d(data, n, n).unwrap()
412}
413
414pub fn matrix_eye(n: usize) -> Tensor {
416 let mut data = vec![0.0; n * n];
417 for i in 0..n {
418 data[i * n + i] = 1.0;
419 }
420 Tensor::new_2d(data, n, n).unwrap() }
422
423#[runtime_builtin(name = "matrix_zeros")]
425fn matrix_zeros_builtin(rows: i32, cols: i32) -> Result<Tensor, String> {
426 if rows < 0 || cols < 0 {
427 return Err("Matrix dimensions must be non-negative".to_string());
428 }
429 Ok(Tensor::zeros(vec![rows as usize, cols as usize]))
430}
431
432#[runtime_builtin(name = "matrix_ones")]
433fn matrix_ones_builtin(rows: i32, cols: i32) -> Result<Tensor, String> {
434 if rows < 0 || cols < 0 {
435 return Err("Matrix dimensions must be non-negative".to_string());
436 }
437 Ok(Tensor::ones(vec![rows as usize, cols as usize]))
438}
439
440#[runtime_builtin(name = "matrix_eye")]
441fn matrix_eye_builtin(n: i32) -> Result<Tensor, String> {
442 if n < 0 {
443 return Err("Matrix size must be non-negative".to_string());
444 }
445 Ok(matrix_eye(n as usize))
446}
447
448#[runtime_builtin(name = "matrix_transpose")]
449fn matrix_transpose_builtin(a: Tensor) -> Result<Tensor, String> {
450 Ok(matrix_transpose(&a))
451}