Skip to main content

runmat_runtime/
elementwise.rs

1//! Element-wise operations for matrices and scalars
2//!
3//! This module implements language-compatible element-wise operations (.*,  ./,  .^)
4//! These operations work element-by-element on matrices and support scalar broadcasting.
5
6use crate::matrix::matrix_power;
7use runmat_builtins::{Tensor, Value};
8
9fn complex_pow_scalar(base_re: f64, base_im: f64, exp_re: f64, exp_im: f64) -> (f64, f64) {
10    if base_re == 0.0 && base_im == 0.0 && exp_re == 0.0 && exp_im == 0.0 {
11        return (1.0, 0.0);
12    }
13    if base_re == 0.0 && base_im == 0.0 && exp_im == 0.0 && exp_re > 0.0 {
14        return (0.0, 0.0);
15    }
16    let r = (base_re.hypot(base_im)).max(0.0);
17    if r == 0.0 {
18        return (0.0, 0.0);
19    }
20    let theta = base_im.atan2(base_re);
21    let ln_r = r.ln();
22    let a = exp_re * ln_r - exp_im * theta;
23    let b = exp_re * theta + exp_im * ln_r;
24    let mag = a.exp();
25    (mag * b.cos(), mag * b.sin())
26}
27
28fn scalar_real_value(value: &Value) -> Option<f64> {
29    match value {
30        Value::Num(n) => Some(*n),
31        Value::Int(i) => Some(i.to_f64()),
32        Value::Bool(b) => Some(if *b { 1.0 } else { 0.0 }),
33        Value::Tensor(t) if t.data.len() == 1 => t.data.first().copied(),
34        _ => None,
35    }
36}
37
38fn scalar_complex_value(value: &Value) -> Option<(f64, f64)> {
39    match value {
40        Value::Complex(re, im) => Some((*re, *im)),
41        Value::ComplexTensor(t) if t.data.len() == 1 => t.data.first().copied(),
42        _ => None,
43    }
44}
45
46fn scalar_power_value(base: &Value, exponent: &Value) -> Option<Value> {
47    let base_is_complex = matches!(base, Value::Complex(_, _) | Value::ComplexTensor(_));
48    let exp_is_complex = matches!(exponent, Value::Complex(_, _) | Value::ComplexTensor(_));
49    let base_val =
50        scalar_complex_value(base).or_else(|| scalar_real_value(base).map(|v| (v, 0.0)))?;
51    let exp_val =
52        scalar_complex_value(exponent).or_else(|| scalar_real_value(exponent).map(|v| (v, 0.0)))?;
53    let (br, bi) = base_val;
54    let (er, ei) = exp_val;
55    if base_is_complex || exp_is_complex || bi != 0.0 || ei != 0.0 {
56        let (re, im) = complex_pow_scalar(br, bi, er, ei);
57        return Some(Value::Complex(re, im));
58    }
59    let pow = br.powf(er);
60    if pow.is_nan() {
61        let (re, im) = complex_pow_scalar(br, 0.0, er, 0.0);
62        Some(Value::Complex(re, im))
63    } else {
64        Some(Value::Num(pow))
65    }
66}
67
68async fn to_host_value(v: &Value) -> Result<Value, String> {
69    match v {
70        Value::GpuTensor(h) => {
71            if runmat_accelerate_api::provider_for_handle(h).is_some() {
72                let gathered = crate::dispatcher::gather_if_needed_async(v)
73                    .await
74                    .map_err(|e| e.to_string())?;
75                Ok(gathered)
76            } else {
77                // Fallback: zeros tensor with same shape
78                let total: usize = h.shape.iter().product();
79                Ok(Value::Tensor(
80                    Tensor::new(vec![0.0; total], h.shape.clone()).map_err(|e| e.to_string())?,
81                ))
82            }
83        }
84        other => Ok(other.clone()),
85    }
86}
87
88/// Element-wise negation: -A
89/// Supports scalars and matrices
90pub fn elementwise_neg(a: &Value) -> Result<Value, String> {
91    match a {
92        Value::Num(x) => Ok(Value::Num(-x)),
93        Value::Complex(re, im) => Ok(Value::Complex(-*re, -*im)),
94        Value::Int(x) => {
95            let v = x.to_i64();
96            if v >= i32::MIN as i64 && v <= i32::MAX as i64 {
97                Ok(Value::Int(runmat_builtins::IntValue::I32(-(v as i32))))
98            } else {
99                Ok(Value::Int(runmat_builtins::IntValue::I64(-v)))
100            }
101        }
102        Value::Bool(b) => Ok(Value::Bool(!b)), // Boolean negation
103        Value::Tensor(m) => {
104            let data: Vec<f64> = m.data.iter().map(|x| -x).collect();
105            Ok(Value::Tensor(Tensor::new_2d(data, m.rows(), m.cols())?))
106        }
107        _ => Err(format!("Negation not supported for type: -{a:?}")),
108    }
109}
110
111/// Element-wise multiplication: A .* B
112/// Supports matrix-matrix, matrix-scalar, and scalar-matrix operations
113#[async_recursion::async_recursion(?Send)]
114pub async fn elementwise_mul(a: &Value, b: &Value) -> Result<Value, String> {
115    // GPU+scalar: keep on device if provider supports scalar mul
116    if let Some(p) = runmat_accelerate_api::provider() {
117        match (a, b) {
118            (Value::GpuTensor(ga), Value::Num(s)) => {
119                if let Ok(hc) = p.scalar_mul(ga, *s) {
120                    return Ok(Value::GpuTensor(hc));
121                }
122            }
123            (Value::Num(s), Value::GpuTensor(gb)) => {
124                if let Ok(hc) = p.scalar_mul(gb, *s) {
125                    return Ok(Value::GpuTensor(hc));
126                }
127            }
128            (Value::GpuTensor(ga), Value::Int(i)) => {
129                if let Ok(hc) = p.scalar_mul(ga, i.to_f64()) {
130                    return Ok(Value::GpuTensor(hc));
131                }
132            }
133            (Value::Int(i), Value::GpuTensor(gb)) => {
134                if let Ok(hc) = p.scalar_mul(gb, i.to_f64()) {
135                    return Ok(Value::GpuTensor(hc));
136                }
137            }
138            _ => {}
139        }
140    }
141    // If exactly one is GPU and no scalar fast-path, gather to host and recurse
142    if matches!(a, Value::GpuTensor(_)) ^ matches!(b, Value::GpuTensor(_)) {
143        let ah = to_host_value(a).await?;
144        let bh = to_host_value(b).await?;
145        return elementwise_mul(&ah, &bh).await;
146    }
147    if let Some(p) = runmat_accelerate_api::provider() {
148        if let (Value::GpuTensor(ha), Value::GpuTensor(hb)) = (a, b) {
149            if let Ok(hc) = p.elem_mul(ha, hb).await {
150                return Ok(Value::GpuTensor(hc));
151            }
152        }
153    }
154    match (a, b) {
155        // Complex scalars
156        (Value::Complex(ar, ai), Value::Complex(br, bi)) => {
157            Ok(Value::Complex(ar * br - ai * bi, ar * bi + ai * br))
158        }
159        (Value::Complex(ar, ai), Value::Num(s)) => Ok(Value::Complex(ar * s, ai * s)),
160        (Value::Num(s), Value::Complex(br, bi)) => Ok(Value::Complex(s * br, s * bi)),
161        // Scalar-scalar case
162        (Value::Num(x), Value::Num(y)) => Ok(Value::Num(x * y)),
163        (Value::Int(x), Value::Num(y)) => Ok(Value::Num(x.to_f64() * y)),
164        (Value::Num(x), Value::Int(y)) => Ok(Value::Num(x * y.to_f64())),
165        (Value::Int(x), Value::Int(y)) => Ok(Value::Num(x.to_f64() * y.to_f64())),
166
167        // Matrix-scalar cases (broadcasting)
168        (Value::Tensor(m), Value::Num(s)) => {
169            let data: Vec<f64> = m.data.iter().map(|x| x * s).collect();
170            Ok(Value::Tensor(Tensor::new_2d(data, m.rows(), m.cols())?))
171        }
172        (Value::Tensor(m), Value::Int(s)) => {
173            let scalar = s.to_f64();
174            let data: Vec<f64> = m.data.iter().map(|x| x * scalar).collect();
175            Ok(Value::Tensor(Tensor::new_2d(data, m.rows(), m.cols())?))
176        }
177        (Value::Num(s), Value::Tensor(m)) => {
178            let data: Vec<f64> = m.data.iter().map(|x| s * x).collect();
179            Ok(Value::Tensor(Tensor::new_2d(data, m.rows(), m.cols())?))
180        }
181        (Value::Int(s), Value::Tensor(m)) => {
182            let scalar = s.to_f64();
183            let data: Vec<f64> = m.data.iter().map(|x| scalar * x).collect();
184            Ok(Value::Tensor(Tensor::new_2d(data, m.rows(), m.cols())?))
185        }
186
187        // Matrix-matrix case
188        (Value::Tensor(m1), Value::Tensor(m2)) => {
189            if m1.rows() != m2.rows() || m1.cols() != m2.cols() {
190                return Err(format!(
191                    "Matrix dimensions must agree for element-wise multiplication: {}x{} .* {}x{}",
192                    m1.rows(),
193                    m1.cols(),
194                    m2.rows(),
195                    m2.cols()
196                ));
197            }
198            let data: Vec<f64> = m1
199                .data
200                .iter()
201                .zip(m2.data.iter())
202                .map(|(x, y)| x * y)
203                .collect();
204            Ok(Value::Tensor(Tensor::new_2d(data, m1.rows(), m1.cols())?))
205        }
206
207        // Complex tensors
208        (Value::ComplexTensor(m1), Value::ComplexTensor(m2)) => {
209            if m1.rows != m2.rows || m1.cols != m2.cols {
210                return Err(format!(
211                    "Matrix dimensions must agree for element-wise multiplication: {}x{} .* {}x{}",
212                    m1.rows, m1.cols, m2.rows, m2.cols
213                ));
214            }
215            let mut out: Vec<(f64, f64)> = Vec::with_capacity(m1.data.len());
216            for i in 0..m1.data.len() {
217                let (ar, ai) = m1.data[i];
218                let (br, bi) = m2.data[i];
219                out.push((ar * br - ai * bi, ar * bi + ai * br));
220            }
221            Ok(Value::ComplexTensor(
222                runmat_builtins::ComplexTensor::new(out, m1.shape.clone())
223                    .map_err(|e| format!(".*: {e}"))?,
224            ))
225        }
226        (Value::ComplexTensor(m), Value::Num(s)) => {
227            let data: Vec<(f64, f64)> = m.data.iter().map(|(re, im)| (re * s, im * s)).collect();
228            Ok(Value::ComplexTensor(
229                runmat_builtins::ComplexTensor::new_2d(data, m.rows, m.cols)?,
230            ))
231        }
232        (Value::Num(s), Value::ComplexTensor(m)) => {
233            let data: Vec<(f64, f64)> = m.data.iter().map(|(re, im)| (s * re, s * im)).collect();
234            Ok(Value::ComplexTensor(
235                runmat_builtins::ComplexTensor::new_2d(data, m.rows, m.cols)?,
236            ))
237        }
238
239        _ => Err(format!(
240            "Element-wise multiplication not supported for types: {a:?} .* {b:?}"
241        )),
242    }
243}
244
245// elementwise_add has been retired in favor of the `plus` builtin
246
247// elementwise_sub has been retired in favor of the `minus` builtin
248
249/// Element-wise division: A ./ B
250/// Supports matrix-matrix, matrix-scalar, and scalar-matrix operations
251#[async_recursion::async_recursion(?Send)]
252pub async fn elementwise_div(a: &Value, b: &Value) -> Result<Value, String> {
253    // GPU+scalar: use scalar div when form is G ./ s or left-scalar s ./ G
254    if let Some(p) = runmat_accelerate_api::provider() {
255        match (a, b) {
256            (Value::GpuTensor(ga), Value::Num(s)) => {
257                if let Ok(hc) = p.scalar_div(ga, *s) {
258                    return Ok(Value::GpuTensor(hc));
259                }
260            }
261            (Value::GpuTensor(ga), Value::Int(i)) => {
262                if let Ok(hc) = p.scalar_div(ga, i.to_f64()) {
263                    return Ok(Value::GpuTensor(hc));
264                }
265            }
266            (Value::Num(s), Value::GpuTensor(gb)) => {
267                if let Ok(hc) = p.scalar_rdiv(gb, *s) {
268                    return Ok(Value::GpuTensor(hc));
269                }
270            }
271            (Value::Int(i), Value::GpuTensor(gb)) => {
272                if let Ok(hc) = p.scalar_rdiv(gb, i.to_f64()) {
273                    return Ok(Value::GpuTensor(hc));
274                }
275            }
276            _ => {}
277        }
278    }
279    if matches!(a, Value::GpuTensor(_)) ^ matches!(b, Value::GpuTensor(_)) {
280        let ah = to_host_value(a).await?;
281        let bh = to_host_value(b).await?;
282        return elementwise_div(&ah, &bh).await;
283    }
284    if let Some(p) = runmat_accelerate_api::provider() {
285        if let (Value::GpuTensor(ha), Value::GpuTensor(hb)) = (a, b) {
286            if let Ok(hc) = p.elem_div(ha, hb).await {
287                return Ok(Value::GpuTensor(hc));
288            }
289        }
290    }
291    match (a, b) {
292        // Complex scalars
293        (Value::Complex(ar, ai), Value::Complex(br, bi)) => {
294            let denom = br * br + bi * bi;
295            if denom == 0.0 {
296                return Ok(Value::Num(f64::NAN));
297            }
298            Ok(Value::Complex(
299                (ar * br + ai * bi) / denom,
300                (ai * br - ar * bi) / denom,
301            ))
302        }
303        (Value::Complex(ar, ai), Value::Num(s)) => Ok(Value::Complex(ar / s, ai / s)),
304        (Value::Num(s), Value::Complex(br, bi)) => {
305            let denom = br * br + bi * bi;
306            if denom == 0.0 {
307                return Ok(Value::Num(f64::NAN));
308            }
309            Ok(Value::Complex((s * br) / denom, (-s * bi) / denom))
310        }
311        // Scalar-scalar case
312        (Value::Num(x), Value::Num(y)) => {
313            if *y == 0.0 {
314                Ok(Value::Num(f64::INFINITY * x.signum()))
315            } else {
316                Ok(Value::Num(x / y))
317            }
318        }
319        (Value::Int(x), Value::Num(y)) => {
320            if *y == 0.0 {
321                Ok(Value::Num(f64::INFINITY * x.to_f64().signum()))
322            } else {
323                Ok(Value::Num(x.to_f64() / y))
324            }
325        }
326        (Value::Num(x), Value::Int(y)) => {
327            if y.is_zero() {
328                Ok(Value::Num(f64::INFINITY * x.signum()))
329            } else {
330                Ok(Value::Num(x / y.to_f64()))
331            }
332        }
333        (Value::Int(x), Value::Int(y)) => {
334            if y.is_zero() {
335                Ok(Value::Num(f64::INFINITY * x.to_f64().signum()))
336            } else {
337                Ok(Value::Num(x.to_f64() / y.to_f64()))
338            }
339        }
340
341        // Matrix-scalar cases (broadcasting)
342        (Value::Tensor(m), Value::Num(s)) => {
343            if *s == 0.0 {
344                let data: Vec<f64> = m.data.iter().map(|x| f64::INFINITY * x.signum()).collect();
345                Ok(Value::Tensor(Tensor::new_2d(data, m.rows(), m.cols())?))
346            } else {
347                let data: Vec<f64> = m.data.iter().map(|x| x / s).collect();
348                Ok(Value::Tensor(Tensor::new_2d(data, m.rows(), m.cols())?))
349            }
350        }
351        (Value::Tensor(m), Value::Int(s)) => {
352            let scalar = s.to_f64();
353            if scalar == 0.0 {
354                let data: Vec<f64> = m.data.iter().map(|x| f64::INFINITY * x.signum()).collect();
355                Ok(Value::Tensor(Tensor::new_2d(data, m.rows(), m.cols())?))
356            } else {
357                let data: Vec<f64> = m.data.iter().map(|x| x / scalar).collect();
358                Ok(Value::Tensor(Tensor::new_2d(data, m.rows(), m.cols())?))
359            }
360        }
361        (Value::Num(s), Value::Tensor(m)) => {
362            let data: Vec<f64> = m
363                .data
364                .iter()
365                .map(|x| {
366                    if *x == 0.0 {
367                        f64::INFINITY * s.signum()
368                    } else {
369                        s / x
370                    }
371                })
372                .collect();
373            Ok(Value::Tensor(Tensor::new_2d(data, m.rows(), m.cols())?))
374        }
375        (Value::Int(s), Value::Tensor(m)) => {
376            let scalar = s.to_f64();
377            let data: Vec<f64> = m
378                .data
379                .iter()
380                .map(|x| {
381                    if *x == 0.0 {
382                        f64::INFINITY * scalar.signum()
383                    } else {
384                        scalar / x
385                    }
386                })
387                .collect();
388            Ok(Value::Tensor(Tensor::new_2d(data, m.rows(), m.cols())?))
389        }
390
391        // Matrix-matrix case
392        (Value::Tensor(m1), Value::Tensor(m2)) => {
393            if m1.rows() != m2.rows() || m1.cols() != m2.cols() {
394                return Err(format!(
395                    "Matrix dimensions must agree for element-wise division: {}x{} ./ {}x{}",
396                    m1.rows(),
397                    m1.cols(),
398                    m2.rows(),
399                    m2.cols()
400                ));
401            }
402            let data: Vec<f64> = m1
403                .data
404                .iter()
405                .zip(m2.data.iter())
406                .map(|(x, y)| {
407                    if *y == 0.0 {
408                        f64::INFINITY * x.signum()
409                    } else {
410                        x / y
411                    }
412                })
413                .collect();
414            Ok(Value::Tensor(Tensor::new_2d(data, m1.rows(), m1.cols())?))
415        }
416
417        // Complex tensors
418        (Value::ComplexTensor(m1), Value::ComplexTensor(m2)) => {
419            if m1.rows != m2.rows || m1.cols != m2.cols {
420                return Err(format!(
421                    "Matrix dimensions must agree for element-wise division: {}x{} ./ {}x{}",
422                    m1.rows, m1.cols, m2.rows, m2.cols
423                ));
424            }
425            let data: Vec<(f64, f64)> = m1
426                .data
427                .iter()
428                .zip(m2.data.iter())
429                .map(|((ar, ai), (br, bi))| {
430                    let denom = br * br + bi * bi;
431                    if denom == 0.0 {
432                        (f64::NAN, f64::NAN)
433                    } else {
434                        ((ar * br + ai * bi) / denom, (ai * br - ar * bi) / denom)
435                    }
436                })
437                .collect();
438            Ok(Value::ComplexTensor(
439                runmat_builtins::ComplexTensor::new_2d(data, m1.rows, m1.cols)?,
440            ))
441        }
442        (Value::ComplexTensor(m), Value::Num(s)) => {
443            let data: Vec<(f64, f64)> = m.data.iter().map(|(re, im)| (re / s, im / s)).collect();
444            Ok(Value::ComplexTensor(
445                runmat_builtins::ComplexTensor::new_2d(data, m.rows, m.cols)?,
446            ))
447        }
448        (Value::Num(s), Value::ComplexTensor(m)) => {
449            let data: Vec<(f64, f64)> = m
450                .data
451                .iter()
452                .map(|(br, bi)| {
453                    let denom = br * br + bi * bi;
454                    if denom == 0.0 {
455                        (f64::NAN, f64::NAN)
456                    } else {
457                        ((s * br) / denom, (-s * bi) / denom)
458                    }
459                })
460                .collect();
461            Ok(Value::ComplexTensor(
462                runmat_builtins::ComplexTensor::new_2d(data, m.rows, m.cols)?,
463            ))
464        }
465
466        _ => Err(format!(
467            "Element-wise division not supported for types: {a:?} ./ {b:?}"
468        )),
469    }
470}
471
472/// Regular power operation: A ^ B  
473/// For matrices, this is matrix exponentiation (A^n where n is integer)
474/// For scalars, this is regular exponentiation
475pub fn power(a: &Value, b: &Value) -> Result<Value, String> {
476    if let Some(result) = scalar_power_value(a, b) {
477        return Ok(result);
478    }
479    match (a, b) {
480        // Scalar cases - include complex
481        (Value::Complex(ar, ai), Value::Complex(br, bi)) => {
482            let (r, i) = complex_pow_scalar(*ar, *ai, *br, *bi);
483            Ok(Value::Complex(r, i))
484        }
485        (Value::Complex(ar, ai), Value::Num(y)) => {
486            let (r, i) = complex_pow_scalar(*ar, *ai, *y, 0.0);
487            Ok(Value::Complex(r, i))
488        }
489        (Value::Num(x), Value::Complex(br, bi)) => {
490            let (r, i) = complex_pow_scalar(*x, 0.0, *br, *bi);
491            Ok(Value::Complex(r, i))
492        }
493        (Value::Complex(ar, ai), Value::Int(y)) => {
494            let yv = y.to_f64();
495            let (r, i) = complex_pow_scalar(*ar, *ai, yv, 0.0);
496            Ok(Value::Complex(r, i))
497        }
498        (Value::Int(x), Value::Complex(br, bi)) => {
499            let xv = x.to_f64();
500            let (r, i) = complex_pow_scalar(xv, 0.0, *br, *bi);
501            Ok(Value::Complex(r, i))
502        }
503
504        // Scalar cases - real only
505        (Value::Num(x), Value::Num(y)) => Ok(Value::Num(x.powf(*y))),
506        (Value::Int(x), Value::Num(y)) => Ok(Value::Num(x.to_f64().powf(*y))),
507        (Value::Num(x), Value::Int(y)) => Ok(Value::Num(x.powf(y.to_f64()))),
508        (Value::Int(x), Value::Int(y)) => Ok(Value::Num(x.to_f64().powf(y.to_f64()))),
509
510        // Matrix^scalar case - matrix exponentiation
511        (Value::Tensor(m), Value::Num(s)) => {
512            // Check if scalar is an integer for matrix power
513            if s.fract() == 0.0 {
514                let n = *s as i32;
515                let result = matrix_power(m, n)?;
516                Ok(Value::Tensor(result))
517            } else {
518                Err("Matrix power requires integer exponent".to_string())
519            }
520        }
521        (Value::Tensor(m), Value::Int(s)) => {
522            let result = matrix_power(m, s.to_i64() as i32)?;
523            Ok(Value::Tensor(result))
524        }
525
526        // Complex matrix^integer case
527        (Value::ComplexTensor(m), Value::Num(s)) => {
528            if s.fract() == 0.0 {
529                let n = *s as i32;
530                let result = crate::matrix::complex_matrix_power(m, n)?;
531                Ok(Value::ComplexTensor(result))
532            } else {
533                Err("Matrix power requires integer exponent".to_string())
534            }
535        }
536        (Value::ComplexTensor(m), Value::Int(s)) => {
537            let result = crate::matrix::complex_matrix_power(m, s.to_i64() as i32)?;
538            Ok(Value::ComplexTensor(result))
539        }
540
541        // Other cases not supported for regular matrix power
542        _ => Err(format!(
543            "Power operation not supported for types: {a:?} ^ {b:?}"
544        )),
545    }
546}
547
548/// Element-wise power: A .^ B
549/// Supports matrix-matrix, matrix-scalar, and scalar-matrix operations
550pub fn elementwise_pow(a: &Value, b: &Value) -> Result<Value, String> {
551    match (a, b) {
552        // Complex scalar cases
553        (Value::Complex(ar, ai), Value::Complex(br, bi)) => {
554            let (r, i) = complex_pow_scalar(*ar, *ai, *br, *bi);
555            Ok(Value::Complex(r, i))
556        }
557        (Value::Complex(ar, ai), Value::Num(y)) => {
558            let (r, i) = complex_pow_scalar(*ar, *ai, *y, 0.0);
559            Ok(Value::Complex(r, i))
560        }
561        (Value::Num(x), Value::Complex(br, bi)) => {
562            let (r, i) = complex_pow_scalar(*x, 0.0, *br, *bi);
563            Ok(Value::Complex(r, i))
564        }
565        (Value::Complex(ar, ai), Value::Int(y)) => {
566            let yv = y.to_f64();
567            let (r, i) = complex_pow_scalar(*ar, *ai, yv, 0.0);
568            Ok(Value::Complex(r, i))
569        }
570        (Value::Int(x), Value::Complex(br, bi)) => {
571            let xv = x.to_f64();
572            let (r, i) = complex_pow_scalar(xv, 0.0, *br, *bi);
573            Ok(Value::Complex(r, i))
574        }
575        // Scalar-scalar case
576        (Value::Num(x), Value::Num(y)) => Ok(Value::Num(x.powf(*y))),
577        (Value::Int(x), Value::Num(y)) => Ok(Value::Num(x.to_f64().powf(*y))),
578        (Value::Num(x), Value::Int(y)) => Ok(Value::Num(x.powf(y.to_f64()))),
579        (Value::Int(x), Value::Int(y)) => Ok(Value::Num(x.to_f64().powf(y.to_f64()))),
580
581        // Matrix-scalar cases (broadcasting)
582        (Value::Tensor(m), Value::Num(s)) => {
583            let data: Vec<f64> = m.data.iter().map(|x| x.powf(*s)).collect();
584            Ok(Value::Tensor(Tensor::new_2d(data, m.rows(), m.cols())?))
585        }
586        (Value::Tensor(m), Value::Int(s)) => {
587            let scalar = s.to_f64();
588            let data: Vec<f64> = m.data.iter().map(|x| x.powf(scalar)).collect();
589            Ok(Value::Tensor(Tensor::new_2d(data, m.rows(), m.cols())?))
590        }
591        (Value::Num(s), Value::Tensor(m)) => {
592            let data: Vec<f64> = m.data.iter().map(|x| s.powf(*x)).collect();
593            Ok(Value::Tensor(Tensor::new_2d(data, m.rows(), m.cols())?))
594        }
595        (Value::Int(s), Value::Tensor(m)) => {
596            let scalar = s.to_f64();
597            let data: Vec<f64> = m.data.iter().map(|x| scalar.powf(*x)).collect();
598            Ok(Value::Tensor(Tensor::new_2d(data, m.rows(), m.cols())?))
599        }
600
601        // Matrix-matrix case
602        (Value::Tensor(m1), Value::Tensor(m2)) => {
603            if m1.rows() != m2.rows() || m1.cols() != m2.cols() {
604                return Err(format!(
605                    "Matrix dimensions must agree for element-wise power: {}x{} .^ {}x{}",
606                    m1.rows(),
607                    m1.cols(),
608                    m2.rows(),
609                    m2.cols()
610                ));
611            }
612            let data: Vec<f64> = m1
613                .data
614                .iter()
615                .zip(m2.data.iter())
616                .map(|(x, y)| x.powf(*y))
617                .collect();
618            Ok(Value::Tensor(Tensor::new_2d(data, m1.rows(), m1.cols())?))
619        }
620
621        // Complex tensor element-wise power
622        (Value::ComplexTensor(m1), Value::ComplexTensor(m2)) => {
623            if m1.rows != m2.rows || m1.cols != m2.cols {
624                return Err(format!(
625                    "Matrix dimensions must agree for element-wise power: {}x{} .^ {}x{}",
626                    m1.rows, m1.cols, m2.rows, m2.cols
627                ));
628            }
629            let mut out: Vec<(f64, f64)> = Vec::with_capacity(m1.data.len());
630            for i in 0..m1.data.len() {
631                let (ar, ai) = m1.data[i];
632                let (br, bi) = m2.data[i];
633                out.push(complex_pow_scalar(ar, ai, br, bi));
634            }
635            Ok(Value::ComplexTensor(
636                runmat_builtins::ComplexTensor::new_2d(out, m1.rows, m1.cols)?,
637            ))
638        }
639        (Value::ComplexTensor(m), Value::Num(s)) => {
640            let out: Vec<(f64, f64)> = m
641                .data
642                .iter()
643                .map(|(ar, ai)| complex_pow_scalar(*ar, *ai, *s, 0.0))
644                .collect();
645            Ok(Value::ComplexTensor(
646                runmat_builtins::ComplexTensor::new_2d(out, m.rows, m.cols)?,
647            ))
648        }
649        (Value::ComplexTensor(m), Value::Int(s)) => {
650            let sv = s.to_f64();
651            let out: Vec<(f64, f64)> = m
652                .data
653                .iter()
654                .map(|(ar, ai)| complex_pow_scalar(*ar, *ai, sv, 0.0))
655                .collect();
656            Ok(Value::ComplexTensor(
657                runmat_builtins::ComplexTensor::new_2d(out, m.rows, m.cols)?,
658            ))
659        }
660        (Value::ComplexTensor(m), Value::Complex(br, bi)) => {
661            let out: Vec<(f64, f64)> = m
662                .data
663                .iter()
664                .map(|(ar, ai)| complex_pow_scalar(*ar, *ai, *br, *bi))
665                .collect();
666            Ok(Value::ComplexTensor(
667                runmat_builtins::ComplexTensor::new_2d(out, m.rows, m.cols)?,
668            ))
669        }
670        (Value::Num(s), Value::ComplexTensor(m)) => {
671            let out: Vec<(f64, f64)> = m
672                .data
673                .iter()
674                .map(|(br, bi)| complex_pow_scalar(*s, 0.0, *br, *bi))
675                .collect();
676            Ok(Value::ComplexTensor(
677                runmat_builtins::ComplexTensor::new_2d(out, m.rows, m.cols)?,
678            ))
679        }
680        (Value::Int(s), Value::ComplexTensor(m)) => {
681            let sv = s.to_f64();
682            let out: Vec<(f64, f64)> = m
683                .data
684                .iter()
685                .map(|(br, bi)| complex_pow_scalar(sv, 0.0, *br, *bi))
686                .collect();
687            Ok(Value::ComplexTensor(
688                runmat_builtins::ComplexTensor::new_2d(out, m.rows, m.cols)?,
689            ))
690        }
691        (Value::Complex(br, bi), Value::ComplexTensor(m)) => {
692            let out: Vec<(f64, f64)> = m
693                .data
694                .iter()
695                .map(|(er, ei)| complex_pow_scalar(*br, *bi, *er, *ei))
696                .collect();
697            Ok(Value::ComplexTensor(
698                runmat_builtins::ComplexTensor::new_2d(out, m.rows, m.cols)?,
699            ))
700        }
701
702        _ => Err(format!(
703            "Element-wise power not supported for types: {a:?} .^ {b:?}"
704        )),
705    }
706}
707
708// Element-wise operations are not directly exposed as runtime builtins because they need
709// to handle multiple types (Value enum variants). Instead, they are called directly from
710// the interpreter and JIT compiler using the elementwise_* functions above.
711
712#[cfg(test)]
713mod tests {
714    use super::*;
715    use futures::executor::block_on;
716
717    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
718    #[test]
719    fn test_elementwise_mul_scalars() {
720        assert_eq!(
721            block_on(elementwise_mul(&Value::Num(3.0), &Value::Num(4.0))).unwrap(),
722            Value::Num(12.0)
723        );
724        assert_eq!(
725            block_on(elementwise_mul(
726                &Value::Int(runmat_builtins::IntValue::I32(3)),
727                &Value::Num(4.5)
728            ))
729            .unwrap(),
730            Value::Num(13.5)
731        );
732    }
733
734    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
735    #[test]
736    fn test_elementwise_mul_matrix_scalar() {
737        let matrix = Tensor::new_2d(vec![1.0, 2.0, 3.0, 4.0], 2, 2).unwrap();
738        let result = block_on(elementwise_mul(&Value::Tensor(matrix), &Value::Num(2.0))).unwrap();
739
740        if let Value::Tensor(m) = result {
741            assert_eq!(m.data, vec![2.0, 4.0, 6.0, 8.0]);
742            assert_eq!(m.rows(), 2);
743            assert_eq!(m.cols(), 2);
744        } else {
745            panic!("Expected matrix result");
746        }
747    }
748
749    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
750    #[test]
751    fn test_elementwise_mul_matrices() {
752        let m1 = Tensor::new_2d(vec![1.0, 2.0, 3.0, 4.0], 2, 2).unwrap();
753        let m2 = Tensor::new_2d(vec![2.0, 3.0, 4.0, 5.0], 2, 2).unwrap();
754        let result = block_on(elementwise_mul(&Value::Tensor(m1), &Value::Tensor(m2))).unwrap();
755
756        if let Value::Tensor(m) = result {
757            assert_eq!(m.data, vec![2.0, 6.0, 12.0, 20.0]);
758        } else {
759            panic!("Expected matrix result");
760        }
761    }
762
763    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
764    #[test]
765    fn test_elementwise_div_with_zero() {
766        let result = block_on(elementwise_div(&Value::Num(5.0), &Value::Num(0.0))).unwrap();
767        if let Value::Num(n) = result {
768            assert!(n.is_infinite() && n.is_sign_positive());
769        } else {
770            panic!("Expected numeric result");
771        }
772    }
773
774    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
775    #[test]
776    fn test_elementwise_pow() {
777        let matrix = Tensor::new_2d(vec![2.0, 3.0, 4.0, 5.0], 2, 2).unwrap();
778        let result = elementwise_pow(&Value::Tensor(matrix), &Value::Num(2.0)).unwrap();
779
780        if let Value::Tensor(m) = result {
781            assert_eq!(m.data, vec![4.0, 9.0, 16.0, 25.0]);
782        } else {
783            panic!("Expected matrix result");
784        }
785    }
786
787    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
788    #[test]
789    fn test_dimension_mismatch() {
790        let m1 = Tensor::new_2d(vec![1.0, 2.0], 1, 2).unwrap();
791        let m2 = Tensor::new_2d(vec![1.0, 2.0, 3.0, 4.0], 2, 2).unwrap();
792
793        assert!(block_on(elementwise_mul(&Value::Tensor(m1), &Value::Tensor(m2))).is_err());
794    }
795}