1use runmat_accelerate_api::HostTensorView;
4use runmat_builtins::{ResolveContext, Tensor, Type, Value};
5use runmat_macros::runtime_builtin;
6
7use super::common::{
8 build_strides, dims_from_tokens, materialize_value, parse_dims, total_elements,
9};
10use crate::builtins::array::type_resolvers::size_vector_len;
11use crate::builtins::common::arg_tokens::tokens_from_context;
12use crate::builtins::common::spec::{
13 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
14 ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
15};
16use crate::builtins::common::tensor;
17use crate::{build_runtime_error, make_cell, RuntimeError};
18
19#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::array::indexing::ind2sub")]
20pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
21 name: "ind2sub",
22 op_kind: GpuOpKind::Custom("indexing"),
23 supported_precisions: &[ScalarType::F32, ScalarType::F64],
24 broadcast: BroadcastSemantics::Matlab,
25 provider_hooks: &[ProviderHook::Custom("ind2sub")],
26 constant_strategy: ConstantStrategy::InlineLiteral,
27 residency: ResidencyPolicy::NewHandle,
28 nan_mode: ReductionNaN::Include,
29 two_pass_threshold: None,
30 workgroup_size: None,
31 accepts_nan_mode: false,
32 notes: "WGPU provider executes `ind2sub` entirely on-device; other providers fall back to the host implementation and re-upload results to preserve residency.",
33};
34
35#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::array::indexing::ind2sub")]
36pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
37 name: "ind2sub",
38 shape: ShapeRequirements::Any,
39 constant_strategy: ConstantStrategy::InlineLiteral,
40 elementwise: None,
41 reduction: None,
42 emits_nan: false,
43 notes: "Index conversion is eager and does not participate in fusion today.",
44};
45
46fn ind2sub_type(args: &[Type], ctx: &ResolveContext) -> Type {
47 let Some(dims) = args.first() else {
48 return Type::Unknown;
49 };
50 let length = dims_from_tokens(&tokens_from_context(ctx))
51 .map(|values| values.len())
52 .or_else(|| size_vector_len(dims));
53 Type::Cell {
54 element_type: Some(Box::new(Type::tensor())),
55 length,
56 }
57}
58
59#[runtime_builtin(
60 name = "ind2sub",
61 category = "array/indexing",
62 summary = "Convert MATLAB column-major linear indices into per-dimension subscript arrays.",
63 keywords = "ind2sub,linear index,subscripts,column major,gpu indexing",
64 accel = "custom",
65 type_resolver(ind2sub_type),
66 builtin_path = "crate::builtins::array::indexing::ind2sub"
67)]
68async fn ind2sub_builtin(dims_val: Value, indices_val: Value) -> crate::BuiltinResult<Value> {
69 let (dims_value, dims_was_gpu) = materialize_value(dims_val, "ind2sub").await?;
70 let dims = parse_dims(&dims_value, "ind2sub").await?;
71 if dims.is_empty() {
72 return Err(ind2sub_error("Size vector must have at least one element."));
73 }
74
75 let total = total_elements(&dims, "ind2sub")?;
76 let strides = build_strides(&dims, "ind2sub")?;
77
78 if let Some(result) = try_gpu_ind2sub(&dims, &strides, total, &indices_val)? {
79 return Ok(result);
80 }
81
82 let (indices_value, indices_was_gpu) = materialize_value(indices_val, "ind2sub").await?;
83 let indices_tensor = tensor::value_into_tensor_for("ind2sub", indices_value)
84 .map_err(|message| ind2sub_error(message))?;
85
86 let subscripts = compute_subscripts(&dims, total, &strides, &indices_tensor)?;
87
88 let want_gpu = (dims_was_gpu || indices_was_gpu) && runmat_accelerate_api::provider().is_some();
89
90 let mut outputs: Vec<Value> = Vec::with_capacity(dims.len());
91 for tensor in subscripts {
92 if want_gpu {
93 #[cfg(all(test, feature = "wgpu"))]
94 {
95 if runmat_accelerate_api::provider().is_none() {
96 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
97 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
98 );
99 }
100 }
101 if let Some(provider) = runmat_accelerate_api::provider() {
102 let view = HostTensorView {
103 data: &tensor.data,
104 shape: &tensor.shape,
105 };
106 if let Ok(handle) = provider.upload(&view) {
107 outputs.push(Value::GpuTensor(handle));
108 continue;
109 }
110 }
111 }
112 outputs.push(tensor::tensor_into_value(tensor));
113 }
114
115 make_cell(outputs, 1, dims.len()).map_err(|message| ind2sub_error(message))
116}
117
118fn try_gpu_ind2sub(
119 dims: &[usize],
120 strides: &[usize],
121 total: usize,
122 indices: &Value,
123) -> crate::BuiltinResult<Option<Value>> {
124 #[cfg(target_arch = "wasm32")]
125 {
126 let _ = (dims, strides, total, indices);
127 Ok(None)
128 }
129 #[cfg(not(target_arch = "wasm32"))]
130 {
131 #[cfg(all(test, feature = "wgpu"))]
132 {
133 if let Value::GpuTensor(h) = indices {
134 if h.device_id != 0 {
135 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
136 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
137 );
138 }
139 }
140 }
141 let provider = match runmat_accelerate_api::provider() {
142 Some(p) => p,
143 None => return Ok(None),
144 };
145 if !provider.supports_ind2sub() {
146 return Ok(None);
147 }
148 let handle = match indices {
149 Value::GpuTensor(handle) => handle,
150 _ => return Ok(None),
151 };
152 if dims.len() != strides.len() {
153 return Err(ind2sub_error("Size vector must have at least one element."));
154 }
155 if dims.iter().any(|&d| d > u32::MAX as usize)
156 || strides.iter().any(|&s| s > u32::MAX as usize)
157 || total > u32::MAX as usize
158 {
159 return Ok(None);
160 }
161 let len = if handle.shape.is_empty() {
162 1usize
163 } else {
164 handle.shape.iter().copied().product()
165 };
166 if total == 0 && len > 0 {
167 return Err(ind2sub_error(
168 "Index exceeds number of array elements. Index must not exceed 0.",
169 ));
170 }
171 if len > u32::MAX as usize {
172 return Ok(None);
173 }
174 let output_shape = if handle.shape.is_empty() {
175 vec![len, 1]
176 } else {
177 handle.shape.clone()
178 };
179 match provider.ind2sub(dims, strides, handle, total, len, &output_shape) {
180 Ok(handles) => {
181 if handles.len() != dims.len() {
182 return Err(ind2sub_error(
183 "ind2sub: provider returned an unexpected number of outputs.",
184 ));
185 }
186 let values: Vec<Value> = handles.into_iter().map(Value::GpuTensor).collect();
187 make_cell(values, 1, dims.len())
188 .map(Some)
189 .map_err(|message| ind2sub_error(message))
190 }
191 Err(err) => Err(ind2sub_error(err.to_string())),
192 }
193 }
194}
195
196fn compute_subscripts(
197 dims: &[usize],
198 total: usize,
199 strides: &[usize],
200 indices: &Tensor,
201) -> crate::BuiltinResult<Vec<Tensor>> {
202 if strides.len() != dims.len() {
203 return Err(ind2sub_error("Size vector must have at least one element."));
204 }
205
206 let len = indices.data.len();
207 let mut outputs: Vec<Vec<f64>> = dims.iter().map(|_| Vec::with_capacity(len)).collect();
208
209 for &value in &indices.data {
210 let idx = coerce_linear_index(value, total)?;
211 let zero_based = idx - 1;
212 for (dim_index, (&dim, &stride)) in dims.iter().zip(strides.iter()).enumerate() {
213 let coord = ((zero_based / stride) % dim) + 1;
214 outputs[dim_index].push(coord as f64);
215 }
216 }
217
218 let output_shape = if indices.shape.is_empty() {
219 vec![len, 1]
220 } else {
221 indices.shape.clone()
222 };
223
224 let mut tensors = Vec::with_capacity(dims.len());
225 for data in outputs {
226 let tensor = Tensor::new(data, output_shape.clone())
227 .map_err(|e| ind2sub_error(format!("ind2sub: {e}")))?;
228 tensors.push(tensor);
229 }
230 Ok(tensors)
231}
232
233fn coerce_linear_index(value: f64, max_index: usize) -> crate::BuiltinResult<usize> {
234 if !value.is_finite() {
235 return Err(ind2sub_error("Linear indices must be positive integers."));
236 }
237 let rounded = value.round();
238 if (rounded - value).abs() > f64::EPSILON {
239 return Err(ind2sub_error("Linear indices must be positive integers."));
240 }
241 if rounded < 1.0 {
242 return Err(ind2sub_error("Linear indices must be positive integers."));
243 }
244 if rounded > usize::MAX as f64 {
245 return Err(ind2sub_error(
246 "Index exceeds maximum supported size for this platform.",
247 ));
248 }
249 let coerced = rounded as usize;
250 if coerced > max_index {
251 return Err(ind2sub_error(format!(
252 "Index exceeds number of array elements. Index must not exceed {}.",
253 max_index
254 )));
255 }
256 Ok(coerced)
257}
258
259fn ind2sub_error(message: impl Into<String>) -> RuntimeError {
260 build_runtime_error(message).with_builtin("ind2sub").build()
261}
262
263#[cfg(test)]
264pub(crate) mod tests {
265 use crate::builtins::common::test_support;
266 use futures::executor::block_on;
267 use runmat_accelerate_api::HostTensorView;
268 use runmat_builtins::{ResolveContext, Tensor, Type, Value};
269
270 fn ind2sub_builtin(dims_val: Value, indices_val: Value) -> crate::BuiltinResult<Value> {
271 block_on(super::ind2sub_builtin(dims_val, indices_val))
272 }
273
274 fn cell_to_vec(cell: &runmat_builtins::CellArray) -> Vec<Value> {
275 cell.data.iter().map(|ptr| (**ptr).clone()).collect()
276 }
277
278 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
279 #[test]
280 fn recovers_tensor_indices() {
281 let dims = Tensor::new(vec![3.0, 4.0], vec![1, 2]).unwrap();
282 let result = ind2sub_builtin(Value::Tensor(dims), Value::Num(8.0)).unwrap();
283 match result {
284 Value::Cell(cell) => {
285 let values = cell_to_vec(&cell);
286 assert_eq!(values.len(), 2);
287 assert_eq!(values[0], Value::Num(2.0));
288 assert_eq!(values[1], Value::Num(3.0));
289 }
290 other => panic!("expected cell output, got {other:?}"),
291 }
292 }
293
294 #[test]
295 fn ind2sub_type_infers_cell_length() {
296 let dims = Type::Tensor {
297 shape: Some(vec![Some(1), Some(3)]),
298 };
299 assert_eq!(
300 super::ind2sub_type(&[dims, Type::Num], &ResolveContext::new(Vec::new())),
301 Type::Cell {
302 element_type: Some(Box::new(Type::tensor())),
303 length: Some(3)
304 }
305 );
306 }
307
308 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
309 #[test]
310 fn handles_vector_indices() {
311 let dims = Tensor::new(vec![3.0, 5.0], vec![1, 2]).unwrap();
312 let idx = Tensor::new(vec![7.0, 8.0, 9.0], vec![1, 3]).unwrap();
313 let result =
314 ind2sub_builtin(Value::Tensor(dims), Value::Tensor(idx)).expect("ind2sub result");
315 match result {
316 Value::Cell(cell) => {
317 let values = cell_to_vec(&cell);
318 assert_eq!(values.len(), 2);
319 match &values[0] {
320 Value::Tensor(t) => {
321 assert_eq!(t.shape, vec![1, 3]);
322 assert_eq!(t.data, vec![1.0, 2.0, 3.0]);
323 }
324 other => panic!("expected tensor rows, got {other:?}"),
325 }
326 match &values[1] {
327 Value::Tensor(t) => {
328 assert_eq!(t.shape, vec![1, 3]);
329 assert_eq!(t.data, vec![3.0, 3.0, 3.0]);
330 }
331 other => panic!("expected tensor cols, got {other:?}"),
332 }
333 }
334 other => panic!("expected cell output, got {other:?}"),
335 }
336 }
337
338 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
339 #[test]
340 fn recovers_three_dimensional_indices() {
341 let dims = Tensor::new(vec![2.0, 3.0, 4.0], vec![1, 3]).unwrap();
342 let idx = Tensor::new(vec![3.0, 11.0], vec![1, 2]).unwrap();
343 let result =
344 ind2sub_builtin(Value::Tensor(dims), Value::Tensor(idx)).expect("ind2sub result");
345 if let Value::Cell(cell) = result {
346 let values = cell_to_vec(&cell);
347 assert_eq!(values.len(), 3);
348 assert_eq!(
349 values[0],
350 Value::Tensor(Tensor::new(vec![1.0, 1.0], vec![1, 2]).unwrap())
351 );
352 assert_eq!(
353 values[1],
354 Value::Tensor(Tensor::new(vec![2.0, 3.0], vec![1, 2]).unwrap())
355 );
356 assert_eq!(
357 values[2],
358 Value::Tensor(Tensor::new(vec![1.0, 2.0], vec![1, 2]).unwrap())
359 );
360 } else {
361 panic!("expected cell output");
362 }
363 }
364
365 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
366 #[test]
367 fn errors_on_out_of_range_index() {
368 let dims = Tensor::new(vec![3.0, 4.0], vec![1, 2]).unwrap();
369 let err =
370 ind2sub_builtin(Value::Tensor(dims), Value::Num(13.0)).expect_err("expected failure");
371 assert!(
372 err.message()
373 .contains("Index exceeds number of array elements"),
374 "unexpected error: {err}"
375 );
376 }
377
378 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
379 #[test]
380 fn errors_on_zero_index() {
381 let dims = Tensor::new(vec![3.0, 4.0], vec![1, 2]).unwrap();
382 let err =
383 ind2sub_builtin(Value::Tensor(dims), Value::Num(0.0)).expect_err("expected failure");
384 assert!(
385 err.contains("Linear indices must be positive integers"),
386 "unexpected error: {err}"
387 );
388 }
389
390 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
391 #[test]
392 fn errors_on_fractional_index() {
393 let dims = Tensor::new(vec![3.0, 4.0], vec![1, 2]).unwrap();
394 let err =
395 ind2sub_builtin(Value::Tensor(dims), Value::Num(2.5)).expect_err("expected failure");
396 assert!(
397 err.contains("Linear indices must be positive integers"),
398 "unexpected error: {err}"
399 );
400 }
401
402 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
403 #[test]
404 fn errors_on_invalid_size_elements() {
405 let dims = Tensor::new(vec![3.5, 4.0], vec![1, 2]).unwrap();
406 let err = ind2sub_builtin(Value::Tensor(dims), Value::Num(5.0)).expect_err("expected fail");
407 assert!(
408 err.to_string()
409 .contains("Size arguments must be positive integers"),
410 "unexpected error: {err}"
411 );
412 }
413
414 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
415 #[test]
416 fn ind2sub_gpu_roundtrip() {
417 test_support::with_test_provider(|provider| {
418 let dims = Tensor::new(vec![3.0, 4.0], vec![1, 2]).unwrap();
419 let idx_tensor = Tensor::new(vec![10.0, 11.0], vec![2, 1]).unwrap();
420 let view = HostTensorView {
421 data: &idx_tensor.data,
422 shape: &idx_tensor.shape,
423 };
424 let handle = provider.upload(&view).expect("upload indices");
425 let result = ind2sub_builtin(Value::Tensor(dims), Value::GpuTensor(handle)).unwrap();
426 match result {
427 Value::Cell(cell) => {
428 let values = cell_to_vec(&cell);
429 assert_eq!(values.len(), 2);
430 match &values[0] {
431 Value::GpuTensor(_) => {}
432 other => panic!("expected gpu tensor output, got {other:?}"),
433 }
434 match &values[1] {
435 Value::GpuTensor(_) => {}
436 other => panic!("expected gpu tensor output, got {other:?}"),
437 }
438 let rows = test_support::gather(values[0].clone()).expect("gather rows");
439 assert_eq!(rows.shape, vec![2, 1]);
440 assert_eq!(rows.data, vec![1.0, 2.0]);
441 let cols = test_support::gather(values[1].clone()).expect("gather cols");
442 assert_eq!(cols.shape, vec![2, 1]);
443 assert_eq!(cols.data, vec![4.0, 4.0]);
444 }
445 other => panic!("expected cell output, got {other:?}"),
446 }
447 });
448 }
449
450 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
451 #[test]
452 #[cfg(feature = "wgpu")]
453 fn ind2sub_wgpu_matches_cpu() {
454 let provider_init = std::panic::catch_unwind(|| {
455 runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
456 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
457 )
458 });
459 if let Ok(Ok(_)) = provider_init {
460 } else {
462 return;
463 }
464
465 let dims_tensor = Tensor::new(vec![3.0, 4.0], vec![1, 2]).unwrap();
466 let idx_tensor = Tensor::new(vec![7.0, 8.0, 9.0], vec![1, 3]).unwrap();
467
468 let cpu = ind2sub_builtin(
469 Value::Tensor(dims_tensor.clone()),
470 Value::Tensor(idx_tensor.clone()),
471 )
472 .expect("cpu ind2sub");
473
474 let provider = runmat_accelerate_api::provider().unwrap();
475 let view = HostTensorView {
476 data: &idx_tensor.data,
477 shape: &idx_tensor.shape,
478 };
479 let handle = provider.upload(&view).expect("upload indices");
480
481 let gpu = ind2sub_builtin(Value::Tensor(dims_tensor), Value::GpuTensor(handle))
482 .expect("gpu ind2sub");
483
484 let cpu_values = match cpu {
485 Value::Cell(cell) => cell_to_vec(&cell),
486 other => panic!("expected cell output, got {other:?}"),
487 };
488 let gpu_values = match gpu {
489 Value::Cell(cell) => cell_to_vec(&cell),
490 other => panic!("expected cell output, got {other:?}"),
491 };
492
493 assert_eq!(cpu_values.len(), gpu_values.len());
494
495 for (cpu_val, gpu_val) in cpu_values.iter().zip(gpu_values.iter()) {
496 let host_cpu = test_support::gather(cpu_val.clone()).expect("gather cpu");
497 let host_gpu = test_support::gather(gpu_val.clone()).expect("gather gpu");
498 assert_eq!(host_cpu.shape, host_gpu.shape);
499 assert_eq!(host_cpu.data, host_gpu.data);
500 }
501 }
502}