1use runmat_accelerate_api::{GpuTensorHandle, GpuTensorStorage};
4use runmat_builtins::{ComplexTensor, ResolveContext, Tensor, Type, Value};
5use runmat_macros::runtime_builtin;
6
7use crate::builtins::common::gpu_helpers;
8use crate::builtins::common::random_args::complex_tensor_into_value;
9use crate::builtins::common::spec::{
10 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
11 ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
12};
13use crate::builtins::common::tensor;
14use crate::builtins::math::type_resolvers::numeric_unary_type;
15use crate::{build_runtime_error, BuiltinResult, RuntimeError};
16
17const NAME: &str = "gradient";
18
19fn gradient_error(message: impl Into<String>) -> RuntimeError {
20 build_runtime_error(message).with_builtin(NAME).build()
21}
22
23fn gradient_type(args: &[Type], ctx: &ResolveContext) -> Type {
24 numeric_unary_type(args, ctx)
25}
26
27#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::math::reduction::gradient")]
28pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
29 name: "gradient",
30 op_kind: GpuOpKind::Custom("numerical-gradient"),
31 supported_precisions: &[ScalarType::F32, ScalarType::F64],
32 broadcast: BroadcastSemantics::Matlab,
33 provider_hooks: &[ProviderHook::Custom("gradient_dim")],
34 constant_strategy: ConstantStrategy::InlineLiteral,
35 residency: ResidencyPolicy::NewHandle,
36 nan_mode: ReductionNaN::Include,
37 two_pass_threshold: None,
38 workgroup_size: None,
39 accepts_nan_mode: false,
40 notes:
41 "Providers may keep scalar-spacing gradients on device via `gradient_dim`; coordinate-vector spacing falls back to the host in this implementation.",
42};
43
44#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::math::reduction::gradient")]
45pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
46 name: "gradient",
47 shape: ShapeRequirements::Any,
48 constant_strategy: ConstantStrategy::InlineLiteral,
49 elementwise: None,
50 reduction: None,
51 emits_nan: false,
52 notes: "Gradient preserves input shape and uses edge-aware finite differences, so providers expose it through a custom sink hook.",
53};
54
55#[runtime_builtin(
56 name = "gradient",
57 category = "math/reduction",
58 summary = "Numerical gradients using central differences with MATLAB-compatible output ordering.",
59 keywords = "gradient,numerical gradient,finite difference,vector field,gpu",
60 accel = "gradient",
61 type_resolver(gradient_type),
62 builtin_path = "crate::builtins::math::reduction::gradient"
63)]
64async fn gradient_builtin(value: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
65 let requested_outputs = crate::output_count::current_output_count().unwrap_or(1);
66 if requested_outputs == 0 {
67 return Ok(Value::OutputList(Vec::new()));
68 }
69
70 let available_outputs = gradient_output_dims(value_shape(&value), value_len(&value));
71 if requested_outputs > available_outputs.len() {
72 return Err(gradient_error(format!(
73 "gradient: requested {requested_outputs} outputs, but input supports at most {}",
74 available_outputs.len()
75 )));
76 }
77
78 let spacings = parse_spacings(&rest, available_outputs.len()).await?;
79 let outputs =
80 evaluate_gradient_outputs(value, &available_outputs[..requested_outputs], &spacings)
81 .await?;
82
83 if crate::output_count::current_output_count().is_some() {
84 return Ok(Value::OutputList(outputs));
85 }
86
87 Ok(outputs
88 .into_iter()
89 .next()
90 .expect("single-output gradient result"))
91}
92
93async fn evaluate_gradient_outputs(
94 value: Value,
95 requested_dims: &[usize],
96 all_spacings: &[f64],
97) -> BuiltinResult<Vec<Value>> {
98 if let Value::GpuTensor(handle) = value {
99 return gradient_gpu_outputs(handle, requested_dims, all_spacings).await;
100 }
101
102 evaluate_host_gradient_outputs(value, requested_dims, all_spacings)
103}
104
105fn evaluate_host_gradient_outputs(
106 value: Value,
107 requested_dims: &[usize],
108 all_spacings: &[f64],
109) -> BuiltinResult<Vec<Value>> {
110 match value {
111 Value::Tensor(tensor) => {
112 let mut outputs = Vec::with_capacity(requested_dims.len());
113 for &dim in requested_dims {
114 let spacing = spacing_for_dim(dim, requested_dims, all_spacings);
115 outputs.push(tensor::tensor_into_value(gradient_real_tensor_host(
116 tensor.clone(),
117 dim,
118 spacing,
119 )?));
120 }
121 Ok(outputs)
122 }
123 Value::LogicalArray(logical) => {
124 let tensor = tensor::logical_to_tensor(&logical).map_err(gradient_error)?;
125 let mut outputs = Vec::with_capacity(requested_dims.len());
126 for &dim in requested_dims {
127 let spacing = spacing_for_dim(dim, requested_dims, all_spacings);
128 outputs.push(tensor::tensor_into_value(gradient_real_tensor_host(
129 tensor.clone(),
130 dim,
131 spacing,
132 )?));
133 }
134 Ok(outputs)
135 }
136 Value::Num(_) | Value::Int(_) | Value::Bool(_) => {
137 let tensor = tensor::value_into_tensor_for(NAME, value).map_err(gradient_error)?;
138 let mut outputs = Vec::with_capacity(requested_dims.len());
139 for &dim in requested_dims {
140 let spacing = spacing_for_dim(dim, requested_dims, all_spacings);
141 outputs.push(tensor::tensor_into_value(gradient_real_tensor_host(
142 tensor.clone(),
143 dim,
144 spacing,
145 )?));
146 }
147 Ok(outputs)
148 }
149 Value::Complex(re, im) => {
150 let tensor = ComplexTensor {
151 data: vec![(re, im)],
152 shape: vec![1, 1],
153 rows: 1,
154 cols: 1,
155 };
156 let mut outputs = Vec::with_capacity(requested_dims.len());
157 for &dim in requested_dims {
158 let spacing = spacing_for_dim(dim, requested_dims, all_spacings);
159 outputs.push(complex_tensor_into_value(gradient_complex_tensor_host(
160 tensor.clone(),
161 dim,
162 spacing,
163 )?));
164 }
165 Ok(outputs)
166 }
167 Value::ComplexTensor(tensor) => {
168 let mut outputs = Vec::with_capacity(requested_dims.len());
169 for &dim in requested_dims {
170 let spacing = spacing_for_dim(dim, requested_dims, all_spacings);
171 outputs.push(complex_tensor_into_value(gradient_complex_tensor_host(
172 tensor.clone(),
173 dim,
174 spacing,
175 )?));
176 }
177 Ok(outputs)
178 }
179 other => Err(gradient_error(format!(
180 "gradient: unsupported input type {:?}; expected numeric or logical data",
181 other
182 ))),
183 }
184}
185
186async fn gradient_gpu_outputs(
187 handle: GpuTensorHandle,
188 requested_dims: &[usize],
189 all_spacings: &[f64],
190) -> BuiltinResult<Vec<Value>> {
191 if runmat_accelerate_api::handle_storage(&handle) == GpuTensorStorage::ComplexInterleaved {
192 let gathered = gpu_helpers::gather_value_async(&Value::GpuTensor(handle)).await?;
193 return evaluate_host_gradient_outputs(gathered, requested_dims, all_spacings);
194 }
195
196 if let Some(provider) = runmat_accelerate_api::provider() {
197 let mut outputs = Vec::with_capacity(requested_dims.len());
198 for &dim in requested_dims {
199 let spacing = spacing_for_dim(dim, requested_dims, all_spacings);
200 match provider.gradient_dim(&handle, dim.saturating_sub(1), spacing) {
201 Ok(device_result) => outputs.push(gpu_helpers::resident_gpu_value(device_result)),
202 Err(_) => {
203 let gathered =
204 gpu_helpers::gather_value_async(&Value::GpuTensor(handle)).await?;
205 return evaluate_host_gradient_outputs(gathered, requested_dims, all_spacings);
206 }
207 }
208 }
209 return Ok(outputs);
210 }
211
212 let gathered = gpu_helpers::gather_value_async(&Value::GpuTensor(handle)).await?;
213 evaluate_host_gradient_outputs(gathered, requested_dims, all_spacings)
214}
215
216fn spacing_for_dim(dim: usize, available_dims: &[usize], spacings: &[f64]) -> f64 {
217 if spacings.len() == 1 {
218 return spacings[0];
219 }
220
221 let index = available_dims
222 .iter()
223 .position(|candidate| *candidate == dim)
224 .expect("spacing lookup requires matching dimension");
225 spacings[index]
226}
227
228async fn parse_spacings(args: &[Value], available_dims: usize) -> BuiltinResult<Vec<f64>> {
229 match args.len() {
230 0 => Ok(vec![1.0; available_dims]),
231 1 => {
232 let spacing = parse_scalar_spacing(&args[0]).await?;
233 Ok(vec![spacing; available_dims])
234 }
235 count if count == available_dims => {
236 let mut spacings = Vec::with_capacity(args.len());
237 for value in args {
238 spacings.push(parse_scalar_spacing(value).await?);
239 }
240 Ok(spacings)
241 }
242 _ => Err(gradient_error(format!(
243 "gradient: expected 0, 1, or {available_dims} scalar spacing arguments"
244 ))),
245 }
246}
247
248async fn parse_scalar_spacing(value: &Value) -> BuiltinResult<f64> {
249 match value {
250 Value::Tensor(tensor) if tensor.data.is_empty() => {
251 return Err(gradient_error(
252 "gradient: empty spacing arguments are not supported",
253 ))
254 }
255 _ => {}
256 }
257
258 let Some(spacing) = tensor::scalar_f64_from_value_async(value)
259 .await
260 .map_err(gradient_error)?
261 else {
262 return Err(gradient_error(
263 "gradient: only scalar spacings are supported in this implementation",
264 ));
265 };
266
267 if !spacing.is_finite() {
268 return Err(gradient_error("gradient: spacing must be finite"));
269 }
270 if spacing == 0.0 {
271 return Err(gradient_error("gradient: spacing must be nonzero"));
272 }
273 Ok(spacing)
274}
275
276fn value_shape(value: &Value) -> &[usize] {
277 match value {
278 Value::Tensor(tensor) => &tensor.shape,
279 Value::LogicalArray(logical) => &logical.shape,
280 Value::ComplexTensor(tensor) => &tensor.shape,
281 Value::GpuTensor(handle) => &handle.shape,
282 _ => &[],
283 }
284}
285
286fn value_len(value: &Value) -> usize {
287 match value {
288 Value::Tensor(tensor) => tensor.data.len(),
289 Value::LogicalArray(logical) => logical.data.len(),
290 Value::ComplexTensor(tensor) => tensor.data.len(),
291 Value::GpuTensor(handle) => product(&handle.shape),
292 _ => 1,
293 }
294}
295
296pub fn matlab_gradient_shape(shape: &[usize], len: usize) -> Vec<usize> {
297 if shape.is_empty() {
298 if len == 0 {
299 Vec::new()
300 } else {
301 vec![1, 1]
302 }
303 } else if shape.len() == 1 {
304 if shape[0] == 1 {
305 vec![1, 1]
306 } else {
307 vec![1, shape[0]]
308 }
309 } else {
310 shape.to_vec()
311 }
312}
313
314fn gradient_output_dims(shape: &[usize], len: usize) -> Vec<usize> {
315 let normalized_shape = matlab_gradient_shape(shape, len);
316 let mut ext_shape = if normalized_shape.is_empty() {
317 if len == 0 {
318 vec![0, 0]
319 } else {
320 vec![1, 1]
321 }
322 } else {
323 normalized_shape
324 };
325 if ext_shape.len() == 1 {
326 ext_shape.push(1);
327 }
328
329 if ext_shape.len() <= 2 {
330 let rows = ext_shape.first().copied().unwrap_or(1);
331 let cols = ext_shape.get(1).copied().unwrap_or(1);
332 if rows == 1 && cols == 1 {
333 vec![1]
334 } else if rows == 1 {
335 vec![2]
336 } else if cols == 1 {
337 vec![1]
338 } else {
339 vec![2, 1]
340 }
341 } else {
342 let mut dims = vec![2, 1];
343 for dim in 3..=ext_shape.len() {
344 dims.push(dim);
345 }
346 dims
347 }
348}
349
350pub fn gradient_real_tensor_host(
351 tensor: Tensor,
352 dim: usize,
353 spacing: f64,
354) -> BuiltinResult<Tensor> {
355 let Tensor {
356 data, shape, dtype, ..
357 } = tensor;
358 let dim_index = dim.saturating_sub(1);
359 let mut shape = matlab_gradient_shape(&shape, data.len());
360
361 if data.is_empty() {
362 let empty_shape = if shape.is_empty() { vec![0, 0] } else { shape };
367 return Tensor::new_with_dtype(Vec::new(), empty_shape, dtype)
368 .map_err(|e| gradient_error(format!("gradient: {e}")));
369 }
370
371 while shape.len() <= dim_index {
372 shape.push(1);
373 }
374
375 let mut ext_shape = shape.clone();
376 while ext_shape.len() <= dim_index {
377 ext_shape.push(1);
378 }
379 let len_dim = ext_shape[dim_index];
380 let stride_before = if dim_index == 0 {
381 1usize
382 } else {
383 product(&ext_shape[..dim_index]).max(1)
384 };
385 let stride_after = if dim_index + 1 >= ext_shape.len() {
386 1usize
387 } else {
388 product(&ext_shape[dim_index + 1..]).max(1)
389 };
390
391 let mut out = vec![0.0; data.len()];
392 if len_dim > 1 {
393 let block = stride_before
394 .checked_mul(len_dim)
395 .ok_or_else(|| gradient_error("gradient: block size overflow"))?;
396 for after in 0..stride_after {
397 let base = after
398 .checked_mul(block)
399 .ok_or_else(|| gradient_error("gradient: indexing overflow"))?;
400 for before in 0..stride_before {
401 for k in 0..len_dim {
402 let idx = base + before + k * stride_before;
403 out[idx] = if k == 0 {
404 (data[idx + stride_before] - data[idx]) / spacing
405 } else if k + 1 == len_dim {
406 (data[idx] - data[idx - stride_before]) / spacing
407 } else {
408 (data[idx + stride_before] - data[idx - stride_before]) / (2.0 * spacing)
409 };
410 }
411 }
412 }
413 }
414
415 Tensor::new_with_dtype(out, shape, dtype).map_err(|e| gradient_error(format!("gradient: {e}")))
416}
417
418pub fn gradient_complex_tensor_host(
419 tensor: ComplexTensor,
420 dim: usize,
421 spacing: f64,
422) -> BuiltinResult<ComplexTensor> {
423 let ComplexTensor { data, shape, .. } = tensor;
424 let dim_index = dim.saturating_sub(1);
425 let mut shape = matlab_gradient_shape(&shape, data.len());
426
427 if data.is_empty() {
428 let empty_shape = if shape.is_empty() { vec![0, 0] } else { shape };
431 return ComplexTensor::new(Vec::new(), empty_shape)
432 .map_err(|e| gradient_error(format!("gradient: {e}")));
433 }
434
435 while shape.len() <= dim_index {
436 shape.push(1);
437 }
438
439 let mut ext_shape = shape.clone();
440 while ext_shape.len() <= dim_index {
441 ext_shape.push(1);
442 }
443 let len_dim = ext_shape[dim_index];
444 let stride_before = if dim_index == 0 {
445 1usize
446 } else {
447 product(&ext_shape[..dim_index]).max(1)
448 };
449 let stride_after = if dim_index + 1 >= ext_shape.len() {
450 1usize
451 } else {
452 product(&ext_shape[dim_index + 1..]).max(1)
453 };
454
455 let mut out = vec![(0.0, 0.0); data.len()];
456 if len_dim > 1 {
457 let block = stride_before
458 .checked_mul(len_dim)
459 .ok_or_else(|| gradient_error("gradient: block size overflow"))?;
460 for after in 0..stride_after {
461 let base = after
462 .checked_mul(block)
463 .ok_or_else(|| gradient_error("gradient: indexing overflow"))?;
464 for before in 0..stride_before {
465 for k in 0..len_dim {
466 let idx = base + before + k * stride_before;
467 out[idx] = if k == 0 {
468 scale_complex(
469 sub_complex(data[idx + stride_before], data[idx]),
470 1.0 / spacing,
471 )
472 } else if k + 1 == len_dim {
473 scale_complex(
474 sub_complex(data[idx], data[idx - stride_before]),
475 1.0 / spacing,
476 )
477 } else {
478 scale_complex(
479 sub_complex(data[idx + stride_before], data[idx - stride_before]),
480 0.5 / spacing,
481 )
482 };
483 }
484 }
485 }
486 }
487
488 ComplexTensor::new(out, shape).map_err(|e| gradient_error(format!("gradient: {e}")))
489}
490
491fn sub_complex(lhs: (f64, f64), rhs: (f64, f64)) -> (f64, f64) {
492 (lhs.0 - rhs.0, lhs.1 - rhs.1)
493}
494
495fn scale_complex(value: (f64, f64), scale: f64) -> (f64, f64) {
496 (value.0 * scale, value.1 * scale)
497}
498
499fn product(dims: &[usize]) -> usize {
500 dims.iter()
501 .copied()
502 .fold(1usize, |acc, value| acc.saturating_mul(value))
503}
504
505#[cfg(test)]
506mod tests {
507 use super::*;
508 use crate::builtins::common::test_support;
509 use futures::executor::block_on;
510 #[cfg(feature = "wgpu")]
511 use runmat_accelerate_api::AccelProvider;
512 #[cfg(feature = "wgpu")]
513 use runmat_accelerate_api::HostTensorView;
514 use runmat_builtins::{NumericDType, Tensor};
515
516 fn gradient_builtin(value: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
517 block_on(super::gradient_builtin(value, rest))
518 }
519
520 #[test]
521 fn gradient_row_vector_returns_horizontal_derivative() {
522 let tensor = Tensor::new(vec![1.0, 4.0, 9.0], vec![1, 3]).unwrap();
523 let result = gradient_builtin(Value::Tensor(tensor), Vec::new()).expect("gradient");
524 assert_eq!(
525 result,
526 Value::Tensor(Tensor::new(vec![3.0, 4.0, 5.0], vec![1, 3]).unwrap())
527 );
528 }
529
530 #[test]
531 fn gradient_one_dimensional_tensor_is_treated_as_row_vector() {
532 let tensor = Tensor::new(vec![1.0, 4.0, 9.0], vec![3]).unwrap();
533 let result =
534 gradient_builtin(Value::Tensor(tensor), vec![Value::Num(2.0)]).expect("gradient");
535 match result {
536 Value::Tensor(out) => {
537 assert_eq!(out.shape, vec![1, 3]);
538 assert_eq!(out.data, vec![1.5, 2.0, 2.5]);
539 }
540 other => panic!("expected tensor, got {other:?}"),
541 }
542 }
543
544 #[test]
545 fn gradient_matrix_outputs_follow_matlab_order() {
546 let tensor = Tensor::new(vec![1.0, 3.0, 2.0, 4.0], vec![2, 2]).unwrap();
547 let _guard = crate::output_count::push_output_count(Some(2));
548 let result = gradient_builtin(Value::Tensor(tensor), Vec::new()).expect("gradient");
549 match result {
550 Value::OutputList(outputs) => {
551 let fx = test_support::gather(outputs[0].clone()).expect("fx");
552 let fy = test_support::gather(outputs[1].clone()).expect("fy");
553 assert_eq!(fx.data, vec![1.0, 1.0, 1.0, 1.0]);
554 assert_eq!(fy.data, vec![2.0, 2.0, 2.0, 2.0]);
555 }
556 other => panic!("expected output list, got {other:?}"),
557 }
558 }
559
560 #[test]
561 fn gradient_scalar_spacing_scales_output() {
562 let tensor = Tensor::new(vec![1.0, 4.0, 9.0], vec![1, 3]).unwrap();
563 let result =
564 gradient_builtin(Value::Tensor(tensor), vec![Value::Num(2.0)]).expect("gradient");
565 match result {
566 Value::Tensor(out) => assert_eq!(out.data, vec![1.5, 2.0, 2.5]),
567 other => panic!("expected tensor, got {other:?}"),
568 }
569 }
570
571 #[test]
572 fn gradient_preserves_single_precision_host_tensor() {
573 let tensor =
574 Tensor::new_with_dtype(vec![1.0, 4.0, 9.0], vec![1, 3], NumericDType::F32).unwrap();
575 let result = gradient_builtin(Value::Tensor(tensor), Vec::new()).expect("gradient");
576 match result {
577 Value::Tensor(out) => assert_eq!(out.dtype, NumericDType::F32),
578 other => panic!("expected tensor, got {other:?}"),
579 }
580 }
581
582 #[test]
583 fn gradient_complex_host_supported() {
584 let tensor =
585 ComplexTensor::new(vec![(1.0, 1.0), (4.0, 3.0), (9.0, 6.0)], vec![1, 3]).unwrap();
586 let result = gradient_builtin(Value::ComplexTensor(tensor), Vec::new()).expect("gradient");
587 match result {
588 Value::ComplexTensor(out) => {
589 assert_eq!(out.data, vec![(3.0, 2.0), (4.0, 2.5), (5.0, 3.0)]);
590 }
591 other => panic!("expected complex tensor, got {other:?}"),
592 }
593 }
594
595 #[test]
596 fn gradient_rejects_coordinate_vector_spacing_in_v1() {
597 let tensor = Tensor::new(vec![1.0, 4.0, 9.0], vec![1, 3]).unwrap();
598 let spacing = Tensor::new(vec![0.0, 1.0, 2.0], vec![1, 3]).unwrap();
599 let err =
600 gradient_builtin(Value::Tensor(tensor), vec![Value::Tensor(spacing)]).unwrap_err();
601 assert!(err.message().contains("scalar"));
602 }
603
604 #[test]
605 fn gradient_rejects_too_many_outputs() {
606 let tensor = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
607 let _guard = crate::output_count::push_output_count(Some(2));
608 let err = gradient_builtin(Value::Tensor(tensor), Vec::new()).unwrap_err();
609 assert!(err.message().contains("requested 2 outputs"));
610 }
611
612 #[test]
613 #[cfg(feature = "wgpu")]
614 fn gradient_gpu_scalar_spacing_matches_cpu_and_stays_resident() {
615 let _guard = test_support::accel_test_lock();
616 let Ok(provider) = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
617 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
618 ) else {
619 return;
620 };
621 let host =
622 Tensor::new_with_dtype(vec![1.0, 4.0, 9.0], vec![1, 3], NumericDType::F32).unwrap();
623 let view = HostTensorView {
624 data: &host.data,
625 shape: &host.shape,
626 };
627 let handle = provider.upload(&view).expect("upload");
628 let result =
629 gradient_builtin(Value::GpuTensor(handle), vec![Value::Num(2.0)]).expect("gradient");
630 match result {
631 Value::GpuTensor(out) => {
632 let gathered = test_support::gather(Value::GpuTensor(out)).expect("gather");
633 assert_eq!(gathered.data, vec![1.5, 2.0, 2.5]);
634 assert_eq!(gathered.dtype, NumericDType::F32);
635 }
636 other => panic!("expected gpu tensor, got {other:?}"),
637 }
638 }
639
640 #[test]
641 #[cfg(feature = "wgpu")]
642 fn gradient_gpu_one_dimensional_shape_matches_matlab_row_vector_semantics() {
643 let _guard = test_support::accel_test_lock();
644 let Ok(provider) = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
645 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
646 ) else {
647 return;
648 };
649 let data = [1.0, 4.0, 9.0];
650 let shape = [3usize];
651 let view = HostTensorView {
652 data: &data,
653 shape: &shape,
654 };
655 let handle = provider.upload(&view).expect("upload");
656 let result =
657 gradient_builtin(Value::GpuTensor(handle), vec![Value::Num(2.0)]).expect("gradient");
658 let gathered = test_support::gather(result).expect("gather");
659 assert_eq!(gathered.shape, vec![1, 3]);
660 assert_eq!(gathered.data, vec![1.5, 2.0, 2.5]);
661 }
662
663 #[test]
664 #[cfg(feature = "wgpu")]
665 fn gradient_gpu_multi_output_uses_output_list() {
666 let _guard = test_support::accel_test_lock();
667 let Ok(provider) = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
668 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
669 ) else {
670 return;
671 };
672 let host = Tensor::new(vec![1.0, 3.0, 2.0, 4.0], vec![2, 2]).unwrap();
673 let view = HostTensorView {
674 data: &host.data,
675 shape: &host.shape,
676 };
677 let handle = provider.upload(&view).expect("upload");
678 let _out_guard = crate::output_count::push_output_count(Some(2));
679 let result = gradient_builtin(Value::GpuTensor(handle), Vec::new()).expect("gradient");
680 match result {
681 Value::OutputList(outputs) => {
682 assert!(matches!(outputs[0], Value::GpuTensor(_)));
683 assert!(matches!(outputs[1], Value::GpuTensor(_)));
684 }
685 other => panic!("expected output list, got {other:?}"),
686 }
687 }
688}