Skip to main content

runmat_runtime/builtins/array/indexing/
ind2sub.rs

1//! MATLAB-compatible `ind2sub` builtin with GPU-aware semantics for RunMat.
2
3use runmat_accelerate_api::HostTensorView;
4use runmat_builtins::{
5    BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
6    BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
7    ResolveContext, Tensor, Type, Value,
8};
9use runmat_macros::runtime_builtin;
10
11use super::common::{
12    build_strides, dims_from_tokens, materialize_value, parse_dims, total_elements,
13};
14use crate::builtins::array::type_resolvers::size_vector_len;
15use crate::builtins::common::arg_tokens::tokens_from_context;
16use crate::builtins::common::spec::{
17    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
18    ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
19};
20use crate::builtins::common::tensor;
21use crate::{build_runtime_error, make_cell, RuntimeError};
22
23#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::array::indexing::ind2sub")]
24pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
25    name: "ind2sub",
26    op_kind: GpuOpKind::Custom("indexing"),
27    supported_precisions: &[ScalarType::F32, ScalarType::F64],
28    broadcast: BroadcastSemantics::Matlab,
29    provider_hooks: &[ProviderHook::Custom("ind2sub")],
30    constant_strategy: ConstantStrategy::InlineLiteral,
31    residency: ResidencyPolicy::NewHandle,
32    nan_mode: ReductionNaN::Include,
33    two_pass_threshold: None,
34    workgroup_size: None,
35    accepts_nan_mode: false,
36    notes: "WGPU provider executes `ind2sub` entirely on-device; other providers fall back to the host implementation and re-upload results to preserve residency.",
37};
38
39#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::array::indexing::ind2sub")]
40pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
41    name: "ind2sub",
42    shape: ShapeRequirements::Any,
43    constant_strategy: ConstantStrategy::InlineLiteral,
44    elementwise: None,
45    reduction: None,
46    emits_nan: false,
47    notes: "Index conversion is eager and does not participate in fusion today.",
48};
49
50fn ind2sub_type(args: &[Type], ctx: &ResolveContext) -> Type {
51    let Some(dims) = args.first() else {
52        return Type::Unknown;
53    };
54    let length = dims_from_tokens(&tokens_from_context(ctx))
55        .map(|values| values.len())
56        .or_else(|| size_vector_len(dims));
57    Type::Cell {
58        element_type: Some(Box::new(Type::tensor())),
59        length,
60    }
61}
62
63const BUILTIN_NAME: &str = "ind2sub";
64
65const IND2SUB_OUTPUT_CELL: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
66    name: "subs",
67    ty: BuiltinParamType::Any,
68    arity: BuiltinParamArity::Required,
69    default: None,
70    description: "Cell array containing one subscript output per dimension.",
71}];
72
73const IND2SUB_INPUTS: [BuiltinParamDescriptor; 2] = [
74    BuiltinParamDescriptor {
75        name: "sz",
76        ty: BuiltinParamType::SizeArg,
77        arity: BuiltinParamArity::Required,
78        default: None,
79        description: "Size vector describing source array dimensions.",
80    },
81    BuiltinParamDescriptor {
82        name: "ind",
83        ty: BuiltinParamType::Any,
84        arity: BuiltinParamArity::Required,
85        default: None,
86        description: "Linear indices to convert into per-dimension subscripts.",
87    },
88];
89
90const IND2SUB_SIGNATURES: [BuiltinSignatureDescriptor; 1] = [BuiltinSignatureDescriptor {
91    label: "subs = ind2sub(sz, ind)",
92    inputs: &IND2SUB_INPUTS,
93    outputs: &IND2SUB_OUTPUT_CELL,
94}];
95
96const IND2SUB_ERROR_INVALID_INPUT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
97    code: "RM.IND2SUB.INVALID_INPUT",
98    identifier: Some("RunMat:ind2sub:InvalidInput"),
99    when: "Size vector or linear index inputs are malformed or unsupported.",
100    message: "ind2sub: invalid input arguments",
101};
102
103const IND2SUB_ERROR_INDEX_BOUNDS: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
104    code: "RM.IND2SUB.INDEX_BOUNDS",
105    identifier: Some("RunMat:ind2sub:IndexBounds"),
106    when: "At least one provided linear index exceeds array element bounds.",
107    message: "ind2sub: index exceeds array bounds",
108};
109
110const IND2SUB_ERROR_PROVIDER: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
111    code: "RM.IND2SUB.PROVIDER",
112    identifier: Some("RunMat:ind2sub:ProviderError"),
113    when: "Provider-side ind2sub execution fails or returns malformed outputs.",
114    message: "ind2sub: provider execution failed",
115};
116
117const IND2SUB_ERROR_INTERNAL: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
118    code: "RM.IND2SUB.INTERNAL",
119    identifier: Some("RunMat:ind2sub:InternalError"),
120    when: "Internal tensor/materialization logic fails while building outputs.",
121    message: "ind2sub: internal error",
122};
123
124const IND2SUB_ERRORS: [BuiltinErrorDescriptor; 4] = [
125    IND2SUB_ERROR_INVALID_INPUT,
126    IND2SUB_ERROR_INDEX_BOUNDS,
127    IND2SUB_ERROR_PROVIDER,
128    IND2SUB_ERROR_INTERNAL,
129];
130
131pub const IND2SUB_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
132    signatures: &IND2SUB_SIGNATURES,
133    output_mode: BuiltinOutputMode::Fixed,
134    completion_policy: BuiltinCompletionPolicy::Public,
135    errors: &IND2SUB_ERRORS,
136};
137
138fn ind2sub_error_with_message(
139    message: impl Into<String>,
140    error: &'static BuiltinErrorDescriptor,
141) -> RuntimeError {
142    let mut builder = build_runtime_error(message).with_builtin(BUILTIN_NAME);
143    if let Some(identifier) = error.identifier {
144        builder = builder.with_identifier(identifier);
145    }
146    builder.build()
147}
148
149fn ind2sub_input_error(message: impl Into<String>) -> RuntimeError {
150    ind2sub_error_with_message(message, &IND2SUB_ERROR_INVALID_INPUT)
151}
152
153fn ind2sub_internal_error(message: impl Into<String>) -> RuntimeError {
154    ind2sub_error_with_message(message, &IND2SUB_ERROR_INTERNAL)
155}
156
157fn ind2sub_provider_error(message: impl Into<String>) -> RuntimeError {
158    ind2sub_error_with_message(message, &IND2SUB_ERROR_PROVIDER)
159}
160
161#[runtime_builtin(
162    name = "ind2sub",
163    category = "array/indexing",
164    summary = "Convert linear indices to subscripts.",
165    keywords = "ind2sub,linear index,subscripts,column major,gpu indexing",
166    accel = "custom",
167    type_resolver(ind2sub_type),
168    descriptor(crate::builtins::array::indexing::ind2sub::IND2SUB_DESCRIPTOR),
169    builtin_path = "crate::builtins::array::indexing::ind2sub"
170)]
171async fn ind2sub_builtin(dims_val: Value, indices_val: Value) -> crate::BuiltinResult<Value> {
172    let (dims_value, dims_was_gpu) = materialize_value(dims_val, "ind2sub").await?;
173    let dims = parse_dims(&dims_value, "ind2sub").await?;
174    if dims.is_empty() {
175        return Err(ind2sub_error("Size vector must have at least one element."));
176    }
177
178    let total = total_elements(&dims, "ind2sub")?;
179    let strides = build_strides(&dims, "ind2sub")?;
180
181    if let Some(result) = try_gpu_ind2sub(&dims, &strides, total, &indices_val)? {
182        return Ok(result);
183    }
184
185    let (indices_value, indices_was_gpu) = materialize_value(indices_val, "ind2sub").await?;
186    let indices_tensor = tensor::value_into_tensor_for("ind2sub", indices_value)
187        .map_err(|message| ind2sub_error(message))?;
188
189    let subscripts = compute_subscripts(&dims, total, &strides, &indices_tensor)?;
190
191    let want_gpu = (dims_was_gpu || indices_was_gpu) && runmat_accelerate_api::provider().is_some();
192
193    let mut outputs: Vec<Value> = Vec::with_capacity(dims.len());
194    for tensor in subscripts {
195        if want_gpu {
196            #[cfg(all(test, feature = "wgpu"))]
197            {
198                if runmat_accelerate_api::provider().is_none() {
199                    let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
200                        runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
201                    );
202                }
203            }
204            if let Some(provider) = runmat_accelerate_api::provider() {
205                let view = HostTensorView {
206                    data: &tensor.data,
207                    shape: &tensor.shape,
208                };
209                if let Ok(handle) = provider.upload(&view) {
210                    outputs.push(Value::GpuTensor(handle));
211                    continue;
212                }
213            }
214        }
215        outputs.push(tensor::tensor_into_value(tensor));
216    }
217
218    make_cell(outputs, 1, dims.len()).map_err(|message| ind2sub_error(message))
219}
220
221fn try_gpu_ind2sub(
222    dims: &[usize],
223    strides: &[usize],
224    total: usize,
225    indices: &Value,
226) -> crate::BuiltinResult<Option<Value>> {
227    #[cfg(target_arch = "wasm32")]
228    {
229        let _ = (dims, strides, total, indices);
230        Ok(None)
231    }
232    #[cfg(not(target_arch = "wasm32"))]
233    {
234        #[cfg(all(test, feature = "wgpu"))]
235        {
236            if let Value::GpuTensor(h) = indices {
237                if h.device_id != 0 {
238                    let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
239                        runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
240                    );
241                }
242            }
243        }
244        let provider = match runmat_accelerate_api::provider() {
245            Some(p) => p,
246            None => return Ok(None),
247        };
248        if !provider.supports_ind2sub() {
249            return Ok(None);
250        }
251        let handle = match indices {
252            Value::GpuTensor(handle) => handle,
253            _ => return Ok(None),
254        };
255        if dims.len() != strides.len() {
256            return Err(ind2sub_error("Size vector must have at least one element."));
257        }
258        if dims.iter().any(|&d| d > u32::MAX as usize)
259            || strides.iter().any(|&s| s > u32::MAX as usize)
260            || total > u32::MAX as usize
261        {
262            return Ok(None);
263        }
264        let len = if handle.shape.is_empty() {
265            1usize
266        } else {
267            handle.shape.iter().copied().product()
268        };
269        if total == 0 && len > 0 {
270            return Err(ind2sub_error(
271                "Index exceeds number of array elements. Index must not exceed 0.",
272            ));
273        }
274        if len > u32::MAX as usize {
275            return Ok(None);
276        }
277        let output_shape = if handle.shape.is_empty() {
278            vec![len, 1]
279        } else {
280            handle.shape.clone()
281        };
282        match provider.ind2sub(dims, strides, handle, total, len, &output_shape) {
283            Ok(handles) => {
284                if handles.len() != dims.len() {
285                    return Err(ind2sub_provider_error(
286                        "ind2sub: provider returned an unexpected number of outputs.",
287                    ));
288                }
289                let values: Vec<Value> = handles.into_iter().map(Value::GpuTensor).collect();
290                make_cell(values, 1, dims.len())
291                    .map(Some)
292                    .map_err(|message| ind2sub_error(message))
293            }
294            Err(err) => Err(ind2sub_provider_error(err.to_string())),
295        }
296    }
297}
298
299fn compute_subscripts(
300    dims: &[usize],
301    total: usize,
302    strides: &[usize],
303    indices: &Tensor,
304) -> crate::BuiltinResult<Vec<Tensor>> {
305    if strides.len() != dims.len() {
306        return Err(ind2sub_error("Size vector must have at least one element."));
307    }
308
309    let len = indices.data.len();
310    let mut outputs: Vec<Vec<f64>> = dims.iter().map(|_| Vec::with_capacity(len)).collect();
311
312    for &value in &indices.data {
313        let idx = coerce_linear_index(value, total)?;
314        let zero_based = idx - 1;
315        for (dim_index, (&dim, &stride)) in dims.iter().zip(strides.iter()).enumerate() {
316            let coord = ((zero_based / stride) % dim) + 1;
317            outputs[dim_index].push(coord as f64);
318        }
319    }
320
321    let output_shape = if indices.shape.is_empty() {
322        vec![len, 1]
323    } else {
324        indices.shape.clone()
325    };
326
327    let mut tensors = Vec::with_capacity(dims.len());
328    for data in outputs {
329        let tensor = Tensor::new(data, output_shape.clone())
330            .map_err(|e| ind2sub_internal_error(format!("ind2sub: {e}")))?;
331        tensors.push(tensor);
332    }
333    Ok(tensors)
334}
335
336fn coerce_linear_index(value: f64, max_index: usize) -> crate::BuiltinResult<usize> {
337    if !value.is_finite() {
338        return Err(ind2sub_error("Linear indices must be positive integers."));
339    }
340    let rounded = value.round();
341    if (rounded - value).abs() > f64::EPSILON {
342        return Err(ind2sub_error("Linear indices must be positive integers."));
343    }
344    if rounded < 1.0 {
345        return Err(ind2sub_error("Linear indices must be positive integers."));
346    }
347    if rounded > usize::MAX as f64 {
348        return Err(ind2sub_error(
349            "Index exceeds maximum supported size for this platform.",
350        ));
351    }
352    let coerced = rounded as usize;
353    if coerced > max_index {
354        return Err(ind2sub_error_with_message(
355            format!(
356                "Index exceeds number of array elements. Index must not exceed {}.",
357                max_index
358            ),
359            &IND2SUB_ERROR_INDEX_BOUNDS,
360        ));
361    }
362    Ok(coerced)
363}
364
365fn ind2sub_error(message: impl Into<String>) -> RuntimeError {
366    ind2sub_input_error(message)
367}
368
369#[cfg(test)]
370pub(crate) mod tests {
371    use crate::builtins::common::test_support;
372    use futures::executor::block_on;
373    use runmat_accelerate_api::HostTensorView;
374    use runmat_builtins::{ResolveContext, Tensor, Type, Value};
375
376    fn ind2sub_builtin(dims_val: Value, indices_val: Value) -> crate::BuiltinResult<Value> {
377        block_on(super::ind2sub_builtin(dims_val, indices_val))
378    }
379
380    fn cell_to_vec(cell: &runmat_builtins::CellArray) -> Vec<Value> {
381        cell.data.iter().map(|ptr| (**ptr).clone()).collect()
382    }
383
384    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
385    #[test]
386    fn recovers_tensor_indices() {
387        let dims = Tensor::new(vec![3.0, 4.0], vec![1, 2]).unwrap();
388        let result = ind2sub_builtin(Value::Tensor(dims), Value::Num(8.0)).unwrap();
389        match result {
390            Value::Cell(cell) => {
391                let values = cell_to_vec(&cell);
392                assert_eq!(values.len(), 2);
393                assert_eq!(values[0], Value::Num(2.0));
394                assert_eq!(values[1], Value::Num(3.0));
395            }
396            other => panic!("expected cell output, got {other:?}"),
397        }
398    }
399
400    #[test]
401    fn ind2sub_type_infers_cell_length() {
402        let dims = Type::Tensor {
403            shape: Some(vec![Some(1), Some(3)]),
404        };
405        assert_eq!(
406            super::ind2sub_type(&[dims, Type::Num], &ResolveContext::new(Vec::new())),
407            Type::Cell {
408                element_type: Some(Box::new(Type::tensor())),
409                length: Some(3)
410            }
411        );
412    }
413
414    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
415    #[test]
416    fn handles_vector_indices() {
417        let dims = Tensor::new(vec![3.0, 5.0], vec![1, 2]).unwrap();
418        let idx = Tensor::new(vec![7.0, 8.0, 9.0], vec![1, 3]).unwrap();
419        let result =
420            ind2sub_builtin(Value::Tensor(dims), Value::Tensor(idx)).expect("ind2sub result");
421        match result {
422            Value::Cell(cell) => {
423                let values = cell_to_vec(&cell);
424                assert_eq!(values.len(), 2);
425                match &values[0] {
426                    Value::Tensor(t) => {
427                        assert_eq!(t.shape, vec![1, 3]);
428                        assert_eq!(t.data, vec![1.0, 2.0, 3.0]);
429                    }
430                    other => panic!("expected tensor rows, got {other:?}"),
431                }
432                match &values[1] {
433                    Value::Tensor(t) => {
434                        assert_eq!(t.shape, vec![1, 3]);
435                        assert_eq!(t.data, vec![3.0, 3.0, 3.0]);
436                    }
437                    other => panic!("expected tensor cols, got {other:?}"),
438                }
439            }
440            other => panic!("expected cell output, got {other:?}"),
441        }
442    }
443
444    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
445    #[test]
446    fn rejects_non_integer_linear_index_identifier() {
447        let dims = Tensor::new(vec![3.0, 4.0], vec![1, 2]).unwrap();
448        let err = ind2sub_builtin(Value::Tensor(dims), Value::Num(1.25))
449            .expect_err("expected non-integer index error");
450        assert_eq!(
451            err.identifier(),
452            super::IND2SUB_ERROR_INVALID_INPUT.identifier
453        );
454    }
455
456    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
457    #[test]
458    fn rejects_out_of_bounds_linear_index_identifier() {
459        let dims = Tensor::new(vec![2.0, 2.0], vec![1, 2]).unwrap();
460        let err = ind2sub_builtin(Value::Tensor(dims), Value::Num(9.0))
461            .expect_err("expected out-of-bounds index error");
462        assert_eq!(
463            err.identifier(),
464            super::IND2SUB_ERROR_INDEX_BOUNDS.identifier
465        );
466    }
467
468    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
469    #[test]
470    fn recovers_three_dimensional_indices() {
471        let dims = Tensor::new(vec![2.0, 3.0, 4.0], vec![1, 3]).unwrap();
472        let idx = Tensor::new(vec![3.0, 11.0], vec![1, 2]).unwrap();
473        let result =
474            ind2sub_builtin(Value::Tensor(dims), Value::Tensor(idx)).expect("ind2sub result");
475        if let Value::Cell(cell) = result {
476            let values = cell_to_vec(&cell);
477            assert_eq!(values.len(), 3);
478            assert_eq!(
479                values[0],
480                Value::Tensor(Tensor::new(vec![1.0, 1.0], vec![1, 2]).unwrap())
481            );
482            assert_eq!(
483                values[1],
484                Value::Tensor(Tensor::new(vec![2.0, 3.0], vec![1, 2]).unwrap())
485            );
486            assert_eq!(
487                values[2],
488                Value::Tensor(Tensor::new(vec![1.0, 2.0], vec![1, 2]).unwrap())
489            );
490        } else {
491            panic!("expected cell output");
492        }
493    }
494
495    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
496    #[test]
497    fn errors_on_out_of_range_index() {
498        let dims = Tensor::new(vec![3.0, 4.0], vec![1, 2]).unwrap();
499        let err =
500            ind2sub_builtin(Value::Tensor(dims), Value::Num(13.0)).expect_err("expected failure");
501        assert!(
502            err.message()
503                .contains("Index exceeds number of array elements"),
504            "unexpected error: {err}"
505        );
506    }
507
508    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
509    #[test]
510    fn errors_on_zero_index() {
511        let dims = Tensor::new(vec![3.0, 4.0], vec![1, 2]).unwrap();
512        let err =
513            ind2sub_builtin(Value::Tensor(dims), Value::Num(0.0)).expect_err("expected failure");
514        assert!(
515            err.contains("Linear indices must be positive integers"),
516            "unexpected error: {err}"
517        );
518    }
519
520    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
521    #[test]
522    fn errors_on_fractional_index() {
523        let dims = Tensor::new(vec![3.0, 4.0], vec![1, 2]).unwrap();
524        let err =
525            ind2sub_builtin(Value::Tensor(dims), Value::Num(2.5)).expect_err("expected failure");
526        assert!(
527            err.contains("Linear indices must be positive integers"),
528            "unexpected error: {err}"
529        );
530    }
531
532    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
533    #[test]
534    fn errors_on_invalid_size_elements() {
535        let dims = Tensor::new(vec![3.5, 4.0], vec![1, 2]).unwrap();
536        let err = ind2sub_builtin(Value::Tensor(dims), Value::Num(5.0)).expect_err("expected fail");
537        assert!(
538            err.to_string()
539                .contains("Size arguments must be positive integers"),
540            "unexpected error: {err}"
541        );
542    }
543
544    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
545    #[test]
546    fn ind2sub_gpu_roundtrip() {
547        test_support::with_test_provider(|provider| {
548            let dims = Tensor::new(vec![3.0, 4.0], vec![1, 2]).unwrap();
549            let idx_tensor = Tensor::new(vec![10.0, 11.0], vec![2, 1]).unwrap();
550            let view = HostTensorView {
551                data: &idx_tensor.data,
552                shape: &idx_tensor.shape,
553            };
554            let handle = provider.upload(&view).expect("upload indices");
555            let result = ind2sub_builtin(Value::Tensor(dims), Value::GpuTensor(handle)).unwrap();
556            match result {
557                Value::Cell(cell) => {
558                    let values = cell_to_vec(&cell);
559                    assert_eq!(values.len(), 2);
560                    match &values[0] {
561                        Value::GpuTensor(_) => {}
562                        other => panic!("expected gpu tensor output, got {other:?}"),
563                    }
564                    match &values[1] {
565                        Value::GpuTensor(_) => {}
566                        other => panic!("expected gpu tensor output, got {other:?}"),
567                    }
568                    let rows = test_support::gather(values[0].clone()).expect("gather rows");
569                    assert_eq!(rows.shape, vec![2, 1]);
570                    assert_eq!(rows.data, vec![1.0, 2.0]);
571                    let cols = test_support::gather(values[1].clone()).expect("gather cols");
572                    assert_eq!(cols.shape, vec![2, 1]);
573                    assert_eq!(cols.data, vec![4.0, 4.0]);
574                }
575                other => panic!("expected cell output, got {other:?}"),
576            }
577        });
578    }
579
580    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
581    #[test]
582    #[cfg(feature = "wgpu")]
583    fn ind2sub_wgpu_matches_cpu() {
584        let provider_init = std::panic::catch_unwind(|| {
585            runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
586                runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
587            )
588        });
589        if let Ok(Ok(_)) = provider_init {
590            // provider successfully registered
591        } else {
592            return;
593        }
594
595        let dims_tensor = Tensor::new(vec![3.0, 4.0], vec![1, 2]).unwrap();
596        let idx_tensor = Tensor::new(vec![7.0, 8.0, 9.0], vec![1, 3]).unwrap();
597
598        let cpu = ind2sub_builtin(
599            Value::Tensor(dims_tensor.clone()),
600            Value::Tensor(idx_tensor.clone()),
601        )
602        .expect("cpu ind2sub");
603
604        let provider = runmat_accelerate_api::provider().unwrap();
605        let view = HostTensorView {
606            data: &idx_tensor.data,
607            shape: &idx_tensor.shape,
608        };
609        let handle = provider.upload(&view).expect("upload indices");
610
611        let gpu = ind2sub_builtin(Value::Tensor(dims_tensor), Value::GpuTensor(handle))
612            .expect("gpu ind2sub");
613
614        let cpu_values = match cpu {
615            Value::Cell(cell) => cell_to_vec(&cell),
616            other => panic!("expected cell output, got {other:?}"),
617        };
618        let gpu_values = match gpu {
619            Value::Cell(cell) => cell_to_vec(&cell),
620            other => panic!("expected cell output, got {other:?}"),
621        };
622
623        assert_eq!(cpu_values.len(), gpu_values.len());
624
625        for (cpu_val, gpu_val) in cpu_values.iter().zip(gpu_values.iter()) {
626            let host_cpu = test_support::gather(cpu_val.clone()).expect("gather cpu");
627            let host_gpu = test_support::gather(gpu_val.clone()).expect("gather gpu");
628            assert_eq!(host_cpu.shape, host_gpu.shape);
629            assert_eq!(host_cpu.data, host_gpu.data);
630        }
631    }
632}