1use runmat_accelerate_api::GpuTensorHandle;
4use runmat_builtins::{CharArray, ComplexTensor, ResolveContext, Tensor, Type, Value};
5use runmat_macros::runtime_builtin;
6
7use crate::builtins::common::random_args::complex_tensor_into_value;
8use crate::builtins::common::spec::{
9 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
10 ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
11};
12use crate::builtins::common::{gpu_helpers, tensor};
13use crate::builtins::math::reduction::type_resolvers::diff_numeric_type;
14use crate::{build_runtime_error, BuiltinResult, RuntimeError};
15
16const NAME: &str = "diff";
17
18fn diff_type(args: &[Type], ctx: &ResolveContext) -> Type {
19 diff_numeric_type(args, ctx)
20}
21
22#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::math::reduction::diff")]
23pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
24 name: "diff",
25 op_kind: GpuOpKind::Custom("finite-difference"),
26 supported_precisions: &[ScalarType::F32, ScalarType::F64],
27 broadcast: BroadcastSemantics::Matlab,
28 provider_hooks: &[ProviderHook::Custom("diff_dim")],
29 constant_strategy: ConstantStrategy::InlineLiteral,
30 residency: ResidencyPolicy::NewHandle,
31 nan_mode: ReductionNaN::Include,
32 two_pass_threshold: None,
33 workgroup_size: None,
34 accepts_nan_mode: false,
35 notes: "Providers surface finite-difference kernels through `diff_dim`; the WGPU backend keeps tensors on the device.",
36};
37
38fn diff_error(message: impl Into<String>) -> RuntimeError {
39 build_runtime_error(message).with_builtin(NAME).build()
40}
41
42#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::math::reduction::diff")]
43pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
44 name: "diff",
45 shape: ShapeRequirements::BroadcastCompatible,
46 constant_strategy: ConstantStrategy::InlineLiteral,
47 elementwise: None,
48 reduction: None,
49 emits_nan: false,
50 notes: "Fusion planner currently delegates to the runtime implementation; providers can override with custom kernels.",
51};
52
53#[runtime_builtin(
54 name = "diff",
55 category = "math/reduction",
56 summary = "Forward finite differences of scalars, vectors, matrices, or N-D tensors.",
57 keywords = "diff,difference,finite difference,nth difference,gpu",
58 accel = "diff",
59 type_resolver(diff_type),
60 builtin_path = "crate::builtins::math::reduction::diff"
61)]
62async fn diff_builtin(value: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
63 let (order, dim) = parse_arguments(&rest)?;
64 if order == 0 {
65 return Ok(value);
66 }
67
68 match value {
69 Value::Tensor(tensor) => {
70 diff_tensor_host(tensor, order, dim).map(tensor::tensor_into_value)
71 }
72 Value::LogicalArray(logical) => {
73 let tensor = tensor::logical_to_tensor(&logical).map_err(diff_error)?;
74 diff_tensor_host(tensor, order, dim).map(tensor::tensor_into_value)
75 }
76 Value::Num(_) | Value::Int(_) | Value::Bool(_) => {
77 let tensor = tensor::value_into_tensor_for("diff", value).map_err(diff_error)?;
78 diff_tensor_host(tensor, order, dim).map(tensor::tensor_into_value)
79 }
80 Value::Complex(re, im) => {
81 let tensor = ComplexTensor {
82 data: vec![(re, im)],
83 shape: vec![1, 1],
84 rows: 1,
85 cols: 1,
86 };
87 diff_complex_tensor(tensor, order, dim).map(complex_tensor_into_value)
88 }
89 Value::ComplexTensor(tensor) => {
90 diff_complex_tensor(tensor, order, dim).map(complex_tensor_into_value)
91 }
92 Value::CharArray(chars) => diff_char_array(chars, order, dim),
93 Value::GpuTensor(handle) => diff_gpu(handle, order, dim).await,
94 other => Err(diff_error(format!(
95 "diff: unsupported input type {:?}; expected numeric, logical, or character data",
96 other
97 ))),
98 }
99}
100
101fn parse_arguments(args: &[Value]) -> BuiltinResult<(usize, Option<usize>)> {
102 match args.len() {
103 0 => Ok((1, None)),
104 1 => {
105 let order = parse_order(&args[0])?;
106 Ok((order.unwrap_or(1), None))
107 }
108 2 => {
109 let order = parse_order(&args[0])?.unwrap_or(1);
110 let dim = parse_dimension_arg(&args[1])?;
111 Ok((order, dim))
112 }
113 _ => Err(diff_error("diff: unsupported arguments")),
114 }
115}
116
117fn parse_order(value: &Value) -> BuiltinResult<Option<usize>> {
118 if is_empty_array(value) {
119 return Ok(None);
120 }
121 match value {
122 Value::Int(i) => {
123 let raw = i.to_i64();
124 if raw < 0 {
125 return Err(diff_error(
126 "diff: order must be a non-negative integer scalar",
127 ));
128 }
129 Ok(Some(raw as usize))
130 }
131 Value::Num(n) => parse_numeric_order(*n).map(Some),
132 Value::Tensor(t) if t.data.len() == 1 => parse_numeric_order(t.data[0]).map(Some),
133 Value::Bool(b) => Ok(Some(if *b { 1 } else { 0 })),
134 other => Err(diff_error(format!(
135 "diff: order must be a non-negative integer scalar, got {:?}",
136 other
137 ))),
138 }
139}
140
141fn parse_numeric_order(value: f64) -> BuiltinResult<usize> {
142 if !value.is_finite() {
143 return Err(diff_error("diff: order must be finite"));
144 }
145 if value < 0.0 {
146 return Err(diff_error(
147 "diff: order must be a non-negative integer scalar",
148 ));
149 }
150 let rounded = value.round();
151 if (rounded - value).abs() > f64::EPSILON {
152 return Err(diff_error(
153 "diff: order must be a non-negative integer scalar",
154 ));
155 }
156 Ok(rounded as usize)
157}
158
159fn parse_dimension_arg(value: &Value) -> BuiltinResult<Option<usize>> {
160 if is_empty_array(value) {
161 return Ok(None);
162 }
163 match value {
164 Value::Int(_) | Value::Num(_) => tensor::parse_dimension(value, "diff")
165 .map(Some)
166 .map_err(diff_error),
167 Value::Tensor(t) if t.data.len() == 1 => {
168 tensor::parse_dimension(&Value::Num(t.data[0]), "diff")
169 .map(Some)
170 .map_err(diff_error)
171 }
172 other => Err(diff_error(format!(
173 "diff: dimension must be a positive integer scalar, got {:?}",
174 other
175 ))),
176 }
177}
178
179fn is_empty_array(value: &Value) -> bool {
180 matches!(value, Value::Tensor(t) if t.data.is_empty())
181}
182
183async fn diff_gpu(
184 handle: GpuTensorHandle,
185 order: usize,
186 dim: Option<usize>,
187) -> BuiltinResult<Value> {
188 let working_dim = dim.unwrap_or_else(|| default_dimension(&handle.shape));
189 if working_dim == 0 {
190 return Err(diff_error("diff: dimension must be >= 1"));
191 }
192
193 if let Some(provider) = runmat_accelerate_api::provider() {
194 if let Ok(device_result) = provider.diff_dim(&handle, order, working_dim.saturating_sub(1))
195 {
196 return Ok(Value::GpuTensor(device_result));
197 }
198 }
199
200 let tensor = gpu_helpers::gather_tensor_async(&handle).await?;
201 diff_tensor_host(tensor, order, Some(working_dim)).map(tensor::tensor_into_value)
202}
203
204fn diff_char_array(chars: CharArray, order: usize, dim: Option<usize>) -> BuiltinResult<Value> {
205 if order == 0 {
206 return Ok(Value::CharArray(chars));
207 }
208 let shape = vec![chars.rows, chars.cols];
209 let data: Vec<f64> = chars.data.iter().map(|&ch| ch as u32 as f64).collect();
210 let tensor = Tensor::new(data, shape).map_err(|e| diff_error(format!("diff: {e}")))?;
211 diff_tensor_host(tensor, order, dim).map(tensor::tensor_into_value)
212}
213
214pub fn diff_tensor_host(tensor: Tensor, order: usize, dim: Option<usize>) -> BuiltinResult<Tensor> {
215 let mut current = tensor;
216 let mut working_dim = dim.unwrap_or_else(|| default_dimension(¤t.shape));
217 for _ in 0..order {
218 current = diff_tensor_once(current, working_dim)?;
219 if current.data.is_empty() {
220 break;
221 }
222 if dim.is_none() && dimension_length(¤t.shape, working_dim) == 0 {
224 working_dim = default_dimension(¤t.shape);
225 }
226 }
227 Ok(current)
228}
229
230fn diff_complex_tensor(
231 tensor: ComplexTensor,
232 order: usize,
233 dim: Option<usize>,
234) -> BuiltinResult<ComplexTensor> {
235 let mut current = tensor;
236 let mut working_dim = dim.unwrap_or_else(|| default_dimension(¤t.shape));
237 for _ in 0..order {
238 current = diff_complex_tensor_once(current, working_dim)?;
239 if current.data.is_empty() {
240 break;
241 }
242 if dim.is_none() && dimension_length(¤t.shape, working_dim) == 0 {
243 working_dim = default_dimension(¤t.shape);
244 }
245 }
246 Ok(current)
247}
248
249fn diff_tensor_once(tensor: Tensor, dim: usize) -> BuiltinResult<Tensor> {
250 let Tensor {
251 data, mut shape, ..
252 } = tensor;
253 let dim_index = dim.saturating_sub(1);
254 while shape.len() <= dim_index {
255 shape.push(1);
256 }
257 let len_dim = shape[dim_index];
258 let mut output_shape = shape.clone();
259 if len_dim <= 1 || data.is_empty() {
260 output_shape[dim_index] = output_shape[dim_index].saturating_sub(1);
261 return Tensor::new(Vec::new(), output_shape).map_err(|e| diff_error(format!("diff: {e}")));
262 }
263 output_shape[dim_index] = len_dim - 1;
264 let stride_before = product(&shape[..dim_index]);
265 let stride_after = product(&shape[dim_index + 1..]);
266 let output_len = stride_before * (len_dim - 1) * stride_after;
267 let mut out = Vec::with_capacity(output_len);
268
269 for after in 0..stride_after {
270 let after_base = after * stride_before * len_dim;
271 for before in 0..stride_before {
272 for k in 0..(len_dim - 1) {
273 let idx0 = before + after_base + k * stride_before;
274 let idx1 = idx0 + stride_before;
275 out.push(data[idx1] - data[idx0]);
276 }
277 }
278 }
279
280 Tensor::new(out, output_shape).map_err(|e| diff_error(format!("diff: {e}")))
281}
282
283fn diff_complex_tensor_once(tensor: ComplexTensor, dim: usize) -> BuiltinResult<ComplexTensor> {
284 let ComplexTensor {
285 data, mut shape, ..
286 } = tensor;
287 let dim_index = dim.saturating_sub(1);
288 while shape.len() <= dim_index {
289 shape.push(1);
290 }
291 let len_dim = shape[dim_index];
292 let mut output_shape = shape.clone();
293 if len_dim <= 1 || data.is_empty() {
294 output_shape[dim_index] = output_shape[dim_index].saturating_sub(1);
295 return ComplexTensor::new(Vec::new(), output_shape)
296 .map_err(|e| diff_error(format!("diff: {e}")));
297 }
298 output_shape[dim_index] = len_dim - 1;
299 let stride_before = product(&shape[..dim_index]);
300 let stride_after = product(&shape[dim_index + 1..]);
301 let mut out = Vec::with_capacity(stride_before * (len_dim - 1) * stride_after);
302
303 for after in 0..stride_after {
304 let after_base = after * stride_before * len_dim;
305 for before in 0..stride_before {
306 for k in 0..(len_dim - 1) {
307 let idx0 = before + after_base + k * stride_before;
308 let idx1 = idx0 + stride_before;
309 let (re0, im0) = data[idx0];
310 let (re1, im1) = data[idx1];
311 out.push((re1 - re0, im1 - im0));
312 }
313 }
314 }
315
316 ComplexTensor::new(out, output_shape).map_err(|e| diff_error(format!("diff: {e}")))
317}
318
319fn default_dimension(shape: &[usize]) -> usize {
320 shape
321 .iter()
322 .position(|&dim| dim > 1)
323 .map(|idx| idx + 1)
324 .unwrap_or(1)
325}
326
327fn dimension_length(shape: &[usize], dim: usize) -> usize {
328 let dim_index = dim.saturating_sub(1);
329 if dim_index < shape.len() {
330 shape[dim_index]
331 } else {
332 1
333 }
334}
335
336fn product(dims: &[usize]) -> usize {
337 dims.iter()
338 .copied()
339 .fold(1usize, |acc, val| acc.saturating_mul(val))
340}
341
342#[cfg(test)]
343pub(crate) mod tests {
344 use super::*;
345 use crate::builtins::common::test_support;
346 use futures::executor::block_on;
347 use runmat_builtins::{IntValue, Tensor};
348
349 fn diff_builtin(value: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
350 block_on(super::diff_builtin(value, rest))
351 }
352
353 #[test]
354 fn diff_type_defaults_tensor() {
355 let out = diff_type(
356 &[Type::Tensor {
357 shape: Some(vec![Some(2), Some(3)]),
358 }],
359 &ResolveContext::new(Vec::new()),
360 );
361 assert_eq!(
362 out,
363 Type::Tensor {
364 shape: Some(vec![None, None])
365 }
366 );
367 }
368
369 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
370 #[test]
371 fn diff_row_vector_default_dimension() {
372 let tensor = Tensor::new(vec![1.0, 4.0, 9.0], vec![1, 3]).unwrap();
373 let result = diff_builtin(Value::Tensor(tensor), Vec::new()).expect("diff");
374 match result {
375 Value::Tensor(out) => {
376 assert_eq!(out.shape, vec![1, 2]);
377 assert_eq!(out.data, vec![3.0, 5.0]);
378 }
379 other => panic!("expected tensor result, got {other:?}"),
380 }
381 }
382
383 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
384 #[test]
385 fn diff_column_vector_second_order() {
386 let tensor = Tensor::new(vec![1.0, 4.0, 9.0, 16.0], vec![4, 1]).unwrap();
387 let args = vec![Value::Int(IntValue::I32(2))];
388 let result = diff_builtin(Value::Tensor(tensor), args).expect("diff");
389 match result {
390 Value::Tensor(out) => {
391 assert_eq!(out.shape, vec![2, 1]);
392 assert_eq!(out.data, vec![2.0, 2.0]);
393 }
394 other => panic!("expected tensor result, got {other:?}"),
395 }
396 }
397
398 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
399 #[test]
400 fn diff_matrix_along_columns() {
401 let tensor = Tensor::new(vec![1.0, 3.0, 5.0, 2.0, 4.0, 6.0], vec![3, 2]).unwrap();
402 let args = vec![Value::Int(IntValue::I32(1)), Value::Int(IntValue::I32(2))];
403 let result = diff_builtin(Value::Tensor(tensor), args).expect("diff");
404 match result {
405 Value::Tensor(out) => {
406 assert_eq!(out.shape, vec![3, 1]);
407 assert_eq!(out.data, vec![1.0, 1.0, 1.0]);
408 }
409 other => panic!("expected tensor result, got {other:?}"),
410 }
411 }
412
413 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
414 #[test]
415 fn diff_handles_empty_when_order_exceeds_dimension() {
416 let tensor = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
417 let args = vec![Value::Int(IntValue::I32(5))];
418 let result = diff_builtin(Value::Tensor(tensor), args).expect("diff");
419 match result {
420 Value::Tensor(out) => {
421 assert_eq!(out.shape[0], 0);
422 assert!(out.data.is_empty());
423 }
424 other => panic!("expected tensor result, got {other:?}"),
425 }
426 }
427
428 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
429 #[test]
430 fn diff_char_array_promotes_to_double() {
431 let chars = CharArray::new("ACEG".chars().collect(), 1, 4).unwrap();
432 let result = diff_builtin(Value::CharArray(chars), Vec::new()).expect("diff");
433 match result {
434 Value::Tensor(out) => {
435 assert_eq!(out.shape, vec![1, 3]);
436 assert_eq!(out.data, vec![2.0, 2.0, 2.0]);
437 }
438 other => panic!("expected tensor result, got {other:?}"),
439 }
440 }
441
442 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
443 #[test]
444 fn diff_complex_tensor_preserves_type() {
445 let tensor =
446 ComplexTensor::new(vec![(1.0, 1.0), (3.0, 2.0), (6.0, 5.0)], vec![1, 3]).unwrap();
447 let result = diff_builtin(Value::ComplexTensor(tensor), Vec::new()).expect("diff");
448 match result {
449 Value::ComplexTensor(out) => {
450 assert_eq!(out.shape, vec![1, 2]);
451 assert_eq!(out.data, vec![(2.0, 1.0), (3.0, 3.0)]);
452 }
453 other => panic!("expected complex tensor result, got {other:?}"),
454 }
455 }
456
457 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
458 #[test]
459 fn diff_zero_order_returns_input() {
460 let tensor = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
461 let args = vec![Value::Int(IntValue::I32(0))];
462 let result = diff_builtin(Value::Tensor(tensor.clone()), args).expect("diff");
463 assert_eq!(result, Value::Tensor(tensor));
464 }
465
466 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
467 #[test]
468 fn diff_accepts_empty_order_argument() {
469 let tensor = Tensor::new(vec![1.0, 4.0, 9.0], vec![3, 1]).unwrap();
470 let baseline = diff_builtin(Value::Tensor(tensor.clone()), Vec::new()).expect("diff");
471 let empty = Tensor::new(vec![], vec![0, 0]).unwrap();
472 let result = diff_builtin(Value::Tensor(tensor), vec![Value::Tensor(empty)]).expect("diff");
473 assert_eq!(result, baseline);
474 }
475
476 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
477 #[test]
478 fn diff_accepts_empty_dimension_argument() {
479 let tensor = Tensor::new(vec![1.0, 4.0, 9.0, 16.0], vec![1, 4]).unwrap();
480 let baseline = diff_builtin(
481 Value::Tensor(tensor.clone()),
482 vec![Value::Int(IntValue::I32(1))],
483 )
484 .expect("diff");
485 let empty = Tensor::new(vec![], vec![0, 0]).unwrap();
486 let result = diff_builtin(
487 Value::Tensor(tensor),
488 vec![Value::Int(IntValue::I32(1)), Value::Tensor(empty)],
489 )
490 .expect("diff");
491 assert_eq!(result, baseline);
492 }
493
494 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
495 #[test]
496 fn diff_rejects_negative_order() {
497 let tensor = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
498 let args = vec![Value::Int(IntValue::I32(-1))];
499 let err = diff_builtin(Value::Tensor(tensor), args).unwrap_err();
500 assert!(err.message().contains("non-negative"));
501 }
502
503 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
504 #[test]
505 fn diff_rejects_non_integer_order() {
506 let tensor = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
507 let args = vec![Value::Num(1.5)];
508 let err = diff_builtin(Value::Tensor(tensor), args).unwrap_err();
509 assert!(err.message().contains("non-negative integer"));
510 }
511
512 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
513 #[test]
514 fn diff_rejects_invalid_dimension() {
515 let tensor = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
516 let args = vec![Value::Int(IntValue::I32(1)), Value::Int(IntValue::I32(0))];
517 let err = diff_builtin(Value::Tensor(tensor), args).unwrap_err();
518 assert!(err.message().contains("dimension must be >= 1"));
519 }
520
521 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
522 #[test]
523 fn diff_gpu_provider_roundtrip() {
524 test_support::with_test_provider(|provider| {
525 let tensor = Tensor::new(vec![1.0, 4.0, 9.0], vec![3, 1]).unwrap();
526 let view = runmat_accelerate_api::HostTensorView {
527 data: &tensor.data,
528 shape: &tensor.shape,
529 };
530 let handle = provider.upload(&view).expect("upload");
531 let result = diff_builtin(Value::GpuTensor(handle), Vec::new()).expect("diff");
532 let gathered = test_support::gather(result).expect("gather");
533 assert_eq!(gathered.shape, vec![2, 1]);
534 assert_eq!(gathered.data, vec![3.0, 5.0]);
535 });
536 }
537
538 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
539 #[test]
540 #[cfg(feature = "wgpu")]
541 fn diff_wgpu_matches_cpu() {
542 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
543 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
544 );
545 let tensor = Tensor::new(vec![1.0, 4.0, 9.0, 16.0], vec![4, 1]).unwrap();
546 let args = vec![Value::Int(IntValue::I32(2))];
547
548 let cpu_result = diff_builtin(Value::Tensor(tensor.clone()), args.clone()).expect("diff");
549 let expected = match cpu_result {
550 Value::Tensor(t) => t,
551 other => panic!("expected tensor result, got {other:?}"),
552 };
553
554 let provider = runmat_accelerate_api::provider().expect("wgpu provider");
555 let view = runmat_accelerate_api::HostTensorView {
556 data: &tensor.data,
557 shape: &tensor.shape,
558 };
559 let handle = provider.upload(&view).expect("upload");
560 let gpu_value = diff_builtin(Value::GpuTensor(handle), args).expect("diff");
561 let gathered = test_support::gather(gpu_value).expect("gather");
562
563 assert_eq!(gathered.shape, expected.shape);
564 let tol = if matches!(
565 provider.precision(),
566 runmat_accelerate_api::ProviderPrecision::F32
567 ) {
568 1e-5
569 } else {
570 1e-12
571 };
572 for (a, b) in gathered.data.iter().zip(expected.data.iter()) {
573 assert!((a - b).abs() < tol, "|{a} - {b}| >= {tol}");
574 }
575 }
576}