1mod ceil;
4mod fix;
5mod floor;
6mod rem;
7mod round;
8
9use runmat_accelerate_api::GpuTensorHandle;
10use runmat_builtins::{ComplexTensor, Tensor, Value};
11use runmat_macros::runtime_builtin;
12
13use crate::builtins::common::broadcast::BroadcastPlan;
14use crate::builtins::common::spec::{
15 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, FusionError,
16 FusionExprContext, FusionKernelTemplate, GpuOpKind, ProviderHook, ReductionNaN,
17 ResidencyPolicy, ScalarType, ShapeRequirements,
18};
19use crate::builtins::common::{gpu_helpers, tensor};
20#[cfg(feature = "doc_export")]
21use crate::register_builtin_doc_text;
22use crate::{register_builtin_fusion_spec, register_builtin_gpu_spec};
23
24#[cfg(feature = "doc_export")]
25pub const DOC_MD: &str = r#"---
26title: "mod"
27category: "math/rounding"
28keywords: ["mod", "modulus", "remainder", "rounding", "gpu"]
29summary: "Compute the MATLAB-style modulus a - b .* floor(a./b) for scalars, matrices, N-D tensors, and complex values."
30references: ["https://www.mathworks.com/help/matlab/ref/mod.html"]
31gpu_support:
32 elementwise: true
33 reduction: false
34 precisions: ["f32", "f64"]
35 broadcasting: "matlab"
36 notes: "Composed from elem_div → unary_floor → elem_mul → elem_sub when all GPU operands share a shape; otherwise RunMat gathers to the host."
37fusion:
38 elementwise: true
39 reduction: false
40 max_inputs: 2
41 constants: "inline"
42requires_feature: null
43tested:
44 unit: "builtins::math::rounding::mod::tests"
45 integration: "builtins::math::rounding::mod::tests::mod_gpu_pair_roundtrip"
46---
47
48# What does the `mod` function do in MATLAB / RunMat?
49`C = mod(A, B)` returns the modulus after division such that `C` has the same sign as `B` and satisfies `A = B.*Q + C` with `Q = floor(A./B)`.
50The definition holds for scalars, vectors, matrices, higher-dimensional tensors, and complex numbers.
51
52## How does the `mod` function behave in MATLAB / RunMat?
53- Works with MATLAB-style implicit expansion (broadcasting) between `A` and `B`.
54- Returns `NaN` for elements where `B` is zero or both arguments are non-finite in incompatible ways (`Inf` modulo finite, `NaN` inputs, etc.).
55- Logical and integer inputs are promoted to double precision; character arrays operate on their Unicode code points.
56- Complex inputs use the MATLAB definition `mod(a, b) = a - b.*floor(a./b)` with complex division and component-wise `floor`.
57- Empty arrays propagate emptiness while retaining their shapes.
58
59## `mod` Function GPU Execution Behaviour
60When both operands are GPU tensors with the same shape, RunMat composes `mod` from the provider hooks `elem_div`, `unary_floor`, `elem_mul`, and `elem_sub`.
61This keeps the computation on the device when those hooks are implemented (the shipped WGPU backend and in-process provider expose them).
62For mixed residency, shape-mismatched operands, or providers that lack any of these hooks, RunMat gathers to the host, applies the CPU implementation, and returns a host-resident result.
63
64## Examples of using the `mod` function in MATLAB / RunMat
65
66### Computing the modulus of positive integers
67
68```matlab
69r = mod(17, 5);
70```
71
72Expected output:
73
74```matlab
75r = 2;
76```
77
78### Modulus with negative divisors keeps the divisor's sign
79
80```matlab
81values = [-7 -3 4 9];
82mods = mod(values, -4);
83```
84
85Expected output:
86
87```matlab
88mods = [-3 -3 0 -3];
89```
90
91### Broadcasting a scalar divisor across a matrix
92
93```matlab
94A = [4.5 7.1; -2.3 0.4];
95result = mod(A, 2);
96```
97
98Expected output:
99
100```matlab
101result =
102 [0.5 1.1;
103 1.7 0.4]
104```
105
106### MATLAB-compatible modulus for complex numbers
107
108```matlab
109z = [3 + 4i, -2 + 5i];
110div = 2 + 1i;
111res = mod(z, div);
112```
113
114Expected output:
115
116```matlab
117res =
118 [0.0 + 0.0i, 0.0 + 1.0i]
119```
120
121### Handling zeros in the divisor
122
123```matlab
124warn = mod([2, 0, -2], [0, 0, 0]);
125```
126
127Expected output:
128
129```matlab
130warn = [NaN NaN NaN];
131```
132
133### Using `mod` with character arrays
134
135```matlab
136letters = mod('ABC', 5);
137```
138
139Expected output:
140
141```matlab
142letters = [0 1 2];
143```
144
145### Staying on the GPU when hooks are available
146
147```matlab
148G = gpuArray(-5:5);
149H = mod(G, 4);
150cpuCopy = gather(H);
151```
152
153Expected output:
154
155```matlab
156cpuCopy = [3 0 1 2 3 0 1 2 3 0 1];
157```
158
159## GPU residency in RunMat (Do I need `gpuArray`?)
160Usually not. When the provider exposes the elementwise hooks noted above, `mod` executes entirely on the GPU.
161Otherwise RunMat gathers transparently, ensuring MATLAB-compatible behaviour without manual intervention.
162Explicit `gpuArray` / `gather` calls remain available for scripts that mirror MathWorks MATLAB workflows.
163
164## FAQ
165
1661. **How is `mod` different from `rem`?** `mod` uses `floor` and keeps the sign of the divisor. `rem` uses `fix` and keeps the sign of the dividend.
1672. **What happens when the divisor is zero?** The result is `NaN` (or `NaN + NaNi` for complex inputs), matching MATLAB semantics.
1683. **Does `mod` support complex numbers?** Yes. Both operands can be complex; the runtime applies MATLAB's definition with complex division and component-wise `floor`.
1694. **Do GPU sources need identical shapes?** Yes. The device fast path currently requires both operands to share the same shape. Other cases fall back to the CPU implementation automatically.
1705. **Are empty arrays preserved?** Yes. Empty inputs return empty outputs with the same shape.
1716. **Will `mod` ever change integer classes?** Inputs promote to double precision internally; results are reported as double scalars or tensors, mirroring MATLAB's default numeric type.
172
173## See Also
174[rem](./rem), [floor](./floor), [fix](./fix), [gpuArray](../../acceleration/gpu/gpuArray), [gather](../../acceleration/gpu/gather)
175
176## Source & Feedback
177- Source: [`crates/runmat-runtime/src/builtins/math/rounding/mod.rs`](https://github.com/runmat-org/runmat/blob/main/crates/runmat-runtime/src/builtins/math/rounding/mod.rs)
178- Found a bug or behavioural difference? [Open an issue](https://github.com/runmat-org/runmat/issues/new/choose) with a minimal repro.
179"#;
180
181pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
182 name: "mod",
183 op_kind: GpuOpKind::Elementwise,
184 supported_precisions: &[ScalarType::F32, ScalarType::F64],
185 broadcast: BroadcastSemantics::Matlab,
186 provider_hooks: &[
187 ProviderHook::Binary {
188 name: "elem_div",
189 commutative: false,
190 },
191 ProviderHook::Unary { name: "unary_floor" },
192 ProviderHook::Binary {
193 name: "elem_mul",
194 commutative: false,
195 },
196 ProviderHook::Binary {
197 name: "elem_sub",
198 commutative: false,
199 },
200 ],
201 constant_strategy: ConstantStrategy::InlineLiteral,
202 residency: ResidencyPolicy::NewHandle,
203 nan_mode: ReductionNaN::Include,
204 two_pass_threshold: None,
205 workgroup_size: None,
206 accepts_nan_mode: false,
207 notes:
208 "Providers can keep mod on-device by composing elem_div → unary_floor → elem_mul → elem_sub for matching shapes. Future backends may expose a dedicated elem_mod hook.",
209};
210
211register_builtin_gpu_spec!(GPU_SPEC);
212
213pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
214 name: "mod",
215 shape: ShapeRequirements::BroadcastCompatible,
216 constant_strategy: ConstantStrategy::InlineLiteral,
217 elementwise: Some(FusionKernelTemplate {
218 scalar_precisions: &[ScalarType::F32, ScalarType::F64],
219 wgsl_body: |ctx: &FusionExprContext| {
220 let a = ctx
221 .inputs
222 .first()
223 .ok_or(FusionError::MissingInput(0))?;
224 let b = ctx.inputs.get(1).ok_or(FusionError::MissingInput(1))?;
225 Ok(format!("{a} - {b} * floor({a} / {b})"))
226 },
227 }),
228 reduction: None,
229 emits_nan: true,
230 notes: "Fusion generates floor(a / b) followed by a - b * q; providers may substitute specialised kernels when available.",
231};
232
233register_builtin_fusion_spec!(FUSION_SPEC);
234
235#[cfg(feature = "doc_export")]
236register_builtin_doc_text!("mod", DOC_MD);
237
238#[runtime_builtin(
239 name = "mod",
240 category = "math/rounding",
241 summary = "MATLAB-compatible modulus a - b .* floor(a./b) with support for complex values and broadcasting.",
242 keywords = "mod,modulus,remainder,gpu",
243 accel = "binary"
244)]
245fn mod_builtin(lhs: Value, rhs: Value) -> Result<Value, String> {
246 match (lhs, rhs) {
247 (Value::GpuTensor(a), Value::GpuTensor(b)) => mod_gpu_pair(a, b),
248 (Value::GpuTensor(a), other) => {
249 let gathered = gpu_helpers::gather_tensor(&a)?;
250 mod_host(Value::Tensor(gathered), other)
251 }
252 (other, Value::GpuTensor(b)) => {
253 let gathered = gpu_helpers::gather_tensor(&b)?;
254 mod_host(other, Value::Tensor(gathered))
255 }
256 (left, right) => mod_host(left, right),
257 }
258}
259
260fn mod_gpu_pair(a: GpuTensorHandle, b: GpuTensorHandle) -> Result<Value, String> {
261 if a.device_id == b.device_id {
262 if let Some(provider) = runmat_accelerate_api::provider_for_handle(&a) {
263 if a.shape == b.shape {
264 if let Ok(div) = provider.elem_div(&a, &b) {
265 match provider.unary_floor(&div) {
266 Ok(floored) => match provider.elem_mul(&b, &floored) {
267 Ok(mul) => match provider.elem_sub(&a, &mul) {
268 Ok(out) => {
269 let _ = provider.free(&div);
270 let _ = provider.free(&floored);
271 let _ = provider.free(&mul);
272 return Ok(Value::GpuTensor(out));
273 }
274 Err(_) => {
275 let _ = provider.free(&mul);
276 let _ = provider.free(&floored);
277 let _ = provider.free(&div);
278 }
279 },
280 Err(_) => {
281 let _ = provider.free(&floored);
282 let _ = provider.free(&div);
283 }
284 },
285 Err(_) => {
286 let _ = provider.free(&div);
287 }
288 }
289 }
290 }
291 }
292 }
293 let left = gpu_helpers::gather_tensor(&a)?;
294 let right = gpu_helpers::gather_tensor(&b)?;
295 mod_host(Value::Tensor(left), Value::Tensor(right))
296}
297
298fn mod_host(lhs: Value, rhs: Value) -> Result<Value, String> {
299 let left = value_into_numeric_array(lhs, "mod")?;
300 let right = value_into_numeric_array(rhs, "mod")?;
301 match align_numeric_arrays(left, right) {
302 Ok(NumericPair::Real(a, b)) => compute_mod_real(&a, &b),
303 Ok(NumericPair::Complex(a, b)) => compute_mod_complex(&a, &b),
304 Err(err) => Err(err),
305 }
306}
307
308fn compute_mod_real(a: &Tensor, b: &Tensor) -> Result<Value, String> {
309 let plan = BroadcastPlan::new(&a.shape, &b.shape).map_err(|err| format!("mod: {err}"))?;
310 if plan.is_empty() {
311 let tensor = Tensor::new(Vec::new(), plan.output_shape().to_vec())
312 .map_err(|e| format!("mod: {e}"))?;
313 return Ok(tensor::tensor_into_value(tensor));
314 }
315 let mut result = vec![0.0f64; plan.len()];
316 for (out_idx, idx_a, idx_b) in plan.iter() {
317 let aval = a.data[idx_a];
318 let bval = b.data[idx_b];
319 result[out_idx] = mod_real_scalar(aval, bval);
320 }
321 let tensor =
322 Tensor::new(result, plan.output_shape().to_vec()).map_err(|e| format!("mod: {e}"))?;
323 Ok(tensor::tensor_into_value(tensor))
324}
325
326fn compute_mod_complex(a: &ComplexTensor, b: &ComplexTensor) -> Result<Value, String> {
327 let plan = BroadcastPlan::new(&a.shape, &b.shape).map_err(|err| format!("mod: {err}"))?;
328 if plan.is_empty() {
329 let tensor = ComplexTensor::new(Vec::new(), plan.output_shape().to_vec())
330 .map_err(|e| format!("mod: {e}"))?;
331 return Ok(complex_tensor_into_value(tensor));
332 }
333 let mut result = vec![(0.0f64, 0.0f64); plan.len()];
334 for (out_idx, idx_a, idx_b) in plan.iter() {
335 let (ar, ai) = a.data[idx_a];
336 let (br, bi) = b.data[idx_b];
337 result[out_idx] = mod_complex_scalar(ar, ai, br, bi);
338 }
339 let tensor = ComplexTensor::new(result, plan.output_shape().to_vec())
340 .map_err(|e| format!("mod: {e}"))?;
341 Ok(complex_tensor_into_value(tensor))
342}
343
344fn mod_real_scalar(a: f64, b: f64) -> f64 {
345 if a.is_nan() || b.is_nan() {
346 return f64::NAN;
347 }
348 if b == 0.0 {
349 return f64::NAN;
350 }
351 if !a.is_finite() && b.is_finite() {
352 return f64::NAN;
353 }
354 let quotient = (a / b).floor();
355 let mut remainder = a - b * quotient;
356 if remainder == 0.0 {
357 remainder = 0.0;
358 }
359 if b.is_infinite() && a.is_finite() {
360 return a;
361 }
362 if !remainder.is_finite() && !a.is_finite() {
363 return f64::NAN;
364 }
365 let same_sign = remainder == 0.0 || remainder.signum() == b.signum();
366 if !same_sign {
367 remainder += b;
368 }
369 if remainder == -0.0 {
370 remainder = 0.0;
371 }
372 remainder
373}
374
375fn mod_complex_scalar(ar: f64, ai: f64, br: f64, bi: f64) -> (f64, f64) {
376 if (ar.is_nan() || ai.is_nan()) || (br.is_nan() || bi.is_nan()) {
377 return (f64::NAN, f64::NAN);
378 }
379 if br == 0.0 && bi == 0.0 {
380 return (f64::NAN, f64::NAN);
381 }
382 if !ar.is_finite() || !ai.is_finite() {
383 return (f64::NAN, f64::NAN);
384 }
385 let (qr, qi) = complex_div(ar, ai, br, bi);
386 if !qr.is_finite() && !qi.is_finite() && br.is_finite() && bi.is_finite() {
387 return (f64::NAN, f64::NAN);
388 }
389 let (fr, fi) = (qr.floor(), qi.floor());
390 let (mulr, muli) = complex_mul(br, bi, fr, fi);
391 let (rr, ri) = (ar - mulr, ai - muli);
392 (normalize_zero(rr), normalize_zero(ri))
393}
394
395fn normalize_zero(value: f64) -> f64 {
396 if value == -0.0 {
397 0.0
398 } else {
399 value
400 }
401}
402
403fn complex_mul(ar: f64, ai: f64, br: f64, bi: f64) -> (f64, f64) {
404 (ar * br - ai * bi, ar * bi + ai * br)
405}
406
407fn complex_div(ar: f64, ai: f64, br: f64, bi: f64) -> (f64, f64) {
408 let denom = br * br + bi * bi;
409 if denom == 0.0 {
410 return (f64::NAN, f64::NAN);
411 }
412 ((ar * br + ai * bi) / denom, (ai * br - ar * bi) / denom)
413}
414
415fn complex_tensor_into_value(tensor: ComplexTensor) -> Value {
416 if tensor.data.len() == 1 {
417 let (re, im) = tensor.data[0];
418 Value::Complex(re, im)
419 } else {
420 Value::ComplexTensor(tensor)
421 }
422}
423
424fn value_into_numeric_array(value: Value, name: &str) -> Result<NumericArray, String> {
425 match value {
426 Value::Complex(re, im) => {
427 let tensor = ComplexTensor::new(vec![(re, im)], vec![1, 1])
428 .map_err(|e| format!("{name}: {e}"))?;
429 Ok(NumericArray::Complex(tensor))
430 }
431 Value::ComplexTensor(ct) => Ok(NumericArray::Complex(ct)),
432 Value::CharArray(ca) => {
433 let data: Vec<f64> = ca.data.iter().map(|&ch| ch as u32 as f64).collect();
434 let tensor =
435 Tensor::new(data, vec![ca.rows, ca.cols]).map_err(|e| format!("{name}: {e}"))?;
436 Ok(NumericArray::Real(tensor))
437 }
438 Value::String(_) | Value::StringArray(_) => {
439 Err(format!("{name}: expected numeric input, got string"))
440 }
441 Value::GpuTensor(_) => Err(format!("{name}: internal error converting GPU tensor")),
442 other => {
443 let tensor = tensor::value_into_tensor_for(name, other)?;
444 Ok(NumericArray::Real(tensor))
445 }
446 }
447}
448
449enum NumericArray {
450 Real(Tensor),
451 Complex(ComplexTensor),
452}
453
454enum NumericPair {
455 Real(Tensor, Tensor),
456 Complex(ComplexTensor, ComplexTensor),
457}
458
459fn align_numeric_arrays(lhs: NumericArray, rhs: NumericArray) -> Result<NumericPair, String> {
460 match (lhs, rhs) {
461 (NumericArray::Real(a), NumericArray::Real(b)) => Ok(NumericPair::Real(a, b)),
462 (left, right) => {
463 let lc = into_complex(left)?;
464 let rc = into_complex(right)?;
465 Ok(NumericPair::Complex(lc, rc))
466 }
467 }
468}
469
470fn into_complex(input: NumericArray) -> Result<ComplexTensor, String> {
471 match input {
472 NumericArray::Real(t) => {
473 let Tensor { data, shape, .. } = t;
474 let complex: Vec<(f64, f64)> = data.into_iter().map(|re| (re, 0.0)).collect();
475 ComplexTensor::new(complex, shape).map_err(|e| format!("mod: {e}"))
476 }
477 NumericArray::Complex(ct) => Ok(ct),
478 }
479}
480
481#[cfg(test)]
482mod tests {
483 use super::*;
484 use crate::builtins::common::test_support;
485 use runmat_builtins::{CharArray, ComplexTensor, IntValue, LogicalArray, Tensor};
486
487 #[test]
488 fn mod_positive_values() {
489 let result = mod_builtin(Value::Num(17.0), Value::Num(5.0)).expect("mod");
490 match result {
491 Value::Num(v) => assert!((v - 2.0).abs() < 1e-12),
492 other => panic!("expected scalar result, got {other:?}"),
493 }
494 }
495
496 #[test]
497 fn mod_negative_divisor_keeps_sign() {
498 let tensor = Tensor::new(vec![-7.0, -3.0, 4.0, 9.0], vec![4, 1]).unwrap();
499 let divisor = Tensor::new(vec![-4.0], vec![1, 1]).unwrap();
500 let result =
501 mod_builtin(Value::Tensor(tensor), Value::Tensor(divisor)).expect("mod broadcast");
502 match result {
503 Value::Tensor(out) => {
504 assert_eq!(out.data, vec![-3.0, -3.0, 0.0, -3.0]);
505 }
506 other => panic!("expected tensor result, got {other:?}"),
507 }
508 }
509
510 #[test]
511 fn mod_negative_numerator_positive_divisor() {
512 let result = mod_builtin(Value::Num(-3.0), Value::Num(2.0)).expect("mod");
513 match result {
514 Value::Num(v) => assert!((v - 1.0).abs() < 1e-12),
515 other => panic!("expected scalar result, got {other:?}"),
516 }
517 }
518
519 #[test]
520 fn mod_zero_divisor_returns_nan() {
521 let result = mod_builtin(Value::Num(3.0), Value::Num(0.0)).expect("mod");
522 match result {
523 Value::Num(v) => assert!(v.is_nan()),
524 other => panic!("expected NaN, got {other:?}"),
525 }
526 }
527
528 #[test]
529 fn mod_matrix_scalar_broadcast() {
530 let matrix = Tensor::new(vec![4.5, 7.1, -2.3, 0.4], vec![2, 2]).unwrap();
531 let result = mod_builtin(Value::Tensor(matrix), Value::Num(2.0)).expect("broadcast");
532 match result {
533 Value::Tensor(t) => {
534 assert_eq!(t.shape, vec![2, 2]);
535 let expected = [0.5, 1.1, 1.7, 0.4];
536 for (a, b) in t.data.iter().zip(expected.iter()) {
537 assert!((a - b).abs() < 1e-12);
538 }
539 }
540 other => panic!("expected tensor result, got {other:?}"),
541 }
542 }
543
544 #[test]
545 fn mod_complex_operands() {
546 let complex =
547 ComplexTensor::new(vec![(3.0, 4.0), (-2.0, 5.0)], vec![1, 2]).expect("complex tensor");
548 let divisor = ComplexTensor::new(vec![(2.0, 1.0)], vec![1, 1]).expect("divisor");
549 let result = mod_builtin(Value::ComplexTensor(complex), Value::ComplexTensor(divisor))
550 .expect("complex mod");
551 match result {
552 Value::ComplexTensor(out) => {
553 assert_eq!(out.shape, vec![1, 2]);
554 let expected = [(0.0, 0.0), (0.0, 1.0)];
555 for ((re, im), (er, ei)) in out.data.iter().zip(expected.iter()) {
556 assert!((re - er).abs() < 1e-12);
557 assert!((im - ei).abs() < 1e-12);
558 }
559 }
560 other => panic!("expected complex tensor result, got {other:?}"),
561 }
562 }
563
564 #[test]
565 fn mod_char_array_support() {
566 let chars = CharArray::new("ABC".chars().collect(), 1, 3).unwrap();
567 let result = mod_builtin(Value::CharArray(chars), Value::Num(5.0)).expect("mod");
568 match result {
569 Value::Tensor(t) => assert_eq!(t.data, vec![0.0, 1.0, 2.0]),
570 other => panic!("expected tensor result, got {other:?}"),
571 }
572 }
573
574 #[test]
575 fn mod_string_input_errors() {
576 let err = mod_builtin(Value::from("abc"), Value::Num(3.0))
577 .expect_err("string inputs should error");
578 assert!(err.contains("expected numeric input"));
579 }
580
581 #[test]
582 fn mod_logical_array_support() {
583 let logical = LogicalArray::new(vec![1, 0, 1, 0], vec![2, 2]).unwrap();
584 let value =
585 mod_builtin(Value::LogicalArray(logical), Value::Num(2.0)).expect("logical mod");
586 match value {
587 Value::Tensor(t) => assert_eq!(t.data, vec![1.0, 0.0, 1.0, 0.0]),
588 other => panic!("expected tensor result, got {other:?}"),
589 }
590 }
591
592 #[test]
593 fn mod_vector_broadcasting() {
594 let lhs = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
595 let rhs = Tensor::new(vec![3.0, 4.0, 5.0], vec![1, 3]).unwrap();
596 let result = mod_builtin(Value::Tensor(lhs), Value::Tensor(rhs)).expect("vector broadcast");
597 match result {
598 Value::Tensor(t) => {
599 assert_eq!(t.shape, vec![2, 3]);
600 assert_eq!(t.data, vec![1.0, 2.0, 1.0, 2.0, 1.0, 2.0]);
601 }
602 other => panic!("expected tensor result, got {other:?}"),
603 }
604 }
605
606 #[test]
607 fn mod_nan_inputs_propagate() {
608 let result = mod_builtin(Value::Num(f64::NAN), Value::Num(3.0)).expect("mod");
609 match result {
610 Value::Num(v) => assert!(v.is_nan()),
611 other => panic!("expected NaN result, got {other:?}"),
612 }
613 }
614
615 #[test]
616 fn mod_gpu_pair_roundtrip() {
617 test_support::with_test_provider(|provider| {
618 let tensor = Tensor::new(vec![-5.0, -3.0, 0.0, 1.0, 6.0, 9.0], vec![3, 2]).unwrap();
619 let divisor = Tensor::new(vec![4.0, 4.0, 4.0, 4.0, 4.0, 4.0], vec![3, 2]).unwrap();
620 let a_view = runmat_accelerate_api::HostTensorView {
621 data: &tensor.data,
622 shape: &tensor.shape,
623 };
624 let b_view = runmat_accelerate_api::HostTensorView {
625 data: &divisor.data,
626 shape: &divisor.shape,
627 };
628 let a_handle = provider.upload(&a_view).expect("upload a");
629 let b_handle = provider.upload(&b_view).expect("upload b");
630 let result =
631 mod_builtin(Value::GpuTensor(a_handle), Value::GpuTensor(b_handle)).expect("mod");
632 let gathered = test_support::gather(result).expect("gather result");
633 assert_eq!(gathered.shape, vec![3, 2]);
634 assert_eq!(gathered.data, vec![3.0, 1.0, 0.0, 1.0, 2.0, 1.0]);
635 });
636 }
637
638 #[test]
639 fn mod_int_scalar_promotes() {
640 let result =
641 mod_builtin(Value::Int(IntValue::I32(-7)), Value::Int(IntValue::I32(4))).expect("mod");
642 match result {
643 Value::Num(v) => assert!((v - 1.0).abs() < 1e-12),
644 other => panic!("expected scalar result, got {other:?}"),
645 }
646 }
647
648 #[test]
649 #[cfg(feature = "wgpu")]
650 fn mod_wgpu_matches_cpu() {
651 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
652 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
653 );
654 let numer = Tensor::new(vec![-5.0, -3.25, 0.0, 1.75, 6.5, 9.0], vec![3, 2]).unwrap();
655 let denom = Tensor::new(vec![4.0, -2.5, 3.0, 3.0, 2.0, -5.0], vec![3, 2]).unwrap();
656 let cpu_value =
657 mod_host(Value::Tensor(numer.clone()), Value::Tensor(denom.clone())).expect("cpu mod");
658
659 let provider = runmat_accelerate_api::provider().expect("wgpu provider registered");
660 let numer_handle = provider
661 .upload(&runmat_accelerate_api::HostTensorView {
662 data: &numer.data,
663 shape: &numer.shape,
664 })
665 .expect("upload numer");
666 let denom_handle = provider
667 .upload(&runmat_accelerate_api::HostTensorView {
668 data: &denom.data,
669 shape: &denom.shape,
670 })
671 .expect("upload denom");
672
673 let gpu_value = mod_gpu_pair(numer_handle, denom_handle).expect("gpu mod");
674 let gpu_tensor = test_support::gather(gpu_value).expect("gather gpu result");
675
676 let cpu_tensor = match cpu_value {
677 Value::Tensor(t) => t,
678 Value::Num(n) => Tensor::new(vec![n], vec![1, 1]).expect("scalar tensor"),
679 other => panic!("unexpected CPU result {other:?}"),
680 };
681
682 assert_eq!(gpu_tensor.shape, cpu_tensor.shape);
683 let tol = match provider.precision() {
684 runmat_accelerate_api::ProviderPrecision::F64 => 1e-12,
685 runmat_accelerate_api::ProviderPrecision::F32 => 1e-5,
686 };
687 for (gpu, cpu) in gpu_tensor.data.iter().zip(cpu_tensor.data.iter()) {
688 assert!(
689 (gpu - cpu).abs() <= tol,
690 "|{gpu} - {cpu}| exceeded tolerance {tol}"
691 );
692 }
693 }
694
695 #[test]
696 #[cfg(feature = "doc_export")]
697 fn doc_examples_present() {
698 let blocks = test_support::doc_examples(DOC_MD);
699 assert!(!blocks.is_empty());
700 }
701}