1use runmat_accelerate_api::HostTensorView;
4use runmat_builtins::{
5 BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
6 BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
7 ResolveContext, Tensor, Type, Value,
8};
9use runmat_macros::runtime_builtin;
10
11use super::common::{
12 build_strides, dims_from_tokens, materialize_value, parse_dims, total_elements,
13};
14use crate::builtins::array::type_resolvers::size_vector_len;
15use crate::builtins::common::arg_tokens::tokens_from_context;
16use crate::builtins::common::spec::{
17 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
18 ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
19};
20use crate::builtins::common::tensor;
21use crate::{build_runtime_error, make_cell, RuntimeError};
22
23#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::array::indexing::ind2sub")]
24pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
25 name: "ind2sub",
26 op_kind: GpuOpKind::Custom("indexing"),
27 supported_precisions: &[ScalarType::F32, ScalarType::F64],
28 broadcast: BroadcastSemantics::Matlab,
29 provider_hooks: &[ProviderHook::Custom("ind2sub")],
30 constant_strategy: ConstantStrategy::InlineLiteral,
31 residency: ResidencyPolicy::NewHandle,
32 nan_mode: ReductionNaN::Include,
33 two_pass_threshold: None,
34 workgroup_size: None,
35 accepts_nan_mode: false,
36 notes: "WGPU provider executes `ind2sub` entirely on-device; other providers fall back to the host implementation and re-upload results to preserve residency.",
37};
38
39#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::array::indexing::ind2sub")]
40pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
41 name: "ind2sub",
42 shape: ShapeRequirements::Any,
43 constant_strategy: ConstantStrategy::InlineLiteral,
44 elementwise: None,
45 reduction: None,
46 emits_nan: false,
47 notes: "Index conversion is eager and does not participate in fusion today.",
48};
49
50fn ind2sub_type(args: &[Type], ctx: &ResolveContext) -> Type {
51 let Some(dims) = args.first() else {
52 return Type::Unknown;
53 };
54 let length = dims_from_tokens(&tokens_from_context(ctx))
55 .map(|values| values.len())
56 .or_else(|| size_vector_len(dims));
57 Type::Cell {
58 element_type: Some(Box::new(Type::tensor())),
59 length,
60 }
61}
62
63const BUILTIN_NAME: &str = "ind2sub";
64
65const IND2SUB_OUTPUT_CELL: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
66 name: "subs",
67 ty: BuiltinParamType::Any,
68 arity: BuiltinParamArity::Required,
69 default: None,
70 description: "Cell array containing one subscript output per dimension.",
71}];
72
73const IND2SUB_INPUTS: [BuiltinParamDescriptor; 2] = [
74 BuiltinParamDescriptor {
75 name: "sz",
76 ty: BuiltinParamType::SizeArg,
77 arity: BuiltinParamArity::Required,
78 default: None,
79 description: "Size vector describing source array dimensions.",
80 },
81 BuiltinParamDescriptor {
82 name: "ind",
83 ty: BuiltinParamType::Any,
84 arity: BuiltinParamArity::Required,
85 default: None,
86 description: "Linear indices to convert into per-dimension subscripts.",
87 },
88];
89
90const IND2SUB_SIGNATURES: [BuiltinSignatureDescriptor; 1] = [BuiltinSignatureDescriptor {
91 label: "subs = ind2sub(sz, ind)",
92 inputs: &IND2SUB_INPUTS,
93 outputs: &IND2SUB_OUTPUT_CELL,
94}];
95
96const IND2SUB_ERROR_INVALID_INPUT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
97 code: "RM.IND2SUB.INVALID_INPUT",
98 identifier: Some("RunMat:ind2sub:InvalidInput"),
99 when: "Size vector or linear index inputs are malformed or unsupported.",
100 message: "ind2sub: invalid input arguments",
101};
102
103const IND2SUB_ERROR_INDEX_BOUNDS: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
104 code: "RM.IND2SUB.INDEX_BOUNDS",
105 identifier: Some("RunMat:ind2sub:IndexBounds"),
106 when: "At least one provided linear index exceeds array element bounds.",
107 message: "ind2sub: index exceeds array bounds",
108};
109
110const IND2SUB_ERROR_PROVIDER: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
111 code: "RM.IND2SUB.PROVIDER",
112 identifier: Some("RunMat:ind2sub:ProviderError"),
113 when: "Provider-side ind2sub execution fails or returns malformed outputs.",
114 message: "ind2sub: provider execution failed",
115};
116
117const IND2SUB_ERROR_INTERNAL: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
118 code: "RM.IND2SUB.INTERNAL",
119 identifier: Some("RunMat:ind2sub:InternalError"),
120 when: "Internal tensor/materialization logic fails while building outputs.",
121 message: "ind2sub: internal error",
122};
123
124const IND2SUB_ERRORS: [BuiltinErrorDescriptor; 4] = [
125 IND2SUB_ERROR_INVALID_INPUT,
126 IND2SUB_ERROR_INDEX_BOUNDS,
127 IND2SUB_ERROR_PROVIDER,
128 IND2SUB_ERROR_INTERNAL,
129];
130
131pub const IND2SUB_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
132 signatures: &IND2SUB_SIGNATURES,
133 output_mode: BuiltinOutputMode::Fixed,
134 completion_policy: BuiltinCompletionPolicy::Public,
135 errors: &IND2SUB_ERRORS,
136};
137
138fn ind2sub_error_with_message(
139 message: impl Into<String>,
140 error: &'static BuiltinErrorDescriptor,
141) -> RuntimeError {
142 let mut builder = build_runtime_error(message).with_builtin(BUILTIN_NAME);
143 if let Some(identifier) = error.identifier {
144 builder = builder.with_identifier(identifier);
145 }
146 builder.build()
147}
148
149fn ind2sub_input_error(message: impl Into<String>) -> RuntimeError {
150 ind2sub_error_with_message(message, &IND2SUB_ERROR_INVALID_INPUT)
151}
152
153fn ind2sub_internal_error(message: impl Into<String>) -> RuntimeError {
154 ind2sub_error_with_message(message, &IND2SUB_ERROR_INTERNAL)
155}
156
157fn ind2sub_provider_error(message: impl Into<String>) -> RuntimeError {
158 ind2sub_error_with_message(message, &IND2SUB_ERROR_PROVIDER)
159}
160
161#[runtime_builtin(
162 name = "ind2sub",
163 category = "array/indexing",
164 summary = "Convert linear indices to subscripts.",
165 keywords = "ind2sub,linear index,subscripts,column major,gpu indexing",
166 accel = "custom",
167 type_resolver(ind2sub_type),
168 descriptor(crate::builtins::array::indexing::ind2sub::IND2SUB_DESCRIPTOR),
169 builtin_path = "crate::builtins::array::indexing::ind2sub"
170)]
171async fn ind2sub_builtin(dims_val: Value, indices_val: Value) -> crate::BuiltinResult<Value> {
172 let (dims_value, dims_was_gpu) = materialize_value(dims_val, "ind2sub").await?;
173 let dims = parse_dims(&dims_value, "ind2sub").await?;
174 if dims.is_empty() {
175 return Err(ind2sub_error("Size vector must have at least one element."));
176 }
177
178 let total = total_elements(&dims, "ind2sub")?;
179 let strides = build_strides(&dims, "ind2sub")?;
180
181 if let Some(result) = try_gpu_ind2sub(&dims, &strides, total, &indices_val)? {
182 return Ok(result);
183 }
184
185 let (indices_value, indices_was_gpu) = materialize_value(indices_val, "ind2sub").await?;
186 let indices_tensor = tensor::value_into_tensor_for("ind2sub", indices_value)
187 .map_err(|message| ind2sub_error(message))?;
188
189 let subscripts = compute_subscripts(&dims, total, &strides, &indices_tensor)?;
190
191 let want_gpu = (dims_was_gpu || indices_was_gpu) && runmat_accelerate_api::provider().is_some();
192
193 let mut outputs: Vec<Value> = Vec::with_capacity(dims.len());
194 for tensor in subscripts {
195 if want_gpu {
196 #[cfg(all(test, feature = "wgpu"))]
197 {
198 if runmat_accelerate_api::provider().is_none() {
199 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
200 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
201 );
202 }
203 }
204 if let Some(provider) = runmat_accelerate_api::provider() {
205 let view = HostTensorView {
206 data: &tensor.data,
207 shape: &tensor.shape,
208 };
209 if let Ok(handle) = provider.upload(&view) {
210 outputs.push(Value::GpuTensor(handle));
211 continue;
212 }
213 }
214 }
215 outputs.push(tensor::tensor_into_value(tensor));
216 }
217
218 make_cell(outputs, 1, dims.len()).map_err(|message| ind2sub_error(message))
219}
220
221fn try_gpu_ind2sub(
222 dims: &[usize],
223 strides: &[usize],
224 total: usize,
225 indices: &Value,
226) -> crate::BuiltinResult<Option<Value>> {
227 #[cfg(target_arch = "wasm32")]
228 {
229 let _ = (dims, strides, total, indices);
230 Ok(None)
231 }
232 #[cfg(not(target_arch = "wasm32"))]
233 {
234 #[cfg(all(test, feature = "wgpu"))]
235 {
236 if let Value::GpuTensor(h) = indices {
237 if h.device_id != 0 {
238 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
239 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
240 );
241 }
242 }
243 }
244 let provider = match runmat_accelerate_api::provider() {
245 Some(p) => p,
246 None => return Ok(None),
247 };
248 if !provider.supports_ind2sub() {
249 return Ok(None);
250 }
251 let handle = match indices {
252 Value::GpuTensor(handle) => handle,
253 _ => return Ok(None),
254 };
255 if dims.len() != strides.len() {
256 return Err(ind2sub_error("Size vector must have at least one element."));
257 }
258 if dims.iter().any(|&d| d > u32::MAX as usize)
259 || strides.iter().any(|&s| s > u32::MAX as usize)
260 || total > u32::MAX as usize
261 {
262 return Ok(None);
263 }
264 let len = if handle.shape.is_empty() {
265 1usize
266 } else {
267 handle.shape.iter().copied().product()
268 };
269 if total == 0 && len > 0 {
270 return Err(ind2sub_error(
271 "Index exceeds number of array elements. Index must not exceed 0.",
272 ));
273 }
274 if len > u32::MAX as usize {
275 return Ok(None);
276 }
277 let output_shape = if handle.shape.is_empty() {
278 vec![len, 1]
279 } else {
280 handle.shape.clone()
281 };
282 match provider.ind2sub(dims, strides, handle, total, len, &output_shape) {
283 Ok(handles) => {
284 if handles.len() != dims.len() {
285 return Err(ind2sub_provider_error(
286 "ind2sub: provider returned an unexpected number of outputs.",
287 ));
288 }
289 let values: Vec<Value> = handles.into_iter().map(Value::GpuTensor).collect();
290 make_cell(values, 1, dims.len())
291 .map(Some)
292 .map_err(|message| ind2sub_error(message))
293 }
294 Err(err) => Err(ind2sub_provider_error(err.to_string())),
295 }
296 }
297}
298
299fn compute_subscripts(
300 dims: &[usize],
301 total: usize,
302 strides: &[usize],
303 indices: &Tensor,
304) -> crate::BuiltinResult<Vec<Tensor>> {
305 if strides.len() != dims.len() {
306 return Err(ind2sub_error("Size vector must have at least one element."));
307 }
308
309 let len = indices.data.len();
310 let mut outputs: Vec<Vec<f64>> = dims.iter().map(|_| Vec::with_capacity(len)).collect();
311
312 for &value in &indices.data {
313 let idx = coerce_linear_index(value, total)?;
314 let zero_based = idx - 1;
315 for (dim_index, (&dim, &stride)) in dims.iter().zip(strides.iter()).enumerate() {
316 let coord = ((zero_based / stride) % dim) + 1;
317 outputs[dim_index].push(coord as f64);
318 }
319 }
320
321 let output_shape = if indices.shape.is_empty() {
322 vec![len, 1]
323 } else {
324 indices.shape.clone()
325 };
326
327 let mut tensors = Vec::with_capacity(dims.len());
328 for data in outputs {
329 let tensor = Tensor::new(data, output_shape.clone())
330 .map_err(|e| ind2sub_internal_error(format!("ind2sub: {e}")))?;
331 tensors.push(tensor);
332 }
333 Ok(tensors)
334}
335
336fn coerce_linear_index(value: f64, max_index: usize) -> crate::BuiltinResult<usize> {
337 if !value.is_finite() {
338 return Err(ind2sub_error("Linear indices must be positive integers."));
339 }
340 let rounded = value.round();
341 if (rounded - value).abs() > f64::EPSILON {
342 return Err(ind2sub_error("Linear indices must be positive integers."));
343 }
344 if rounded < 1.0 {
345 return Err(ind2sub_error("Linear indices must be positive integers."));
346 }
347 if rounded > usize::MAX as f64 {
348 return Err(ind2sub_error(
349 "Index exceeds maximum supported size for this platform.",
350 ));
351 }
352 let coerced = rounded as usize;
353 if coerced > max_index {
354 return Err(ind2sub_error_with_message(
355 format!(
356 "Index exceeds number of array elements. Index must not exceed {}.",
357 max_index
358 ),
359 &IND2SUB_ERROR_INDEX_BOUNDS,
360 ));
361 }
362 Ok(coerced)
363}
364
365fn ind2sub_error(message: impl Into<String>) -> RuntimeError {
366 ind2sub_input_error(message)
367}
368
369#[cfg(test)]
370pub(crate) mod tests {
371 use crate::builtins::common::test_support;
372 use futures::executor::block_on;
373 use runmat_accelerate_api::HostTensorView;
374 use runmat_builtins::{ResolveContext, Tensor, Type, Value};
375
376 fn ind2sub_builtin(dims_val: Value, indices_val: Value) -> crate::BuiltinResult<Value> {
377 block_on(super::ind2sub_builtin(dims_val, indices_val))
378 }
379
380 fn cell_to_vec(cell: &runmat_builtins::CellArray) -> Vec<Value> {
381 cell.data.iter().map(|ptr| (**ptr).clone()).collect()
382 }
383
384 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
385 #[test]
386 fn recovers_tensor_indices() {
387 let dims = Tensor::new(vec![3.0, 4.0], vec![1, 2]).unwrap();
388 let result = ind2sub_builtin(Value::Tensor(dims), Value::Num(8.0)).unwrap();
389 match result {
390 Value::Cell(cell) => {
391 let values = cell_to_vec(&cell);
392 assert_eq!(values.len(), 2);
393 assert_eq!(values[0], Value::Num(2.0));
394 assert_eq!(values[1], Value::Num(3.0));
395 }
396 other => panic!("expected cell output, got {other:?}"),
397 }
398 }
399
400 #[test]
401 fn ind2sub_type_infers_cell_length() {
402 let dims = Type::Tensor {
403 shape: Some(vec![Some(1), Some(3)]),
404 };
405 assert_eq!(
406 super::ind2sub_type(&[dims, Type::Num], &ResolveContext::new(Vec::new())),
407 Type::Cell {
408 element_type: Some(Box::new(Type::tensor())),
409 length: Some(3)
410 }
411 );
412 }
413
414 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
415 #[test]
416 fn handles_vector_indices() {
417 let dims = Tensor::new(vec![3.0, 5.0], vec![1, 2]).unwrap();
418 let idx = Tensor::new(vec![7.0, 8.0, 9.0], vec![1, 3]).unwrap();
419 let result =
420 ind2sub_builtin(Value::Tensor(dims), Value::Tensor(idx)).expect("ind2sub result");
421 match result {
422 Value::Cell(cell) => {
423 let values = cell_to_vec(&cell);
424 assert_eq!(values.len(), 2);
425 match &values[0] {
426 Value::Tensor(t) => {
427 assert_eq!(t.shape, vec![1, 3]);
428 assert_eq!(t.data, vec![1.0, 2.0, 3.0]);
429 }
430 other => panic!("expected tensor rows, got {other:?}"),
431 }
432 match &values[1] {
433 Value::Tensor(t) => {
434 assert_eq!(t.shape, vec![1, 3]);
435 assert_eq!(t.data, vec![3.0, 3.0, 3.0]);
436 }
437 other => panic!("expected tensor cols, got {other:?}"),
438 }
439 }
440 other => panic!("expected cell output, got {other:?}"),
441 }
442 }
443
444 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
445 #[test]
446 fn rejects_non_integer_linear_index_identifier() {
447 let dims = Tensor::new(vec![3.0, 4.0], vec![1, 2]).unwrap();
448 let err = ind2sub_builtin(Value::Tensor(dims), Value::Num(1.25))
449 .expect_err("expected non-integer index error");
450 assert_eq!(
451 err.identifier(),
452 super::IND2SUB_ERROR_INVALID_INPUT.identifier
453 );
454 }
455
456 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
457 #[test]
458 fn rejects_out_of_bounds_linear_index_identifier() {
459 let dims = Tensor::new(vec![2.0, 2.0], vec![1, 2]).unwrap();
460 let err = ind2sub_builtin(Value::Tensor(dims), Value::Num(9.0))
461 .expect_err("expected out-of-bounds index error");
462 assert_eq!(
463 err.identifier(),
464 super::IND2SUB_ERROR_INDEX_BOUNDS.identifier
465 );
466 }
467
468 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
469 #[test]
470 fn recovers_three_dimensional_indices() {
471 let dims = Tensor::new(vec![2.0, 3.0, 4.0], vec![1, 3]).unwrap();
472 let idx = Tensor::new(vec![3.0, 11.0], vec![1, 2]).unwrap();
473 let result =
474 ind2sub_builtin(Value::Tensor(dims), Value::Tensor(idx)).expect("ind2sub result");
475 if let Value::Cell(cell) = result {
476 let values = cell_to_vec(&cell);
477 assert_eq!(values.len(), 3);
478 assert_eq!(
479 values[0],
480 Value::Tensor(Tensor::new(vec![1.0, 1.0], vec![1, 2]).unwrap())
481 );
482 assert_eq!(
483 values[1],
484 Value::Tensor(Tensor::new(vec![2.0, 3.0], vec![1, 2]).unwrap())
485 );
486 assert_eq!(
487 values[2],
488 Value::Tensor(Tensor::new(vec![1.0, 2.0], vec![1, 2]).unwrap())
489 );
490 } else {
491 panic!("expected cell output");
492 }
493 }
494
495 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
496 #[test]
497 fn errors_on_out_of_range_index() {
498 let dims = Tensor::new(vec![3.0, 4.0], vec![1, 2]).unwrap();
499 let err =
500 ind2sub_builtin(Value::Tensor(dims), Value::Num(13.0)).expect_err("expected failure");
501 assert!(
502 err.message()
503 .contains("Index exceeds number of array elements"),
504 "unexpected error: {err}"
505 );
506 }
507
508 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
509 #[test]
510 fn errors_on_zero_index() {
511 let dims = Tensor::new(vec![3.0, 4.0], vec![1, 2]).unwrap();
512 let err =
513 ind2sub_builtin(Value::Tensor(dims), Value::Num(0.0)).expect_err("expected failure");
514 assert!(
515 err.contains("Linear indices must be positive integers"),
516 "unexpected error: {err}"
517 );
518 }
519
520 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
521 #[test]
522 fn errors_on_fractional_index() {
523 let dims = Tensor::new(vec![3.0, 4.0], vec![1, 2]).unwrap();
524 let err =
525 ind2sub_builtin(Value::Tensor(dims), Value::Num(2.5)).expect_err("expected failure");
526 assert!(
527 err.contains("Linear indices must be positive integers"),
528 "unexpected error: {err}"
529 );
530 }
531
532 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
533 #[test]
534 fn errors_on_invalid_size_elements() {
535 let dims = Tensor::new(vec![3.5, 4.0], vec![1, 2]).unwrap();
536 let err = ind2sub_builtin(Value::Tensor(dims), Value::Num(5.0)).expect_err("expected fail");
537 assert!(
538 err.to_string()
539 .contains("Size arguments must be positive integers"),
540 "unexpected error: {err}"
541 );
542 }
543
544 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
545 #[test]
546 fn ind2sub_gpu_roundtrip() {
547 test_support::with_test_provider(|provider| {
548 let dims = Tensor::new(vec![3.0, 4.0], vec![1, 2]).unwrap();
549 let idx_tensor = Tensor::new(vec![10.0, 11.0], vec![2, 1]).unwrap();
550 let view = HostTensorView {
551 data: &idx_tensor.data,
552 shape: &idx_tensor.shape,
553 };
554 let handle = provider.upload(&view).expect("upload indices");
555 let result = ind2sub_builtin(Value::Tensor(dims), Value::GpuTensor(handle)).unwrap();
556 match result {
557 Value::Cell(cell) => {
558 let values = cell_to_vec(&cell);
559 assert_eq!(values.len(), 2);
560 match &values[0] {
561 Value::GpuTensor(_) => {}
562 other => panic!("expected gpu tensor output, got {other:?}"),
563 }
564 match &values[1] {
565 Value::GpuTensor(_) => {}
566 other => panic!("expected gpu tensor output, got {other:?}"),
567 }
568 let rows = test_support::gather(values[0].clone()).expect("gather rows");
569 assert_eq!(rows.shape, vec![2, 1]);
570 assert_eq!(rows.data, vec![1.0, 2.0]);
571 let cols = test_support::gather(values[1].clone()).expect("gather cols");
572 assert_eq!(cols.shape, vec![2, 1]);
573 assert_eq!(cols.data, vec![4.0, 4.0]);
574 }
575 other => panic!("expected cell output, got {other:?}"),
576 }
577 });
578 }
579
580 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
581 #[test]
582 #[cfg(feature = "wgpu")]
583 fn ind2sub_wgpu_matches_cpu() {
584 let provider_init = std::panic::catch_unwind(|| {
585 runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
586 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
587 )
588 });
589 if let Ok(Ok(_)) = provider_init {
590 } else {
592 return;
593 }
594
595 let dims_tensor = Tensor::new(vec![3.0, 4.0], vec![1, 2]).unwrap();
596 let idx_tensor = Tensor::new(vec![7.0, 8.0, 9.0], vec![1, 3]).unwrap();
597
598 let cpu = ind2sub_builtin(
599 Value::Tensor(dims_tensor.clone()),
600 Value::Tensor(idx_tensor.clone()),
601 )
602 .expect("cpu ind2sub");
603
604 let provider = runmat_accelerate_api::provider().unwrap();
605 let view = HostTensorView {
606 data: &idx_tensor.data,
607 shape: &idx_tensor.shape,
608 };
609 let handle = provider.upload(&view).expect("upload indices");
610
611 let gpu = ind2sub_builtin(Value::Tensor(dims_tensor), Value::GpuTensor(handle))
612 .expect("gpu ind2sub");
613
614 let cpu_values = match cpu {
615 Value::Cell(cell) => cell_to_vec(&cell),
616 other => panic!("expected cell output, got {other:?}"),
617 };
618 let gpu_values = match gpu {
619 Value::Cell(cell) => cell_to_vec(&cell),
620 other => panic!("expected cell output, got {other:?}"),
621 };
622
623 assert_eq!(cpu_values.len(), gpu_values.len());
624
625 for (cpu_val, gpu_val) in cpu_values.iter().zip(gpu_values.iter()) {
626 let host_cpu = test_support::gather(cpu_val.clone()).expect("gather cpu");
627 let host_gpu = test_support::gather(gpu_val.clone()).expect("gather gpu");
628 assert_eq!(host_cpu.shape, host_gpu.shape);
629 assert_eq!(host_cpu.data, host_gpu.data);
630 }
631 }
632}