Skip to main content

runmat_runtime/builtins/math/reduction/
gradient.rs

1//! MATLAB-compatible `gradient` builtin with scalar-spacing GPU residency.
2
3use runmat_accelerate_api::{GpuTensorHandle, GpuTensorStorage};
4use runmat_builtins::{
5    BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
6    BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
7    ComplexTensor, ResolveContext, Tensor, Type, Value,
8};
9use runmat_macros::runtime_builtin;
10
11use crate::builtins::common::gpu_helpers;
12use crate::builtins::common::random_args::complex_tensor_into_value;
13use crate::builtins::common::spec::{
14    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
15    ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
16};
17use crate::builtins::common::tensor;
18use crate::builtins::math::type_resolvers::numeric_unary_type;
19use crate::{build_runtime_error, BuiltinResult, RuntimeError};
20
21const NAME: &str = "gradient";
22
23fn gradient_type(args: &[Type], ctx: &ResolveContext) -> Type {
24    numeric_unary_type(args, ctx)
25}
26
27const GRADIENT_OUTPUT_G: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
28    name: "G",
29    ty: BuiltinParamType::NumericArray,
30    arity: BuiltinParamArity::Required,
31    default: None,
32    description: "Primary gradient component.",
33}];
34
35const GRADIENT_OUTPUT_GS: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
36    name: "Gi",
37    ty: BuiltinParamType::NumericArray,
38    arity: BuiltinParamArity::Variadic,
39    default: None,
40    description: "Gradient components ordered by MATLAB axis semantics.",
41}];
42
43const GRADIENT_INPUTS_F: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
44    name: "F",
45    ty: BuiltinParamType::Any,
46    arity: BuiltinParamArity::Required,
47    default: None,
48    description: "Input scalar or array.",
49}];
50
51const GRADIENT_INPUTS_F_H: [BuiltinParamDescriptor; 2] = [
52    BuiltinParamDescriptor {
53        name: "F",
54        ty: BuiltinParamType::Any,
55        arity: BuiltinParamArity::Required,
56        default: None,
57        description: "Input scalar or array.",
58    },
59    BuiltinParamDescriptor {
60        name: "h",
61        ty: BuiltinParamType::Any,
62        arity: BuiltinParamArity::Optional,
63        default: Some("1"),
64        description: "Scalar spacing shared across all output dimensions.",
65    },
66];
67
68const GRADIENT_INPUTS_F_HS: [BuiltinParamDescriptor; 2] = [
69    BuiltinParamDescriptor {
70        name: "F",
71        ty: BuiltinParamType::Any,
72        arity: BuiltinParamArity::Required,
73        default: None,
74        description: "Input scalar or array.",
75    },
76    BuiltinParamDescriptor {
77        name: "h_i",
78        ty: BuiltinParamType::Any,
79        arity: BuiltinParamArity::Variadic,
80        default: None,
81        description: "Per-dimension scalar spacings (one per requested gradient component).",
82    },
83];
84
85const GRADIENT_SIGNATURES: [BuiltinSignatureDescriptor; 4] = [
86    BuiltinSignatureDescriptor {
87        label: "G = gradient(F)",
88        inputs: &GRADIENT_INPUTS_F,
89        outputs: &GRADIENT_OUTPUT_G,
90    },
91    BuiltinSignatureDescriptor {
92        label: "G = gradient(F, h)",
93        inputs: &GRADIENT_INPUTS_F_H,
94        outputs: &GRADIENT_OUTPUT_G,
95    },
96    BuiltinSignatureDescriptor {
97        label: "[G1, G2, ...] = gradient(F)",
98        inputs: &GRADIENT_INPUTS_F,
99        outputs: &GRADIENT_OUTPUT_GS,
100    },
101    BuiltinSignatureDescriptor {
102        label: "[G1, G2, ...] = gradient(F, h1, h2, ...)",
103        inputs: &GRADIENT_INPUTS_F_HS,
104        outputs: &GRADIENT_OUTPUT_GS,
105    },
106];
107
108const GRADIENT_ERROR_INVALID_ARGUMENT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
109    code: "RM.GRADIENT.INVALID_ARGUMENT",
110    identifier: Some("RunMat:gradient:InvalidArgument"),
111    when: "Output-count or spacing argument grammar is invalid.",
112    message: "gradient: invalid argument",
113};
114
115const GRADIENT_ERROR_INVALID_INPUT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
116    code: "RM.GRADIENT.INVALID_INPUT",
117    identifier: Some("RunMat:gradient:InvalidInput"),
118    when: "Input value cannot be converted to a supported gradient domain.",
119    message: "gradient: invalid input",
120};
121
122const GRADIENT_ERROR_INTERNAL: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
123    code: "RM.GRADIENT.INTERNAL",
124    identifier: Some("RunMat:gradient:Internal"),
125    when: "Gradient execution fails due to gather, conversion, allocation, or indexing operations.",
126    message: "gradient: internal failure",
127};
128
129const GRADIENT_ERRORS: [BuiltinErrorDescriptor; 3] = [
130    GRADIENT_ERROR_INVALID_ARGUMENT,
131    GRADIENT_ERROR_INVALID_INPUT,
132    GRADIENT_ERROR_INTERNAL,
133];
134
135pub const GRADIENT_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
136    signatures: &GRADIENT_SIGNATURES,
137    output_mode: BuiltinOutputMode::ByRequestedOutputCount,
138    completion_policy: BuiltinCompletionPolicy::Public,
139    errors: &GRADIENT_ERRORS,
140};
141
142fn gradient_descriptor_error_with_message(
143    message: impl Into<String>,
144    error: &'static BuiltinErrorDescriptor,
145) -> RuntimeError {
146    let mut builder = build_runtime_error(message).with_builtin(NAME);
147    if let Some(identifier) = error.identifier {
148        builder = builder.with_identifier(identifier);
149    }
150    builder.build()
151}
152
153fn gradient_descriptor_error_with_detail(
154    error: &'static BuiltinErrorDescriptor,
155    detail: impl AsRef<str>,
156) -> RuntimeError {
157    gradient_descriptor_error_with_message(format!("{}: {}", error.message, detail.as_ref()), error)
158}
159
160fn gradient_invalid_argument(detail: impl AsRef<str>) -> RuntimeError {
161    gradient_descriptor_error_with_detail(&GRADIENT_ERROR_INVALID_ARGUMENT, detail)
162}
163
164fn gradient_invalid_input(detail: impl AsRef<str>) -> RuntimeError {
165    gradient_descriptor_error_with_detail(&GRADIENT_ERROR_INVALID_INPUT, detail)
166}
167
168fn gradient_internal_error(detail: impl AsRef<str>) -> RuntimeError {
169    gradient_descriptor_error_with_detail(&GRADIENT_ERROR_INTERNAL, detail)
170}
171
172#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::math::reduction::gradient")]
173pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
174    name: "gradient",
175    op_kind: GpuOpKind::Custom("numerical-gradient"),
176    supported_precisions: &[ScalarType::F32, ScalarType::F64],
177    broadcast: BroadcastSemantics::Matlab,
178    provider_hooks: &[ProviderHook::Custom("gradient_dim")],
179    constant_strategy: ConstantStrategy::InlineLiteral,
180    residency: ResidencyPolicy::NewHandle,
181    nan_mode: ReductionNaN::Include,
182    two_pass_threshold: None,
183    workgroup_size: None,
184    accepts_nan_mode: false,
185    notes:
186        "Providers may keep scalar-spacing gradients on device via `gradient_dim`; coordinate-vector spacing falls back to the host in this implementation.",
187};
188
189#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::math::reduction::gradient")]
190pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
191    name: "gradient",
192    shape: ShapeRequirements::Any,
193    constant_strategy: ConstantStrategy::InlineLiteral,
194    elementwise: None,
195    reduction: None,
196    emits_nan: false,
197    notes: "Gradient preserves input shape and uses edge-aware finite differences, so providers expose it through a custom sink hook.",
198};
199
200#[runtime_builtin(
201    name = "gradient",
202    category = "math/reduction",
203    summary = "Compute numerical gradients.",
204    keywords = "gradient,numerical gradient,finite difference,vector field,gpu",
205    accel = "gradient",
206    type_resolver(gradient_type),
207    descriptor(crate::builtins::math::reduction::gradient::GRADIENT_DESCRIPTOR),
208    builtin_path = "crate::builtins::math::reduction::gradient"
209)]
210async fn gradient_builtin(value: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
211    let requested_outputs = crate::output_count::current_output_count().unwrap_or(1);
212    if requested_outputs == 0 {
213        return Ok(Value::OutputList(Vec::new()));
214    }
215
216    let available_outputs = gradient_output_dims(value_shape(&value), value_len(&value));
217    if requested_outputs > available_outputs.len() {
218        return Err(gradient_invalid_argument(format!(
219            "gradient: requested {requested_outputs} outputs, but input supports at most {}",
220            available_outputs.len()
221        )));
222    }
223
224    let spacings = parse_spacings(&rest, available_outputs.len()).await?;
225    let outputs =
226        evaluate_gradient_outputs(value, &available_outputs[..requested_outputs], &spacings)
227            .await?;
228
229    if crate::output_count::current_output_count().is_some() {
230        return Ok(Value::OutputList(outputs));
231    }
232
233    Ok(outputs
234        .into_iter()
235        .next()
236        .expect("single-output gradient result"))
237}
238
239async fn evaluate_gradient_outputs(
240    value: Value,
241    requested_dims: &[usize],
242    all_spacings: &[f64],
243) -> BuiltinResult<Vec<Value>> {
244    if let Value::GpuTensor(handle) = value {
245        return gradient_gpu_outputs(handle, requested_dims, all_spacings).await;
246    }
247
248    evaluate_host_gradient_outputs(value, requested_dims, all_spacings)
249}
250
251fn evaluate_host_gradient_outputs(
252    value: Value,
253    requested_dims: &[usize],
254    all_spacings: &[f64],
255) -> BuiltinResult<Vec<Value>> {
256    match value {
257        Value::Tensor(tensor) => {
258            let mut outputs = Vec::with_capacity(requested_dims.len());
259            for &dim in requested_dims {
260                let spacing = spacing_for_dim(dim, requested_dims, all_spacings);
261                outputs.push(tensor::tensor_into_value(gradient_real_tensor_host(
262                    tensor.clone(),
263                    dim,
264                    spacing,
265                )?));
266            }
267            Ok(outputs)
268        }
269        Value::LogicalArray(logical) => {
270            let tensor = tensor::logical_to_tensor(&logical).map_err(gradient_invalid_input)?;
271            let mut outputs = Vec::with_capacity(requested_dims.len());
272            for &dim in requested_dims {
273                let spacing = spacing_for_dim(dim, requested_dims, all_spacings);
274                outputs.push(tensor::tensor_into_value(gradient_real_tensor_host(
275                    tensor.clone(),
276                    dim,
277                    spacing,
278                )?));
279            }
280            Ok(outputs)
281        }
282        Value::Num(_) | Value::Int(_) | Value::Bool(_) => {
283            let tensor =
284                tensor::value_into_tensor_for(NAME, value).map_err(gradient_invalid_input)?;
285            let mut outputs = Vec::with_capacity(requested_dims.len());
286            for &dim in requested_dims {
287                let spacing = spacing_for_dim(dim, requested_dims, all_spacings);
288                outputs.push(tensor::tensor_into_value(gradient_real_tensor_host(
289                    tensor.clone(),
290                    dim,
291                    spacing,
292                )?));
293            }
294            Ok(outputs)
295        }
296        Value::Complex(re, im) => {
297            let tensor = ComplexTensor {
298                data: vec![(re, im)],
299                shape: vec![1, 1],
300                rows: 1,
301                cols: 1,
302            };
303            let mut outputs = Vec::with_capacity(requested_dims.len());
304            for &dim in requested_dims {
305                let spacing = spacing_for_dim(dim, requested_dims, all_spacings);
306                outputs.push(complex_tensor_into_value(gradient_complex_tensor_host(
307                    tensor.clone(),
308                    dim,
309                    spacing,
310                )?));
311            }
312            Ok(outputs)
313        }
314        Value::ComplexTensor(tensor) => {
315            let mut outputs = Vec::with_capacity(requested_dims.len());
316            for &dim in requested_dims {
317                let spacing = spacing_for_dim(dim, requested_dims, all_spacings);
318                outputs.push(complex_tensor_into_value(gradient_complex_tensor_host(
319                    tensor.clone(),
320                    dim,
321                    spacing,
322                )?));
323            }
324            Ok(outputs)
325        }
326        other => Err(gradient_invalid_input(format!(
327            "gradient: unsupported input type {:?}; expected numeric or logical data",
328            other
329        ))),
330    }
331}
332
333async fn gradient_gpu_outputs(
334    handle: GpuTensorHandle,
335    requested_dims: &[usize],
336    all_spacings: &[f64],
337) -> BuiltinResult<Vec<Value>> {
338    let complex_storage =
339        runmat_accelerate_api::handle_storage(&handle) == GpuTensorStorage::ComplexInterleaved;
340
341    if let Some(provider) =
342        runmat_accelerate_api::provider_for_handle(&handle).or_else(runmat_accelerate_api::provider)
343    {
344        let _guard = runmat_accelerate_api::ThreadProviderGuard::set(Some(provider));
345        let mut outputs = Vec::with_capacity(requested_dims.len());
346        for &dim in requested_dims {
347            let spacing = spacing_for_dim(dim, requested_dims, all_spacings);
348            match provider.gradient_dim(&handle, dim.saturating_sub(1), spacing) {
349                Ok(device_result) => {
350                    if complex_storage
351                        || runmat_accelerate_api::handle_storage(&device_result)
352                            == GpuTensorStorage::ComplexInterleaved
353                    {
354                        outputs.push(gpu_helpers::complex_gpu_value(device_result));
355                    } else {
356                        outputs.push(gpu_helpers::resident_gpu_value(device_result));
357                    }
358                }
359                Err(_) => {
360                    let gathered =
361                        gpu_helpers::gather_value_async(&Value::GpuTensor(handle)).await?;
362                    return evaluate_host_gradient_outputs(gathered, requested_dims, all_spacings);
363                }
364            }
365        }
366        return Ok(outputs);
367    }
368
369    let gathered = gpu_helpers::gather_value_async(&Value::GpuTensor(handle)).await?;
370    evaluate_host_gradient_outputs(gathered, requested_dims, all_spacings)
371}
372
373fn spacing_for_dim(dim: usize, available_dims: &[usize], spacings: &[f64]) -> f64 {
374    if spacings.len() == 1 {
375        return spacings[0];
376    }
377
378    let index = available_dims
379        .iter()
380        .position(|candidate| *candidate == dim)
381        .expect("spacing lookup requires matching dimension");
382    spacings[index]
383}
384
385async fn parse_spacings(args: &[Value], available_dims: usize) -> BuiltinResult<Vec<f64>> {
386    match args.len() {
387        0 => Ok(vec![1.0; available_dims]),
388        1 => {
389            let spacing = parse_scalar_spacing(&args[0]).await?;
390            Ok(vec![spacing; available_dims])
391        }
392        count if count == available_dims => {
393            let mut spacings = Vec::with_capacity(args.len());
394            for value in args {
395                spacings.push(parse_scalar_spacing(value).await?);
396            }
397            Ok(spacings)
398        }
399        _ => Err(gradient_invalid_argument(format!(
400            "gradient: expected 0, 1, or {available_dims} scalar spacing arguments"
401        ))),
402    }
403}
404
405async fn parse_scalar_spacing(value: &Value) -> BuiltinResult<f64> {
406    match value {
407        Value::Tensor(tensor) if tensor.data.is_empty() => {
408            return Err(gradient_invalid_argument(
409                "gradient: empty spacing arguments are not supported",
410            ))
411        }
412        _ => {}
413    }
414
415    let Some(spacing) = tensor::scalar_f64_from_value_async(value)
416        .await
417        .map_err(gradient_invalid_argument)?
418    else {
419        return Err(gradient_invalid_argument(
420            "gradient: only scalar spacings are supported in this implementation",
421        ));
422    };
423
424    if !spacing.is_finite() {
425        return Err(gradient_invalid_argument(
426            "gradient: spacing must be finite",
427        ));
428    }
429    if spacing == 0.0 {
430        return Err(gradient_invalid_argument(
431            "gradient: spacing must be nonzero",
432        ));
433    }
434    Ok(spacing)
435}
436
437fn value_shape(value: &Value) -> &[usize] {
438    match value {
439        Value::Tensor(tensor) => &tensor.shape,
440        Value::LogicalArray(logical) => &logical.shape,
441        Value::ComplexTensor(tensor) => &tensor.shape,
442        Value::GpuTensor(handle) => &handle.shape,
443        _ => &[],
444    }
445}
446
447fn value_len(value: &Value) -> usize {
448    match value {
449        Value::Tensor(tensor) => tensor.data.len(),
450        Value::LogicalArray(logical) => logical.data.len(),
451        Value::ComplexTensor(tensor) => tensor.data.len(),
452        Value::GpuTensor(handle) => product(&handle.shape),
453        _ => 1,
454    }
455}
456
457pub fn matlab_gradient_shape(shape: &[usize], len: usize) -> Vec<usize> {
458    if shape.is_empty() {
459        if len == 0 {
460            Vec::new()
461        } else {
462            vec![1, 1]
463        }
464    } else if shape.len() == 1 {
465        if shape[0] == 1 {
466            vec![1, 1]
467        } else {
468            vec![1, shape[0]]
469        }
470    } else {
471        shape.to_vec()
472    }
473}
474
475fn gradient_output_dims(shape: &[usize], len: usize) -> Vec<usize> {
476    let normalized_shape = matlab_gradient_shape(shape, len);
477    let mut ext_shape = if normalized_shape.is_empty() {
478        if len == 0 {
479            vec![0, 0]
480        } else {
481            vec![1, 1]
482        }
483    } else {
484        normalized_shape
485    };
486    if ext_shape.len() == 1 {
487        ext_shape.push(1);
488    }
489
490    if ext_shape.len() <= 2 {
491        let rows = ext_shape.first().copied().unwrap_or(1);
492        let cols = ext_shape.get(1).copied().unwrap_or(1);
493        if rows == 1 && cols == 1 {
494            vec![1]
495        } else if rows == 1 {
496            vec![2]
497        } else if cols == 1 {
498            vec![1]
499        } else {
500            vec![2, 1]
501        }
502    } else {
503        let mut dims = vec![2, 1];
504        for dim in 3..=ext_shape.len() {
505            dims.push(dim);
506        }
507        dims
508    }
509}
510
511pub fn gradient_real_tensor_host(
512    tensor: Tensor,
513    dim: usize,
514    spacing: f64,
515) -> BuiltinResult<Tensor> {
516    let Tensor {
517        data, shape, dtype, ..
518    } = tensor;
519    let dim_index = dim.saturating_sub(1);
520    let mut shape = matlab_gradient_shape(&shape, data.len());
521
522    if data.is_empty() {
523        // Return early before the `push(1)` padding loop: that loop would give a
524        // shape like [1] or [1,1] whose product is 1 ≠ 0, violating Tensor's
525        // invariant. Use the normalised shape directly, falling back to [0,0] if
526        // matlab_gradient_shape returned an empty vec (untyped empty tensor).
527        let empty_shape = if shape.is_empty() { vec![0, 0] } else { shape };
528        return Tensor::new_with_dtype(Vec::new(), empty_shape, dtype)
529            .map_err(|e| gradient_internal_error(format!("gradient: {e}")));
530    }
531
532    while shape.len() <= dim_index {
533        shape.push(1);
534    }
535
536    let mut ext_shape = shape.clone();
537    while ext_shape.len() <= dim_index {
538        ext_shape.push(1);
539    }
540    let len_dim = ext_shape[dim_index];
541    let stride_before = if dim_index == 0 {
542        1usize
543    } else {
544        product(&ext_shape[..dim_index]).max(1)
545    };
546    let stride_after = if dim_index + 1 >= ext_shape.len() {
547        1usize
548    } else {
549        product(&ext_shape[dim_index + 1..]).max(1)
550    };
551
552    let mut out = vec![0.0; data.len()];
553    if len_dim > 1 {
554        let block = stride_before
555            .checked_mul(len_dim)
556            .ok_or_else(|| gradient_internal_error("gradient: block size overflow"))?;
557        for after in 0..stride_after {
558            let base = after
559                .checked_mul(block)
560                .ok_or_else(|| gradient_internal_error("gradient: indexing overflow"))?;
561            for before in 0..stride_before {
562                for k in 0..len_dim {
563                    let idx = base + before + k * stride_before;
564                    out[idx] = if k == 0 {
565                        (data[idx + stride_before] - data[idx]) / spacing
566                    } else if k + 1 == len_dim {
567                        (data[idx] - data[idx - stride_before]) / spacing
568                    } else {
569                        (data[idx + stride_before] - data[idx - stride_before]) / (2.0 * spacing)
570                    };
571                }
572            }
573        }
574    }
575
576    Tensor::new_with_dtype(out, shape, dtype)
577        .map_err(|e| gradient_internal_error(format!("gradient: {e}")))
578}
579
580pub fn gradient_complex_tensor_host(
581    tensor: ComplexTensor,
582    dim: usize,
583    spacing: f64,
584) -> BuiltinResult<ComplexTensor> {
585    let ComplexTensor { data, shape, .. } = tensor;
586    let dim_index = dim.saturating_sub(1);
587    let mut shape = matlab_gradient_shape(&shape, data.len());
588
589    if data.is_empty() {
590        // Same fix as gradient_real_tensor_host: avoid padding the shape with 1s
591        // before the early return, which would produce product ≠ 0 for empty data.
592        let empty_shape = if shape.is_empty() { vec![0, 0] } else { shape };
593        return ComplexTensor::new(Vec::new(), empty_shape)
594            .map_err(|e| gradient_internal_error(format!("gradient: {e}")));
595    }
596
597    while shape.len() <= dim_index {
598        shape.push(1);
599    }
600
601    let mut ext_shape = shape.clone();
602    while ext_shape.len() <= dim_index {
603        ext_shape.push(1);
604    }
605    let len_dim = ext_shape[dim_index];
606    let stride_before = if dim_index == 0 {
607        1usize
608    } else {
609        product(&ext_shape[..dim_index]).max(1)
610    };
611    let stride_after = if dim_index + 1 >= ext_shape.len() {
612        1usize
613    } else {
614        product(&ext_shape[dim_index + 1..]).max(1)
615    };
616
617    let mut out = vec![(0.0, 0.0); data.len()];
618    if len_dim > 1 {
619        let block = stride_before
620            .checked_mul(len_dim)
621            .ok_or_else(|| gradient_internal_error("gradient: block size overflow"))?;
622        for after in 0..stride_after {
623            let base = after
624                .checked_mul(block)
625                .ok_or_else(|| gradient_internal_error("gradient: indexing overflow"))?;
626            for before in 0..stride_before {
627                for k in 0..len_dim {
628                    let idx = base + before + k * stride_before;
629                    out[idx] = if k == 0 {
630                        scale_complex(
631                            sub_complex(data[idx + stride_before], data[idx]),
632                            1.0 / spacing,
633                        )
634                    } else if k + 1 == len_dim {
635                        scale_complex(
636                            sub_complex(data[idx], data[idx - stride_before]),
637                            1.0 / spacing,
638                        )
639                    } else {
640                        scale_complex(
641                            sub_complex(data[idx + stride_before], data[idx - stride_before]),
642                            0.5 / spacing,
643                        )
644                    };
645                }
646            }
647        }
648    }
649
650    ComplexTensor::new(out, shape).map_err(|e| gradient_internal_error(format!("gradient: {e}")))
651}
652
653fn sub_complex(lhs: (f64, f64), rhs: (f64, f64)) -> (f64, f64) {
654    (lhs.0 - rhs.0, lhs.1 - rhs.1)
655}
656
657fn scale_complex(value: (f64, f64), scale: f64) -> (f64, f64) {
658    (value.0 * scale, value.1 * scale)
659}
660
661fn product(dims: &[usize]) -> usize {
662    dims.iter()
663        .copied()
664        .fold(1usize, |acc, value| acc.saturating_mul(value))
665}
666
667#[cfg(test)]
668mod tests {
669    use super::*;
670    use crate::builtins::common::test_support;
671    use futures::executor::block_on;
672    #[cfg(feature = "wgpu")]
673    use runmat_accelerate_api::AccelProvider;
674    #[cfg(feature = "wgpu")]
675    use runmat_accelerate_api::HostTensorView;
676    use runmat_builtins::{NumericDType, Tensor};
677
678    fn gradient_builtin(value: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
679        block_on(super::gradient_builtin(value, rest))
680    }
681
682    #[test]
683    fn gradient_descriptor_signatures_cover_core_forms() {
684        let labels: Vec<&str> = GRADIENT_DESCRIPTOR
685            .signatures
686            .iter()
687            .map(|sig| sig.label)
688            .collect();
689        assert!(labels.contains(&"G = gradient(F)"));
690        assert!(labels.contains(&"G = gradient(F, h)"));
691        assert!(labels.contains(&"[G1, G2, ...] = gradient(F)"));
692        assert!(labels.contains(&"[G1, G2, ...] = gradient(F, h1, h2, ...)"));
693    }
694
695    #[test]
696    fn gradient_descriptor_errors_have_stable_codes() {
697        assert!(GRADIENT_DESCRIPTOR
698            .errors
699            .iter()
700            .any(|error| error.code == GRADIENT_ERROR_INVALID_ARGUMENT.code));
701        assert!(GRADIENT_DESCRIPTOR
702            .errors
703            .iter()
704            .any(|error| error.code == GRADIENT_ERROR_INVALID_INPUT.code));
705        assert!(GRADIENT_DESCRIPTOR
706            .errors
707            .iter()
708            .any(|error| error.code == GRADIENT_ERROR_INTERNAL.code));
709    }
710
711    #[test]
712    fn gradient_row_vector_returns_horizontal_derivative() {
713        let tensor = Tensor::new(vec![1.0, 4.0, 9.0], vec![1, 3]).unwrap();
714        let result = gradient_builtin(Value::Tensor(tensor), Vec::new()).expect("gradient");
715        assert_eq!(
716            result,
717            Value::Tensor(Tensor::new(vec![3.0, 4.0, 5.0], vec![1, 3]).unwrap())
718        );
719    }
720
721    #[test]
722    fn gradient_one_dimensional_tensor_is_treated_as_row_vector() {
723        let tensor = Tensor::new(vec![1.0, 4.0, 9.0], vec![3]).unwrap();
724        let result =
725            gradient_builtin(Value::Tensor(tensor), vec![Value::Num(2.0)]).expect("gradient");
726        match result {
727            Value::Tensor(out) => {
728                assert_eq!(out.shape, vec![1, 3]);
729                assert_eq!(out.data, vec![1.5, 2.0, 2.5]);
730            }
731            other => panic!("expected tensor, got {other:?}"),
732        }
733    }
734
735    #[test]
736    fn gradient_matrix_outputs_follow_matlab_order() {
737        let tensor = Tensor::new(vec![1.0, 3.0, 2.0, 4.0], vec![2, 2]).unwrap();
738        let _guard = crate::output_count::push_output_count(Some(2));
739        let result = gradient_builtin(Value::Tensor(tensor), Vec::new()).expect("gradient");
740        match result {
741            Value::OutputList(outputs) => {
742                let fx = test_support::gather(outputs[0].clone()).expect("fx");
743                let fy = test_support::gather(outputs[1].clone()).expect("fy");
744                assert_eq!(fx.data, vec![1.0, 1.0, 1.0, 1.0]);
745                assert_eq!(fy.data, vec![2.0, 2.0, 2.0, 2.0]);
746            }
747            other => panic!("expected output list, got {other:?}"),
748        }
749    }
750
751    #[test]
752    fn gradient_scalar_spacing_scales_output() {
753        let tensor = Tensor::new(vec![1.0, 4.0, 9.0], vec![1, 3]).unwrap();
754        let result =
755            gradient_builtin(Value::Tensor(tensor), vec![Value::Num(2.0)]).expect("gradient");
756        match result {
757            Value::Tensor(out) => assert_eq!(out.data, vec![1.5, 2.0, 2.5]),
758            other => panic!("expected tensor, got {other:?}"),
759        }
760    }
761
762    #[test]
763    fn gradient_preserves_single_precision_host_tensor() {
764        let tensor =
765            Tensor::new_with_dtype(vec![1.0, 4.0, 9.0], vec![1, 3], NumericDType::F32).unwrap();
766        let result = gradient_builtin(Value::Tensor(tensor), Vec::new()).expect("gradient");
767        match result {
768            Value::Tensor(out) => assert_eq!(out.dtype, NumericDType::F32),
769            other => panic!("expected tensor, got {other:?}"),
770        }
771    }
772
773    #[test]
774    fn gradient_complex_host_supported() {
775        let tensor =
776            ComplexTensor::new(vec![(1.0, 1.0), (4.0, 3.0), (9.0, 6.0)], vec![1, 3]).unwrap();
777        let result = gradient_builtin(Value::ComplexTensor(tensor), Vec::new()).expect("gradient");
778        match result {
779            Value::ComplexTensor(out) => {
780                assert_eq!(out.data, vec![(3.0, 2.0), (4.0, 2.5), (5.0, 3.0)]);
781            }
782            other => panic!("expected complex tensor, got {other:?}"),
783        }
784    }
785
786    #[test]
787    fn gradient_rejects_coordinate_vector_spacing_in_v1() {
788        let tensor = Tensor::new(vec![1.0, 4.0, 9.0], vec![1, 3]).unwrap();
789        let spacing = Tensor::new(vec![0.0, 1.0, 2.0], vec![1, 3]).unwrap();
790        let err =
791            gradient_builtin(Value::Tensor(tensor), vec![Value::Tensor(spacing)]).unwrap_err();
792        assert_eq!(err.identifier(), GRADIENT_ERROR_INVALID_ARGUMENT.identifier);
793        assert!(err.message().contains("scalar"));
794    }
795
796    #[test]
797    fn gradient_rejects_too_many_outputs() {
798        let tensor = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
799        let _guard = crate::output_count::push_output_count(Some(2));
800        let err = gradient_builtin(Value::Tensor(tensor), Vec::new()).unwrap_err();
801        assert_eq!(err.identifier(), GRADIENT_ERROR_INVALID_ARGUMENT.identifier);
802        assert!(err.message().contains("requested 2 outputs"));
803    }
804
805    #[test]
806    #[cfg(feature = "wgpu")]
807    fn gradient_gpu_scalar_spacing_matches_cpu_and_stays_resident() {
808        let _guard = test_support::accel_test_lock();
809        let Ok(provider) = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
810            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
811        ) else {
812            return;
813        };
814        let host =
815            Tensor::new_with_dtype(vec![1.0, 4.0, 9.0], vec![1, 3], NumericDType::F32).unwrap();
816        let view = HostTensorView {
817            data: &host.data,
818            shape: &host.shape,
819        };
820        let handle = provider.upload(&view).expect("upload");
821        let result =
822            gradient_builtin(Value::GpuTensor(handle), vec![Value::Num(2.0)]).expect("gradient");
823        match result {
824            Value::GpuTensor(out) => {
825                let gathered = test_support::gather(Value::GpuTensor(out)).expect("gather");
826                assert_eq!(gathered.data, vec![1.5, 2.0, 2.5]);
827                assert_eq!(gathered.dtype, NumericDType::F32);
828            }
829            other => panic!("expected gpu tensor, got {other:?}"),
830        }
831    }
832
833    #[test]
834    #[cfg(feature = "wgpu")]
835    fn gradient_gpu_one_dimensional_shape_matches_matlab_row_vector_semantics() {
836        let _guard = test_support::accel_test_lock();
837        let Ok(provider) = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
838            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
839        ) else {
840            return;
841        };
842        let data = [1.0, 4.0, 9.0];
843        let shape = [3usize];
844        let view = HostTensorView {
845            data: &data,
846            shape: &shape,
847        };
848        let handle = provider.upload(&view).expect("upload");
849        let result =
850            gradient_builtin(Value::GpuTensor(handle), vec![Value::Num(2.0)]).expect("gradient");
851        let gathered = test_support::gather(result).expect("gather");
852        assert_eq!(gathered.shape, vec![1, 3]);
853        assert_eq!(gathered.data, vec![1.5, 2.0, 2.5]);
854    }
855
856    #[test]
857    #[cfg(feature = "wgpu")]
858    fn gradient_gpu_multi_output_uses_output_list() {
859        let _guard = test_support::accel_test_lock();
860        let Ok(provider) = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
861            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
862        ) else {
863            return;
864        };
865        let host = Tensor::new(vec![1.0, 3.0, 2.0, 4.0], vec![2, 2]).unwrap();
866        let view = HostTensorView {
867            data: &host.data,
868            shape: &host.shape,
869        };
870        let handle = provider.upload(&view).expect("upload");
871        let _out_guard = crate::output_count::push_output_count(Some(2));
872        let result = gradient_builtin(Value::GpuTensor(handle), Vec::new()).expect("gradient");
873        match result {
874            Value::OutputList(outputs) => {
875                assert!(matches!(outputs[0], Value::GpuTensor(_)));
876                assert!(matches!(outputs[1], Value::GpuTensor(_)));
877            }
878            other => panic!("expected output list, got {other:?}"),
879        }
880    }
881
882    #[test]
883    fn gradient_inprocess_complex_gpu_matches_cpu_and_stays_resident() {
884        test_support::with_test_provider(|provider| {
885            let host = ComplexTensor::new(
886                vec![
887                    (1.0, 1.0),
888                    (2.0, -1.0),
889                    (4.0, 3.0),
890                    (6.0, 2.0),
891                    (9.0, 6.0),
892                    (12.0, 4.0),
893                ],
894                vec![2, 3],
895            )
896            .unwrap();
897            let expected =
898                gradient_complex_tensor_host(host.clone(), 2, 2.0).expect("cpu gradient");
899            let handle = gpu_helpers::upload_complex_tensor(provider, &host).expect("upload");
900            let result = gradient_builtin(Value::GpuTensor(handle), vec![Value::Num(2.0)])
901                .expect("gradient");
902            let Value::GpuTensor(out_handle) = result else {
903                panic!("expected complex gpu tensor");
904            };
905            assert_eq!(
906                runmat_accelerate_api::handle_storage(&out_handle),
907                GpuTensorStorage::ComplexInterleaved
908            );
909            let gathered = block_on(
910                crate::builtins::math::fft::common::gather_gpu_complex_tensor(&out_handle, NAME),
911            )
912            .expect("gather complex gradient");
913            assert_eq!(gathered.shape, expected.shape);
914            assert_eq!(gathered.data, expected.data);
915        });
916    }
917
918    #[test]
919    #[cfg(feature = "wgpu")]
920    fn gradient_gpu_complex_matches_cpu_and_stays_resident() {
921        let _guard = test_support::accel_test_lock();
922        let Ok(provider) = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
923            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
924        ) else {
925            return;
926        };
927        let host = ComplexTensor::new(
928            vec![
929                (1.0, 1.0),
930                (2.0, -1.0),
931                (4.0, 3.0),
932                (6.0, 2.0),
933                (9.0, 6.0),
934                (12.0, 4.0),
935            ],
936            vec![2, 3],
937        )
938        .unwrap();
939        let expected = gradient_complex_tensor_host(host.clone(), 2, 2.0).expect("cpu gradient");
940        let handle = gpu_helpers::upload_complex_tensor(provider, &host).expect("upload");
941        let result =
942            gradient_builtin(Value::GpuTensor(handle), vec![Value::Num(2.0)]).expect("gradient");
943        let Value::GpuTensor(out_handle) = result else {
944            panic!("expected complex gpu tensor");
945        };
946        assert_eq!(
947            runmat_accelerate_api::handle_storage(&out_handle),
948            GpuTensorStorage::ComplexInterleaved
949        );
950        let gathered = block_on(
951            crate::builtins::math::fft::common::gather_gpu_complex_tensor(&out_handle, NAME),
952        )
953        .expect("gather complex gradient");
954        assert_eq!(gathered.shape, expected.shape);
955        for (idx, (actual, expected)) in gathered.data.iter().zip(expected.data.iter()).enumerate()
956        {
957            assert!(
958                (actual.0 - expected.0).abs() <= 1.0e-5,
959                "real mismatch at {idx}: actual={} expected={}",
960                actual.0,
961                expected.0
962            );
963            assert!(
964                (actual.1 - expected.1).abs() <= 1.0e-5,
965                "imag mismatch at {idx}: actual={} expected={}",
966                actual.1,
967                expected.1
968            );
969        }
970    }
971}