Skip to main content

runmat_runtime/builtins/math/rounding/
mod.rs

1//! MATLAB-compatible `mod` builtin plus rounding helpers for RunMat.
2
3pub(crate) mod ceil;
4pub(crate) mod fix;
5pub(crate) mod floor;
6pub(crate) mod rem;
7pub(crate) mod round;
8
9use runmat_accelerate_api::GpuTensorHandle;
10use runmat_builtins::{
11    BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
12    BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
13    ComplexTensor, Tensor, Value,
14};
15use runmat_macros::runtime_builtin;
16
17use crate::builtins::common::broadcast::BroadcastPlan;
18use crate::builtins::common::spec::{
19    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, FusionError,
20    FusionExprContext, FusionKernelTemplate, GpuOpKind, ProviderHook, ReductionNaN,
21    ResidencyPolicy, ScalarType, ShapeRequirements,
22};
23use crate::builtins::common::{gpu_helpers, tensor};
24use crate::builtins::math::type_resolvers::numeric_binary_type;
25use crate::{build_runtime_error, BuiltinResult, RuntimeError};
26
27#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::math::rounding")]
28pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
29    name: "mod",
30    op_kind: GpuOpKind::Elementwise,
31    supported_precisions: &[ScalarType::F32, ScalarType::F64],
32    broadcast: BroadcastSemantics::Matlab,
33    provider_hooks: &[
34        ProviderHook::Binary {
35            name: "elem_div",
36            commutative: false,
37        },
38        ProviderHook::Unary { name: "unary_floor" },
39        ProviderHook::Binary {
40            name: "elem_mul",
41            commutative: false,
42        },
43        ProviderHook::Binary {
44            name: "elem_sub",
45            commutative: false,
46        },
47    ],
48    constant_strategy: ConstantStrategy::InlineLiteral,
49    residency: ResidencyPolicy::NewHandle,
50    nan_mode: ReductionNaN::Include,
51    two_pass_threshold: None,
52    workgroup_size: None,
53    accepts_nan_mode: false,
54    notes:
55        "Providers can keep mod on-device by composing elem_div → unary_floor → elem_mul → elem_sub for matching shapes. Future backends may expose a dedicated elem_mod hook.",
56};
57
58#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::math::rounding")]
59pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
60    name: "mod",
61    shape: ShapeRequirements::BroadcastCompatible,
62    constant_strategy: ConstantStrategy::InlineLiteral,
63    elementwise: Some(FusionKernelTemplate {
64        scalar_precisions: &[ScalarType::F32, ScalarType::F64],
65        wgsl_body: |ctx: &FusionExprContext| {
66            let a = ctx
67                .inputs
68                .first()
69                .ok_or(FusionError::MissingInput(0))?;
70            let b = ctx.inputs.get(1).ok_or(FusionError::MissingInput(1))?;
71            Ok(format!("{a} - {b} * floor({a} / {b})"))
72        },
73    }),
74    reduction: None,
75    emits_nan: true,
76    notes: "Fusion generates floor(a / b) followed by a - b * q; providers may substitute specialised kernels when available.",
77};
78
79const BUILTIN_NAME: &str = "mod";
80
81const MOD_OUTPUT: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
82    name: "R",
83    ty: BuiltinParamType::NumericArray,
84    arity: BuiltinParamArity::Required,
85    default: None,
86    description: "Element-wise modulus result.",
87}];
88const MOD_INPUTS: [BuiltinParamDescriptor; 2] = [
89    BuiltinParamDescriptor {
90        name: "A",
91        ty: BuiltinParamType::Any,
92        arity: BuiltinParamArity::Required,
93        default: None,
94        description: "Dividend input (numeric/logical/char/complex).",
95    },
96    BuiltinParamDescriptor {
97        name: "B",
98        ty: BuiltinParamType::Any,
99        arity: BuiltinParamArity::Required,
100        default: None,
101        description: "Divisor input (numeric/logical/char/complex).",
102    },
103];
104const MOD_SIGNATURES: [BuiltinSignatureDescriptor; 1] = [BuiltinSignatureDescriptor {
105    label: "R = mod(A, B)",
106    inputs: &MOD_INPUTS,
107    outputs: &MOD_OUTPUT,
108}];
109const MOD_ERROR_INVALID_INPUT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
110    code: "RM.MOD.INVALID_INPUT",
111    identifier: Some("RunMat:mod:InvalidInput"),
112    when: "Inputs cannot be interpreted as numeric, logical, char, or complex operands.",
113    message: "mod: invalid input",
114};
115const MOD_ERROR_SIZE_MISMATCH: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
116    code: "RM.MOD.SIZE_MISMATCH",
117    identifier: Some("RunMat:mod:SizeMismatch"),
118    when: "Operands are not broadcast-compatible.",
119    message: "mod: array sizes are not compatible for broadcasting",
120};
121const MOD_ERROR_INTERNAL: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
122    code: "RM.MOD.INTERNAL",
123    identifier: Some("RunMat:mod:Internal"),
124    when: "Internal tensor conversion, allocation, or provider composition failed.",
125    message: "mod: internal error",
126};
127const MOD_ERRORS: [BuiltinErrorDescriptor; 3] = [
128    MOD_ERROR_INVALID_INPUT,
129    MOD_ERROR_SIZE_MISMATCH,
130    MOD_ERROR_INTERNAL,
131];
132pub const MOD_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
133    signatures: &MOD_SIGNATURES,
134    output_mode: BuiltinOutputMode::Fixed,
135    completion_policy: BuiltinCompletionPolicy::Public,
136    errors: &MOD_ERRORS,
137};
138
139fn mod_error_with_detail(
140    error: &'static BuiltinErrorDescriptor,
141    detail: impl AsRef<str>,
142) -> RuntimeError {
143    mod_error_with_message(format!("{}: {}", error.message, detail.as_ref()), error)
144}
145
146fn mod_error_with_message(
147    message: impl Into<String>,
148    error: &'static BuiltinErrorDescriptor,
149) -> RuntimeError {
150    let mut builder = build_runtime_error(message).with_builtin(BUILTIN_NAME);
151    if let Some(identifier) = error.identifier {
152        builder = builder.with_identifier(identifier);
153    }
154    builder.build()
155}
156
157#[runtime_builtin(
158    name = "mod",
159    category = "math/rounding",
160    summary = "MATLAB-compatible modulus a - b .* floor(a./b) with support for complex values and broadcasting.",
161    keywords = "mod,modulus,remainder,gpu",
162    accel = "binary",
163    type_resolver(numeric_binary_type),
164    descriptor(crate::builtins::math::rounding::MOD_DESCRIPTOR),
165    builtin_path = "crate::builtins::math::rounding"
166)]
167async fn mod_builtin(lhs: Value, rhs: Value) -> BuiltinResult<Value> {
168    match (lhs, rhs) {
169        (Value::GpuTensor(a), Value::GpuTensor(b)) => mod_gpu_pair(a, b).await,
170        (Value::GpuTensor(a), other) => {
171            let gathered = gpu_helpers::gather_tensor_async(&a).await?;
172            mod_host(Value::Tensor(gathered), other)
173        }
174        (other, Value::GpuTensor(b)) => {
175            let gathered = gpu_helpers::gather_tensor_async(&b).await?;
176            mod_host(other, Value::Tensor(gathered))
177        }
178        (left, right) => mod_host(left, right),
179    }
180}
181
182async fn mod_gpu_pair(a: GpuTensorHandle, b: GpuTensorHandle) -> BuiltinResult<Value> {
183    if a.device_id == b.device_id {
184        if let Some(provider) = runmat_accelerate_api::provider_for_handle(&a) {
185            if a.shape == b.shape {
186                if let Ok(div) = provider.elem_div(&a, &b).await {
187                    match provider.unary_floor(&div).await {
188                        Ok(floored) => match provider.elem_mul(&b, &floored).await {
189                            Ok(mul) => match provider.elem_sub(&a, &mul).await {
190                                Ok(out) => {
191                                    let _ = provider.free(&div);
192                                    let _ = provider.free(&floored);
193                                    let _ = provider.free(&mul);
194                                    return Ok(gpu_helpers::resident_gpu_value(out));
195                                }
196                                Err(_) => {
197                                    let _ = provider.free(&mul);
198                                    let _ = provider.free(&floored);
199                                    let _ = provider.free(&div);
200                                }
201                            },
202                            Err(_) => {
203                                let _ = provider.free(&floored);
204                                let _ = provider.free(&div);
205                            }
206                        },
207                        Err(_) => {
208                            let _ = provider.free(&div);
209                        }
210                    }
211                }
212            }
213        }
214    }
215    let left = gpu_helpers::gather_tensor_async(&a).await?;
216    let right = gpu_helpers::gather_tensor_async(&b).await?;
217    mod_host(Value::Tensor(left), Value::Tensor(right))
218}
219
220fn mod_host(lhs: Value, rhs: Value) -> BuiltinResult<Value> {
221    if let Some(result) = scalar_mod_value(&lhs, &rhs) {
222        return Ok(result);
223    }
224    let left = value_into_numeric_array(lhs)?;
225    let right = value_into_numeric_array(rhs)?;
226    match align_numeric_arrays(left, right)? {
227        NumericPair::Real(a, b) => compute_mod_real(&a, &b),
228        NumericPair::Complex(a, b) => compute_mod_complex(&a, &b),
229    }
230}
231
232fn compute_mod_real(a: &Tensor, b: &Tensor) -> BuiltinResult<Value> {
233    let plan = BroadcastPlan::new(&a.shape, &b.shape)
234        .map_err(|err| mod_error_with_detail(&MOD_ERROR_SIZE_MISMATCH, err))?;
235    if plan.is_empty() {
236        let tensor = Tensor::new(Vec::new(), plan.output_shape().to_vec())
237            .map_err(|e| mod_error_with_detail(&MOD_ERROR_INTERNAL, e))?;
238        return Ok(tensor::tensor_into_value(tensor));
239    }
240    let mut result = vec![0.0f64; plan.len()];
241    for (out_idx, idx_a, idx_b) in plan.iter() {
242        let aval = a.data[idx_a];
243        let bval = b.data[idx_b];
244        result[out_idx] = mod_real_scalar(aval, bval);
245    }
246    let tensor = Tensor::new(result, plan.output_shape().to_vec())
247        .map_err(|e| mod_error_with_detail(&MOD_ERROR_INTERNAL, e))?;
248    Ok(tensor::tensor_into_value(tensor))
249}
250
251fn compute_mod_complex(a: &ComplexTensor, b: &ComplexTensor) -> BuiltinResult<Value> {
252    let plan = BroadcastPlan::new(&a.shape, &b.shape)
253        .map_err(|err| mod_error_with_detail(&MOD_ERROR_SIZE_MISMATCH, err))?;
254    if plan.is_empty() {
255        let tensor = ComplexTensor::new(Vec::new(), plan.output_shape().to_vec())
256            .map_err(|e| mod_error_with_detail(&MOD_ERROR_INTERNAL, e))?;
257        return Ok(complex_tensor_into_value(tensor));
258    }
259    let mut result = vec![(0.0f64, 0.0f64); plan.len()];
260    for (out_idx, idx_a, idx_b) in plan.iter() {
261        let (ar, ai) = a.data[idx_a];
262        let (br, bi) = b.data[idx_b];
263        result[out_idx] = mod_complex_scalar(ar, ai, br, bi);
264    }
265    let tensor = ComplexTensor::new(result, plan.output_shape().to_vec())
266        .map_err(|e| mod_error_with_detail(&MOD_ERROR_INTERNAL, e))?;
267    Ok(complex_tensor_into_value(tensor))
268}
269
270fn mod_real_scalar(a: f64, b: f64) -> f64 {
271    if a.is_nan() || b.is_nan() {
272        return f64::NAN;
273    }
274    if b == 0.0 {
275        return f64::NAN;
276    }
277    if !a.is_finite() && b.is_finite() {
278        return f64::NAN;
279    }
280    let quotient = (a / b).floor();
281    let mut remainder = a - b * quotient;
282    if remainder == 0.0 {
283        remainder = 0.0;
284    }
285    if b.is_infinite() && a.is_finite() {
286        // MATLAB sign-correction: mod(a, ±Inf) returns a when signs match, ±Inf otherwise.
287        if a == 0.0 {
288            return 0.0;
289        }
290        return if a.signum() == b.signum() { a } else { b };
291    }
292    if !remainder.is_finite() && !a.is_finite() {
293        return f64::NAN;
294    }
295    let same_sign = remainder == 0.0 || remainder.signum() == b.signum();
296    if !same_sign {
297        remainder += b;
298    }
299    if remainder == -0.0 {
300        remainder = 0.0;
301    }
302    remainder
303}
304
305fn mod_complex_scalar(ar: f64, ai: f64, br: f64, bi: f64) -> (f64, f64) {
306    if (ar.is_nan() || ai.is_nan()) || (br.is_nan() || bi.is_nan()) {
307        return (f64::NAN, f64::NAN);
308    }
309    if br == 0.0 && bi == 0.0 {
310        return (f64::NAN, f64::NAN);
311    }
312    if !ar.is_finite() || !ai.is_finite() {
313        return (f64::NAN, f64::NAN);
314    }
315    let (qr, qi) = complex_div(ar, ai, br, bi);
316    if !qr.is_finite() && !qi.is_finite() && br.is_finite() && bi.is_finite() {
317        return (f64::NAN, f64::NAN);
318    }
319    let (fr, fi) = (qr.floor(), qi.floor());
320    let (mulr, muli) = complex_mul(br, bi, fr, fi);
321    let (rr, ri) = (ar - mulr, ai - muli);
322    (normalize_zero(rr), normalize_zero(ri))
323}
324
325fn scalar_real_value(value: &Value) -> Option<f64> {
326    match value {
327        Value::Num(n) => Some(*n),
328        Value::Int(i) => Some(i.to_f64()),
329        Value::Bool(b) => Some(if *b { 1.0 } else { 0.0 }),
330        Value::Tensor(t) if t.data.len() == 1 => t.data.first().copied(),
331        Value::LogicalArray(l) if l.data.len() == 1 => Some(if l.data[0] != 0 { 1.0 } else { 0.0 }),
332        Value::CharArray(ca) if ca.rows * ca.cols == 1 => {
333            Some(ca.data.first().map(|&ch| ch as u32 as f64).unwrap_or(0.0))
334        }
335        _ => None,
336    }
337}
338
339fn scalar_complex_value(value: &Value) -> Option<(f64, f64)> {
340    match value {
341        Value::Complex(re, im) => Some((*re, *im)),
342        Value::ComplexTensor(ct) if ct.data.len() == 1 => ct.data.first().copied(),
343        _ => None,
344    }
345}
346
347fn scalar_mod_value(lhs: &Value, rhs: &Value) -> Option<Value> {
348    let left = scalar_complex_value(lhs).or_else(|| scalar_real_value(lhs).map(|v| (v, 0.0)))?;
349    let right = scalar_complex_value(rhs).or_else(|| scalar_real_value(rhs).map(|v| (v, 0.0)))?;
350    let (ar, ai) = left;
351    let (br, bi) = right;
352    if ai != 0.0 || bi != 0.0 {
353        let (re, im) = mod_complex_scalar(ar, ai, br, bi);
354        return Some(Value::Complex(re, im));
355    }
356    Some(Value::Num(mod_real_scalar(ar, br)))
357}
358
359fn normalize_zero(value: f64) -> f64 {
360    if value == -0.0 {
361        0.0
362    } else {
363        value
364    }
365}
366
367fn complex_mul(ar: f64, ai: f64, br: f64, bi: f64) -> (f64, f64) {
368    (ar * br - ai * bi, ar * bi + ai * br)
369}
370
371fn complex_div(ar: f64, ai: f64, br: f64, bi: f64) -> (f64, f64) {
372    let denom = br * br + bi * bi;
373    if denom == 0.0 {
374        return (f64::NAN, f64::NAN);
375    }
376    ((ar * br + ai * bi) / denom, (ai * br - ar * bi) / denom)
377}
378
379fn complex_tensor_into_value(tensor: ComplexTensor) -> Value {
380    if tensor.data.len() == 1 {
381        let (re, im) = tensor.data[0];
382        Value::Complex(re, im)
383    } else {
384        Value::ComplexTensor(tensor)
385    }
386}
387
388fn value_into_numeric_array(value: Value) -> BuiltinResult<NumericArray> {
389    match value {
390        Value::Complex(re, im) => {
391            let tensor = ComplexTensor::new(vec![(re, im)], vec![1, 1])
392                .map_err(|e| mod_error_with_detail(&MOD_ERROR_INTERNAL, e))?;
393            Ok(NumericArray::Complex(tensor))
394        }
395        Value::ComplexTensor(ct) => Ok(NumericArray::Complex(ct)),
396        Value::CharArray(ca) => {
397            let data: Vec<f64> = ca.data.iter().map(|&ch| ch as u32 as f64).collect();
398            let tensor = Tensor::new(data, vec![ca.rows, ca.cols])
399                .map_err(|e| mod_error_with_detail(&MOD_ERROR_INTERNAL, e))?;
400            Ok(NumericArray::Real(tensor))
401        }
402        Value::String(_) | Value::StringArray(_) => Err(mod_error_with_detail(
403            &MOD_ERROR_INVALID_INPUT,
404            "expected numeric input, got string",
405        )),
406        Value::GpuTensor(_) => Err(mod_error_with_detail(
407            &MOD_ERROR_INTERNAL,
408            "internal error converting GPU tensor",
409        )),
410        other => {
411            let tensor = tensor::value_into_tensor_for(BUILTIN_NAME, other)
412                .map_err(|err| mod_error_with_detail(&MOD_ERROR_INVALID_INPUT, err))?;
413            Ok(NumericArray::Real(tensor))
414        }
415    }
416}
417
418enum NumericArray {
419    Real(Tensor),
420    Complex(ComplexTensor),
421}
422
423enum NumericPair {
424    Real(Tensor, Tensor),
425    Complex(ComplexTensor, ComplexTensor),
426}
427
428fn align_numeric_arrays(lhs: NumericArray, rhs: NumericArray) -> BuiltinResult<NumericPair> {
429    match (lhs, rhs) {
430        (NumericArray::Real(a), NumericArray::Real(b)) => Ok(NumericPair::Real(a, b)),
431        (left, right) => {
432            let lc = into_complex(left)?;
433            let rc = into_complex(right)?;
434            Ok(NumericPair::Complex(lc, rc))
435        }
436    }
437}
438
439fn into_complex(input: NumericArray) -> BuiltinResult<ComplexTensor> {
440    match input {
441        NumericArray::Real(t) => {
442            let Tensor { data, shape, .. } = t;
443            let complex: Vec<(f64, f64)> = data.into_iter().map(|re| (re, 0.0)).collect();
444            ComplexTensor::new(complex, shape)
445                .map_err(|e| mod_error_with_detail(&MOD_ERROR_INTERNAL, e))
446        }
447        NumericArray::Complex(ct) => Ok(ct),
448    }
449}
450
451#[cfg(test)]
452pub(crate) mod tests {
453    use super::*;
454    use crate::builtins::common::test_support;
455    use crate::RuntimeError;
456    use futures::executor::block_on;
457    use runmat_builtins::{
458        CharArray, ComplexTensor, IntValue, LogicalArray, ResolveContext, Tensor, Type,
459    };
460
461    fn mod_builtin(lhs: Value, rhs: Value) -> BuiltinResult<Value> {
462        block_on(super::mod_builtin(lhs, rhs))
463    }
464
465    fn assert_error_contains(error: RuntimeError, needle: &str) {
466        assert!(
467            error.message().contains(needle),
468            "unexpected error: {}",
469            error.message()
470        );
471    }
472
473    #[test]
474    fn mod_descriptor_signatures_cover_core_forms() {
475        let labels: Vec<&str> = MOD_DESCRIPTOR
476            .signatures
477            .iter()
478            .map(|sig| sig.label)
479            .collect();
480        assert!(labels.contains(&"R = mod(A, B)"));
481    }
482
483    #[test]
484    fn mod_type_preserves_tensor_shape() {
485        let out = numeric_binary_type(
486            &[
487                Type::Tensor {
488                    shape: Some(vec![Some(2), Some(3)]),
489                },
490                Type::Tensor {
491                    shape: Some(vec![Some(2), Some(3)]),
492                },
493            ],
494            &ResolveContext::new(Vec::new()),
495        );
496        assert_eq!(
497            out,
498            Type::Tensor {
499                shape: Some(vec![Some(2), Some(3)])
500            }
501        );
502    }
503
504    #[test]
505    fn mod_type_scalar_and_tensor_returns_tensor() {
506        let out = numeric_binary_type(
507            &[
508                Type::Num,
509                Type::Tensor {
510                    shape: Some(vec![Some(4), Some(1)]),
511                },
512            ],
513            &ResolveContext::new(Vec::new()),
514        );
515        assert_eq!(
516            out,
517            Type::Tensor {
518                shape: Some(vec![Some(4), Some(1)])
519            }
520        );
521    }
522
523    #[test]
524    fn mod_type_scalar_returns_num() {
525        let out = numeric_binary_type(&[Type::Num, Type::Int], &ResolveContext::new(Vec::new()));
526        assert_eq!(out, Type::Num);
527    }
528
529    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
530    #[test]
531    fn mod_positive_values() {
532        let result = mod_builtin(Value::Num(17.0), Value::Num(5.0)).expect("mod");
533        match result {
534            Value::Num(v) => assert!((v - 2.0).abs() < 1e-12),
535            other => panic!("expected scalar result, got {other:?}"),
536        }
537    }
538
539    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
540    #[test]
541    fn mod_negative_divisor_keeps_sign() {
542        let tensor = Tensor::new(vec![-7.0, -3.0, 4.0, 9.0], vec![4, 1]).unwrap();
543        let divisor = Tensor::new(vec![-4.0], vec![1, 1]).unwrap();
544        let result =
545            mod_builtin(Value::Tensor(tensor), Value::Tensor(divisor)).expect("mod broadcast");
546        match result {
547            Value::Tensor(out) => {
548                assert_eq!(out.data, vec![-3.0, -3.0, 0.0, -3.0]);
549            }
550            other => panic!("expected tensor result, got {other:?}"),
551        }
552    }
553
554    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
555    #[test]
556    fn mod_negative_numerator_positive_divisor() {
557        let result = mod_builtin(Value::Num(-3.0), Value::Num(2.0)).expect("mod");
558        match result {
559            Value::Num(v) => assert!((v - 1.0).abs() < 1e-12),
560            other => panic!("expected scalar result, got {other:?}"),
561        }
562    }
563
564    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
565    #[test]
566    fn mod_zero_divisor_returns_nan() {
567        let result = mod_builtin(Value::Num(3.0), Value::Num(0.0)).expect("mod");
568        match result {
569            Value::Num(v) => assert!(v.is_nan()),
570            other => panic!("expected NaN, got {other:?}"),
571        }
572    }
573
574    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
575    #[test]
576    fn mod_matrix_scalar_broadcast() {
577        let matrix = Tensor::new(vec![4.5, 7.1, -2.3, 0.4], vec![2, 2]).unwrap();
578        let result = mod_builtin(Value::Tensor(matrix), Value::Num(2.0)).expect("broadcast");
579        match result {
580            Value::Tensor(t) => {
581                assert_eq!(t.shape, vec![2, 2]);
582                let expected = [0.5, 1.1, 1.7, 0.4];
583                for (a, b) in t.data.iter().zip(expected.iter()) {
584                    assert!((a - b).abs() < 1e-12);
585                }
586            }
587            other => panic!("expected tensor result, got {other:?}"),
588        }
589    }
590
591    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
592    #[test]
593    fn mod_complex_operands() {
594        let complex =
595            ComplexTensor::new(vec![(3.0, 4.0), (-2.0, 5.0)], vec![1, 2]).expect("complex tensor");
596        let divisor = ComplexTensor::new(vec![(2.0, 1.0)], vec![1, 1]).expect("divisor");
597        let result = mod_builtin(Value::ComplexTensor(complex), Value::ComplexTensor(divisor))
598            .expect("complex mod");
599        match result {
600            Value::ComplexTensor(out) => {
601                assert_eq!(out.shape, vec![1, 2]);
602                let expected = [(0.0, 0.0), (0.0, 1.0)];
603                for ((re, im), (er, ei)) in out.data.iter().zip(expected.iter()) {
604                    assert!((re - er).abs() < 1e-12);
605                    assert!((im - ei).abs() < 1e-12);
606                }
607            }
608            other => panic!("expected complex tensor result, got {other:?}"),
609        }
610    }
611
612    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
613    #[test]
614    fn mod_char_array_support() {
615        let chars = CharArray::new("ABC".chars().collect(), 1, 3).unwrap();
616        let result = mod_builtin(Value::CharArray(chars), Value::Num(5.0)).expect("mod");
617        match result {
618            Value::Tensor(t) => assert_eq!(t.data, vec![0.0, 1.0, 2.0]),
619            other => panic!("expected tensor result, got {other:?}"),
620        }
621    }
622
623    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
624    #[test]
625    fn mod_string_input_errors() {
626        let err = mod_builtin(Value::from("abc"), Value::Num(3.0))
627            .expect_err("string inputs should error");
628        let identifier = err.identifier().map(str::to_string);
629        assert_error_contains(err, "expected numeric input");
630        assert_eq!(identifier.as_deref(), MOD_ERROR_INVALID_INPUT.identifier);
631    }
632
633    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
634    #[test]
635    fn mod_logical_array_support() {
636        let logical = LogicalArray::new(vec![1, 0, 1, 0], vec![2, 2]).unwrap();
637        let value =
638            mod_builtin(Value::LogicalArray(logical), Value::Num(2.0)).expect("logical mod");
639        match value {
640            Value::Tensor(t) => assert_eq!(t.data, vec![1.0, 0.0, 1.0, 0.0]),
641            other => panic!("expected tensor result, got {other:?}"),
642        }
643    }
644
645    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
646    #[test]
647    fn mod_vector_broadcasting() {
648        let lhs = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
649        let rhs = Tensor::new(vec![3.0, 4.0, 5.0], vec![1, 3]).unwrap();
650        let result = mod_builtin(Value::Tensor(lhs), Value::Tensor(rhs)).expect("vector broadcast");
651        match result {
652            Value::Tensor(t) => {
653                assert_eq!(t.shape, vec![2, 3]);
654                assert_eq!(t.data, vec![1.0, 2.0, 1.0, 2.0, 1.0, 2.0]);
655            }
656            other => panic!("expected tensor result, got {other:?}"),
657        }
658    }
659
660    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
661    #[test]
662    fn mod_nan_inputs_propagate() {
663        let result = mod_builtin(Value::Num(f64::NAN), Value::Num(3.0)).expect("mod");
664        match result {
665            Value::Num(v) => assert!(v.is_nan()),
666            other => panic!("expected NaN result, got {other:?}"),
667        }
668    }
669
670    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
671    #[test]
672    fn mod_gpu_pair_roundtrip() {
673        test_support::with_test_provider(|provider| {
674            let tensor = Tensor::new(vec![-5.0, -3.0, 0.0, 1.0, 6.0, 9.0], vec![3, 2]).unwrap();
675            let divisor = Tensor::new(vec![4.0, 4.0, 4.0, 4.0, 4.0, 4.0], vec![3, 2]).unwrap();
676            let a_view = runmat_accelerate_api::HostTensorView {
677                data: &tensor.data,
678                shape: &tensor.shape,
679            };
680            let b_view = runmat_accelerate_api::HostTensorView {
681                data: &divisor.data,
682                shape: &divisor.shape,
683            };
684            let a_handle = provider.upload(&a_view).expect("upload a");
685            let b_handle = provider.upload(&b_view).expect("upload b");
686            let result =
687                mod_builtin(Value::GpuTensor(a_handle), Value::GpuTensor(b_handle)).expect("mod");
688            let gathered = test_support::gather(result).expect("gather result");
689            assert_eq!(gathered.shape, vec![3, 2]);
690            assert_eq!(gathered.data, vec![3.0, 1.0, 0.0, 1.0, 2.0, 1.0]);
691        });
692    }
693
694    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
695    #[test]
696    fn mod_int_scalar_promotes() {
697        let result =
698            mod_builtin(Value::Int(IntValue::I32(-7)), Value::Int(IntValue::I32(4))).expect("mod");
699        match result {
700            Value::Num(v) => assert!((v - 1.0).abs() < 1e-12),
701            other => panic!("expected scalar result, got {other:?}"),
702        }
703    }
704
705    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
706    #[test]
707    #[cfg(feature = "wgpu")]
708    fn mod_wgpu_matches_cpu() {
709        let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
710            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
711        );
712        let numer = Tensor::new(vec![-5.0, -3.25, 0.0, 1.75, 6.5, 9.0], vec![3, 2]).unwrap();
713        let denom = Tensor::new(vec![4.0, -2.5, 3.0, 3.0, 2.0, -5.0], vec![3, 2]).unwrap();
714        let cpu_value =
715            mod_host(Value::Tensor(numer.clone()), Value::Tensor(denom.clone())).expect("cpu mod");
716
717        let provider = runmat_accelerate_api::provider().expect("wgpu provider registered");
718        let numer_handle = provider
719            .upload(&runmat_accelerate_api::HostTensorView {
720                data: &numer.data,
721                shape: &numer.shape,
722            })
723            .expect("upload numer");
724        let denom_handle = provider
725            .upload(&runmat_accelerate_api::HostTensorView {
726                data: &denom.data,
727                shape: &denom.shape,
728            })
729            .expect("upload denom");
730
731        let gpu_value = block_on(mod_gpu_pair(numer_handle, denom_handle)).expect("gpu mod");
732        let gpu_tensor = test_support::gather(gpu_value).expect("gather gpu result");
733
734        let cpu_tensor = match cpu_value {
735            Value::Tensor(t) => t,
736            Value::Num(n) => Tensor::new(vec![n], vec![1, 1]).expect("scalar tensor"),
737            other => panic!("unexpected CPU result {other:?}"),
738        };
739
740        assert_eq!(gpu_tensor.shape, cpu_tensor.shape);
741        let tol = match provider.precision() {
742            runmat_accelerate_api::ProviderPrecision::F64 => 1e-12,
743            runmat_accelerate_api::ProviderPrecision::F32 => 1e-5,
744        };
745        for (gpu, cpu) in gpu_tensor.data.iter().zip(cpu_tensor.data.iter()) {
746            assert!(
747                (gpu - cpu).abs() <= tol,
748                "|{gpu} - {cpu}| exceeded tolerance {tol}"
749            );
750        }
751    }
752}