Skip to main content

runmat_runtime/builtins/array/indexing/
ind2sub.rs

1//! MATLAB-compatible `ind2sub` builtin with GPU-aware semantics for RunMat.
2
3use runmat_accelerate_api::HostTensorView;
4use runmat_builtins::{ResolveContext, Tensor, Type, Value};
5use runmat_macros::runtime_builtin;
6
7use super::common::{
8    build_strides, dims_from_tokens, materialize_value, parse_dims, total_elements,
9};
10use crate::builtins::array::type_resolvers::size_vector_len;
11use crate::builtins::common::arg_tokens::tokens_from_context;
12use crate::builtins::common::spec::{
13    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
14    ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
15};
16use crate::builtins::common::tensor;
17use crate::{build_runtime_error, make_cell, RuntimeError};
18
19#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::array::indexing::ind2sub")]
20pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
21    name: "ind2sub",
22    op_kind: GpuOpKind::Custom("indexing"),
23    supported_precisions: &[ScalarType::F32, ScalarType::F64],
24    broadcast: BroadcastSemantics::Matlab,
25    provider_hooks: &[ProviderHook::Custom("ind2sub")],
26    constant_strategy: ConstantStrategy::InlineLiteral,
27    residency: ResidencyPolicy::NewHandle,
28    nan_mode: ReductionNaN::Include,
29    two_pass_threshold: None,
30    workgroup_size: None,
31    accepts_nan_mode: false,
32    notes: "WGPU provider executes `ind2sub` entirely on-device; other providers fall back to the host implementation and re-upload results to preserve residency.",
33};
34
35#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::array::indexing::ind2sub")]
36pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
37    name: "ind2sub",
38    shape: ShapeRequirements::Any,
39    constant_strategy: ConstantStrategy::InlineLiteral,
40    elementwise: None,
41    reduction: None,
42    emits_nan: false,
43    notes: "Index conversion is eager and does not participate in fusion today.",
44};
45
46fn ind2sub_type(args: &[Type], ctx: &ResolveContext) -> Type {
47    let Some(dims) = args.first() else {
48        return Type::Unknown;
49    };
50    let length = dims_from_tokens(&tokens_from_context(ctx))
51        .map(|values| values.len())
52        .or_else(|| size_vector_len(dims));
53    Type::Cell {
54        element_type: Some(Box::new(Type::tensor())),
55        length,
56    }
57}
58
59#[runtime_builtin(
60    name = "ind2sub",
61    category = "array/indexing",
62    summary = "Convert MATLAB column-major linear indices into per-dimension subscript arrays.",
63    keywords = "ind2sub,linear index,subscripts,column major,gpu indexing",
64    accel = "custom",
65    type_resolver(ind2sub_type),
66    builtin_path = "crate::builtins::array::indexing::ind2sub"
67)]
68async fn ind2sub_builtin(dims_val: Value, indices_val: Value) -> crate::BuiltinResult<Value> {
69    let (dims_value, dims_was_gpu) = materialize_value(dims_val, "ind2sub").await?;
70    let dims = parse_dims(&dims_value, "ind2sub").await?;
71    if dims.is_empty() {
72        return Err(ind2sub_error("Size vector must have at least one element."));
73    }
74
75    let total = total_elements(&dims, "ind2sub")?;
76    let strides = build_strides(&dims, "ind2sub")?;
77
78    if let Some(result) = try_gpu_ind2sub(&dims, &strides, total, &indices_val)? {
79        return Ok(result);
80    }
81
82    let (indices_value, indices_was_gpu) = materialize_value(indices_val, "ind2sub").await?;
83    let indices_tensor = tensor::value_into_tensor_for("ind2sub", indices_value)
84        .map_err(|message| ind2sub_error(message))?;
85
86    let subscripts = compute_subscripts(&dims, total, &strides, &indices_tensor)?;
87
88    let want_gpu = (dims_was_gpu || indices_was_gpu) && runmat_accelerate_api::provider().is_some();
89
90    let mut outputs: Vec<Value> = Vec::with_capacity(dims.len());
91    for tensor in subscripts {
92        if want_gpu {
93            #[cfg(all(test, feature = "wgpu"))]
94            {
95                if runmat_accelerate_api::provider().is_none() {
96                    let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
97                        runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
98                    );
99                }
100            }
101            if let Some(provider) = runmat_accelerate_api::provider() {
102                let view = HostTensorView {
103                    data: &tensor.data,
104                    shape: &tensor.shape,
105                };
106                if let Ok(handle) = provider.upload(&view) {
107                    outputs.push(Value::GpuTensor(handle));
108                    continue;
109                }
110            }
111        }
112        outputs.push(tensor::tensor_into_value(tensor));
113    }
114
115    make_cell(outputs, 1, dims.len()).map_err(|message| ind2sub_error(message))
116}
117
118fn try_gpu_ind2sub(
119    dims: &[usize],
120    strides: &[usize],
121    total: usize,
122    indices: &Value,
123) -> crate::BuiltinResult<Option<Value>> {
124    #[cfg(target_arch = "wasm32")]
125    {
126        let _ = (dims, strides, total, indices);
127        Ok(None)
128    }
129    #[cfg(not(target_arch = "wasm32"))]
130    {
131        #[cfg(all(test, feature = "wgpu"))]
132        {
133            if let Value::GpuTensor(h) = indices {
134                if h.device_id != 0 {
135                    let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
136                        runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
137                    );
138                }
139            }
140        }
141        let provider = match runmat_accelerate_api::provider() {
142            Some(p) => p,
143            None => return Ok(None),
144        };
145        if !provider.supports_ind2sub() {
146            return Ok(None);
147        }
148        let handle = match indices {
149            Value::GpuTensor(handle) => handle,
150            _ => return Ok(None),
151        };
152        if dims.len() != strides.len() {
153            return Err(ind2sub_error("Size vector must have at least one element."));
154        }
155        if dims.iter().any(|&d| d > u32::MAX as usize)
156            || strides.iter().any(|&s| s > u32::MAX as usize)
157            || total > u32::MAX as usize
158        {
159            return Ok(None);
160        }
161        let len = if handle.shape.is_empty() {
162            1usize
163        } else {
164            handle.shape.iter().copied().product()
165        };
166        if total == 0 && len > 0 {
167            return Err(ind2sub_error(
168                "Index exceeds number of array elements. Index must not exceed 0.",
169            ));
170        }
171        if len > u32::MAX as usize {
172            return Ok(None);
173        }
174        let output_shape = if handle.shape.is_empty() {
175            vec![len, 1]
176        } else {
177            handle.shape.clone()
178        };
179        match provider.ind2sub(dims, strides, handle, total, len, &output_shape) {
180            Ok(handles) => {
181                if handles.len() != dims.len() {
182                    return Err(ind2sub_error(
183                        "ind2sub: provider returned an unexpected number of outputs.",
184                    ));
185                }
186                let values: Vec<Value> = handles.into_iter().map(Value::GpuTensor).collect();
187                make_cell(values, 1, dims.len())
188                    .map(Some)
189                    .map_err(|message| ind2sub_error(message))
190            }
191            Err(err) => Err(ind2sub_error(err.to_string())),
192        }
193    }
194}
195
196fn compute_subscripts(
197    dims: &[usize],
198    total: usize,
199    strides: &[usize],
200    indices: &Tensor,
201) -> crate::BuiltinResult<Vec<Tensor>> {
202    if strides.len() != dims.len() {
203        return Err(ind2sub_error("Size vector must have at least one element."));
204    }
205
206    let len = indices.data.len();
207    let mut outputs: Vec<Vec<f64>> = dims.iter().map(|_| Vec::with_capacity(len)).collect();
208
209    for &value in &indices.data {
210        let idx = coerce_linear_index(value, total)?;
211        let zero_based = idx - 1;
212        for (dim_index, (&dim, &stride)) in dims.iter().zip(strides.iter()).enumerate() {
213            let coord = ((zero_based / stride) % dim) + 1;
214            outputs[dim_index].push(coord as f64);
215        }
216    }
217
218    let output_shape = if indices.shape.is_empty() {
219        vec![len, 1]
220    } else {
221        indices.shape.clone()
222    };
223
224    let mut tensors = Vec::with_capacity(dims.len());
225    for data in outputs {
226        let tensor = Tensor::new(data, output_shape.clone())
227            .map_err(|e| ind2sub_error(format!("ind2sub: {e}")))?;
228        tensors.push(tensor);
229    }
230    Ok(tensors)
231}
232
233fn coerce_linear_index(value: f64, max_index: usize) -> crate::BuiltinResult<usize> {
234    if !value.is_finite() {
235        return Err(ind2sub_error("Linear indices must be positive integers."));
236    }
237    let rounded = value.round();
238    if (rounded - value).abs() > f64::EPSILON {
239        return Err(ind2sub_error("Linear indices must be positive integers."));
240    }
241    if rounded < 1.0 {
242        return Err(ind2sub_error("Linear indices must be positive integers."));
243    }
244    if rounded > usize::MAX as f64 {
245        return Err(ind2sub_error(
246            "Index exceeds maximum supported size for this platform.",
247        ));
248    }
249    let coerced = rounded as usize;
250    if coerced > max_index {
251        return Err(ind2sub_error(format!(
252            "Index exceeds number of array elements. Index must not exceed {}.",
253            max_index
254        )));
255    }
256    Ok(coerced)
257}
258
259fn ind2sub_error(message: impl Into<String>) -> RuntimeError {
260    build_runtime_error(message).with_builtin("ind2sub").build()
261}
262
263#[cfg(test)]
264pub(crate) mod tests {
265    use crate::builtins::common::test_support;
266    use futures::executor::block_on;
267    use runmat_accelerate_api::HostTensorView;
268    use runmat_builtins::{ResolveContext, Tensor, Type, Value};
269
270    fn ind2sub_builtin(dims_val: Value, indices_val: Value) -> crate::BuiltinResult<Value> {
271        block_on(super::ind2sub_builtin(dims_val, indices_val))
272    }
273
274    fn cell_to_vec(cell: &runmat_builtins::CellArray) -> Vec<Value> {
275        cell.data.iter().map(|ptr| (**ptr).clone()).collect()
276    }
277
278    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
279    #[test]
280    fn recovers_tensor_indices() {
281        let dims = Tensor::new(vec![3.0, 4.0], vec![1, 2]).unwrap();
282        let result = ind2sub_builtin(Value::Tensor(dims), Value::Num(8.0)).unwrap();
283        match result {
284            Value::Cell(cell) => {
285                let values = cell_to_vec(&cell);
286                assert_eq!(values.len(), 2);
287                assert_eq!(values[0], Value::Num(2.0));
288                assert_eq!(values[1], Value::Num(3.0));
289            }
290            other => panic!("expected cell output, got {other:?}"),
291        }
292    }
293
294    #[test]
295    fn ind2sub_type_infers_cell_length() {
296        let dims = Type::Tensor {
297            shape: Some(vec![Some(1), Some(3)]),
298        };
299        assert_eq!(
300            super::ind2sub_type(&[dims, Type::Num], &ResolveContext::new(Vec::new())),
301            Type::Cell {
302                element_type: Some(Box::new(Type::tensor())),
303                length: Some(3)
304            }
305        );
306    }
307
308    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
309    #[test]
310    fn handles_vector_indices() {
311        let dims = Tensor::new(vec![3.0, 5.0], vec![1, 2]).unwrap();
312        let idx = Tensor::new(vec![7.0, 8.0, 9.0], vec![1, 3]).unwrap();
313        let result =
314            ind2sub_builtin(Value::Tensor(dims), Value::Tensor(idx)).expect("ind2sub result");
315        match result {
316            Value::Cell(cell) => {
317                let values = cell_to_vec(&cell);
318                assert_eq!(values.len(), 2);
319                match &values[0] {
320                    Value::Tensor(t) => {
321                        assert_eq!(t.shape, vec![1, 3]);
322                        assert_eq!(t.data, vec![1.0, 2.0, 3.0]);
323                    }
324                    other => panic!("expected tensor rows, got {other:?}"),
325                }
326                match &values[1] {
327                    Value::Tensor(t) => {
328                        assert_eq!(t.shape, vec![1, 3]);
329                        assert_eq!(t.data, vec![3.0, 3.0, 3.0]);
330                    }
331                    other => panic!("expected tensor cols, got {other:?}"),
332                }
333            }
334            other => panic!("expected cell output, got {other:?}"),
335        }
336    }
337
338    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
339    #[test]
340    fn recovers_three_dimensional_indices() {
341        let dims = Tensor::new(vec![2.0, 3.0, 4.0], vec![1, 3]).unwrap();
342        let idx = Tensor::new(vec![3.0, 11.0], vec![1, 2]).unwrap();
343        let result =
344            ind2sub_builtin(Value::Tensor(dims), Value::Tensor(idx)).expect("ind2sub result");
345        if let Value::Cell(cell) = result {
346            let values = cell_to_vec(&cell);
347            assert_eq!(values.len(), 3);
348            assert_eq!(
349                values[0],
350                Value::Tensor(Tensor::new(vec![1.0, 1.0], vec![1, 2]).unwrap())
351            );
352            assert_eq!(
353                values[1],
354                Value::Tensor(Tensor::new(vec![2.0, 3.0], vec![1, 2]).unwrap())
355            );
356            assert_eq!(
357                values[2],
358                Value::Tensor(Tensor::new(vec![1.0, 2.0], vec![1, 2]).unwrap())
359            );
360        } else {
361            panic!("expected cell output");
362        }
363    }
364
365    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
366    #[test]
367    fn errors_on_out_of_range_index() {
368        let dims = Tensor::new(vec![3.0, 4.0], vec![1, 2]).unwrap();
369        let err =
370            ind2sub_builtin(Value::Tensor(dims), Value::Num(13.0)).expect_err("expected failure");
371        assert!(
372            err.message()
373                .contains("Index exceeds number of array elements"),
374            "unexpected error: {err}"
375        );
376    }
377
378    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
379    #[test]
380    fn errors_on_zero_index() {
381        let dims = Tensor::new(vec![3.0, 4.0], vec![1, 2]).unwrap();
382        let err =
383            ind2sub_builtin(Value::Tensor(dims), Value::Num(0.0)).expect_err("expected failure");
384        assert!(
385            err.contains("Linear indices must be positive integers"),
386            "unexpected error: {err}"
387        );
388    }
389
390    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
391    #[test]
392    fn errors_on_fractional_index() {
393        let dims = Tensor::new(vec![3.0, 4.0], vec![1, 2]).unwrap();
394        let err =
395            ind2sub_builtin(Value::Tensor(dims), Value::Num(2.5)).expect_err("expected failure");
396        assert!(
397            err.contains("Linear indices must be positive integers"),
398            "unexpected error: {err}"
399        );
400    }
401
402    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
403    #[test]
404    fn errors_on_invalid_size_elements() {
405        let dims = Tensor::new(vec![3.5, 4.0], vec![1, 2]).unwrap();
406        let err = ind2sub_builtin(Value::Tensor(dims), Value::Num(5.0)).expect_err("expected fail");
407        assert!(
408            err.to_string()
409                .contains("Size arguments must be positive integers"),
410            "unexpected error: {err}"
411        );
412    }
413
414    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
415    #[test]
416    fn ind2sub_gpu_roundtrip() {
417        test_support::with_test_provider(|provider| {
418            let dims = Tensor::new(vec![3.0, 4.0], vec![1, 2]).unwrap();
419            let idx_tensor = Tensor::new(vec![10.0, 11.0], vec![2, 1]).unwrap();
420            let view = HostTensorView {
421                data: &idx_tensor.data,
422                shape: &idx_tensor.shape,
423            };
424            let handle = provider.upload(&view).expect("upload indices");
425            let result = ind2sub_builtin(Value::Tensor(dims), Value::GpuTensor(handle)).unwrap();
426            match result {
427                Value::Cell(cell) => {
428                    let values = cell_to_vec(&cell);
429                    assert_eq!(values.len(), 2);
430                    match &values[0] {
431                        Value::GpuTensor(_) => {}
432                        other => panic!("expected gpu tensor output, got {other:?}"),
433                    }
434                    match &values[1] {
435                        Value::GpuTensor(_) => {}
436                        other => panic!("expected gpu tensor output, got {other:?}"),
437                    }
438                    let rows = test_support::gather(values[0].clone()).expect("gather rows");
439                    assert_eq!(rows.shape, vec![2, 1]);
440                    assert_eq!(rows.data, vec![1.0, 2.0]);
441                    let cols = test_support::gather(values[1].clone()).expect("gather cols");
442                    assert_eq!(cols.shape, vec![2, 1]);
443                    assert_eq!(cols.data, vec![4.0, 4.0]);
444                }
445                other => panic!("expected cell output, got {other:?}"),
446            }
447        });
448    }
449
450    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
451    #[test]
452    #[cfg(feature = "wgpu")]
453    fn ind2sub_wgpu_matches_cpu() {
454        let provider_init = std::panic::catch_unwind(|| {
455            runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
456                runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
457            )
458        });
459        if let Ok(Ok(_)) = provider_init {
460            // provider successfully registered
461        } else {
462            return;
463        }
464
465        let dims_tensor = Tensor::new(vec![3.0, 4.0], vec![1, 2]).unwrap();
466        let idx_tensor = Tensor::new(vec![7.0, 8.0, 9.0], vec![1, 3]).unwrap();
467
468        let cpu = ind2sub_builtin(
469            Value::Tensor(dims_tensor.clone()),
470            Value::Tensor(idx_tensor.clone()),
471        )
472        .expect("cpu ind2sub");
473
474        let provider = runmat_accelerate_api::provider().unwrap();
475        let view = HostTensorView {
476            data: &idx_tensor.data,
477            shape: &idx_tensor.shape,
478        };
479        let handle = provider.upload(&view).expect("upload indices");
480
481        let gpu = ind2sub_builtin(Value::Tensor(dims_tensor), Value::GpuTensor(handle))
482            .expect("gpu ind2sub");
483
484        let cpu_values = match cpu {
485            Value::Cell(cell) => cell_to_vec(&cell),
486            other => panic!("expected cell output, got {other:?}"),
487        };
488        let gpu_values = match gpu {
489            Value::Cell(cell) => cell_to_vec(&cell),
490            other => panic!("expected cell output, got {other:?}"),
491        };
492
493        assert_eq!(cpu_values.len(), gpu_values.len());
494
495        for (cpu_val, gpu_val) in cpu_values.iter().zip(gpu_values.iter()) {
496            let host_cpu = test_support::gather(cpu_val.clone()).expect("gather cpu");
497            let host_gpu = test_support::gather(gpu_val.clone()).expect("gather gpu");
498            assert_eq!(host_cpu.shape, host_gpu.shape);
499            assert_eq!(host_cpu.data, host_gpu.data);
500        }
501    }
502}