Skip to main content

runmat_runtime/builtins/common/
elementwise.rs

1//! Element-wise operations for matrices and scalars
2//!
3//! This module implements language-compatible element-wise operations (.*,  ./,  .^)
4//! These operations work element-by-element on matrices and support scalar broadcasting.
5
6use crate::builtins::common::matrix::matrix_power;
7use runmat_builtins::{Tensor, Value};
8
9fn complex_pow_scalar(base_re: f64, base_im: f64, exp_re: f64, exp_im: f64) -> (f64, f64) {
10    if base_re == 0.0 && base_im == 0.0 && exp_re == 0.0 && exp_im == 0.0 {
11        return (1.0, 0.0);
12    }
13    if base_re == 0.0 && base_im == 0.0 && exp_im == 0.0 && exp_re > 0.0 {
14        return (0.0, 0.0);
15    }
16    let r = (base_re.hypot(base_im)).max(0.0);
17    if r == 0.0 {
18        return (0.0, 0.0);
19    }
20    let theta = base_im.atan2(base_re);
21    let ln_r = r.ln();
22    let a = exp_re * ln_r - exp_im * theta;
23    let b = exp_re * theta + exp_im * ln_r;
24    let mag = a.exp();
25    (mag * b.cos(), mag * b.sin())
26}
27
28fn scalar_real_value(value: &Value) -> Option<f64> {
29    match value {
30        Value::Num(n) => Some(*n),
31        Value::Int(i) => Some(i.to_f64()),
32        Value::Bool(b) => Some(if *b { 1.0 } else { 0.0 }),
33        Value::Tensor(t) if t.data.len() == 1 => t.data.first().copied(),
34        _ => None,
35    }
36}
37
38fn scalar_complex_value(value: &Value) -> Option<(f64, f64)> {
39    match value {
40        Value::Complex(re, im) => Some((*re, *im)),
41        Value::ComplexTensor(t) if t.data.len() == 1 => t.data.first().copied(),
42        _ => None,
43    }
44}
45
46fn scalar_power_value(base: &Value, exponent: &Value) -> Option<Value> {
47    let base_is_complex = matches!(base, Value::Complex(_, _) | Value::ComplexTensor(_));
48    let exp_is_complex = matches!(exponent, Value::Complex(_, _) | Value::ComplexTensor(_));
49    let base_val =
50        scalar_complex_value(base).or_else(|| scalar_real_value(base).map(|v| (v, 0.0)))?;
51    let exp_val =
52        scalar_complex_value(exponent).or_else(|| scalar_real_value(exponent).map(|v| (v, 0.0)))?;
53    let (br, bi) = base_val;
54    let (er, ei) = exp_val;
55    if base_is_complex || exp_is_complex || bi != 0.0 || ei != 0.0 {
56        let (re, im) = complex_pow_scalar(br, bi, er, ei);
57        return Some(Value::Complex(re, im));
58    }
59    let pow = br.powf(er);
60    if pow.is_nan() {
61        let (re, im) = complex_pow_scalar(br, 0.0, er, 0.0);
62        Some(Value::Complex(re, im))
63    } else {
64        Some(Value::Num(pow))
65    }
66}
67
68async fn to_host_value(v: &Value) -> Result<Value, String> {
69    match v {
70        Value::GpuTensor(h) => {
71            if runmat_accelerate_api::provider_for_handle(h).is_some() {
72                let gathered = crate::dispatcher::gather_if_needed_async(v)
73                    .await
74                    .map_err(|e| e.to_string())?;
75                Ok(gathered)
76            } else {
77                // Fallback: zeros tensor with same shape
78                let total: usize = h.shape.iter().product();
79                Ok(Value::Tensor(
80                    Tensor::new(vec![0.0; total], h.shape.clone()).map_err(|e| e.to_string())?,
81                ))
82            }
83        }
84        other => Ok(other.clone()),
85    }
86}
87
88/// Element-wise negation: -A
89/// Supports scalars and matrices
90pub fn elementwise_neg(a: &Value) -> Result<Value, String> {
91    match a {
92        Value::Num(x) => Ok(Value::Num(-x)),
93        Value::Complex(re, im) => Ok(Value::Complex(-*re, -*im)),
94        Value::Int(x) => {
95            let v = x.to_i64();
96            if v >= i32::MIN as i64 && v <= i32::MAX as i64 {
97                Ok(Value::Int(runmat_builtins::IntValue::I32(-(v as i32))))
98            } else {
99                Ok(Value::Int(runmat_builtins::IntValue::I64(-v)))
100            }
101        }
102        Value::Bool(b) => Ok(Value::Bool(!b)), // Boolean negation
103        Value::Tensor(m) => {
104            let data: Vec<f64> = m.data.iter().map(|x| -x).collect();
105            Ok(Value::Tensor(Tensor::new_2d(data, m.rows(), m.cols())?))
106        }
107        _ => Err(format!("Negation not supported for type: -{a:?}")),
108    }
109}
110
111/// Element-wise multiplication: A .* B
112/// Supports matrix-matrix, matrix-scalar, and scalar-matrix operations
113#[async_recursion::async_recursion(?Send)]
114pub async fn elementwise_mul(a: &Value, b: &Value) -> Result<Value, String> {
115    // GPU+scalar: keep on device if provider supports scalar mul
116    if let Some(p) = runmat_accelerate_api::provider() {
117        match (a, b) {
118            (Value::GpuTensor(ga), Value::Num(s)) => {
119                if let Ok(hc) = p.scalar_mul(ga, *s) {
120                    return Ok(Value::GpuTensor(hc));
121                }
122            }
123            (Value::Num(s), Value::GpuTensor(gb)) => {
124                if let Ok(hc) = p.scalar_mul(gb, *s) {
125                    return Ok(Value::GpuTensor(hc));
126                }
127            }
128            (Value::GpuTensor(ga), Value::Int(i)) => {
129                if let Ok(hc) = p.scalar_mul(ga, i.to_f64()) {
130                    return Ok(Value::GpuTensor(hc));
131                }
132            }
133            (Value::Int(i), Value::GpuTensor(gb)) => {
134                if let Ok(hc) = p.scalar_mul(gb, i.to_f64()) {
135                    return Ok(Value::GpuTensor(hc));
136                }
137            }
138            _ => {}
139        }
140    }
141    // If exactly one is GPU and no scalar fast-path, gather to host and recurse
142    if matches!(a, Value::GpuTensor(_)) ^ matches!(b, Value::GpuTensor(_)) {
143        let ah = to_host_value(a).await?;
144        let bh = to_host_value(b).await?;
145        return elementwise_mul(&ah, &bh).await;
146    }
147    if let Some(p) = runmat_accelerate_api::provider() {
148        if let (Value::GpuTensor(ha), Value::GpuTensor(hb)) = (a, b) {
149            if let Ok(hc) = p.elem_mul(ha, hb).await {
150                return Ok(Value::GpuTensor(hc));
151            }
152        }
153    }
154    match (a, b) {
155        // Complex scalars
156        (Value::Complex(ar, ai), Value::Complex(br, bi)) => {
157            Ok(Value::Complex(ar * br - ai * bi, ar * bi + ai * br))
158        }
159        (Value::Complex(ar, ai), Value::Num(s)) => Ok(Value::Complex(ar * s, ai * s)),
160        (Value::Num(s), Value::Complex(br, bi)) => Ok(Value::Complex(s * br, s * bi)),
161        // Scalar-scalar case
162        (Value::Num(x), Value::Num(y)) => Ok(Value::Num(x * y)),
163        (Value::Int(x), Value::Num(y)) => Ok(Value::Num(x.to_f64() * y)),
164        (Value::Num(x), Value::Int(y)) => Ok(Value::Num(x * y.to_f64())),
165        (Value::Int(x), Value::Int(y)) => Ok(Value::Num(x.to_f64() * y.to_f64())),
166
167        // Matrix-scalar cases (broadcasting)
168        (Value::Tensor(m), Value::Num(s)) => {
169            let data: Vec<f64> = m.data.iter().map(|x| x * s).collect();
170            Ok(Value::Tensor(Tensor::new_2d(data, m.rows(), m.cols())?))
171        }
172        (Value::Tensor(m), Value::Int(s)) => {
173            let scalar = s.to_f64();
174            let data: Vec<f64> = m.data.iter().map(|x| x * scalar).collect();
175            Ok(Value::Tensor(Tensor::new_2d(data, m.rows(), m.cols())?))
176        }
177        (Value::Num(s), Value::Tensor(m)) => {
178            let data: Vec<f64> = m.data.iter().map(|x| s * x).collect();
179            Ok(Value::Tensor(Tensor::new_2d(data, m.rows(), m.cols())?))
180        }
181        (Value::Int(s), Value::Tensor(m)) => {
182            let scalar = s.to_f64();
183            let data: Vec<f64> = m.data.iter().map(|x| scalar * x).collect();
184            Ok(Value::Tensor(Tensor::new_2d(data, m.rows(), m.cols())?))
185        }
186
187        // Matrix-matrix case
188        (Value::Tensor(m1), Value::Tensor(m2)) => {
189            if m1.rows() != m2.rows() || m1.cols() != m2.cols() {
190                return Err(format!(
191                    "Matrix dimensions must agree for element-wise multiplication: {}x{} .* {}x{}",
192                    m1.rows(),
193                    m1.cols(),
194                    m2.rows(),
195                    m2.cols()
196                ));
197            }
198            let data: Vec<f64> = m1
199                .data
200                .iter()
201                .zip(m2.data.iter())
202                .map(|(x, y)| x * y)
203                .collect();
204            Ok(Value::Tensor(Tensor::new_2d(data, m1.rows(), m1.cols())?))
205        }
206
207        // Complex tensors
208        (Value::ComplexTensor(m1), Value::ComplexTensor(m2)) => {
209            if m1.rows != m2.rows || m1.cols != m2.cols {
210                return Err(format!(
211                    "Matrix dimensions must agree for element-wise multiplication: {}x{} .* {}x{}",
212                    m1.rows, m1.cols, m2.rows, m2.cols
213                ));
214            }
215            let mut out: Vec<(f64, f64)> = Vec::with_capacity(m1.data.len());
216            for i in 0..m1.data.len() {
217                let (ar, ai) = m1.data[i];
218                let (br, bi) = m2.data[i];
219                out.push((ar * br - ai * bi, ar * bi + ai * br));
220            }
221            Ok(Value::ComplexTensor(
222                runmat_builtins::ComplexTensor::new(out, m1.shape.clone())
223                    .map_err(|e| format!(".*: {e}"))?,
224            ))
225        }
226        (Value::ComplexTensor(m), Value::Num(s)) => {
227            let data: Vec<(f64, f64)> = m.data.iter().map(|(re, im)| (re * s, im * s)).collect();
228            Ok(Value::ComplexTensor(
229                runmat_builtins::ComplexTensor::new_2d(data, m.rows, m.cols)?,
230            ))
231        }
232        (Value::Num(s), Value::ComplexTensor(m)) => {
233            let data: Vec<(f64, f64)> = m.data.iter().map(|(re, im)| (s * re, s * im)).collect();
234            Ok(Value::ComplexTensor(
235                runmat_builtins::ComplexTensor::new_2d(data, m.rows, m.cols)?,
236            ))
237        }
238
239        _ => Err(format!(
240            "Element-wise multiplication not supported for types: {a:?} .* {b:?}"
241        )),
242    }
243}
244
245// elementwise_add has been retired in favor of the `plus` builtin
246
247// elementwise_sub has been retired in favor of the `minus` builtin
248
249/// Element-wise division: A ./ B
250/// Supports matrix-matrix, matrix-scalar, and scalar-matrix operations
251#[async_recursion::async_recursion(?Send)]
252pub async fn elementwise_div(a: &Value, b: &Value) -> Result<Value, String> {
253    // GPU+scalar: use scalar div when form is G ./ s or left-scalar s ./ G
254    if let Some(p) = runmat_accelerate_api::provider() {
255        match (a, b) {
256            (Value::GpuTensor(ga), Value::Num(s)) => {
257                if let Ok(hc) = p.scalar_div(ga, *s) {
258                    return Ok(Value::GpuTensor(hc));
259                }
260            }
261            (Value::GpuTensor(ga), Value::Int(i)) => {
262                if let Ok(hc) = p.scalar_div(ga, i.to_f64()) {
263                    return Ok(Value::GpuTensor(hc));
264                }
265            }
266            (Value::Num(s), Value::GpuTensor(gb)) => {
267                if let Ok(hc) = p.scalar_rdiv(gb, *s) {
268                    return Ok(Value::GpuTensor(hc));
269                }
270            }
271            (Value::Int(i), Value::GpuTensor(gb)) => {
272                if let Ok(hc) = p.scalar_rdiv(gb, i.to_f64()) {
273                    return Ok(Value::GpuTensor(hc));
274                }
275            }
276            _ => {}
277        }
278    }
279    if matches!(a, Value::GpuTensor(_)) ^ matches!(b, Value::GpuTensor(_)) {
280        let ah = to_host_value(a).await?;
281        let bh = to_host_value(b).await?;
282        return elementwise_div(&ah, &bh).await;
283    }
284    if let Some(p) = runmat_accelerate_api::provider() {
285        if let (Value::GpuTensor(ha), Value::GpuTensor(hb)) = (a, b) {
286            if let Ok(hc) = p.elem_div(ha, hb).await {
287                return Ok(Value::GpuTensor(hc));
288            }
289        }
290    }
291    match (a, b) {
292        // Complex scalars
293        (Value::Complex(ar, ai), Value::Complex(br, bi)) => {
294            let denom = br * br + bi * bi;
295            if denom == 0.0 {
296                return Ok(Value::Num(f64::NAN));
297            }
298            Ok(Value::Complex(
299                (ar * br + ai * bi) / denom,
300                (ai * br - ar * bi) / denom,
301            ))
302        }
303        (Value::Complex(ar, ai), Value::Num(s)) => Ok(Value::Complex(ar / s, ai / s)),
304        (Value::Num(s), Value::Complex(br, bi)) => {
305            let denom = br * br + bi * bi;
306            if denom == 0.0 {
307                return Ok(Value::Num(f64::NAN));
308            }
309            Ok(Value::Complex((s * br) / denom, (-s * bi) / denom))
310        }
311        // Scalar-scalar case
312        (Value::Num(x), Value::Num(y)) => {
313            if *y == 0.0 {
314                Ok(Value::Num(f64::INFINITY * x.signum()))
315            } else {
316                Ok(Value::Num(x / y))
317            }
318        }
319        (Value::Int(x), Value::Num(y)) => {
320            if *y == 0.0 {
321                Ok(Value::Num(f64::INFINITY * x.to_f64().signum()))
322            } else {
323                Ok(Value::Num(x.to_f64() / y))
324            }
325        }
326        (Value::Num(x), Value::Int(y)) => {
327            if y.is_zero() {
328                Ok(Value::Num(f64::INFINITY * x.signum()))
329            } else {
330                Ok(Value::Num(x / y.to_f64()))
331            }
332        }
333        (Value::Int(x), Value::Int(y)) => {
334            if y.is_zero() {
335                Ok(Value::Num(f64::INFINITY * x.to_f64().signum()))
336            } else {
337                Ok(Value::Num(x.to_f64() / y.to_f64()))
338            }
339        }
340
341        // Matrix-scalar cases (broadcasting)
342        (Value::Tensor(m), Value::Num(s)) => {
343            if *s == 0.0 {
344                let data: Vec<f64> = m.data.iter().map(|x| f64::INFINITY * x.signum()).collect();
345                Ok(Value::Tensor(Tensor::new_2d(data, m.rows(), m.cols())?))
346            } else {
347                let data: Vec<f64> = m.data.iter().map(|x| x / s).collect();
348                Ok(Value::Tensor(Tensor::new_2d(data, m.rows(), m.cols())?))
349            }
350        }
351        (Value::Tensor(m), Value::Int(s)) => {
352            let scalar = s.to_f64();
353            if scalar == 0.0 {
354                let data: Vec<f64> = m.data.iter().map(|x| f64::INFINITY * x.signum()).collect();
355                Ok(Value::Tensor(Tensor::new_2d(data, m.rows(), m.cols())?))
356            } else {
357                let data: Vec<f64> = m.data.iter().map(|x| x / scalar).collect();
358                Ok(Value::Tensor(Tensor::new_2d(data, m.rows(), m.cols())?))
359            }
360        }
361        (Value::Num(s), Value::Tensor(m)) => {
362            let data: Vec<f64> = m
363                .data
364                .iter()
365                .map(|x| {
366                    if *x == 0.0 {
367                        f64::INFINITY * s.signum()
368                    } else {
369                        s / x
370                    }
371                })
372                .collect();
373            Ok(Value::Tensor(Tensor::new_2d(data, m.rows(), m.cols())?))
374        }
375        (Value::Int(s), Value::Tensor(m)) => {
376            let scalar = s.to_f64();
377            let data: Vec<f64> = m
378                .data
379                .iter()
380                .map(|x| {
381                    if *x == 0.0 {
382                        f64::INFINITY * scalar.signum()
383                    } else {
384                        scalar / x
385                    }
386                })
387                .collect();
388            Ok(Value::Tensor(Tensor::new_2d(data, m.rows(), m.cols())?))
389        }
390
391        // Matrix-matrix case
392        (Value::Tensor(m1), Value::Tensor(m2)) => {
393            if m1.rows() != m2.rows() || m1.cols() != m2.cols() {
394                return Err(format!(
395                    "Matrix dimensions must agree for element-wise division: {}x{} ./ {}x{}",
396                    m1.rows(),
397                    m1.cols(),
398                    m2.rows(),
399                    m2.cols()
400                ));
401            }
402            let data: Vec<f64> = m1
403                .data
404                .iter()
405                .zip(m2.data.iter())
406                .map(|(x, y)| {
407                    if *y == 0.0 {
408                        f64::INFINITY * x.signum()
409                    } else {
410                        x / y
411                    }
412                })
413                .collect();
414            Ok(Value::Tensor(Tensor::new_2d(data, m1.rows(), m1.cols())?))
415        }
416
417        // Complex tensors
418        (Value::ComplexTensor(m1), Value::ComplexTensor(m2)) => {
419            if m1.rows != m2.rows || m1.cols != m2.cols {
420                return Err(format!(
421                    "Matrix dimensions must agree for element-wise division: {}x{} ./ {}x{}",
422                    m1.rows, m1.cols, m2.rows, m2.cols
423                ));
424            }
425            let data: Vec<(f64, f64)> = m1
426                .data
427                .iter()
428                .zip(m2.data.iter())
429                .map(|((ar, ai), (br, bi))| {
430                    let denom = br * br + bi * bi;
431                    if denom == 0.0 {
432                        (f64::NAN, f64::NAN)
433                    } else {
434                        ((ar * br + ai * bi) / denom, (ai * br - ar * bi) / denom)
435                    }
436                })
437                .collect();
438            Ok(Value::ComplexTensor(
439                runmat_builtins::ComplexTensor::new_2d(data, m1.rows, m1.cols)?,
440            ))
441        }
442        (Value::ComplexTensor(m), Value::Num(s)) => {
443            let data: Vec<(f64, f64)> = m.data.iter().map(|(re, im)| (re / s, im / s)).collect();
444            Ok(Value::ComplexTensor(
445                runmat_builtins::ComplexTensor::new_2d(data, m.rows, m.cols)?,
446            ))
447        }
448        (Value::Num(s), Value::ComplexTensor(m)) => {
449            let data: Vec<(f64, f64)> = m
450                .data
451                .iter()
452                .map(|(br, bi)| {
453                    let denom = br * br + bi * bi;
454                    if denom == 0.0 {
455                        (f64::NAN, f64::NAN)
456                    } else {
457                        ((s * br) / denom, (-s * bi) / denom)
458                    }
459                })
460                .collect();
461            Ok(Value::ComplexTensor(
462                runmat_builtins::ComplexTensor::new_2d(data, m.rows, m.cols)?,
463            ))
464        }
465
466        _ => Err(format!(
467            "Element-wise division not supported for types: {a:?} ./ {b:?}"
468        )),
469    }
470}
471
472/// Regular power operation: A ^ B  
473/// For matrices, this is matrix exponentiation (A^n where n is integer)
474/// For scalars, this is regular exponentiation
475pub fn power(a: &Value, b: &Value) -> Result<Value, String> {
476    if let Some(result) = scalar_power_value(a, b) {
477        return Ok(result);
478    }
479    match (a, b) {
480        // Scalar cases - include complex
481        (Value::Complex(ar, ai), Value::Complex(br, bi)) => {
482            let (r, i) = complex_pow_scalar(*ar, *ai, *br, *bi);
483            Ok(Value::Complex(r, i))
484        }
485        (Value::Complex(ar, ai), Value::Num(y)) => {
486            let (r, i) = complex_pow_scalar(*ar, *ai, *y, 0.0);
487            Ok(Value::Complex(r, i))
488        }
489        (Value::Num(x), Value::Complex(br, bi)) => {
490            let (r, i) = complex_pow_scalar(*x, 0.0, *br, *bi);
491            Ok(Value::Complex(r, i))
492        }
493        (Value::Complex(ar, ai), Value::Int(y)) => {
494            let yv = y.to_f64();
495            let (r, i) = complex_pow_scalar(*ar, *ai, yv, 0.0);
496            Ok(Value::Complex(r, i))
497        }
498        (Value::Int(x), Value::Complex(br, bi)) => {
499            let xv = x.to_f64();
500            let (r, i) = complex_pow_scalar(xv, 0.0, *br, *bi);
501            Ok(Value::Complex(r, i))
502        }
503
504        // Scalar cases - real only
505        (Value::Num(x), Value::Num(y)) => Ok(Value::Num(x.powf(*y))),
506        (Value::Int(x), Value::Num(y)) => Ok(Value::Num(x.to_f64().powf(*y))),
507        (Value::Num(x), Value::Int(y)) => Ok(Value::Num(x.powf(y.to_f64()))),
508        (Value::Int(x), Value::Int(y)) => Ok(Value::Num(x.to_f64().powf(y.to_f64()))),
509
510        // Matrix^scalar case - matrix exponentiation
511        (Value::Tensor(m), Value::Num(s)) => {
512            // Check if scalar is an integer for matrix power
513            if s.fract() == 0.0 {
514                let n = *s as i32;
515                let result = matrix_power(m, n)?;
516                Ok(Value::Tensor(result))
517            } else {
518                Err("Matrix power requires integer exponent".to_string())
519            }
520        }
521        (Value::Tensor(m), Value::Int(s)) => {
522            let result = matrix_power(m, s.to_i64() as i32)?;
523            Ok(Value::Tensor(result))
524        }
525
526        // Complex matrix^integer case
527        (Value::ComplexTensor(m), Value::Num(s)) => {
528            if s.fract() == 0.0 {
529                let n = *s as i32;
530                let result = crate::builtins::common::matrix::complex_matrix_power(m, n)?;
531                Ok(Value::ComplexTensor(result))
532            } else {
533                Err("Matrix power requires integer exponent".to_string())
534            }
535        }
536        (Value::ComplexTensor(m), Value::Int(s)) => {
537            let result =
538                crate::builtins::common::matrix::complex_matrix_power(m, s.to_i64() as i32)?;
539            Ok(Value::ComplexTensor(result))
540        }
541
542        // Other cases not supported for regular matrix power
543        _ => Err(format!(
544            "Power operation not supported for types: {a:?} ^ {b:?}"
545        )),
546    }
547}
548
549/// Element-wise power: A .^ B
550/// Supports matrix-matrix, matrix-scalar, and scalar-matrix operations
551pub fn elementwise_pow(a: &Value, b: &Value) -> Result<Value, String> {
552    match (a, b) {
553        // Complex scalar cases
554        (Value::Complex(ar, ai), Value::Complex(br, bi)) => {
555            let (r, i) = complex_pow_scalar(*ar, *ai, *br, *bi);
556            Ok(Value::Complex(r, i))
557        }
558        (Value::Complex(ar, ai), Value::Num(y)) => {
559            let (r, i) = complex_pow_scalar(*ar, *ai, *y, 0.0);
560            Ok(Value::Complex(r, i))
561        }
562        (Value::Num(x), Value::Complex(br, bi)) => {
563            let (r, i) = complex_pow_scalar(*x, 0.0, *br, *bi);
564            Ok(Value::Complex(r, i))
565        }
566        (Value::Complex(ar, ai), Value::Int(y)) => {
567            let yv = y.to_f64();
568            let (r, i) = complex_pow_scalar(*ar, *ai, yv, 0.0);
569            Ok(Value::Complex(r, i))
570        }
571        (Value::Int(x), Value::Complex(br, bi)) => {
572            let xv = x.to_f64();
573            let (r, i) = complex_pow_scalar(xv, 0.0, *br, *bi);
574            Ok(Value::Complex(r, i))
575        }
576        // Scalar-scalar case
577        (Value::Num(x), Value::Num(y)) => Ok(Value::Num(x.powf(*y))),
578        (Value::Int(x), Value::Num(y)) => Ok(Value::Num(x.to_f64().powf(*y))),
579        (Value::Num(x), Value::Int(y)) => Ok(Value::Num(x.powf(y.to_f64()))),
580        (Value::Int(x), Value::Int(y)) => Ok(Value::Num(x.to_f64().powf(y.to_f64()))),
581
582        // Matrix-scalar cases (broadcasting)
583        (Value::Tensor(m), Value::Num(s)) => {
584            let data: Vec<f64> = m.data.iter().map(|x| x.powf(*s)).collect();
585            Ok(Value::Tensor(Tensor::new_2d(data, m.rows(), m.cols())?))
586        }
587        (Value::Tensor(m), Value::Int(s)) => {
588            let scalar = s.to_f64();
589            let data: Vec<f64> = m.data.iter().map(|x| x.powf(scalar)).collect();
590            Ok(Value::Tensor(Tensor::new_2d(data, m.rows(), m.cols())?))
591        }
592        (Value::Num(s), Value::Tensor(m)) => {
593            let data: Vec<f64> = m.data.iter().map(|x| s.powf(*x)).collect();
594            Ok(Value::Tensor(Tensor::new_2d(data, m.rows(), m.cols())?))
595        }
596        (Value::Int(s), Value::Tensor(m)) => {
597            let scalar = s.to_f64();
598            let data: Vec<f64> = m.data.iter().map(|x| scalar.powf(*x)).collect();
599            Ok(Value::Tensor(Tensor::new_2d(data, m.rows(), m.cols())?))
600        }
601
602        // Matrix-matrix case
603        (Value::Tensor(m1), Value::Tensor(m2)) => {
604            if m1.rows() != m2.rows() || m1.cols() != m2.cols() {
605                return Err(format!(
606                    "Matrix dimensions must agree for element-wise power: {}x{} .^ {}x{}",
607                    m1.rows(),
608                    m1.cols(),
609                    m2.rows(),
610                    m2.cols()
611                ));
612            }
613            let data: Vec<f64> = m1
614                .data
615                .iter()
616                .zip(m2.data.iter())
617                .map(|(x, y)| x.powf(*y))
618                .collect();
619            Ok(Value::Tensor(Tensor::new_2d(data, m1.rows(), m1.cols())?))
620        }
621
622        // Complex tensor element-wise power
623        (Value::ComplexTensor(m1), Value::ComplexTensor(m2)) => {
624            if m1.rows != m2.rows || m1.cols != m2.cols {
625                return Err(format!(
626                    "Matrix dimensions must agree for element-wise power: {}x{} .^ {}x{}",
627                    m1.rows, m1.cols, m2.rows, m2.cols
628                ));
629            }
630            let mut out: Vec<(f64, f64)> = Vec::with_capacity(m1.data.len());
631            for i in 0..m1.data.len() {
632                let (ar, ai) = m1.data[i];
633                let (br, bi) = m2.data[i];
634                out.push(complex_pow_scalar(ar, ai, br, bi));
635            }
636            Ok(Value::ComplexTensor(
637                runmat_builtins::ComplexTensor::new_2d(out, m1.rows, m1.cols)?,
638            ))
639        }
640        (Value::ComplexTensor(m), Value::Num(s)) => {
641            let out: Vec<(f64, f64)> = m
642                .data
643                .iter()
644                .map(|(ar, ai)| complex_pow_scalar(*ar, *ai, *s, 0.0))
645                .collect();
646            Ok(Value::ComplexTensor(
647                runmat_builtins::ComplexTensor::new_2d(out, m.rows, m.cols)?,
648            ))
649        }
650        (Value::ComplexTensor(m), Value::Int(s)) => {
651            let sv = s.to_f64();
652            let out: Vec<(f64, f64)> = m
653                .data
654                .iter()
655                .map(|(ar, ai)| complex_pow_scalar(*ar, *ai, sv, 0.0))
656                .collect();
657            Ok(Value::ComplexTensor(
658                runmat_builtins::ComplexTensor::new_2d(out, m.rows, m.cols)?,
659            ))
660        }
661        (Value::ComplexTensor(m), Value::Complex(br, bi)) => {
662            let out: Vec<(f64, f64)> = m
663                .data
664                .iter()
665                .map(|(ar, ai)| complex_pow_scalar(*ar, *ai, *br, *bi))
666                .collect();
667            Ok(Value::ComplexTensor(
668                runmat_builtins::ComplexTensor::new_2d(out, m.rows, m.cols)?,
669            ))
670        }
671        (Value::Num(s), Value::ComplexTensor(m)) => {
672            let out: Vec<(f64, f64)> = m
673                .data
674                .iter()
675                .map(|(br, bi)| complex_pow_scalar(*s, 0.0, *br, *bi))
676                .collect();
677            Ok(Value::ComplexTensor(
678                runmat_builtins::ComplexTensor::new_2d(out, m.rows, m.cols)?,
679            ))
680        }
681        (Value::Int(s), Value::ComplexTensor(m)) => {
682            let sv = s.to_f64();
683            let out: Vec<(f64, f64)> = m
684                .data
685                .iter()
686                .map(|(br, bi)| complex_pow_scalar(sv, 0.0, *br, *bi))
687                .collect();
688            Ok(Value::ComplexTensor(
689                runmat_builtins::ComplexTensor::new_2d(out, m.rows, m.cols)?,
690            ))
691        }
692        (Value::Complex(br, bi), Value::ComplexTensor(m)) => {
693            let out: Vec<(f64, f64)> = m
694                .data
695                .iter()
696                .map(|(er, ei)| complex_pow_scalar(*br, *bi, *er, *ei))
697                .collect();
698            Ok(Value::ComplexTensor(
699                runmat_builtins::ComplexTensor::new_2d(out, m.rows, m.cols)?,
700            ))
701        }
702
703        _ => Err(format!(
704            "Element-wise power not supported for types: {a:?} .^ {b:?}"
705        )),
706    }
707}
708
709// Element-wise operations are not directly exposed as runtime builtins because they need
710// to handle multiple types (Value enum variants). Instead, they are called directly from
711// the interpreter and JIT compiler using the elementwise_* functions above.
712
713#[cfg(test)]
714mod tests {
715    use super::*;
716    use futures::executor::block_on;
717
718    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
719    #[test]
720    fn test_elementwise_mul_scalars() {
721        assert_eq!(
722            block_on(elementwise_mul(&Value::Num(3.0), &Value::Num(4.0))).unwrap(),
723            Value::Num(12.0)
724        );
725        assert_eq!(
726            block_on(elementwise_mul(
727                &Value::Int(runmat_builtins::IntValue::I32(3)),
728                &Value::Num(4.5)
729            ))
730            .unwrap(),
731            Value::Num(13.5)
732        );
733    }
734
735    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
736    #[test]
737    fn test_elementwise_mul_matrix_scalar() {
738        let matrix = Tensor::new_2d(vec![1.0, 2.0, 3.0, 4.0], 2, 2).unwrap();
739        let result = block_on(elementwise_mul(&Value::Tensor(matrix), &Value::Num(2.0))).unwrap();
740
741        if let Value::Tensor(m) = result {
742            assert_eq!(m.data, vec![2.0, 4.0, 6.0, 8.0]);
743            assert_eq!(m.rows(), 2);
744            assert_eq!(m.cols(), 2);
745        } else {
746            panic!("Expected matrix result");
747        }
748    }
749
750    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
751    #[test]
752    fn test_elementwise_mul_matrices() {
753        let m1 = Tensor::new_2d(vec![1.0, 2.0, 3.0, 4.0], 2, 2).unwrap();
754        let m2 = Tensor::new_2d(vec![2.0, 3.0, 4.0, 5.0], 2, 2).unwrap();
755        let result = block_on(elementwise_mul(&Value::Tensor(m1), &Value::Tensor(m2))).unwrap();
756
757        if let Value::Tensor(m) = result {
758            assert_eq!(m.data, vec![2.0, 6.0, 12.0, 20.0]);
759        } else {
760            panic!("Expected matrix result");
761        }
762    }
763
764    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
765    #[test]
766    fn test_elementwise_div_with_zero() {
767        let result = block_on(elementwise_div(&Value::Num(5.0), &Value::Num(0.0))).unwrap();
768        if let Value::Num(n) = result {
769            assert!(n.is_infinite() && n.is_sign_positive());
770        } else {
771            panic!("Expected numeric result");
772        }
773    }
774
775    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
776    #[test]
777    fn test_elementwise_pow() {
778        let matrix = Tensor::new_2d(vec![2.0, 3.0, 4.0, 5.0], 2, 2).unwrap();
779        let result = elementwise_pow(&Value::Tensor(matrix), &Value::Num(2.0)).unwrap();
780
781        if let Value::Tensor(m) = result {
782            assert_eq!(m.data, vec![4.0, 9.0, 16.0, 25.0]);
783        } else {
784            panic!("Expected matrix result");
785        }
786    }
787
788    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
789    #[test]
790    fn test_dimension_mismatch() {
791        let m1 = Tensor::new_2d(vec![1.0, 2.0], 1, 2).unwrap();
792        let m2 = Tensor::new_2d(vec![1.0, 2.0, 3.0, 4.0], 2, 2).unwrap();
793
794        assert!(block_on(elementwise_mul(&Value::Tensor(m1), &Value::Tensor(m2))).is_err());
795    }
796}