1#[cfg(not(target_arch = "wasm32"))]
4use runmat_accelerate_api::GpuTensorHandle;
5use runmat_accelerate_api::HostTensorView;
6use runmat_builtins::{
7 BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
8 BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
9 ResolveContext, Tensor, Type, Value,
10};
11use runmat_macros::runtime_builtin;
12
13use super::common::{build_strides, dims_from_tokens, materialize_value, parse_dims};
14use crate::builtins::array::type_resolvers::is_scalar_type;
15use crate::builtins::common::arg_tokens::tokens_from_context;
16use crate::builtins::common::spec::{
17 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
18 ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
19};
20use crate::builtins::common::tensor;
21use crate::{build_runtime_error, RuntimeError};
22use runmat_builtins::shape_rules::element_count_if_known;
23
24#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::array::indexing::sub2ind")]
25pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
26 name: "sub2ind",
27 op_kind: GpuOpKind::Custom("indexing"),
28 supported_precisions: &[ScalarType::F32, ScalarType::F64],
29 broadcast: BroadcastSemantics::Matlab,
30 provider_hooks: &[ProviderHook::Custom("sub2ind")],
31 constant_strategy: ConstantStrategy::InlineLiteral,
32 residency: ResidencyPolicy::NewHandle,
33 nan_mode: ReductionNaN::Include,
34 two_pass_threshold: None,
35 workgroup_size: None,
36 accepts_nan_mode: false,
37 notes: "Providers can implement the custom `sub2ind` hook to execute on device; runtimes fall back to host computation otherwise.",
38};
39
40#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::array::indexing::sub2ind")]
41pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
42 name: "sub2ind",
43 shape: ShapeRequirements::Any,
44 constant_strategy: ConstantStrategy::InlineLiteral,
45 elementwise: None,
46 reduction: None,
47 emits_nan: false,
48 notes: "Index conversion executes eagerly on the host; fusion does not apply.",
49};
50
51fn sub2ind_type(args: &[Type], ctx: &ResolveContext) -> Type {
52 if args.len() < 2 {
53 return Type::Unknown;
54 }
55 if let Some(dims) = dims_from_tokens(&tokens_from_context(ctx)) {
56 if args.len() - 1 != dims.len() {
57 return Type::Unknown;
58 }
59 }
60 let subscripts = &args[1..];
61 if subscripts.iter().all(|ty| is_scalar_type(ty)) {
62 return Type::Num;
63 }
64 for ty in subscripts {
65 if let Type::Tensor { shape: Some(shape) } | Type::Logical { shape: Some(shape) } = ty {
66 if element_count_if_known(shape).unwrap_or(0) > 1 {
67 return Type::Tensor {
68 shape: Some(shape.clone()),
69 };
70 }
71 }
72 }
73 Type::tensor()
74}
75
76const BUILTIN_NAME: &str = "sub2ind";
77
78const SUB2IND_OUTPUT: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
79 name: "ind",
80 ty: BuiltinParamType::NumericArray,
81 arity: BuiltinParamArity::Required,
82 default: None,
83 description: "Column-major linear indices corresponding to provided subscripts.",
84}];
85
86const SUB2IND_INPUTS: [BuiltinParamDescriptor; 3] = [
87 BuiltinParamDescriptor {
88 name: "sz",
89 ty: BuiltinParamType::SizeArg,
90 arity: BuiltinParamArity::Required,
91 default: None,
92 description: "Size vector describing source array dimensions.",
93 },
94 BuiltinParamDescriptor {
95 name: "I1",
96 ty: BuiltinParamType::Any,
97 arity: BuiltinParamArity::Required,
98 default: None,
99 description: "First-dimension subscript values.",
100 },
101 BuiltinParamDescriptor {
102 name: "In",
103 ty: BuiltinParamType::Any,
104 arity: BuiltinParamArity::Variadic,
105 default: None,
106 description: "Remaining per-dimension subscript arrays/scalars.",
107 },
108];
109
110const SUB2IND_SIGNATURES: [BuiltinSignatureDescriptor; 1] = [BuiltinSignatureDescriptor {
111 label: "ind = sub2ind(sz, I1, In...)",
112 inputs: &SUB2IND_INPUTS,
113 outputs: &SUB2IND_OUTPUT,
114}];
115
116const SUB2IND_ERROR_INVALID_INPUT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
117 code: "RM.SUB2IND.INVALID_INPUT",
118 identifier: Some("RunMat:sub2ind:InvalidInput"),
119 when: "Size vector, subscript count, or subscript types are invalid.",
120 message: "sub2ind: invalid input arguments",
121};
122
123const SUB2IND_ERROR_INDEX_BOUNDS: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
124 code: "RM.SUB2IND.INDEX_BOUNDS",
125 identifier: Some("RunMat:sub2ind:IndexBounds"),
126 when: "At least one subscript lies outside bounds for its dimension.",
127 message: "sub2ind: subscript index exceeds dimension bounds",
128};
129
130const SUB2IND_ERROR_PROVIDER: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
131 code: "RM.SUB2IND.PROVIDER",
132 identifier: Some("RunMat:sub2ind:ProviderError"),
133 when: "GPU provider sub2ind hook fails.",
134 message: "sub2ind: provider execution failed",
135};
136
137const SUB2IND_ERROR_INTERNAL: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
138 code: "RM.SUB2IND.INTERNAL",
139 identifier: Some("RunMat:sub2ind:InternalError"),
140 when: "Internal tensor conversion/output construction fails.",
141 message: "sub2ind: internal error",
142};
143
144const SUB2IND_ERRORS: [BuiltinErrorDescriptor; 4] = [
145 SUB2IND_ERROR_INVALID_INPUT,
146 SUB2IND_ERROR_INDEX_BOUNDS,
147 SUB2IND_ERROR_PROVIDER,
148 SUB2IND_ERROR_INTERNAL,
149];
150
151pub const SUB2IND_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
152 signatures: &SUB2IND_SIGNATURES,
153 output_mode: BuiltinOutputMode::Fixed,
154 completion_policy: BuiltinCompletionPolicy::Public,
155 errors: &SUB2IND_ERRORS,
156};
157
158fn sub2ind_error_with_message(
159 message: impl Into<String>,
160 error: &'static BuiltinErrorDescriptor,
161) -> RuntimeError {
162 let mut builder = build_runtime_error(message).with_builtin(BUILTIN_NAME);
163 if let Some(identifier) = error.identifier {
164 builder = builder.with_identifier(identifier);
165 }
166 builder.build()
167}
168
169fn sub2ind_input_error(message: impl Into<String>) -> RuntimeError {
170 sub2ind_error_with_message(message, &SUB2IND_ERROR_INVALID_INPUT)
171}
172
173fn sub2ind_bounds_error(message: impl Into<String>) -> RuntimeError {
174 sub2ind_error_with_message(message, &SUB2IND_ERROR_INDEX_BOUNDS)
175}
176
177fn sub2ind_provider_error(message: impl Into<String>) -> RuntimeError {
178 sub2ind_error_with_message(message, &SUB2IND_ERROR_PROVIDER)
179}
180
181fn sub2ind_internal_error(message: impl Into<String>) -> RuntimeError {
182 sub2ind_error_with_message(message, &SUB2IND_ERROR_INTERNAL)
183}
184
185#[runtime_builtin(
186 name = "sub2ind",
187 category = "array/indexing",
188 summary = "Convert N-D subscripts to MATLAB-style column-major linear indices.",
189 keywords = "sub2ind,linear index,column major,gpu indexing",
190 accel = "custom",
191 type_resolver(sub2ind_type),
192 descriptor(crate::builtins::array::indexing::sub2ind::SUB2IND_DESCRIPTOR),
193 builtin_path = "crate::builtins::array::indexing::sub2ind"
194)]
195async fn sub2ind_builtin(dims_val: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
196 let (dims_value, dims_was_gpu) = materialize_value(dims_val, "sub2ind").await?;
197 let dims = parse_dims(&dims_value, "sub2ind").await?;
198 if dims.is_empty() {
199 return Err(sub2ind_error("Size vector must have at least one element."));
200 }
201
202 if rest.len() != dims.len() {
203 return Err(sub2ind_error(
204 "The number of subscripts supplied must equal the number of dimensions in the size vector.",
205 ));
206 }
207
208 if let Some(value) = try_gpu_sub2ind(&dims, &rest)? {
209 return Ok(value);
210 }
211
212 let mut saw_gpu = dims_was_gpu;
213 let mut subscripts: Vec<Tensor> = Vec::with_capacity(rest.len());
214 for value in rest {
215 let (materialised, was_gpu) = materialize_value(value, "sub2ind").await?;
216 saw_gpu |= was_gpu;
217 let tensor = tensor::value_into_tensor_for("sub2ind", materialised)
218 .map_err(|message| sub2ind_error(message))?;
219 subscripts.push(tensor);
220 }
221
222 let (result_data, result_shape) = compute_indices(&dims, &subscripts)?;
223 let want_gpu_output = saw_gpu && runmat_accelerate_api::provider().is_some();
224
225 if want_gpu_output {
226 #[cfg(all(test, feature = "wgpu"))]
227 {
228 if runmat_accelerate_api::provider().is_none() {
229 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
230 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
231 );
232 }
233 }
234 let shape = result_shape.clone().unwrap_or_else(|| vec![1, 1]);
235 if let Some(provider) = runmat_accelerate_api::provider() {
236 let view = HostTensorView {
237 data: &result_data,
238 shape: &shape,
239 };
240 if let Ok(handle) = provider.upload(&view) {
241 return Ok(Value::GpuTensor(handle));
242 }
243 }
244 }
245
246 build_host_value(result_data, result_shape)
247}
248
249fn try_gpu_sub2ind(dims: &[usize], subs: &[Value]) -> crate::BuiltinResult<Option<Value>> {
250 #[cfg(target_arch = "wasm32")]
251 {
252 let _ = (dims, subs);
253 Ok(None)
254 }
255 #[cfg(not(target_arch = "wasm32"))]
256 {
257 #[cfg(all(test, feature = "wgpu"))]
258 {
259 if subs
260 .iter()
261 .any(|v| matches!(v, Value::GpuTensor(h) if h.device_id != 0))
262 {
263 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
264 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
265 );
266 }
267 }
268 let provider = match runmat_accelerate_api::provider() {
269 Some(p) => p,
270 None => return Ok(None),
271 };
272 if !subs
273 .iter()
274 .all(|value| matches!(value, Value::GpuTensor(_)))
275 {
276 return Ok(None);
277 }
278 if dims.is_empty() {
279 return Ok(None);
280 }
281
282 let mut handles: Vec<&GpuTensorHandle> = Vec::with_capacity(subs.len());
283 for value in subs {
284 if let Value::GpuTensor(handle) = value {
285 handles.push(handle);
286 }
287 }
288
289 if handles.len() != dims.len() {
290 return Err(sub2ind_error(
291 "The number of subscripts supplied must equal the number of dimensions in the size vector.",
292 ));
293 }
294
295 let mut scalar_mask: Vec<bool> = Vec::with_capacity(handles.len());
296 let mut target_shape: Option<Vec<usize>> = None;
297 let mut result_len: usize = 1;
298 let mut saw_non_scalar = false;
299
300 for handle in &handles {
301 let len = tensor::element_count(&handle.shape);
302 let is_scalar = len == 1;
303 scalar_mask.push(is_scalar);
304 if !is_scalar {
305 saw_non_scalar = true;
306 if let Some(existing) = &target_shape {
307 if existing != &handle.shape {
308 return Err(sub2ind_error("Subscript inputs must have the same size."));
309 }
310 } else {
311 target_shape = Some(handle.shape.clone());
312 result_len = len;
313 }
314 }
315 }
316
317 if !saw_non_scalar {
318 target_shape = Some(vec![1, 1]);
319 result_len = 1;
320 } else if let Some(shape) = &target_shape {
321 result_len = tensor::element_count(shape);
322 }
323
324 let strides = build_strides(dims, "sub2ind")?;
325 if dims.iter().any(|&d| d > u32::MAX as usize)
326 || strides.iter().any(|&s| s > u32::MAX as usize)
327 || result_len > u32::MAX as usize
328 {
329 return Ok(None);
330 }
331
332 let output_shape = target_shape.clone().unwrap_or_else(|| vec![1, 1]);
333 match provider.sub2ind(
334 dims,
335 &strides,
336 &handles,
337 &scalar_mask,
338 result_len,
339 &output_shape,
340 ) {
341 Ok(handle) => Ok(Some(Value::GpuTensor(handle))),
342 Err(err) => Err(sub2ind_provider_error(err.to_string())),
343 }
344 }
345}
346
347fn compute_indices(
348 dims: &[usize],
349 subscripts: &[Tensor],
350) -> crate::BuiltinResult<(Vec<f64>, Option<Vec<usize>>)> {
351 let mut target_shape: Option<Vec<usize>> = None;
352 let mut result_len: usize = 1;
353 let mut has_non_scalar = false;
354
355 for tensor in subscripts {
356 if tensor.data.len() != 1 {
357 has_non_scalar = true;
358 if let Some(shape) = &target_shape {
359 if &tensor.shape != shape {
360 return Err(sub2ind_error("Subscript inputs must have the same size."));
361 }
362 } else {
363 target_shape = Some(tensor.shape.clone());
364 result_len = tensor.data.len();
365 }
366 }
367 }
368
369 if !has_non_scalar {
370 target_shape = Some(vec![1, 1]);
372 result_len = 1;
373 }
374
375 if result_len == 0 {
376 return Ok((Vec::new(), target_shape));
377 }
378
379 let strides = build_strides(dims, "sub2ind")?;
380 let mut output = Vec::with_capacity(result_len);
381
382 for idx in 0..result_len {
383 let mut offset: usize = 0;
384 for (dim_index, (&dim, tensor)) in dims.iter().zip(subscripts.iter()).enumerate() {
385 let raw = subscript_value(tensor, idx);
386 let coerced = coerce_subscript(raw, dim_index + 1, dim)?;
387 let term = coerced
388 .checked_sub(1)
389 .and_then(|v| v.checked_mul(strides[dim_index]))
390 .ok_or_else(|| sub2ind_bounds_error("Index exceeds array dimensions."))?;
391 offset = offset
392 .checked_add(term)
393 .ok_or_else(|| sub2ind_bounds_error("Index exceeds array dimensions."))?;
394 }
395 output.push((offset + 1) as f64);
396 }
397
398 Ok((output, target_shape))
399}
400
401fn subscript_value(tensor: &Tensor, idx: usize) -> f64 {
402 if tensor.data.len() == 1 {
403 tensor.data[0]
404 } else {
405 tensor.data[idx]
406 }
407}
408
409fn coerce_subscript(value: f64, dim_number: usize, dim_size: usize) -> crate::BuiltinResult<usize> {
410 if !value.is_finite() {
411 return Err(sub2ind_error(
412 "Subscript indices must either be real positive integers or logicals.",
413 ));
414 }
415 let rounded = value.round();
416 if (rounded - value).abs() > f64::EPSILON {
417 return Err(sub2ind_error(
418 "Subscript indices must either be real positive integers or logicals.",
419 ));
420 }
421 if rounded < 1.0 {
422 return Err(sub2ind_error(
423 "Subscript indices must either be real positive integers or logicals.",
424 ));
425 }
426 if rounded > dim_size as f64 {
427 return Err(dimension_bounds_error(dim_number));
428 }
429 Ok(rounded as usize)
430}
431
432fn dimension_bounds_error(dim_number: usize) -> RuntimeError {
433 let message = match dim_number {
434 1 => format!("Index exceeds the number of rows in dimension {dim_number}."),
435 2 => format!("Index exceeds the number of columns in dimension {dim_number}."),
436 3 => format!("Index exceeds the number of pages in dimension {dim_number}."),
437 _ => "Index exceeds array dimensions.".to_string(),
438 };
439 sub2ind_bounds_error(message)
440}
441
442fn build_host_value(data: Vec<f64>, shape: Option<Vec<usize>>) -> crate::BuiltinResult<Value> {
443 let shape = shape.unwrap_or_else(|| vec![1, 1]);
444 if data.len() == 1 && tensor::element_count(&shape) == 1 {
445 Ok(Value::Num(data[0]))
446 } else {
447 let tensor = Tensor::new(data, shape).map_err(|e| {
448 sub2ind_internal_error(format!("Unable to construct sub2ind output: {e}"))
449 })?;
450 Ok(Value::Tensor(tensor))
451 }
452}
453
454fn sub2ind_error(message: impl Into<String>) -> RuntimeError {
455 sub2ind_input_error(message)
456}
457
458#[cfg(test)]
459pub(crate) mod tests {
460 use super::*;
461 use crate::builtins::common::test_support;
462 use futures::executor::block_on;
463 use runmat_builtins::{IntValue, Tensor, Type, Value};
464
465 fn sub2ind_builtin(dims_val: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
466 block_on(super::sub2ind_builtin(dims_val, rest))
467 }
468
469 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
470 #[test]
471 fn converts_scalar_indices() {
472 let dims = Tensor::new(vec![3.0, 4.0], vec![1, 2]).unwrap();
473 let result =
474 sub2ind_builtin(Value::Tensor(dims), vec![Value::Num(2.0), Value::Num(3.0)]).unwrap();
475 assert_eq!(result, Value::Num(8.0));
476 }
477
478 #[test]
479 fn sub2ind_type_scalar_outputs_num() {
480 assert_eq!(
481 sub2ind_type(
482 &[Type::Tensor { shape: None }, Type::Num, Type::Int],
483 &ResolveContext::new(Vec::new()),
484 ),
485 Type::Num
486 );
487 }
488
489 #[test]
490 fn sub2ind_type_vector_outputs_tensor() {
491 let subs = Type::Tensor {
492 shape: Some(vec![Some(3), Some(1)]),
493 };
494 assert_eq!(
495 sub2ind_type(
496 &[Type::Tensor { shape: None }, subs.clone(), Type::Num],
497 &ResolveContext::new(Vec::new()),
498 ),
499 Type::Tensor {
500 shape: Some(vec![Some(3), Some(1)])
501 }
502 );
503 }
504
505 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
506 #[test]
507 fn broadcasts_scalars_over_vectors() {
508 let dims = Tensor::new(vec![3.0, 4.0], vec![1, 2]).unwrap();
509 let rows = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
510 let result = sub2ind_builtin(
511 Value::Tensor(dims),
512 vec![Value::Tensor(rows), Value::Num(4.0)],
513 )
514 .unwrap();
515 match result {
516 Value::Tensor(t) => {
517 assert_eq!(t.shape, vec![3, 1]);
518 assert_eq!(t.data, vec![10.0, 11.0, 12.0]);
519 }
520 other => panic!("expected tensor result, got {other:?}"),
521 }
522 }
523
524 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
525 #[test]
526 fn handles_three_dimensions() {
527 let dims = Tensor::new(vec![2.0, 3.0, 4.0], vec![1, 3]).unwrap();
528 let row = Tensor::new(vec![1.0, 1.0], vec![1, 2]).unwrap();
529 let col = Tensor::new(vec![2.0, 3.0], vec![1, 2]).unwrap();
530 let page = Tensor::new(vec![1.0, 2.0], vec![1, 2]).unwrap();
531 let result = sub2ind_builtin(
532 Value::Tensor(dims),
533 vec![Value::Tensor(row), Value::Tensor(col), Value::Tensor(page)],
534 )
535 .unwrap();
536 match result {
537 Value::Tensor(t) => {
538 assert_eq!(t.shape, vec![1, 2]);
539 assert_eq!(t.data, vec![3.0, 11.0]);
540 }
541 other => panic!("expected tensor result, got {other:?}"),
542 }
543 }
544
545 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
546 #[test]
547 fn rejects_out_of_range_subscripts() {
548 let dims = Tensor::new(vec![3.0, 4.0], vec![1, 2]).unwrap();
549 let err = sub2ind_builtin(Value::Tensor(dims), vec![Value::Num(4.0), Value::Num(1.0)])
550 .unwrap_err();
551 assert!(
552 err.to_string().contains("Index exceeds"),
553 "expected index bounds error, got {err}"
554 );
555 assert_eq!(
556 err.identifier(),
557 super::SUB2IND_ERROR_INDEX_BOUNDS.identifier
558 );
559 }
560
561 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
562 #[test]
563 fn rejects_shape_mismatch() {
564 let dims = Tensor::new(vec![3.0, 4.0], vec![1, 2]).unwrap();
565 let rows = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
566 let cols = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
567 let err = sub2ind_builtin(
568 Value::Tensor(dims),
569 vec![Value::Tensor(rows), Value::Tensor(cols)],
570 )
571 .unwrap_err();
572 assert!(
573 err.to_string().contains("same size"),
574 "expected size mismatch error, got {err}"
575 );
576 assert_eq!(
577 err.identifier(),
578 super::SUB2IND_ERROR_INVALID_INPUT.identifier
579 );
580 }
581
582 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
583 #[test]
584 fn rejects_non_integer_subscripts() {
585 let dims = Tensor::new(vec![3.0, 4.0], vec![1, 2]).unwrap();
586 let err = sub2ind_builtin(Value::Tensor(dims), vec![Value::Num(1.5), Value::Num(1.0)])
587 .unwrap_err();
588 assert!(
589 err.to_string().contains("real positive integers"),
590 "expected integer coercion error, got {err}"
591 );
592 assert_eq!(
593 err.identifier(),
594 super::SUB2IND_ERROR_INVALID_INPUT.identifier
595 );
596 }
597
598 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
599 #[test]
600 fn accepts_integer_value_variants() {
601 let dims = Value::Tensor(Tensor::new(vec![3.0], vec![1, 1]).unwrap());
602 let result = sub2ind_builtin(dims, vec![Value::Int(IntValue::I32(2))]).expect("sub2ind");
603 assert_eq!(result, Value::Num(2.0));
604 }
605
606 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
607 #[test]
608 fn sub2ind_gpu_roundtrip() {
609 test_support::with_test_provider(|provider| {
610 let dims = Tensor::new(vec![3.0, 4.0], vec![1, 2]).unwrap();
611 let rows = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
612 let cols = Tensor::new(vec![4.0, 4.0, 4.0], vec![3, 1]).unwrap();
613
614 let dims_handle = provider
615 .upload(&HostTensorView {
616 data: &dims.data,
617 shape: &dims.shape,
618 })
619 .expect("upload dims");
620 let rows_handle = provider
621 .upload(&HostTensorView {
622 data: &rows.data,
623 shape: &rows.shape,
624 })
625 .expect("upload rows");
626 let cols_handle = provider
627 .upload(&HostTensorView {
628 data: &cols.data,
629 shape: &cols.shape,
630 })
631 .expect("upload cols");
632
633 let result = sub2ind_builtin(
634 Value::GpuTensor(dims_handle),
635 vec![Value::GpuTensor(rows_handle), Value::GpuTensor(cols_handle)],
636 )
637 .expect("sub2ind");
638
639 match result {
640 Value::GpuTensor(handle) => {
641 let gathered = test_support::gather(Value::GpuTensor(handle)).unwrap();
642 assert_eq!(gathered.shape, vec![3, 1]);
643 assert_eq!(gathered.data, vec![10.0, 11.0, 12.0]);
644 }
645 other => panic!("expected gpu tensor, got {other:?}"),
646 }
647 });
648 }
649
650 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
651 #[test]
652 #[cfg(feature = "wgpu")]
653 fn sub2ind_wgpu_matches_cpu() {
654 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
655 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
656 );
657 let Some(provider) = runmat_accelerate_api::provider() else {
658 panic!("wgpu provider not available");
659 };
660
661 let dims = Tensor::new(vec![3.0, 4.0], vec![1, 2]).unwrap();
662 let rows = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
663 let cols = Tensor::new(vec![4.0, 4.0, 4.0], vec![3, 1]).unwrap();
664
665 let cpu = sub2ind_builtin(
666 Value::Tensor(dims.clone()),
667 vec![Value::Tensor(rows.clone()), Value::Tensor(cols.clone())],
668 )
669 .expect("cpu sub2ind");
670
671 let rows_handle = provider
672 .upload(&HostTensorView {
673 data: &rows.data,
674 shape: &rows.shape,
675 })
676 .expect("upload rows");
677 let cols_handle = provider
678 .upload(&HostTensorView {
679 data: &cols.data,
680 shape: &cols.shape,
681 })
682 .expect("upload cols");
683
684 let result = sub2ind_builtin(
685 Value::Tensor(dims),
686 vec![Value::GpuTensor(rows_handle), Value::GpuTensor(cols_handle)],
687 )
688 .expect("wgpu sub2ind");
689
690 let gathered = test_support::gather(result).expect("gather");
691 let expected = match cpu {
692 Value::Tensor(t) => t,
693 Value::Num(v) => Tensor::new(vec![v], vec![1, 1]).unwrap(),
694 other => panic!("unexpected cpu result {other:?}"),
695 };
696 assert_eq!(gathered.shape, expected.shape);
697 assert_eq!(gathered.data, expected.data);
698 }
699}