1pub(crate) mod ceil;
4pub(crate) mod fix;
5pub(crate) mod floor;
6pub(crate) mod rem;
7pub(crate) mod round;
8
9use runmat_accelerate_api::GpuTensorHandle;
10use runmat_builtins::{ComplexTensor, Tensor, Value};
11use runmat_macros::runtime_builtin;
12
13use crate::builtins::common::broadcast::BroadcastPlan;
14use crate::builtins::common::spec::{
15 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, FusionError,
16 FusionExprContext, FusionKernelTemplate, GpuOpKind, ProviderHook, ReductionNaN,
17 ResidencyPolicy, ScalarType, ShapeRequirements,
18};
19use crate::builtins::common::{gpu_helpers, tensor};
20use crate::builtins::math::type_resolvers::numeric_binary_type;
21use crate::{build_runtime_error, BuiltinResult, RuntimeError};
22
23#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::math::rounding")]
24pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
25 name: "mod",
26 op_kind: GpuOpKind::Elementwise,
27 supported_precisions: &[ScalarType::F32, ScalarType::F64],
28 broadcast: BroadcastSemantics::Matlab,
29 provider_hooks: &[
30 ProviderHook::Binary {
31 name: "elem_div",
32 commutative: false,
33 },
34 ProviderHook::Unary { name: "unary_floor" },
35 ProviderHook::Binary {
36 name: "elem_mul",
37 commutative: false,
38 },
39 ProviderHook::Binary {
40 name: "elem_sub",
41 commutative: false,
42 },
43 ],
44 constant_strategy: ConstantStrategy::InlineLiteral,
45 residency: ResidencyPolicy::NewHandle,
46 nan_mode: ReductionNaN::Include,
47 two_pass_threshold: None,
48 workgroup_size: None,
49 accepts_nan_mode: false,
50 notes:
51 "Providers can keep mod on-device by composing elem_div → unary_floor → elem_mul → elem_sub for matching shapes. Future backends may expose a dedicated elem_mod hook.",
52};
53
54#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::math::rounding")]
55pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
56 name: "mod",
57 shape: ShapeRequirements::BroadcastCompatible,
58 constant_strategy: ConstantStrategy::InlineLiteral,
59 elementwise: Some(FusionKernelTemplate {
60 scalar_precisions: &[ScalarType::F32, ScalarType::F64],
61 wgsl_body: |ctx: &FusionExprContext| {
62 let a = ctx
63 .inputs
64 .first()
65 .ok_or(FusionError::MissingInput(0))?;
66 let b = ctx.inputs.get(1).ok_or(FusionError::MissingInput(1))?;
67 Ok(format!("{a} - {b} * floor({a} / {b})"))
68 },
69 }),
70 reduction: None,
71 emits_nan: true,
72 notes: "Fusion generates floor(a / b) followed by a - b * q; providers may substitute specialised kernels when available.",
73};
74
75const BUILTIN_NAME: &str = "mod";
76
77fn builtin_error(message: impl Into<String>) -> RuntimeError {
78 build_runtime_error(message)
79 .with_builtin(BUILTIN_NAME)
80 .build()
81}
82
83#[runtime_builtin(
84 name = "mod",
85 category = "math/rounding",
86 summary = "MATLAB-compatible modulus a - b .* floor(a./b) with support for complex values and broadcasting.",
87 keywords = "mod,modulus,remainder,gpu",
88 accel = "binary",
89 type_resolver(numeric_binary_type),
90 builtin_path = "crate::builtins::math::rounding"
91)]
92async fn mod_builtin(lhs: Value, rhs: Value) -> BuiltinResult<Value> {
93 match (lhs, rhs) {
94 (Value::GpuTensor(a), Value::GpuTensor(b)) => mod_gpu_pair(a, b).await,
95 (Value::GpuTensor(a), other) => {
96 let gathered = gpu_helpers::gather_tensor_async(&a).await?;
97 mod_host(Value::Tensor(gathered), other)
98 }
99 (other, Value::GpuTensor(b)) => {
100 let gathered = gpu_helpers::gather_tensor_async(&b).await?;
101 mod_host(other, Value::Tensor(gathered))
102 }
103 (left, right) => mod_host(left, right),
104 }
105}
106
107async fn mod_gpu_pair(a: GpuTensorHandle, b: GpuTensorHandle) -> BuiltinResult<Value> {
108 if a.device_id == b.device_id {
109 if let Some(provider) = runmat_accelerate_api::provider_for_handle(&a) {
110 if a.shape == b.shape {
111 if let Ok(div) = provider.elem_div(&a, &b).await {
112 match provider.unary_floor(&div).await {
113 Ok(floored) => match provider.elem_mul(&b, &floored).await {
114 Ok(mul) => match provider.elem_sub(&a, &mul).await {
115 Ok(out) => {
116 let _ = provider.free(&div);
117 let _ = provider.free(&floored);
118 let _ = provider.free(&mul);
119 return Ok(gpu_helpers::resident_gpu_value(out));
120 }
121 Err(_) => {
122 let _ = provider.free(&mul);
123 let _ = provider.free(&floored);
124 let _ = provider.free(&div);
125 }
126 },
127 Err(_) => {
128 let _ = provider.free(&floored);
129 let _ = provider.free(&div);
130 }
131 },
132 Err(_) => {
133 let _ = provider.free(&div);
134 }
135 }
136 }
137 }
138 }
139 }
140 let left = gpu_helpers::gather_tensor_async(&a).await?;
141 let right = gpu_helpers::gather_tensor_async(&b).await?;
142 mod_host(Value::Tensor(left), Value::Tensor(right))
143}
144
145fn mod_host(lhs: Value, rhs: Value) -> BuiltinResult<Value> {
146 if let Some(result) = scalar_mod_value(&lhs, &rhs) {
147 return Ok(result);
148 }
149 let left = value_into_numeric_array(lhs, "mod")?;
150 let right = value_into_numeric_array(rhs, "mod")?;
151 match align_numeric_arrays(left, right)? {
152 NumericPair::Real(a, b) => compute_mod_real(&a, &b),
153 NumericPair::Complex(a, b) => compute_mod_complex(&a, &b),
154 }
155}
156
157fn compute_mod_real(a: &Tensor, b: &Tensor) -> BuiltinResult<Value> {
158 let plan = BroadcastPlan::new(&a.shape, &b.shape)
159 .map_err(|err| builtin_error(format!("mod: {err}")))?;
160 if plan.is_empty() {
161 let tensor = Tensor::new(Vec::new(), plan.output_shape().to_vec())
162 .map_err(|e| builtin_error(format!("mod: {e}")))?;
163 return Ok(tensor::tensor_into_value(tensor));
164 }
165 let mut result = vec![0.0f64; plan.len()];
166 for (out_idx, idx_a, idx_b) in plan.iter() {
167 let aval = a.data[idx_a];
168 let bval = b.data[idx_b];
169 result[out_idx] = mod_real_scalar(aval, bval);
170 }
171 let tensor = Tensor::new(result, plan.output_shape().to_vec())
172 .map_err(|e| builtin_error(format!("mod: {e}")))?;
173 Ok(tensor::tensor_into_value(tensor))
174}
175
176fn compute_mod_complex(a: &ComplexTensor, b: &ComplexTensor) -> BuiltinResult<Value> {
177 let plan = BroadcastPlan::new(&a.shape, &b.shape)
178 .map_err(|err| builtin_error(format!("mod: {err}")))?;
179 if plan.is_empty() {
180 let tensor = ComplexTensor::new(Vec::new(), plan.output_shape().to_vec())
181 .map_err(|e| builtin_error(format!("mod: {e}")))?;
182 return Ok(complex_tensor_into_value(tensor));
183 }
184 let mut result = vec![(0.0f64, 0.0f64); plan.len()];
185 for (out_idx, idx_a, idx_b) in plan.iter() {
186 let (ar, ai) = a.data[idx_a];
187 let (br, bi) = b.data[idx_b];
188 result[out_idx] = mod_complex_scalar(ar, ai, br, bi);
189 }
190 let tensor = ComplexTensor::new(result, plan.output_shape().to_vec())
191 .map_err(|e| builtin_error(format!("mod: {e}")))?;
192 Ok(complex_tensor_into_value(tensor))
193}
194
195fn mod_real_scalar(a: f64, b: f64) -> f64 {
196 if a.is_nan() || b.is_nan() {
197 return f64::NAN;
198 }
199 if b == 0.0 {
200 return f64::NAN;
201 }
202 if !a.is_finite() && b.is_finite() {
203 return f64::NAN;
204 }
205 let quotient = (a / b).floor();
206 let mut remainder = a - b * quotient;
207 if remainder == 0.0 {
208 remainder = 0.0;
209 }
210 if b.is_infinite() && a.is_finite() {
211 if a == 0.0 {
213 return 0.0;
214 }
215 return if a.signum() == b.signum() { a } else { b };
216 }
217 if !remainder.is_finite() && !a.is_finite() {
218 return f64::NAN;
219 }
220 let same_sign = remainder == 0.0 || remainder.signum() == b.signum();
221 if !same_sign {
222 remainder += b;
223 }
224 if remainder == -0.0 {
225 remainder = 0.0;
226 }
227 remainder
228}
229
230fn mod_complex_scalar(ar: f64, ai: f64, br: f64, bi: f64) -> (f64, f64) {
231 if (ar.is_nan() || ai.is_nan()) || (br.is_nan() || bi.is_nan()) {
232 return (f64::NAN, f64::NAN);
233 }
234 if br == 0.0 && bi == 0.0 {
235 return (f64::NAN, f64::NAN);
236 }
237 if !ar.is_finite() || !ai.is_finite() {
238 return (f64::NAN, f64::NAN);
239 }
240 let (qr, qi) = complex_div(ar, ai, br, bi);
241 if !qr.is_finite() && !qi.is_finite() && br.is_finite() && bi.is_finite() {
242 return (f64::NAN, f64::NAN);
243 }
244 let (fr, fi) = (qr.floor(), qi.floor());
245 let (mulr, muli) = complex_mul(br, bi, fr, fi);
246 let (rr, ri) = (ar - mulr, ai - muli);
247 (normalize_zero(rr), normalize_zero(ri))
248}
249
250fn scalar_real_value(value: &Value) -> Option<f64> {
251 match value {
252 Value::Num(n) => Some(*n),
253 Value::Int(i) => Some(i.to_f64()),
254 Value::Bool(b) => Some(if *b { 1.0 } else { 0.0 }),
255 Value::Tensor(t) if t.data.len() == 1 => t.data.first().copied(),
256 Value::LogicalArray(l) if l.data.len() == 1 => Some(if l.data[0] != 0 { 1.0 } else { 0.0 }),
257 Value::CharArray(ca) if ca.rows * ca.cols == 1 => {
258 Some(ca.data.first().map(|&ch| ch as u32 as f64).unwrap_or(0.0))
259 }
260 _ => None,
261 }
262}
263
264fn scalar_complex_value(value: &Value) -> Option<(f64, f64)> {
265 match value {
266 Value::Complex(re, im) => Some((*re, *im)),
267 Value::ComplexTensor(ct) if ct.data.len() == 1 => ct.data.first().copied(),
268 _ => None,
269 }
270}
271
272fn scalar_mod_value(lhs: &Value, rhs: &Value) -> Option<Value> {
273 let left = scalar_complex_value(lhs).or_else(|| scalar_real_value(lhs).map(|v| (v, 0.0)))?;
274 let right = scalar_complex_value(rhs).or_else(|| scalar_real_value(rhs).map(|v| (v, 0.0)))?;
275 let (ar, ai) = left;
276 let (br, bi) = right;
277 if ai != 0.0 || bi != 0.0 {
278 let (re, im) = mod_complex_scalar(ar, ai, br, bi);
279 return Some(Value::Complex(re, im));
280 }
281 Some(Value::Num(mod_real_scalar(ar, br)))
282}
283
284fn normalize_zero(value: f64) -> f64 {
285 if value == -0.0 {
286 0.0
287 } else {
288 value
289 }
290}
291
292fn complex_mul(ar: f64, ai: f64, br: f64, bi: f64) -> (f64, f64) {
293 (ar * br - ai * bi, ar * bi + ai * br)
294}
295
296fn complex_div(ar: f64, ai: f64, br: f64, bi: f64) -> (f64, f64) {
297 let denom = br * br + bi * bi;
298 if denom == 0.0 {
299 return (f64::NAN, f64::NAN);
300 }
301 ((ar * br + ai * bi) / denom, (ai * br - ar * bi) / denom)
302}
303
304fn complex_tensor_into_value(tensor: ComplexTensor) -> Value {
305 if tensor.data.len() == 1 {
306 let (re, im) = tensor.data[0];
307 Value::Complex(re, im)
308 } else {
309 Value::ComplexTensor(tensor)
310 }
311}
312
313fn value_into_numeric_array(value: Value, name: &str) -> BuiltinResult<NumericArray> {
314 match value {
315 Value::Complex(re, im) => {
316 let tensor = ComplexTensor::new(vec![(re, im)], vec![1, 1])
317 .map_err(|e| builtin_error(format!("{name}: {e}")))?;
318 Ok(NumericArray::Complex(tensor))
319 }
320 Value::ComplexTensor(ct) => Ok(NumericArray::Complex(ct)),
321 Value::CharArray(ca) => {
322 let data: Vec<f64> = ca.data.iter().map(|&ch| ch as u32 as f64).collect();
323 let tensor = Tensor::new(data, vec![ca.rows, ca.cols])
324 .map_err(|e| builtin_error(format!("{name}: {e}")))?;
325 Ok(NumericArray::Real(tensor))
326 }
327 Value::String(_) | Value::StringArray(_) => Err(builtin_error(format!(
328 "{name}: expected numeric input, got string"
329 ))),
330 Value::GpuTensor(_) => Err(builtin_error(format!(
331 "{name}: internal error converting GPU tensor"
332 ))),
333 other => {
334 let tensor =
335 tensor::value_into_tensor_for(name, other).map_err(|err| builtin_error(err))?;
336 Ok(NumericArray::Real(tensor))
337 }
338 }
339}
340
341enum NumericArray {
342 Real(Tensor),
343 Complex(ComplexTensor),
344}
345
346enum NumericPair {
347 Real(Tensor, Tensor),
348 Complex(ComplexTensor, ComplexTensor),
349}
350
351fn align_numeric_arrays(lhs: NumericArray, rhs: NumericArray) -> BuiltinResult<NumericPair> {
352 match (lhs, rhs) {
353 (NumericArray::Real(a), NumericArray::Real(b)) => Ok(NumericPair::Real(a, b)),
354 (left, right) => {
355 let lc = into_complex(left)?;
356 let rc = into_complex(right)?;
357 Ok(NumericPair::Complex(lc, rc))
358 }
359 }
360}
361
362fn into_complex(input: NumericArray) -> BuiltinResult<ComplexTensor> {
363 match input {
364 NumericArray::Real(t) => {
365 let Tensor { data, shape, .. } = t;
366 let complex: Vec<(f64, f64)> = data.into_iter().map(|re| (re, 0.0)).collect();
367 ComplexTensor::new(complex, shape).map_err(|e| builtin_error(format!("mod: {e}")))
368 }
369 NumericArray::Complex(ct) => Ok(ct),
370 }
371}
372
373#[cfg(test)]
374pub(crate) mod tests {
375 use super::*;
376 use crate::builtins::common::test_support;
377 use crate::RuntimeError;
378 use futures::executor::block_on;
379 use runmat_builtins::{
380 CharArray, ComplexTensor, IntValue, LogicalArray, ResolveContext, Tensor, Type,
381 };
382
383 fn mod_builtin(lhs: Value, rhs: Value) -> BuiltinResult<Value> {
384 block_on(super::mod_builtin(lhs, rhs))
385 }
386
387 fn assert_error_contains(error: RuntimeError, needle: &str) {
388 assert!(
389 error.message().contains(needle),
390 "unexpected error: {}",
391 error.message()
392 );
393 }
394
395 #[test]
396 fn mod_type_preserves_tensor_shape() {
397 let out = numeric_binary_type(
398 &[
399 Type::Tensor {
400 shape: Some(vec![Some(2), Some(3)]),
401 },
402 Type::Tensor {
403 shape: Some(vec![Some(2), Some(3)]),
404 },
405 ],
406 &ResolveContext::new(Vec::new()),
407 );
408 assert_eq!(
409 out,
410 Type::Tensor {
411 shape: Some(vec![Some(2), Some(3)])
412 }
413 );
414 }
415
416 #[test]
417 fn mod_type_scalar_and_tensor_returns_tensor() {
418 let out = numeric_binary_type(
419 &[
420 Type::Num,
421 Type::Tensor {
422 shape: Some(vec![Some(4), Some(1)]),
423 },
424 ],
425 &ResolveContext::new(Vec::new()),
426 );
427 assert_eq!(
428 out,
429 Type::Tensor {
430 shape: Some(vec![Some(4), Some(1)])
431 }
432 );
433 }
434
435 #[test]
436 fn mod_type_scalar_returns_num() {
437 let out = numeric_binary_type(&[Type::Num, Type::Int], &ResolveContext::new(Vec::new()));
438 assert_eq!(out, Type::Num);
439 }
440
441 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
442 #[test]
443 fn mod_positive_values() {
444 let result = mod_builtin(Value::Num(17.0), Value::Num(5.0)).expect("mod");
445 match result {
446 Value::Num(v) => assert!((v - 2.0).abs() < 1e-12),
447 other => panic!("expected scalar result, got {other:?}"),
448 }
449 }
450
451 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
452 #[test]
453 fn mod_negative_divisor_keeps_sign() {
454 let tensor = Tensor::new(vec![-7.0, -3.0, 4.0, 9.0], vec![4, 1]).unwrap();
455 let divisor = Tensor::new(vec![-4.0], vec![1, 1]).unwrap();
456 let result =
457 mod_builtin(Value::Tensor(tensor), Value::Tensor(divisor)).expect("mod broadcast");
458 match result {
459 Value::Tensor(out) => {
460 assert_eq!(out.data, vec![-3.0, -3.0, 0.0, -3.0]);
461 }
462 other => panic!("expected tensor result, got {other:?}"),
463 }
464 }
465
466 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
467 #[test]
468 fn mod_negative_numerator_positive_divisor() {
469 let result = mod_builtin(Value::Num(-3.0), Value::Num(2.0)).expect("mod");
470 match result {
471 Value::Num(v) => assert!((v - 1.0).abs() < 1e-12),
472 other => panic!("expected scalar result, got {other:?}"),
473 }
474 }
475
476 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
477 #[test]
478 fn mod_zero_divisor_returns_nan() {
479 let result = mod_builtin(Value::Num(3.0), Value::Num(0.0)).expect("mod");
480 match result {
481 Value::Num(v) => assert!(v.is_nan()),
482 other => panic!("expected NaN, got {other:?}"),
483 }
484 }
485
486 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
487 #[test]
488 fn mod_matrix_scalar_broadcast() {
489 let matrix = Tensor::new(vec![4.5, 7.1, -2.3, 0.4], vec![2, 2]).unwrap();
490 let result = mod_builtin(Value::Tensor(matrix), Value::Num(2.0)).expect("broadcast");
491 match result {
492 Value::Tensor(t) => {
493 assert_eq!(t.shape, vec![2, 2]);
494 let expected = [0.5, 1.1, 1.7, 0.4];
495 for (a, b) in t.data.iter().zip(expected.iter()) {
496 assert!((a - b).abs() < 1e-12);
497 }
498 }
499 other => panic!("expected tensor result, got {other:?}"),
500 }
501 }
502
503 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
504 #[test]
505 fn mod_complex_operands() {
506 let complex =
507 ComplexTensor::new(vec![(3.0, 4.0), (-2.0, 5.0)], vec![1, 2]).expect("complex tensor");
508 let divisor = ComplexTensor::new(vec![(2.0, 1.0)], vec![1, 1]).expect("divisor");
509 let result = mod_builtin(Value::ComplexTensor(complex), Value::ComplexTensor(divisor))
510 .expect("complex mod");
511 match result {
512 Value::ComplexTensor(out) => {
513 assert_eq!(out.shape, vec![1, 2]);
514 let expected = [(0.0, 0.0), (0.0, 1.0)];
515 for ((re, im), (er, ei)) in out.data.iter().zip(expected.iter()) {
516 assert!((re - er).abs() < 1e-12);
517 assert!((im - ei).abs() < 1e-12);
518 }
519 }
520 other => panic!("expected complex tensor result, got {other:?}"),
521 }
522 }
523
524 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
525 #[test]
526 fn mod_char_array_support() {
527 let chars = CharArray::new("ABC".chars().collect(), 1, 3).unwrap();
528 let result = mod_builtin(Value::CharArray(chars), Value::Num(5.0)).expect("mod");
529 match result {
530 Value::Tensor(t) => assert_eq!(t.data, vec![0.0, 1.0, 2.0]),
531 other => panic!("expected tensor result, got {other:?}"),
532 }
533 }
534
535 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
536 #[test]
537 fn mod_string_input_errors() {
538 let err = mod_builtin(Value::from("abc"), Value::Num(3.0))
539 .expect_err("string inputs should error");
540 assert_error_contains(err, "expected numeric input");
541 }
542
543 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
544 #[test]
545 fn mod_logical_array_support() {
546 let logical = LogicalArray::new(vec![1, 0, 1, 0], vec![2, 2]).unwrap();
547 let value =
548 mod_builtin(Value::LogicalArray(logical), Value::Num(2.0)).expect("logical mod");
549 match value {
550 Value::Tensor(t) => assert_eq!(t.data, vec![1.0, 0.0, 1.0, 0.0]),
551 other => panic!("expected tensor result, got {other:?}"),
552 }
553 }
554
555 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
556 #[test]
557 fn mod_vector_broadcasting() {
558 let lhs = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
559 let rhs = Tensor::new(vec![3.0, 4.0, 5.0], vec![1, 3]).unwrap();
560 let result = mod_builtin(Value::Tensor(lhs), Value::Tensor(rhs)).expect("vector broadcast");
561 match result {
562 Value::Tensor(t) => {
563 assert_eq!(t.shape, vec![2, 3]);
564 assert_eq!(t.data, vec![1.0, 2.0, 1.0, 2.0, 1.0, 2.0]);
565 }
566 other => panic!("expected tensor result, got {other:?}"),
567 }
568 }
569
570 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
571 #[test]
572 fn mod_nan_inputs_propagate() {
573 let result = mod_builtin(Value::Num(f64::NAN), Value::Num(3.0)).expect("mod");
574 match result {
575 Value::Num(v) => assert!(v.is_nan()),
576 other => panic!("expected NaN result, got {other:?}"),
577 }
578 }
579
580 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
581 #[test]
582 fn mod_gpu_pair_roundtrip() {
583 test_support::with_test_provider(|provider| {
584 let tensor = Tensor::new(vec![-5.0, -3.0, 0.0, 1.0, 6.0, 9.0], vec![3, 2]).unwrap();
585 let divisor = Tensor::new(vec![4.0, 4.0, 4.0, 4.0, 4.0, 4.0], vec![3, 2]).unwrap();
586 let a_view = runmat_accelerate_api::HostTensorView {
587 data: &tensor.data,
588 shape: &tensor.shape,
589 };
590 let b_view = runmat_accelerate_api::HostTensorView {
591 data: &divisor.data,
592 shape: &divisor.shape,
593 };
594 let a_handle = provider.upload(&a_view).expect("upload a");
595 let b_handle = provider.upload(&b_view).expect("upload b");
596 let result =
597 mod_builtin(Value::GpuTensor(a_handle), Value::GpuTensor(b_handle)).expect("mod");
598 let gathered = test_support::gather(result).expect("gather result");
599 assert_eq!(gathered.shape, vec![3, 2]);
600 assert_eq!(gathered.data, vec![3.0, 1.0, 0.0, 1.0, 2.0, 1.0]);
601 });
602 }
603
604 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
605 #[test]
606 fn mod_int_scalar_promotes() {
607 let result =
608 mod_builtin(Value::Int(IntValue::I32(-7)), Value::Int(IntValue::I32(4))).expect("mod");
609 match result {
610 Value::Num(v) => assert!((v - 1.0).abs() < 1e-12),
611 other => panic!("expected scalar result, got {other:?}"),
612 }
613 }
614
615 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
616 #[test]
617 #[cfg(feature = "wgpu")]
618 fn mod_wgpu_matches_cpu() {
619 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
620 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
621 );
622 let numer = Tensor::new(vec![-5.0, -3.25, 0.0, 1.75, 6.5, 9.0], vec![3, 2]).unwrap();
623 let denom = Tensor::new(vec![4.0, -2.5, 3.0, 3.0, 2.0, -5.0], vec![3, 2]).unwrap();
624 let cpu_value =
625 mod_host(Value::Tensor(numer.clone()), Value::Tensor(denom.clone())).expect("cpu mod");
626
627 let provider = runmat_accelerate_api::provider().expect("wgpu provider registered");
628 let numer_handle = provider
629 .upload(&runmat_accelerate_api::HostTensorView {
630 data: &numer.data,
631 shape: &numer.shape,
632 })
633 .expect("upload numer");
634 let denom_handle = provider
635 .upload(&runmat_accelerate_api::HostTensorView {
636 data: &denom.data,
637 shape: &denom.shape,
638 })
639 .expect("upload denom");
640
641 let gpu_value = block_on(mod_gpu_pair(numer_handle, denom_handle)).expect("gpu mod");
642 let gpu_tensor = test_support::gather(gpu_value).expect("gather gpu result");
643
644 let cpu_tensor = match cpu_value {
645 Value::Tensor(t) => t,
646 Value::Num(n) => Tensor::new(vec![n], vec![1, 1]).expect("scalar tensor"),
647 other => panic!("unexpected CPU result {other:?}"),
648 };
649
650 assert_eq!(gpu_tensor.shape, cpu_tensor.shape);
651 let tol = match provider.precision() {
652 runmat_accelerate_api::ProviderPrecision::F64 => 1e-12,
653 runmat_accelerate_api::ProviderPrecision::F32 => 1e-5,
654 };
655 for (gpu, cpu) in gpu_tensor.data.iter().zip(cpu_tensor.data.iter()) {
656 assert!(
657 (gpu - cpu).abs() <= tol,
658 "|{gpu} - {cpu}| exceeded tolerance {tol}"
659 );
660 }
661 }
662}