1pub(crate) mod ceil;
4pub(crate) mod fix;
5pub(crate) mod floor;
6pub(crate) mod rem;
7pub(crate) mod round;
8
9use runmat_accelerate_api::GpuTensorHandle;
10use runmat_builtins::{
11 BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
12 BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
13 ComplexTensor, Tensor, Value,
14};
15use runmat_macros::runtime_builtin;
16
17use crate::builtins::common::broadcast::BroadcastPlan;
18use crate::builtins::common::spec::{
19 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, FusionError,
20 FusionExprContext, FusionKernelTemplate, GpuOpKind, ProviderHook, ReductionNaN,
21 ResidencyPolicy, ScalarType, ShapeRequirements,
22};
23use crate::builtins::common::{gpu_helpers, tensor};
24use crate::builtins::math::type_resolvers::numeric_binary_type;
25use crate::{build_runtime_error, BuiltinResult, RuntimeError};
26
27#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::math::rounding")]
28pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
29 name: "mod",
30 op_kind: GpuOpKind::Elementwise,
31 supported_precisions: &[ScalarType::F32, ScalarType::F64],
32 broadcast: BroadcastSemantics::Matlab,
33 provider_hooks: &[
34 ProviderHook::Binary {
35 name: "elem_div",
36 commutative: false,
37 },
38 ProviderHook::Unary { name: "unary_floor" },
39 ProviderHook::Binary {
40 name: "elem_mul",
41 commutative: false,
42 },
43 ProviderHook::Binary {
44 name: "elem_sub",
45 commutative: false,
46 },
47 ],
48 constant_strategy: ConstantStrategy::InlineLiteral,
49 residency: ResidencyPolicy::NewHandle,
50 nan_mode: ReductionNaN::Include,
51 two_pass_threshold: None,
52 workgroup_size: None,
53 accepts_nan_mode: false,
54 notes:
55 "Providers can keep mod on-device by composing elem_div → unary_floor → elem_mul → elem_sub for matching shapes. Future backends may expose a dedicated elem_mod hook.",
56};
57
58#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::math::rounding")]
59pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
60 name: "mod",
61 shape: ShapeRequirements::BroadcastCompatible,
62 constant_strategy: ConstantStrategy::InlineLiteral,
63 elementwise: Some(FusionKernelTemplate {
64 scalar_precisions: &[ScalarType::F32, ScalarType::F64],
65 wgsl_body: |ctx: &FusionExprContext| {
66 let a = ctx
67 .inputs
68 .first()
69 .ok_or(FusionError::MissingInput(0))?;
70 let b = ctx.inputs.get(1).ok_or(FusionError::MissingInput(1))?;
71 Ok(format!("{a} - {b} * floor({a} / {b})"))
72 },
73 }),
74 reduction: None,
75 emits_nan: true,
76 notes: "Fusion generates floor(a / b) followed by a - b * q; providers may substitute specialised kernels when available.",
77};
78
79const BUILTIN_NAME: &str = "mod";
80
81const MOD_OUTPUT: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
82 name: "R",
83 ty: BuiltinParamType::NumericArray,
84 arity: BuiltinParamArity::Required,
85 default: None,
86 description: "Element-wise modulus result.",
87}];
88const MOD_INPUTS: [BuiltinParamDescriptor; 2] = [
89 BuiltinParamDescriptor {
90 name: "A",
91 ty: BuiltinParamType::Any,
92 arity: BuiltinParamArity::Required,
93 default: None,
94 description: "Dividend input (numeric/logical/char/complex).",
95 },
96 BuiltinParamDescriptor {
97 name: "B",
98 ty: BuiltinParamType::Any,
99 arity: BuiltinParamArity::Required,
100 default: None,
101 description: "Divisor input (numeric/logical/char/complex).",
102 },
103];
104const MOD_SIGNATURES: [BuiltinSignatureDescriptor; 1] = [BuiltinSignatureDescriptor {
105 label: "R = mod(A, B)",
106 inputs: &MOD_INPUTS,
107 outputs: &MOD_OUTPUT,
108}];
109const MOD_ERROR_INVALID_INPUT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
110 code: "RM.MOD.INVALID_INPUT",
111 identifier: Some("RunMat:mod:InvalidInput"),
112 when: "Inputs cannot be interpreted as numeric, logical, char, or complex operands.",
113 message: "mod: invalid input",
114};
115const MOD_ERROR_SIZE_MISMATCH: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
116 code: "RM.MOD.SIZE_MISMATCH",
117 identifier: Some("RunMat:mod:SizeMismatch"),
118 when: "Operands are not broadcast-compatible.",
119 message: "mod: array sizes are not compatible for broadcasting",
120};
121const MOD_ERROR_INTERNAL: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
122 code: "RM.MOD.INTERNAL",
123 identifier: Some("RunMat:mod:Internal"),
124 when: "Internal tensor conversion, allocation, or provider composition failed.",
125 message: "mod: internal error",
126};
127const MOD_ERRORS: [BuiltinErrorDescriptor; 3] = [
128 MOD_ERROR_INVALID_INPUT,
129 MOD_ERROR_SIZE_MISMATCH,
130 MOD_ERROR_INTERNAL,
131];
132pub const MOD_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
133 signatures: &MOD_SIGNATURES,
134 output_mode: BuiltinOutputMode::Fixed,
135 completion_policy: BuiltinCompletionPolicy::Public,
136 errors: &MOD_ERRORS,
137};
138
139fn mod_error_with_detail(
140 error: &'static BuiltinErrorDescriptor,
141 detail: impl AsRef<str>,
142) -> RuntimeError {
143 mod_error_with_message(format!("{}: {}", error.message, detail.as_ref()), error)
144}
145
146fn mod_error_with_message(
147 message: impl Into<String>,
148 error: &'static BuiltinErrorDescriptor,
149) -> RuntimeError {
150 let mut builder = build_runtime_error(message).with_builtin(BUILTIN_NAME);
151 if let Some(identifier) = error.identifier {
152 builder = builder.with_identifier(identifier);
153 }
154 builder.build()
155}
156
157#[runtime_builtin(
158 name = "mod",
159 category = "math/rounding",
160 summary = "MATLAB-compatible modulus a - b .* floor(a./b) with support for complex values and broadcasting.",
161 keywords = "mod,modulus,remainder,gpu",
162 accel = "binary",
163 type_resolver(numeric_binary_type),
164 descriptor(crate::builtins::math::rounding::MOD_DESCRIPTOR),
165 builtin_path = "crate::builtins::math::rounding"
166)]
167async fn mod_builtin(lhs: Value, rhs: Value) -> BuiltinResult<Value> {
168 match (lhs, rhs) {
169 (Value::GpuTensor(a), Value::GpuTensor(b)) => mod_gpu_pair(a, b).await,
170 (Value::GpuTensor(a), other) => {
171 let gathered = gpu_helpers::gather_tensor_async(&a).await?;
172 mod_host(Value::Tensor(gathered), other)
173 }
174 (other, Value::GpuTensor(b)) => {
175 let gathered = gpu_helpers::gather_tensor_async(&b).await?;
176 mod_host(other, Value::Tensor(gathered))
177 }
178 (left, right) => mod_host(left, right),
179 }
180}
181
182async fn mod_gpu_pair(a: GpuTensorHandle, b: GpuTensorHandle) -> BuiltinResult<Value> {
183 if a.device_id == b.device_id {
184 if let Some(provider) = runmat_accelerate_api::provider_for_handle(&a) {
185 if a.shape == b.shape {
186 if let Ok(div) = provider.elem_div(&a, &b).await {
187 match provider.unary_floor(&div).await {
188 Ok(floored) => match provider.elem_mul(&b, &floored).await {
189 Ok(mul) => match provider.elem_sub(&a, &mul).await {
190 Ok(out) => {
191 let _ = provider.free(&div);
192 let _ = provider.free(&floored);
193 let _ = provider.free(&mul);
194 return Ok(gpu_helpers::resident_gpu_value(out));
195 }
196 Err(_) => {
197 let _ = provider.free(&mul);
198 let _ = provider.free(&floored);
199 let _ = provider.free(&div);
200 }
201 },
202 Err(_) => {
203 let _ = provider.free(&floored);
204 let _ = provider.free(&div);
205 }
206 },
207 Err(_) => {
208 let _ = provider.free(&div);
209 }
210 }
211 }
212 }
213 }
214 }
215 let left = gpu_helpers::gather_tensor_async(&a).await?;
216 let right = gpu_helpers::gather_tensor_async(&b).await?;
217 mod_host(Value::Tensor(left), Value::Tensor(right))
218}
219
220fn mod_host(lhs: Value, rhs: Value) -> BuiltinResult<Value> {
221 if let Some(result) = scalar_mod_value(&lhs, &rhs) {
222 return Ok(result);
223 }
224 let left = value_into_numeric_array(lhs)?;
225 let right = value_into_numeric_array(rhs)?;
226 match align_numeric_arrays(left, right)? {
227 NumericPair::Real(a, b) => compute_mod_real(&a, &b),
228 NumericPair::Complex(a, b) => compute_mod_complex(&a, &b),
229 }
230}
231
232fn compute_mod_real(a: &Tensor, b: &Tensor) -> BuiltinResult<Value> {
233 let plan = BroadcastPlan::new(&a.shape, &b.shape)
234 .map_err(|err| mod_error_with_detail(&MOD_ERROR_SIZE_MISMATCH, err))?;
235 if plan.is_empty() {
236 let tensor = Tensor::new(Vec::new(), plan.output_shape().to_vec())
237 .map_err(|e| mod_error_with_detail(&MOD_ERROR_INTERNAL, e))?;
238 return Ok(tensor::tensor_into_value(tensor));
239 }
240 let mut result = vec![0.0f64; plan.len()];
241 for (out_idx, idx_a, idx_b) in plan.iter() {
242 let aval = a.data[idx_a];
243 let bval = b.data[idx_b];
244 result[out_idx] = mod_real_scalar(aval, bval);
245 }
246 let tensor = Tensor::new(result, plan.output_shape().to_vec())
247 .map_err(|e| mod_error_with_detail(&MOD_ERROR_INTERNAL, e))?;
248 Ok(tensor::tensor_into_value(tensor))
249}
250
251fn compute_mod_complex(a: &ComplexTensor, b: &ComplexTensor) -> BuiltinResult<Value> {
252 let plan = BroadcastPlan::new(&a.shape, &b.shape)
253 .map_err(|err| mod_error_with_detail(&MOD_ERROR_SIZE_MISMATCH, err))?;
254 if plan.is_empty() {
255 let tensor = ComplexTensor::new(Vec::new(), plan.output_shape().to_vec())
256 .map_err(|e| mod_error_with_detail(&MOD_ERROR_INTERNAL, e))?;
257 return Ok(complex_tensor_into_value(tensor));
258 }
259 let mut result = vec![(0.0f64, 0.0f64); plan.len()];
260 for (out_idx, idx_a, idx_b) in plan.iter() {
261 let (ar, ai) = a.data[idx_a];
262 let (br, bi) = b.data[idx_b];
263 result[out_idx] = mod_complex_scalar(ar, ai, br, bi);
264 }
265 let tensor = ComplexTensor::new(result, plan.output_shape().to_vec())
266 .map_err(|e| mod_error_with_detail(&MOD_ERROR_INTERNAL, e))?;
267 Ok(complex_tensor_into_value(tensor))
268}
269
270fn mod_real_scalar(a: f64, b: f64) -> f64 {
271 if a.is_nan() || b.is_nan() {
272 return f64::NAN;
273 }
274 if b == 0.0 {
275 return f64::NAN;
276 }
277 if !a.is_finite() && b.is_finite() {
278 return f64::NAN;
279 }
280 let quotient = (a / b).floor();
281 let mut remainder = a - b * quotient;
282 if remainder == 0.0 {
283 remainder = 0.0;
284 }
285 if b.is_infinite() && a.is_finite() {
286 if a == 0.0 {
288 return 0.0;
289 }
290 return if a.signum() == b.signum() { a } else { b };
291 }
292 if !remainder.is_finite() && !a.is_finite() {
293 return f64::NAN;
294 }
295 let same_sign = remainder == 0.0 || remainder.signum() == b.signum();
296 if !same_sign {
297 remainder += b;
298 }
299 if remainder == -0.0 {
300 remainder = 0.0;
301 }
302 remainder
303}
304
305fn mod_complex_scalar(ar: f64, ai: f64, br: f64, bi: f64) -> (f64, f64) {
306 if (ar.is_nan() || ai.is_nan()) || (br.is_nan() || bi.is_nan()) {
307 return (f64::NAN, f64::NAN);
308 }
309 if br == 0.0 && bi == 0.0 {
310 return (f64::NAN, f64::NAN);
311 }
312 if !ar.is_finite() || !ai.is_finite() {
313 return (f64::NAN, f64::NAN);
314 }
315 let (qr, qi) = complex_div(ar, ai, br, bi);
316 if !qr.is_finite() && !qi.is_finite() && br.is_finite() && bi.is_finite() {
317 return (f64::NAN, f64::NAN);
318 }
319 let (fr, fi) = (qr.floor(), qi.floor());
320 let (mulr, muli) = complex_mul(br, bi, fr, fi);
321 let (rr, ri) = (ar - mulr, ai - muli);
322 (normalize_zero(rr), normalize_zero(ri))
323}
324
325fn scalar_real_value(value: &Value) -> Option<f64> {
326 match value {
327 Value::Num(n) => Some(*n),
328 Value::Int(i) => Some(i.to_f64()),
329 Value::Bool(b) => Some(if *b { 1.0 } else { 0.0 }),
330 Value::Tensor(t) if t.data.len() == 1 => t.data.first().copied(),
331 Value::LogicalArray(l) if l.data.len() == 1 => Some(if l.data[0] != 0 { 1.0 } else { 0.0 }),
332 Value::CharArray(ca) if ca.rows * ca.cols == 1 => {
333 Some(ca.data.first().map(|&ch| ch as u32 as f64).unwrap_or(0.0))
334 }
335 _ => None,
336 }
337}
338
339fn scalar_complex_value(value: &Value) -> Option<(f64, f64)> {
340 match value {
341 Value::Complex(re, im) => Some((*re, *im)),
342 Value::ComplexTensor(ct) if ct.data.len() == 1 => ct.data.first().copied(),
343 _ => None,
344 }
345}
346
347fn scalar_mod_value(lhs: &Value, rhs: &Value) -> Option<Value> {
348 let left = scalar_complex_value(lhs).or_else(|| scalar_real_value(lhs).map(|v| (v, 0.0)))?;
349 let right = scalar_complex_value(rhs).or_else(|| scalar_real_value(rhs).map(|v| (v, 0.0)))?;
350 let (ar, ai) = left;
351 let (br, bi) = right;
352 if ai != 0.0 || bi != 0.0 {
353 let (re, im) = mod_complex_scalar(ar, ai, br, bi);
354 return Some(Value::Complex(re, im));
355 }
356 Some(Value::Num(mod_real_scalar(ar, br)))
357}
358
359fn normalize_zero(value: f64) -> f64 {
360 if value == -0.0 {
361 0.0
362 } else {
363 value
364 }
365}
366
367fn complex_mul(ar: f64, ai: f64, br: f64, bi: f64) -> (f64, f64) {
368 (ar * br - ai * bi, ar * bi + ai * br)
369}
370
371fn complex_div(ar: f64, ai: f64, br: f64, bi: f64) -> (f64, f64) {
372 let denom = br * br + bi * bi;
373 if denom == 0.0 {
374 return (f64::NAN, f64::NAN);
375 }
376 ((ar * br + ai * bi) / denom, (ai * br - ar * bi) / denom)
377}
378
379fn complex_tensor_into_value(tensor: ComplexTensor) -> Value {
380 if tensor.data.len() == 1 {
381 let (re, im) = tensor.data[0];
382 Value::Complex(re, im)
383 } else {
384 Value::ComplexTensor(tensor)
385 }
386}
387
388fn value_into_numeric_array(value: Value) -> BuiltinResult<NumericArray> {
389 match value {
390 Value::Complex(re, im) => {
391 let tensor = ComplexTensor::new(vec![(re, im)], vec![1, 1])
392 .map_err(|e| mod_error_with_detail(&MOD_ERROR_INTERNAL, e))?;
393 Ok(NumericArray::Complex(tensor))
394 }
395 Value::ComplexTensor(ct) => Ok(NumericArray::Complex(ct)),
396 Value::CharArray(ca) => {
397 let data: Vec<f64> = ca.data.iter().map(|&ch| ch as u32 as f64).collect();
398 let tensor = Tensor::new(data, vec![ca.rows, ca.cols])
399 .map_err(|e| mod_error_with_detail(&MOD_ERROR_INTERNAL, e))?;
400 Ok(NumericArray::Real(tensor))
401 }
402 Value::String(_) | Value::StringArray(_) => Err(mod_error_with_detail(
403 &MOD_ERROR_INVALID_INPUT,
404 "expected numeric input, got string",
405 )),
406 Value::GpuTensor(_) => Err(mod_error_with_detail(
407 &MOD_ERROR_INTERNAL,
408 "internal error converting GPU tensor",
409 )),
410 other => {
411 let tensor = tensor::value_into_tensor_for(BUILTIN_NAME, other)
412 .map_err(|err| mod_error_with_detail(&MOD_ERROR_INVALID_INPUT, err))?;
413 Ok(NumericArray::Real(tensor))
414 }
415 }
416}
417
418enum NumericArray {
419 Real(Tensor),
420 Complex(ComplexTensor),
421}
422
423enum NumericPair {
424 Real(Tensor, Tensor),
425 Complex(ComplexTensor, ComplexTensor),
426}
427
428fn align_numeric_arrays(lhs: NumericArray, rhs: NumericArray) -> BuiltinResult<NumericPair> {
429 match (lhs, rhs) {
430 (NumericArray::Real(a), NumericArray::Real(b)) => Ok(NumericPair::Real(a, b)),
431 (left, right) => {
432 let lc = into_complex(left)?;
433 let rc = into_complex(right)?;
434 Ok(NumericPair::Complex(lc, rc))
435 }
436 }
437}
438
439fn into_complex(input: NumericArray) -> BuiltinResult<ComplexTensor> {
440 match input {
441 NumericArray::Real(t) => {
442 let Tensor { data, shape, .. } = t;
443 let complex: Vec<(f64, f64)> = data.into_iter().map(|re| (re, 0.0)).collect();
444 ComplexTensor::new(complex, shape)
445 .map_err(|e| mod_error_with_detail(&MOD_ERROR_INTERNAL, e))
446 }
447 NumericArray::Complex(ct) => Ok(ct),
448 }
449}
450
451#[cfg(test)]
452pub(crate) mod tests {
453 use super::*;
454 use crate::builtins::common::test_support;
455 use crate::RuntimeError;
456 use futures::executor::block_on;
457 use runmat_builtins::{
458 CharArray, ComplexTensor, IntValue, LogicalArray, ResolveContext, Tensor, Type,
459 };
460
461 fn mod_builtin(lhs: Value, rhs: Value) -> BuiltinResult<Value> {
462 block_on(super::mod_builtin(lhs, rhs))
463 }
464
465 fn assert_error_contains(error: RuntimeError, needle: &str) {
466 assert!(
467 error.message().contains(needle),
468 "unexpected error: {}",
469 error.message()
470 );
471 }
472
473 #[test]
474 fn mod_descriptor_signatures_cover_core_forms() {
475 let labels: Vec<&str> = MOD_DESCRIPTOR
476 .signatures
477 .iter()
478 .map(|sig| sig.label)
479 .collect();
480 assert!(labels.contains(&"R = mod(A, B)"));
481 }
482
483 #[test]
484 fn mod_type_preserves_tensor_shape() {
485 let out = numeric_binary_type(
486 &[
487 Type::Tensor {
488 shape: Some(vec![Some(2), Some(3)]),
489 },
490 Type::Tensor {
491 shape: Some(vec![Some(2), Some(3)]),
492 },
493 ],
494 &ResolveContext::new(Vec::new()),
495 );
496 assert_eq!(
497 out,
498 Type::Tensor {
499 shape: Some(vec![Some(2), Some(3)])
500 }
501 );
502 }
503
504 #[test]
505 fn mod_type_scalar_and_tensor_returns_tensor() {
506 let out = numeric_binary_type(
507 &[
508 Type::Num,
509 Type::Tensor {
510 shape: Some(vec![Some(4), Some(1)]),
511 },
512 ],
513 &ResolveContext::new(Vec::new()),
514 );
515 assert_eq!(
516 out,
517 Type::Tensor {
518 shape: Some(vec![Some(4), Some(1)])
519 }
520 );
521 }
522
523 #[test]
524 fn mod_type_scalar_returns_num() {
525 let out = numeric_binary_type(&[Type::Num, Type::Int], &ResolveContext::new(Vec::new()));
526 assert_eq!(out, Type::Num);
527 }
528
529 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
530 #[test]
531 fn mod_positive_values() {
532 let result = mod_builtin(Value::Num(17.0), Value::Num(5.0)).expect("mod");
533 match result {
534 Value::Num(v) => assert!((v - 2.0).abs() < 1e-12),
535 other => panic!("expected scalar result, got {other:?}"),
536 }
537 }
538
539 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
540 #[test]
541 fn mod_negative_divisor_keeps_sign() {
542 let tensor = Tensor::new(vec![-7.0, -3.0, 4.0, 9.0], vec![4, 1]).unwrap();
543 let divisor = Tensor::new(vec![-4.0], vec![1, 1]).unwrap();
544 let result =
545 mod_builtin(Value::Tensor(tensor), Value::Tensor(divisor)).expect("mod broadcast");
546 match result {
547 Value::Tensor(out) => {
548 assert_eq!(out.data, vec![-3.0, -3.0, 0.0, -3.0]);
549 }
550 other => panic!("expected tensor result, got {other:?}"),
551 }
552 }
553
554 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
555 #[test]
556 fn mod_negative_numerator_positive_divisor() {
557 let result = mod_builtin(Value::Num(-3.0), Value::Num(2.0)).expect("mod");
558 match result {
559 Value::Num(v) => assert!((v - 1.0).abs() < 1e-12),
560 other => panic!("expected scalar result, got {other:?}"),
561 }
562 }
563
564 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
565 #[test]
566 fn mod_zero_divisor_returns_nan() {
567 let result = mod_builtin(Value::Num(3.0), Value::Num(0.0)).expect("mod");
568 match result {
569 Value::Num(v) => assert!(v.is_nan()),
570 other => panic!("expected NaN, got {other:?}"),
571 }
572 }
573
574 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
575 #[test]
576 fn mod_matrix_scalar_broadcast() {
577 let matrix = Tensor::new(vec![4.5, 7.1, -2.3, 0.4], vec![2, 2]).unwrap();
578 let result = mod_builtin(Value::Tensor(matrix), Value::Num(2.0)).expect("broadcast");
579 match result {
580 Value::Tensor(t) => {
581 assert_eq!(t.shape, vec![2, 2]);
582 let expected = [0.5, 1.1, 1.7, 0.4];
583 for (a, b) in t.data.iter().zip(expected.iter()) {
584 assert!((a - b).abs() < 1e-12);
585 }
586 }
587 other => panic!("expected tensor result, got {other:?}"),
588 }
589 }
590
591 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
592 #[test]
593 fn mod_complex_operands() {
594 let complex =
595 ComplexTensor::new(vec![(3.0, 4.0), (-2.0, 5.0)], vec![1, 2]).expect("complex tensor");
596 let divisor = ComplexTensor::new(vec![(2.0, 1.0)], vec![1, 1]).expect("divisor");
597 let result = mod_builtin(Value::ComplexTensor(complex), Value::ComplexTensor(divisor))
598 .expect("complex mod");
599 match result {
600 Value::ComplexTensor(out) => {
601 assert_eq!(out.shape, vec![1, 2]);
602 let expected = [(0.0, 0.0), (0.0, 1.0)];
603 for ((re, im), (er, ei)) in out.data.iter().zip(expected.iter()) {
604 assert!((re - er).abs() < 1e-12);
605 assert!((im - ei).abs() < 1e-12);
606 }
607 }
608 other => panic!("expected complex tensor result, got {other:?}"),
609 }
610 }
611
612 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
613 #[test]
614 fn mod_char_array_support() {
615 let chars = CharArray::new("ABC".chars().collect(), 1, 3).unwrap();
616 let result = mod_builtin(Value::CharArray(chars), Value::Num(5.0)).expect("mod");
617 match result {
618 Value::Tensor(t) => assert_eq!(t.data, vec![0.0, 1.0, 2.0]),
619 other => panic!("expected tensor result, got {other:?}"),
620 }
621 }
622
623 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
624 #[test]
625 fn mod_string_input_errors() {
626 let err = mod_builtin(Value::from("abc"), Value::Num(3.0))
627 .expect_err("string inputs should error");
628 let identifier = err.identifier().map(str::to_string);
629 assert_error_contains(err, "expected numeric input");
630 assert_eq!(identifier.as_deref(), MOD_ERROR_INVALID_INPUT.identifier);
631 }
632
633 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
634 #[test]
635 fn mod_logical_array_support() {
636 let logical = LogicalArray::new(vec![1, 0, 1, 0], vec![2, 2]).unwrap();
637 let value =
638 mod_builtin(Value::LogicalArray(logical), Value::Num(2.0)).expect("logical mod");
639 match value {
640 Value::Tensor(t) => assert_eq!(t.data, vec![1.0, 0.0, 1.0, 0.0]),
641 other => panic!("expected tensor result, got {other:?}"),
642 }
643 }
644
645 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
646 #[test]
647 fn mod_vector_broadcasting() {
648 let lhs = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
649 let rhs = Tensor::new(vec![3.0, 4.0, 5.0], vec![1, 3]).unwrap();
650 let result = mod_builtin(Value::Tensor(lhs), Value::Tensor(rhs)).expect("vector broadcast");
651 match result {
652 Value::Tensor(t) => {
653 assert_eq!(t.shape, vec![2, 3]);
654 assert_eq!(t.data, vec![1.0, 2.0, 1.0, 2.0, 1.0, 2.0]);
655 }
656 other => panic!("expected tensor result, got {other:?}"),
657 }
658 }
659
660 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
661 #[test]
662 fn mod_nan_inputs_propagate() {
663 let result = mod_builtin(Value::Num(f64::NAN), Value::Num(3.0)).expect("mod");
664 match result {
665 Value::Num(v) => assert!(v.is_nan()),
666 other => panic!("expected NaN result, got {other:?}"),
667 }
668 }
669
670 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
671 #[test]
672 fn mod_gpu_pair_roundtrip() {
673 test_support::with_test_provider(|provider| {
674 let tensor = Tensor::new(vec![-5.0, -3.0, 0.0, 1.0, 6.0, 9.0], vec![3, 2]).unwrap();
675 let divisor = Tensor::new(vec![4.0, 4.0, 4.0, 4.0, 4.0, 4.0], vec![3, 2]).unwrap();
676 let a_view = runmat_accelerate_api::HostTensorView {
677 data: &tensor.data,
678 shape: &tensor.shape,
679 };
680 let b_view = runmat_accelerate_api::HostTensorView {
681 data: &divisor.data,
682 shape: &divisor.shape,
683 };
684 let a_handle = provider.upload(&a_view).expect("upload a");
685 let b_handle = provider.upload(&b_view).expect("upload b");
686 let result =
687 mod_builtin(Value::GpuTensor(a_handle), Value::GpuTensor(b_handle)).expect("mod");
688 let gathered = test_support::gather(result).expect("gather result");
689 assert_eq!(gathered.shape, vec![3, 2]);
690 assert_eq!(gathered.data, vec![3.0, 1.0, 0.0, 1.0, 2.0, 1.0]);
691 });
692 }
693
694 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
695 #[test]
696 fn mod_int_scalar_promotes() {
697 let result =
698 mod_builtin(Value::Int(IntValue::I32(-7)), Value::Int(IntValue::I32(4))).expect("mod");
699 match result {
700 Value::Num(v) => assert!((v - 1.0).abs() < 1e-12),
701 other => panic!("expected scalar result, got {other:?}"),
702 }
703 }
704
705 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
706 #[test]
707 #[cfg(feature = "wgpu")]
708 fn mod_wgpu_matches_cpu() {
709 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
710 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
711 );
712 let numer = Tensor::new(vec![-5.0, -3.25, 0.0, 1.75, 6.5, 9.0], vec![3, 2]).unwrap();
713 let denom = Tensor::new(vec![4.0, -2.5, 3.0, 3.0, 2.0, -5.0], vec![3, 2]).unwrap();
714 let cpu_value =
715 mod_host(Value::Tensor(numer.clone()), Value::Tensor(denom.clone())).expect("cpu mod");
716
717 let provider = runmat_accelerate_api::provider().expect("wgpu provider registered");
718 let numer_handle = provider
719 .upload(&runmat_accelerate_api::HostTensorView {
720 data: &numer.data,
721 shape: &numer.shape,
722 })
723 .expect("upload numer");
724 let denom_handle = provider
725 .upload(&runmat_accelerate_api::HostTensorView {
726 data: &denom.data,
727 shape: &denom.shape,
728 })
729 .expect("upload denom");
730
731 let gpu_value = block_on(mod_gpu_pair(numer_handle, denom_handle)).expect("gpu mod");
732 let gpu_tensor = test_support::gather(gpu_value).expect("gather gpu result");
733
734 let cpu_tensor = match cpu_value {
735 Value::Tensor(t) => t,
736 Value::Num(n) => Tensor::new(vec![n], vec![1, 1]).expect("scalar tensor"),
737 other => panic!("unexpected CPU result {other:?}"),
738 };
739
740 assert_eq!(gpu_tensor.shape, cpu_tensor.shape);
741 let tol = match provider.precision() {
742 runmat_accelerate_api::ProviderPrecision::F64 => 1e-12,
743 runmat_accelerate_api::ProviderPrecision::F32 => 1e-5,
744 };
745 for (gpu, cpu) in gpu_tensor.data.iter().zip(cpu_tensor.data.iter()) {
746 assert!(
747 (gpu - cpu).abs() <= tol,
748 "|{gpu} - {cpu}| exceeded tolerance {tol}"
749 );
750 }
751 }
752}