runmat_runtime/
elementwise.rs

1//! Element-wise operations for matrices and scalars
2//!
3//! This module implements language-compatible element-wise operations (.*,  ./,  .^)
4//! These operations work element-by-element on matrices and support scalar broadcasting.
5
6use crate::matrix::matrix_power;
7use runmat_builtins::{Tensor, Value};
8
9fn complex_pow_scalar(base_re: f64, base_im: f64, exp_re: f64, exp_im: f64) -> (f64, f64) {
10    if base_re == 0.0 && base_im == 0.0 && exp_re == 0.0 && exp_im == 0.0 {
11        return (1.0, 0.0);
12    }
13    if base_re == 0.0 && base_im == 0.0 && exp_im == 0.0 && exp_re > 0.0 {
14        return (0.0, 0.0);
15    }
16    let r = (base_re.hypot(base_im)).max(0.0);
17    if r == 0.0 {
18        return (0.0, 0.0);
19    }
20    let theta = base_im.atan2(base_re);
21    let ln_r = r.ln();
22    let a = exp_re * ln_r - exp_im * theta;
23    let b = exp_re * theta + exp_im * ln_r;
24    let mag = a.exp();
25    (mag * b.cos(), mag * b.sin())
26}
27
28fn to_host_value(v: &Value) -> Result<Value, String> {
29    match v {
30        Value::GpuTensor(h) => {
31            if let Some(p) = runmat_accelerate_api::provider_for_handle(h) {
32                let ht = p.download(h).map_err(|e| e.to_string())?;
33                Ok(Value::Tensor(
34                    Tensor::new(ht.data, ht.shape).map_err(|e| e.to_string())?,
35                ))
36            } else {
37                // Fallback: zeros tensor with same shape
38                let total: usize = h.shape.iter().product();
39                Ok(Value::Tensor(
40                    Tensor::new(vec![0.0; total], h.shape.clone()).map_err(|e| e.to_string())?,
41                ))
42            }
43        }
44        other => Ok(other.clone()),
45    }
46}
47
48/// Element-wise negation: -A
49/// Supports scalars and matrices
50pub fn elementwise_neg(a: &Value) -> Result<Value, String> {
51    match a {
52        Value::Num(x) => Ok(Value::Num(-x)),
53        Value::Complex(re, im) => Ok(Value::Complex(-*re, -*im)),
54        Value::Int(x) => {
55            let v = x.to_i64();
56            if v >= i32::MIN as i64 && v <= i32::MAX as i64 {
57                Ok(Value::Int(runmat_builtins::IntValue::I32(-(v as i32))))
58            } else {
59                Ok(Value::Int(runmat_builtins::IntValue::I64(-v)))
60            }
61        }
62        Value::Bool(b) => Ok(Value::Bool(!b)), // Boolean negation
63        Value::Tensor(m) => {
64            let data: Vec<f64> = m.data.iter().map(|x| -x).collect();
65            Ok(Value::Tensor(Tensor::new_2d(data, m.rows(), m.cols())?))
66        }
67        _ => Err(format!("Negation not supported for type: -{a:?}")),
68    }
69}
70
71/// Element-wise multiplication: A .* B
72/// Supports matrix-matrix, matrix-scalar, and scalar-matrix operations
73pub fn elementwise_mul(a: &Value, b: &Value) -> Result<Value, String> {
74    // GPU+scalar: keep on device if provider supports scalar mul
75    if let Some(p) = runmat_accelerate_api::provider() {
76        match (a, b) {
77            (Value::GpuTensor(ga), Value::Num(s)) => {
78                if let Ok(hc) = p.scalar_mul(ga, *s) {
79                    return Ok(Value::GpuTensor(hc));
80                }
81            }
82            (Value::Num(s), Value::GpuTensor(gb)) => {
83                if let Ok(hc) = p.scalar_mul(gb, *s) {
84                    return Ok(Value::GpuTensor(hc));
85                }
86            }
87            (Value::GpuTensor(ga), Value::Int(i)) => {
88                if let Ok(hc) = p.scalar_mul(ga, i.to_f64()) {
89                    return Ok(Value::GpuTensor(hc));
90                }
91            }
92            (Value::Int(i), Value::GpuTensor(gb)) => {
93                if let Ok(hc) = p.scalar_mul(gb, i.to_f64()) {
94                    return Ok(Value::GpuTensor(hc));
95                }
96            }
97            _ => {}
98        }
99    }
100    // If exactly one is GPU and no scalar fast-path, gather to host and recurse
101    if matches!(a, Value::GpuTensor(_)) ^ matches!(b, Value::GpuTensor(_)) {
102        let ah = to_host_value(a)?;
103        let bh = to_host_value(b)?;
104        return elementwise_mul(&ah, &bh);
105    }
106    if let Some(p) = runmat_accelerate_api::provider() {
107        if let (Value::GpuTensor(ha), Value::GpuTensor(hb)) = (a, b) {
108            if let Ok(hc) = p.elem_mul(ha, hb) {
109                return Ok(Value::GpuTensor(hc));
110            }
111        }
112    }
113    match (a, b) {
114        // Complex scalars
115        (Value::Complex(ar, ai), Value::Complex(br, bi)) => {
116            Ok(Value::Complex(ar * br - ai * bi, ar * bi + ai * br))
117        }
118        (Value::Complex(ar, ai), Value::Num(s)) => Ok(Value::Complex(ar * s, ai * s)),
119        (Value::Num(s), Value::Complex(br, bi)) => Ok(Value::Complex(s * br, s * bi)),
120        // Scalar-scalar case
121        (Value::Num(x), Value::Num(y)) => Ok(Value::Num(x * y)),
122        (Value::Int(x), Value::Num(y)) => Ok(Value::Num(x.to_f64() * y)),
123        (Value::Num(x), Value::Int(y)) => Ok(Value::Num(x * y.to_f64())),
124        (Value::Int(x), Value::Int(y)) => Ok(Value::Num(x.to_f64() * y.to_f64())),
125
126        // Matrix-scalar cases (broadcasting)
127        (Value::Tensor(m), Value::Num(s)) => {
128            let data: Vec<f64> = m.data.iter().map(|x| x * s).collect();
129            Ok(Value::Tensor(Tensor::new_2d(data, m.rows(), m.cols())?))
130        }
131        (Value::Tensor(m), Value::Int(s)) => {
132            let scalar = s.to_f64();
133            let data: Vec<f64> = m.data.iter().map(|x| x * scalar).collect();
134            Ok(Value::Tensor(Tensor::new_2d(data, m.rows(), m.cols())?))
135        }
136        (Value::Num(s), Value::Tensor(m)) => {
137            let data: Vec<f64> = m.data.iter().map(|x| s * x).collect();
138            Ok(Value::Tensor(Tensor::new_2d(data, m.rows(), m.cols())?))
139        }
140        (Value::Int(s), Value::Tensor(m)) => {
141            let scalar = s.to_f64();
142            let data: Vec<f64> = m.data.iter().map(|x| scalar * x).collect();
143            Ok(Value::Tensor(Tensor::new_2d(data, m.rows(), m.cols())?))
144        }
145
146        // Matrix-matrix case
147        (Value::Tensor(m1), Value::Tensor(m2)) => {
148            if m1.rows() != m2.rows() || m1.cols() != m2.cols() {
149                return Err(format!(
150                    "Matrix dimensions must agree for element-wise multiplication: {}x{} .* {}x{}",
151                    m1.rows(),
152                    m1.cols(),
153                    m2.rows(),
154                    m2.cols()
155                ));
156            }
157            let data: Vec<f64> = m1
158                .data
159                .iter()
160                .zip(m2.data.iter())
161                .map(|(x, y)| x * y)
162                .collect();
163            Ok(Value::Tensor(Tensor::new_2d(data, m1.rows(), m1.cols())?))
164        }
165
166        // Complex tensors
167        (Value::ComplexTensor(m1), Value::ComplexTensor(m2)) => {
168            if m1.rows != m2.rows || m1.cols != m2.cols {
169                return Err(format!(
170                    "Matrix dimensions must agree for element-wise multiplication: {}x{} .* {}x{}",
171                    m1.rows, m1.cols, m2.rows, m2.cols
172                ));
173            }
174            let mut out: Vec<(f64, f64)> = Vec::with_capacity(m1.data.len());
175            for i in 0..m1.data.len() {
176                let (ar, ai) = m1.data[i];
177                let (br, bi) = m2.data[i];
178                out.push((ar * br - ai * bi, ar * bi + ai * br));
179            }
180            Ok(Value::ComplexTensor(
181                runmat_builtins::ComplexTensor::new(out, m1.shape.clone())
182                    .map_err(|e| format!(".*: {e}"))?,
183            ))
184        }
185        (Value::ComplexTensor(m), Value::Num(s)) => {
186            let data: Vec<(f64, f64)> = m.data.iter().map(|(re, im)| (re * s, im * s)).collect();
187            Ok(Value::ComplexTensor(
188                runmat_builtins::ComplexTensor::new_2d(data, m.rows, m.cols)?,
189            ))
190        }
191        (Value::Num(s), Value::ComplexTensor(m)) => {
192            let data: Vec<(f64, f64)> = m.data.iter().map(|(re, im)| (s * re, s * im)).collect();
193            Ok(Value::ComplexTensor(
194                runmat_builtins::ComplexTensor::new_2d(data, m.rows, m.cols)?,
195            ))
196        }
197
198        _ => Err(format!(
199            "Element-wise multiplication not supported for types: {a:?} .* {b:?}"
200        )),
201    }
202}
203
204// elementwise_add has been retired in favor of the `plus` builtin
205
206// elementwise_sub has been retired in favor of the `minus` builtin
207
208/// Element-wise division: A ./ B
209/// Supports matrix-matrix, matrix-scalar, and scalar-matrix operations
210pub fn elementwise_div(a: &Value, b: &Value) -> Result<Value, String> {
211    // GPU+scalar: use scalar div when form is G ./ s or left-scalar s ./ G
212    if let Some(p) = runmat_accelerate_api::provider() {
213        match (a, b) {
214            (Value::GpuTensor(ga), Value::Num(s)) => {
215                if let Ok(hc) = p.scalar_div(ga, *s) {
216                    return Ok(Value::GpuTensor(hc));
217                }
218            }
219            (Value::GpuTensor(ga), Value::Int(i)) => {
220                if let Ok(hc) = p.scalar_div(ga, i.to_f64()) {
221                    return Ok(Value::GpuTensor(hc));
222                }
223            }
224            (Value::Num(s), Value::GpuTensor(gb)) => {
225                if let Ok(hc) = p.scalar_rdiv(gb, *s) {
226                    return Ok(Value::GpuTensor(hc));
227                }
228            }
229            (Value::Int(i), Value::GpuTensor(gb)) => {
230                if let Ok(hc) = p.scalar_rdiv(gb, i.to_f64()) {
231                    return Ok(Value::GpuTensor(hc));
232                }
233            }
234            _ => {}
235        }
236    }
237    if matches!(a, Value::GpuTensor(_)) ^ matches!(b, Value::GpuTensor(_)) {
238        let ah = to_host_value(a)?;
239        let bh = to_host_value(b)?;
240        return elementwise_div(&ah, &bh);
241    }
242    if let Some(p) = runmat_accelerate_api::provider() {
243        if let (Value::GpuTensor(ha), Value::GpuTensor(hb)) = (a, b) {
244            if let Ok(hc) = p.elem_div(ha, hb) {
245                return Ok(Value::GpuTensor(hc));
246            }
247        }
248    }
249    match (a, b) {
250        // Complex scalars
251        (Value::Complex(ar, ai), Value::Complex(br, bi)) => {
252            let denom = br * br + bi * bi;
253            if denom == 0.0 {
254                return Ok(Value::Num(f64::NAN));
255            }
256            Ok(Value::Complex(
257                (ar * br + ai * bi) / denom,
258                (ai * br - ar * bi) / denom,
259            ))
260        }
261        (Value::Complex(ar, ai), Value::Num(s)) => Ok(Value::Complex(ar / s, ai / s)),
262        (Value::Num(s), Value::Complex(br, bi)) => {
263            let denom = br * br + bi * bi;
264            if denom == 0.0 {
265                return Ok(Value::Num(f64::NAN));
266            }
267            Ok(Value::Complex((s * br) / denom, (-s * bi) / denom))
268        }
269        // Scalar-scalar case
270        (Value::Num(x), Value::Num(y)) => {
271            if *y == 0.0 {
272                Ok(Value::Num(f64::INFINITY * x.signum()))
273            } else {
274                Ok(Value::Num(x / y))
275            }
276        }
277        (Value::Int(x), Value::Num(y)) => {
278            if *y == 0.0 {
279                Ok(Value::Num(f64::INFINITY * x.to_f64().signum()))
280            } else {
281                Ok(Value::Num(x.to_f64() / y))
282            }
283        }
284        (Value::Num(x), Value::Int(y)) => {
285            if y.is_zero() {
286                Ok(Value::Num(f64::INFINITY * x.signum()))
287            } else {
288                Ok(Value::Num(x / y.to_f64()))
289            }
290        }
291        (Value::Int(x), Value::Int(y)) => {
292            if y.is_zero() {
293                Ok(Value::Num(f64::INFINITY * x.to_f64().signum()))
294            } else {
295                Ok(Value::Num(x.to_f64() / y.to_f64()))
296            }
297        }
298
299        // Matrix-scalar cases (broadcasting)
300        (Value::Tensor(m), Value::Num(s)) => {
301            if *s == 0.0 {
302                let data: Vec<f64> = m.data.iter().map(|x| f64::INFINITY * x.signum()).collect();
303                Ok(Value::Tensor(Tensor::new_2d(data, m.rows(), m.cols())?))
304            } else {
305                let data: Vec<f64> = m.data.iter().map(|x| x / s).collect();
306                Ok(Value::Tensor(Tensor::new_2d(data, m.rows(), m.cols())?))
307            }
308        }
309        (Value::Tensor(m), Value::Int(s)) => {
310            let scalar = s.to_f64();
311            if scalar == 0.0 {
312                let data: Vec<f64> = m.data.iter().map(|x| f64::INFINITY * x.signum()).collect();
313                Ok(Value::Tensor(Tensor::new_2d(data, m.rows(), m.cols())?))
314            } else {
315                let data: Vec<f64> = m.data.iter().map(|x| x / scalar).collect();
316                Ok(Value::Tensor(Tensor::new_2d(data, m.rows(), m.cols())?))
317            }
318        }
319        (Value::Num(s), Value::Tensor(m)) => {
320            let data: Vec<f64> = m
321                .data
322                .iter()
323                .map(|x| {
324                    if *x == 0.0 {
325                        f64::INFINITY * s.signum()
326                    } else {
327                        s / x
328                    }
329                })
330                .collect();
331            Ok(Value::Tensor(Tensor::new_2d(data, m.rows(), m.cols())?))
332        }
333        (Value::Int(s), Value::Tensor(m)) => {
334            let scalar = s.to_f64();
335            let data: Vec<f64> = m
336                .data
337                .iter()
338                .map(|x| {
339                    if *x == 0.0 {
340                        f64::INFINITY * scalar.signum()
341                    } else {
342                        scalar / x
343                    }
344                })
345                .collect();
346            Ok(Value::Tensor(Tensor::new_2d(data, m.rows(), m.cols())?))
347        }
348
349        // Matrix-matrix case
350        (Value::Tensor(m1), Value::Tensor(m2)) => {
351            if m1.rows() != m2.rows() || m1.cols() != m2.cols() {
352                return Err(format!(
353                    "Matrix dimensions must agree for element-wise division: {}x{} ./ {}x{}",
354                    m1.rows(),
355                    m1.cols(),
356                    m2.rows(),
357                    m2.cols()
358                ));
359            }
360            let data: Vec<f64> = m1
361                .data
362                .iter()
363                .zip(m2.data.iter())
364                .map(|(x, y)| {
365                    if *y == 0.0 {
366                        f64::INFINITY * x.signum()
367                    } else {
368                        x / y
369                    }
370                })
371                .collect();
372            Ok(Value::Tensor(Tensor::new_2d(data, m1.rows(), m1.cols())?))
373        }
374
375        // Complex tensors
376        (Value::ComplexTensor(m1), Value::ComplexTensor(m2)) => {
377            if m1.rows != m2.rows || m1.cols != m2.cols {
378                return Err(format!(
379                    "Matrix dimensions must agree for element-wise division: {}x{} ./ {}x{}",
380                    m1.rows, m1.cols, m2.rows, m2.cols
381                ));
382            }
383            let data: Vec<(f64, f64)> = m1
384                .data
385                .iter()
386                .zip(m2.data.iter())
387                .map(|((ar, ai), (br, bi))| {
388                    let denom = br * br + bi * bi;
389                    if denom == 0.0 {
390                        (f64::NAN, f64::NAN)
391                    } else {
392                        ((ar * br + ai * bi) / denom, (ai * br - ar * bi) / denom)
393                    }
394                })
395                .collect();
396            Ok(Value::ComplexTensor(
397                runmat_builtins::ComplexTensor::new_2d(data, m1.rows, m1.cols)?,
398            ))
399        }
400        (Value::ComplexTensor(m), Value::Num(s)) => {
401            let data: Vec<(f64, f64)> = m.data.iter().map(|(re, im)| (re / s, im / s)).collect();
402            Ok(Value::ComplexTensor(
403                runmat_builtins::ComplexTensor::new_2d(data, m.rows, m.cols)?,
404            ))
405        }
406        (Value::Num(s), Value::ComplexTensor(m)) => {
407            let data: Vec<(f64, f64)> = m
408                .data
409                .iter()
410                .map(|(br, bi)| {
411                    let denom = br * br + bi * bi;
412                    if denom == 0.0 {
413                        (f64::NAN, f64::NAN)
414                    } else {
415                        ((s * br) / denom, (-s * bi) / denom)
416                    }
417                })
418                .collect();
419            Ok(Value::ComplexTensor(
420                runmat_builtins::ComplexTensor::new_2d(data, m.rows, m.cols)?,
421            ))
422        }
423
424        _ => Err(format!(
425            "Element-wise division not supported for types: {a:?} ./ {b:?}"
426        )),
427    }
428}
429
430/// Regular power operation: A ^ B  
431/// For matrices, this is matrix exponentiation (A^n where n is integer)
432/// For scalars, this is regular exponentiation
433pub fn power(a: &Value, b: &Value) -> Result<Value, String> {
434    match (a, b) {
435        // Scalar cases - include complex
436        (Value::Complex(ar, ai), Value::Complex(br, bi)) => {
437            let (r, i) = complex_pow_scalar(*ar, *ai, *br, *bi);
438            Ok(Value::Complex(r, i))
439        }
440        (Value::Complex(ar, ai), Value::Num(y)) => {
441            let (r, i) = complex_pow_scalar(*ar, *ai, *y, 0.0);
442            Ok(Value::Complex(r, i))
443        }
444        (Value::Num(x), Value::Complex(br, bi)) => {
445            let (r, i) = complex_pow_scalar(*x, 0.0, *br, *bi);
446            Ok(Value::Complex(r, i))
447        }
448        (Value::Complex(ar, ai), Value::Int(y)) => {
449            let yv = y.to_f64();
450            let (r, i) = complex_pow_scalar(*ar, *ai, yv, 0.0);
451            Ok(Value::Complex(r, i))
452        }
453        (Value::Int(x), Value::Complex(br, bi)) => {
454            let xv = x.to_f64();
455            let (r, i) = complex_pow_scalar(xv, 0.0, *br, *bi);
456            Ok(Value::Complex(r, i))
457        }
458
459        // Scalar cases - real only
460        (Value::Num(x), Value::Num(y)) => Ok(Value::Num(x.powf(*y))),
461        (Value::Int(x), Value::Num(y)) => Ok(Value::Num(x.to_f64().powf(*y))),
462        (Value::Num(x), Value::Int(y)) => Ok(Value::Num(x.powf(y.to_f64()))),
463        (Value::Int(x), Value::Int(y)) => Ok(Value::Num(x.to_f64().powf(y.to_f64()))),
464
465        // Matrix^scalar case - matrix exponentiation
466        (Value::Tensor(m), Value::Num(s)) => {
467            // Check if scalar is an integer for matrix power
468            if s.fract() == 0.0 {
469                let n = *s as i32;
470                let result = matrix_power(m, n)?;
471                Ok(Value::Tensor(result))
472            } else {
473                Err("Matrix power requires integer exponent".to_string())
474            }
475        }
476        (Value::Tensor(m), Value::Int(s)) => {
477            let result = matrix_power(m, s.to_i64() as i32)?;
478            Ok(Value::Tensor(result))
479        }
480
481        // Complex matrix^integer case
482        (Value::ComplexTensor(m), Value::Num(s)) => {
483            if s.fract() == 0.0 {
484                let n = *s as i32;
485                let result = crate::matrix::complex_matrix_power(m, n)?;
486                Ok(Value::ComplexTensor(result))
487            } else {
488                Err("Matrix power requires integer exponent".to_string())
489            }
490        }
491        (Value::ComplexTensor(m), Value::Int(s)) => {
492            let result = crate::matrix::complex_matrix_power(m, s.to_i64() as i32)?;
493            Ok(Value::ComplexTensor(result))
494        }
495
496        // Other cases not supported for regular matrix power
497        _ => Err(format!(
498            "Power operation not supported for types: {a:?} ^ {b:?}"
499        )),
500    }
501}
502
503/// Element-wise power: A .^ B
504/// Supports matrix-matrix, matrix-scalar, and scalar-matrix operations
505pub fn elementwise_pow(a: &Value, b: &Value) -> Result<Value, String> {
506    match (a, b) {
507        // Complex scalar cases
508        (Value::Complex(ar, ai), Value::Complex(br, bi)) => {
509            let (r, i) = complex_pow_scalar(*ar, *ai, *br, *bi);
510            Ok(Value::Complex(r, i))
511        }
512        (Value::Complex(ar, ai), Value::Num(y)) => {
513            let (r, i) = complex_pow_scalar(*ar, *ai, *y, 0.0);
514            Ok(Value::Complex(r, i))
515        }
516        (Value::Num(x), Value::Complex(br, bi)) => {
517            let (r, i) = complex_pow_scalar(*x, 0.0, *br, *bi);
518            Ok(Value::Complex(r, i))
519        }
520        (Value::Complex(ar, ai), Value::Int(y)) => {
521            let yv = y.to_f64();
522            let (r, i) = complex_pow_scalar(*ar, *ai, yv, 0.0);
523            Ok(Value::Complex(r, i))
524        }
525        (Value::Int(x), Value::Complex(br, bi)) => {
526            let xv = x.to_f64();
527            let (r, i) = complex_pow_scalar(xv, 0.0, *br, *bi);
528            Ok(Value::Complex(r, i))
529        }
530        // Scalar-scalar case
531        (Value::Num(x), Value::Num(y)) => Ok(Value::Num(x.powf(*y))),
532        (Value::Int(x), Value::Num(y)) => Ok(Value::Num(x.to_f64().powf(*y))),
533        (Value::Num(x), Value::Int(y)) => Ok(Value::Num(x.powf(y.to_f64()))),
534        (Value::Int(x), Value::Int(y)) => Ok(Value::Num(x.to_f64().powf(y.to_f64()))),
535
536        // Matrix-scalar cases (broadcasting)
537        (Value::Tensor(m), Value::Num(s)) => {
538            let data: Vec<f64> = m.data.iter().map(|x| x.powf(*s)).collect();
539            Ok(Value::Tensor(Tensor::new_2d(data, m.rows(), m.cols())?))
540        }
541        (Value::Tensor(m), Value::Int(s)) => {
542            let scalar = s.to_f64();
543            let data: Vec<f64> = m.data.iter().map(|x| x.powf(scalar)).collect();
544            Ok(Value::Tensor(Tensor::new_2d(data, m.rows(), m.cols())?))
545        }
546        (Value::Num(s), Value::Tensor(m)) => {
547            let data: Vec<f64> = m.data.iter().map(|x| s.powf(*x)).collect();
548            Ok(Value::Tensor(Tensor::new_2d(data, m.rows(), m.cols())?))
549        }
550        (Value::Int(s), Value::Tensor(m)) => {
551            let scalar = s.to_f64();
552            let data: Vec<f64> = m.data.iter().map(|x| scalar.powf(*x)).collect();
553            Ok(Value::Tensor(Tensor::new_2d(data, m.rows(), m.cols())?))
554        }
555
556        // Matrix-matrix case
557        (Value::Tensor(m1), Value::Tensor(m2)) => {
558            if m1.rows() != m2.rows() || m1.cols() != m2.cols() {
559                return Err(format!(
560                    "Matrix dimensions must agree for element-wise power: {}x{} .^ {}x{}",
561                    m1.rows(),
562                    m1.cols(),
563                    m2.rows(),
564                    m2.cols()
565                ));
566            }
567            let data: Vec<f64> = m1
568                .data
569                .iter()
570                .zip(m2.data.iter())
571                .map(|(x, y)| x.powf(*y))
572                .collect();
573            Ok(Value::Tensor(Tensor::new_2d(data, m1.rows(), m1.cols())?))
574        }
575
576        // Complex tensor element-wise power
577        (Value::ComplexTensor(m1), Value::ComplexTensor(m2)) => {
578            if m1.rows != m2.rows || m1.cols != m2.cols {
579                return Err(format!(
580                    "Matrix dimensions must agree for element-wise power: {}x{} .^ {}x{}",
581                    m1.rows, m1.cols, m2.rows, m2.cols
582                ));
583            }
584            let mut out: Vec<(f64, f64)> = Vec::with_capacity(m1.data.len());
585            for i in 0..m1.data.len() {
586                let (ar, ai) = m1.data[i];
587                let (br, bi) = m2.data[i];
588                out.push(complex_pow_scalar(ar, ai, br, bi));
589            }
590            Ok(Value::ComplexTensor(
591                runmat_builtins::ComplexTensor::new_2d(out, m1.rows, m1.cols)?,
592            ))
593        }
594        (Value::ComplexTensor(m), Value::Num(s)) => {
595            let out: Vec<(f64, f64)> = m
596                .data
597                .iter()
598                .map(|(ar, ai)| complex_pow_scalar(*ar, *ai, *s, 0.0))
599                .collect();
600            Ok(Value::ComplexTensor(
601                runmat_builtins::ComplexTensor::new_2d(out, m.rows, m.cols)?,
602            ))
603        }
604        (Value::ComplexTensor(m), Value::Int(s)) => {
605            let sv = s.to_f64();
606            let out: Vec<(f64, f64)> = m
607                .data
608                .iter()
609                .map(|(ar, ai)| complex_pow_scalar(*ar, *ai, sv, 0.0))
610                .collect();
611            Ok(Value::ComplexTensor(
612                runmat_builtins::ComplexTensor::new_2d(out, m.rows, m.cols)?,
613            ))
614        }
615        (Value::ComplexTensor(m), Value::Complex(br, bi)) => {
616            let out: Vec<(f64, f64)> = m
617                .data
618                .iter()
619                .map(|(ar, ai)| complex_pow_scalar(*ar, *ai, *br, *bi))
620                .collect();
621            Ok(Value::ComplexTensor(
622                runmat_builtins::ComplexTensor::new_2d(out, m.rows, m.cols)?,
623            ))
624        }
625        (Value::Num(s), Value::ComplexTensor(m)) => {
626            let out: Vec<(f64, f64)> = m
627                .data
628                .iter()
629                .map(|(br, bi)| complex_pow_scalar(*s, 0.0, *br, *bi))
630                .collect();
631            Ok(Value::ComplexTensor(
632                runmat_builtins::ComplexTensor::new_2d(out, m.rows, m.cols)?,
633            ))
634        }
635        (Value::Int(s), Value::ComplexTensor(m)) => {
636            let sv = s.to_f64();
637            let out: Vec<(f64, f64)> = m
638                .data
639                .iter()
640                .map(|(br, bi)| complex_pow_scalar(sv, 0.0, *br, *bi))
641                .collect();
642            Ok(Value::ComplexTensor(
643                runmat_builtins::ComplexTensor::new_2d(out, m.rows, m.cols)?,
644            ))
645        }
646        (Value::Complex(br, bi), Value::ComplexTensor(m)) => {
647            let out: Vec<(f64, f64)> = m
648                .data
649                .iter()
650                .map(|(er, ei)| complex_pow_scalar(*br, *bi, *er, *ei))
651                .collect();
652            Ok(Value::ComplexTensor(
653                runmat_builtins::ComplexTensor::new_2d(out, m.rows, m.cols)?,
654            ))
655        }
656
657        _ => Err(format!(
658            "Element-wise power not supported for types: {a:?} .^ {b:?}"
659        )),
660    }
661}
662
663// Element-wise operations are not directly exposed as runtime builtins because they need
664// to handle multiple types (Value enum variants). Instead, they are called directly from
665// the interpreter and JIT compiler using the elementwise_* functions above.
666
667#[cfg(test)]
668mod tests {
669    use super::*;
670
671    #[test]
672    fn test_elementwise_mul_scalars() {
673        assert_eq!(
674            elementwise_mul(&Value::Num(3.0), &Value::Num(4.0)).unwrap(),
675            Value::Num(12.0)
676        );
677        assert_eq!(
678            elementwise_mul(
679                &Value::Int(runmat_builtins::IntValue::I32(3)),
680                &Value::Num(4.5)
681            )
682            .unwrap(),
683            Value::Num(13.5)
684        );
685    }
686
687    #[test]
688    fn test_elementwise_mul_matrix_scalar() {
689        let matrix = Tensor::new_2d(vec![1.0, 2.0, 3.0, 4.0], 2, 2).unwrap();
690        let result = elementwise_mul(&Value::Tensor(matrix), &Value::Num(2.0)).unwrap();
691
692        if let Value::Tensor(m) = result {
693            assert_eq!(m.data, vec![2.0, 4.0, 6.0, 8.0]);
694            assert_eq!(m.rows(), 2);
695            assert_eq!(m.cols(), 2);
696        } else {
697            panic!("Expected matrix result");
698        }
699    }
700
701    #[test]
702    fn test_elementwise_mul_matrices() {
703        let m1 = Tensor::new_2d(vec![1.0, 2.0, 3.0, 4.0], 2, 2).unwrap();
704        let m2 = Tensor::new_2d(vec![2.0, 3.0, 4.0, 5.0], 2, 2).unwrap();
705        let result = elementwise_mul(&Value::Tensor(m1), &Value::Tensor(m2)).unwrap();
706
707        if let Value::Tensor(m) = result {
708            assert_eq!(m.data, vec![2.0, 6.0, 12.0, 20.0]);
709        } else {
710            panic!("Expected matrix result");
711        }
712    }
713
714    #[test]
715    fn test_elementwise_div_with_zero() {
716        let result = elementwise_div(&Value::Num(5.0), &Value::Num(0.0)).unwrap();
717        if let Value::Num(n) = result {
718            assert!(n.is_infinite() && n.is_sign_positive());
719        } else {
720            panic!("Expected numeric result");
721        }
722    }
723
724    #[test]
725    fn test_elementwise_pow() {
726        let matrix = Tensor::new_2d(vec![2.0, 3.0, 4.0, 5.0], 2, 2).unwrap();
727        let result = elementwise_pow(&Value::Tensor(matrix), &Value::Num(2.0)).unwrap();
728
729        if let Value::Tensor(m) = result {
730            assert_eq!(m.data, vec![4.0, 9.0, 16.0, 25.0]);
731        } else {
732            panic!("Expected matrix result");
733        }
734    }
735
736    #[test]
737    fn test_dimension_mismatch() {
738        let m1 = Tensor::new_2d(vec![1.0, 2.0], 1, 2).unwrap();
739        let m2 = Tensor::new_2d(vec![1.0, 2.0, 3.0, 4.0], 2, 2).unwrap();
740
741        assert!(elementwise_mul(&Value::Tensor(m1), &Value::Tensor(m2)).is_err());
742    }
743}