1use runmat_accelerate_api::{GpuTensorHandle, GpuTensorStorage};
4use runmat_builtins::{
5 BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
6 BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
7 ComplexTensor, ResolveContext, Tensor, Type, Value,
8};
9use runmat_macros::runtime_builtin;
10
11use crate::builtins::common::gpu_helpers;
12use crate::builtins::common::random_args::complex_tensor_into_value;
13use crate::builtins::common::spec::{
14 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
15 ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
16};
17use crate::builtins::common::tensor;
18use crate::builtins::math::type_resolvers::numeric_unary_type;
19use crate::{build_runtime_error, BuiltinResult, RuntimeError};
20
21const NAME: &str = "gradient";
22
23fn gradient_type(args: &[Type], ctx: &ResolveContext) -> Type {
24 numeric_unary_type(args, ctx)
25}
26
27const GRADIENT_OUTPUT_G: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
28 name: "G",
29 ty: BuiltinParamType::NumericArray,
30 arity: BuiltinParamArity::Required,
31 default: None,
32 description: "Primary gradient component.",
33}];
34
35const GRADIENT_OUTPUT_GS: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
36 name: "Gi",
37 ty: BuiltinParamType::NumericArray,
38 arity: BuiltinParamArity::Variadic,
39 default: None,
40 description: "Gradient components ordered by MATLAB axis semantics.",
41}];
42
43const GRADIENT_INPUTS_F: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
44 name: "F",
45 ty: BuiltinParamType::Any,
46 arity: BuiltinParamArity::Required,
47 default: None,
48 description: "Input scalar or array.",
49}];
50
51const GRADIENT_INPUTS_F_H: [BuiltinParamDescriptor; 2] = [
52 BuiltinParamDescriptor {
53 name: "F",
54 ty: BuiltinParamType::Any,
55 arity: BuiltinParamArity::Required,
56 default: None,
57 description: "Input scalar or array.",
58 },
59 BuiltinParamDescriptor {
60 name: "h",
61 ty: BuiltinParamType::Any,
62 arity: BuiltinParamArity::Optional,
63 default: Some("1"),
64 description: "Scalar spacing shared across all output dimensions.",
65 },
66];
67
68const GRADIENT_INPUTS_F_HS: [BuiltinParamDescriptor; 2] = [
69 BuiltinParamDescriptor {
70 name: "F",
71 ty: BuiltinParamType::Any,
72 arity: BuiltinParamArity::Required,
73 default: None,
74 description: "Input scalar or array.",
75 },
76 BuiltinParamDescriptor {
77 name: "h_i",
78 ty: BuiltinParamType::Any,
79 arity: BuiltinParamArity::Variadic,
80 default: None,
81 description: "Per-dimension scalar spacings (one per requested gradient component).",
82 },
83];
84
85const GRADIENT_SIGNATURES: [BuiltinSignatureDescriptor; 4] = [
86 BuiltinSignatureDescriptor {
87 label: "G = gradient(F)",
88 inputs: &GRADIENT_INPUTS_F,
89 outputs: &GRADIENT_OUTPUT_G,
90 },
91 BuiltinSignatureDescriptor {
92 label: "G = gradient(F, h)",
93 inputs: &GRADIENT_INPUTS_F_H,
94 outputs: &GRADIENT_OUTPUT_G,
95 },
96 BuiltinSignatureDescriptor {
97 label: "[G1, G2, ...] = gradient(F)",
98 inputs: &GRADIENT_INPUTS_F,
99 outputs: &GRADIENT_OUTPUT_GS,
100 },
101 BuiltinSignatureDescriptor {
102 label: "[G1, G2, ...] = gradient(F, h1, h2, ...)",
103 inputs: &GRADIENT_INPUTS_F_HS,
104 outputs: &GRADIENT_OUTPUT_GS,
105 },
106];
107
108const GRADIENT_ERROR_INVALID_ARGUMENT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
109 code: "RM.GRADIENT.INVALID_ARGUMENT",
110 identifier: Some("RunMat:gradient:InvalidArgument"),
111 when: "Output-count or spacing argument grammar is invalid.",
112 message: "gradient: invalid argument",
113};
114
115const GRADIENT_ERROR_INVALID_INPUT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
116 code: "RM.GRADIENT.INVALID_INPUT",
117 identifier: Some("RunMat:gradient:InvalidInput"),
118 when: "Input value cannot be converted to a supported gradient domain.",
119 message: "gradient: invalid input",
120};
121
122const GRADIENT_ERROR_INTERNAL: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
123 code: "RM.GRADIENT.INTERNAL",
124 identifier: Some("RunMat:gradient:Internal"),
125 when: "Gradient execution fails due to gather, conversion, allocation, or indexing operations.",
126 message: "gradient: internal failure",
127};
128
129const GRADIENT_ERRORS: [BuiltinErrorDescriptor; 3] = [
130 GRADIENT_ERROR_INVALID_ARGUMENT,
131 GRADIENT_ERROR_INVALID_INPUT,
132 GRADIENT_ERROR_INTERNAL,
133];
134
135pub const GRADIENT_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
136 signatures: &GRADIENT_SIGNATURES,
137 output_mode: BuiltinOutputMode::ByRequestedOutputCount,
138 completion_policy: BuiltinCompletionPolicy::Public,
139 errors: &GRADIENT_ERRORS,
140};
141
142fn gradient_descriptor_error_with_message(
143 message: impl Into<String>,
144 error: &'static BuiltinErrorDescriptor,
145) -> RuntimeError {
146 let mut builder = build_runtime_error(message).with_builtin(NAME);
147 if let Some(identifier) = error.identifier {
148 builder = builder.with_identifier(identifier);
149 }
150 builder.build()
151}
152
153fn gradient_descriptor_error_with_detail(
154 error: &'static BuiltinErrorDescriptor,
155 detail: impl AsRef<str>,
156) -> RuntimeError {
157 gradient_descriptor_error_with_message(format!("{}: {}", error.message, detail.as_ref()), error)
158}
159
160fn gradient_invalid_argument(detail: impl AsRef<str>) -> RuntimeError {
161 gradient_descriptor_error_with_detail(&GRADIENT_ERROR_INVALID_ARGUMENT, detail)
162}
163
164fn gradient_invalid_input(detail: impl AsRef<str>) -> RuntimeError {
165 gradient_descriptor_error_with_detail(&GRADIENT_ERROR_INVALID_INPUT, detail)
166}
167
168fn gradient_internal_error(detail: impl AsRef<str>) -> RuntimeError {
169 gradient_descriptor_error_with_detail(&GRADIENT_ERROR_INTERNAL, detail)
170}
171
172#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::math::reduction::gradient")]
173pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
174 name: "gradient",
175 op_kind: GpuOpKind::Custom("numerical-gradient"),
176 supported_precisions: &[ScalarType::F32, ScalarType::F64],
177 broadcast: BroadcastSemantics::Matlab,
178 provider_hooks: &[ProviderHook::Custom("gradient_dim")],
179 constant_strategy: ConstantStrategy::InlineLiteral,
180 residency: ResidencyPolicy::NewHandle,
181 nan_mode: ReductionNaN::Include,
182 two_pass_threshold: None,
183 workgroup_size: None,
184 accepts_nan_mode: false,
185 notes:
186 "Providers may keep scalar-spacing gradients on device via `gradient_dim`; coordinate-vector spacing falls back to the host in this implementation.",
187};
188
189#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::math::reduction::gradient")]
190pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
191 name: "gradient",
192 shape: ShapeRequirements::Any,
193 constant_strategy: ConstantStrategy::InlineLiteral,
194 elementwise: None,
195 reduction: None,
196 emits_nan: false,
197 notes: "Gradient preserves input shape and uses edge-aware finite differences, so providers expose it through a custom sink hook.",
198};
199
200#[runtime_builtin(
201 name = "gradient",
202 category = "math/reduction",
203 summary = "Compute numerical gradients.",
204 keywords = "gradient,numerical gradient,finite difference,vector field,gpu",
205 accel = "gradient",
206 type_resolver(gradient_type),
207 descriptor(crate::builtins::math::reduction::gradient::GRADIENT_DESCRIPTOR),
208 builtin_path = "crate::builtins::math::reduction::gradient"
209)]
210async fn gradient_builtin(value: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
211 let requested_outputs = crate::output_count::current_output_count().unwrap_or(1);
212 if requested_outputs == 0 {
213 return Ok(Value::OutputList(Vec::new()));
214 }
215
216 let available_outputs = gradient_output_dims(value_shape(&value), value_len(&value));
217 if requested_outputs > available_outputs.len() {
218 return Err(gradient_invalid_argument(format!(
219 "gradient: requested {requested_outputs} outputs, but input supports at most {}",
220 available_outputs.len()
221 )));
222 }
223
224 let spacings = parse_spacings(&rest, available_outputs.len()).await?;
225 let outputs =
226 evaluate_gradient_outputs(value, &available_outputs[..requested_outputs], &spacings)
227 .await?;
228
229 if crate::output_count::current_output_count().is_some() {
230 return Ok(Value::OutputList(outputs));
231 }
232
233 Ok(outputs
234 .into_iter()
235 .next()
236 .expect("single-output gradient result"))
237}
238
239async fn evaluate_gradient_outputs(
240 value: Value,
241 requested_dims: &[usize],
242 all_spacings: &[f64],
243) -> BuiltinResult<Vec<Value>> {
244 if let Value::GpuTensor(handle) = value {
245 return gradient_gpu_outputs(handle, requested_dims, all_spacings).await;
246 }
247
248 evaluate_host_gradient_outputs(value, requested_dims, all_spacings)
249}
250
251fn evaluate_host_gradient_outputs(
252 value: Value,
253 requested_dims: &[usize],
254 all_spacings: &[f64],
255) -> BuiltinResult<Vec<Value>> {
256 match value {
257 Value::Tensor(tensor) => {
258 let mut outputs = Vec::with_capacity(requested_dims.len());
259 for &dim in requested_dims {
260 let spacing = spacing_for_dim(dim, requested_dims, all_spacings);
261 outputs.push(tensor::tensor_into_value(gradient_real_tensor_host(
262 tensor.clone(),
263 dim,
264 spacing,
265 )?));
266 }
267 Ok(outputs)
268 }
269 Value::LogicalArray(logical) => {
270 let tensor = tensor::logical_to_tensor(&logical).map_err(gradient_invalid_input)?;
271 let mut outputs = Vec::with_capacity(requested_dims.len());
272 for &dim in requested_dims {
273 let spacing = spacing_for_dim(dim, requested_dims, all_spacings);
274 outputs.push(tensor::tensor_into_value(gradient_real_tensor_host(
275 tensor.clone(),
276 dim,
277 spacing,
278 )?));
279 }
280 Ok(outputs)
281 }
282 Value::Num(_) | Value::Int(_) | Value::Bool(_) => {
283 let tensor =
284 tensor::value_into_tensor_for(NAME, value).map_err(gradient_invalid_input)?;
285 let mut outputs = Vec::with_capacity(requested_dims.len());
286 for &dim in requested_dims {
287 let spacing = spacing_for_dim(dim, requested_dims, all_spacings);
288 outputs.push(tensor::tensor_into_value(gradient_real_tensor_host(
289 tensor.clone(),
290 dim,
291 spacing,
292 )?));
293 }
294 Ok(outputs)
295 }
296 Value::Complex(re, im) => {
297 let tensor = ComplexTensor {
298 data: vec![(re, im)],
299 shape: vec![1, 1],
300 rows: 1,
301 cols: 1,
302 };
303 let mut outputs = Vec::with_capacity(requested_dims.len());
304 for &dim in requested_dims {
305 let spacing = spacing_for_dim(dim, requested_dims, all_spacings);
306 outputs.push(complex_tensor_into_value(gradient_complex_tensor_host(
307 tensor.clone(),
308 dim,
309 spacing,
310 )?));
311 }
312 Ok(outputs)
313 }
314 Value::ComplexTensor(tensor) => {
315 let mut outputs = Vec::with_capacity(requested_dims.len());
316 for &dim in requested_dims {
317 let spacing = spacing_for_dim(dim, requested_dims, all_spacings);
318 outputs.push(complex_tensor_into_value(gradient_complex_tensor_host(
319 tensor.clone(),
320 dim,
321 spacing,
322 )?));
323 }
324 Ok(outputs)
325 }
326 other => Err(gradient_invalid_input(format!(
327 "gradient: unsupported input type {:?}; expected numeric or logical data",
328 other
329 ))),
330 }
331}
332
333async fn gradient_gpu_outputs(
334 handle: GpuTensorHandle,
335 requested_dims: &[usize],
336 all_spacings: &[f64],
337) -> BuiltinResult<Vec<Value>> {
338 let complex_storage =
339 runmat_accelerate_api::handle_storage(&handle) == GpuTensorStorage::ComplexInterleaved;
340
341 if let Some(provider) =
342 runmat_accelerate_api::provider_for_handle(&handle).or_else(runmat_accelerate_api::provider)
343 {
344 let _guard = runmat_accelerate_api::ThreadProviderGuard::set(Some(provider));
345 let mut outputs = Vec::with_capacity(requested_dims.len());
346 for &dim in requested_dims {
347 let spacing = spacing_for_dim(dim, requested_dims, all_spacings);
348 match provider.gradient_dim(&handle, dim.saturating_sub(1), spacing) {
349 Ok(device_result) => {
350 if complex_storage
351 || runmat_accelerate_api::handle_storage(&device_result)
352 == GpuTensorStorage::ComplexInterleaved
353 {
354 outputs.push(gpu_helpers::complex_gpu_value(device_result));
355 } else {
356 outputs.push(gpu_helpers::resident_gpu_value(device_result));
357 }
358 }
359 Err(_) => {
360 let gathered =
361 gpu_helpers::gather_value_async(&Value::GpuTensor(handle)).await?;
362 return evaluate_host_gradient_outputs(gathered, requested_dims, all_spacings);
363 }
364 }
365 }
366 return Ok(outputs);
367 }
368
369 let gathered = gpu_helpers::gather_value_async(&Value::GpuTensor(handle)).await?;
370 evaluate_host_gradient_outputs(gathered, requested_dims, all_spacings)
371}
372
373fn spacing_for_dim(dim: usize, available_dims: &[usize], spacings: &[f64]) -> f64 {
374 if spacings.len() == 1 {
375 return spacings[0];
376 }
377
378 let index = available_dims
379 .iter()
380 .position(|candidate| *candidate == dim)
381 .expect("spacing lookup requires matching dimension");
382 spacings[index]
383}
384
385async fn parse_spacings(args: &[Value], available_dims: usize) -> BuiltinResult<Vec<f64>> {
386 match args.len() {
387 0 => Ok(vec![1.0; available_dims]),
388 1 => {
389 let spacing = parse_scalar_spacing(&args[0]).await?;
390 Ok(vec![spacing; available_dims])
391 }
392 count if count == available_dims => {
393 let mut spacings = Vec::with_capacity(args.len());
394 for value in args {
395 spacings.push(parse_scalar_spacing(value).await?);
396 }
397 Ok(spacings)
398 }
399 _ => Err(gradient_invalid_argument(format!(
400 "gradient: expected 0, 1, or {available_dims} scalar spacing arguments"
401 ))),
402 }
403}
404
405async fn parse_scalar_spacing(value: &Value) -> BuiltinResult<f64> {
406 match value {
407 Value::Tensor(tensor) if tensor.data.is_empty() => {
408 return Err(gradient_invalid_argument(
409 "gradient: empty spacing arguments are not supported",
410 ))
411 }
412 _ => {}
413 }
414
415 let Some(spacing) = tensor::scalar_f64_from_value_async(value)
416 .await
417 .map_err(gradient_invalid_argument)?
418 else {
419 return Err(gradient_invalid_argument(
420 "gradient: only scalar spacings are supported in this implementation",
421 ));
422 };
423
424 if !spacing.is_finite() {
425 return Err(gradient_invalid_argument(
426 "gradient: spacing must be finite",
427 ));
428 }
429 if spacing == 0.0 {
430 return Err(gradient_invalid_argument(
431 "gradient: spacing must be nonzero",
432 ));
433 }
434 Ok(spacing)
435}
436
437fn value_shape(value: &Value) -> &[usize] {
438 match value {
439 Value::Tensor(tensor) => &tensor.shape,
440 Value::LogicalArray(logical) => &logical.shape,
441 Value::ComplexTensor(tensor) => &tensor.shape,
442 Value::GpuTensor(handle) => &handle.shape,
443 _ => &[],
444 }
445}
446
447fn value_len(value: &Value) -> usize {
448 match value {
449 Value::Tensor(tensor) => tensor.data.len(),
450 Value::LogicalArray(logical) => logical.data.len(),
451 Value::ComplexTensor(tensor) => tensor.data.len(),
452 Value::GpuTensor(handle) => product(&handle.shape),
453 _ => 1,
454 }
455}
456
457pub fn matlab_gradient_shape(shape: &[usize], len: usize) -> Vec<usize> {
458 if shape.is_empty() {
459 if len == 0 {
460 Vec::new()
461 } else {
462 vec![1, 1]
463 }
464 } else if shape.len() == 1 {
465 if shape[0] == 1 {
466 vec![1, 1]
467 } else {
468 vec![1, shape[0]]
469 }
470 } else {
471 shape.to_vec()
472 }
473}
474
475fn gradient_output_dims(shape: &[usize], len: usize) -> Vec<usize> {
476 let normalized_shape = matlab_gradient_shape(shape, len);
477 let mut ext_shape = if normalized_shape.is_empty() {
478 if len == 0 {
479 vec![0, 0]
480 } else {
481 vec![1, 1]
482 }
483 } else {
484 normalized_shape
485 };
486 if ext_shape.len() == 1 {
487 ext_shape.push(1);
488 }
489
490 if ext_shape.len() <= 2 {
491 let rows = ext_shape.first().copied().unwrap_or(1);
492 let cols = ext_shape.get(1).copied().unwrap_or(1);
493 if rows == 1 && cols == 1 {
494 vec![1]
495 } else if rows == 1 {
496 vec![2]
497 } else if cols == 1 {
498 vec![1]
499 } else {
500 vec![2, 1]
501 }
502 } else {
503 let mut dims = vec![2, 1];
504 for dim in 3..=ext_shape.len() {
505 dims.push(dim);
506 }
507 dims
508 }
509}
510
511pub fn gradient_real_tensor_host(
512 tensor: Tensor,
513 dim: usize,
514 spacing: f64,
515) -> BuiltinResult<Tensor> {
516 let Tensor {
517 data, shape, dtype, ..
518 } = tensor;
519 let dim_index = dim.saturating_sub(1);
520 let mut shape = matlab_gradient_shape(&shape, data.len());
521
522 if data.is_empty() {
523 let empty_shape = if shape.is_empty() { vec![0, 0] } else { shape };
528 return Tensor::new_with_dtype(Vec::new(), empty_shape, dtype)
529 .map_err(|e| gradient_internal_error(format!("gradient: {e}")));
530 }
531
532 while shape.len() <= dim_index {
533 shape.push(1);
534 }
535
536 let mut ext_shape = shape.clone();
537 while ext_shape.len() <= dim_index {
538 ext_shape.push(1);
539 }
540 let len_dim = ext_shape[dim_index];
541 let stride_before = if dim_index == 0 {
542 1usize
543 } else {
544 product(&ext_shape[..dim_index]).max(1)
545 };
546 let stride_after = if dim_index + 1 >= ext_shape.len() {
547 1usize
548 } else {
549 product(&ext_shape[dim_index + 1..]).max(1)
550 };
551
552 let mut out = vec![0.0; data.len()];
553 if len_dim > 1 {
554 let block = stride_before
555 .checked_mul(len_dim)
556 .ok_or_else(|| gradient_internal_error("gradient: block size overflow"))?;
557 for after in 0..stride_after {
558 let base = after
559 .checked_mul(block)
560 .ok_or_else(|| gradient_internal_error("gradient: indexing overflow"))?;
561 for before in 0..stride_before {
562 for k in 0..len_dim {
563 let idx = base + before + k * stride_before;
564 out[idx] = if k == 0 {
565 (data[idx + stride_before] - data[idx]) / spacing
566 } else if k + 1 == len_dim {
567 (data[idx] - data[idx - stride_before]) / spacing
568 } else {
569 (data[idx + stride_before] - data[idx - stride_before]) / (2.0 * spacing)
570 };
571 }
572 }
573 }
574 }
575
576 Tensor::new_with_dtype(out, shape, dtype)
577 .map_err(|e| gradient_internal_error(format!("gradient: {e}")))
578}
579
580pub fn gradient_complex_tensor_host(
581 tensor: ComplexTensor,
582 dim: usize,
583 spacing: f64,
584) -> BuiltinResult<ComplexTensor> {
585 let ComplexTensor { data, shape, .. } = tensor;
586 let dim_index = dim.saturating_sub(1);
587 let mut shape = matlab_gradient_shape(&shape, data.len());
588
589 if data.is_empty() {
590 let empty_shape = if shape.is_empty() { vec![0, 0] } else { shape };
593 return ComplexTensor::new(Vec::new(), empty_shape)
594 .map_err(|e| gradient_internal_error(format!("gradient: {e}")));
595 }
596
597 while shape.len() <= dim_index {
598 shape.push(1);
599 }
600
601 let mut ext_shape = shape.clone();
602 while ext_shape.len() <= dim_index {
603 ext_shape.push(1);
604 }
605 let len_dim = ext_shape[dim_index];
606 let stride_before = if dim_index == 0 {
607 1usize
608 } else {
609 product(&ext_shape[..dim_index]).max(1)
610 };
611 let stride_after = if dim_index + 1 >= ext_shape.len() {
612 1usize
613 } else {
614 product(&ext_shape[dim_index + 1..]).max(1)
615 };
616
617 let mut out = vec![(0.0, 0.0); data.len()];
618 if len_dim > 1 {
619 let block = stride_before
620 .checked_mul(len_dim)
621 .ok_or_else(|| gradient_internal_error("gradient: block size overflow"))?;
622 for after in 0..stride_after {
623 let base = after
624 .checked_mul(block)
625 .ok_or_else(|| gradient_internal_error("gradient: indexing overflow"))?;
626 for before in 0..stride_before {
627 for k in 0..len_dim {
628 let idx = base + before + k * stride_before;
629 out[idx] = if k == 0 {
630 scale_complex(
631 sub_complex(data[idx + stride_before], data[idx]),
632 1.0 / spacing,
633 )
634 } else if k + 1 == len_dim {
635 scale_complex(
636 sub_complex(data[idx], data[idx - stride_before]),
637 1.0 / spacing,
638 )
639 } else {
640 scale_complex(
641 sub_complex(data[idx + stride_before], data[idx - stride_before]),
642 0.5 / spacing,
643 )
644 };
645 }
646 }
647 }
648 }
649
650 ComplexTensor::new(out, shape).map_err(|e| gradient_internal_error(format!("gradient: {e}")))
651}
652
653fn sub_complex(lhs: (f64, f64), rhs: (f64, f64)) -> (f64, f64) {
654 (lhs.0 - rhs.0, lhs.1 - rhs.1)
655}
656
657fn scale_complex(value: (f64, f64), scale: f64) -> (f64, f64) {
658 (value.0 * scale, value.1 * scale)
659}
660
661fn product(dims: &[usize]) -> usize {
662 dims.iter()
663 .copied()
664 .fold(1usize, |acc, value| acc.saturating_mul(value))
665}
666
667#[cfg(test)]
668mod tests {
669 use super::*;
670 use crate::builtins::common::test_support;
671 use futures::executor::block_on;
672 #[cfg(feature = "wgpu")]
673 use runmat_accelerate_api::AccelProvider;
674 #[cfg(feature = "wgpu")]
675 use runmat_accelerate_api::HostTensorView;
676 use runmat_builtins::{NumericDType, Tensor};
677
678 fn gradient_builtin(value: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
679 block_on(super::gradient_builtin(value, rest))
680 }
681
682 #[test]
683 fn gradient_descriptor_signatures_cover_core_forms() {
684 let labels: Vec<&str> = GRADIENT_DESCRIPTOR
685 .signatures
686 .iter()
687 .map(|sig| sig.label)
688 .collect();
689 assert!(labels.contains(&"G = gradient(F)"));
690 assert!(labels.contains(&"G = gradient(F, h)"));
691 assert!(labels.contains(&"[G1, G2, ...] = gradient(F)"));
692 assert!(labels.contains(&"[G1, G2, ...] = gradient(F, h1, h2, ...)"));
693 }
694
695 #[test]
696 fn gradient_descriptor_errors_have_stable_codes() {
697 assert!(GRADIENT_DESCRIPTOR
698 .errors
699 .iter()
700 .any(|error| error.code == GRADIENT_ERROR_INVALID_ARGUMENT.code));
701 assert!(GRADIENT_DESCRIPTOR
702 .errors
703 .iter()
704 .any(|error| error.code == GRADIENT_ERROR_INVALID_INPUT.code));
705 assert!(GRADIENT_DESCRIPTOR
706 .errors
707 .iter()
708 .any(|error| error.code == GRADIENT_ERROR_INTERNAL.code));
709 }
710
711 #[test]
712 fn gradient_row_vector_returns_horizontal_derivative() {
713 let tensor = Tensor::new(vec![1.0, 4.0, 9.0], vec![1, 3]).unwrap();
714 let result = gradient_builtin(Value::Tensor(tensor), Vec::new()).expect("gradient");
715 assert_eq!(
716 result,
717 Value::Tensor(Tensor::new(vec![3.0, 4.0, 5.0], vec![1, 3]).unwrap())
718 );
719 }
720
721 #[test]
722 fn gradient_one_dimensional_tensor_is_treated_as_row_vector() {
723 let tensor = Tensor::new(vec![1.0, 4.0, 9.0], vec![3]).unwrap();
724 let result =
725 gradient_builtin(Value::Tensor(tensor), vec![Value::Num(2.0)]).expect("gradient");
726 match result {
727 Value::Tensor(out) => {
728 assert_eq!(out.shape, vec![1, 3]);
729 assert_eq!(out.data, vec![1.5, 2.0, 2.5]);
730 }
731 other => panic!("expected tensor, got {other:?}"),
732 }
733 }
734
735 #[test]
736 fn gradient_matrix_outputs_follow_matlab_order() {
737 let tensor = Tensor::new(vec![1.0, 3.0, 2.0, 4.0], vec![2, 2]).unwrap();
738 let _guard = crate::output_count::push_output_count(Some(2));
739 let result = gradient_builtin(Value::Tensor(tensor), Vec::new()).expect("gradient");
740 match result {
741 Value::OutputList(outputs) => {
742 let fx = test_support::gather(outputs[0].clone()).expect("fx");
743 let fy = test_support::gather(outputs[1].clone()).expect("fy");
744 assert_eq!(fx.data, vec![1.0, 1.0, 1.0, 1.0]);
745 assert_eq!(fy.data, vec![2.0, 2.0, 2.0, 2.0]);
746 }
747 other => panic!("expected output list, got {other:?}"),
748 }
749 }
750
751 #[test]
752 fn gradient_scalar_spacing_scales_output() {
753 let tensor = Tensor::new(vec![1.0, 4.0, 9.0], vec![1, 3]).unwrap();
754 let result =
755 gradient_builtin(Value::Tensor(tensor), vec![Value::Num(2.0)]).expect("gradient");
756 match result {
757 Value::Tensor(out) => assert_eq!(out.data, vec![1.5, 2.0, 2.5]),
758 other => panic!("expected tensor, got {other:?}"),
759 }
760 }
761
762 #[test]
763 fn gradient_preserves_single_precision_host_tensor() {
764 let tensor =
765 Tensor::new_with_dtype(vec![1.0, 4.0, 9.0], vec![1, 3], NumericDType::F32).unwrap();
766 let result = gradient_builtin(Value::Tensor(tensor), Vec::new()).expect("gradient");
767 match result {
768 Value::Tensor(out) => assert_eq!(out.dtype, NumericDType::F32),
769 other => panic!("expected tensor, got {other:?}"),
770 }
771 }
772
773 #[test]
774 fn gradient_complex_host_supported() {
775 let tensor =
776 ComplexTensor::new(vec![(1.0, 1.0), (4.0, 3.0), (9.0, 6.0)], vec![1, 3]).unwrap();
777 let result = gradient_builtin(Value::ComplexTensor(tensor), Vec::new()).expect("gradient");
778 match result {
779 Value::ComplexTensor(out) => {
780 assert_eq!(out.data, vec![(3.0, 2.0), (4.0, 2.5), (5.0, 3.0)]);
781 }
782 other => panic!("expected complex tensor, got {other:?}"),
783 }
784 }
785
786 #[test]
787 fn gradient_rejects_coordinate_vector_spacing_in_v1() {
788 let tensor = Tensor::new(vec![1.0, 4.0, 9.0], vec![1, 3]).unwrap();
789 let spacing = Tensor::new(vec![0.0, 1.0, 2.0], vec![1, 3]).unwrap();
790 let err =
791 gradient_builtin(Value::Tensor(tensor), vec![Value::Tensor(spacing)]).unwrap_err();
792 assert_eq!(err.identifier(), GRADIENT_ERROR_INVALID_ARGUMENT.identifier);
793 assert!(err.message().contains("scalar"));
794 }
795
796 #[test]
797 fn gradient_rejects_too_many_outputs() {
798 let tensor = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
799 let _guard = crate::output_count::push_output_count(Some(2));
800 let err = gradient_builtin(Value::Tensor(tensor), Vec::new()).unwrap_err();
801 assert_eq!(err.identifier(), GRADIENT_ERROR_INVALID_ARGUMENT.identifier);
802 assert!(err.message().contains("requested 2 outputs"));
803 }
804
805 #[test]
806 #[cfg(feature = "wgpu")]
807 fn gradient_gpu_scalar_spacing_matches_cpu_and_stays_resident() {
808 let _guard = test_support::accel_test_lock();
809 let Ok(provider) = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
810 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
811 ) else {
812 return;
813 };
814 let host =
815 Tensor::new_with_dtype(vec![1.0, 4.0, 9.0], vec![1, 3], NumericDType::F32).unwrap();
816 let view = HostTensorView {
817 data: &host.data,
818 shape: &host.shape,
819 };
820 let handle = provider.upload(&view).expect("upload");
821 let result =
822 gradient_builtin(Value::GpuTensor(handle), vec![Value::Num(2.0)]).expect("gradient");
823 match result {
824 Value::GpuTensor(out) => {
825 let gathered = test_support::gather(Value::GpuTensor(out)).expect("gather");
826 assert_eq!(gathered.data, vec![1.5, 2.0, 2.5]);
827 assert_eq!(gathered.dtype, NumericDType::F32);
828 }
829 other => panic!("expected gpu tensor, got {other:?}"),
830 }
831 }
832
833 #[test]
834 #[cfg(feature = "wgpu")]
835 fn gradient_gpu_one_dimensional_shape_matches_matlab_row_vector_semantics() {
836 let _guard = test_support::accel_test_lock();
837 let Ok(provider) = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
838 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
839 ) else {
840 return;
841 };
842 let data = [1.0, 4.0, 9.0];
843 let shape = [3usize];
844 let view = HostTensorView {
845 data: &data,
846 shape: &shape,
847 };
848 let handle = provider.upload(&view).expect("upload");
849 let result =
850 gradient_builtin(Value::GpuTensor(handle), vec![Value::Num(2.0)]).expect("gradient");
851 let gathered = test_support::gather(result).expect("gather");
852 assert_eq!(gathered.shape, vec![1, 3]);
853 assert_eq!(gathered.data, vec![1.5, 2.0, 2.5]);
854 }
855
856 #[test]
857 #[cfg(feature = "wgpu")]
858 fn gradient_gpu_multi_output_uses_output_list() {
859 let _guard = test_support::accel_test_lock();
860 let Ok(provider) = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
861 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
862 ) else {
863 return;
864 };
865 let host = Tensor::new(vec![1.0, 3.0, 2.0, 4.0], vec![2, 2]).unwrap();
866 let view = HostTensorView {
867 data: &host.data,
868 shape: &host.shape,
869 };
870 let handle = provider.upload(&view).expect("upload");
871 let _out_guard = crate::output_count::push_output_count(Some(2));
872 let result = gradient_builtin(Value::GpuTensor(handle), Vec::new()).expect("gradient");
873 match result {
874 Value::OutputList(outputs) => {
875 assert!(matches!(outputs[0], Value::GpuTensor(_)));
876 assert!(matches!(outputs[1], Value::GpuTensor(_)));
877 }
878 other => panic!("expected output list, got {other:?}"),
879 }
880 }
881
882 #[test]
883 fn gradient_inprocess_complex_gpu_matches_cpu_and_stays_resident() {
884 test_support::with_test_provider(|provider| {
885 let host = ComplexTensor::new(
886 vec![
887 (1.0, 1.0),
888 (2.0, -1.0),
889 (4.0, 3.0),
890 (6.0, 2.0),
891 (9.0, 6.0),
892 (12.0, 4.0),
893 ],
894 vec![2, 3],
895 )
896 .unwrap();
897 let expected =
898 gradient_complex_tensor_host(host.clone(), 2, 2.0).expect("cpu gradient");
899 let handle = gpu_helpers::upload_complex_tensor(provider, &host).expect("upload");
900 let result = gradient_builtin(Value::GpuTensor(handle), vec![Value::Num(2.0)])
901 .expect("gradient");
902 let Value::GpuTensor(out_handle) = result else {
903 panic!("expected complex gpu tensor");
904 };
905 assert_eq!(
906 runmat_accelerate_api::handle_storage(&out_handle),
907 GpuTensorStorage::ComplexInterleaved
908 );
909 let gathered = block_on(
910 crate::builtins::math::fft::common::gather_gpu_complex_tensor(&out_handle, NAME),
911 )
912 .expect("gather complex gradient");
913 assert_eq!(gathered.shape, expected.shape);
914 assert_eq!(gathered.data, expected.data);
915 });
916 }
917
918 #[test]
919 #[cfg(feature = "wgpu")]
920 fn gradient_gpu_complex_matches_cpu_and_stays_resident() {
921 let _guard = test_support::accel_test_lock();
922 let Ok(provider) = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
923 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
924 ) else {
925 return;
926 };
927 let host = ComplexTensor::new(
928 vec![
929 (1.0, 1.0),
930 (2.0, -1.0),
931 (4.0, 3.0),
932 (6.0, 2.0),
933 (9.0, 6.0),
934 (12.0, 4.0),
935 ],
936 vec![2, 3],
937 )
938 .unwrap();
939 let expected = gradient_complex_tensor_host(host.clone(), 2, 2.0).expect("cpu gradient");
940 let handle = gpu_helpers::upload_complex_tensor(provider, &host).expect("upload");
941 let result =
942 gradient_builtin(Value::GpuTensor(handle), vec![Value::Num(2.0)]).expect("gradient");
943 let Value::GpuTensor(out_handle) = result else {
944 panic!("expected complex gpu tensor");
945 };
946 assert_eq!(
947 runmat_accelerate_api::handle_storage(&out_handle),
948 GpuTensorStorage::ComplexInterleaved
949 );
950 let gathered = block_on(
951 crate::builtins::math::fft::common::gather_gpu_complex_tensor(&out_handle, NAME),
952 )
953 .expect("gather complex gradient");
954 assert_eq!(gathered.shape, expected.shape);
955 for (idx, (actual, expected)) in gathered.data.iter().zip(expected.data.iter()).enumerate()
956 {
957 assert!(
958 (actual.0 - expected.0).abs() <= 1.0e-5,
959 "real mismatch at {idx}: actual={} expected={}",
960 actual.0,
961 expected.0
962 );
963 assert!(
964 (actual.1 - expected.1).abs() <= 1.0e-5,
965 "imag mismatch at {idx}: actual={} expected={}",
966 actual.1,
967 expected.1
968 );
969 }
970 }
971}