Skip to main content

runmat_runtime/builtins/math/reduction/
diff.rs

1//! MATLAB-compatible `diff` builtin with GPU-aware semantics for RunMat.
2
3use runmat_accelerate_api::GpuTensorHandle;
4use runmat_builtins::{
5    BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
6    BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
7    CharArray, ComplexTensor, ResolveContext, Tensor, Type, Value,
8};
9use runmat_macros::runtime_builtin;
10
11use crate::builtins::common::random_args::complex_tensor_into_value;
12use crate::builtins::common::spec::{
13    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
14    ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
15};
16use crate::builtins::common::{gpu_helpers, tensor};
17use crate::builtins::math::reduction::type_resolvers::diff_numeric_type;
18use crate::{build_runtime_error, BuiltinResult, RuntimeError};
19
20const NAME: &str = "diff";
21
22fn diff_type(args: &[Type], ctx: &ResolveContext) -> Type {
23    diff_numeric_type(args, ctx)
24}
25
26const DIFF_OUTPUT_B: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
27    name: "B",
28    ty: BuiltinParamType::NumericArray,
29    arity: BuiltinParamArity::Required,
30    default: None,
31    description: "Finite differences along the selected dimension.",
32}];
33
34const DIFF_INPUTS_X: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
35    name: "X",
36    ty: BuiltinParamType::Any,
37    arity: BuiltinParamArity::Required,
38    default: None,
39    description: "Input scalar or array.",
40}];
41
42const DIFF_INPUTS_X_N: [BuiltinParamDescriptor; 2] = [
43    BuiltinParamDescriptor {
44        name: "X",
45        ty: BuiltinParamType::Any,
46        arity: BuiltinParamArity::Required,
47        default: None,
48        description: "Input scalar or array.",
49    },
50    BuiltinParamDescriptor {
51        name: "n",
52        ty: BuiltinParamType::Any,
53        arity: BuiltinParamArity::Optional,
54        default: Some("1"),
55        description: "Difference order (non-negative integer scalar or empty placeholder).",
56    },
57];
58
59const DIFF_INPUTS_X_N_DIM: [BuiltinParamDescriptor; 3] = [
60    BuiltinParamDescriptor {
61        name: "X",
62        ty: BuiltinParamType::Any,
63        arity: BuiltinParamArity::Required,
64        default: None,
65        description: "Input scalar or array.",
66    },
67    BuiltinParamDescriptor {
68        name: "n",
69        ty: BuiltinParamType::Any,
70        arity: BuiltinParamArity::Optional,
71        default: Some("1"),
72        description: "Difference order (non-negative integer scalar or empty placeholder).",
73    },
74    BuiltinParamDescriptor {
75        name: "dim",
76        ty: BuiltinParamType::Any,
77        arity: BuiltinParamArity::Optional,
78        default: Some("[]"),
79        description: "Reduction dimension (positive integer scalar or empty placeholder).",
80    },
81];
82
83const DIFF_SIGNATURES: [BuiltinSignatureDescriptor; 3] = [
84    BuiltinSignatureDescriptor {
85        label: "B = diff(X)",
86        inputs: &DIFF_INPUTS_X,
87        outputs: &DIFF_OUTPUT_B,
88    },
89    BuiltinSignatureDescriptor {
90        label: "B = diff(X, n)",
91        inputs: &DIFF_INPUTS_X_N,
92        outputs: &DIFF_OUTPUT_B,
93    },
94    BuiltinSignatureDescriptor {
95        label: "B = diff(X, n, dim)",
96        inputs: &DIFF_INPUTS_X_N_DIM,
97        outputs: &DIFF_OUTPUT_B,
98    },
99];
100
101const DIFF_ERROR_INVALID_ARGUMENT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
102    code: "RM.DIFF.INVALID_ARGUMENT",
103    identifier: Some("RunMat:diff:InvalidArgument"),
104    when: "Argument count/order/dimension/order grammar is invalid.",
105    message: "diff: invalid argument",
106};
107
108const DIFF_ERROR_INVALID_INPUT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
109    code: "RM.DIFF.INVALID_INPUT",
110    identifier: Some("RunMat:diff:InvalidInput"),
111    when: "Input value cannot be converted to a supported diff domain.",
112    message: "diff: invalid input",
113};
114
115const DIFF_ERROR_INTERNAL: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
116    code: "RM.DIFF.INTERNAL",
117    identifier: Some("RunMat:diff:Internal"),
118    when: "Finite-difference execution fails due to conversion, gather, allocation, or reshape operations.",
119    message: "diff: internal failure",
120};
121
122const DIFF_ERRORS: [BuiltinErrorDescriptor; 3] = [
123    DIFF_ERROR_INVALID_ARGUMENT,
124    DIFF_ERROR_INVALID_INPUT,
125    DIFF_ERROR_INTERNAL,
126];
127
128pub const DIFF_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
129    signatures: &DIFF_SIGNATURES,
130    output_mode: BuiltinOutputMode::Fixed,
131    completion_policy: BuiltinCompletionPolicy::Public,
132    errors: &DIFF_ERRORS,
133};
134
135#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::math::reduction::diff")]
136pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
137    name: "diff",
138    op_kind: GpuOpKind::Custom("finite-difference"),
139    supported_precisions: &[ScalarType::F32, ScalarType::F64],
140    broadcast: BroadcastSemantics::Matlab,
141    provider_hooks: &[ProviderHook::Custom("diff_dim")],
142    constant_strategy: ConstantStrategy::InlineLiteral,
143    residency: ResidencyPolicy::NewHandle,
144    nan_mode: ReductionNaN::Include,
145    two_pass_threshold: None,
146    workgroup_size: None,
147    accepts_nan_mode: false,
148    notes: "Providers surface finite-difference kernels through `diff_dim`; the WGPU backend keeps tensors on the device.",
149};
150
151fn diff_descriptor_error_with_message(
152    message: impl Into<String>,
153    error: &'static BuiltinErrorDescriptor,
154) -> RuntimeError {
155    let mut builder = build_runtime_error(message).with_builtin(NAME);
156    if let Some(identifier) = error.identifier {
157        builder = builder.with_identifier(identifier);
158    }
159    builder.build()
160}
161
162fn diff_descriptor_error_with_detail(
163    error: &'static BuiltinErrorDescriptor,
164    detail: impl AsRef<str>,
165) -> RuntimeError {
166    diff_descriptor_error_with_message(format!("{}: {}", error.message, detail.as_ref()), error)
167}
168
169fn diff_invalid_argument(detail: impl AsRef<str>) -> RuntimeError {
170    diff_descriptor_error_with_detail(&DIFF_ERROR_INVALID_ARGUMENT, detail)
171}
172
173fn diff_invalid_input(detail: impl AsRef<str>) -> RuntimeError {
174    diff_descriptor_error_with_detail(&DIFF_ERROR_INVALID_INPUT, detail)
175}
176
177fn diff_internal_error(detail: impl AsRef<str>) -> RuntimeError {
178    diff_descriptor_error_with_detail(&DIFF_ERROR_INTERNAL, detail)
179}
180
181#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::math::reduction::diff")]
182pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
183    name: "diff",
184    shape: ShapeRequirements::BroadcastCompatible,
185    constant_strategy: ConstantStrategy::InlineLiteral,
186    elementwise: None,
187    reduction: None,
188    emits_nan: false,
189    notes: "Fusion planner currently delegates to the runtime implementation; providers can override with custom kernels.",
190};
191
192#[runtime_builtin(
193    name = "diff",
194    category = "math/reduction",
195    summary = "Compute forward finite differences.",
196    keywords = "diff,difference,finite difference,nth difference,gpu",
197    accel = "diff",
198    type_resolver(diff_type),
199    descriptor(crate::builtins::math::reduction::diff::DIFF_DESCRIPTOR),
200    builtin_path = "crate::builtins::math::reduction::diff"
201)]
202async fn diff_builtin(value: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
203    let (order, dim) = parse_arguments(&rest)?;
204    if order == 0 {
205        return Ok(value);
206    }
207
208    match value {
209        Value::Tensor(tensor) => {
210            diff_tensor_host(tensor, order, dim).map(tensor::tensor_into_value)
211        }
212        Value::LogicalArray(logical) => {
213            let tensor = tensor::logical_to_tensor(&logical).map_err(diff_invalid_input)?;
214            diff_tensor_host(tensor, order, dim).map(tensor::tensor_into_value)
215        }
216        Value::Num(_) | Value::Int(_) | Value::Bool(_) => {
217            let tensor =
218                tensor::value_into_tensor_for("diff", value).map_err(diff_invalid_input)?;
219            diff_tensor_host(tensor, order, dim).map(tensor::tensor_into_value)
220        }
221        Value::Complex(re, im) => {
222            let tensor = ComplexTensor {
223                data: vec![(re, im)],
224                shape: vec![1, 1],
225                rows: 1,
226                cols: 1,
227            };
228            diff_complex_tensor(tensor, order, dim).map(complex_tensor_into_value)
229        }
230        Value::ComplexTensor(tensor) => {
231            diff_complex_tensor(tensor, order, dim).map(complex_tensor_into_value)
232        }
233        Value::CharArray(chars) => diff_char_array(chars, order, dim),
234        Value::GpuTensor(handle) => diff_gpu(handle, order, dim).await,
235        other => Err(diff_invalid_input(format!(
236            "diff: unsupported input type {:?}; expected numeric, logical, or character data",
237            other
238        ))),
239    }
240}
241
242fn parse_arguments(args: &[Value]) -> BuiltinResult<(usize, Option<usize>)> {
243    match args.len() {
244        0 => Ok((1, None)),
245        1 => {
246            let order = parse_order(&args[0])?;
247            Ok((order.unwrap_or(1), None))
248        }
249        2 => {
250            let order = parse_order(&args[0])?.unwrap_or(1);
251            let dim = parse_dimension_arg(&args[1])?;
252            Ok((order, dim))
253        }
254        _ => Err(diff_invalid_argument("diff: unsupported arguments")),
255    }
256}
257
258fn parse_order(value: &Value) -> BuiltinResult<Option<usize>> {
259    if is_empty_array(value) {
260        return Ok(None);
261    }
262    match value {
263        Value::Int(i) => {
264            let raw = i.to_i64();
265            if raw < 0 {
266                return Err(diff_invalid_argument(
267                    "diff: order must be a non-negative integer scalar",
268                ));
269            }
270            Ok(Some(raw as usize))
271        }
272        Value::Num(n) => parse_numeric_order(*n).map(Some),
273        Value::Tensor(t) if t.data.len() == 1 => parse_numeric_order(t.data[0]).map(Some),
274        Value::Bool(b) => Ok(Some(if *b { 1 } else { 0 })),
275        other => Err(diff_invalid_argument(format!(
276            "diff: order must be a non-negative integer scalar, got {:?}",
277            other
278        ))),
279    }
280}
281
282fn parse_numeric_order(value: f64) -> BuiltinResult<usize> {
283    if !value.is_finite() {
284        return Err(diff_invalid_argument("diff: order must be finite"));
285    }
286    if value < 0.0 {
287        return Err(diff_invalid_argument(
288            "diff: order must be a non-negative integer scalar",
289        ));
290    }
291    let rounded = value.round();
292    if (rounded - value).abs() > f64::EPSILON {
293        return Err(diff_invalid_argument(
294            "diff: order must be a non-negative integer scalar",
295        ));
296    }
297    Ok(rounded as usize)
298}
299
300fn parse_dimension_arg(value: &Value) -> BuiltinResult<Option<usize>> {
301    if is_empty_array(value) {
302        return Ok(None);
303    }
304    match value {
305        Value::Int(_) | Value::Num(_) => tensor::parse_dimension(value, "diff")
306            .map(Some)
307            .map_err(diff_invalid_argument),
308        Value::Tensor(t) if t.data.len() == 1 => {
309            tensor::parse_dimension(&Value::Num(t.data[0]), "diff")
310                .map(Some)
311                .map_err(diff_invalid_argument)
312        }
313        other => Err(diff_invalid_argument(format!(
314            "diff: dimension must be a positive integer scalar, got {:?}",
315            other
316        ))),
317    }
318}
319
320fn is_empty_array(value: &Value) -> bool {
321    matches!(value, Value::Tensor(t) if t.data.is_empty())
322}
323
324async fn diff_gpu(
325    handle: GpuTensorHandle,
326    order: usize,
327    dim: Option<usize>,
328) -> BuiltinResult<Value> {
329    let working_dim = dim.unwrap_or_else(|| default_dimension(&handle.shape));
330    if working_dim == 0 {
331        return Err(diff_invalid_argument("diff: dimension must be >= 1"));
332    }
333
334    if let Some(provider) = runmat_accelerate_api::provider() {
335        if let Ok(device_result) = provider.diff_dim(&handle, order, working_dim.saturating_sub(1))
336        {
337            return Ok(Value::GpuTensor(device_result));
338        }
339    }
340
341    let tensor = gpu_helpers::gather_tensor_async(&handle)
342        .await
343        .map_err(|e| diff_internal_error(format!("diff: {e}")))?;
344    diff_tensor_host(tensor, order, Some(working_dim)).map(tensor::tensor_into_value)
345}
346
347fn diff_char_array(chars: CharArray, order: usize, dim: Option<usize>) -> BuiltinResult<Value> {
348    if order == 0 {
349        return Ok(Value::CharArray(chars));
350    }
351    let shape = vec![chars.rows, chars.cols];
352    let data: Vec<f64> = chars.data.iter().map(|&ch| ch as u32 as f64).collect();
353    let tensor = Tensor::new(data, shape).map_err(|e| diff_internal_error(format!("diff: {e}")))?;
354    diff_tensor_host(tensor, order, dim).map(tensor::tensor_into_value)
355}
356
357pub fn diff_tensor_host(tensor: Tensor, order: usize, dim: Option<usize>) -> BuiltinResult<Tensor> {
358    let mut current = tensor;
359    let mut working_dim = dim.unwrap_or_else(|| default_dimension(&current.shape));
360    for _ in 0..order {
361        current = diff_tensor_once(current, working_dim)?;
362        if current.data.is_empty() {
363            break;
364        }
365        // Preserve explicit dimension if the caller provided one; update when defaulting and shape shrinks.
366        if dim.is_none() && dimension_length(&current.shape, working_dim) == 0 {
367            working_dim = default_dimension(&current.shape);
368        }
369    }
370    Ok(current)
371}
372
373fn diff_complex_tensor(
374    tensor: ComplexTensor,
375    order: usize,
376    dim: Option<usize>,
377) -> BuiltinResult<ComplexTensor> {
378    let mut current = tensor;
379    let mut working_dim = dim.unwrap_or_else(|| default_dimension(&current.shape));
380    for _ in 0..order {
381        current = diff_complex_tensor_once(current, working_dim)?;
382        if current.data.is_empty() {
383            break;
384        }
385        if dim.is_none() && dimension_length(&current.shape, working_dim) == 0 {
386            working_dim = default_dimension(&current.shape);
387        }
388    }
389    Ok(current)
390}
391
392fn diff_tensor_once(tensor: Tensor, dim: usize) -> BuiltinResult<Tensor> {
393    let Tensor {
394        data, mut shape, ..
395    } = tensor;
396    let dim_index = dim.saturating_sub(1);
397    while shape.len() <= dim_index {
398        shape.push(1);
399    }
400    let len_dim = shape[dim_index];
401    let mut output_shape = shape.clone();
402    if len_dim <= 1 || data.is_empty() {
403        output_shape[dim_index] = output_shape[dim_index].saturating_sub(1);
404        return Tensor::new(Vec::new(), output_shape)
405            .map_err(|e| diff_internal_error(format!("diff: {e}")));
406    }
407    output_shape[dim_index] = len_dim - 1;
408    let stride_before = product(&shape[..dim_index]);
409    let stride_after = product(&shape[dim_index + 1..]);
410    let output_len = stride_before * (len_dim - 1) * stride_after;
411    let mut out = Vec::with_capacity(output_len);
412
413    for after in 0..stride_after {
414        let after_base = after * stride_before * len_dim;
415        for before in 0..stride_before {
416            for k in 0..(len_dim - 1) {
417                let idx0 = before + after_base + k * stride_before;
418                let idx1 = idx0 + stride_before;
419                out.push(data[idx1] - data[idx0]);
420            }
421        }
422    }
423
424    Tensor::new(out, output_shape).map_err(|e| diff_internal_error(format!("diff: {e}")))
425}
426
427fn diff_complex_tensor_once(tensor: ComplexTensor, dim: usize) -> BuiltinResult<ComplexTensor> {
428    let ComplexTensor {
429        data, mut shape, ..
430    } = tensor;
431    let dim_index = dim.saturating_sub(1);
432    while shape.len() <= dim_index {
433        shape.push(1);
434    }
435    let len_dim = shape[dim_index];
436    let mut output_shape = shape.clone();
437    if len_dim <= 1 || data.is_empty() {
438        output_shape[dim_index] = output_shape[dim_index].saturating_sub(1);
439        return ComplexTensor::new(Vec::new(), output_shape)
440            .map_err(|e| diff_internal_error(format!("diff: {e}")));
441    }
442    output_shape[dim_index] = len_dim - 1;
443    let stride_before = product(&shape[..dim_index]);
444    let stride_after = product(&shape[dim_index + 1..]);
445    let mut out = Vec::with_capacity(stride_before * (len_dim - 1) * stride_after);
446
447    for after in 0..stride_after {
448        let after_base = after * stride_before * len_dim;
449        for before in 0..stride_before {
450            for k in 0..(len_dim - 1) {
451                let idx0 = before + after_base + k * stride_before;
452                let idx1 = idx0 + stride_before;
453                let (re0, im0) = data[idx0];
454                let (re1, im1) = data[idx1];
455                out.push((re1 - re0, im1 - im0));
456            }
457        }
458    }
459
460    ComplexTensor::new(out, output_shape).map_err(|e| diff_internal_error(format!("diff: {e}")))
461}
462
463fn default_dimension(shape: &[usize]) -> usize {
464    shape
465        .iter()
466        .position(|&dim| dim > 1)
467        .map(|idx| idx + 1)
468        .unwrap_or(1)
469}
470
471fn dimension_length(shape: &[usize], dim: usize) -> usize {
472    let dim_index = dim.saturating_sub(1);
473    if dim_index < shape.len() {
474        shape[dim_index]
475    } else {
476        1
477    }
478}
479
480fn product(dims: &[usize]) -> usize {
481    dims.iter()
482        .copied()
483        .fold(1usize, |acc, val| acc.saturating_mul(val))
484}
485
486#[cfg(test)]
487pub(crate) mod tests {
488    use super::*;
489    use crate::builtins::common::test_support;
490    use futures::executor::block_on;
491    use runmat_builtins::{IntValue, Tensor};
492
493    fn diff_builtin(value: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
494        block_on(super::diff_builtin(value, rest))
495    }
496
497    #[test]
498    fn diff_type_defaults_tensor() {
499        let out = diff_type(
500            &[Type::Tensor {
501                shape: Some(vec![Some(2), Some(3)]),
502            }],
503            &ResolveContext::new(Vec::new()),
504        );
505        assert_eq!(
506            out,
507            Type::Tensor {
508                shape: Some(vec![None, None])
509            }
510        );
511    }
512
513    #[test]
514    fn diff_descriptor_signatures_cover_core_forms() {
515        let labels: Vec<&str> = DIFF_DESCRIPTOR
516            .signatures
517            .iter()
518            .map(|sig| sig.label)
519            .collect();
520        assert!(labels.contains(&"B = diff(X)"));
521        assert!(labels.contains(&"B = diff(X, n)"));
522        assert!(labels.contains(&"B = diff(X, n, dim)"));
523    }
524
525    #[test]
526    fn diff_descriptor_errors_have_stable_codes() {
527        assert!(DIFF_DESCRIPTOR
528            .errors
529            .iter()
530            .any(|error| error.code == DIFF_ERROR_INVALID_ARGUMENT.code));
531        assert!(DIFF_DESCRIPTOR
532            .errors
533            .iter()
534            .any(|error| error.code == DIFF_ERROR_INVALID_INPUT.code));
535        assert!(DIFF_DESCRIPTOR
536            .errors
537            .iter()
538            .any(|error| error.code == DIFF_ERROR_INTERNAL.code));
539    }
540
541    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
542    #[test]
543    fn diff_row_vector_default_dimension() {
544        let tensor = Tensor::new(vec![1.0, 4.0, 9.0], vec![1, 3]).unwrap();
545        let result = diff_builtin(Value::Tensor(tensor), Vec::new()).expect("diff");
546        match result {
547            Value::Tensor(out) => {
548                assert_eq!(out.shape, vec![1, 2]);
549                assert_eq!(out.data, vec![3.0, 5.0]);
550            }
551            other => panic!("expected tensor result, got {other:?}"),
552        }
553    }
554
555    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
556    #[test]
557    fn diff_column_vector_second_order() {
558        let tensor = Tensor::new(vec![1.0, 4.0, 9.0, 16.0], vec![4, 1]).unwrap();
559        let args = vec![Value::Int(IntValue::I32(2))];
560        let result = diff_builtin(Value::Tensor(tensor), args).expect("diff");
561        match result {
562            Value::Tensor(out) => {
563                assert_eq!(out.shape, vec![2, 1]);
564                assert_eq!(out.data, vec![2.0, 2.0]);
565            }
566            other => panic!("expected tensor result, got {other:?}"),
567        }
568    }
569
570    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
571    #[test]
572    fn diff_matrix_along_columns() {
573        let tensor = Tensor::new(vec![1.0, 3.0, 5.0, 2.0, 4.0, 6.0], vec![3, 2]).unwrap();
574        let args = vec![Value::Int(IntValue::I32(1)), Value::Int(IntValue::I32(2))];
575        let result = diff_builtin(Value::Tensor(tensor), args).expect("diff");
576        match result {
577            Value::Tensor(out) => {
578                assert_eq!(out.shape, vec![3, 1]);
579                assert_eq!(out.data, vec![1.0, 1.0, 1.0]);
580            }
581            other => panic!("expected tensor result, got {other:?}"),
582        }
583    }
584
585    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
586    #[test]
587    fn diff_handles_empty_when_order_exceeds_dimension() {
588        let tensor = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
589        let args = vec![Value::Int(IntValue::I32(5))];
590        let result = diff_builtin(Value::Tensor(tensor), args).expect("diff");
591        match result {
592            Value::Tensor(out) => {
593                assert_eq!(out.shape[0], 0);
594                assert!(out.data.is_empty());
595            }
596            other => panic!("expected tensor result, got {other:?}"),
597        }
598    }
599
600    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
601    #[test]
602    fn diff_char_array_promotes_to_double() {
603        let chars = CharArray::new("ACEG".chars().collect(), 1, 4).unwrap();
604        let result = diff_builtin(Value::CharArray(chars), Vec::new()).expect("diff");
605        match result {
606            Value::Tensor(out) => {
607                assert_eq!(out.shape, vec![1, 3]);
608                assert_eq!(out.data, vec![2.0, 2.0, 2.0]);
609            }
610            other => panic!("expected tensor result, got {other:?}"),
611        }
612    }
613
614    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
615    #[test]
616    fn diff_complex_tensor_preserves_type() {
617        let tensor =
618            ComplexTensor::new(vec![(1.0, 1.0), (3.0, 2.0), (6.0, 5.0)], vec![1, 3]).unwrap();
619        let result = diff_builtin(Value::ComplexTensor(tensor), Vec::new()).expect("diff");
620        match result {
621            Value::ComplexTensor(out) => {
622                assert_eq!(out.shape, vec![1, 2]);
623                assert_eq!(out.data, vec![(2.0, 1.0), (3.0, 3.0)]);
624            }
625            other => panic!("expected complex tensor result, got {other:?}"),
626        }
627    }
628
629    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
630    #[test]
631    fn diff_zero_order_returns_input() {
632        let tensor = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
633        let args = vec![Value::Int(IntValue::I32(0))];
634        let result = diff_builtin(Value::Tensor(tensor.clone()), args).expect("diff");
635        assert_eq!(result, Value::Tensor(tensor));
636    }
637
638    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
639    #[test]
640    fn diff_accepts_empty_order_argument() {
641        let tensor = Tensor::new(vec![1.0, 4.0, 9.0], vec![3, 1]).unwrap();
642        let baseline = diff_builtin(Value::Tensor(tensor.clone()), Vec::new()).expect("diff");
643        let empty = Tensor::new(vec![], vec![0, 0]).unwrap();
644        let result = diff_builtin(Value::Tensor(tensor), vec![Value::Tensor(empty)]).expect("diff");
645        assert_eq!(result, baseline);
646    }
647
648    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
649    #[test]
650    fn diff_accepts_empty_dimension_argument() {
651        let tensor = Tensor::new(vec![1.0, 4.0, 9.0, 16.0], vec![1, 4]).unwrap();
652        let baseline = diff_builtin(
653            Value::Tensor(tensor.clone()),
654            vec![Value::Int(IntValue::I32(1))],
655        )
656        .expect("diff");
657        let empty = Tensor::new(vec![], vec![0, 0]).unwrap();
658        let result = diff_builtin(
659            Value::Tensor(tensor),
660            vec![Value::Int(IntValue::I32(1)), Value::Tensor(empty)],
661        )
662        .expect("diff");
663        assert_eq!(result, baseline);
664    }
665
666    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
667    #[test]
668    fn diff_rejects_negative_order() {
669        let tensor = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
670        let args = vec![Value::Int(IntValue::I32(-1))];
671        let err = diff_builtin(Value::Tensor(tensor), args).unwrap_err();
672        assert_eq!(err.identifier(), DIFF_ERROR_INVALID_ARGUMENT.identifier);
673        assert!(err.message().contains("non-negative"));
674    }
675
676    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
677    #[test]
678    fn diff_rejects_non_integer_order() {
679        let tensor = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
680        let args = vec![Value::Num(1.5)];
681        let err = diff_builtin(Value::Tensor(tensor), args).unwrap_err();
682        assert_eq!(err.identifier(), DIFF_ERROR_INVALID_ARGUMENT.identifier);
683        assert!(err.message().contains("non-negative integer"));
684    }
685
686    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
687    #[test]
688    fn diff_rejects_invalid_dimension() {
689        let tensor = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
690        let args = vec![Value::Int(IntValue::I32(1)), Value::Int(IntValue::I32(0))];
691        let err = diff_builtin(Value::Tensor(tensor), args).unwrap_err();
692        assert_eq!(err.identifier(), DIFF_ERROR_INVALID_ARGUMENT.identifier);
693        assert!(err.message().contains("dimension must be >= 1"));
694    }
695
696    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
697    #[test]
698    fn diff_gpu_provider_roundtrip() {
699        test_support::with_test_provider(|provider| {
700            let tensor = Tensor::new(vec![1.0, 4.0, 9.0], vec![3, 1]).unwrap();
701            let view = runmat_accelerate_api::HostTensorView {
702                data: &tensor.data,
703                shape: &tensor.shape,
704            };
705            let handle = provider.upload(&view).expect("upload");
706            let result = diff_builtin(Value::GpuTensor(handle), Vec::new()).expect("diff");
707            let gathered = test_support::gather(result).expect("gather");
708            assert_eq!(gathered.shape, vec![2, 1]);
709            assert_eq!(gathered.data, vec![3.0, 5.0]);
710        });
711    }
712
713    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
714    #[test]
715    #[cfg(feature = "wgpu")]
716    fn diff_wgpu_matches_cpu() {
717        let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
718            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
719        );
720        let tensor = Tensor::new(vec![1.0, 4.0, 9.0, 16.0], vec![4, 1]).unwrap();
721        let args = vec![Value::Int(IntValue::I32(2))];
722
723        let cpu_result = diff_builtin(Value::Tensor(tensor.clone()), args.clone()).expect("diff");
724        let expected = match cpu_result {
725            Value::Tensor(t) => t,
726            other => panic!("expected tensor result, got {other:?}"),
727        };
728
729        let provider = runmat_accelerate_api::provider().expect("wgpu provider");
730        let view = runmat_accelerate_api::HostTensorView {
731            data: &tensor.data,
732            shape: &tensor.shape,
733        };
734        let handle = provider.upload(&view).expect("upload");
735        let gpu_value = diff_builtin(Value::GpuTensor(handle), args).expect("diff");
736        let gathered = test_support::gather(gpu_value).expect("gather");
737
738        assert_eq!(gathered.shape, expected.shape);
739        let tol = if matches!(
740            provider.precision(),
741            runmat_accelerate_api::ProviderPrecision::F32
742        ) {
743            1e-5
744        } else {
745            1e-12
746        };
747        for (a, b) in gathered.data.iter().zip(expected.data.iter()) {
748            assert!((a - b).abs() < tol, "|{a} - {b}| >= {tol}");
749        }
750    }
751}