Skip to main content

runmat_runtime/builtins/math/rounding/
mod.rs

1//! MATLAB-compatible `mod` builtin plus rounding helpers for RunMat.
2
3pub(crate) mod ceil;
4pub(crate) mod fix;
5pub(crate) mod floor;
6pub(crate) mod rem;
7pub(crate) mod round;
8
9use runmat_accelerate_api::GpuTensorHandle;
10use runmat_builtins::{ComplexTensor, Tensor, Value};
11use runmat_macros::runtime_builtin;
12
13use crate::builtins::common::broadcast::BroadcastPlan;
14use crate::builtins::common::spec::{
15    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, FusionError,
16    FusionExprContext, FusionKernelTemplate, GpuOpKind, ProviderHook, ReductionNaN,
17    ResidencyPolicy, ScalarType, ShapeRequirements,
18};
19use crate::builtins::common::{gpu_helpers, tensor};
20use crate::builtins::math::type_resolvers::numeric_binary_type;
21use crate::{build_runtime_error, BuiltinResult, RuntimeError};
22
23#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::math::rounding")]
24pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
25    name: "mod",
26    op_kind: GpuOpKind::Elementwise,
27    supported_precisions: &[ScalarType::F32, ScalarType::F64],
28    broadcast: BroadcastSemantics::Matlab,
29    provider_hooks: &[
30        ProviderHook::Binary {
31            name: "elem_div",
32            commutative: false,
33        },
34        ProviderHook::Unary { name: "unary_floor" },
35        ProviderHook::Binary {
36            name: "elem_mul",
37            commutative: false,
38        },
39        ProviderHook::Binary {
40            name: "elem_sub",
41            commutative: false,
42        },
43    ],
44    constant_strategy: ConstantStrategy::InlineLiteral,
45    residency: ResidencyPolicy::NewHandle,
46    nan_mode: ReductionNaN::Include,
47    two_pass_threshold: None,
48    workgroup_size: None,
49    accepts_nan_mode: false,
50    notes:
51        "Providers can keep mod on-device by composing elem_div → unary_floor → elem_mul → elem_sub for matching shapes. Future backends may expose a dedicated elem_mod hook.",
52};
53
54#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::math::rounding")]
55pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
56    name: "mod",
57    shape: ShapeRequirements::BroadcastCompatible,
58    constant_strategy: ConstantStrategy::InlineLiteral,
59    elementwise: Some(FusionKernelTemplate {
60        scalar_precisions: &[ScalarType::F32, ScalarType::F64],
61        wgsl_body: |ctx: &FusionExprContext| {
62            let a = ctx
63                .inputs
64                .first()
65                .ok_or(FusionError::MissingInput(0))?;
66            let b = ctx.inputs.get(1).ok_or(FusionError::MissingInput(1))?;
67            Ok(format!("{a} - {b} * floor({a} / {b})"))
68        },
69    }),
70    reduction: None,
71    emits_nan: true,
72    notes: "Fusion generates floor(a / b) followed by a - b * q; providers may substitute specialised kernels when available.",
73};
74
75const BUILTIN_NAME: &str = "mod";
76
77fn builtin_error(message: impl Into<String>) -> RuntimeError {
78    build_runtime_error(message)
79        .with_builtin(BUILTIN_NAME)
80        .build()
81}
82
83#[runtime_builtin(
84    name = "mod",
85    category = "math/rounding",
86    summary = "MATLAB-compatible modulus a - b .* floor(a./b) with support for complex values and broadcasting.",
87    keywords = "mod,modulus,remainder,gpu",
88    accel = "binary",
89    type_resolver(numeric_binary_type),
90    builtin_path = "crate::builtins::math::rounding"
91)]
92async fn mod_builtin(lhs: Value, rhs: Value) -> BuiltinResult<Value> {
93    match (lhs, rhs) {
94        (Value::GpuTensor(a), Value::GpuTensor(b)) => mod_gpu_pair(a, b).await,
95        (Value::GpuTensor(a), other) => {
96            let gathered = gpu_helpers::gather_tensor_async(&a).await?;
97            mod_host(Value::Tensor(gathered), other)
98        }
99        (other, Value::GpuTensor(b)) => {
100            let gathered = gpu_helpers::gather_tensor_async(&b).await?;
101            mod_host(other, Value::Tensor(gathered))
102        }
103        (left, right) => mod_host(left, right),
104    }
105}
106
107async fn mod_gpu_pair(a: GpuTensorHandle, b: GpuTensorHandle) -> BuiltinResult<Value> {
108    if a.device_id == b.device_id {
109        if let Some(provider) = runmat_accelerate_api::provider_for_handle(&a) {
110            if a.shape == b.shape {
111                if let Ok(div) = provider.elem_div(&a, &b).await {
112                    match provider.unary_floor(&div).await {
113                        Ok(floored) => match provider.elem_mul(&b, &floored).await {
114                            Ok(mul) => match provider.elem_sub(&a, &mul).await {
115                                Ok(out) => {
116                                    let _ = provider.free(&div);
117                                    let _ = provider.free(&floored);
118                                    let _ = provider.free(&mul);
119                                    return Ok(gpu_helpers::resident_gpu_value(out));
120                                }
121                                Err(_) => {
122                                    let _ = provider.free(&mul);
123                                    let _ = provider.free(&floored);
124                                    let _ = provider.free(&div);
125                                }
126                            },
127                            Err(_) => {
128                                let _ = provider.free(&floored);
129                                let _ = provider.free(&div);
130                            }
131                        },
132                        Err(_) => {
133                            let _ = provider.free(&div);
134                        }
135                    }
136                }
137            }
138        }
139    }
140    let left = gpu_helpers::gather_tensor_async(&a).await?;
141    let right = gpu_helpers::gather_tensor_async(&b).await?;
142    mod_host(Value::Tensor(left), Value::Tensor(right))
143}
144
145fn mod_host(lhs: Value, rhs: Value) -> BuiltinResult<Value> {
146    if let Some(result) = scalar_mod_value(&lhs, &rhs) {
147        return Ok(result);
148    }
149    let left = value_into_numeric_array(lhs, "mod")?;
150    let right = value_into_numeric_array(rhs, "mod")?;
151    match align_numeric_arrays(left, right)? {
152        NumericPair::Real(a, b) => compute_mod_real(&a, &b),
153        NumericPair::Complex(a, b) => compute_mod_complex(&a, &b),
154    }
155}
156
157fn compute_mod_real(a: &Tensor, b: &Tensor) -> BuiltinResult<Value> {
158    let plan = BroadcastPlan::new(&a.shape, &b.shape)
159        .map_err(|err| builtin_error(format!("mod: {err}")))?;
160    if plan.is_empty() {
161        let tensor = Tensor::new(Vec::new(), plan.output_shape().to_vec())
162            .map_err(|e| builtin_error(format!("mod: {e}")))?;
163        return Ok(tensor::tensor_into_value(tensor));
164    }
165    let mut result = vec![0.0f64; plan.len()];
166    for (out_idx, idx_a, idx_b) in plan.iter() {
167        let aval = a.data[idx_a];
168        let bval = b.data[idx_b];
169        result[out_idx] = mod_real_scalar(aval, bval);
170    }
171    let tensor = Tensor::new(result, plan.output_shape().to_vec())
172        .map_err(|e| builtin_error(format!("mod: {e}")))?;
173    Ok(tensor::tensor_into_value(tensor))
174}
175
176fn compute_mod_complex(a: &ComplexTensor, b: &ComplexTensor) -> BuiltinResult<Value> {
177    let plan = BroadcastPlan::new(&a.shape, &b.shape)
178        .map_err(|err| builtin_error(format!("mod: {err}")))?;
179    if plan.is_empty() {
180        let tensor = ComplexTensor::new(Vec::new(), plan.output_shape().to_vec())
181            .map_err(|e| builtin_error(format!("mod: {e}")))?;
182        return Ok(complex_tensor_into_value(tensor));
183    }
184    let mut result = vec![(0.0f64, 0.0f64); plan.len()];
185    for (out_idx, idx_a, idx_b) in plan.iter() {
186        let (ar, ai) = a.data[idx_a];
187        let (br, bi) = b.data[idx_b];
188        result[out_idx] = mod_complex_scalar(ar, ai, br, bi);
189    }
190    let tensor = ComplexTensor::new(result, plan.output_shape().to_vec())
191        .map_err(|e| builtin_error(format!("mod: {e}")))?;
192    Ok(complex_tensor_into_value(tensor))
193}
194
195fn mod_real_scalar(a: f64, b: f64) -> f64 {
196    if a.is_nan() || b.is_nan() {
197        return f64::NAN;
198    }
199    if b == 0.0 {
200        return f64::NAN;
201    }
202    if !a.is_finite() && b.is_finite() {
203        return f64::NAN;
204    }
205    let quotient = (a / b).floor();
206    let mut remainder = a - b * quotient;
207    if remainder == 0.0 {
208        remainder = 0.0;
209    }
210    if b.is_infinite() && a.is_finite() {
211        // MATLAB sign-correction: mod(a, ±Inf) returns a when signs match, ±Inf otherwise.
212        if a == 0.0 {
213            return 0.0;
214        }
215        return if a.signum() == b.signum() { a } else { b };
216    }
217    if !remainder.is_finite() && !a.is_finite() {
218        return f64::NAN;
219    }
220    let same_sign = remainder == 0.0 || remainder.signum() == b.signum();
221    if !same_sign {
222        remainder += b;
223    }
224    if remainder == -0.0 {
225        remainder = 0.0;
226    }
227    remainder
228}
229
230fn mod_complex_scalar(ar: f64, ai: f64, br: f64, bi: f64) -> (f64, f64) {
231    if (ar.is_nan() || ai.is_nan()) || (br.is_nan() || bi.is_nan()) {
232        return (f64::NAN, f64::NAN);
233    }
234    if br == 0.0 && bi == 0.0 {
235        return (f64::NAN, f64::NAN);
236    }
237    if !ar.is_finite() || !ai.is_finite() {
238        return (f64::NAN, f64::NAN);
239    }
240    let (qr, qi) = complex_div(ar, ai, br, bi);
241    if !qr.is_finite() && !qi.is_finite() && br.is_finite() && bi.is_finite() {
242        return (f64::NAN, f64::NAN);
243    }
244    let (fr, fi) = (qr.floor(), qi.floor());
245    let (mulr, muli) = complex_mul(br, bi, fr, fi);
246    let (rr, ri) = (ar - mulr, ai - muli);
247    (normalize_zero(rr), normalize_zero(ri))
248}
249
250fn scalar_real_value(value: &Value) -> Option<f64> {
251    match value {
252        Value::Num(n) => Some(*n),
253        Value::Int(i) => Some(i.to_f64()),
254        Value::Bool(b) => Some(if *b { 1.0 } else { 0.0 }),
255        Value::Tensor(t) if t.data.len() == 1 => t.data.first().copied(),
256        Value::LogicalArray(l) if l.data.len() == 1 => Some(if l.data[0] != 0 { 1.0 } else { 0.0 }),
257        Value::CharArray(ca) if ca.rows * ca.cols == 1 => {
258            Some(ca.data.first().map(|&ch| ch as u32 as f64).unwrap_or(0.0))
259        }
260        _ => None,
261    }
262}
263
264fn scalar_complex_value(value: &Value) -> Option<(f64, f64)> {
265    match value {
266        Value::Complex(re, im) => Some((*re, *im)),
267        Value::ComplexTensor(ct) if ct.data.len() == 1 => ct.data.first().copied(),
268        _ => None,
269    }
270}
271
272fn scalar_mod_value(lhs: &Value, rhs: &Value) -> Option<Value> {
273    let left = scalar_complex_value(lhs).or_else(|| scalar_real_value(lhs).map(|v| (v, 0.0)))?;
274    let right = scalar_complex_value(rhs).or_else(|| scalar_real_value(rhs).map(|v| (v, 0.0)))?;
275    let (ar, ai) = left;
276    let (br, bi) = right;
277    if ai != 0.0 || bi != 0.0 {
278        let (re, im) = mod_complex_scalar(ar, ai, br, bi);
279        return Some(Value::Complex(re, im));
280    }
281    Some(Value::Num(mod_real_scalar(ar, br)))
282}
283
284fn normalize_zero(value: f64) -> f64 {
285    if value == -0.0 {
286        0.0
287    } else {
288        value
289    }
290}
291
292fn complex_mul(ar: f64, ai: f64, br: f64, bi: f64) -> (f64, f64) {
293    (ar * br - ai * bi, ar * bi + ai * br)
294}
295
296fn complex_div(ar: f64, ai: f64, br: f64, bi: f64) -> (f64, f64) {
297    let denom = br * br + bi * bi;
298    if denom == 0.0 {
299        return (f64::NAN, f64::NAN);
300    }
301    ((ar * br + ai * bi) / denom, (ai * br - ar * bi) / denom)
302}
303
304fn complex_tensor_into_value(tensor: ComplexTensor) -> Value {
305    if tensor.data.len() == 1 {
306        let (re, im) = tensor.data[0];
307        Value::Complex(re, im)
308    } else {
309        Value::ComplexTensor(tensor)
310    }
311}
312
313fn value_into_numeric_array(value: Value, name: &str) -> BuiltinResult<NumericArray> {
314    match value {
315        Value::Complex(re, im) => {
316            let tensor = ComplexTensor::new(vec![(re, im)], vec![1, 1])
317                .map_err(|e| builtin_error(format!("{name}: {e}")))?;
318            Ok(NumericArray::Complex(tensor))
319        }
320        Value::ComplexTensor(ct) => Ok(NumericArray::Complex(ct)),
321        Value::CharArray(ca) => {
322            let data: Vec<f64> = ca.data.iter().map(|&ch| ch as u32 as f64).collect();
323            let tensor = Tensor::new(data, vec![ca.rows, ca.cols])
324                .map_err(|e| builtin_error(format!("{name}: {e}")))?;
325            Ok(NumericArray::Real(tensor))
326        }
327        Value::String(_) | Value::StringArray(_) => Err(builtin_error(format!(
328            "{name}: expected numeric input, got string"
329        ))),
330        Value::GpuTensor(_) => Err(builtin_error(format!(
331            "{name}: internal error converting GPU tensor"
332        ))),
333        other => {
334            let tensor =
335                tensor::value_into_tensor_for(name, other).map_err(|err| builtin_error(err))?;
336            Ok(NumericArray::Real(tensor))
337        }
338    }
339}
340
341enum NumericArray {
342    Real(Tensor),
343    Complex(ComplexTensor),
344}
345
346enum NumericPair {
347    Real(Tensor, Tensor),
348    Complex(ComplexTensor, ComplexTensor),
349}
350
351fn align_numeric_arrays(lhs: NumericArray, rhs: NumericArray) -> BuiltinResult<NumericPair> {
352    match (lhs, rhs) {
353        (NumericArray::Real(a), NumericArray::Real(b)) => Ok(NumericPair::Real(a, b)),
354        (left, right) => {
355            let lc = into_complex(left)?;
356            let rc = into_complex(right)?;
357            Ok(NumericPair::Complex(lc, rc))
358        }
359    }
360}
361
362fn into_complex(input: NumericArray) -> BuiltinResult<ComplexTensor> {
363    match input {
364        NumericArray::Real(t) => {
365            let Tensor { data, shape, .. } = t;
366            let complex: Vec<(f64, f64)> = data.into_iter().map(|re| (re, 0.0)).collect();
367            ComplexTensor::new(complex, shape).map_err(|e| builtin_error(format!("mod: {e}")))
368        }
369        NumericArray::Complex(ct) => Ok(ct),
370    }
371}
372
373#[cfg(test)]
374pub(crate) mod tests {
375    use super::*;
376    use crate::builtins::common::test_support;
377    use crate::RuntimeError;
378    use futures::executor::block_on;
379    use runmat_builtins::{
380        CharArray, ComplexTensor, IntValue, LogicalArray, ResolveContext, Tensor, Type,
381    };
382
383    fn mod_builtin(lhs: Value, rhs: Value) -> BuiltinResult<Value> {
384        block_on(super::mod_builtin(lhs, rhs))
385    }
386
387    fn assert_error_contains(error: RuntimeError, needle: &str) {
388        assert!(
389            error.message().contains(needle),
390            "unexpected error: {}",
391            error.message()
392        );
393    }
394
395    #[test]
396    fn mod_type_preserves_tensor_shape() {
397        let out = numeric_binary_type(
398            &[
399                Type::Tensor {
400                    shape: Some(vec![Some(2), Some(3)]),
401                },
402                Type::Tensor {
403                    shape: Some(vec![Some(2), Some(3)]),
404                },
405            ],
406            &ResolveContext::new(Vec::new()),
407        );
408        assert_eq!(
409            out,
410            Type::Tensor {
411                shape: Some(vec![Some(2), Some(3)])
412            }
413        );
414    }
415
416    #[test]
417    fn mod_type_scalar_and_tensor_returns_tensor() {
418        let out = numeric_binary_type(
419            &[
420                Type::Num,
421                Type::Tensor {
422                    shape: Some(vec![Some(4), Some(1)]),
423                },
424            ],
425            &ResolveContext::new(Vec::new()),
426        );
427        assert_eq!(
428            out,
429            Type::Tensor {
430                shape: Some(vec![Some(4), Some(1)])
431            }
432        );
433    }
434
435    #[test]
436    fn mod_type_scalar_returns_num() {
437        let out = numeric_binary_type(&[Type::Num, Type::Int], &ResolveContext::new(Vec::new()));
438        assert_eq!(out, Type::Num);
439    }
440
441    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
442    #[test]
443    fn mod_positive_values() {
444        let result = mod_builtin(Value::Num(17.0), Value::Num(5.0)).expect("mod");
445        match result {
446            Value::Num(v) => assert!((v - 2.0).abs() < 1e-12),
447            other => panic!("expected scalar result, got {other:?}"),
448        }
449    }
450
451    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
452    #[test]
453    fn mod_negative_divisor_keeps_sign() {
454        let tensor = Tensor::new(vec![-7.0, -3.0, 4.0, 9.0], vec![4, 1]).unwrap();
455        let divisor = Tensor::new(vec![-4.0], vec![1, 1]).unwrap();
456        let result =
457            mod_builtin(Value::Tensor(tensor), Value::Tensor(divisor)).expect("mod broadcast");
458        match result {
459            Value::Tensor(out) => {
460                assert_eq!(out.data, vec![-3.0, -3.0, 0.0, -3.0]);
461            }
462            other => panic!("expected tensor result, got {other:?}"),
463        }
464    }
465
466    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
467    #[test]
468    fn mod_negative_numerator_positive_divisor() {
469        let result = mod_builtin(Value::Num(-3.0), Value::Num(2.0)).expect("mod");
470        match result {
471            Value::Num(v) => assert!((v - 1.0).abs() < 1e-12),
472            other => panic!("expected scalar result, got {other:?}"),
473        }
474    }
475
476    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
477    #[test]
478    fn mod_zero_divisor_returns_nan() {
479        let result = mod_builtin(Value::Num(3.0), Value::Num(0.0)).expect("mod");
480        match result {
481            Value::Num(v) => assert!(v.is_nan()),
482            other => panic!("expected NaN, got {other:?}"),
483        }
484    }
485
486    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
487    #[test]
488    fn mod_matrix_scalar_broadcast() {
489        let matrix = Tensor::new(vec![4.5, 7.1, -2.3, 0.4], vec![2, 2]).unwrap();
490        let result = mod_builtin(Value::Tensor(matrix), Value::Num(2.0)).expect("broadcast");
491        match result {
492            Value::Tensor(t) => {
493                assert_eq!(t.shape, vec![2, 2]);
494                let expected = [0.5, 1.1, 1.7, 0.4];
495                for (a, b) in t.data.iter().zip(expected.iter()) {
496                    assert!((a - b).abs() < 1e-12);
497                }
498            }
499            other => panic!("expected tensor result, got {other:?}"),
500        }
501    }
502
503    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
504    #[test]
505    fn mod_complex_operands() {
506        let complex =
507            ComplexTensor::new(vec![(3.0, 4.0), (-2.0, 5.0)], vec![1, 2]).expect("complex tensor");
508        let divisor = ComplexTensor::new(vec![(2.0, 1.0)], vec![1, 1]).expect("divisor");
509        let result = mod_builtin(Value::ComplexTensor(complex), Value::ComplexTensor(divisor))
510            .expect("complex mod");
511        match result {
512            Value::ComplexTensor(out) => {
513                assert_eq!(out.shape, vec![1, 2]);
514                let expected = [(0.0, 0.0), (0.0, 1.0)];
515                for ((re, im), (er, ei)) in out.data.iter().zip(expected.iter()) {
516                    assert!((re - er).abs() < 1e-12);
517                    assert!((im - ei).abs() < 1e-12);
518                }
519            }
520            other => panic!("expected complex tensor result, got {other:?}"),
521        }
522    }
523
524    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
525    #[test]
526    fn mod_char_array_support() {
527        let chars = CharArray::new("ABC".chars().collect(), 1, 3).unwrap();
528        let result = mod_builtin(Value::CharArray(chars), Value::Num(5.0)).expect("mod");
529        match result {
530            Value::Tensor(t) => assert_eq!(t.data, vec![0.0, 1.0, 2.0]),
531            other => panic!("expected tensor result, got {other:?}"),
532        }
533    }
534
535    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
536    #[test]
537    fn mod_string_input_errors() {
538        let err = mod_builtin(Value::from("abc"), Value::Num(3.0))
539            .expect_err("string inputs should error");
540        assert_error_contains(err, "expected numeric input");
541    }
542
543    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
544    #[test]
545    fn mod_logical_array_support() {
546        let logical = LogicalArray::new(vec![1, 0, 1, 0], vec![2, 2]).unwrap();
547        let value =
548            mod_builtin(Value::LogicalArray(logical), Value::Num(2.0)).expect("logical mod");
549        match value {
550            Value::Tensor(t) => assert_eq!(t.data, vec![1.0, 0.0, 1.0, 0.0]),
551            other => panic!("expected tensor result, got {other:?}"),
552        }
553    }
554
555    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
556    #[test]
557    fn mod_vector_broadcasting() {
558        let lhs = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
559        let rhs = Tensor::new(vec![3.0, 4.0, 5.0], vec![1, 3]).unwrap();
560        let result = mod_builtin(Value::Tensor(lhs), Value::Tensor(rhs)).expect("vector broadcast");
561        match result {
562            Value::Tensor(t) => {
563                assert_eq!(t.shape, vec![2, 3]);
564                assert_eq!(t.data, vec![1.0, 2.0, 1.0, 2.0, 1.0, 2.0]);
565            }
566            other => panic!("expected tensor result, got {other:?}"),
567        }
568    }
569
570    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
571    #[test]
572    fn mod_nan_inputs_propagate() {
573        let result = mod_builtin(Value::Num(f64::NAN), Value::Num(3.0)).expect("mod");
574        match result {
575            Value::Num(v) => assert!(v.is_nan()),
576            other => panic!("expected NaN result, got {other:?}"),
577        }
578    }
579
580    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
581    #[test]
582    fn mod_gpu_pair_roundtrip() {
583        test_support::with_test_provider(|provider| {
584            let tensor = Tensor::new(vec![-5.0, -3.0, 0.0, 1.0, 6.0, 9.0], vec![3, 2]).unwrap();
585            let divisor = Tensor::new(vec![4.0, 4.0, 4.0, 4.0, 4.0, 4.0], vec![3, 2]).unwrap();
586            let a_view = runmat_accelerate_api::HostTensorView {
587                data: &tensor.data,
588                shape: &tensor.shape,
589            };
590            let b_view = runmat_accelerate_api::HostTensorView {
591                data: &divisor.data,
592                shape: &divisor.shape,
593            };
594            let a_handle = provider.upload(&a_view).expect("upload a");
595            let b_handle = provider.upload(&b_view).expect("upload b");
596            let result =
597                mod_builtin(Value::GpuTensor(a_handle), Value::GpuTensor(b_handle)).expect("mod");
598            let gathered = test_support::gather(result).expect("gather result");
599            assert_eq!(gathered.shape, vec![3, 2]);
600            assert_eq!(gathered.data, vec![3.0, 1.0, 0.0, 1.0, 2.0, 1.0]);
601        });
602    }
603
604    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
605    #[test]
606    fn mod_int_scalar_promotes() {
607        let result =
608            mod_builtin(Value::Int(IntValue::I32(-7)), Value::Int(IntValue::I32(4))).expect("mod");
609        match result {
610            Value::Num(v) => assert!((v - 1.0).abs() < 1e-12),
611            other => panic!("expected scalar result, got {other:?}"),
612        }
613    }
614
615    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
616    #[test]
617    #[cfg(feature = "wgpu")]
618    fn mod_wgpu_matches_cpu() {
619        let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
620            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
621        );
622        let numer = Tensor::new(vec![-5.0, -3.25, 0.0, 1.75, 6.5, 9.0], vec![3, 2]).unwrap();
623        let denom = Tensor::new(vec![4.0, -2.5, 3.0, 3.0, 2.0, -5.0], vec![3, 2]).unwrap();
624        let cpu_value =
625            mod_host(Value::Tensor(numer.clone()), Value::Tensor(denom.clone())).expect("cpu mod");
626
627        let provider = runmat_accelerate_api::provider().expect("wgpu provider registered");
628        let numer_handle = provider
629            .upload(&runmat_accelerate_api::HostTensorView {
630                data: &numer.data,
631                shape: &numer.shape,
632            })
633            .expect("upload numer");
634        let denom_handle = provider
635            .upload(&runmat_accelerate_api::HostTensorView {
636                data: &denom.data,
637                shape: &denom.shape,
638            })
639            .expect("upload denom");
640
641        let gpu_value = block_on(mod_gpu_pair(numer_handle, denom_handle)).expect("gpu mod");
642        let gpu_tensor = test_support::gather(gpu_value).expect("gather gpu result");
643
644        let cpu_tensor = match cpu_value {
645            Value::Tensor(t) => t,
646            Value::Num(n) => Tensor::new(vec![n], vec![1, 1]).expect("scalar tensor"),
647            other => panic!("unexpected CPU result {other:?}"),
648        };
649
650        assert_eq!(gpu_tensor.shape, cpu_tensor.shape);
651        let tol = match provider.precision() {
652            runmat_accelerate_api::ProviderPrecision::F64 => 1e-12,
653            runmat_accelerate_api::ProviderPrecision::F32 => 1e-5,
654        };
655        for (gpu, cpu) in gpu_tensor.data.iter().zip(cpu_tensor.data.iter()) {
656            assert!(
657                (gpu - cpu).abs() <= tol,
658                "|{gpu} - {cpu}| exceeded tolerance {tol}"
659            );
660        }
661    }
662}