Skip to main content

runmat_runtime/builtins/array/indexing/
sub2ind.rs

1//! MATLAB-compatible `sub2ind` builtin with GPU-aware semantics for RunMat.
2
3#[cfg(not(target_arch = "wasm32"))]
4use runmat_accelerate_api::GpuTensorHandle;
5use runmat_accelerate_api::HostTensorView;
6use runmat_builtins::{ResolveContext, Tensor, Type, Value};
7use runmat_macros::runtime_builtin;
8
9use super::common::{build_strides, dims_from_tokens, materialize_value, parse_dims};
10use crate::builtins::array::type_resolvers::is_scalar_type;
11use crate::builtins::common::arg_tokens::tokens_from_context;
12use crate::builtins::common::spec::{
13    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
14    ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
15};
16use crate::builtins::common::tensor;
17use crate::{build_runtime_error, RuntimeError};
18use runmat_builtins::shape_rules::element_count_if_known;
19
20#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::array::indexing::sub2ind")]
21pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
22    name: "sub2ind",
23    op_kind: GpuOpKind::Custom("indexing"),
24    supported_precisions: &[ScalarType::F32, ScalarType::F64],
25    broadcast: BroadcastSemantics::Matlab,
26    provider_hooks: &[ProviderHook::Custom("sub2ind")],
27    constant_strategy: ConstantStrategy::InlineLiteral,
28    residency: ResidencyPolicy::NewHandle,
29    nan_mode: ReductionNaN::Include,
30    two_pass_threshold: None,
31    workgroup_size: None,
32    accepts_nan_mode: false,
33    notes: "Providers can implement the custom `sub2ind` hook to execute on device; runtimes fall back to host computation otherwise.",
34};
35
36#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::array::indexing::sub2ind")]
37pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
38    name: "sub2ind",
39    shape: ShapeRequirements::Any,
40    constant_strategy: ConstantStrategy::InlineLiteral,
41    elementwise: None,
42    reduction: None,
43    emits_nan: false,
44    notes: "Index conversion executes eagerly on the host; fusion does not apply.",
45};
46
47fn sub2ind_type(args: &[Type], ctx: &ResolveContext) -> Type {
48    if args.len() < 2 {
49        return Type::Unknown;
50    }
51    if let Some(dims) = dims_from_tokens(&tokens_from_context(ctx)) {
52        if args.len() - 1 != dims.len() {
53            return Type::Unknown;
54        }
55    }
56    let subscripts = &args[1..];
57    if subscripts.iter().all(|ty| is_scalar_type(ty)) {
58        return Type::Num;
59    }
60    for ty in subscripts {
61        if let Type::Tensor { shape: Some(shape) } | Type::Logical { shape: Some(shape) } = ty {
62            if element_count_if_known(shape).unwrap_or(0) > 1 {
63                return Type::Tensor {
64                    shape: Some(shape.clone()),
65                };
66            }
67        }
68    }
69    Type::tensor()
70}
71
72#[runtime_builtin(
73    name = "sub2ind",
74    category = "array/indexing",
75    summary = "Convert N-D subscripts into MATLAB-style column-major linear indices.",
76    keywords = "sub2ind,linear index,column major,gpu indexing",
77    accel = "custom",
78    type_resolver(sub2ind_type),
79    builtin_path = "crate::builtins::array::indexing::sub2ind"
80)]
81async fn sub2ind_builtin(dims_val: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
82    let (dims_value, dims_was_gpu) = materialize_value(dims_val, "sub2ind").await?;
83    let dims = parse_dims(&dims_value, "sub2ind").await?;
84    if dims.is_empty() {
85        return Err(sub2ind_error("Size vector must have at least one element."));
86    }
87
88    if rest.len() != dims.len() {
89        return Err(sub2ind_error(
90            "The number of subscripts supplied must equal the number of dimensions in the size vector.",
91        ));
92    }
93
94    if let Some(value) = try_gpu_sub2ind(&dims, &rest)? {
95        return Ok(value);
96    }
97
98    let mut saw_gpu = dims_was_gpu;
99    let mut subscripts: Vec<Tensor> = Vec::with_capacity(rest.len());
100    for value in rest {
101        let (materialised, was_gpu) = materialize_value(value, "sub2ind").await?;
102        saw_gpu |= was_gpu;
103        let tensor = tensor::value_into_tensor_for("sub2ind", materialised)
104            .map_err(|message| sub2ind_error(message))?;
105        subscripts.push(tensor);
106    }
107
108    let (result_data, result_shape) = compute_indices(&dims, &subscripts)?;
109    let want_gpu_output = saw_gpu && runmat_accelerate_api::provider().is_some();
110
111    if want_gpu_output {
112        #[cfg(all(test, feature = "wgpu"))]
113        {
114            if runmat_accelerate_api::provider().is_none() {
115                let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
116                    runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
117                );
118            }
119        }
120        let shape = result_shape.clone().unwrap_or_else(|| vec![1, 1]);
121        if let Some(provider) = runmat_accelerate_api::provider() {
122            let view = HostTensorView {
123                data: &result_data,
124                shape: &shape,
125            };
126            if let Ok(handle) = provider.upload(&view) {
127                return Ok(Value::GpuTensor(handle));
128            }
129        }
130    }
131
132    build_host_value(result_data, result_shape)
133}
134
135fn try_gpu_sub2ind(dims: &[usize], subs: &[Value]) -> crate::BuiltinResult<Option<Value>> {
136    #[cfg(target_arch = "wasm32")]
137    {
138        let _ = (dims, subs);
139        Ok(None)
140    }
141    #[cfg(not(target_arch = "wasm32"))]
142    {
143        #[cfg(all(test, feature = "wgpu"))]
144        {
145            if subs
146                .iter()
147                .any(|v| matches!(v, Value::GpuTensor(h) if h.device_id != 0))
148            {
149                let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
150                    runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
151                );
152            }
153        }
154        let provider = match runmat_accelerate_api::provider() {
155            Some(p) => p,
156            None => return Ok(None),
157        };
158        if !subs
159            .iter()
160            .all(|value| matches!(value, Value::GpuTensor(_)))
161        {
162            return Ok(None);
163        }
164        if dims.is_empty() {
165            return Ok(None);
166        }
167
168        let mut handles: Vec<&GpuTensorHandle> = Vec::with_capacity(subs.len());
169        for value in subs {
170            if let Value::GpuTensor(handle) = value {
171                handles.push(handle);
172            }
173        }
174
175        if handles.len() != dims.len() {
176            return Err(sub2ind_error(
177            "The number of subscripts supplied must equal the number of dimensions in the size vector.",
178        ));
179        }
180
181        let mut scalar_mask: Vec<bool> = Vec::with_capacity(handles.len());
182        let mut target_shape: Option<Vec<usize>> = None;
183        let mut result_len: usize = 1;
184        let mut saw_non_scalar = false;
185
186        for handle in &handles {
187            let len = tensor::element_count(&handle.shape);
188            let is_scalar = len == 1;
189            scalar_mask.push(is_scalar);
190            if !is_scalar {
191                saw_non_scalar = true;
192                if let Some(existing) = &target_shape {
193                    if existing != &handle.shape {
194                        return Err(sub2ind_error("Subscript inputs must have the same size."));
195                    }
196                } else {
197                    target_shape = Some(handle.shape.clone());
198                    result_len = len;
199                }
200            }
201        }
202
203        if !saw_non_scalar {
204            target_shape = Some(vec![1, 1]);
205            result_len = 1;
206        } else if let Some(shape) = &target_shape {
207            result_len = tensor::element_count(shape);
208        }
209
210        let strides = build_strides(dims, "sub2ind")?;
211        if dims.iter().any(|&d| d > u32::MAX as usize)
212            || strides.iter().any(|&s| s > u32::MAX as usize)
213            || result_len > u32::MAX as usize
214        {
215            return Ok(None);
216        }
217
218        let output_shape = target_shape.clone().unwrap_or_else(|| vec![1, 1]);
219        match provider.sub2ind(
220            dims,
221            &strides,
222            &handles,
223            &scalar_mask,
224            result_len,
225            &output_shape,
226        ) {
227            Ok(handle) => Ok(Some(Value::GpuTensor(handle))),
228            Err(err) => Err(sub2ind_error(err.to_string())),
229        }
230    }
231}
232
233fn compute_indices(
234    dims: &[usize],
235    subscripts: &[Tensor],
236) -> crate::BuiltinResult<(Vec<f64>, Option<Vec<usize>>)> {
237    let mut target_shape: Option<Vec<usize>> = None;
238    let mut result_len: usize = 1;
239    let mut has_non_scalar = false;
240
241    for tensor in subscripts {
242        if tensor.data.len() != 1 {
243            has_non_scalar = true;
244            if let Some(shape) = &target_shape {
245                if &tensor.shape != shape {
246                    return Err(sub2ind_error("Subscript inputs must have the same size."));
247                }
248            } else {
249                target_shape = Some(tensor.shape.clone());
250                result_len = tensor.data.len();
251            }
252        }
253    }
254
255    if !has_non_scalar {
256        // All scalars -> scalar output
257        target_shape = Some(vec![1, 1]);
258        result_len = 1;
259    }
260
261    if result_len == 0 {
262        return Ok((Vec::new(), target_shape));
263    }
264
265    let strides = build_strides(dims, "sub2ind")?;
266    let mut output = Vec::with_capacity(result_len);
267
268    for idx in 0..result_len {
269        let mut offset: usize = 0;
270        for (dim_index, (&dim, tensor)) in dims.iter().zip(subscripts.iter()).enumerate() {
271            let raw = subscript_value(tensor, idx);
272            let coerced = coerce_subscript(raw, dim_index + 1, dim)?;
273            let term = coerced
274                .checked_sub(1)
275                .and_then(|v| v.checked_mul(strides[dim_index]))
276                .ok_or_else(|| sub2ind_error("Index exceeds array dimensions."))?;
277            offset = offset
278                .checked_add(term)
279                .ok_or_else(|| sub2ind_error("Index exceeds array dimensions."))?;
280        }
281        output.push((offset + 1) as f64);
282    }
283
284    Ok((output, target_shape))
285}
286
287fn subscript_value(tensor: &Tensor, idx: usize) -> f64 {
288    if tensor.data.len() == 1 {
289        tensor.data[0]
290    } else {
291        tensor.data[idx]
292    }
293}
294
295fn coerce_subscript(value: f64, dim_number: usize, dim_size: usize) -> crate::BuiltinResult<usize> {
296    if !value.is_finite() {
297        return Err(sub2ind_error(
298            "Subscript indices must either be real positive integers or logicals.",
299        ));
300    }
301    let rounded = value.round();
302    if (rounded - value).abs() > f64::EPSILON {
303        return Err(sub2ind_error(
304            "Subscript indices must either be real positive integers or logicals.",
305        ));
306    }
307    if rounded < 1.0 {
308        return Err(sub2ind_error(
309            "Subscript indices must either be real positive integers or logicals.",
310        ));
311    }
312    if rounded > dim_size as f64 {
313        return Err(dimension_bounds_error(dim_number));
314    }
315    Ok(rounded as usize)
316}
317
318fn dimension_bounds_error(dim_number: usize) -> RuntimeError {
319    let message = match dim_number {
320        1 => format!("Index exceeds the number of rows in dimension {dim_number}."),
321        2 => format!("Index exceeds the number of columns in dimension {dim_number}."),
322        3 => format!("Index exceeds the number of pages in dimension {dim_number}."),
323        _ => "Index exceeds array dimensions.".to_string(),
324    };
325    sub2ind_error(message)
326}
327
328fn build_host_value(data: Vec<f64>, shape: Option<Vec<usize>>) -> crate::BuiltinResult<Value> {
329    let shape = shape.unwrap_or_else(|| vec![1, 1]);
330    if data.len() == 1 && tensor::element_count(&shape) == 1 {
331        Ok(Value::Num(data[0]))
332    } else {
333        let tensor = Tensor::new(data, shape)
334            .map_err(|e| sub2ind_error(format!("Unable to construct sub2ind output: {e}")))?;
335        Ok(Value::Tensor(tensor))
336    }
337}
338
339fn sub2ind_error(message: impl Into<String>) -> RuntimeError {
340    build_runtime_error(message).with_builtin("sub2ind").build()
341}
342
343#[cfg(test)]
344pub(crate) mod tests {
345    use super::*;
346    use crate::builtins::common::test_support;
347    use futures::executor::block_on;
348    use runmat_builtins::{IntValue, Tensor, Type, Value};
349
350    fn sub2ind_builtin(dims_val: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
351        block_on(super::sub2ind_builtin(dims_val, rest))
352    }
353
354    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
355    #[test]
356    fn converts_scalar_indices() {
357        let dims = Tensor::new(vec![3.0, 4.0], vec![1, 2]).unwrap();
358        let result =
359            sub2ind_builtin(Value::Tensor(dims), vec![Value::Num(2.0), Value::Num(3.0)]).unwrap();
360        assert_eq!(result, Value::Num(8.0));
361    }
362
363    #[test]
364    fn sub2ind_type_scalar_outputs_num() {
365        assert_eq!(
366            sub2ind_type(
367                &[Type::Tensor { shape: None }, Type::Num, Type::Int],
368                &ResolveContext::new(Vec::new()),
369            ),
370            Type::Num
371        );
372    }
373
374    #[test]
375    fn sub2ind_type_vector_outputs_tensor() {
376        let subs = Type::Tensor {
377            shape: Some(vec![Some(3), Some(1)]),
378        };
379        assert_eq!(
380            sub2ind_type(
381                &[Type::Tensor { shape: None }, subs.clone(), Type::Num],
382                &ResolveContext::new(Vec::new()),
383            ),
384            Type::Tensor {
385                shape: Some(vec![Some(3), Some(1)])
386            }
387        );
388    }
389
390    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
391    #[test]
392    fn broadcasts_scalars_over_vectors() {
393        let dims = Tensor::new(vec![3.0, 4.0], vec![1, 2]).unwrap();
394        let rows = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
395        let result = sub2ind_builtin(
396            Value::Tensor(dims),
397            vec![Value::Tensor(rows), Value::Num(4.0)],
398        )
399        .unwrap();
400        match result {
401            Value::Tensor(t) => {
402                assert_eq!(t.shape, vec![3, 1]);
403                assert_eq!(t.data, vec![10.0, 11.0, 12.0]);
404            }
405            other => panic!("expected tensor result, got {other:?}"),
406        }
407    }
408
409    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
410    #[test]
411    fn handles_three_dimensions() {
412        let dims = Tensor::new(vec![2.0, 3.0, 4.0], vec![1, 3]).unwrap();
413        let row = Tensor::new(vec![1.0, 1.0], vec![1, 2]).unwrap();
414        let col = Tensor::new(vec![2.0, 3.0], vec![1, 2]).unwrap();
415        let page = Tensor::new(vec![1.0, 2.0], vec![1, 2]).unwrap();
416        let result = sub2ind_builtin(
417            Value::Tensor(dims),
418            vec![Value::Tensor(row), Value::Tensor(col), Value::Tensor(page)],
419        )
420        .unwrap();
421        match result {
422            Value::Tensor(t) => {
423                assert_eq!(t.shape, vec![1, 2]);
424                assert_eq!(t.data, vec![3.0, 11.0]);
425            }
426            other => panic!("expected tensor result, got {other:?}"),
427        }
428    }
429
430    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
431    #[test]
432    fn rejects_out_of_range_subscripts() {
433        let dims = Tensor::new(vec![3.0, 4.0], vec![1, 2]).unwrap();
434        let err = sub2ind_builtin(Value::Tensor(dims), vec![Value::Num(4.0), Value::Num(1.0)])
435            .unwrap_err();
436        assert!(
437            err.to_string().contains("Index exceeds"),
438            "expected index bounds error, got {err}"
439        );
440    }
441
442    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
443    #[test]
444    fn rejects_shape_mismatch() {
445        let dims = Tensor::new(vec![3.0, 4.0], vec![1, 2]).unwrap();
446        let rows = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
447        let cols = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
448        let err = sub2ind_builtin(
449            Value::Tensor(dims),
450            vec![Value::Tensor(rows), Value::Tensor(cols)],
451        )
452        .unwrap_err();
453        assert!(
454            err.to_string().contains("same size"),
455            "expected size mismatch error, got {err}"
456        );
457    }
458
459    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
460    #[test]
461    fn rejects_non_integer_subscripts() {
462        let dims = Tensor::new(vec![3.0, 4.0], vec![1, 2]).unwrap();
463        let err = sub2ind_builtin(Value::Tensor(dims), vec![Value::Num(1.5), Value::Num(1.0)])
464            .unwrap_err();
465        assert!(
466            err.to_string().contains("real positive integers"),
467            "expected integer coercion error, got {err}"
468        );
469    }
470
471    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
472    #[test]
473    fn accepts_integer_value_variants() {
474        let dims = Value::Tensor(Tensor::new(vec![3.0], vec![1, 1]).unwrap());
475        let result = sub2ind_builtin(dims, vec![Value::Int(IntValue::I32(2))]).expect("sub2ind");
476        assert_eq!(result, Value::Num(2.0));
477    }
478
479    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
480    #[test]
481    fn sub2ind_gpu_roundtrip() {
482        test_support::with_test_provider(|provider| {
483            let dims = Tensor::new(vec![3.0, 4.0], vec![1, 2]).unwrap();
484            let rows = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
485            let cols = Tensor::new(vec![4.0, 4.0, 4.0], vec![3, 1]).unwrap();
486
487            let dims_handle = provider
488                .upload(&HostTensorView {
489                    data: &dims.data,
490                    shape: &dims.shape,
491                })
492                .expect("upload dims");
493            let rows_handle = provider
494                .upload(&HostTensorView {
495                    data: &rows.data,
496                    shape: &rows.shape,
497                })
498                .expect("upload rows");
499            let cols_handle = provider
500                .upload(&HostTensorView {
501                    data: &cols.data,
502                    shape: &cols.shape,
503                })
504                .expect("upload cols");
505
506            let result = sub2ind_builtin(
507                Value::GpuTensor(dims_handle),
508                vec![Value::GpuTensor(rows_handle), Value::GpuTensor(cols_handle)],
509            )
510            .expect("sub2ind");
511
512            match result {
513                Value::GpuTensor(handle) => {
514                    let gathered = test_support::gather(Value::GpuTensor(handle)).unwrap();
515                    assert_eq!(gathered.shape, vec![3, 1]);
516                    assert_eq!(gathered.data, vec![10.0, 11.0, 12.0]);
517                }
518                other => panic!("expected gpu tensor, got {other:?}"),
519            }
520        });
521    }
522
523    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
524    #[test]
525    #[cfg(feature = "wgpu")]
526    fn sub2ind_wgpu_matches_cpu() {
527        let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
528            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
529        );
530        let Some(provider) = runmat_accelerate_api::provider() else {
531            panic!("wgpu provider not available");
532        };
533
534        let dims = Tensor::new(vec![3.0, 4.0], vec![1, 2]).unwrap();
535        let rows = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
536        let cols = Tensor::new(vec![4.0, 4.0, 4.0], vec![3, 1]).unwrap();
537
538        let cpu = sub2ind_builtin(
539            Value::Tensor(dims.clone()),
540            vec![Value::Tensor(rows.clone()), Value::Tensor(cols.clone())],
541        )
542        .expect("cpu sub2ind");
543
544        let rows_handle = provider
545            .upload(&HostTensorView {
546                data: &rows.data,
547                shape: &rows.shape,
548            })
549            .expect("upload rows");
550        let cols_handle = provider
551            .upload(&HostTensorView {
552                data: &cols.data,
553                shape: &cols.shape,
554            })
555            .expect("upload cols");
556
557        let result = sub2ind_builtin(
558            Value::Tensor(dims),
559            vec![Value::GpuTensor(rows_handle), Value::GpuTensor(cols_handle)],
560        )
561        .expect("wgpu sub2ind");
562
563        let gathered = test_support::gather(result).expect("gather");
564        let expected = match cpu {
565            Value::Tensor(t) => t,
566            Value::Num(v) => Tensor::new(vec![v], vec![1, 1]).unwrap(),
567            other => panic!("unexpected cpu result {other:?}"),
568        };
569        assert_eq!(gathered.shape, expected.shape);
570        assert_eq!(gathered.data, expected.data);
571    }
572}