1use runmat_accelerate_api::GpuTensorHandle;
4use runmat_builtins::{
5 BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
6 BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
7 CharArray, ComplexTensor, ResolveContext, Tensor, Type, Value,
8};
9use runmat_macros::runtime_builtin;
10
11use crate::builtins::common::random_args::complex_tensor_into_value;
12use crate::builtins::common::spec::{
13 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
14 ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
15};
16use crate::builtins::common::{gpu_helpers, tensor};
17use crate::builtins::math::reduction::type_resolvers::diff_numeric_type;
18use crate::builtins::math::symbolic::{
19 symbolic_expr_to_value, symbolic_variable_name_from_value, value_to_symbolic_scalar,
20};
21use crate::{build_runtime_error, BuiltinResult, RuntimeError};
22
23const NAME: &str = "diff";
24
25fn diff_type(args: &[Type], ctx: &ResolveContext) -> Type {
26 diff_numeric_type(args, ctx)
27}
28
29const DIFF_OUTPUT_B: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
30 name: "B",
31 ty: BuiltinParamType::NumericArray,
32 arity: BuiltinParamArity::Required,
33 default: None,
34 description: "Finite differences along the selected dimension.",
35}];
36
37const DIFF_INPUTS_X: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
38 name: "X",
39 ty: BuiltinParamType::Any,
40 arity: BuiltinParamArity::Required,
41 default: None,
42 description: "Input scalar or array.",
43}];
44
45const DIFF_INPUTS_X_N: [BuiltinParamDescriptor; 2] = [
46 BuiltinParamDescriptor {
47 name: "X",
48 ty: BuiltinParamType::Any,
49 arity: BuiltinParamArity::Required,
50 default: None,
51 description: "Input scalar or array.",
52 },
53 BuiltinParamDescriptor {
54 name: "n",
55 ty: BuiltinParamType::Any,
56 arity: BuiltinParamArity::Optional,
57 default: Some("1"),
58 description: "Difference order (non-negative integer scalar or empty placeholder).",
59 },
60];
61
62const DIFF_INPUTS_X_N_DIM: [BuiltinParamDescriptor; 3] = [
63 BuiltinParamDescriptor {
64 name: "X",
65 ty: BuiltinParamType::Any,
66 arity: BuiltinParamArity::Required,
67 default: None,
68 description: "Input scalar or array.",
69 },
70 BuiltinParamDescriptor {
71 name: "n",
72 ty: BuiltinParamType::Any,
73 arity: BuiltinParamArity::Optional,
74 default: Some("1"),
75 description: "Difference order (non-negative integer scalar or empty placeholder).",
76 },
77 BuiltinParamDescriptor {
78 name: "dim",
79 ty: BuiltinParamType::Any,
80 arity: BuiltinParamArity::Optional,
81 default: Some("[]"),
82 description: "Reduction dimension (positive integer scalar or empty placeholder).",
83 },
84];
85
86const DIFF_SIGNATURES: [BuiltinSignatureDescriptor; 3] = [
87 BuiltinSignatureDescriptor {
88 label: "B = diff(X)",
89 inputs: &DIFF_INPUTS_X,
90 outputs: &DIFF_OUTPUT_B,
91 },
92 BuiltinSignatureDescriptor {
93 label: "B = diff(X, n)",
94 inputs: &DIFF_INPUTS_X_N,
95 outputs: &DIFF_OUTPUT_B,
96 },
97 BuiltinSignatureDescriptor {
98 label: "B = diff(X, n, dim)",
99 inputs: &DIFF_INPUTS_X_N_DIM,
100 outputs: &DIFF_OUTPUT_B,
101 },
102];
103
104const DIFF_ERROR_INVALID_ARGUMENT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
105 code: "RM.DIFF.INVALID_ARGUMENT",
106 identifier: Some("RunMat:diff:InvalidArgument"),
107 when: "Argument count/order/dimension/order grammar is invalid.",
108 message: "diff: invalid argument",
109};
110
111const DIFF_ERROR_INVALID_INPUT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
112 code: "RM.DIFF.INVALID_INPUT",
113 identifier: Some("RunMat:diff:InvalidInput"),
114 when: "Input value cannot be converted to a supported diff domain.",
115 message: "diff: invalid input",
116};
117
118const DIFF_ERROR_INTERNAL: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
119 code: "RM.DIFF.INTERNAL",
120 identifier: Some("RunMat:diff:Internal"),
121 when: "Finite-difference execution fails due to conversion, gather, allocation, or reshape operations.",
122 message: "diff: internal failure",
123};
124
125const DIFF_ERRORS: [BuiltinErrorDescriptor; 3] = [
126 DIFF_ERROR_INVALID_ARGUMENT,
127 DIFF_ERROR_INVALID_INPUT,
128 DIFF_ERROR_INTERNAL,
129];
130
131pub const DIFF_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
132 signatures: &DIFF_SIGNATURES,
133 output_mode: BuiltinOutputMode::Fixed,
134 completion_policy: BuiltinCompletionPolicy::Public,
135 errors: &DIFF_ERRORS,
136};
137
138#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::math::reduction::diff")]
139pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
140 name: "diff",
141 op_kind: GpuOpKind::Custom("finite-difference"),
142 supported_precisions: &[ScalarType::F32, ScalarType::F64],
143 broadcast: BroadcastSemantics::Matlab,
144 provider_hooks: &[ProviderHook::Custom("diff_dim")],
145 constant_strategy: ConstantStrategy::InlineLiteral,
146 residency: ResidencyPolicy::NewHandle,
147 nan_mode: ReductionNaN::Include,
148 two_pass_threshold: None,
149 workgroup_size: None,
150 accepts_nan_mode: false,
151 notes: "Providers surface finite-difference kernels through `diff_dim`; the WGPU backend keeps tensors on the device.",
152};
153
154fn diff_descriptor_error_with_message(
155 message: impl Into<String>,
156 error: &'static BuiltinErrorDescriptor,
157) -> RuntimeError {
158 let mut builder = build_runtime_error(message).with_builtin(NAME);
159 if let Some(identifier) = error.identifier {
160 builder = builder.with_identifier(identifier);
161 }
162 builder.build()
163}
164
165fn diff_descriptor_error_with_detail(
166 error: &'static BuiltinErrorDescriptor,
167 detail: impl AsRef<str>,
168) -> RuntimeError {
169 diff_descriptor_error_with_message(format!("{}: {}", error.message, detail.as_ref()), error)
170}
171
172fn diff_invalid_argument(detail: impl AsRef<str>) -> RuntimeError {
173 diff_descriptor_error_with_detail(&DIFF_ERROR_INVALID_ARGUMENT, detail)
174}
175
176fn diff_invalid_input(detail: impl AsRef<str>) -> RuntimeError {
177 diff_descriptor_error_with_detail(&DIFF_ERROR_INVALID_INPUT, detail)
178}
179
180fn diff_internal_error(detail: impl AsRef<str>) -> RuntimeError {
181 diff_descriptor_error_with_detail(&DIFF_ERROR_INTERNAL, detail)
182}
183
184#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::math::reduction::diff")]
185pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
186 name: "diff",
187 shape: ShapeRequirements::BroadcastCompatible,
188 constant_strategy: ConstantStrategy::InlineLiteral,
189 elementwise: None,
190 reduction: None,
191 emits_nan: false,
192 notes: "Fusion planner currently delegates to the runtime implementation; providers can override with custom kernels.",
193};
194
195#[runtime_builtin(
196 name = "diff",
197 category = "math/reduction",
198 summary = "Compute forward finite differences.",
199 keywords = "diff,difference,finite difference,nth difference,gpu",
200 accel = "diff",
201 type_resolver(diff_type),
202 descriptor(crate::builtins::math::reduction::diff::DIFF_DESCRIPTOR),
203 builtin_path = "crate::builtins::math::reduction::diff"
204)]
205async fn diff_builtin(value: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
206 if let Value::Symbolic(expr) = value {
207 return diff_symbolic(expr, &rest);
208 }
209
210 let (order, dim) = parse_arguments(&rest)?;
211 if order == 0 {
212 return Ok(value);
213 }
214
215 match value {
216 Value::Tensor(tensor) => {
217 diff_tensor_host(tensor, order, dim).map(tensor::tensor_into_value)
218 }
219 Value::LogicalArray(logical) => {
220 let tensor = tensor::logical_to_tensor(&logical).map_err(diff_invalid_input)?;
221 diff_tensor_host(tensor, order, dim).map(tensor::tensor_into_value)
222 }
223 Value::Num(_) | Value::Int(_) | Value::Bool(_) => {
224 let tensor =
225 tensor::value_into_tensor_for("diff", value).map_err(diff_invalid_input)?;
226 diff_tensor_host(tensor, order, dim).map(tensor::tensor_into_value)
227 }
228 Value::Complex(re, im) => {
229 let tensor = ComplexTensor {
230 data: vec![(re, im)],
231 shape: vec![1, 1],
232 rows: 1,
233 cols: 1,
234 };
235 diff_complex_tensor(tensor, order, dim).map(complex_tensor_into_value)
236 }
237 Value::ComplexTensor(tensor) => {
238 diff_complex_tensor(tensor, order, dim).map(complex_tensor_into_value)
239 }
240 Value::CharArray(chars) => diff_char_array(chars, order, dim),
241 Value::GpuTensor(handle) => diff_gpu(handle, order, dim).await,
242 other => Err(diff_invalid_input(format!(
243 "diff: unsupported input type {:?}; expected numeric, logical, or character data",
244 other
245 ))),
246 }
247}
248
249fn diff_symbolic(expr: runmat_builtins::SymbolicExpr, args: &[Value]) -> BuiltinResult<Value> {
250 let (variable, order) = parse_symbolic_diff_args(&expr, args)?;
251 Ok(symbolic_expr_to_value(
252 runmat_builtins::SymbolicExpr::derivative_expr(expr, variable, order),
253 ))
254}
255
256fn parse_symbolic_diff_args(
257 expr: &runmat_builtins::SymbolicExpr,
258 args: &[Value],
259) -> BuiltinResult<(String, u32)> {
260 match args.len() {
261 0 => Ok((infer_symbolic_diff_variable(expr)?, 1)),
262 1 => {
263 if let Some(variable) = symbolic_variable_name_from_value(&args[0]) {
264 Ok((variable, 1))
265 } else {
266 Ok((
267 infer_symbolic_diff_variable(expr)?,
268 parse_symbolic_order(&args[0])?,
269 ))
270 }
271 }
272 2 => {
273 if let Some(variable) = symbolic_variable_name_from_value(&args[0]) {
274 Ok((variable, parse_symbolic_order(&args[1])?))
275 } else if let Some(variable) = symbolic_variable_name_from_value(&args[1]) {
276 Ok((variable, parse_symbolic_order(&args[0])?))
277 } else {
278 Err(diff_invalid_argument(
279 "diff: symbolic differentiation expects a variable and optional order",
280 ))
281 }
282 }
283 _ => Err(diff_invalid_argument(
284 "diff: symbolic differentiation supports at most two trailing arguments",
285 )),
286 }
287}
288
289fn infer_symbolic_diff_variable(expr: &runmat_builtins::SymbolicExpr) -> BuiltinResult<String> {
290 let variables = expr.variables();
291 if variables.len() == 1 {
292 Ok(variables.into_iter().next().unwrap_or_default())
293 } else if variables.is_empty() {
294 Ok(String::new())
295 } else {
296 Err(diff_invalid_argument(
297 "diff: symbolic differentiation variable is ambiguous",
298 ))
299 }
300}
301
302fn parse_symbolic_order(value: &Value) -> BuiltinResult<u32> {
303 let expr = value_to_symbolic_scalar(value).ok_or_else(|| {
304 diff_invalid_argument("diff: symbolic differentiation order must be a scalar integer")
305 })?;
306 let Some(order) = expr.constant_value() else {
307 return Err(diff_invalid_argument(
308 "diff: symbolic differentiation order must be numeric",
309 ));
310 };
311 if !order.is_finite() || order < 0.0 || (order.round() - order).abs() > 1.0e-12 {
312 return Err(diff_invalid_argument(
313 "diff: symbolic differentiation order must be a nonnegative integer",
314 ));
315 }
316 if order > u32::MAX as f64 {
317 return Err(diff_invalid_argument(
318 "diff: symbolic differentiation order is too large",
319 ));
320 }
321 Ok(order as u32)
322}
323
324fn parse_arguments(args: &[Value]) -> BuiltinResult<(usize, Option<usize>)> {
325 match args.len() {
326 0 => Ok((1, None)),
327 1 => {
328 let order = parse_order(&args[0])?;
329 Ok((order.unwrap_or(1), None))
330 }
331 2 => {
332 let order = parse_order(&args[0])?.unwrap_or(1);
333 let dim = parse_dimension_arg(&args[1])?;
334 Ok((order, dim))
335 }
336 _ => Err(diff_invalid_argument("diff: unsupported arguments")),
337 }
338}
339
340fn parse_order(value: &Value) -> BuiltinResult<Option<usize>> {
341 if is_empty_array(value) {
342 return Ok(None);
343 }
344 match value {
345 Value::Int(i) => {
346 let raw = i.to_i64();
347 if raw < 0 {
348 return Err(diff_invalid_argument(
349 "diff: order must be a non-negative integer scalar",
350 ));
351 }
352 Ok(Some(raw as usize))
353 }
354 Value::Num(n) => parse_numeric_order(*n).map(Some),
355 Value::Tensor(t) if t.data.len() == 1 => parse_numeric_order(t.data[0]).map(Some),
356 Value::Bool(b) => Ok(Some(if *b { 1 } else { 0 })),
357 other => Err(diff_invalid_argument(format!(
358 "diff: order must be a non-negative integer scalar, got {:?}",
359 other
360 ))),
361 }
362}
363
364fn parse_numeric_order(value: f64) -> BuiltinResult<usize> {
365 if !value.is_finite() {
366 return Err(diff_invalid_argument("diff: order must be finite"));
367 }
368 if value < 0.0 {
369 return Err(diff_invalid_argument(
370 "diff: order must be a non-negative integer scalar",
371 ));
372 }
373 let rounded = value.round();
374 if (rounded - value).abs() > f64::EPSILON {
375 return Err(diff_invalid_argument(
376 "diff: order must be a non-negative integer scalar",
377 ));
378 }
379 Ok(rounded as usize)
380}
381
382fn parse_dimension_arg(value: &Value) -> BuiltinResult<Option<usize>> {
383 if is_empty_array(value) {
384 return Ok(None);
385 }
386 match value {
387 Value::Int(_) | Value::Num(_) => tensor::parse_dimension(value, "diff")
388 .map(Some)
389 .map_err(diff_invalid_argument),
390 Value::Tensor(t) if t.data.len() == 1 => {
391 tensor::parse_dimension(&Value::Num(t.data[0]), "diff")
392 .map(Some)
393 .map_err(diff_invalid_argument)
394 }
395 other => Err(diff_invalid_argument(format!(
396 "diff: dimension must be a positive integer scalar, got {:?}",
397 other
398 ))),
399 }
400}
401
402fn is_empty_array(value: &Value) -> bool {
403 matches!(value, Value::Tensor(t) if t.data.is_empty())
404}
405
406async fn diff_gpu(
407 handle: GpuTensorHandle,
408 order: usize,
409 dim: Option<usize>,
410) -> BuiltinResult<Value> {
411 let working_dim = dim.unwrap_or_else(|| default_dimension(&handle.shape));
412 if working_dim == 0 {
413 return Err(diff_invalid_argument("diff: dimension must be >= 1"));
414 }
415
416 if let Some(provider) = runmat_accelerate_api::provider() {
417 if let Ok(device_result) = provider.diff_dim(&handle, order, working_dim.saturating_sub(1))
418 {
419 return Ok(Value::GpuTensor(device_result));
420 }
421 }
422
423 let tensor = gpu_helpers::gather_tensor_async(&handle)
424 .await
425 .map_err(|e| diff_internal_error(format!("diff: {e}")))?;
426 diff_tensor_host(tensor, order, Some(working_dim)).map(tensor::tensor_into_value)
427}
428
429fn diff_char_array(chars: CharArray, order: usize, dim: Option<usize>) -> BuiltinResult<Value> {
430 if order == 0 {
431 return Ok(Value::CharArray(chars));
432 }
433 let shape = vec![chars.rows, chars.cols];
434 let data: Vec<f64> = chars.data.iter().map(|&ch| ch as u32 as f64).collect();
435 let tensor = Tensor::new(data, shape).map_err(|e| diff_internal_error(format!("diff: {e}")))?;
436 diff_tensor_host(tensor, order, dim).map(tensor::tensor_into_value)
437}
438
439pub fn diff_tensor_host(tensor: Tensor, order: usize, dim: Option<usize>) -> BuiltinResult<Tensor> {
440 let mut current = tensor;
441 let mut working_dim = dim.unwrap_or_else(|| default_dimension(¤t.shape));
442 for _ in 0..order {
443 current = diff_tensor_once(current, working_dim)?;
444 if current.data.is_empty() {
445 break;
446 }
447 if dim.is_none() && dimension_length(¤t.shape, working_dim) == 0 {
449 working_dim = default_dimension(¤t.shape);
450 }
451 }
452 Ok(current)
453}
454
455fn diff_complex_tensor(
456 tensor: ComplexTensor,
457 order: usize,
458 dim: Option<usize>,
459) -> BuiltinResult<ComplexTensor> {
460 let mut current = tensor;
461 let mut working_dim = dim.unwrap_or_else(|| default_dimension(¤t.shape));
462 for _ in 0..order {
463 current = diff_complex_tensor_once(current, working_dim)?;
464 if current.data.is_empty() {
465 break;
466 }
467 if dim.is_none() && dimension_length(¤t.shape, working_dim) == 0 {
468 working_dim = default_dimension(¤t.shape);
469 }
470 }
471 Ok(current)
472}
473
474fn diff_tensor_once(tensor: Tensor, dim: usize) -> BuiltinResult<Tensor> {
475 let Tensor {
476 data, mut shape, ..
477 } = tensor;
478 let dim_index = dim.saturating_sub(1);
479 while shape.len() <= dim_index {
480 shape.push(1);
481 }
482 let len_dim = shape[dim_index];
483 let mut output_shape = shape.clone();
484 if len_dim <= 1 || data.is_empty() {
485 output_shape[dim_index] = output_shape[dim_index].saturating_sub(1);
486 return Tensor::new(Vec::new(), output_shape)
487 .map_err(|e| diff_internal_error(format!("diff: {e}")));
488 }
489 output_shape[dim_index] = len_dim - 1;
490 let stride_before = product(&shape[..dim_index]);
491 let stride_after = product(&shape[dim_index + 1..]);
492 let output_len = stride_before * (len_dim - 1) * stride_after;
493 let mut out = Vec::with_capacity(output_len);
494
495 for after in 0..stride_after {
496 let after_base = after * stride_before * len_dim;
497 for before in 0..stride_before {
498 for k in 0..(len_dim - 1) {
499 let idx0 = before + after_base + k * stride_before;
500 let idx1 = idx0 + stride_before;
501 out.push(data[idx1] - data[idx0]);
502 }
503 }
504 }
505
506 Tensor::new(out, output_shape).map_err(|e| diff_internal_error(format!("diff: {e}")))
507}
508
509fn diff_complex_tensor_once(tensor: ComplexTensor, dim: usize) -> BuiltinResult<ComplexTensor> {
510 let ComplexTensor {
511 data, mut shape, ..
512 } = tensor;
513 let dim_index = dim.saturating_sub(1);
514 while shape.len() <= dim_index {
515 shape.push(1);
516 }
517 let len_dim = shape[dim_index];
518 let mut output_shape = shape.clone();
519 if len_dim <= 1 || data.is_empty() {
520 output_shape[dim_index] = output_shape[dim_index].saturating_sub(1);
521 return ComplexTensor::new(Vec::new(), output_shape)
522 .map_err(|e| diff_internal_error(format!("diff: {e}")));
523 }
524 output_shape[dim_index] = len_dim - 1;
525 let stride_before = product(&shape[..dim_index]);
526 let stride_after = product(&shape[dim_index + 1..]);
527 let mut out = Vec::with_capacity(stride_before * (len_dim - 1) * stride_after);
528
529 for after in 0..stride_after {
530 let after_base = after * stride_before * len_dim;
531 for before in 0..stride_before {
532 for k in 0..(len_dim - 1) {
533 let idx0 = before + after_base + k * stride_before;
534 let idx1 = idx0 + stride_before;
535 let (re0, im0) = data[idx0];
536 let (re1, im1) = data[idx1];
537 out.push((re1 - re0, im1 - im0));
538 }
539 }
540 }
541
542 ComplexTensor::new(out, output_shape).map_err(|e| diff_internal_error(format!("diff: {e}")))
543}
544
545fn default_dimension(shape: &[usize]) -> usize {
546 shape
547 .iter()
548 .position(|&dim| dim > 1)
549 .map(|idx| idx + 1)
550 .unwrap_or(1)
551}
552
553fn dimension_length(shape: &[usize], dim: usize) -> usize {
554 let dim_index = dim.saturating_sub(1);
555 if dim_index < shape.len() {
556 shape[dim_index]
557 } else {
558 1
559 }
560}
561
562fn product(dims: &[usize]) -> usize {
563 dims.iter()
564 .copied()
565 .fold(1usize, |acc, val| acc.saturating_mul(val))
566}
567
568#[cfg(test)]
569pub(crate) mod tests {
570 use super::*;
571 use crate::builtins::common::test_support;
572 use futures::executor::block_on;
573 use runmat_builtins::{IntValue, SymbolicExpr, Tensor};
574
575 fn diff_builtin(value: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
576 block_on(super::diff_builtin(value, rest))
577 }
578
579 #[test]
580 fn diff_type_defaults_tensor() {
581 let out = diff_type(
582 &[Type::Tensor {
583 shape: Some(vec![Some(2), Some(3)]),
584 }],
585 &ResolveContext::new(Vec::new()),
586 );
587 assert_eq!(
588 out,
589 Type::Tensor {
590 shape: Some(vec![None, None])
591 }
592 );
593 }
594
595 #[test]
596 fn diff_descriptor_signatures_cover_core_forms() {
597 let labels: Vec<&str> = DIFF_DESCRIPTOR
598 .signatures
599 .iter()
600 .map(|sig| sig.label)
601 .collect();
602 assert!(labels.contains(&"B = diff(X)"));
603 assert!(labels.contains(&"B = diff(X, n)"));
604 assert!(labels.contains(&"B = diff(X, n, dim)"));
605 }
606
607 #[test]
608 fn diff_descriptor_errors_have_stable_codes() {
609 assert!(DIFF_DESCRIPTOR
610 .errors
611 .iter()
612 .any(|error| error.code == DIFF_ERROR_INVALID_ARGUMENT.code));
613 assert!(DIFF_DESCRIPTOR
614 .errors
615 .iter()
616 .any(|error| error.code == DIFF_ERROR_INVALID_INPUT.code));
617 assert!(DIFF_DESCRIPTOR
618 .errors
619 .iter()
620 .any(|error| error.code == DIFF_ERROR_INTERNAL.code));
621 }
622
623 #[test]
624 fn diff_symbolic_function_with_explicit_variable() {
625 let y = SymbolicExpr::function_reference("Y", vec!["X".to_string()]);
626
627 let result = diff_builtin(
628 Value::Symbolic(y),
629 vec![Value::Symbolic(SymbolicExpr::variable("X"))],
630 )
631 .expect("diff");
632
633 assert_eq!(result.to_string(), "diff(Y(X), X)");
634 }
635
636 #[test]
637 fn diff_symbolic_function_accepts_order_before_variable() {
638 let y = SymbolicExpr::function_reference("Y", vec!["X".to_string()]);
639
640 let result = diff_builtin(
641 Value::Symbolic(y),
642 vec![
643 Value::Int(IntValue::I32(2)),
644 Value::Symbolic(SymbolicExpr::variable("X")),
645 ],
646 )
647 .expect("diff");
648
649 assert_eq!(result.to_string(), "diff(Y(X), X, 2)");
650 }
651
652 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
653 #[test]
654 fn diff_row_vector_default_dimension() {
655 let tensor = Tensor::new(vec![1.0, 4.0, 9.0], vec![1, 3]).unwrap();
656 let result = diff_builtin(Value::Tensor(tensor), Vec::new()).expect("diff");
657 match result {
658 Value::Tensor(out) => {
659 assert_eq!(out.shape, vec![1, 2]);
660 assert_eq!(out.data, vec![3.0, 5.0]);
661 }
662 other => panic!("expected tensor result, got {other:?}"),
663 }
664 }
665
666 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
667 #[test]
668 fn diff_column_vector_second_order() {
669 let tensor = Tensor::new(vec![1.0, 4.0, 9.0, 16.0], vec![4, 1]).unwrap();
670 let args = vec![Value::Int(IntValue::I32(2))];
671 let result = diff_builtin(Value::Tensor(tensor), args).expect("diff");
672 match result {
673 Value::Tensor(out) => {
674 assert_eq!(out.shape, vec![2, 1]);
675 assert_eq!(out.data, vec![2.0, 2.0]);
676 }
677 other => panic!("expected tensor result, got {other:?}"),
678 }
679 }
680
681 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
682 #[test]
683 fn diff_matrix_along_columns() {
684 let tensor = Tensor::new(vec![1.0, 3.0, 5.0, 2.0, 4.0, 6.0], vec![3, 2]).unwrap();
685 let args = vec![Value::Int(IntValue::I32(1)), Value::Int(IntValue::I32(2))];
686 let result = diff_builtin(Value::Tensor(tensor), args).expect("diff");
687 match result {
688 Value::Tensor(out) => {
689 assert_eq!(out.shape, vec![3, 1]);
690 assert_eq!(out.data, vec![1.0, 1.0, 1.0]);
691 }
692 other => panic!("expected tensor result, got {other:?}"),
693 }
694 }
695
696 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
697 #[test]
698 fn diff_handles_empty_when_order_exceeds_dimension() {
699 let tensor = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
700 let args = vec![Value::Int(IntValue::I32(5))];
701 let result = diff_builtin(Value::Tensor(tensor), args).expect("diff");
702 match result {
703 Value::Tensor(out) => {
704 assert_eq!(out.shape[0], 0);
705 assert!(out.data.is_empty());
706 }
707 other => panic!("expected tensor result, got {other:?}"),
708 }
709 }
710
711 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
712 #[test]
713 fn diff_char_array_promotes_to_double() {
714 let chars = CharArray::new("ACEG".chars().collect(), 1, 4).unwrap();
715 let result = diff_builtin(Value::CharArray(chars), Vec::new()).expect("diff");
716 match result {
717 Value::Tensor(out) => {
718 assert_eq!(out.shape, vec![1, 3]);
719 assert_eq!(out.data, vec![2.0, 2.0, 2.0]);
720 }
721 other => panic!("expected tensor result, got {other:?}"),
722 }
723 }
724
725 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
726 #[test]
727 fn diff_complex_tensor_preserves_type() {
728 let tensor =
729 ComplexTensor::new(vec![(1.0, 1.0), (3.0, 2.0), (6.0, 5.0)], vec![1, 3]).unwrap();
730 let result = diff_builtin(Value::ComplexTensor(tensor), Vec::new()).expect("diff");
731 match result {
732 Value::ComplexTensor(out) => {
733 assert_eq!(out.shape, vec![1, 2]);
734 assert_eq!(out.data, vec![(2.0, 1.0), (3.0, 3.0)]);
735 }
736 other => panic!("expected complex tensor result, got {other:?}"),
737 }
738 }
739
740 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
741 #[test]
742 fn diff_zero_order_returns_input() {
743 let tensor = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
744 let args = vec![Value::Int(IntValue::I32(0))];
745 let result = diff_builtin(Value::Tensor(tensor.clone()), args).expect("diff");
746 assert_eq!(result, Value::Tensor(tensor));
747 }
748
749 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
750 #[test]
751 fn diff_accepts_empty_order_argument() {
752 let tensor = Tensor::new(vec![1.0, 4.0, 9.0], vec![3, 1]).unwrap();
753 let baseline = diff_builtin(Value::Tensor(tensor.clone()), Vec::new()).expect("diff");
754 let empty = Tensor::new(vec![], vec![0, 0]).unwrap();
755 let result = diff_builtin(Value::Tensor(tensor), vec![Value::Tensor(empty)]).expect("diff");
756 assert_eq!(result, baseline);
757 }
758
759 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
760 #[test]
761 fn diff_accepts_empty_dimension_argument() {
762 let tensor = Tensor::new(vec![1.0, 4.0, 9.0, 16.0], vec![1, 4]).unwrap();
763 let baseline = diff_builtin(
764 Value::Tensor(tensor.clone()),
765 vec![Value::Int(IntValue::I32(1))],
766 )
767 .expect("diff");
768 let empty = Tensor::new(vec![], vec![0, 0]).unwrap();
769 let result = diff_builtin(
770 Value::Tensor(tensor),
771 vec![Value::Int(IntValue::I32(1)), Value::Tensor(empty)],
772 )
773 .expect("diff");
774 assert_eq!(result, baseline);
775 }
776
777 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
778 #[test]
779 fn diff_rejects_negative_order() {
780 let tensor = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
781 let args = vec![Value::Int(IntValue::I32(-1))];
782 let err = diff_builtin(Value::Tensor(tensor), args).unwrap_err();
783 assert_eq!(err.identifier(), DIFF_ERROR_INVALID_ARGUMENT.identifier);
784 assert!(err.message().contains("non-negative"));
785 }
786
787 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
788 #[test]
789 fn diff_rejects_non_integer_order() {
790 let tensor = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
791 let args = vec![Value::Num(1.5)];
792 let err = diff_builtin(Value::Tensor(tensor), args).unwrap_err();
793 assert_eq!(err.identifier(), DIFF_ERROR_INVALID_ARGUMENT.identifier);
794 assert!(err.message().contains("non-negative integer"));
795 }
796
797 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
798 #[test]
799 fn diff_rejects_invalid_dimension() {
800 let tensor = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
801 let args = vec![Value::Int(IntValue::I32(1)), Value::Int(IntValue::I32(0))];
802 let err = diff_builtin(Value::Tensor(tensor), args).unwrap_err();
803 assert_eq!(err.identifier(), DIFF_ERROR_INVALID_ARGUMENT.identifier);
804 assert!(err.message().contains("dimension must be >= 1"));
805 }
806
807 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
808 #[test]
809 fn diff_gpu_provider_roundtrip() {
810 test_support::with_test_provider(|provider| {
811 let tensor = Tensor::new(vec![1.0, 4.0, 9.0], vec![3, 1]).unwrap();
812 let view = runmat_accelerate_api::HostTensorView {
813 data: &tensor.data,
814 shape: &tensor.shape,
815 };
816 let handle = provider.upload(&view).expect("upload");
817 let result = diff_builtin(Value::GpuTensor(handle), Vec::new()).expect("diff");
818 let gathered = test_support::gather(result).expect("gather");
819 assert_eq!(gathered.shape, vec![2, 1]);
820 assert_eq!(gathered.data, vec![3.0, 5.0]);
821 });
822 }
823
824 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
825 #[test]
826 #[cfg(feature = "wgpu")]
827 fn diff_wgpu_matches_cpu() {
828 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
829 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
830 );
831 let tensor = Tensor::new(vec![1.0, 4.0, 9.0, 16.0], vec![4, 1]).unwrap();
832 let args = vec![Value::Int(IntValue::I32(2))];
833
834 let cpu_result = diff_builtin(Value::Tensor(tensor.clone()), args.clone()).expect("diff");
835 let expected = match cpu_result {
836 Value::Tensor(t) => t,
837 other => panic!("expected tensor result, got {other:?}"),
838 };
839
840 let provider = runmat_accelerate_api::provider().expect("wgpu provider");
841 let view = runmat_accelerate_api::HostTensorView {
842 data: &tensor.data,
843 shape: &tensor.shape,
844 };
845 let handle = provider.upload(&view).expect("upload");
846 let gpu_value = diff_builtin(Value::GpuTensor(handle), args).expect("diff");
847 let gathered = test_support::gather(gpu_value).expect("gather");
848
849 assert_eq!(gathered.shape, expected.shape);
850 let tol = if matches!(
851 provider.precision(),
852 runmat_accelerate_api::ProviderPrecision::F32
853 ) {
854 1e-5
855 } else {
856 1e-12
857 };
858 for (a, b) in gathered.data.iter().zip(expected.data.iter()) {
859 assert!((a - b).abs() < tol, "|{a} - {b}| >= {tol}");
860 }
861 }
862}