1#[cfg(not(target_arch = "wasm32"))]
4use runmat_accelerate_api::GpuTensorHandle;
5use runmat_accelerate_api::HostTensorView;
6use runmat_builtins::{ResolveContext, Tensor, Type, Value};
7use runmat_macros::runtime_builtin;
8
9use super::common::{build_strides, dims_from_tokens, materialize_value, parse_dims};
10use crate::builtins::array::type_resolvers::is_scalar_type;
11use crate::builtins::common::arg_tokens::tokens_from_context;
12use crate::builtins::common::spec::{
13 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
14 ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
15};
16use crate::builtins::common::tensor;
17use crate::{build_runtime_error, RuntimeError};
18use runmat_builtins::shape_rules::element_count_if_known;
19
20#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::array::indexing::sub2ind")]
21pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
22 name: "sub2ind",
23 op_kind: GpuOpKind::Custom("indexing"),
24 supported_precisions: &[ScalarType::F32, ScalarType::F64],
25 broadcast: BroadcastSemantics::Matlab,
26 provider_hooks: &[ProviderHook::Custom("sub2ind")],
27 constant_strategy: ConstantStrategy::InlineLiteral,
28 residency: ResidencyPolicy::NewHandle,
29 nan_mode: ReductionNaN::Include,
30 two_pass_threshold: None,
31 workgroup_size: None,
32 accepts_nan_mode: false,
33 notes: "Providers can implement the custom `sub2ind` hook to execute on device; runtimes fall back to host computation otherwise.",
34};
35
36#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::array::indexing::sub2ind")]
37pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
38 name: "sub2ind",
39 shape: ShapeRequirements::Any,
40 constant_strategy: ConstantStrategy::InlineLiteral,
41 elementwise: None,
42 reduction: None,
43 emits_nan: false,
44 notes: "Index conversion executes eagerly on the host; fusion does not apply.",
45};
46
47fn sub2ind_type(args: &[Type], ctx: &ResolveContext) -> Type {
48 if args.len() < 2 {
49 return Type::Unknown;
50 }
51 if let Some(dims) = dims_from_tokens(&tokens_from_context(ctx)) {
52 if args.len() - 1 != dims.len() {
53 return Type::Unknown;
54 }
55 }
56 let subscripts = &args[1..];
57 if subscripts.iter().all(|ty| is_scalar_type(ty)) {
58 return Type::Num;
59 }
60 for ty in subscripts {
61 if let Type::Tensor { shape: Some(shape) } | Type::Logical { shape: Some(shape) } = ty {
62 if element_count_if_known(shape).unwrap_or(0) > 1 {
63 return Type::Tensor {
64 shape: Some(shape.clone()),
65 };
66 }
67 }
68 }
69 Type::tensor()
70}
71
72#[runtime_builtin(
73 name = "sub2ind",
74 category = "array/indexing",
75 summary = "Convert N-D subscripts into MATLAB-style column-major linear indices.",
76 keywords = "sub2ind,linear index,column major,gpu indexing",
77 accel = "custom",
78 type_resolver(sub2ind_type),
79 builtin_path = "crate::builtins::array::indexing::sub2ind"
80)]
81async fn sub2ind_builtin(dims_val: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
82 let (dims_value, dims_was_gpu) = materialize_value(dims_val, "sub2ind").await?;
83 let dims = parse_dims(&dims_value, "sub2ind").await?;
84 if dims.is_empty() {
85 return Err(sub2ind_error("Size vector must have at least one element."));
86 }
87
88 if rest.len() != dims.len() {
89 return Err(sub2ind_error(
90 "The number of subscripts supplied must equal the number of dimensions in the size vector.",
91 ));
92 }
93
94 if let Some(value) = try_gpu_sub2ind(&dims, &rest)? {
95 return Ok(value);
96 }
97
98 let mut saw_gpu = dims_was_gpu;
99 let mut subscripts: Vec<Tensor> = Vec::with_capacity(rest.len());
100 for value in rest {
101 let (materialised, was_gpu) = materialize_value(value, "sub2ind").await?;
102 saw_gpu |= was_gpu;
103 let tensor = tensor::value_into_tensor_for("sub2ind", materialised)
104 .map_err(|message| sub2ind_error(message))?;
105 subscripts.push(tensor);
106 }
107
108 let (result_data, result_shape) = compute_indices(&dims, &subscripts)?;
109 let want_gpu_output = saw_gpu && runmat_accelerate_api::provider().is_some();
110
111 if want_gpu_output {
112 #[cfg(all(test, feature = "wgpu"))]
113 {
114 if runmat_accelerate_api::provider().is_none() {
115 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
116 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
117 );
118 }
119 }
120 let shape = result_shape.clone().unwrap_or_else(|| vec![1, 1]);
121 if let Some(provider) = runmat_accelerate_api::provider() {
122 let view = HostTensorView {
123 data: &result_data,
124 shape: &shape,
125 };
126 if let Ok(handle) = provider.upload(&view) {
127 return Ok(Value::GpuTensor(handle));
128 }
129 }
130 }
131
132 build_host_value(result_data, result_shape)
133}
134
135fn try_gpu_sub2ind(dims: &[usize], subs: &[Value]) -> crate::BuiltinResult<Option<Value>> {
136 #[cfg(target_arch = "wasm32")]
137 {
138 let _ = (dims, subs);
139 Ok(None)
140 }
141 #[cfg(not(target_arch = "wasm32"))]
142 {
143 #[cfg(all(test, feature = "wgpu"))]
144 {
145 if subs
146 .iter()
147 .any(|v| matches!(v, Value::GpuTensor(h) if h.device_id != 0))
148 {
149 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
150 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
151 );
152 }
153 }
154 let provider = match runmat_accelerate_api::provider() {
155 Some(p) => p,
156 None => return Ok(None),
157 };
158 if !subs
159 .iter()
160 .all(|value| matches!(value, Value::GpuTensor(_)))
161 {
162 return Ok(None);
163 }
164 if dims.is_empty() {
165 return Ok(None);
166 }
167
168 let mut handles: Vec<&GpuTensorHandle> = Vec::with_capacity(subs.len());
169 for value in subs {
170 if let Value::GpuTensor(handle) = value {
171 handles.push(handle);
172 }
173 }
174
175 if handles.len() != dims.len() {
176 return Err(sub2ind_error(
177 "The number of subscripts supplied must equal the number of dimensions in the size vector.",
178 ));
179 }
180
181 let mut scalar_mask: Vec<bool> = Vec::with_capacity(handles.len());
182 let mut target_shape: Option<Vec<usize>> = None;
183 let mut result_len: usize = 1;
184 let mut saw_non_scalar = false;
185
186 for handle in &handles {
187 let len = tensor::element_count(&handle.shape);
188 let is_scalar = len == 1;
189 scalar_mask.push(is_scalar);
190 if !is_scalar {
191 saw_non_scalar = true;
192 if let Some(existing) = &target_shape {
193 if existing != &handle.shape {
194 return Err(sub2ind_error("Subscript inputs must have the same size."));
195 }
196 } else {
197 target_shape = Some(handle.shape.clone());
198 result_len = len;
199 }
200 }
201 }
202
203 if !saw_non_scalar {
204 target_shape = Some(vec![1, 1]);
205 result_len = 1;
206 } else if let Some(shape) = &target_shape {
207 result_len = tensor::element_count(shape);
208 }
209
210 let strides = build_strides(dims, "sub2ind")?;
211 if dims.iter().any(|&d| d > u32::MAX as usize)
212 || strides.iter().any(|&s| s > u32::MAX as usize)
213 || result_len > u32::MAX as usize
214 {
215 return Ok(None);
216 }
217
218 let output_shape = target_shape.clone().unwrap_or_else(|| vec![1, 1]);
219 match provider.sub2ind(
220 dims,
221 &strides,
222 &handles,
223 &scalar_mask,
224 result_len,
225 &output_shape,
226 ) {
227 Ok(handle) => Ok(Some(Value::GpuTensor(handle))),
228 Err(err) => Err(sub2ind_error(err.to_string())),
229 }
230 }
231}
232
233fn compute_indices(
234 dims: &[usize],
235 subscripts: &[Tensor],
236) -> crate::BuiltinResult<(Vec<f64>, Option<Vec<usize>>)> {
237 let mut target_shape: Option<Vec<usize>> = None;
238 let mut result_len: usize = 1;
239 let mut has_non_scalar = false;
240
241 for tensor in subscripts {
242 if tensor.data.len() != 1 {
243 has_non_scalar = true;
244 if let Some(shape) = &target_shape {
245 if &tensor.shape != shape {
246 return Err(sub2ind_error("Subscript inputs must have the same size."));
247 }
248 } else {
249 target_shape = Some(tensor.shape.clone());
250 result_len = tensor.data.len();
251 }
252 }
253 }
254
255 if !has_non_scalar {
256 target_shape = Some(vec![1, 1]);
258 result_len = 1;
259 }
260
261 if result_len == 0 {
262 return Ok((Vec::new(), target_shape));
263 }
264
265 let strides = build_strides(dims, "sub2ind")?;
266 let mut output = Vec::with_capacity(result_len);
267
268 for idx in 0..result_len {
269 let mut offset: usize = 0;
270 for (dim_index, (&dim, tensor)) in dims.iter().zip(subscripts.iter()).enumerate() {
271 let raw = subscript_value(tensor, idx);
272 let coerced = coerce_subscript(raw, dim_index + 1, dim)?;
273 let term = coerced
274 .checked_sub(1)
275 .and_then(|v| v.checked_mul(strides[dim_index]))
276 .ok_or_else(|| sub2ind_error("Index exceeds array dimensions."))?;
277 offset = offset
278 .checked_add(term)
279 .ok_or_else(|| sub2ind_error("Index exceeds array dimensions."))?;
280 }
281 output.push((offset + 1) as f64);
282 }
283
284 Ok((output, target_shape))
285}
286
287fn subscript_value(tensor: &Tensor, idx: usize) -> f64 {
288 if tensor.data.len() == 1 {
289 tensor.data[0]
290 } else {
291 tensor.data[idx]
292 }
293}
294
295fn coerce_subscript(value: f64, dim_number: usize, dim_size: usize) -> crate::BuiltinResult<usize> {
296 if !value.is_finite() {
297 return Err(sub2ind_error(
298 "Subscript indices must either be real positive integers or logicals.",
299 ));
300 }
301 let rounded = value.round();
302 if (rounded - value).abs() > f64::EPSILON {
303 return Err(sub2ind_error(
304 "Subscript indices must either be real positive integers or logicals.",
305 ));
306 }
307 if rounded < 1.0 {
308 return Err(sub2ind_error(
309 "Subscript indices must either be real positive integers or logicals.",
310 ));
311 }
312 if rounded > dim_size as f64 {
313 return Err(dimension_bounds_error(dim_number));
314 }
315 Ok(rounded as usize)
316}
317
318fn dimension_bounds_error(dim_number: usize) -> RuntimeError {
319 let message = match dim_number {
320 1 => format!("Index exceeds the number of rows in dimension {dim_number}."),
321 2 => format!("Index exceeds the number of columns in dimension {dim_number}."),
322 3 => format!("Index exceeds the number of pages in dimension {dim_number}."),
323 _ => "Index exceeds array dimensions.".to_string(),
324 };
325 sub2ind_error(message)
326}
327
328fn build_host_value(data: Vec<f64>, shape: Option<Vec<usize>>) -> crate::BuiltinResult<Value> {
329 let shape = shape.unwrap_or_else(|| vec![1, 1]);
330 if data.len() == 1 && tensor::element_count(&shape) == 1 {
331 Ok(Value::Num(data[0]))
332 } else {
333 let tensor = Tensor::new(data, shape)
334 .map_err(|e| sub2ind_error(format!("Unable to construct sub2ind output: {e}")))?;
335 Ok(Value::Tensor(tensor))
336 }
337}
338
339fn sub2ind_error(message: impl Into<String>) -> RuntimeError {
340 build_runtime_error(message).with_builtin("sub2ind").build()
341}
342
343#[cfg(test)]
344pub(crate) mod tests {
345 use super::*;
346 use crate::builtins::common::test_support;
347 use futures::executor::block_on;
348 use runmat_builtins::{IntValue, Tensor, Type, Value};
349
350 fn sub2ind_builtin(dims_val: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
351 block_on(super::sub2ind_builtin(dims_val, rest))
352 }
353
354 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
355 #[test]
356 fn converts_scalar_indices() {
357 let dims = Tensor::new(vec![3.0, 4.0], vec![1, 2]).unwrap();
358 let result =
359 sub2ind_builtin(Value::Tensor(dims), vec![Value::Num(2.0), Value::Num(3.0)]).unwrap();
360 assert_eq!(result, Value::Num(8.0));
361 }
362
363 #[test]
364 fn sub2ind_type_scalar_outputs_num() {
365 assert_eq!(
366 sub2ind_type(
367 &[Type::Tensor { shape: None }, Type::Num, Type::Int],
368 &ResolveContext::new(Vec::new()),
369 ),
370 Type::Num
371 );
372 }
373
374 #[test]
375 fn sub2ind_type_vector_outputs_tensor() {
376 let subs = Type::Tensor {
377 shape: Some(vec![Some(3), Some(1)]),
378 };
379 assert_eq!(
380 sub2ind_type(
381 &[Type::Tensor { shape: None }, subs.clone(), Type::Num],
382 &ResolveContext::new(Vec::new()),
383 ),
384 Type::Tensor {
385 shape: Some(vec![Some(3), Some(1)])
386 }
387 );
388 }
389
390 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
391 #[test]
392 fn broadcasts_scalars_over_vectors() {
393 let dims = Tensor::new(vec![3.0, 4.0], vec![1, 2]).unwrap();
394 let rows = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
395 let result = sub2ind_builtin(
396 Value::Tensor(dims),
397 vec![Value::Tensor(rows), Value::Num(4.0)],
398 )
399 .unwrap();
400 match result {
401 Value::Tensor(t) => {
402 assert_eq!(t.shape, vec![3, 1]);
403 assert_eq!(t.data, vec![10.0, 11.0, 12.0]);
404 }
405 other => panic!("expected tensor result, got {other:?}"),
406 }
407 }
408
409 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
410 #[test]
411 fn handles_three_dimensions() {
412 let dims = Tensor::new(vec![2.0, 3.0, 4.0], vec![1, 3]).unwrap();
413 let row = Tensor::new(vec![1.0, 1.0], vec![1, 2]).unwrap();
414 let col = Tensor::new(vec![2.0, 3.0], vec![1, 2]).unwrap();
415 let page = Tensor::new(vec![1.0, 2.0], vec![1, 2]).unwrap();
416 let result = sub2ind_builtin(
417 Value::Tensor(dims),
418 vec![Value::Tensor(row), Value::Tensor(col), Value::Tensor(page)],
419 )
420 .unwrap();
421 match result {
422 Value::Tensor(t) => {
423 assert_eq!(t.shape, vec![1, 2]);
424 assert_eq!(t.data, vec![3.0, 11.0]);
425 }
426 other => panic!("expected tensor result, got {other:?}"),
427 }
428 }
429
430 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
431 #[test]
432 fn rejects_out_of_range_subscripts() {
433 let dims = Tensor::new(vec![3.0, 4.0], vec![1, 2]).unwrap();
434 let err = sub2ind_builtin(Value::Tensor(dims), vec![Value::Num(4.0), Value::Num(1.0)])
435 .unwrap_err();
436 assert!(
437 err.to_string().contains("Index exceeds"),
438 "expected index bounds error, got {err}"
439 );
440 }
441
442 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
443 #[test]
444 fn rejects_shape_mismatch() {
445 let dims = Tensor::new(vec![3.0, 4.0], vec![1, 2]).unwrap();
446 let rows = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
447 let cols = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
448 let err = sub2ind_builtin(
449 Value::Tensor(dims),
450 vec![Value::Tensor(rows), Value::Tensor(cols)],
451 )
452 .unwrap_err();
453 assert!(
454 err.to_string().contains("same size"),
455 "expected size mismatch error, got {err}"
456 );
457 }
458
459 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
460 #[test]
461 fn rejects_non_integer_subscripts() {
462 let dims = Tensor::new(vec![3.0, 4.0], vec![1, 2]).unwrap();
463 let err = sub2ind_builtin(Value::Tensor(dims), vec![Value::Num(1.5), Value::Num(1.0)])
464 .unwrap_err();
465 assert!(
466 err.to_string().contains("real positive integers"),
467 "expected integer coercion error, got {err}"
468 );
469 }
470
471 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
472 #[test]
473 fn accepts_integer_value_variants() {
474 let dims = Value::Tensor(Tensor::new(vec![3.0], vec![1, 1]).unwrap());
475 let result = sub2ind_builtin(dims, vec![Value::Int(IntValue::I32(2))]).expect("sub2ind");
476 assert_eq!(result, Value::Num(2.0));
477 }
478
479 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
480 #[test]
481 fn sub2ind_gpu_roundtrip() {
482 test_support::with_test_provider(|provider| {
483 let dims = Tensor::new(vec![3.0, 4.0], vec![1, 2]).unwrap();
484 let rows = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
485 let cols = Tensor::new(vec![4.0, 4.0, 4.0], vec![3, 1]).unwrap();
486
487 let dims_handle = provider
488 .upload(&HostTensorView {
489 data: &dims.data,
490 shape: &dims.shape,
491 })
492 .expect("upload dims");
493 let rows_handle = provider
494 .upload(&HostTensorView {
495 data: &rows.data,
496 shape: &rows.shape,
497 })
498 .expect("upload rows");
499 let cols_handle = provider
500 .upload(&HostTensorView {
501 data: &cols.data,
502 shape: &cols.shape,
503 })
504 .expect("upload cols");
505
506 let result = sub2ind_builtin(
507 Value::GpuTensor(dims_handle),
508 vec![Value::GpuTensor(rows_handle), Value::GpuTensor(cols_handle)],
509 )
510 .expect("sub2ind");
511
512 match result {
513 Value::GpuTensor(handle) => {
514 let gathered = test_support::gather(Value::GpuTensor(handle)).unwrap();
515 assert_eq!(gathered.shape, vec![3, 1]);
516 assert_eq!(gathered.data, vec![10.0, 11.0, 12.0]);
517 }
518 other => panic!("expected gpu tensor, got {other:?}"),
519 }
520 });
521 }
522
523 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
524 #[test]
525 #[cfg(feature = "wgpu")]
526 fn sub2ind_wgpu_matches_cpu() {
527 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
528 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
529 );
530 let Some(provider) = runmat_accelerate_api::provider() else {
531 panic!("wgpu provider not available");
532 };
533
534 let dims = Tensor::new(vec![3.0, 4.0], vec![1, 2]).unwrap();
535 let rows = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
536 let cols = Tensor::new(vec![4.0, 4.0, 4.0], vec![3, 1]).unwrap();
537
538 let cpu = sub2ind_builtin(
539 Value::Tensor(dims.clone()),
540 vec![Value::Tensor(rows.clone()), Value::Tensor(cols.clone())],
541 )
542 .expect("cpu sub2ind");
543
544 let rows_handle = provider
545 .upload(&HostTensorView {
546 data: &rows.data,
547 shape: &rows.shape,
548 })
549 .expect("upload rows");
550 let cols_handle = provider
551 .upload(&HostTensorView {
552 data: &cols.data,
553 shape: &cols.shape,
554 })
555 .expect("upload cols");
556
557 let result = sub2ind_builtin(
558 Value::Tensor(dims),
559 vec![Value::GpuTensor(rows_handle), Value::GpuTensor(cols_handle)],
560 )
561 .expect("wgpu sub2ind");
562
563 let gathered = test_support::gather(result).expect("gather");
564 let expected = match cpu {
565 Value::Tensor(t) => t,
566 Value::Num(v) => Tensor::new(vec![v], vec![1, 1]).unwrap(),
567 other => panic!("unexpected cpu result {other:?}"),
568 };
569 assert_eq!(gathered.shape, expected.shape);
570 assert_eq!(gathered.data, expected.data);
571 }
572}