1use runmat_accelerate_api::GpuTensorHandle;
4use runmat_builtins::{
5 BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
6 BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
7 CharArray, ComplexTensor, ResolveContext, Tensor, Type, Value,
8};
9use runmat_macros::runtime_builtin;
10
11use crate::builtins::common::random_args::complex_tensor_into_value;
12use crate::builtins::common::spec::{
13 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
14 ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
15};
16use crate::builtins::common::{gpu_helpers, tensor};
17use crate::builtins::math::reduction::type_resolvers::diff_numeric_type;
18use crate::{build_runtime_error, BuiltinResult, RuntimeError};
19
20const NAME: &str = "diff";
21
22fn diff_type(args: &[Type], ctx: &ResolveContext) -> Type {
23 diff_numeric_type(args, ctx)
24}
25
26const DIFF_OUTPUT_B: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
27 name: "B",
28 ty: BuiltinParamType::NumericArray,
29 arity: BuiltinParamArity::Required,
30 default: None,
31 description: "Finite differences along the selected dimension.",
32}];
33
34const DIFF_INPUTS_X: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
35 name: "X",
36 ty: BuiltinParamType::Any,
37 arity: BuiltinParamArity::Required,
38 default: None,
39 description: "Input scalar or array.",
40}];
41
42const DIFF_INPUTS_X_N: [BuiltinParamDescriptor; 2] = [
43 BuiltinParamDescriptor {
44 name: "X",
45 ty: BuiltinParamType::Any,
46 arity: BuiltinParamArity::Required,
47 default: None,
48 description: "Input scalar or array.",
49 },
50 BuiltinParamDescriptor {
51 name: "n",
52 ty: BuiltinParamType::Any,
53 arity: BuiltinParamArity::Optional,
54 default: Some("1"),
55 description: "Difference order (non-negative integer scalar or empty placeholder).",
56 },
57];
58
59const DIFF_INPUTS_X_N_DIM: [BuiltinParamDescriptor; 3] = [
60 BuiltinParamDescriptor {
61 name: "X",
62 ty: BuiltinParamType::Any,
63 arity: BuiltinParamArity::Required,
64 default: None,
65 description: "Input scalar or array.",
66 },
67 BuiltinParamDescriptor {
68 name: "n",
69 ty: BuiltinParamType::Any,
70 arity: BuiltinParamArity::Optional,
71 default: Some("1"),
72 description: "Difference order (non-negative integer scalar or empty placeholder).",
73 },
74 BuiltinParamDescriptor {
75 name: "dim",
76 ty: BuiltinParamType::Any,
77 arity: BuiltinParamArity::Optional,
78 default: Some("[]"),
79 description: "Reduction dimension (positive integer scalar or empty placeholder).",
80 },
81];
82
83const DIFF_SIGNATURES: [BuiltinSignatureDescriptor; 3] = [
84 BuiltinSignatureDescriptor {
85 label: "B = diff(X)",
86 inputs: &DIFF_INPUTS_X,
87 outputs: &DIFF_OUTPUT_B,
88 },
89 BuiltinSignatureDescriptor {
90 label: "B = diff(X, n)",
91 inputs: &DIFF_INPUTS_X_N,
92 outputs: &DIFF_OUTPUT_B,
93 },
94 BuiltinSignatureDescriptor {
95 label: "B = diff(X, n, dim)",
96 inputs: &DIFF_INPUTS_X_N_DIM,
97 outputs: &DIFF_OUTPUT_B,
98 },
99];
100
101const DIFF_ERROR_INVALID_ARGUMENT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
102 code: "RM.DIFF.INVALID_ARGUMENT",
103 identifier: Some("RunMat:diff:InvalidArgument"),
104 when: "Argument count/order/dimension/order grammar is invalid.",
105 message: "diff: invalid argument",
106};
107
108const DIFF_ERROR_INVALID_INPUT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
109 code: "RM.DIFF.INVALID_INPUT",
110 identifier: Some("RunMat:diff:InvalidInput"),
111 when: "Input value cannot be converted to a supported diff domain.",
112 message: "diff: invalid input",
113};
114
115const DIFF_ERROR_INTERNAL: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
116 code: "RM.DIFF.INTERNAL",
117 identifier: Some("RunMat:diff:Internal"),
118 when: "Finite-difference execution fails due to conversion, gather, allocation, or reshape operations.",
119 message: "diff: internal failure",
120};
121
122const DIFF_ERRORS: [BuiltinErrorDescriptor; 3] = [
123 DIFF_ERROR_INVALID_ARGUMENT,
124 DIFF_ERROR_INVALID_INPUT,
125 DIFF_ERROR_INTERNAL,
126];
127
128pub const DIFF_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
129 signatures: &DIFF_SIGNATURES,
130 output_mode: BuiltinOutputMode::Fixed,
131 completion_policy: BuiltinCompletionPolicy::Public,
132 errors: &DIFF_ERRORS,
133};
134
135#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::math::reduction::diff")]
136pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
137 name: "diff",
138 op_kind: GpuOpKind::Custom("finite-difference"),
139 supported_precisions: &[ScalarType::F32, ScalarType::F64],
140 broadcast: BroadcastSemantics::Matlab,
141 provider_hooks: &[ProviderHook::Custom("diff_dim")],
142 constant_strategy: ConstantStrategy::InlineLiteral,
143 residency: ResidencyPolicy::NewHandle,
144 nan_mode: ReductionNaN::Include,
145 two_pass_threshold: None,
146 workgroup_size: None,
147 accepts_nan_mode: false,
148 notes: "Providers surface finite-difference kernels through `diff_dim`; the WGPU backend keeps tensors on the device.",
149};
150
151fn diff_descriptor_error_with_message(
152 message: impl Into<String>,
153 error: &'static BuiltinErrorDescriptor,
154) -> RuntimeError {
155 let mut builder = build_runtime_error(message).with_builtin(NAME);
156 if let Some(identifier) = error.identifier {
157 builder = builder.with_identifier(identifier);
158 }
159 builder.build()
160}
161
162fn diff_descriptor_error_with_detail(
163 error: &'static BuiltinErrorDescriptor,
164 detail: impl AsRef<str>,
165) -> RuntimeError {
166 diff_descriptor_error_with_message(format!("{}: {}", error.message, detail.as_ref()), error)
167}
168
169fn diff_invalid_argument(detail: impl AsRef<str>) -> RuntimeError {
170 diff_descriptor_error_with_detail(&DIFF_ERROR_INVALID_ARGUMENT, detail)
171}
172
173fn diff_invalid_input(detail: impl AsRef<str>) -> RuntimeError {
174 diff_descriptor_error_with_detail(&DIFF_ERROR_INVALID_INPUT, detail)
175}
176
177fn diff_internal_error(detail: impl AsRef<str>) -> RuntimeError {
178 diff_descriptor_error_with_detail(&DIFF_ERROR_INTERNAL, detail)
179}
180
181#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::math::reduction::diff")]
182pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
183 name: "diff",
184 shape: ShapeRequirements::BroadcastCompatible,
185 constant_strategy: ConstantStrategy::InlineLiteral,
186 elementwise: None,
187 reduction: None,
188 emits_nan: false,
189 notes: "Fusion planner currently delegates to the runtime implementation; providers can override with custom kernels.",
190};
191
192#[runtime_builtin(
193 name = "diff",
194 category = "math/reduction",
195 summary = "Compute forward finite differences.",
196 keywords = "diff,difference,finite difference,nth difference,gpu",
197 accel = "diff",
198 type_resolver(diff_type),
199 descriptor(crate::builtins::math::reduction::diff::DIFF_DESCRIPTOR),
200 builtin_path = "crate::builtins::math::reduction::diff"
201)]
202async fn diff_builtin(value: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
203 let (order, dim) = parse_arguments(&rest)?;
204 if order == 0 {
205 return Ok(value);
206 }
207
208 match value {
209 Value::Tensor(tensor) => {
210 diff_tensor_host(tensor, order, dim).map(tensor::tensor_into_value)
211 }
212 Value::LogicalArray(logical) => {
213 let tensor = tensor::logical_to_tensor(&logical).map_err(diff_invalid_input)?;
214 diff_tensor_host(tensor, order, dim).map(tensor::tensor_into_value)
215 }
216 Value::Num(_) | Value::Int(_) | Value::Bool(_) => {
217 let tensor =
218 tensor::value_into_tensor_for("diff", value).map_err(diff_invalid_input)?;
219 diff_tensor_host(tensor, order, dim).map(tensor::tensor_into_value)
220 }
221 Value::Complex(re, im) => {
222 let tensor = ComplexTensor {
223 data: vec![(re, im)],
224 shape: vec![1, 1],
225 rows: 1,
226 cols: 1,
227 };
228 diff_complex_tensor(tensor, order, dim).map(complex_tensor_into_value)
229 }
230 Value::ComplexTensor(tensor) => {
231 diff_complex_tensor(tensor, order, dim).map(complex_tensor_into_value)
232 }
233 Value::CharArray(chars) => diff_char_array(chars, order, dim),
234 Value::GpuTensor(handle) => diff_gpu(handle, order, dim).await,
235 other => Err(diff_invalid_input(format!(
236 "diff: unsupported input type {:?}; expected numeric, logical, or character data",
237 other
238 ))),
239 }
240}
241
242fn parse_arguments(args: &[Value]) -> BuiltinResult<(usize, Option<usize>)> {
243 match args.len() {
244 0 => Ok((1, None)),
245 1 => {
246 let order = parse_order(&args[0])?;
247 Ok((order.unwrap_or(1), None))
248 }
249 2 => {
250 let order = parse_order(&args[0])?.unwrap_or(1);
251 let dim = parse_dimension_arg(&args[1])?;
252 Ok((order, dim))
253 }
254 _ => Err(diff_invalid_argument("diff: unsupported arguments")),
255 }
256}
257
258fn parse_order(value: &Value) -> BuiltinResult<Option<usize>> {
259 if is_empty_array(value) {
260 return Ok(None);
261 }
262 match value {
263 Value::Int(i) => {
264 let raw = i.to_i64();
265 if raw < 0 {
266 return Err(diff_invalid_argument(
267 "diff: order must be a non-negative integer scalar",
268 ));
269 }
270 Ok(Some(raw as usize))
271 }
272 Value::Num(n) => parse_numeric_order(*n).map(Some),
273 Value::Tensor(t) if t.data.len() == 1 => parse_numeric_order(t.data[0]).map(Some),
274 Value::Bool(b) => Ok(Some(if *b { 1 } else { 0 })),
275 other => Err(diff_invalid_argument(format!(
276 "diff: order must be a non-negative integer scalar, got {:?}",
277 other
278 ))),
279 }
280}
281
282fn parse_numeric_order(value: f64) -> BuiltinResult<usize> {
283 if !value.is_finite() {
284 return Err(diff_invalid_argument("diff: order must be finite"));
285 }
286 if value < 0.0 {
287 return Err(diff_invalid_argument(
288 "diff: order must be a non-negative integer scalar",
289 ));
290 }
291 let rounded = value.round();
292 if (rounded - value).abs() > f64::EPSILON {
293 return Err(diff_invalid_argument(
294 "diff: order must be a non-negative integer scalar",
295 ));
296 }
297 Ok(rounded as usize)
298}
299
300fn parse_dimension_arg(value: &Value) -> BuiltinResult<Option<usize>> {
301 if is_empty_array(value) {
302 return Ok(None);
303 }
304 match value {
305 Value::Int(_) | Value::Num(_) => tensor::parse_dimension(value, "diff")
306 .map(Some)
307 .map_err(diff_invalid_argument),
308 Value::Tensor(t) if t.data.len() == 1 => {
309 tensor::parse_dimension(&Value::Num(t.data[0]), "diff")
310 .map(Some)
311 .map_err(diff_invalid_argument)
312 }
313 other => Err(diff_invalid_argument(format!(
314 "diff: dimension must be a positive integer scalar, got {:?}",
315 other
316 ))),
317 }
318}
319
320fn is_empty_array(value: &Value) -> bool {
321 matches!(value, Value::Tensor(t) if t.data.is_empty())
322}
323
324async fn diff_gpu(
325 handle: GpuTensorHandle,
326 order: usize,
327 dim: Option<usize>,
328) -> BuiltinResult<Value> {
329 let working_dim = dim.unwrap_or_else(|| default_dimension(&handle.shape));
330 if working_dim == 0 {
331 return Err(diff_invalid_argument("diff: dimension must be >= 1"));
332 }
333
334 if let Some(provider) = runmat_accelerate_api::provider() {
335 if let Ok(device_result) = provider.diff_dim(&handle, order, working_dim.saturating_sub(1))
336 {
337 return Ok(Value::GpuTensor(device_result));
338 }
339 }
340
341 let tensor = gpu_helpers::gather_tensor_async(&handle)
342 .await
343 .map_err(|e| diff_internal_error(format!("diff: {e}")))?;
344 diff_tensor_host(tensor, order, Some(working_dim)).map(tensor::tensor_into_value)
345}
346
347fn diff_char_array(chars: CharArray, order: usize, dim: Option<usize>) -> BuiltinResult<Value> {
348 if order == 0 {
349 return Ok(Value::CharArray(chars));
350 }
351 let shape = vec![chars.rows, chars.cols];
352 let data: Vec<f64> = chars.data.iter().map(|&ch| ch as u32 as f64).collect();
353 let tensor = Tensor::new(data, shape).map_err(|e| diff_internal_error(format!("diff: {e}")))?;
354 diff_tensor_host(tensor, order, dim).map(tensor::tensor_into_value)
355}
356
357pub fn diff_tensor_host(tensor: Tensor, order: usize, dim: Option<usize>) -> BuiltinResult<Tensor> {
358 let mut current = tensor;
359 let mut working_dim = dim.unwrap_or_else(|| default_dimension(¤t.shape));
360 for _ in 0..order {
361 current = diff_tensor_once(current, working_dim)?;
362 if current.data.is_empty() {
363 break;
364 }
365 if dim.is_none() && dimension_length(¤t.shape, working_dim) == 0 {
367 working_dim = default_dimension(¤t.shape);
368 }
369 }
370 Ok(current)
371}
372
373fn diff_complex_tensor(
374 tensor: ComplexTensor,
375 order: usize,
376 dim: Option<usize>,
377) -> BuiltinResult<ComplexTensor> {
378 let mut current = tensor;
379 let mut working_dim = dim.unwrap_or_else(|| default_dimension(¤t.shape));
380 for _ in 0..order {
381 current = diff_complex_tensor_once(current, working_dim)?;
382 if current.data.is_empty() {
383 break;
384 }
385 if dim.is_none() && dimension_length(¤t.shape, working_dim) == 0 {
386 working_dim = default_dimension(¤t.shape);
387 }
388 }
389 Ok(current)
390}
391
392fn diff_tensor_once(tensor: Tensor, dim: usize) -> BuiltinResult<Tensor> {
393 let Tensor {
394 data, mut shape, ..
395 } = tensor;
396 let dim_index = dim.saturating_sub(1);
397 while shape.len() <= dim_index {
398 shape.push(1);
399 }
400 let len_dim = shape[dim_index];
401 let mut output_shape = shape.clone();
402 if len_dim <= 1 || data.is_empty() {
403 output_shape[dim_index] = output_shape[dim_index].saturating_sub(1);
404 return Tensor::new(Vec::new(), output_shape)
405 .map_err(|e| diff_internal_error(format!("diff: {e}")));
406 }
407 output_shape[dim_index] = len_dim - 1;
408 let stride_before = product(&shape[..dim_index]);
409 let stride_after = product(&shape[dim_index + 1..]);
410 let output_len = stride_before * (len_dim - 1) * stride_after;
411 let mut out = Vec::with_capacity(output_len);
412
413 for after in 0..stride_after {
414 let after_base = after * stride_before * len_dim;
415 for before in 0..stride_before {
416 for k in 0..(len_dim - 1) {
417 let idx0 = before + after_base + k * stride_before;
418 let idx1 = idx0 + stride_before;
419 out.push(data[idx1] - data[idx0]);
420 }
421 }
422 }
423
424 Tensor::new(out, output_shape).map_err(|e| diff_internal_error(format!("diff: {e}")))
425}
426
427fn diff_complex_tensor_once(tensor: ComplexTensor, dim: usize) -> BuiltinResult<ComplexTensor> {
428 let ComplexTensor {
429 data, mut shape, ..
430 } = tensor;
431 let dim_index = dim.saturating_sub(1);
432 while shape.len() <= dim_index {
433 shape.push(1);
434 }
435 let len_dim = shape[dim_index];
436 let mut output_shape = shape.clone();
437 if len_dim <= 1 || data.is_empty() {
438 output_shape[dim_index] = output_shape[dim_index].saturating_sub(1);
439 return ComplexTensor::new(Vec::new(), output_shape)
440 .map_err(|e| diff_internal_error(format!("diff: {e}")));
441 }
442 output_shape[dim_index] = len_dim - 1;
443 let stride_before = product(&shape[..dim_index]);
444 let stride_after = product(&shape[dim_index + 1..]);
445 let mut out = Vec::with_capacity(stride_before * (len_dim - 1) * stride_after);
446
447 for after in 0..stride_after {
448 let after_base = after * stride_before * len_dim;
449 for before in 0..stride_before {
450 for k in 0..(len_dim - 1) {
451 let idx0 = before + after_base + k * stride_before;
452 let idx1 = idx0 + stride_before;
453 let (re0, im0) = data[idx0];
454 let (re1, im1) = data[idx1];
455 out.push((re1 - re0, im1 - im0));
456 }
457 }
458 }
459
460 ComplexTensor::new(out, output_shape).map_err(|e| diff_internal_error(format!("diff: {e}")))
461}
462
463fn default_dimension(shape: &[usize]) -> usize {
464 shape
465 .iter()
466 .position(|&dim| dim > 1)
467 .map(|idx| idx + 1)
468 .unwrap_or(1)
469}
470
471fn dimension_length(shape: &[usize], dim: usize) -> usize {
472 let dim_index = dim.saturating_sub(1);
473 if dim_index < shape.len() {
474 shape[dim_index]
475 } else {
476 1
477 }
478}
479
480fn product(dims: &[usize]) -> usize {
481 dims.iter()
482 .copied()
483 .fold(1usize, |acc, val| acc.saturating_mul(val))
484}
485
486#[cfg(test)]
487pub(crate) mod tests {
488 use super::*;
489 use crate::builtins::common::test_support;
490 use futures::executor::block_on;
491 use runmat_builtins::{IntValue, Tensor};
492
493 fn diff_builtin(value: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
494 block_on(super::diff_builtin(value, rest))
495 }
496
497 #[test]
498 fn diff_type_defaults_tensor() {
499 let out = diff_type(
500 &[Type::Tensor {
501 shape: Some(vec![Some(2), Some(3)]),
502 }],
503 &ResolveContext::new(Vec::new()),
504 );
505 assert_eq!(
506 out,
507 Type::Tensor {
508 shape: Some(vec![None, None])
509 }
510 );
511 }
512
513 #[test]
514 fn diff_descriptor_signatures_cover_core_forms() {
515 let labels: Vec<&str> = DIFF_DESCRIPTOR
516 .signatures
517 .iter()
518 .map(|sig| sig.label)
519 .collect();
520 assert!(labels.contains(&"B = diff(X)"));
521 assert!(labels.contains(&"B = diff(X, n)"));
522 assert!(labels.contains(&"B = diff(X, n, dim)"));
523 }
524
525 #[test]
526 fn diff_descriptor_errors_have_stable_codes() {
527 assert!(DIFF_DESCRIPTOR
528 .errors
529 .iter()
530 .any(|error| error.code == DIFF_ERROR_INVALID_ARGUMENT.code));
531 assert!(DIFF_DESCRIPTOR
532 .errors
533 .iter()
534 .any(|error| error.code == DIFF_ERROR_INVALID_INPUT.code));
535 assert!(DIFF_DESCRIPTOR
536 .errors
537 .iter()
538 .any(|error| error.code == DIFF_ERROR_INTERNAL.code));
539 }
540
541 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
542 #[test]
543 fn diff_row_vector_default_dimension() {
544 let tensor = Tensor::new(vec![1.0, 4.0, 9.0], vec![1, 3]).unwrap();
545 let result = diff_builtin(Value::Tensor(tensor), Vec::new()).expect("diff");
546 match result {
547 Value::Tensor(out) => {
548 assert_eq!(out.shape, vec![1, 2]);
549 assert_eq!(out.data, vec![3.0, 5.0]);
550 }
551 other => panic!("expected tensor result, got {other:?}"),
552 }
553 }
554
555 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
556 #[test]
557 fn diff_column_vector_second_order() {
558 let tensor = Tensor::new(vec![1.0, 4.0, 9.0, 16.0], vec![4, 1]).unwrap();
559 let args = vec![Value::Int(IntValue::I32(2))];
560 let result = diff_builtin(Value::Tensor(tensor), args).expect("diff");
561 match result {
562 Value::Tensor(out) => {
563 assert_eq!(out.shape, vec![2, 1]);
564 assert_eq!(out.data, vec![2.0, 2.0]);
565 }
566 other => panic!("expected tensor result, got {other:?}"),
567 }
568 }
569
570 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
571 #[test]
572 fn diff_matrix_along_columns() {
573 let tensor = Tensor::new(vec![1.0, 3.0, 5.0, 2.0, 4.0, 6.0], vec![3, 2]).unwrap();
574 let args = vec![Value::Int(IntValue::I32(1)), Value::Int(IntValue::I32(2))];
575 let result = diff_builtin(Value::Tensor(tensor), args).expect("diff");
576 match result {
577 Value::Tensor(out) => {
578 assert_eq!(out.shape, vec![3, 1]);
579 assert_eq!(out.data, vec![1.0, 1.0, 1.0]);
580 }
581 other => panic!("expected tensor result, got {other:?}"),
582 }
583 }
584
585 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
586 #[test]
587 fn diff_handles_empty_when_order_exceeds_dimension() {
588 let tensor = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
589 let args = vec![Value::Int(IntValue::I32(5))];
590 let result = diff_builtin(Value::Tensor(tensor), args).expect("diff");
591 match result {
592 Value::Tensor(out) => {
593 assert_eq!(out.shape[0], 0);
594 assert!(out.data.is_empty());
595 }
596 other => panic!("expected tensor result, got {other:?}"),
597 }
598 }
599
600 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
601 #[test]
602 fn diff_char_array_promotes_to_double() {
603 let chars = CharArray::new("ACEG".chars().collect(), 1, 4).unwrap();
604 let result = diff_builtin(Value::CharArray(chars), Vec::new()).expect("diff");
605 match result {
606 Value::Tensor(out) => {
607 assert_eq!(out.shape, vec![1, 3]);
608 assert_eq!(out.data, vec![2.0, 2.0, 2.0]);
609 }
610 other => panic!("expected tensor result, got {other:?}"),
611 }
612 }
613
614 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
615 #[test]
616 fn diff_complex_tensor_preserves_type() {
617 let tensor =
618 ComplexTensor::new(vec![(1.0, 1.0), (3.0, 2.0), (6.0, 5.0)], vec![1, 3]).unwrap();
619 let result = diff_builtin(Value::ComplexTensor(tensor), Vec::new()).expect("diff");
620 match result {
621 Value::ComplexTensor(out) => {
622 assert_eq!(out.shape, vec![1, 2]);
623 assert_eq!(out.data, vec![(2.0, 1.0), (3.0, 3.0)]);
624 }
625 other => panic!("expected complex tensor result, got {other:?}"),
626 }
627 }
628
629 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
630 #[test]
631 fn diff_zero_order_returns_input() {
632 let tensor = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
633 let args = vec![Value::Int(IntValue::I32(0))];
634 let result = diff_builtin(Value::Tensor(tensor.clone()), args).expect("diff");
635 assert_eq!(result, Value::Tensor(tensor));
636 }
637
638 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
639 #[test]
640 fn diff_accepts_empty_order_argument() {
641 let tensor = Tensor::new(vec![1.0, 4.0, 9.0], vec![3, 1]).unwrap();
642 let baseline = diff_builtin(Value::Tensor(tensor.clone()), Vec::new()).expect("diff");
643 let empty = Tensor::new(vec![], vec![0, 0]).unwrap();
644 let result = diff_builtin(Value::Tensor(tensor), vec![Value::Tensor(empty)]).expect("diff");
645 assert_eq!(result, baseline);
646 }
647
648 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
649 #[test]
650 fn diff_accepts_empty_dimension_argument() {
651 let tensor = Tensor::new(vec![1.0, 4.0, 9.0, 16.0], vec![1, 4]).unwrap();
652 let baseline = diff_builtin(
653 Value::Tensor(tensor.clone()),
654 vec![Value::Int(IntValue::I32(1))],
655 )
656 .expect("diff");
657 let empty = Tensor::new(vec![], vec![0, 0]).unwrap();
658 let result = diff_builtin(
659 Value::Tensor(tensor),
660 vec![Value::Int(IntValue::I32(1)), Value::Tensor(empty)],
661 )
662 .expect("diff");
663 assert_eq!(result, baseline);
664 }
665
666 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
667 #[test]
668 fn diff_rejects_negative_order() {
669 let tensor = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
670 let args = vec![Value::Int(IntValue::I32(-1))];
671 let err = diff_builtin(Value::Tensor(tensor), args).unwrap_err();
672 assert_eq!(err.identifier(), DIFF_ERROR_INVALID_ARGUMENT.identifier);
673 assert!(err.message().contains("non-negative"));
674 }
675
676 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
677 #[test]
678 fn diff_rejects_non_integer_order() {
679 let tensor = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
680 let args = vec![Value::Num(1.5)];
681 let err = diff_builtin(Value::Tensor(tensor), args).unwrap_err();
682 assert_eq!(err.identifier(), DIFF_ERROR_INVALID_ARGUMENT.identifier);
683 assert!(err.message().contains("non-negative integer"));
684 }
685
686 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
687 #[test]
688 fn diff_rejects_invalid_dimension() {
689 let tensor = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
690 let args = vec![Value::Int(IntValue::I32(1)), Value::Int(IntValue::I32(0))];
691 let err = diff_builtin(Value::Tensor(tensor), args).unwrap_err();
692 assert_eq!(err.identifier(), DIFF_ERROR_INVALID_ARGUMENT.identifier);
693 assert!(err.message().contains("dimension must be >= 1"));
694 }
695
696 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
697 #[test]
698 fn diff_gpu_provider_roundtrip() {
699 test_support::with_test_provider(|provider| {
700 let tensor = Tensor::new(vec![1.0, 4.0, 9.0], vec![3, 1]).unwrap();
701 let view = runmat_accelerate_api::HostTensorView {
702 data: &tensor.data,
703 shape: &tensor.shape,
704 };
705 let handle = provider.upload(&view).expect("upload");
706 let result = diff_builtin(Value::GpuTensor(handle), Vec::new()).expect("diff");
707 let gathered = test_support::gather(result).expect("gather");
708 assert_eq!(gathered.shape, vec![2, 1]);
709 assert_eq!(gathered.data, vec![3.0, 5.0]);
710 });
711 }
712
713 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
714 #[test]
715 #[cfg(feature = "wgpu")]
716 fn diff_wgpu_matches_cpu() {
717 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
718 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
719 );
720 let tensor = Tensor::new(vec![1.0, 4.0, 9.0, 16.0], vec![4, 1]).unwrap();
721 let args = vec![Value::Int(IntValue::I32(2))];
722
723 let cpu_result = diff_builtin(Value::Tensor(tensor.clone()), args.clone()).expect("diff");
724 let expected = match cpu_result {
725 Value::Tensor(t) => t,
726 other => panic!("expected tensor result, got {other:?}"),
727 };
728
729 let provider = runmat_accelerate_api::provider().expect("wgpu provider");
730 let view = runmat_accelerate_api::HostTensorView {
731 data: &tensor.data,
732 shape: &tensor.shape,
733 };
734 let handle = provider.upload(&view).expect("upload");
735 let gpu_value = diff_builtin(Value::GpuTensor(handle), args).expect("diff");
736 let gathered = test_support::gather(gpu_value).expect("gather");
737
738 assert_eq!(gathered.shape, expected.shape);
739 let tol = if matches!(
740 provider.precision(),
741 runmat_accelerate_api::ProviderPrecision::F32
742 ) {
743 1e-5
744 } else {
745 1e-12
746 };
747 for (a, b) in gathered.data.iter().zip(expected.data.iter()) {
748 assert!((a - b).abs() < tol, "|{a} - {b}| >= {tol}");
749 }
750 }
751}