Skip to main content

runmat_runtime/builtins/math/reduction/
diff.rs

1//! MATLAB-compatible `diff` builtin with GPU-aware semantics for RunMat.
2
3use runmat_accelerate_api::GpuTensorHandle;
4use runmat_builtins::{
5    BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
6    BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
7    CharArray, ComplexTensor, ResolveContext, Tensor, Type, Value,
8};
9use runmat_macros::runtime_builtin;
10
11use crate::builtins::common::random_args::complex_tensor_into_value;
12use crate::builtins::common::spec::{
13    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
14    ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
15};
16use crate::builtins::common::{gpu_helpers, tensor};
17use crate::builtins::math::reduction::type_resolvers::diff_numeric_type;
18use crate::builtins::math::symbolic::{
19    symbolic_expr_to_value, symbolic_variable_name_from_value, value_to_symbolic_scalar,
20};
21use crate::{build_runtime_error, BuiltinResult, RuntimeError};
22
23const NAME: &str = "diff";
24
25fn diff_type(args: &[Type], ctx: &ResolveContext) -> Type {
26    diff_numeric_type(args, ctx)
27}
28
29const DIFF_OUTPUT_B: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
30    name: "B",
31    ty: BuiltinParamType::NumericArray,
32    arity: BuiltinParamArity::Required,
33    default: None,
34    description: "Finite differences along the selected dimension.",
35}];
36
37const DIFF_INPUTS_X: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
38    name: "X",
39    ty: BuiltinParamType::Any,
40    arity: BuiltinParamArity::Required,
41    default: None,
42    description: "Input scalar or array.",
43}];
44
45const DIFF_INPUTS_X_N: [BuiltinParamDescriptor; 2] = [
46    BuiltinParamDescriptor {
47        name: "X",
48        ty: BuiltinParamType::Any,
49        arity: BuiltinParamArity::Required,
50        default: None,
51        description: "Input scalar or array.",
52    },
53    BuiltinParamDescriptor {
54        name: "n",
55        ty: BuiltinParamType::Any,
56        arity: BuiltinParamArity::Optional,
57        default: Some("1"),
58        description: "Difference order (non-negative integer scalar or empty placeholder).",
59    },
60];
61
62const DIFF_INPUTS_X_N_DIM: [BuiltinParamDescriptor; 3] = [
63    BuiltinParamDescriptor {
64        name: "X",
65        ty: BuiltinParamType::Any,
66        arity: BuiltinParamArity::Required,
67        default: None,
68        description: "Input scalar or array.",
69    },
70    BuiltinParamDescriptor {
71        name: "n",
72        ty: BuiltinParamType::Any,
73        arity: BuiltinParamArity::Optional,
74        default: Some("1"),
75        description: "Difference order (non-negative integer scalar or empty placeholder).",
76    },
77    BuiltinParamDescriptor {
78        name: "dim",
79        ty: BuiltinParamType::Any,
80        arity: BuiltinParamArity::Optional,
81        default: Some("[]"),
82        description: "Reduction dimension (positive integer scalar or empty placeholder).",
83    },
84];
85
86const DIFF_SIGNATURES: [BuiltinSignatureDescriptor; 3] = [
87    BuiltinSignatureDescriptor {
88        label: "B = diff(X)",
89        inputs: &DIFF_INPUTS_X,
90        outputs: &DIFF_OUTPUT_B,
91    },
92    BuiltinSignatureDescriptor {
93        label: "B = diff(X, n)",
94        inputs: &DIFF_INPUTS_X_N,
95        outputs: &DIFF_OUTPUT_B,
96    },
97    BuiltinSignatureDescriptor {
98        label: "B = diff(X, n, dim)",
99        inputs: &DIFF_INPUTS_X_N_DIM,
100        outputs: &DIFF_OUTPUT_B,
101    },
102];
103
104const DIFF_ERROR_INVALID_ARGUMENT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
105    code: "RM.DIFF.INVALID_ARGUMENT",
106    identifier: Some("RunMat:diff:InvalidArgument"),
107    when: "Argument count/order/dimension/order grammar is invalid.",
108    message: "diff: invalid argument",
109};
110
111const DIFF_ERROR_INVALID_INPUT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
112    code: "RM.DIFF.INVALID_INPUT",
113    identifier: Some("RunMat:diff:InvalidInput"),
114    when: "Input value cannot be converted to a supported diff domain.",
115    message: "diff: invalid input",
116};
117
118const DIFF_ERROR_INTERNAL: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
119    code: "RM.DIFF.INTERNAL",
120    identifier: Some("RunMat:diff:Internal"),
121    when: "Finite-difference execution fails due to conversion, gather, allocation, or reshape operations.",
122    message: "diff: internal failure",
123};
124
125const DIFF_ERRORS: [BuiltinErrorDescriptor; 3] = [
126    DIFF_ERROR_INVALID_ARGUMENT,
127    DIFF_ERROR_INVALID_INPUT,
128    DIFF_ERROR_INTERNAL,
129];
130
131pub const DIFF_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
132    signatures: &DIFF_SIGNATURES,
133    output_mode: BuiltinOutputMode::Fixed,
134    completion_policy: BuiltinCompletionPolicy::Public,
135    errors: &DIFF_ERRORS,
136};
137
138#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::math::reduction::diff")]
139pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
140    name: "diff",
141    op_kind: GpuOpKind::Custom("finite-difference"),
142    supported_precisions: &[ScalarType::F32, ScalarType::F64],
143    broadcast: BroadcastSemantics::Matlab,
144    provider_hooks: &[ProviderHook::Custom("diff_dim")],
145    constant_strategy: ConstantStrategy::InlineLiteral,
146    residency: ResidencyPolicy::NewHandle,
147    nan_mode: ReductionNaN::Include,
148    two_pass_threshold: None,
149    workgroup_size: None,
150    accepts_nan_mode: false,
151    notes: "Providers surface finite-difference kernels through `diff_dim`; the WGPU backend keeps tensors on the device.",
152};
153
154fn diff_descriptor_error_with_message(
155    message: impl Into<String>,
156    error: &'static BuiltinErrorDescriptor,
157) -> RuntimeError {
158    let mut builder = build_runtime_error(message).with_builtin(NAME);
159    if let Some(identifier) = error.identifier {
160        builder = builder.with_identifier(identifier);
161    }
162    builder.build()
163}
164
165fn diff_descriptor_error_with_detail(
166    error: &'static BuiltinErrorDescriptor,
167    detail: impl AsRef<str>,
168) -> RuntimeError {
169    diff_descriptor_error_with_message(format!("{}: {}", error.message, detail.as_ref()), error)
170}
171
172fn diff_invalid_argument(detail: impl AsRef<str>) -> RuntimeError {
173    diff_descriptor_error_with_detail(&DIFF_ERROR_INVALID_ARGUMENT, detail)
174}
175
176fn diff_invalid_input(detail: impl AsRef<str>) -> RuntimeError {
177    diff_descriptor_error_with_detail(&DIFF_ERROR_INVALID_INPUT, detail)
178}
179
180fn diff_internal_error(detail: impl AsRef<str>) -> RuntimeError {
181    diff_descriptor_error_with_detail(&DIFF_ERROR_INTERNAL, detail)
182}
183
184#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::math::reduction::diff")]
185pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
186    name: "diff",
187    shape: ShapeRequirements::BroadcastCompatible,
188    constant_strategy: ConstantStrategy::InlineLiteral,
189    elementwise: None,
190    reduction: None,
191    emits_nan: false,
192    notes: "Fusion planner currently delegates to the runtime implementation; providers can override with custom kernels.",
193};
194
195#[runtime_builtin(
196    name = "diff",
197    category = "math/reduction",
198    summary = "Compute forward finite differences.",
199    keywords = "diff,difference,finite difference,nth difference,gpu",
200    accel = "diff",
201    type_resolver(diff_type),
202    descriptor(crate::builtins::math::reduction::diff::DIFF_DESCRIPTOR),
203    builtin_path = "crate::builtins::math::reduction::diff"
204)]
205async fn diff_builtin(value: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
206    if let Value::Symbolic(expr) = value {
207        return diff_symbolic(expr, &rest);
208    }
209
210    let (order, dim) = parse_arguments(&rest)?;
211    if order == 0 {
212        return Ok(value);
213    }
214
215    match value {
216        Value::Tensor(tensor) => {
217            diff_tensor_host(tensor, order, dim).map(tensor::tensor_into_value)
218        }
219        Value::LogicalArray(logical) => {
220            let tensor = tensor::logical_to_tensor(&logical).map_err(diff_invalid_input)?;
221            diff_tensor_host(tensor, order, dim).map(tensor::tensor_into_value)
222        }
223        Value::Num(_) | Value::Int(_) | Value::Bool(_) => {
224            let tensor =
225                tensor::value_into_tensor_for("diff", value).map_err(diff_invalid_input)?;
226            diff_tensor_host(tensor, order, dim).map(tensor::tensor_into_value)
227        }
228        Value::Complex(re, im) => {
229            let tensor = ComplexTensor {
230                data: vec![(re, im)],
231                shape: vec![1, 1],
232                rows: 1,
233                cols: 1,
234            };
235            diff_complex_tensor(tensor, order, dim).map(complex_tensor_into_value)
236        }
237        Value::ComplexTensor(tensor) => {
238            diff_complex_tensor(tensor, order, dim).map(complex_tensor_into_value)
239        }
240        Value::CharArray(chars) => diff_char_array(chars, order, dim),
241        Value::GpuTensor(handle) => diff_gpu(handle, order, dim).await,
242        other => Err(diff_invalid_input(format!(
243            "diff: unsupported input type {:?}; expected numeric, logical, or character data",
244            other
245        ))),
246    }
247}
248
249fn diff_symbolic(expr: runmat_builtins::SymbolicExpr, args: &[Value]) -> BuiltinResult<Value> {
250    let (variable, order) = parse_symbolic_diff_args(&expr, args)?;
251    Ok(symbolic_expr_to_value(
252        runmat_builtins::SymbolicExpr::derivative_expr(expr, variable, order),
253    ))
254}
255
256fn parse_symbolic_diff_args(
257    expr: &runmat_builtins::SymbolicExpr,
258    args: &[Value],
259) -> BuiltinResult<(String, u32)> {
260    match args.len() {
261        0 => Ok((infer_symbolic_diff_variable(expr)?, 1)),
262        1 => {
263            if let Some(variable) = symbolic_variable_name_from_value(&args[0]) {
264                Ok((variable, 1))
265            } else {
266                Ok((
267                    infer_symbolic_diff_variable(expr)?,
268                    parse_symbolic_order(&args[0])?,
269                ))
270            }
271        }
272        2 => {
273            if let Some(variable) = symbolic_variable_name_from_value(&args[0]) {
274                Ok((variable, parse_symbolic_order(&args[1])?))
275            } else if let Some(variable) = symbolic_variable_name_from_value(&args[1]) {
276                Ok((variable, parse_symbolic_order(&args[0])?))
277            } else {
278                Err(diff_invalid_argument(
279                    "diff: symbolic differentiation expects a variable and optional order",
280                ))
281            }
282        }
283        _ => Err(diff_invalid_argument(
284            "diff: symbolic differentiation supports at most two trailing arguments",
285        )),
286    }
287}
288
289fn infer_symbolic_diff_variable(expr: &runmat_builtins::SymbolicExpr) -> BuiltinResult<String> {
290    let variables = expr.variables();
291    if variables.len() == 1 {
292        Ok(variables.into_iter().next().unwrap_or_default())
293    } else if variables.is_empty() {
294        Ok(String::new())
295    } else {
296        Err(diff_invalid_argument(
297            "diff: symbolic differentiation variable is ambiguous",
298        ))
299    }
300}
301
302fn parse_symbolic_order(value: &Value) -> BuiltinResult<u32> {
303    let expr = value_to_symbolic_scalar(value).ok_or_else(|| {
304        diff_invalid_argument("diff: symbolic differentiation order must be a scalar integer")
305    })?;
306    let Some(order) = expr.constant_value() else {
307        return Err(diff_invalid_argument(
308            "diff: symbolic differentiation order must be numeric",
309        ));
310    };
311    if !order.is_finite() || order < 0.0 || (order.round() - order).abs() > 1.0e-12 {
312        return Err(diff_invalid_argument(
313            "diff: symbolic differentiation order must be a nonnegative integer",
314        ));
315    }
316    if order > u32::MAX as f64 {
317        return Err(diff_invalid_argument(
318            "diff: symbolic differentiation order is too large",
319        ));
320    }
321    Ok(order as u32)
322}
323
324fn parse_arguments(args: &[Value]) -> BuiltinResult<(usize, Option<usize>)> {
325    match args.len() {
326        0 => Ok((1, None)),
327        1 => {
328            let order = parse_order(&args[0])?;
329            Ok((order.unwrap_or(1), None))
330        }
331        2 => {
332            let order = parse_order(&args[0])?.unwrap_or(1);
333            let dim = parse_dimension_arg(&args[1])?;
334            Ok((order, dim))
335        }
336        _ => Err(diff_invalid_argument("diff: unsupported arguments")),
337    }
338}
339
340fn parse_order(value: &Value) -> BuiltinResult<Option<usize>> {
341    if is_empty_array(value) {
342        return Ok(None);
343    }
344    match value {
345        Value::Int(i) => {
346            let raw = i.to_i64();
347            if raw < 0 {
348                return Err(diff_invalid_argument(
349                    "diff: order must be a non-negative integer scalar",
350                ));
351            }
352            Ok(Some(raw as usize))
353        }
354        Value::Num(n) => parse_numeric_order(*n).map(Some),
355        Value::Tensor(t) if t.data.len() == 1 => parse_numeric_order(t.data[0]).map(Some),
356        Value::Bool(b) => Ok(Some(if *b { 1 } else { 0 })),
357        other => Err(diff_invalid_argument(format!(
358            "diff: order must be a non-negative integer scalar, got {:?}",
359            other
360        ))),
361    }
362}
363
364fn parse_numeric_order(value: f64) -> BuiltinResult<usize> {
365    if !value.is_finite() {
366        return Err(diff_invalid_argument("diff: order must be finite"));
367    }
368    if value < 0.0 {
369        return Err(diff_invalid_argument(
370            "diff: order must be a non-negative integer scalar",
371        ));
372    }
373    let rounded = value.round();
374    if (rounded - value).abs() > f64::EPSILON {
375        return Err(diff_invalid_argument(
376            "diff: order must be a non-negative integer scalar",
377        ));
378    }
379    Ok(rounded as usize)
380}
381
382fn parse_dimension_arg(value: &Value) -> BuiltinResult<Option<usize>> {
383    if is_empty_array(value) {
384        return Ok(None);
385    }
386    match value {
387        Value::Int(_) | Value::Num(_) => tensor::parse_dimension(value, "diff")
388            .map(Some)
389            .map_err(diff_invalid_argument),
390        Value::Tensor(t) if t.data.len() == 1 => {
391            tensor::parse_dimension(&Value::Num(t.data[0]), "diff")
392                .map(Some)
393                .map_err(diff_invalid_argument)
394        }
395        other => Err(diff_invalid_argument(format!(
396            "diff: dimension must be a positive integer scalar, got {:?}",
397            other
398        ))),
399    }
400}
401
402fn is_empty_array(value: &Value) -> bool {
403    matches!(value, Value::Tensor(t) if t.data.is_empty())
404}
405
406async fn diff_gpu(
407    handle: GpuTensorHandle,
408    order: usize,
409    dim: Option<usize>,
410) -> BuiltinResult<Value> {
411    let working_dim = dim.unwrap_or_else(|| default_dimension(&handle.shape));
412    if working_dim == 0 {
413        return Err(diff_invalid_argument("diff: dimension must be >= 1"));
414    }
415
416    if let Some(provider) = runmat_accelerate_api::provider() {
417        if let Ok(device_result) = provider.diff_dim(&handle, order, working_dim.saturating_sub(1))
418        {
419            return Ok(Value::GpuTensor(device_result));
420        }
421    }
422
423    let tensor = gpu_helpers::gather_tensor_async(&handle)
424        .await
425        .map_err(|e| diff_internal_error(format!("diff: {e}")))?;
426    diff_tensor_host(tensor, order, Some(working_dim)).map(tensor::tensor_into_value)
427}
428
429fn diff_char_array(chars: CharArray, order: usize, dim: Option<usize>) -> BuiltinResult<Value> {
430    if order == 0 {
431        return Ok(Value::CharArray(chars));
432    }
433    let shape = vec![chars.rows, chars.cols];
434    let data: Vec<f64> = chars.data.iter().map(|&ch| ch as u32 as f64).collect();
435    let tensor = Tensor::new(data, shape).map_err(|e| diff_internal_error(format!("diff: {e}")))?;
436    diff_tensor_host(tensor, order, dim).map(tensor::tensor_into_value)
437}
438
439pub fn diff_tensor_host(tensor: Tensor, order: usize, dim: Option<usize>) -> BuiltinResult<Tensor> {
440    let mut current = tensor;
441    let mut working_dim = dim.unwrap_or_else(|| default_dimension(&current.shape));
442    for _ in 0..order {
443        current = diff_tensor_once(current, working_dim)?;
444        if current.data.is_empty() {
445            break;
446        }
447        // Preserve explicit dimension if the caller provided one; update when defaulting and shape shrinks.
448        if dim.is_none() && dimension_length(&current.shape, working_dim) == 0 {
449            working_dim = default_dimension(&current.shape);
450        }
451    }
452    Ok(current)
453}
454
455fn diff_complex_tensor(
456    tensor: ComplexTensor,
457    order: usize,
458    dim: Option<usize>,
459) -> BuiltinResult<ComplexTensor> {
460    let mut current = tensor;
461    let mut working_dim = dim.unwrap_or_else(|| default_dimension(&current.shape));
462    for _ in 0..order {
463        current = diff_complex_tensor_once(current, working_dim)?;
464        if current.data.is_empty() {
465            break;
466        }
467        if dim.is_none() && dimension_length(&current.shape, working_dim) == 0 {
468            working_dim = default_dimension(&current.shape);
469        }
470    }
471    Ok(current)
472}
473
474fn diff_tensor_once(tensor: Tensor, dim: usize) -> BuiltinResult<Tensor> {
475    let Tensor {
476        data, mut shape, ..
477    } = tensor;
478    let dim_index = dim.saturating_sub(1);
479    while shape.len() <= dim_index {
480        shape.push(1);
481    }
482    let len_dim = shape[dim_index];
483    let mut output_shape = shape.clone();
484    if len_dim <= 1 || data.is_empty() {
485        output_shape[dim_index] = output_shape[dim_index].saturating_sub(1);
486        return Tensor::new(Vec::new(), output_shape)
487            .map_err(|e| diff_internal_error(format!("diff: {e}")));
488    }
489    output_shape[dim_index] = len_dim - 1;
490    let stride_before = product(&shape[..dim_index]);
491    let stride_after = product(&shape[dim_index + 1..]);
492    let output_len = stride_before * (len_dim - 1) * stride_after;
493    let mut out = Vec::with_capacity(output_len);
494
495    for after in 0..stride_after {
496        let after_base = after * stride_before * len_dim;
497        for before in 0..stride_before {
498            for k in 0..(len_dim - 1) {
499                let idx0 = before + after_base + k * stride_before;
500                let idx1 = idx0 + stride_before;
501                out.push(data[idx1] - data[idx0]);
502            }
503        }
504    }
505
506    Tensor::new(out, output_shape).map_err(|e| diff_internal_error(format!("diff: {e}")))
507}
508
509fn diff_complex_tensor_once(tensor: ComplexTensor, dim: usize) -> BuiltinResult<ComplexTensor> {
510    let ComplexTensor {
511        data, mut shape, ..
512    } = tensor;
513    let dim_index = dim.saturating_sub(1);
514    while shape.len() <= dim_index {
515        shape.push(1);
516    }
517    let len_dim = shape[dim_index];
518    let mut output_shape = shape.clone();
519    if len_dim <= 1 || data.is_empty() {
520        output_shape[dim_index] = output_shape[dim_index].saturating_sub(1);
521        return ComplexTensor::new(Vec::new(), output_shape)
522            .map_err(|e| diff_internal_error(format!("diff: {e}")));
523    }
524    output_shape[dim_index] = len_dim - 1;
525    let stride_before = product(&shape[..dim_index]);
526    let stride_after = product(&shape[dim_index + 1..]);
527    let mut out = Vec::with_capacity(stride_before * (len_dim - 1) * stride_after);
528
529    for after in 0..stride_after {
530        let after_base = after * stride_before * len_dim;
531        for before in 0..stride_before {
532            for k in 0..(len_dim - 1) {
533                let idx0 = before + after_base + k * stride_before;
534                let idx1 = idx0 + stride_before;
535                let (re0, im0) = data[idx0];
536                let (re1, im1) = data[idx1];
537                out.push((re1 - re0, im1 - im0));
538            }
539        }
540    }
541
542    ComplexTensor::new(out, output_shape).map_err(|e| diff_internal_error(format!("diff: {e}")))
543}
544
545fn default_dimension(shape: &[usize]) -> usize {
546    shape
547        .iter()
548        .position(|&dim| dim > 1)
549        .map(|idx| idx + 1)
550        .unwrap_or(1)
551}
552
553fn dimension_length(shape: &[usize], dim: usize) -> usize {
554    let dim_index = dim.saturating_sub(1);
555    if dim_index < shape.len() {
556        shape[dim_index]
557    } else {
558        1
559    }
560}
561
562fn product(dims: &[usize]) -> usize {
563    dims.iter()
564        .copied()
565        .fold(1usize, |acc, val| acc.saturating_mul(val))
566}
567
568#[cfg(test)]
569pub(crate) mod tests {
570    use super::*;
571    use crate::builtins::common::test_support;
572    use futures::executor::block_on;
573    use runmat_builtins::{IntValue, SymbolicExpr, Tensor};
574
575    fn diff_builtin(value: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
576        block_on(super::diff_builtin(value, rest))
577    }
578
579    #[test]
580    fn diff_type_defaults_tensor() {
581        let out = diff_type(
582            &[Type::Tensor {
583                shape: Some(vec![Some(2), Some(3)]),
584            }],
585            &ResolveContext::new(Vec::new()),
586        );
587        assert_eq!(
588            out,
589            Type::Tensor {
590                shape: Some(vec![None, None])
591            }
592        );
593    }
594
595    #[test]
596    fn diff_descriptor_signatures_cover_core_forms() {
597        let labels: Vec<&str> = DIFF_DESCRIPTOR
598            .signatures
599            .iter()
600            .map(|sig| sig.label)
601            .collect();
602        assert!(labels.contains(&"B = diff(X)"));
603        assert!(labels.contains(&"B = diff(X, n)"));
604        assert!(labels.contains(&"B = diff(X, n, dim)"));
605    }
606
607    #[test]
608    fn diff_descriptor_errors_have_stable_codes() {
609        assert!(DIFF_DESCRIPTOR
610            .errors
611            .iter()
612            .any(|error| error.code == DIFF_ERROR_INVALID_ARGUMENT.code));
613        assert!(DIFF_DESCRIPTOR
614            .errors
615            .iter()
616            .any(|error| error.code == DIFF_ERROR_INVALID_INPUT.code));
617        assert!(DIFF_DESCRIPTOR
618            .errors
619            .iter()
620            .any(|error| error.code == DIFF_ERROR_INTERNAL.code));
621    }
622
623    #[test]
624    fn diff_symbolic_function_with_explicit_variable() {
625        let y = SymbolicExpr::function_reference("Y", vec!["X".to_string()]);
626
627        let result = diff_builtin(
628            Value::Symbolic(y),
629            vec![Value::Symbolic(SymbolicExpr::variable("X"))],
630        )
631        .expect("diff");
632
633        assert_eq!(result.to_string(), "diff(Y(X), X)");
634    }
635
636    #[test]
637    fn diff_symbolic_function_accepts_order_before_variable() {
638        let y = SymbolicExpr::function_reference("Y", vec!["X".to_string()]);
639
640        let result = diff_builtin(
641            Value::Symbolic(y),
642            vec![
643                Value::Int(IntValue::I32(2)),
644                Value::Symbolic(SymbolicExpr::variable("X")),
645            ],
646        )
647        .expect("diff");
648
649        assert_eq!(result.to_string(), "diff(Y(X), X, 2)");
650    }
651
652    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
653    #[test]
654    fn diff_row_vector_default_dimension() {
655        let tensor = Tensor::new(vec![1.0, 4.0, 9.0], vec![1, 3]).unwrap();
656        let result = diff_builtin(Value::Tensor(tensor), Vec::new()).expect("diff");
657        match result {
658            Value::Tensor(out) => {
659                assert_eq!(out.shape, vec![1, 2]);
660                assert_eq!(out.data, vec![3.0, 5.0]);
661            }
662            other => panic!("expected tensor result, got {other:?}"),
663        }
664    }
665
666    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
667    #[test]
668    fn diff_column_vector_second_order() {
669        let tensor = Tensor::new(vec![1.0, 4.0, 9.0, 16.0], vec![4, 1]).unwrap();
670        let args = vec![Value::Int(IntValue::I32(2))];
671        let result = diff_builtin(Value::Tensor(tensor), args).expect("diff");
672        match result {
673            Value::Tensor(out) => {
674                assert_eq!(out.shape, vec![2, 1]);
675                assert_eq!(out.data, vec![2.0, 2.0]);
676            }
677            other => panic!("expected tensor result, got {other:?}"),
678        }
679    }
680
681    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
682    #[test]
683    fn diff_matrix_along_columns() {
684        let tensor = Tensor::new(vec![1.0, 3.0, 5.0, 2.0, 4.0, 6.0], vec![3, 2]).unwrap();
685        let args = vec![Value::Int(IntValue::I32(1)), Value::Int(IntValue::I32(2))];
686        let result = diff_builtin(Value::Tensor(tensor), args).expect("diff");
687        match result {
688            Value::Tensor(out) => {
689                assert_eq!(out.shape, vec![3, 1]);
690                assert_eq!(out.data, vec![1.0, 1.0, 1.0]);
691            }
692            other => panic!("expected tensor result, got {other:?}"),
693        }
694    }
695
696    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
697    #[test]
698    fn diff_handles_empty_when_order_exceeds_dimension() {
699        let tensor = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
700        let args = vec![Value::Int(IntValue::I32(5))];
701        let result = diff_builtin(Value::Tensor(tensor), args).expect("diff");
702        match result {
703            Value::Tensor(out) => {
704                assert_eq!(out.shape[0], 0);
705                assert!(out.data.is_empty());
706            }
707            other => panic!("expected tensor result, got {other:?}"),
708        }
709    }
710
711    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
712    #[test]
713    fn diff_char_array_promotes_to_double() {
714        let chars = CharArray::new("ACEG".chars().collect(), 1, 4).unwrap();
715        let result = diff_builtin(Value::CharArray(chars), Vec::new()).expect("diff");
716        match result {
717            Value::Tensor(out) => {
718                assert_eq!(out.shape, vec![1, 3]);
719                assert_eq!(out.data, vec![2.0, 2.0, 2.0]);
720            }
721            other => panic!("expected tensor result, got {other:?}"),
722        }
723    }
724
725    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
726    #[test]
727    fn diff_complex_tensor_preserves_type() {
728        let tensor =
729            ComplexTensor::new(vec![(1.0, 1.0), (3.0, 2.0), (6.0, 5.0)], vec![1, 3]).unwrap();
730        let result = diff_builtin(Value::ComplexTensor(tensor), Vec::new()).expect("diff");
731        match result {
732            Value::ComplexTensor(out) => {
733                assert_eq!(out.shape, vec![1, 2]);
734                assert_eq!(out.data, vec![(2.0, 1.0), (3.0, 3.0)]);
735            }
736            other => panic!("expected complex tensor result, got {other:?}"),
737        }
738    }
739
740    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
741    #[test]
742    fn diff_zero_order_returns_input() {
743        let tensor = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
744        let args = vec![Value::Int(IntValue::I32(0))];
745        let result = diff_builtin(Value::Tensor(tensor.clone()), args).expect("diff");
746        assert_eq!(result, Value::Tensor(tensor));
747    }
748
749    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
750    #[test]
751    fn diff_accepts_empty_order_argument() {
752        let tensor = Tensor::new(vec![1.0, 4.0, 9.0], vec![3, 1]).unwrap();
753        let baseline = diff_builtin(Value::Tensor(tensor.clone()), Vec::new()).expect("diff");
754        let empty = Tensor::new(vec![], vec![0, 0]).unwrap();
755        let result = diff_builtin(Value::Tensor(tensor), vec![Value::Tensor(empty)]).expect("diff");
756        assert_eq!(result, baseline);
757    }
758
759    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
760    #[test]
761    fn diff_accepts_empty_dimension_argument() {
762        let tensor = Tensor::new(vec![1.0, 4.0, 9.0, 16.0], vec![1, 4]).unwrap();
763        let baseline = diff_builtin(
764            Value::Tensor(tensor.clone()),
765            vec![Value::Int(IntValue::I32(1))],
766        )
767        .expect("diff");
768        let empty = Tensor::new(vec![], vec![0, 0]).unwrap();
769        let result = diff_builtin(
770            Value::Tensor(tensor),
771            vec![Value::Int(IntValue::I32(1)), Value::Tensor(empty)],
772        )
773        .expect("diff");
774        assert_eq!(result, baseline);
775    }
776
777    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
778    #[test]
779    fn diff_rejects_negative_order() {
780        let tensor = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
781        let args = vec![Value::Int(IntValue::I32(-1))];
782        let err = diff_builtin(Value::Tensor(tensor), args).unwrap_err();
783        assert_eq!(err.identifier(), DIFF_ERROR_INVALID_ARGUMENT.identifier);
784        assert!(err.message().contains("non-negative"));
785    }
786
787    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
788    #[test]
789    fn diff_rejects_non_integer_order() {
790        let tensor = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
791        let args = vec![Value::Num(1.5)];
792        let err = diff_builtin(Value::Tensor(tensor), args).unwrap_err();
793        assert_eq!(err.identifier(), DIFF_ERROR_INVALID_ARGUMENT.identifier);
794        assert!(err.message().contains("non-negative integer"));
795    }
796
797    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
798    #[test]
799    fn diff_rejects_invalid_dimension() {
800        let tensor = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
801        let args = vec![Value::Int(IntValue::I32(1)), Value::Int(IntValue::I32(0))];
802        let err = diff_builtin(Value::Tensor(tensor), args).unwrap_err();
803        assert_eq!(err.identifier(), DIFF_ERROR_INVALID_ARGUMENT.identifier);
804        assert!(err.message().contains("dimension must be >= 1"));
805    }
806
807    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
808    #[test]
809    fn diff_gpu_provider_roundtrip() {
810        test_support::with_test_provider(|provider| {
811            let tensor = Tensor::new(vec![1.0, 4.0, 9.0], vec![3, 1]).unwrap();
812            let view = runmat_accelerate_api::HostTensorView {
813                data: &tensor.data,
814                shape: &tensor.shape,
815            };
816            let handle = provider.upload(&view).expect("upload");
817            let result = diff_builtin(Value::GpuTensor(handle), Vec::new()).expect("diff");
818            let gathered = test_support::gather(result).expect("gather");
819            assert_eq!(gathered.shape, vec![2, 1]);
820            assert_eq!(gathered.data, vec![3.0, 5.0]);
821        });
822    }
823
824    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
825    #[test]
826    #[cfg(feature = "wgpu")]
827    fn diff_wgpu_matches_cpu() {
828        let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
829            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
830        );
831        let tensor = Tensor::new(vec![1.0, 4.0, 9.0, 16.0], vec![4, 1]).unwrap();
832        let args = vec![Value::Int(IntValue::I32(2))];
833
834        let cpu_result = diff_builtin(Value::Tensor(tensor.clone()), args.clone()).expect("diff");
835        let expected = match cpu_result {
836            Value::Tensor(t) => t,
837            other => panic!("expected tensor result, got {other:?}"),
838        };
839
840        let provider = runmat_accelerate_api::provider().expect("wgpu provider");
841        let view = runmat_accelerate_api::HostTensorView {
842            data: &tensor.data,
843            shape: &tensor.shape,
844        };
845        let handle = provider.upload(&view).expect("upload");
846        let gpu_value = diff_builtin(Value::GpuTensor(handle), args).expect("diff");
847        let gathered = test_support::gather(gpu_value).expect("gather");
848
849        assert_eq!(gathered.shape, expected.shape);
850        let tol = if matches!(
851            provider.precision(),
852            runmat_accelerate_api::ProviderPrecision::F32
853        ) {
854            1e-5
855        } else {
856            1e-12
857        };
858        for (a, b) in gathered.data.iter().zip(expected.data.iter()) {
859            assert!((a - b).abs() < tol, "|{a} - {b}| >= {tol}");
860        }
861    }
862}