Skip to main content

runmat_runtime/builtins/array/indexing/
sub2ind.rs

1//! MATLAB-compatible `sub2ind` builtin with GPU-aware semantics for RunMat.
2
3#[cfg(not(target_arch = "wasm32"))]
4use runmat_accelerate_api::GpuTensorHandle;
5use runmat_accelerate_api::HostTensorView;
6use runmat_builtins::{
7    BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
8    BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
9    ResolveContext, Tensor, Type, Value,
10};
11use runmat_macros::runtime_builtin;
12
13use super::common::{build_strides, dims_from_tokens, materialize_value, parse_dims};
14use crate::builtins::array::type_resolvers::is_scalar_type;
15use crate::builtins::common::arg_tokens::tokens_from_context;
16use crate::builtins::common::spec::{
17    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
18    ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
19};
20use crate::builtins::common::tensor;
21use crate::{build_runtime_error, RuntimeError};
22use runmat_builtins::shape_rules::element_count_if_known;
23
24#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::array::indexing::sub2ind")]
25pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
26    name: "sub2ind",
27    op_kind: GpuOpKind::Custom("indexing"),
28    supported_precisions: &[ScalarType::F32, ScalarType::F64],
29    broadcast: BroadcastSemantics::Matlab,
30    provider_hooks: &[ProviderHook::Custom("sub2ind")],
31    constant_strategy: ConstantStrategy::InlineLiteral,
32    residency: ResidencyPolicy::NewHandle,
33    nan_mode: ReductionNaN::Include,
34    two_pass_threshold: None,
35    workgroup_size: None,
36    accepts_nan_mode: false,
37    notes: "Providers can implement the custom `sub2ind` hook to execute on device; runtimes fall back to host computation otherwise.",
38};
39
40#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::array::indexing::sub2ind")]
41pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
42    name: "sub2ind",
43    shape: ShapeRequirements::Any,
44    constant_strategy: ConstantStrategy::InlineLiteral,
45    elementwise: None,
46    reduction: None,
47    emits_nan: false,
48    notes: "Index conversion executes eagerly on the host; fusion does not apply.",
49};
50
51fn sub2ind_type(args: &[Type], ctx: &ResolveContext) -> Type {
52    if args.len() < 2 {
53        return Type::Unknown;
54    }
55    if let Some(dims) = dims_from_tokens(&tokens_from_context(ctx)) {
56        if args.len() - 1 != dims.len() {
57            return Type::Unknown;
58        }
59    }
60    let subscripts = &args[1..];
61    if subscripts.iter().all(|ty| is_scalar_type(ty)) {
62        return Type::Num;
63    }
64    for ty in subscripts {
65        if let Type::Tensor { shape: Some(shape) } | Type::Logical { shape: Some(shape) } = ty {
66            if element_count_if_known(shape).unwrap_or(0) > 1 {
67                return Type::Tensor {
68                    shape: Some(shape.clone()),
69                };
70            }
71        }
72    }
73    Type::tensor()
74}
75
76const BUILTIN_NAME: &str = "sub2ind";
77
78const SUB2IND_OUTPUT: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
79    name: "ind",
80    ty: BuiltinParamType::NumericArray,
81    arity: BuiltinParamArity::Required,
82    default: None,
83    description: "Column-major linear indices corresponding to provided subscripts.",
84}];
85
86const SUB2IND_INPUTS: [BuiltinParamDescriptor; 3] = [
87    BuiltinParamDescriptor {
88        name: "sz",
89        ty: BuiltinParamType::SizeArg,
90        arity: BuiltinParamArity::Required,
91        default: None,
92        description: "Size vector describing source array dimensions.",
93    },
94    BuiltinParamDescriptor {
95        name: "I1",
96        ty: BuiltinParamType::Any,
97        arity: BuiltinParamArity::Required,
98        default: None,
99        description: "First-dimension subscript values.",
100    },
101    BuiltinParamDescriptor {
102        name: "In",
103        ty: BuiltinParamType::Any,
104        arity: BuiltinParamArity::Variadic,
105        default: None,
106        description: "Remaining per-dimension subscript arrays/scalars.",
107    },
108];
109
110const SUB2IND_SIGNATURES: [BuiltinSignatureDescriptor; 1] = [BuiltinSignatureDescriptor {
111    label: "ind = sub2ind(sz, I1, In...)",
112    inputs: &SUB2IND_INPUTS,
113    outputs: &SUB2IND_OUTPUT,
114}];
115
116const SUB2IND_ERROR_INVALID_INPUT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
117    code: "RM.SUB2IND.INVALID_INPUT",
118    identifier: Some("RunMat:sub2ind:InvalidInput"),
119    when: "Size vector, subscript count, or subscript types are invalid.",
120    message: "sub2ind: invalid input arguments",
121};
122
123const SUB2IND_ERROR_INDEX_BOUNDS: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
124    code: "RM.SUB2IND.INDEX_BOUNDS",
125    identifier: Some("RunMat:sub2ind:IndexBounds"),
126    when: "At least one subscript lies outside bounds for its dimension.",
127    message: "sub2ind: subscript index exceeds dimension bounds",
128};
129
130const SUB2IND_ERROR_PROVIDER: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
131    code: "RM.SUB2IND.PROVIDER",
132    identifier: Some("RunMat:sub2ind:ProviderError"),
133    when: "GPU provider sub2ind hook fails.",
134    message: "sub2ind: provider execution failed",
135};
136
137const SUB2IND_ERROR_INTERNAL: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
138    code: "RM.SUB2IND.INTERNAL",
139    identifier: Some("RunMat:sub2ind:InternalError"),
140    when: "Internal tensor conversion/output construction fails.",
141    message: "sub2ind: internal error",
142};
143
144const SUB2IND_ERRORS: [BuiltinErrorDescriptor; 4] = [
145    SUB2IND_ERROR_INVALID_INPUT,
146    SUB2IND_ERROR_INDEX_BOUNDS,
147    SUB2IND_ERROR_PROVIDER,
148    SUB2IND_ERROR_INTERNAL,
149];
150
151pub const SUB2IND_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
152    signatures: &SUB2IND_SIGNATURES,
153    output_mode: BuiltinOutputMode::Fixed,
154    completion_policy: BuiltinCompletionPolicy::Public,
155    errors: &SUB2IND_ERRORS,
156};
157
158fn sub2ind_error_with_message(
159    message: impl Into<String>,
160    error: &'static BuiltinErrorDescriptor,
161) -> RuntimeError {
162    let mut builder = build_runtime_error(message).with_builtin(BUILTIN_NAME);
163    if let Some(identifier) = error.identifier {
164        builder = builder.with_identifier(identifier);
165    }
166    builder.build()
167}
168
169fn sub2ind_input_error(message: impl Into<String>) -> RuntimeError {
170    sub2ind_error_with_message(message, &SUB2IND_ERROR_INVALID_INPUT)
171}
172
173fn sub2ind_bounds_error(message: impl Into<String>) -> RuntimeError {
174    sub2ind_error_with_message(message, &SUB2IND_ERROR_INDEX_BOUNDS)
175}
176
177fn sub2ind_provider_error(message: impl Into<String>) -> RuntimeError {
178    sub2ind_error_with_message(message, &SUB2IND_ERROR_PROVIDER)
179}
180
181fn sub2ind_internal_error(message: impl Into<String>) -> RuntimeError {
182    sub2ind_error_with_message(message, &SUB2IND_ERROR_INTERNAL)
183}
184
185#[runtime_builtin(
186    name = "sub2ind",
187    category = "array/indexing",
188    summary = "Convert N-D subscripts to MATLAB-style column-major linear indices.",
189    keywords = "sub2ind,linear index,column major,gpu indexing",
190    accel = "custom",
191    type_resolver(sub2ind_type),
192    descriptor(crate::builtins::array::indexing::sub2ind::SUB2IND_DESCRIPTOR),
193    builtin_path = "crate::builtins::array::indexing::sub2ind"
194)]
195async fn sub2ind_builtin(dims_val: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
196    let (dims_value, dims_was_gpu) = materialize_value(dims_val, "sub2ind").await?;
197    let dims = parse_dims(&dims_value, "sub2ind").await?;
198    if dims.is_empty() {
199        return Err(sub2ind_error("Size vector must have at least one element."));
200    }
201
202    if rest.len() != dims.len() {
203        return Err(sub2ind_error(
204            "The number of subscripts supplied must equal the number of dimensions in the size vector.",
205        ));
206    }
207
208    if let Some(value) = try_gpu_sub2ind(&dims, &rest)? {
209        return Ok(value);
210    }
211
212    let mut saw_gpu = dims_was_gpu;
213    let mut subscripts: Vec<Tensor> = Vec::with_capacity(rest.len());
214    for value in rest {
215        let (materialised, was_gpu) = materialize_value(value, "sub2ind").await?;
216        saw_gpu |= was_gpu;
217        let tensor = tensor::value_into_tensor_for("sub2ind", materialised)
218            .map_err(|message| sub2ind_error(message))?;
219        subscripts.push(tensor);
220    }
221
222    let (result_data, result_shape) = compute_indices(&dims, &subscripts)?;
223    let want_gpu_output = saw_gpu && runmat_accelerate_api::provider().is_some();
224
225    if want_gpu_output {
226        #[cfg(all(test, feature = "wgpu"))]
227        {
228            if runmat_accelerate_api::provider().is_none() {
229                let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
230                    runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
231                );
232            }
233        }
234        let shape = result_shape.clone().unwrap_or_else(|| vec![1, 1]);
235        if let Some(provider) = runmat_accelerate_api::provider() {
236            let view = HostTensorView {
237                data: &result_data,
238                shape: &shape,
239            };
240            if let Ok(handle) = provider.upload(&view) {
241                return Ok(Value::GpuTensor(handle));
242            }
243        }
244    }
245
246    build_host_value(result_data, result_shape)
247}
248
249fn try_gpu_sub2ind(dims: &[usize], subs: &[Value]) -> crate::BuiltinResult<Option<Value>> {
250    #[cfg(target_arch = "wasm32")]
251    {
252        let _ = (dims, subs);
253        Ok(None)
254    }
255    #[cfg(not(target_arch = "wasm32"))]
256    {
257        #[cfg(all(test, feature = "wgpu"))]
258        {
259            if subs
260                .iter()
261                .any(|v| matches!(v, Value::GpuTensor(h) if h.device_id != 0))
262            {
263                let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
264                    runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
265                );
266            }
267        }
268        let provider = match runmat_accelerate_api::provider() {
269            Some(p) => p,
270            None => return Ok(None),
271        };
272        if !subs
273            .iter()
274            .all(|value| matches!(value, Value::GpuTensor(_)))
275        {
276            return Ok(None);
277        }
278        if dims.is_empty() {
279            return Ok(None);
280        }
281
282        let mut handles: Vec<&GpuTensorHandle> = Vec::with_capacity(subs.len());
283        for value in subs {
284            if let Value::GpuTensor(handle) = value {
285                handles.push(handle);
286            }
287        }
288
289        if handles.len() != dims.len() {
290            return Err(sub2ind_error(
291            "The number of subscripts supplied must equal the number of dimensions in the size vector.",
292        ));
293        }
294
295        let mut scalar_mask: Vec<bool> = Vec::with_capacity(handles.len());
296        let mut target_shape: Option<Vec<usize>> = None;
297        let mut result_len: usize = 1;
298        let mut saw_non_scalar = false;
299
300        for handle in &handles {
301            let len = tensor::element_count(&handle.shape);
302            let is_scalar = len == 1;
303            scalar_mask.push(is_scalar);
304            if !is_scalar {
305                saw_non_scalar = true;
306                if let Some(existing) = &target_shape {
307                    if existing != &handle.shape {
308                        return Err(sub2ind_error("Subscript inputs must have the same size."));
309                    }
310                } else {
311                    target_shape = Some(handle.shape.clone());
312                    result_len = len;
313                }
314            }
315        }
316
317        if !saw_non_scalar {
318            target_shape = Some(vec![1, 1]);
319            result_len = 1;
320        } else if let Some(shape) = &target_shape {
321            result_len = tensor::element_count(shape);
322        }
323
324        let strides = build_strides(dims, "sub2ind")?;
325        if dims.iter().any(|&d| d > u32::MAX as usize)
326            || strides.iter().any(|&s| s > u32::MAX as usize)
327            || result_len > u32::MAX as usize
328        {
329            return Ok(None);
330        }
331
332        let output_shape = target_shape.clone().unwrap_or_else(|| vec![1, 1]);
333        match provider.sub2ind(
334            dims,
335            &strides,
336            &handles,
337            &scalar_mask,
338            result_len,
339            &output_shape,
340        ) {
341            Ok(handle) => Ok(Some(Value::GpuTensor(handle))),
342            Err(err) => Err(sub2ind_provider_error(err.to_string())),
343        }
344    }
345}
346
347fn compute_indices(
348    dims: &[usize],
349    subscripts: &[Tensor],
350) -> crate::BuiltinResult<(Vec<f64>, Option<Vec<usize>>)> {
351    let mut target_shape: Option<Vec<usize>> = None;
352    let mut result_len: usize = 1;
353    let mut has_non_scalar = false;
354
355    for tensor in subscripts {
356        if tensor.data.len() != 1 {
357            has_non_scalar = true;
358            if let Some(shape) = &target_shape {
359                if &tensor.shape != shape {
360                    return Err(sub2ind_error("Subscript inputs must have the same size."));
361                }
362            } else {
363                target_shape = Some(tensor.shape.clone());
364                result_len = tensor.data.len();
365            }
366        }
367    }
368
369    if !has_non_scalar {
370        // All scalars -> scalar output
371        target_shape = Some(vec![1, 1]);
372        result_len = 1;
373    }
374
375    if result_len == 0 {
376        return Ok((Vec::new(), target_shape));
377    }
378
379    let strides = build_strides(dims, "sub2ind")?;
380    let mut output = Vec::with_capacity(result_len);
381
382    for idx in 0..result_len {
383        let mut offset: usize = 0;
384        for (dim_index, (&dim, tensor)) in dims.iter().zip(subscripts.iter()).enumerate() {
385            let raw = subscript_value(tensor, idx);
386            let coerced = coerce_subscript(raw, dim_index + 1, dim)?;
387            let term = coerced
388                .checked_sub(1)
389                .and_then(|v| v.checked_mul(strides[dim_index]))
390                .ok_or_else(|| sub2ind_bounds_error("Index exceeds array dimensions."))?;
391            offset = offset
392                .checked_add(term)
393                .ok_or_else(|| sub2ind_bounds_error("Index exceeds array dimensions."))?;
394        }
395        output.push((offset + 1) as f64);
396    }
397
398    Ok((output, target_shape))
399}
400
401fn subscript_value(tensor: &Tensor, idx: usize) -> f64 {
402    if tensor.data.len() == 1 {
403        tensor.data[0]
404    } else {
405        tensor.data[idx]
406    }
407}
408
409fn coerce_subscript(value: f64, dim_number: usize, dim_size: usize) -> crate::BuiltinResult<usize> {
410    if !value.is_finite() {
411        return Err(sub2ind_error(
412            "Subscript indices must either be real positive integers or logicals.",
413        ));
414    }
415    let rounded = value.round();
416    if (rounded - value).abs() > f64::EPSILON {
417        return Err(sub2ind_error(
418            "Subscript indices must either be real positive integers or logicals.",
419        ));
420    }
421    if rounded < 1.0 {
422        return Err(sub2ind_error(
423            "Subscript indices must either be real positive integers or logicals.",
424        ));
425    }
426    if rounded > dim_size as f64 {
427        return Err(dimension_bounds_error(dim_number));
428    }
429    Ok(rounded as usize)
430}
431
432fn dimension_bounds_error(dim_number: usize) -> RuntimeError {
433    let message = match dim_number {
434        1 => format!("Index exceeds the number of rows in dimension {dim_number}."),
435        2 => format!("Index exceeds the number of columns in dimension {dim_number}."),
436        3 => format!("Index exceeds the number of pages in dimension {dim_number}."),
437        _ => "Index exceeds array dimensions.".to_string(),
438    };
439    sub2ind_bounds_error(message)
440}
441
442fn build_host_value(data: Vec<f64>, shape: Option<Vec<usize>>) -> crate::BuiltinResult<Value> {
443    let shape = shape.unwrap_or_else(|| vec![1, 1]);
444    if data.len() == 1 && tensor::element_count(&shape) == 1 {
445        Ok(Value::Num(data[0]))
446    } else {
447        let tensor = Tensor::new(data, shape).map_err(|e| {
448            sub2ind_internal_error(format!("Unable to construct sub2ind output: {e}"))
449        })?;
450        Ok(Value::Tensor(tensor))
451    }
452}
453
454fn sub2ind_error(message: impl Into<String>) -> RuntimeError {
455    sub2ind_input_error(message)
456}
457
458#[cfg(test)]
459pub(crate) mod tests {
460    use super::*;
461    use crate::builtins::common::test_support;
462    use futures::executor::block_on;
463    use runmat_builtins::{IntValue, Tensor, Type, Value};
464
465    fn sub2ind_builtin(dims_val: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
466        block_on(super::sub2ind_builtin(dims_val, rest))
467    }
468
469    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
470    #[test]
471    fn converts_scalar_indices() {
472        let dims = Tensor::new(vec![3.0, 4.0], vec![1, 2]).unwrap();
473        let result =
474            sub2ind_builtin(Value::Tensor(dims), vec![Value::Num(2.0), Value::Num(3.0)]).unwrap();
475        assert_eq!(result, Value::Num(8.0));
476    }
477
478    #[test]
479    fn sub2ind_type_scalar_outputs_num() {
480        assert_eq!(
481            sub2ind_type(
482                &[Type::Tensor { shape: None }, Type::Num, Type::Int],
483                &ResolveContext::new(Vec::new()),
484            ),
485            Type::Num
486        );
487    }
488
489    #[test]
490    fn sub2ind_type_vector_outputs_tensor() {
491        let subs = Type::Tensor {
492            shape: Some(vec![Some(3), Some(1)]),
493        };
494        assert_eq!(
495            sub2ind_type(
496                &[Type::Tensor { shape: None }, subs.clone(), Type::Num],
497                &ResolveContext::new(Vec::new()),
498            ),
499            Type::Tensor {
500                shape: Some(vec![Some(3), Some(1)])
501            }
502        );
503    }
504
505    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
506    #[test]
507    fn broadcasts_scalars_over_vectors() {
508        let dims = Tensor::new(vec![3.0, 4.0], vec![1, 2]).unwrap();
509        let rows = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
510        let result = sub2ind_builtin(
511            Value::Tensor(dims),
512            vec![Value::Tensor(rows), Value::Num(4.0)],
513        )
514        .unwrap();
515        match result {
516            Value::Tensor(t) => {
517                assert_eq!(t.shape, vec![3, 1]);
518                assert_eq!(t.data, vec![10.0, 11.0, 12.0]);
519            }
520            other => panic!("expected tensor result, got {other:?}"),
521        }
522    }
523
524    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
525    #[test]
526    fn handles_three_dimensions() {
527        let dims = Tensor::new(vec![2.0, 3.0, 4.0], vec![1, 3]).unwrap();
528        let row = Tensor::new(vec![1.0, 1.0], vec![1, 2]).unwrap();
529        let col = Tensor::new(vec![2.0, 3.0], vec![1, 2]).unwrap();
530        let page = Tensor::new(vec![1.0, 2.0], vec![1, 2]).unwrap();
531        let result = sub2ind_builtin(
532            Value::Tensor(dims),
533            vec![Value::Tensor(row), Value::Tensor(col), Value::Tensor(page)],
534        )
535        .unwrap();
536        match result {
537            Value::Tensor(t) => {
538                assert_eq!(t.shape, vec![1, 2]);
539                assert_eq!(t.data, vec![3.0, 11.0]);
540            }
541            other => panic!("expected tensor result, got {other:?}"),
542        }
543    }
544
545    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
546    #[test]
547    fn rejects_out_of_range_subscripts() {
548        let dims = Tensor::new(vec![3.0, 4.0], vec![1, 2]).unwrap();
549        let err = sub2ind_builtin(Value::Tensor(dims), vec![Value::Num(4.0), Value::Num(1.0)])
550            .unwrap_err();
551        assert!(
552            err.to_string().contains("Index exceeds"),
553            "expected index bounds error, got {err}"
554        );
555        assert_eq!(
556            err.identifier(),
557            super::SUB2IND_ERROR_INDEX_BOUNDS.identifier
558        );
559    }
560
561    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
562    #[test]
563    fn rejects_shape_mismatch() {
564        let dims = Tensor::new(vec![3.0, 4.0], vec![1, 2]).unwrap();
565        let rows = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
566        let cols = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
567        let err = sub2ind_builtin(
568            Value::Tensor(dims),
569            vec![Value::Tensor(rows), Value::Tensor(cols)],
570        )
571        .unwrap_err();
572        assert!(
573            err.to_string().contains("same size"),
574            "expected size mismatch error, got {err}"
575        );
576        assert_eq!(
577            err.identifier(),
578            super::SUB2IND_ERROR_INVALID_INPUT.identifier
579        );
580    }
581
582    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
583    #[test]
584    fn rejects_non_integer_subscripts() {
585        let dims = Tensor::new(vec![3.0, 4.0], vec![1, 2]).unwrap();
586        let err = sub2ind_builtin(Value::Tensor(dims), vec![Value::Num(1.5), Value::Num(1.0)])
587            .unwrap_err();
588        assert!(
589            err.to_string().contains("real positive integers"),
590            "expected integer coercion error, got {err}"
591        );
592        assert_eq!(
593            err.identifier(),
594            super::SUB2IND_ERROR_INVALID_INPUT.identifier
595        );
596    }
597
598    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
599    #[test]
600    fn accepts_integer_value_variants() {
601        let dims = Value::Tensor(Tensor::new(vec![3.0], vec![1, 1]).unwrap());
602        let result = sub2ind_builtin(dims, vec![Value::Int(IntValue::I32(2))]).expect("sub2ind");
603        assert_eq!(result, Value::Num(2.0));
604    }
605
606    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
607    #[test]
608    fn sub2ind_gpu_roundtrip() {
609        test_support::with_test_provider(|provider| {
610            let dims = Tensor::new(vec![3.0, 4.0], vec![1, 2]).unwrap();
611            let rows = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
612            let cols = Tensor::new(vec![4.0, 4.0, 4.0], vec![3, 1]).unwrap();
613
614            let dims_handle = provider
615                .upload(&HostTensorView {
616                    data: &dims.data,
617                    shape: &dims.shape,
618                })
619                .expect("upload dims");
620            let rows_handle = provider
621                .upload(&HostTensorView {
622                    data: &rows.data,
623                    shape: &rows.shape,
624                })
625                .expect("upload rows");
626            let cols_handle = provider
627                .upload(&HostTensorView {
628                    data: &cols.data,
629                    shape: &cols.shape,
630                })
631                .expect("upload cols");
632
633            let result = sub2ind_builtin(
634                Value::GpuTensor(dims_handle),
635                vec![Value::GpuTensor(rows_handle), Value::GpuTensor(cols_handle)],
636            )
637            .expect("sub2ind");
638
639            match result {
640                Value::GpuTensor(handle) => {
641                    let gathered = test_support::gather(Value::GpuTensor(handle)).unwrap();
642                    assert_eq!(gathered.shape, vec![3, 1]);
643                    assert_eq!(gathered.data, vec![10.0, 11.0, 12.0]);
644                }
645                other => panic!("expected gpu tensor, got {other:?}"),
646            }
647        });
648    }
649
650    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
651    #[test]
652    #[cfg(feature = "wgpu")]
653    fn sub2ind_wgpu_matches_cpu() {
654        let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
655            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
656        );
657        let Some(provider) = runmat_accelerate_api::provider() else {
658            panic!("wgpu provider not available");
659        };
660
661        let dims = Tensor::new(vec![3.0, 4.0], vec![1, 2]).unwrap();
662        let rows = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
663        let cols = Tensor::new(vec![4.0, 4.0, 4.0], vec![3, 1]).unwrap();
664
665        let cpu = sub2ind_builtin(
666            Value::Tensor(dims.clone()),
667            vec![Value::Tensor(rows.clone()), Value::Tensor(cols.clone())],
668        )
669        .expect("cpu sub2ind");
670
671        let rows_handle = provider
672            .upload(&HostTensorView {
673                data: &rows.data,
674                shape: &rows.shape,
675            })
676            .expect("upload rows");
677        let cols_handle = provider
678            .upload(&HostTensorView {
679                data: &cols.data,
680                shape: &cols.shape,
681            })
682            .expect("upload cols");
683
684        let result = sub2ind_builtin(
685            Value::Tensor(dims),
686            vec![Value::GpuTensor(rows_handle), Value::GpuTensor(cols_handle)],
687        )
688        .expect("wgpu sub2ind");
689
690        let gathered = test_support::gather(result).expect("gather");
691        let expected = match cpu {
692            Value::Tensor(t) => t,
693            Value::Num(v) => Tensor::new(vec![v], vec![1, 1]).unwrap(),
694            other => panic!("unexpected cpu result {other:?}"),
695        };
696        assert_eq!(gathered.shape, expected.shape);
697        assert_eq!(gathered.data, expected.data);
698    }
699}