runmat_runtime/
elementwise.rs

1//! Element-wise operations for matrices and scalars
2//!
3//! This module implements language-compatible element-wise operations (.*,  ./,  .^)
4//! These operations work element-by-element on matrices and support scalar broadcasting.
5
6use crate::matrix::matrix_power;
7use runmat_builtins::{Tensor, Value};
8
9fn complex_pow_scalar(base_re: f64, base_im: f64, exp_re: f64, exp_im: f64) -> (f64, f64) {
10    if base_re == 0.0 && base_im == 0.0 && exp_re == 0.0 && exp_im == 0.0 {
11        return (1.0, 0.0);
12    }
13    if base_re == 0.0 && base_im == 0.0 && exp_im == 0.0 && exp_re > 0.0 {
14        return (0.0, 0.0);
15    }
16    let r = (base_re.hypot(base_im)).max(0.0);
17    if r == 0.0 {
18        return (0.0, 0.0);
19    }
20    let theta = base_im.atan2(base_re);
21    let ln_r = r.ln();
22    let a = exp_re * ln_r - exp_im * theta;
23    let b = exp_re * theta + exp_im * ln_r;
24    let mag = a.exp();
25    (mag * b.cos(), mag * b.sin())
26}
27
28fn to_host_value(v: &Value) -> Result<Value, String> {
29    match v {
30        Value::GpuTensor(h) => {
31            if let Some(p) = runmat_accelerate_api::provider() {
32                let ht = p.download(h).map_err(|e| e.to_string())?;
33                Ok(Value::Tensor(
34                    Tensor::new(ht.data, ht.shape).map_err(|e| e.to_string())?,
35                ))
36            } else {
37                // Fallback: zeros tensor with same shape
38                let total: usize = h.shape.iter().product();
39                Ok(Value::Tensor(
40                    Tensor::new(vec![0.0; total], h.shape.clone()).map_err(|e| e.to_string())?,
41                ))
42            }
43        }
44        other => Ok(other.clone()),
45    }
46}
47
48/// Element-wise negation: -A
49/// Supports scalars and matrices
50pub fn elementwise_neg(a: &Value) -> Result<Value, String> {
51    match a {
52        Value::Num(x) => Ok(Value::Num(-x)),
53        Value::Complex(re, im) => Ok(Value::Complex(-*re, -*im)),
54        Value::Int(x) => {
55            let v = x.to_i64();
56            if v >= i32::MIN as i64 && v <= i32::MAX as i64 {
57                Ok(Value::Int(runmat_builtins::IntValue::I32(-(v as i32))))
58            } else {
59                Ok(Value::Int(runmat_builtins::IntValue::I64(-v)))
60            }
61        }
62        Value::Bool(b) => Ok(Value::Bool(!b)), // Boolean negation
63        Value::Tensor(m) => {
64            let data: Vec<f64> = m.data.iter().map(|x| -x).collect();
65            Ok(Value::Tensor(Tensor::new_2d(data, m.rows(), m.cols())?))
66        }
67        _ => Err(format!("Negation not supported for type: -{a:?}")),
68    }
69}
70
71/// Element-wise multiplication: A .* B
72/// Supports matrix-matrix, matrix-scalar, and scalar-matrix operations
73pub fn elementwise_mul(a: &Value, b: &Value) -> Result<Value, String> {
74    // If exactly one is GPU, gather to host and recurse
75    if matches!(a, Value::GpuTensor(_)) ^ matches!(b, Value::GpuTensor(_)) {
76        let ah = to_host_value(a)?;
77        let bh = to_host_value(b)?;
78        return elementwise_mul(&ah, &bh);
79    }
80    if let Some(p) = runmat_accelerate_api::provider() {
81        if let (Value::GpuTensor(ha), Value::GpuTensor(hb)) = (a, b) {
82            if let Ok(hc) = p.elem_mul(ha, hb) {
83                let ht = p.download(&hc).map_err(|e| e.to_string())?;
84                return Ok(Value::Tensor(
85                    Tensor::new(ht.data, ht.shape).map_err(|e| e.to_string())?,
86                ));
87            }
88        }
89    }
90    match (a, b) {
91        // Complex scalars
92        (Value::Complex(ar, ai), Value::Complex(br, bi)) => {
93            Ok(Value::Complex(ar * br - ai * bi, ar * bi + ai * br))
94        }
95        (Value::Complex(ar, ai), Value::Num(s)) => Ok(Value::Complex(ar * s, ai * s)),
96        (Value::Num(s), Value::Complex(br, bi)) => Ok(Value::Complex(s * br, s * bi)),
97        // Scalar-scalar case
98        (Value::Num(x), Value::Num(y)) => Ok(Value::Num(x * y)),
99        (Value::Int(x), Value::Num(y)) => Ok(Value::Num(x.to_f64() * y)),
100        (Value::Num(x), Value::Int(y)) => Ok(Value::Num(x * y.to_f64())),
101        (Value::Int(x), Value::Int(y)) => Ok(Value::Num(x.to_f64() * y.to_f64())),
102
103        // Matrix-scalar cases (broadcasting)
104        (Value::Tensor(m), Value::Num(s)) => {
105            let data: Vec<f64> = m.data.iter().map(|x| x * s).collect();
106            Ok(Value::Tensor(Tensor::new_2d(data, m.rows(), m.cols())?))
107        }
108        (Value::Tensor(m), Value::Int(s)) => {
109            let scalar = s.to_f64();
110            let data: Vec<f64> = m.data.iter().map(|x| x * scalar).collect();
111            Ok(Value::Tensor(Tensor::new_2d(data, m.rows(), m.cols())?))
112        }
113        (Value::Num(s), Value::Tensor(m)) => {
114            let data: Vec<f64> = m.data.iter().map(|x| s * x).collect();
115            Ok(Value::Tensor(Tensor::new_2d(data, m.rows(), m.cols())?))
116        }
117        (Value::Int(s), Value::Tensor(m)) => {
118            let scalar = s.to_f64();
119            let data: Vec<f64> = m.data.iter().map(|x| scalar * x).collect();
120            Ok(Value::Tensor(Tensor::new_2d(data, m.rows(), m.cols())?))
121        }
122
123        // Matrix-matrix case
124        (Value::Tensor(m1), Value::Tensor(m2)) => {
125            if m1.rows() != m2.rows() || m1.cols() != m2.cols() {
126                return Err(format!(
127                    "Matrix dimensions must agree for element-wise multiplication: {}x{} .* {}x{}",
128                    m1.rows(),
129                    m1.cols(),
130                    m2.rows(),
131                    m2.cols()
132                ));
133            }
134            let data: Vec<f64> = m1
135                .data
136                .iter()
137                .zip(m2.data.iter())
138                .map(|(x, y)| x * y)
139                .collect();
140            Ok(Value::Tensor(Tensor::new_2d(data, m1.rows(), m1.cols())?))
141        }
142
143        // Complex tensors
144        (Value::ComplexTensor(m1), Value::ComplexTensor(m2)) => {
145            if m1.rows != m2.rows || m1.cols != m2.cols {
146                return Err(format!(
147                    "Matrix dimensions must agree for element-wise multiplication: {}x{} .* {}x{}",
148                    m1.rows, m1.cols, m2.rows, m2.cols
149                ));
150            }
151            let mut out: Vec<(f64, f64)> = Vec::with_capacity(m1.data.len());
152            for i in 0..m1.data.len() {
153                let (ar, ai) = m1.data[i];
154                let (br, bi) = m2.data[i];
155                out.push((ar * br - ai * bi, ar * bi + ai * br));
156            }
157            Ok(Value::ComplexTensor(
158                runmat_builtins::ComplexTensor::new(out, m1.shape.clone())
159                    .map_err(|e| format!(".*: {e}"))?,
160            ))
161        }
162        (Value::ComplexTensor(m), Value::Num(s)) => {
163            let data: Vec<(f64, f64)> = m.data.iter().map(|(re, im)| (re * s, im * s)).collect();
164            Ok(Value::ComplexTensor(
165                runmat_builtins::ComplexTensor::new_2d(data, m.rows, m.cols)?,
166            ))
167        }
168        (Value::Num(s), Value::ComplexTensor(m)) => {
169            let data: Vec<(f64, f64)> = m.data.iter().map(|(re, im)| (s * re, s * im)).collect();
170            Ok(Value::ComplexTensor(
171                runmat_builtins::ComplexTensor::new_2d(data, m.rows, m.cols)?,
172            ))
173        }
174
175        _ => Err(format!(
176            "Element-wise multiplication not supported for types: {a:?} .* {b:?}"
177        )),
178    }
179}
180
181/// Element-wise addition: A + B
182/// Supports matrix-matrix, matrix-scalar, and scalar-matrix operations
183pub fn elementwise_add(a: &Value, b: &Value) -> Result<Value, String> {
184    if matches!(a, Value::GpuTensor(_)) ^ matches!(b, Value::GpuTensor(_)) {
185        let ah = to_host_value(a)?;
186        let bh = to_host_value(b)?;
187        return elementwise_add(&ah, &bh);
188    }
189    if let Some(p) = runmat_accelerate_api::provider() {
190        if let (Value::GpuTensor(ha), Value::GpuTensor(hb)) = (a, b) {
191            if let Ok(hc) = p.elem_add(ha, hb) {
192                let ht = p.download(&hc).map_err(|e| e.to_string())?;
193                return Ok(Value::Tensor(
194                    Tensor::new(ht.data, ht.shape).map_err(|e| e.to_string())?,
195                ));
196            }
197        }
198    }
199    match (a, b) {
200        // Complex scalars
201        (Value::Complex(ar, ai), Value::Complex(br, bi)) => Ok(Value::Complex(ar + br, ai + bi)),
202        (Value::Complex(ar, ai), Value::Num(s)) => Ok(Value::Complex(ar + s, *ai)),
203        (Value::Num(s), Value::Complex(br, bi)) => Ok(Value::Complex(s + br, *bi)),
204        // Scalar-scalar case
205        (Value::Num(x), Value::Num(y)) => Ok(Value::Num(x + y)),
206        (Value::Int(x), Value::Num(y)) => Ok(Value::Num(x.to_f64() + y)),
207        (Value::Num(x), Value::Int(y)) => Ok(Value::Num(x + y.to_f64())),
208        (Value::Int(x), Value::Int(y)) => Ok(Value::Num(x.to_f64() + y.to_f64())),
209
210        // Matrix-scalar cases (broadcasting)
211        (Value::Tensor(m), Value::Num(s)) => {
212            let data: Vec<f64> = m.data.iter().map(|x| x + s).collect();
213            Ok(Value::Tensor(Tensor::new_2d(data, m.rows(), m.cols())?))
214        }
215        (Value::Tensor(m), Value::Int(s)) => {
216            let scalar = s.to_f64();
217            let data: Vec<f64> = m.data.iter().map(|x| x + scalar).collect();
218            Ok(Value::Tensor(Tensor::new_2d(data, m.rows(), m.cols())?))
219        }
220        (Value::Num(s), Value::Tensor(m)) => {
221            let data: Vec<f64> = m.data.iter().map(|x| s + x).collect();
222            Ok(Value::Tensor(Tensor::new_2d(data, m.rows(), m.cols())?))
223        }
224        (Value::Int(s), Value::Tensor(m)) => {
225            let scalar = s.to_f64();
226            let data: Vec<f64> = m.data.iter().map(|x| scalar + x).collect();
227            Ok(Value::Tensor(Tensor::new_2d(data, m.rows(), m.cols())?))
228        }
229
230        // Matrix-matrix case
231        (Value::Tensor(m1), Value::Tensor(m2)) => {
232            if m1.rows() != m2.rows() || m1.cols() != m2.cols() {
233                return Err(format!(
234                    "Matrix dimensions must agree for addition: {}x{} + {}x{}",
235                    m1.rows(),
236                    m1.cols(),
237                    m2.rows(),
238                    m2.cols()
239                ));
240            }
241            let data: Vec<f64> = m1
242                .data
243                .iter()
244                .zip(m2.data.iter())
245                .map(|(x, y)| x + y)
246                .collect();
247            Ok(Value::Tensor(Tensor::new_2d(data, m1.rows(), m1.cols())?))
248        }
249
250        // Complex tensors
251        (Value::ComplexTensor(m1), Value::ComplexTensor(m2)) => {
252            if m1.rows != m2.rows || m1.cols != m2.cols {
253                return Err(format!(
254                    "Matrix dimensions must agree for addition: {}x{} + {}x{}",
255                    m1.rows, m1.cols, m2.rows, m2.cols
256                ));
257            }
258            let data: Vec<(f64, f64)> = m1
259                .data
260                .iter()
261                .zip(m2.data.iter())
262                .map(|((ar, ai), (br, bi))| (ar + br, ai + bi))
263                .collect();
264            Ok(Value::ComplexTensor(
265                runmat_builtins::ComplexTensor::new_2d(data, m1.rows, m1.cols)?,
266            ))
267        }
268        (Value::ComplexTensor(m), Value::Num(s)) => {
269            let data: Vec<(f64, f64)> = m.data.iter().map(|(re, im)| (re + s, *im)).collect();
270            Ok(Value::ComplexTensor(
271                runmat_builtins::ComplexTensor::new_2d(data, m.rows, m.cols)?,
272            ))
273        }
274        (Value::Num(s), Value::ComplexTensor(m)) => {
275            let data: Vec<(f64, f64)> = m.data.iter().map(|(re, im)| (s + re, *im)).collect();
276            Ok(Value::ComplexTensor(
277                runmat_builtins::ComplexTensor::new_2d(data, m.rows, m.cols)?,
278            ))
279        }
280
281        _ => Err(format!("Addition not supported for types: {a:?} + {b:?}")),
282    }
283}
284
285/// Element-wise subtraction: A - B
286/// Supports matrix-matrix, matrix-scalar, and scalar-matrix operations
287pub fn elementwise_sub(a: &Value, b: &Value) -> Result<Value, String> {
288    if matches!(a, Value::GpuTensor(_)) ^ matches!(b, Value::GpuTensor(_)) {
289        let ah = to_host_value(a)?;
290        let bh = to_host_value(b)?;
291        return elementwise_sub(&ah, &bh);
292    }
293    if let Some(p) = runmat_accelerate_api::provider() {
294        if let (Value::GpuTensor(ha), Value::GpuTensor(hb)) = (a, b) {
295            if let Ok(hc) = p.elem_sub(ha, hb) {
296                let ht = p.download(&hc).map_err(|e| e.to_string())?;
297                return Ok(Value::Tensor(
298                    Tensor::new(ht.data, ht.shape).map_err(|e| e.to_string())?,
299                ));
300            }
301        }
302    }
303    match (a, b) {
304        // Complex scalars
305        (Value::Complex(ar, ai), Value::Complex(br, bi)) => Ok(Value::Complex(ar - br, ai - bi)),
306        (Value::Complex(ar, ai), Value::Num(s)) => Ok(Value::Complex(ar - s, *ai)),
307        (Value::Num(s), Value::Complex(br, bi)) => Ok(Value::Complex(s - br, -*bi)),
308        // Scalar-scalar case
309        (Value::Num(x), Value::Num(y)) => Ok(Value::Num(x - y)),
310        (Value::Int(x), Value::Num(y)) => Ok(Value::Num(x.to_f64() - y)),
311        (Value::Num(x), Value::Int(y)) => Ok(Value::Num(x - y.to_f64())),
312        (Value::Int(x), Value::Int(y)) => Ok(Value::Num(x.to_f64() - y.to_f64())),
313
314        // Matrix-scalar cases (broadcasting)
315        (Value::Tensor(m), Value::Num(s)) => {
316            let data: Vec<f64> = m.data.iter().map(|x| x - s).collect();
317            Ok(Value::Tensor(Tensor::new_2d(data, m.rows(), m.cols())?))
318        }
319        (Value::Tensor(m), Value::Int(s)) => {
320            let scalar = s.to_f64();
321            let data: Vec<f64> = m.data.iter().map(|x| x - scalar).collect();
322            Ok(Value::Tensor(Tensor::new_2d(data, m.rows(), m.cols())?))
323        }
324        (Value::Num(s), Value::Tensor(m)) => {
325            let data: Vec<f64> = m.data.iter().map(|x| s - x).collect();
326            Ok(Value::Tensor(Tensor::new_2d(data, m.rows(), m.cols())?))
327        }
328        (Value::Int(s), Value::Tensor(m)) => {
329            let scalar = s.to_f64();
330            let data: Vec<f64> = m.data.iter().map(|x| scalar - x).collect();
331            Ok(Value::Tensor(Tensor::new_2d(data, m.rows(), m.cols())?))
332        }
333
334        // Matrix-matrix case
335        (Value::Tensor(m1), Value::Tensor(m2)) => {
336            if m1.rows() != m2.rows() || m1.cols() != m2.cols() {
337                return Err(format!(
338                    "Matrix dimensions must agree for subtraction: {}x{} - {}x{}",
339                    m1.rows(),
340                    m1.cols(),
341                    m2.rows(),
342                    m2.cols()
343                ));
344            }
345            let data: Vec<f64> = m1
346                .data
347                .iter()
348                .zip(m2.data.iter())
349                .map(|(x, y)| x - y)
350                .collect();
351            Ok(Value::Tensor(Tensor::new_2d(data, m1.rows(), m1.cols())?))
352        }
353
354        // Complex tensors
355        (Value::ComplexTensor(m1), Value::ComplexTensor(m2)) => {
356            if m1.rows != m2.rows || m1.cols != m2.cols {
357                return Err(format!(
358                    "Matrix dimensions must agree for element-wise multiplication: {}x{} .* {}x{}",
359                    m1.rows, m1.cols, m2.rows, m2.cols
360                ));
361            }
362            let data: Vec<(f64, f64)> = m1
363                .data
364                .iter()
365                .zip(m2.data.iter())
366                .map(|((ar, ai), (br, bi))| (ar * br - ai * bi, ar * bi + ai * br))
367                .collect();
368            Ok(Value::ComplexTensor(
369                runmat_builtins::ComplexTensor::new_2d(data, m1.rows, m1.cols)?,
370            ))
371        }
372        (Value::ComplexTensor(m), Value::Num(s)) => {
373            let data: Vec<(f64, f64)> = m.data.iter().map(|(re, im)| (re * s, im * s)).collect();
374            Ok(Value::ComplexTensor(
375                runmat_builtins::ComplexTensor::new_2d(data, m.rows, m.cols)?,
376            ))
377        }
378        (Value::Num(s), Value::ComplexTensor(m)) => {
379            let data: Vec<(f64, f64)> = m.data.iter().map(|(re, im)| (s * re, s * im)).collect();
380            Ok(Value::ComplexTensor(
381                runmat_builtins::ComplexTensor::new_2d(data, m.rows, m.cols)?,
382            ))
383        }
384
385        _ => Err(format!(
386            "Subtraction not supported for types: {a:?} - {b:?}"
387        )),
388    }
389}
390
391/// Element-wise division: A ./ B
392/// Supports matrix-matrix, matrix-scalar, and scalar-matrix operations
393pub fn elementwise_div(a: &Value, b: &Value) -> Result<Value, String> {
394    if matches!(a, Value::GpuTensor(_)) ^ matches!(b, Value::GpuTensor(_)) {
395        let ah = to_host_value(a)?;
396        let bh = to_host_value(b)?;
397        return elementwise_div(&ah, &bh);
398    }
399    if let Some(p) = runmat_accelerate_api::provider() {
400        if let (Value::GpuTensor(ha), Value::GpuTensor(hb)) = (a, b) {
401            if let Ok(hc) = p.elem_div(ha, hb) {
402                let ht = p.download(&hc).map_err(|e| e.to_string())?;
403                return Ok(Value::Tensor(
404                    Tensor::new(ht.data, ht.shape).map_err(|e| e.to_string())?,
405                ));
406            }
407        }
408    }
409    match (a, b) {
410        // Complex scalars
411        (Value::Complex(ar, ai), Value::Complex(br, bi)) => {
412            let denom = br * br + bi * bi;
413            if denom == 0.0 {
414                return Ok(Value::Num(f64::NAN));
415            }
416            Ok(Value::Complex(
417                (ar * br + ai * bi) / denom,
418                (ai * br - ar * bi) / denom,
419            ))
420        }
421        (Value::Complex(ar, ai), Value::Num(s)) => Ok(Value::Complex(ar / s, ai / s)),
422        (Value::Num(s), Value::Complex(br, bi)) => {
423            let denom = br * br + bi * bi;
424            if denom == 0.0 {
425                return Ok(Value::Num(f64::NAN));
426            }
427            Ok(Value::Complex((s * br) / denom, (-s * bi) / denom))
428        }
429        // Scalar-scalar case
430        (Value::Num(x), Value::Num(y)) => {
431            if *y == 0.0 {
432                Ok(Value::Num(f64::INFINITY * x.signum()))
433            } else {
434                Ok(Value::Num(x / y))
435            }
436        }
437        (Value::Int(x), Value::Num(y)) => {
438            if *y == 0.0 {
439                Ok(Value::Num(f64::INFINITY * x.to_f64().signum()))
440            } else {
441                Ok(Value::Num(x.to_f64() / y))
442            }
443        }
444        (Value::Num(x), Value::Int(y)) => {
445            if y.is_zero() {
446                Ok(Value::Num(f64::INFINITY * x.signum()))
447            } else {
448                Ok(Value::Num(x / y.to_f64()))
449            }
450        }
451        (Value::Int(x), Value::Int(y)) => {
452            if y.is_zero() {
453                Ok(Value::Num(f64::INFINITY * x.to_f64().signum()))
454            } else {
455                Ok(Value::Num(x.to_f64() / y.to_f64()))
456            }
457        }
458
459        // Matrix-scalar cases (broadcasting)
460        (Value::Tensor(m), Value::Num(s)) => {
461            if *s == 0.0 {
462                let data: Vec<f64> = m.data.iter().map(|x| f64::INFINITY * x.signum()).collect();
463                Ok(Value::Tensor(Tensor::new_2d(data, m.rows(), m.cols())?))
464            } else {
465                let data: Vec<f64> = m.data.iter().map(|x| x / s).collect();
466                Ok(Value::Tensor(Tensor::new_2d(data, m.rows(), m.cols())?))
467            }
468        }
469        (Value::Tensor(m), Value::Int(s)) => {
470            let scalar = s.to_f64();
471            if scalar == 0.0 {
472                let data: Vec<f64> = m.data.iter().map(|x| f64::INFINITY * x.signum()).collect();
473                Ok(Value::Tensor(Tensor::new_2d(data, m.rows(), m.cols())?))
474            } else {
475                let data: Vec<f64> = m.data.iter().map(|x| x / scalar).collect();
476                Ok(Value::Tensor(Tensor::new_2d(data, m.rows(), m.cols())?))
477            }
478        }
479        (Value::Num(s), Value::Tensor(m)) => {
480            let data: Vec<f64> = m
481                .data
482                .iter()
483                .map(|x| {
484                    if *x == 0.0 {
485                        f64::INFINITY * s.signum()
486                    } else {
487                        s / x
488                    }
489                })
490                .collect();
491            Ok(Value::Tensor(Tensor::new_2d(data, m.rows(), m.cols())?))
492        }
493        (Value::Int(s), Value::Tensor(m)) => {
494            let scalar = s.to_f64();
495            let data: Vec<f64> = m
496                .data
497                .iter()
498                .map(|x| {
499                    if *x == 0.0 {
500                        f64::INFINITY * scalar.signum()
501                    } else {
502                        scalar / x
503                    }
504                })
505                .collect();
506            Ok(Value::Tensor(Tensor::new_2d(data, m.rows(), m.cols())?))
507        }
508
509        // Matrix-matrix case
510        (Value::Tensor(m1), Value::Tensor(m2)) => {
511            if m1.rows() != m2.rows() || m1.cols() != m2.cols() {
512                return Err(format!(
513                    "Matrix dimensions must agree for element-wise division: {}x{} ./ {}x{}",
514                    m1.rows(),
515                    m1.cols(),
516                    m2.rows(),
517                    m2.cols()
518                ));
519            }
520            let data: Vec<f64> = m1
521                .data
522                .iter()
523                .zip(m2.data.iter())
524                .map(|(x, y)| {
525                    if *y == 0.0 {
526                        f64::INFINITY * x.signum()
527                    } else {
528                        x / y
529                    }
530                })
531                .collect();
532            Ok(Value::Tensor(Tensor::new_2d(data, m1.rows(), m1.cols())?))
533        }
534
535        // Complex tensors
536        (Value::ComplexTensor(m1), Value::ComplexTensor(m2)) => {
537            if m1.rows != m2.rows || m1.cols != m2.cols {
538                return Err(format!(
539                    "Matrix dimensions must agree for element-wise division: {}x{} ./ {}x{}",
540                    m1.rows, m1.cols, m2.rows, m2.cols
541                ));
542            }
543            let data: Vec<(f64, f64)> = m1
544                .data
545                .iter()
546                .zip(m2.data.iter())
547                .map(|((ar, ai), (br, bi))| {
548                    let denom = br * br + bi * bi;
549                    if denom == 0.0 {
550                        (f64::NAN, f64::NAN)
551                    } else {
552                        ((ar * br + ai * bi) / denom, (ai * br - ar * bi) / denom)
553                    }
554                })
555                .collect();
556            Ok(Value::ComplexTensor(
557                runmat_builtins::ComplexTensor::new_2d(data, m1.rows, m1.cols)?,
558            ))
559        }
560        (Value::ComplexTensor(m), Value::Num(s)) => {
561            let data: Vec<(f64, f64)> = m.data.iter().map(|(re, im)| (re / s, im / s)).collect();
562            Ok(Value::ComplexTensor(
563                runmat_builtins::ComplexTensor::new_2d(data, m.rows, m.cols)?,
564            ))
565        }
566        (Value::Num(s), Value::ComplexTensor(m)) => {
567            let data: Vec<(f64, f64)> = m
568                .data
569                .iter()
570                .map(|(br, bi)| {
571                    let denom = br * br + bi * bi;
572                    if denom == 0.0 {
573                        (f64::NAN, f64::NAN)
574                    } else {
575                        ((s * br) / denom, (-s * bi) / denom)
576                    }
577                })
578                .collect();
579            Ok(Value::ComplexTensor(
580                runmat_builtins::ComplexTensor::new_2d(data, m.rows, m.cols)?,
581            ))
582        }
583
584        _ => Err(format!(
585            "Element-wise division not supported for types: {a:?} ./ {b:?}"
586        )),
587    }
588}
589
590/// Regular power operation: A ^ B  
591/// For matrices, this is matrix exponentiation (A^n where n is integer)
592/// For scalars, this is regular exponentiation
593pub fn power(a: &Value, b: &Value) -> Result<Value, String> {
594    match (a, b) {
595        // Scalar cases - include complex
596        (Value::Complex(ar, ai), Value::Complex(br, bi)) => {
597            let (r, i) = complex_pow_scalar(*ar, *ai, *br, *bi);
598            Ok(Value::Complex(r, i))
599        }
600        (Value::Complex(ar, ai), Value::Num(y)) => {
601            let (r, i) = complex_pow_scalar(*ar, *ai, *y, 0.0);
602            Ok(Value::Complex(r, i))
603        }
604        (Value::Num(x), Value::Complex(br, bi)) => {
605            let (r, i) = complex_pow_scalar(*x, 0.0, *br, *bi);
606            Ok(Value::Complex(r, i))
607        }
608        (Value::Complex(ar, ai), Value::Int(y)) => {
609            let yv = y.to_f64();
610            let (r, i) = complex_pow_scalar(*ar, *ai, yv, 0.0);
611            Ok(Value::Complex(r, i))
612        }
613        (Value::Int(x), Value::Complex(br, bi)) => {
614            let xv = x.to_f64();
615            let (r, i) = complex_pow_scalar(xv, 0.0, *br, *bi);
616            Ok(Value::Complex(r, i))
617        }
618
619        // Scalar cases - real only
620        (Value::Num(x), Value::Num(y)) => Ok(Value::Num(x.powf(*y))),
621        (Value::Int(x), Value::Num(y)) => Ok(Value::Num(x.to_f64().powf(*y))),
622        (Value::Num(x), Value::Int(y)) => Ok(Value::Num(x.powf(y.to_f64()))),
623        (Value::Int(x), Value::Int(y)) => Ok(Value::Num(x.to_f64().powf(y.to_f64()))),
624
625        // Matrix^scalar case - matrix exponentiation
626        (Value::Tensor(m), Value::Num(s)) => {
627            // Check if scalar is an integer for matrix power
628            if s.fract() == 0.0 {
629                let n = *s as i32;
630                let result = matrix_power(m, n)?;
631                Ok(Value::Tensor(result))
632            } else {
633                Err("Matrix power requires integer exponent".to_string())
634            }
635        }
636        (Value::Tensor(m), Value::Int(s)) => {
637            let result = matrix_power(m, s.to_i64() as i32)?;
638            Ok(Value::Tensor(result))
639        }
640
641        // Complex matrix^integer case
642        (Value::ComplexTensor(m), Value::Num(s)) => {
643            if s.fract() == 0.0 {
644                let n = *s as i32;
645                let result = crate::matrix::complex_matrix_power(m, n)?;
646                Ok(Value::ComplexTensor(result))
647            } else {
648                Err("Matrix power requires integer exponent".to_string())
649            }
650        }
651        (Value::ComplexTensor(m), Value::Int(s)) => {
652            let result = crate::matrix::complex_matrix_power(m, s.to_i64() as i32)?;
653            Ok(Value::ComplexTensor(result))
654        }
655
656        // Other cases not supported for regular matrix power
657        _ => Err(format!(
658            "Power operation not supported for types: {a:?} ^ {b:?}"
659        )),
660    }
661}
662
663/// Element-wise power: A .^ B
664/// Supports matrix-matrix, matrix-scalar, and scalar-matrix operations
665pub fn elementwise_pow(a: &Value, b: &Value) -> Result<Value, String> {
666    match (a, b) {
667        // Complex scalar cases
668        (Value::Complex(ar, ai), Value::Complex(br, bi)) => {
669            let (r, i) = complex_pow_scalar(*ar, *ai, *br, *bi);
670            Ok(Value::Complex(r, i))
671        }
672        (Value::Complex(ar, ai), Value::Num(y)) => {
673            let (r, i) = complex_pow_scalar(*ar, *ai, *y, 0.0);
674            Ok(Value::Complex(r, i))
675        }
676        (Value::Num(x), Value::Complex(br, bi)) => {
677            let (r, i) = complex_pow_scalar(*x, 0.0, *br, *bi);
678            Ok(Value::Complex(r, i))
679        }
680        (Value::Complex(ar, ai), Value::Int(y)) => {
681            let yv = y.to_f64();
682            let (r, i) = complex_pow_scalar(*ar, *ai, yv, 0.0);
683            Ok(Value::Complex(r, i))
684        }
685        (Value::Int(x), Value::Complex(br, bi)) => {
686            let xv = x.to_f64();
687            let (r, i) = complex_pow_scalar(xv, 0.0, *br, *bi);
688            Ok(Value::Complex(r, i))
689        }
690        // Scalar-scalar case
691        (Value::Num(x), Value::Num(y)) => Ok(Value::Num(x.powf(*y))),
692        (Value::Int(x), Value::Num(y)) => Ok(Value::Num(x.to_f64().powf(*y))),
693        (Value::Num(x), Value::Int(y)) => Ok(Value::Num(x.powf(y.to_f64()))),
694        (Value::Int(x), Value::Int(y)) => Ok(Value::Num(x.to_f64().powf(y.to_f64()))),
695
696        // Matrix-scalar cases (broadcasting)
697        (Value::Tensor(m), Value::Num(s)) => {
698            let data: Vec<f64> = m.data.iter().map(|x| x.powf(*s)).collect();
699            Ok(Value::Tensor(Tensor::new_2d(data, m.rows(), m.cols())?))
700        }
701        (Value::Tensor(m), Value::Int(s)) => {
702            let scalar = s.to_f64();
703            let data: Vec<f64> = m.data.iter().map(|x| x.powf(scalar)).collect();
704            Ok(Value::Tensor(Tensor::new_2d(data, m.rows(), m.cols())?))
705        }
706        (Value::Num(s), Value::Tensor(m)) => {
707            let data: Vec<f64> = m.data.iter().map(|x| s.powf(*x)).collect();
708            Ok(Value::Tensor(Tensor::new_2d(data, m.rows(), m.cols())?))
709        }
710        (Value::Int(s), Value::Tensor(m)) => {
711            let scalar = s.to_f64();
712            let data: Vec<f64> = m.data.iter().map(|x| scalar.powf(*x)).collect();
713            Ok(Value::Tensor(Tensor::new_2d(data, m.rows(), m.cols())?))
714        }
715
716        // Matrix-matrix case
717        (Value::Tensor(m1), Value::Tensor(m2)) => {
718            if m1.rows() != m2.rows() || m1.cols() != m2.cols() {
719                return Err(format!(
720                    "Matrix dimensions must agree for element-wise power: {}x{} .^ {}x{}",
721                    m1.rows(),
722                    m1.cols(),
723                    m2.rows(),
724                    m2.cols()
725                ));
726            }
727            let data: Vec<f64> = m1
728                .data
729                .iter()
730                .zip(m2.data.iter())
731                .map(|(x, y)| x.powf(*y))
732                .collect();
733            Ok(Value::Tensor(Tensor::new_2d(data, m1.rows(), m1.cols())?))
734        }
735
736        // Complex tensor element-wise power
737        (Value::ComplexTensor(m1), Value::ComplexTensor(m2)) => {
738            if m1.rows != m2.rows || m1.cols != m2.cols {
739                return Err(format!(
740                    "Matrix dimensions must agree for element-wise power: {}x{} .^ {}x{}",
741                    m1.rows, m1.cols, m2.rows, m2.cols
742                ));
743            }
744            let mut out: Vec<(f64, f64)> = Vec::with_capacity(m1.data.len());
745            for i in 0..m1.data.len() {
746                let (ar, ai) = m1.data[i];
747                let (br, bi) = m2.data[i];
748                out.push(complex_pow_scalar(ar, ai, br, bi));
749            }
750            Ok(Value::ComplexTensor(
751                runmat_builtins::ComplexTensor::new_2d(out, m1.rows, m1.cols)?,
752            ))
753        }
754        (Value::ComplexTensor(m), Value::Num(s)) => {
755            let out: Vec<(f64, f64)> = m
756                .data
757                .iter()
758                .map(|(ar, ai)| complex_pow_scalar(*ar, *ai, *s, 0.0))
759                .collect();
760            Ok(Value::ComplexTensor(
761                runmat_builtins::ComplexTensor::new_2d(out, m.rows, m.cols)?,
762            ))
763        }
764        (Value::ComplexTensor(m), Value::Int(s)) => {
765            let sv = s.to_f64();
766            let out: Vec<(f64, f64)> = m
767                .data
768                .iter()
769                .map(|(ar, ai)| complex_pow_scalar(*ar, *ai, sv, 0.0))
770                .collect();
771            Ok(Value::ComplexTensor(
772                runmat_builtins::ComplexTensor::new_2d(out, m.rows, m.cols)?,
773            ))
774        }
775        (Value::ComplexTensor(m), Value::Complex(br, bi)) => {
776            let out: Vec<(f64, f64)> = m
777                .data
778                .iter()
779                .map(|(ar, ai)| complex_pow_scalar(*ar, *ai, *br, *bi))
780                .collect();
781            Ok(Value::ComplexTensor(
782                runmat_builtins::ComplexTensor::new_2d(out, m.rows, m.cols)?,
783            ))
784        }
785        (Value::Num(s), Value::ComplexTensor(m)) => {
786            let out: Vec<(f64, f64)> = m
787                .data
788                .iter()
789                .map(|(br, bi)| complex_pow_scalar(*s, 0.0, *br, *bi))
790                .collect();
791            Ok(Value::ComplexTensor(
792                runmat_builtins::ComplexTensor::new_2d(out, m.rows, m.cols)?,
793            ))
794        }
795        (Value::Int(s), Value::ComplexTensor(m)) => {
796            let sv = s.to_f64();
797            let out: Vec<(f64, f64)> = m
798                .data
799                .iter()
800                .map(|(br, bi)| complex_pow_scalar(sv, 0.0, *br, *bi))
801                .collect();
802            Ok(Value::ComplexTensor(
803                runmat_builtins::ComplexTensor::new_2d(out, m.rows, m.cols)?,
804            ))
805        }
806        (Value::Complex(br, bi), Value::ComplexTensor(m)) => {
807            let out: Vec<(f64, f64)> = m
808                .data
809                .iter()
810                .map(|(er, ei)| complex_pow_scalar(*br, *bi, *er, *ei))
811                .collect();
812            Ok(Value::ComplexTensor(
813                runmat_builtins::ComplexTensor::new_2d(out, m.rows, m.cols)?,
814            ))
815        }
816
817        _ => Err(format!(
818            "Element-wise power not supported for types: {a:?} .^ {b:?}"
819        )),
820    }
821}
822
823// Element-wise operations are not directly exposed as runtime builtins because they need
824// to handle multiple types (Value enum variants). Instead, they are called directly from
825// the interpreter and JIT compiler using the elementwise_* functions above.
826
827#[cfg(test)]
828mod tests {
829    use super::*;
830
831    #[test]
832    fn test_elementwise_mul_scalars() {
833        assert_eq!(
834            elementwise_mul(&Value::Num(3.0), &Value::Num(4.0)).unwrap(),
835            Value::Num(12.0)
836        );
837        assert_eq!(
838            elementwise_mul(
839                &Value::Int(runmat_builtins::IntValue::I32(3)),
840                &Value::Num(4.5)
841            )
842            .unwrap(),
843            Value::Num(13.5)
844        );
845    }
846
847    #[test]
848    fn test_elementwise_mul_matrix_scalar() {
849        let matrix = Tensor::new_2d(vec![1.0, 2.0, 3.0, 4.0], 2, 2).unwrap();
850        let result = elementwise_mul(&Value::Tensor(matrix), &Value::Num(2.0)).unwrap();
851
852        if let Value::Tensor(m) = result {
853            assert_eq!(m.data, vec![2.0, 4.0, 6.0, 8.0]);
854            assert_eq!(m.rows(), 2);
855            assert_eq!(m.cols(), 2);
856        } else {
857            panic!("Expected matrix result");
858        }
859    }
860
861    #[test]
862    fn test_elementwise_mul_matrices() {
863        let m1 = Tensor::new_2d(vec![1.0, 2.0, 3.0, 4.0], 2, 2).unwrap();
864        let m2 = Tensor::new_2d(vec![2.0, 3.0, 4.0, 5.0], 2, 2).unwrap();
865        let result = elementwise_mul(&Value::Tensor(m1), &Value::Tensor(m2)).unwrap();
866
867        if let Value::Tensor(m) = result {
868            assert_eq!(m.data, vec![2.0, 6.0, 12.0, 20.0]);
869        } else {
870            panic!("Expected matrix result");
871        }
872    }
873
874    #[test]
875    fn test_elementwise_div_with_zero() {
876        let result = elementwise_div(&Value::Num(5.0), &Value::Num(0.0)).unwrap();
877        if let Value::Num(n) = result {
878            assert!(n.is_infinite() && n.is_sign_positive());
879        } else {
880            panic!("Expected numeric result");
881        }
882    }
883
884    #[test]
885    fn test_elementwise_pow() {
886        let matrix = Tensor::new_2d(vec![2.0, 3.0, 4.0, 5.0], 2, 2).unwrap();
887        let result = elementwise_pow(&Value::Tensor(matrix), &Value::Num(2.0)).unwrap();
888
889        if let Value::Tensor(m) = result {
890            assert_eq!(m.data, vec![4.0, 9.0, 16.0, 25.0]);
891        } else {
892            panic!("Expected matrix result");
893        }
894    }
895
896    #[test]
897    fn test_dimension_mismatch() {
898        let m1 = Tensor::new_2d(vec![1.0, 2.0], 1, 2).unwrap();
899        let m2 = Tensor::new_2d(vec![1.0, 2.0, 3.0, 4.0], 2, 2).unwrap();
900
901        assert!(elementwise_mul(&Value::Tensor(m1), &Value::Tensor(m2)).is_err());
902    }
903}