Skip to main content

runmat_runtime/builtins/math/reduction/
diff.rs

1//! MATLAB-compatible `diff` builtin with GPU-aware semantics for RunMat.
2
3use runmat_accelerate_api::GpuTensorHandle;
4use runmat_builtins::{CharArray, ComplexTensor, ResolveContext, Tensor, Type, Value};
5use runmat_macros::runtime_builtin;
6
7use crate::builtins::common::random_args::complex_tensor_into_value;
8use crate::builtins::common::spec::{
9    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
10    ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
11};
12use crate::builtins::common::{gpu_helpers, tensor};
13use crate::builtins::math::reduction::type_resolvers::diff_numeric_type;
14use crate::{build_runtime_error, BuiltinResult, RuntimeError};
15
16const NAME: &str = "diff";
17
18fn diff_type(args: &[Type], ctx: &ResolveContext) -> Type {
19    diff_numeric_type(args, ctx)
20}
21
22#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::math::reduction::diff")]
23pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
24    name: "diff",
25    op_kind: GpuOpKind::Custom("finite-difference"),
26    supported_precisions: &[ScalarType::F32, ScalarType::F64],
27    broadcast: BroadcastSemantics::Matlab,
28    provider_hooks: &[ProviderHook::Custom("diff_dim")],
29    constant_strategy: ConstantStrategy::InlineLiteral,
30    residency: ResidencyPolicy::NewHandle,
31    nan_mode: ReductionNaN::Include,
32    two_pass_threshold: None,
33    workgroup_size: None,
34    accepts_nan_mode: false,
35    notes: "Providers surface finite-difference kernels through `diff_dim`; the WGPU backend keeps tensors on the device.",
36};
37
38fn diff_error(message: impl Into<String>) -> RuntimeError {
39    build_runtime_error(message).with_builtin(NAME).build()
40}
41
42#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::math::reduction::diff")]
43pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
44    name: "diff",
45    shape: ShapeRequirements::BroadcastCompatible,
46    constant_strategy: ConstantStrategy::InlineLiteral,
47    elementwise: None,
48    reduction: None,
49    emits_nan: false,
50    notes: "Fusion planner currently delegates to the runtime implementation; providers can override with custom kernels.",
51};
52
53#[runtime_builtin(
54    name = "diff",
55    category = "math/reduction",
56    summary = "Forward finite differences of scalars, vectors, matrices, or N-D tensors.",
57    keywords = "diff,difference,finite difference,nth difference,gpu",
58    accel = "diff",
59    type_resolver(diff_type),
60    builtin_path = "crate::builtins::math::reduction::diff"
61)]
62async fn diff_builtin(value: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
63    let (order, dim) = parse_arguments(&rest)?;
64    if order == 0 {
65        return Ok(value);
66    }
67
68    match value {
69        Value::Tensor(tensor) => {
70            diff_tensor_host(tensor, order, dim).map(tensor::tensor_into_value)
71        }
72        Value::LogicalArray(logical) => {
73            let tensor = tensor::logical_to_tensor(&logical).map_err(diff_error)?;
74            diff_tensor_host(tensor, order, dim).map(tensor::tensor_into_value)
75        }
76        Value::Num(_) | Value::Int(_) | Value::Bool(_) => {
77            let tensor = tensor::value_into_tensor_for("diff", value).map_err(diff_error)?;
78            diff_tensor_host(tensor, order, dim).map(tensor::tensor_into_value)
79        }
80        Value::Complex(re, im) => {
81            let tensor = ComplexTensor {
82                data: vec![(re, im)],
83                shape: vec![1, 1],
84                rows: 1,
85                cols: 1,
86            };
87            diff_complex_tensor(tensor, order, dim).map(complex_tensor_into_value)
88        }
89        Value::ComplexTensor(tensor) => {
90            diff_complex_tensor(tensor, order, dim).map(complex_tensor_into_value)
91        }
92        Value::CharArray(chars) => diff_char_array(chars, order, dim),
93        Value::GpuTensor(handle) => diff_gpu(handle, order, dim).await,
94        other => Err(diff_error(format!(
95            "diff: unsupported input type {:?}; expected numeric, logical, or character data",
96            other
97        ))),
98    }
99}
100
101fn parse_arguments(args: &[Value]) -> BuiltinResult<(usize, Option<usize>)> {
102    match args.len() {
103        0 => Ok((1, None)),
104        1 => {
105            let order = parse_order(&args[0])?;
106            Ok((order.unwrap_or(1), None))
107        }
108        2 => {
109            let order = parse_order(&args[0])?.unwrap_or(1);
110            let dim = parse_dimension_arg(&args[1])?;
111            Ok((order, dim))
112        }
113        _ => Err(diff_error("diff: unsupported arguments")),
114    }
115}
116
117fn parse_order(value: &Value) -> BuiltinResult<Option<usize>> {
118    if is_empty_array(value) {
119        return Ok(None);
120    }
121    match value {
122        Value::Int(i) => {
123            let raw = i.to_i64();
124            if raw < 0 {
125                return Err(diff_error(
126                    "diff: order must be a non-negative integer scalar",
127                ));
128            }
129            Ok(Some(raw as usize))
130        }
131        Value::Num(n) => parse_numeric_order(*n).map(Some),
132        Value::Tensor(t) if t.data.len() == 1 => parse_numeric_order(t.data[0]).map(Some),
133        Value::Bool(b) => Ok(Some(if *b { 1 } else { 0 })),
134        other => Err(diff_error(format!(
135            "diff: order must be a non-negative integer scalar, got {:?}",
136            other
137        ))),
138    }
139}
140
141fn parse_numeric_order(value: f64) -> BuiltinResult<usize> {
142    if !value.is_finite() {
143        return Err(diff_error("diff: order must be finite"));
144    }
145    if value < 0.0 {
146        return Err(diff_error(
147            "diff: order must be a non-negative integer scalar",
148        ));
149    }
150    let rounded = value.round();
151    if (rounded - value).abs() > f64::EPSILON {
152        return Err(diff_error(
153            "diff: order must be a non-negative integer scalar",
154        ));
155    }
156    Ok(rounded as usize)
157}
158
159fn parse_dimension_arg(value: &Value) -> BuiltinResult<Option<usize>> {
160    if is_empty_array(value) {
161        return Ok(None);
162    }
163    match value {
164        Value::Int(_) | Value::Num(_) => tensor::parse_dimension(value, "diff")
165            .map(Some)
166            .map_err(diff_error),
167        Value::Tensor(t) if t.data.len() == 1 => {
168            tensor::parse_dimension(&Value::Num(t.data[0]), "diff")
169                .map(Some)
170                .map_err(diff_error)
171        }
172        other => Err(diff_error(format!(
173            "diff: dimension must be a positive integer scalar, got {:?}",
174            other
175        ))),
176    }
177}
178
179fn is_empty_array(value: &Value) -> bool {
180    matches!(value, Value::Tensor(t) if t.data.is_empty())
181}
182
183async fn diff_gpu(
184    handle: GpuTensorHandle,
185    order: usize,
186    dim: Option<usize>,
187) -> BuiltinResult<Value> {
188    let working_dim = dim.unwrap_or_else(|| default_dimension(&handle.shape));
189    if working_dim == 0 {
190        return Err(diff_error("diff: dimension must be >= 1"));
191    }
192
193    if let Some(provider) = runmat_accelerate_api::provider() {
194        if let Ok(device_result) = provider.diff_dim(&handle, order, working_dim.saturating_sub(1))
195        {
196            return Ok(Value::GpuTensor(device_result));
197        }
198    }
199
200    let tensor = gpu_helpers::gather_tensor_async(&handle).await?;
201    diff_tensor_host(tensor, order, Some(working_dim)).map(tensor::tensor_into_value)
202}
203
204fn diff_char_array(chars: CharArray, order: usize, dim: Option<usize>) -> BuiltinResult<Value> {
205    if order == 0 {
206        return Ok(Value::CharArray(chars));
207    }
208    let shape = vec![chars.rows, chars.cols];
209    let data: Vec<f64> = chars.data.iter().map(|&ch| ch as u32 as f64).collect();
210    let tensor = Tensor::new(data, shape).map_err(|e| diff_error(format!("diff: {e}")))?;
211    diff_tensor_host(tensor, order, dim).map(tensor::tensor_into_value)
212}
213
214pub fn diff_tensor_host(tensor: Tensor, order: usize, dim: Option<usize>) -> BuiltinResult<Tensor> {
215    let mut current = tensor;
216    let mut working_dim = dim.unwrap_or_else(|| default_dimension(&current.shape));
217    for _ in 0..order {
218        current = diff_tensor_once(current, working_dim)?;
219        if current.data.is_empty() {
220            break;
221        }
222        // Preserve explicit dimension if the caller provided one; update when defaulting and shape shrinks.
223        if dim.is_none() && dimension_length(&current.shape, working_dim) == 0 {
224            working_dim = default_dimension(&current.shape);
225        }
226    }
227    Ok(current)
228}
229
230fn diff_complex_tensor(
231    tensor: ComplexTensor,
232    order: usize,
233    dim: Option<usize>,
234) -> BuiltinResult<ComplexTensor> {
235    let mut current = tensor;
236    let mut working_dim = dim.unwrap_or_else(|| default_dimension(&current.shape));
237    for _ in 0..order {
238        current = diff_complex_tensor_once(current, working_dim)?;
239        if current.data.is_empty() {
240            break;
241        }
242        if dim.is_none() && dimension_length(&current.shape, working_dim) == 0 {
243            working_dim = default_dimension(&current.shape);
244        }
245    }
246    Ok(current)
247}
248
249fn diff_tensor_once(tensor: Tensor, dim: usize) -> BuiltinResult<Tensor> {
250    let Tensor {
251        data, mut shape, ..
252    } = tensor;
253    let dim_index = dim.saturating_sub(1);
254    while shape.len() <= dim_index {
255        shape.push(1);
256    }
257    let len_dim = shape[dim_index];
258    let mut output_shape = shape.clone();
259    if len_dim <= 1 || data.is_empty() {
260        output_shape[dim_index] = output_shape[dim_index].saturating_sub(1);
261        return Tensor::new(Vec::new(), output_shape).map_err(|e| diff_error(format!("diff: {e}")));
262    }
263    output_shape[dim_index] = len_dim - 1;
264    let stride_before = product(&shape[..dim_index]);
265    let stride_after = product(&shape[dim_index + 1..]);
266    let output_len = stride_before * (len_dim - 1) * stride_after;
267    let mut out = Vec::with_capacity(output_len);
268
269    for after in 0..stride_after {
270        let after_base = after * stride_before * len_dim;
271        for before in 0..stride_before {
272            for k in 0..(len_dim - 1) {
273                let idx0 = before + after_base + k * stride_before;
274                let idx1 = idx0 + stride_before;
275                out.push(data[idx1] - data[idx0]);
276            }
277        }
278    }
279
280    Tensor::new(out, output_shape).map_err(|e| diff_error(format!("diff: {e}")))
281}
282
283fn diff_complex_tensor_once(tensor: ComplexTensor, dim: usize) -> BuiltinResult<ComplexTensor> {
284    let ComplexTensor {
285        data, mut shape, ..
286    } = tensor;
287    let dim_index = dim.saturating_sub(1);
288    while shape.len() <= dim_index {
289        shape.push(1);
290    }
291    let len_dim = shape[dim_index];
292    let mut output_shape = shape.clone();
293    if len_dim <= 1 || data.is_empty() {
294        output_shape[dim_index] = output_shape[dim_index].saturating_sub(1);
295        return ComplexTensor::new(Vec::new(), output_shape)
296            .map_err(|e| diff_error(format!("diff: {e}")));
297    }
298    output_shape[dim_index] = len_dim - 1;
299    let stride_before = product(&shape[..dim_index]);
300    let stride_after = product(&shape[dim_index + 1..]);
301    let mut out = Vec::with_capacity(stride_before * (len_dim - 1) * stride_after);
302
303    for after in 0..stride_after {
304        let after_base = after * stride_before * len_dim;
305        for before in 0..stride_before {
306            for k in 0..(len_dim - 1) {
307                let idx0 = before + after_base + k * stride_before;
308                let idx1 = idx0 + stride_before;
309                let (re0, im0) = data[idx0];
310                let (re1, im1) = data[idx1];
311                out.push((re1 - re0, im1 - im0));
312            }
313        }
314    }
315
316    ComplexTensor::new(out, output_shape).map_err(|e| diff_error(format!("diff: {e}")))
317}
318
319fn default_dimension(shape: &[usize]) -> usize {
320    shape
321        .iter()
322        .position(|&dim| dim > 1)
323        .map(|idx| idx + 1)
324        .unwrap_or(1)
325}
326
327fn dimension_length(shape: &[usize], dim: usize) -> usize {
328    let dim_index = dim.saturating_sub(1);
329    if dim_index < shape.len() {
330        shape[dim_index]
331    } else {
332        1
333    }
334}
335
336fn product(dims: &[usize]) -> usize {
337    dims.iter()
338        .copied()
339        .fold(1usize, |acc, val| acc.saturating_mul(val))
340}
341
342#[cfg(test)]
343pub(crate) mod tests {
344    use super::*;
345    use crate::builtins::common::test_support;
346    use futures::executor::block_on;
347    use runmat_builtins::{IntValue, Tensor};
348
349    fn diff_builtin(value: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
350        block_on(super::diff_builtin(value, rest))
351    }
352
353    #[test]
354    fn diff_type_defaults_tensor() {
355        let out = diff_type(
356            &[Type::Tensor {
357                shape: Some(vec![Some(2), Some(3)]),
358            }],
359            &ResolveContext::new(Vec::new()),
360        );
361        assert_eq!(
362            out,
363            Type::Tensor {
364                shape: Some(vec![None, None])
365            }
366        );
367    }
368
369    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
370    #[test]
371    fn diff_row_vector_default_dimension() {
372        let tensor = Tensor::new(vec![1.0, 4.0, 9.0], vec![1, 3]).unwrap();
373        let result = diff_builtin(Value::Tensor(tensor), Vec::new()).expect("diff");
374        match result {
375            Value::Tensor(out) => {
376                assert_eq!(out.shape, vec![1, 2]);
377                assert_eq!(out.data, vec![3.0, 5.0]);
378            }
379            other => panic!("expected tensor result, got {other:?}"),
380        }
381    }
382
383    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
384    #[test]
385    fn diff_column_vector_second_order() {
386        let tensor = Tensor::new(vec![1.0, 4.0, 9.0, 16.0], vec![4, 1]).unwrap();
387        let args = vec![Value::Int(IntValue::I32(2))];
388        let result = diff_builtin(Value::Tensor(tensor), args).expect("diff");
389        match result {
390            Value::Tensor(out) => {
391                assert_eq!(out.shape, vec![2, 1]);
392                assert_eq!(out.data, vec![2.0, 2.0]);
393            }
394            other => panic!("expected tensor result, got {other:?}"),
395        }
396    }
397
398    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
399    #[test]
400    fn diff_matrix_along_columns() {
401        let tensor = Tensor::new(vec![1.0, 3.0, 5.0, 2.0, 4.0, 6.0], vec![3, 2]).unwrap();
402        let args = vec![Value::Int(IntValue::I32(1)), Value::Int(IntValue::I32(2))];
403        let result = diff_builtin(Value::Tensor(tensor), args).expect("diff");
404        match result {
405            Value::Tensor(out) => {
406                assert_eq!(out.shape, vec![3, 1]);
407                assert_eq!(out.data, vec![1.0, 1.0, 1.0]);
408            }
409            other => panic!("expected tensor result, got {other:?}"),
410        }
411    }
412
413    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
414    #[test]
415    fn diff_handles_empty_when_order_exceeds_dimension() {
416        let tensor = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
417        let args = vec![Value::Int(IntValue::I32(5))];
418        let result = diff_builtin(Value::Tensor(tensor), args).expect("diff");
419        match result {
420            Value::Tensor(out) => {
421                assert_eq!(out.shape[0], 0);
422                assert!(out.data.is_empty());
423            }
424            other => panic!("expected tensor result, got {other:?}"),
425        }
426    }
427
428    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
429    #[test]
430    fn diff_char_array_promotes_to_double() {
431        let chars = CharArray::new("ACEG".chars().collect(), 1, 4).unwrap();
432        let result = diff_builtin(Value::CharArray(chars), Vec::new()).expect("diff");
433        match result {
434            Value::Tensor(out) => {
435                assert_eq!(out.shape, vec![1, 3]);
436                assert_eq!(out.data, vec![2.0, 2.0, 2.0]);
437            }
438            other => panic!("expected tensor result, got {other:?}"),
439        }
440    }
441
442    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
443    #[test]
444    fn diff_complex_tensor_preserves_type() {
445        let tensor =
446            ComplexTensor::new(vec![(1.0, 1.0), (3.0, 2.0), (6.0, 5.0)], vec![1, 3]).unwrap();
447        let result = diff_builtin(Value::ComplexTensor(tensor), Vec::new()).expect("diff");
448        match result {
449            Value::ComplexTensor(out) => {
450                assert_eq!(out.shape, vec![1, 2]);
451                assert_eq!(out.data, vec![(2.0, 1.0), (3.0, 3.0)]);
452            }
453            other => panic!("expected complex tensor result, got {other:?}"),
454        }
455    }
456
457    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
458    #[test]
459    fn diff_zero_order_returns_input() {
460        let tensor = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
461        let args = vec![Value::Int(IntValue::I32(0))];
462        let result = diff_builtin(Value::Tensor(tensor.clone()), args).expect("diff");
463        assert_eq!(result, Value::Tensor(tensor));
464    }
465
466    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
467    #[test]
468    fn diff_accepts_empty_order_argument() {
469        let tensor = Tensor::new(vec![1.0, 4.0, 9.0], vec![3, 1]).unwrap();
470        let baseline = diff_builtin(Value::Tensor(tensor.clone()), Vec::new()).expect("diff");
471        let empty = Tensor::new(vec![], vec![0, 0]).unwrap();
472        let result = diff_builtin(Value::Tensor(tensor), vec![Value::Tensor(empty)]).expect("diff");
473        assert_eq!(result, baseline);
474    }
475
476    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
477    #[test]
478    fn diff_accepts_empty_dimension_argument() {
479        let tensor = Tensor::new(vec![1.0, 4.0, 9.0, 16.0], vec![1, 4]).unwrap();
480        let baseline = diff_builtin(
481            Value::Tensor(tensor.clone()),
482            vec![Value::Int(IntValue::I32(1))],
483        )
484        .expect("diff");
485        let empty = Tensor::new(vec![], vec![0, 0]).unwrap();
486        let result = diff_builtin(
487            Value::Tensor(tensor),
488            vec![Value::Int(IntValue::I32(1)), Value::Tensor(empty)],
489        )
490        .expect("diff");
491        assert_eq!(result, baseline);
492    }
493
494    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
495    #[test]
496    fn diff_rejects_negative_order() {
497        let tensor = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
498        let args = vec![Value::Int(IntValue::I32(-1))];
499        let err = diff_builtin(Value::Tensor(tensor), args).unwrap_err();
500        assert!(err.message().contains("non-negative"));
501    }
502
503    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
504    #[test]
505    fn diff_rejects_non_integer_order() {
506        let tensor = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
507        let args = vec![Value::Num(1.5)];
508        let err = diff_builtin(Value::Tensor(tensor), args).unwrap_err();
509        assert!(err.message().contains("non-negative integer"));
510    }
511
512    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
513    #[test]
514    fn diff_rejects_invalid_dimension() {
515        let tensor = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
516        let args = vec![Value::Int(IntValue::I32(1)), Value::Int(IntValue::I32(0))];
517        let err = diff_builtin(Value::Tensor(tensor), args).unwrap_err();
518        assert!(err.message().contains("dimension must be >= 1"));
519    }
520
521    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
522    #[test]
523    fn diff_gpu_provider_roundtrip() {
524        test_support::with_test_provider(|provider| {
525            let tensor = Tensor::new(vec![1.0, 4.0, 9.0], vec![3, 1]).unwrap();
526            let view = runmat_accelerate_api::HostTensorView {
527                data: &tensor.data,
528                shape: &tensor.shape,
529            };
530            let handle = provider.upload(&view).expect("upload");
531            let result = diff_builtin(Value::GpuTensor(handle), Vec::new()).expect("diff");
532            let gathered = test_support::gather(result).expect("gather");
533            assert_eq!(gathered.shape, vec![2, 1]);
534            assert_eq!(gathered.data, vec![3.0, 5.0]);
535        });
536    }
537
538    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
539    #[test]
540    #[cfg(feature = "wgpu")]
541    fn diff_wgpu_matches_cpu() {
542        let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
543            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
544        );
545        let tensor = Tensor::new(vec![1.0, 4.0, 9.0, 16.0], vec![4, 1]).unwrap();
546        let args = vec![Value::Int(IntValue::I32(2))];
547
548        let cpu_result = diff_builtin(Value::Tensor(tensor.clone()), args.clone()).expect("diff");
549        let expected = match cpu_result {
550            Value::Tensor(t) => t,
551            other => panic!("expected tensor result, got {other:?}"),
552        };
553
554        let provider = runmat_accelerate_api::provider().expect("wgpu provider");
555        let view = runmat_accelerate_api::HostTensorView {
556            data: &tensor.data,
557            shape: &tensor.shape,
558        };
559        let handle = provider.upload(&view).expect("upload");
560        let gpu_value = diff_builtin(Value::GpuTensor(handle), args).expect("diff");
561        let gathered = test_support::gather(gpu_value).expect("gather");
562
563        assert_eq!(gathered.shape, expected.shape);
564        let tol = if matches!(
565            provider.precision(),
566            runmat_accelerate_api::ProviderPrecision::F32
567        ) {
568            1e-5
569        } else {
570            1e-12
571        };
572        for (a, b) in gathered.data.iter().zip(expected.data.iter()) {
573            assert!((a - b).abs() < tol, "|{a} - {b}| >= {tol}");
574        }
575    }
576}