1use crate::builtins::acceleration::gpu::type_resolvers::gpuarray_type;
9use crate::builtins::common::spec::{
10 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
11 ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
12};
13use crate::builtins::common::{gpu_helpers, tensor};
14use runmat_accelerate_api::{GpuTensorHandle, HostTensorView, ProviderPrecision};
15use runmat_builtins::{CharArray, IntValue, Tensor, Value};
16use runmat_macros::runtime_builtin;
17
18use crate::{build_runtime_error, BuiltinResult, RuntimeError};
19
20const ERR_NO_PROVIDER: &str = "gpuArray: no acceleration provider registered";
21
22fn gpu_array_error(message: impl Into<String>) -> RuntimeError {
23 build_runtime_error(message)
24 .with_builtin("gpuArray")
25 .build()
26}
27
28#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::acceleration::gpu::gpuarray")]
29pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
30 name: "gpuArray",
31 op_kind: GpuOpKind::Custom("upload"),
32 supported_precisions: &[ScalarType::F32, ScalarType::F64],
33 broadcast: BroadcastSemantics::None,
34 provider_hooks: &[ProviderHook::Custom("upload")],
35 constant_strategy: ConstantStrategy::InlineLiteral,
36 residency: ResidencyPolicy::NewHandle,
37 nan_mode: ReductionNaN::Include,
38 two_pass_threshold: None,
39 workgroup_size: None,
40 accepts_nan_mode: false,
41 notes: "Invokes the provider `upload` hook, reuploading gpuArray inputs when dtype conversion is requested. Handles class strings, size vectors, and `'like'` prototypes.",
42};
43
44#[runmat_macros::register_fusion_spec(
45 builtin_path = "crate::builtins::acceleration::gpu::gpuarray"
46)]
47pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
48 name: "gpuArray",
49 shape: ShapeRequirements::Any,
50 constant_strategy: ConstantStrategy::InlineLiteral,
51 elementwise: None,
52 reduction: None,
53 emits_nan: false,
54 notes:
55 "Acts as a residency boundary; fusion graphs never cross explicit host↔device transfers.",
56};
57
58#[runtime_builtin(
59 name = "gpuArray",
60 category = "acceleration/gpu",
61 summary = "Move data to the GPU and return a gpuArray handle.",
62 keywords = "gpuArray,gpu,accelerate,upload,dtype,like",
63 examples = "G = gpuArray([1 2 3], 'single');",
64 accel = "array_construct",
65 type_resolver(gpuarray_type),
66 builtin_path = "crate::builtins::acceleration::gpu::gpuarray"
67)]
68async fn gpu_array_builtin(value: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
69 let options = parse_options(&rest)?;
70 let incoming_precision = match &value {
71 Value::GpuTensor(handle) => runmat_accelerate_api::handle_precision(handle),
72 _ => None,
73 };
74 let dtype = resolve_dtype(&value, &options)?;
75 let dims = options.dims.clone();
76
77 let prepared = match value {
78 Value::GpuTensor(handle) => convert_device_value(handle, dtype).await?,
79 other => upload_host_value(other, dtype)?,
80 };
81
82 let mut handle = prepared.handle;
83
84 if let Some(dims) = dims.as_ref() {
85 apply_dims(&mut handle, dims)?;
86 }
87
88 let provider_precision = runmat_accelerate_api::provider()
89 .map(|p| p.precision())
90 .unwrap_or(ProviderPrecision::F64);
91 let requested_precision = match dtype {
92 DataClass::Single => Some(ProviderPrecision::F32),
93 _ => None,
94 };
95 let final_precision = requested_precision
96 .or(incoming_precision)
97 .unwrap_or(provider_precision);
98 runmat_accelerate_api::set_handle_precision(&handle, final_precision);
99
100 runmat_accelerate_api::set_handle_logical(&handle, prepared.logical);
101
102 Ok(Value::GpuTensor(handle))
103}
104
105#[derive(Clone, Copy, Debug, PartialEq, Eq)]
106enum DataClass {
107 Double,
108 Single,
109 Logical,
110 Int8,
111 Int16,
112 Int32,
113 Int64,
114 UInt8,
115 UInt16,
116 UInt32,
117 UInt64,
118}
119
120impl DataClass {
121 fn from_tag(tag: &str) -> Option<Self> {
122 match tag {
123 "double" => Some(Self::Double),
124 "single" | "float32" => Some(Self::Single),
125 "logical" | "bool" | "boolean" => Some(Self::Logical),
126 "int8" => Some(Self::Int8),
127 "int16" => Some(Self::Int16),
128 "int32" | "int" => Some(Self::Int32),
129 "int64" => Some(Self::Int64),
130 "uint8" => Some(Self::UInt8),
131 "uint16" => Some(Self::UInt16),
132 "uint32" => Some(Self::UInt32),
133 "uint64" => Some(Self::UInt64),
134 "gpuarray" => None, _ => None,
136 }
137 }
138}
139
140#[derive(Debug, Default)]
141struct ParsedOptions {
142 dims: Option<Vec<usize>>,
143 explicit_dtype: Option<DataClass>,
144 prototype: Option<Value>,
145}
146
147fn parse_options(rest: &[Value]) -> BuiltinResult<ParsedOptions> {
148 let (index_after_dims, dims) = parse_size_arguments(rest)?;
149 let mut options = ParsedOptions {
150 dims,
151 ..ParsedOptions::default()
152 };
153
154 let mut idx = index_after_dims;
155 while idx < rest.len() {
156 let tag = value_to_lower_string(&rest[idx]).ok_or_else(|| {
157 gpu_array_error(format!(
158 "gpuArray: unexpected argument {:?}; expected a class string or the keyword 'like'",
159 rest[idx]
160 ))
161 })?;
162
163 match tag.as_str() {
164 "like" => {
165 idx += 1;
166 if idx >= rest.len() {
167 return Err(gpu_array_error(
168 "gpuArray: expected a prototype value after 'like'",
169 ));
170 }
171 if options.prototype.is_some() {
172 return Err(gpu_array_error("gpuArray: duplicate 'like' qualifier"));
173 }
174 options.prototype = Some(rest[idx].clone());
175 }
176 "distributed" | "codistributed" => {
177 return Err(gpu_array_error(
178 "gpuArray: codistributed arrays are not supported yet",
179 ));
180 }
181 tag => {
182 if let Some(class) = DataClass::from_tag(tag) {
183 if let Some(existing) = options.explicit_dtype {
184 if existing != class {
185 return Err(gpu_array_error(
186 "gpuArray: conflicting type qualifiers supplied",
187 ));
188 }
189 } else {
190 options.explicit_dtype = Some(class);
191 }
192 } else if tag != "gpuarray" {
193 return Err(gpu_array_error(format!(
194 "gpuArray: unrecognised option '{tag}'",
195 )));
196 }
197 }
198 }
199
200 idx += 1;
201 }
202
203 Ok(options)
204}
205
206fn parse_size_arguments(rest: &[Value]) -> BuiltinResult<(usize, Option<Vec<usize>>)> {
207 let mut idx = 0;
208 let mut dims: Vec<usize> = Vec::new();
209 let mut vector_consumed = false;
210
211 while idx < rest.len() {
212 match &rest[idx] {
214 Value::String(_) | Value::StringArray(_) | Value::CharArray(_) => break,
215 _ => {}
216 }
217
218 match &rest[idx] {
219 Value::Int(i) => {
220 dims.push(int_to_dim(i)?);
221 }
222 Value::Num(n) => {
223 dims.push(float_to_dim(*n)?);
224 }
225 Value::Tensor(t) => {
226 if vector_consumed || !dims.is_empty() {
227 return Err(gpu_array_error(
228 "gpuArray: size vectors cannot be combined with scalar dimensions",
229 ));
230 }
231 dims = tensor_to_dims(t)?;
232 vector_consumed = true;
233 }
234 _ => break,
235 }
236 idx += 1;
237 }
238
239 let dims_option = if dims.is_empty() { None } else { Some(dims) };
240 Ok((idx, dims_option))
241}
242
243fn value_to_lower_string(value: &Value) -> Option<String> {
244 crate::builtins::common::tensor::value_to_string(value).map(|s| s.trim().to_ascii_lowercase())
245}
246
247fn int_to_dim(value: &IntValue) -> BuiltinResult<usize> {
248 let raw = value.to_i64();
249 if raw < 0 {
250 return Err(gpu_array_error(
251 "gpuArray: size arguments must be non-negative integers",
252 ));
253 }
254 Ok(raw as usize)
255}
256
257fn float_to_dim(value: f64) -> BuiltinResult<usize> {
258 if !value.is_finite() {
259 return Err(gpu_array_error(
260 "gpuArray: size arguments must be finite integers",
261 ));
262 }
263 let rounded = value.round();
264 if (rounded - value).abs() > f64::EPSILON {
265 return Err(gpu_array_error("gpuArray: size arguments must be integers"));
266 }
267 if rounded < 0.0 {
268 return Err(gpu_array_error(
269 "gpuArray: size arguments must be non-negative",
270 ));
271 }
272 Ok(rounded as usize)
273}
274
275fn tensor_to_dims(tensor: &Tensor) -> BuiltinResult<Vec<usize>> {
276 let mut dims = Vec::with_capacity(tensor.data.len());
277 for value in &tensor.data {
278 dims.push(float_to_dim(*value)?);
279 }
280 Ok(dims)
281}
282
283fn resolve_dtype(value: &Value, options: &ParsedOptions) -> BuiltinResult<DataClass> {
284 if let Some(explicit) = options.explicit_dtype {
285 return Ok(explicit);
286 }
287 if let Some(prototype) = options.prototype.as_ref() {
288 return infer_dtype_from_prototype(prototype);
289 }
290 if value_defaults_to_logical(value) {
291 return Ok(DataClass::Logical);
292 }
293 Ok(DataClass::Double)
294}
295
296fn infer_dtype_from_prototype(proto: &Value) -> BuiltinResult<DataClass> {
297 match proto {
298 Value::GpuTensor(handle) => {
299 if runmat_accelerate_api::handle_is_logical(handle) {
300 Ok(DataClass::Logical)
301 } else {
302 Ok(DataClass::Double)
303 }
304 }
305 Value::LogicalArray(_) | Value::Bool(_) => Ok(DataClass::Logical),
306 Value::Int(int) => Ok(match int {
307 IntValue::I8(_) => DataClass::Int8,
308 IntValue::I16(_) => DataClass::Int16,
309 IntValue::I32(_) => DataClass::Int32,
310 IntValue::I64(_) => DataClass::Int64,
311 IntValue::U8(_) => DataClass::UInt8,
312 IntValue::U16(_) => DataClass::UInt16,
313 IntValue::U32(_) => DataClass::UInt32,
314 IntValue::U64(_) => DataClass::UInt64,
315 }),
316 Value::Tensor(_) | Value::Num(_) => Ok(DataClass::Double),
317 Value::CharArray(_) => Ok(DataClass::Double),
318 Value::String(_) => Err(gpu_array_error(
319 "gpuArray: 'like' does not accept MATLAB string scalars; convert to char() first",
320 )),
321 Value::StringArray(_) => Err(gpu_array_error(
322 "gpuArray: 'like' does not accept string arrays; convert to char arrays first",
323 )),
324 Value::Complex(_, _) | Value::ComplexTensor(_) => Err(gpu_array_error(
325 "gpuArray: complex prototypes are not supported yet; provide real-valued inputs",
326 )),
327 other => Err(gpu_array_error(format!(
328 "gpuArray: unsupported 'like' prototype type {other:?}; expected numeric or logical values"
329 ))),
330 }
331}
332
333fn value_defaults_to_logical(value: &Value) -> bool {
334 match value {
335 Value::LogicalArray(_) | Value::Bool(_) => true,
336 Value::GpuTensor(handle) => runmat_accelerate_api::handle_is_logical(handle),
337 _ => false,
338 }
339}
340
341struct PreparedHandle {
342 handle: GpuTensorHandle,
343 logical: bool,
344}
345
346fn upload_host_value(value: Value, dtype: DataClass) -> BuiltinResult<PreparedHandle> {
347 #[cfg(all(test, feature = "wgpu"))]
348 {
349 if runmat_accelerate_api::provider().is_none() {
350 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
351 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
352 );
353 }
354 }
355 let provider =
356 runmat_accelerate_api::provider().ok_or_else(|| gpu_array_error(ERR_NO_PROVIDER))?;
357
358 let tensor = coerce_host_value(value)?;
359 let (mut tensor, logical) = cast_tensor(tensor, dtype)?;
360
361 let view = HostTensorView {
362 data: &tensor.data,
363 shape: &tensor.shape,
364 };
365 let new_handle = provider
366 .upload(&view)
367 .map_err(|err| gpu_array_error(format!("gpuArray: {err}")))?;
368
369 tensor.data.clear();
370
371 Ok(PreparedHandle {
372 handle: new_handle,
373 logical,
374 })
375}
376
377async fn convert_device_value(
378 handle: GpuTensorHandle,
379 dtype: DataClass,
380) -> BuiltinResult<PreparedHandle> {
381 let was_logical = runmat_accelerate_api::handle_is_logical(&handle);
382 match dtype {
383 DataClass::Double => {
384 return Ok(PreparedHandle {
385 handle,
386 logical: false,
387 });
388 }
389 DataClass::Logical => {
390 if was_logical {
391 return Ok(PreparedHandle {
392 handle,
393 logical: true,
394 });
395 }
396 }
397 _ => {}
398 }
399
400 let provider =
401 runmat_accelerate_api::provider().ok_or_else(|| gpu_array_error(ERR_NO_PROVIDER))?;
402 let tensor = gpu_helpers::gather_tensor_async(&handle)
403 .await
404 .map_err(|err| gpu_array_error(err.to_string()))?;
405 let (mut tensor, logical) = cast_tensor(tensor, dtype)?;
406
407 let view = HostTensorView {
408 data: &tensor.data,
409 shape: &tensor.shape,
410 };
411 let new_handle = provider
412 .upload(&view)
413 .map_err(|err| gpu_array_error(format!("gpuArray: {err}")))?;
414
415 provider.free(&handle).ok();
416 tensor.data.clear();
417
418 Ok(PreparedHandle {
419 handle: new_handle,
420 logical,
421 })
422}
423
424fn coerce_host_value(value: Value) -> BuiltinResult<Tensor> {
425 match value {
426 Value::Tensor(t) => Ok(t),
427 Value::LogicalArray(logical) => tensor::logical_to_tensor(&logical)
428 .map_err(|err| gpu_array_error(format!("gpuArray: {err}"))),
429 Value::Bool(flag) => Tensor::new(vec![if flag { 1.0 } else { 0.0 }], vec![1, 1])
430 .map_err(|err| gpu_array_error(format!("gpuArray: {err}"))),
431 Value::Num(n) => Tensor::new(vec![n], vec![1, 1])
432 .map_err(|err| gpu_array_error(format!("gpuArray: {err}"))),
433 Value::Int(i) => Tensor::new(vec![i.to_f64()], vec![1, 1])
434 .map_err(|err| gpu_array_error(format!("gpuArray: {err}"))),
435 Value::CharArray(ca) => char_array_to_tensor(&ca),
436 Value::String(text) => {
437 let ca = CharArray::new_row(&text);
438 char_array_to_tensor(&ca)
439 }
440 Value::StringArray(_) => Err(gpu_array_error(
441 "gpuArray: string arrays are not supported yet; convert to char arrays with CHAR first",
442 )),
443 Value::Complex(_, _) | Value::ComplexTensor(_) => Err(gpu_array_error(
444 "gpuArray: complex inputs are not supported yet; split real and imaginary parts before uploading",
445 )),
446 other => Err(gpu_array_error(format!(
447 "gpuArray: unsupported input type for GPU transfer: {other:?}"
448 ))),
449 }
450}
451
452fn cast_tensor(mut tensor: Tensor, dtype: DataClass) -> BuiltinResult<(Tensor, bool)> {
453 let logical = match dtype {
454 DataClass::Logical => {
455 convert_to_logical(&mut tensor.data)?;
456 true
457 }
458 DataClass::Single => {
459 convert_to_single(&mut tensor.data);
460 false
461 }
462 DataClass::Int8 => {
463 convert_to_int_range(&mut tensor.data, i8::MIN as f64, i8::MAX as f64);
464 false
465 }
466 DataClass::Int16 => {
467 convert_to_int_range(&mut tensor.data, i16::MIN as f64, i16::MAX as f64);
468 false
469 }
470 DataClass::Int32 => {
471 convert_to_int_range(&mut tensor.data, i32::MIN as f64, i32::MAX as f64);
472 false
473 }
474 DataClass::Int64 => {
475 convert_to_int_range(&mut tensor.data, i64::MIN as f64, i64::MAX as f64);
476 false
477 }
478 DataClass::UInt8 => {
479 convert_to_int_range(&mut tensor.data, 0.0, u8::MAX as f64);
480 false
481 }
482 DataClass::UInt16 => {
483 convert_to_int_range(&mut tensor.data, 0.0, u16::MAX as f64);
484 false
485 }
486 DataClass::UInt32 => {
487 convert_to_int_range(&mut tensor.data, 0.0, u32::MAX as f64);
488 false
489 }
490 DataClass::UInt64 => {
491 convert_to_int_range(&mut tensor.data, 0.0, u64::MAX as f64);
492 false
493 }
494 DataClass::Double => false,
495 };
496
497 Ok((tensor, logical))
498}
499
500fn convert_to_logical(data: &mut [f64]) -> BuiltinResult<()> {
501 for value in data.iter_mut() {
502 if value.is_nan() {
503 return Err(gpu_array_error("gpuArray: cannot convert NaN to logical"));
504 }
505 *value = if *value != 0.0 { 1.0 } else { 0.0 };
506 }
507 Ok(())
508}
509
510fn convert_to_single(data: &mut [f64]) {
511 for value in data.iter_mut() {
512 *value = (*value as f32) as f64;
513 }
514}
515
516fn convert_to_int_range(data: &mut [f64], min: f64, max: f64) {
517 for value in data.iter_mut() {
518 if value.is_nan() {
519 *value = min;
520 continue;
521 }
522 if value.is_infinite() {
523 *value = if value.is_sign_negative() { min } else { max };
524 continue;
525 }
526 let rounded = value.round();
527 *value = rounded.clamp(min, max);
528 }
529}
530
531fn apply_dims(handle: &mut GpuTensorHandle, dims: &[usize]) -> BuiltinResult<()> {
532 let new_elems: usize = dims.iter().product();
533 let current_elems: usize = if handle.shape.is_empty() {
534 new_elems
535 } else {
536 handle.shape.iter().product()
537 };
538 if new_elems != current_elems {
539 return Err(gpu_array_error(format!(
540 "gpuArray: cannot reshape gpuArray of {current_elems} elements into size {:?}",
541 dims
542 )));
543 }
544 handle.shape = dims.to_vec();
545 Ok(())
546}
547
548fn char_array_to_tensor(ca: &CharArray) -> BuiltinResult<Tensor> {
549 let rows = ca.rows;
550 let cols = ca.cols;
551 if rows == 0 || cols == 0 {
552 return Tensor::new(Vec::new(), vec![rows, cols])
553 .map_err(|err| gpu_array_error(format!("gpuArray: {err}")));
554 }
555 let mut data = vec![0.0; rows * cols];
556 for row in 0..rows {
558 for col in 0..cols {
559 let idx_char = row * cols + col;
560 let ch = ca.data[idx_char];
561 data[row * cols + col] = ch as u32 as f64;
562 }
563 }
564 Tensor::new(data, vec![rows, cols]).map_err(|err| gpu_array_error(format!("gpuArray: {err}")))
565}
566
567#[cfg(test)]
568pub(crate) mod tests {
569 use super::*;
570 use crate::builtins::common::test_support;
571 use futures::executor::block_on;
572 use runmat_accelerate_api::HostTensorView;
573 use runmat_builtins::{IntValue, LogicalArray, ResolveContext, Type};
574
575 fn call(value: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
576 block_on(gpu_array_builtin(value, rest))
577 }
578
579 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
580 #[test]
581 fn gpu_array_transfers_numeric_tensor() {
582 test_support::with_test_provider(|_| {
583 let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
584 let result = call(Value::Tensor(tensor.clone()), Vec::new()).expect("gpuArray upload");
585 let Value::GpuTensor(handle) = result else {
586 panic!("expected gpu tensor");
587 };
588 assert_eq!(handle.shape, tensor.shape);
589 let gathered =
590 test_support::gather(Value::GpuTensor(handle.clone())).expect("gather values");
591 assert_eq!(gathered.shape, tensor.shape);
592 assert_eq!(gathered.data, tensor.data);
593 });
594 }
595
596 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
597 #[test]
598 fn gpu_array_marks_logical_inputs() {
599 test_support::with_test_provider(|_| {
600 let logical =
601 LogicalArray::new(vec![1, 0, 1, 1], vec![2, 2]).expect("logical construction");
602 let result =
603 call(Value::LogicalArray(logical.clone()), Vec::new()).expect("gpuArray logical");
604 let Value::GpuTensor(handle) = result else {
605 panic!("expected gpu tensor");
606 };
607 assert!(runmat_accelerate_api::handle_is_logical(&handle));
608 let gathered =
609 test_support::gather(Value::GpuTensor(handle.clone())).expect("gather logical");
610 assert_eq!(gathered.shape, logical.shape);
611 assert_eq!(gathered.data, vec![1.0, 0.0, 1.0, 1.0]);
612 });
613 }
614
615 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
616 #[test]
617 fn gpu_array_handles_scalar_bool() {
618 test_support::with_test_provider(|_| {
619 let result = call(Value::Bool(true), Vec::new()).expect("gpuArray bool");
620 let Value::GpuTensor(handle) = result else {
621 panic!("expected gpu tensor");
622 };
623 assert!(runmat_accelerate_api::handle_is_logical(&handle));
624 let gathered =
625 test_support::gather(Value::GpuTensor(handle.clone())).expect("gather bool");
626 assert_eq!(gathered.shape, vec![1, 1]);
627 assert_eq!(gathered.data, vec![1.0]);
628 });
629 }
630
631 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
632 #[test]
633 fn gpu_array_supports_char_arrays() {
634 test_support::with_test_provider(|_| {
635 let chars = CharArray::new("row1row2".chars().collect(), 2, 4).unwrap();
636 let original: Vec<char> = chars.data.clone();
637 let result =
638 call(Value::CharArray(chars), Vec::new()).expect("gpuArray char array upload");
639 let Value::GpuTensor(handle) = result else {
640 panic!("expected gpu tensor");
641 };
642 let gathered =
643 test_support::gather(Value::GpuTensor(handle.clone())).expect("gather chars");
644 assert_eq!(gathered.shape, vec![2, 4]);
645 let mut recovered = Vec::new();
646 for col in 0..4 {
647 for row in 0..2 {
648 let idx = row + col * 2;
649 let code = gathered.data[idx];
650 let ch = char::from_u32(code as u32)
651 .expect("valid unicode scalar from numeric code");
652 recovered.push(ch);
653 }
654 }
655 assert_eq!(recovered, original);
656 });
657 }
658
659 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
660 #[test]
661 fn gpu_array_converts_strings() {
662 test_support::with_test_provider(|_| {
663 let result = call(Value::String("gpu".into()), Vec::new()).expect("gpuArray string");
664 let Value::GpuTensor(handle) = result else {
665 panic!("expected gpu tensor");
666 };
667 let gathered =
668 test_support::gather(Value::GpuTensor(handle.clone())).expect("gather string");
669 assert_eq!(gathered.shape, vec![1, 3]);
670 let expected: Vec<f64> = "gpu".chars().map(|ch| ch as u32 as f64).collect();
671 assert_eq!(gathered.data, expected);
672 });
673 }
674
675 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
676 #[test]
677 fn gpu_array_passthrough_existing_handle() {
678 test_support::with_test_provider(|provider| {
679 let tensor = Tensor::new(vec![5.0, 6.0], vec![2, 1]).unwrap();
680 let view = HostTensorView {
681 data: &tensor.data,
682 shape: &tensor.shape,
683 };
684 let handle = provider.upload(&view).expect("upload");
685 let cloned = handle.clone();
686 let result =
687 call(Value::GpuTensor(handle.clone()), Vec::new()).expect("gpuArray passthrough");
688 let Value::GpuTensor(returned) = result else {
689 panic!("expected gpu tensor");
690 };
691 assert_eq!(returned.buffer_id, cloned.buffer_id);
692 assert_eq!(returned.shape, cloned.shape);
693 });
694 }
695
696 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
697 #[test]
698 fn gpu_array_casts_to_int32() {
699 test_support::with_test_provider(|_| {
700 let tensor = Tensor::new(vec![1.2, -3.7, 123456.0], vec![3, 1]).unwrap();
701 let result =
702 call(Value::Tensor(tensor), vec![Value::from("int32")]).expect("gpuArray int32");
703 let Value::GpuTensor(handle) = result else {
704 panic!("expected gpu tensor");
705 };
706 let gathered =
707 test_support::gather(Value::GpuTensor(handle.clone())).expect("gather int32");
708 assert_eq!(gathered.data, vec![1.0, -4.0, 123456.0]);
709 });
710 }
711
712 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
713 #[test]
714 fn gpu_array_casts_to_uint8() {
715 test_support::with_test_provider(|_| {
716 let tensor = Tensor::new(vec![-12.0, 12.8, 300.4, f64::INFINITY], vec![4, 1]).unwrap();
717 let result =
718 call(Value::Tensor(tensor), vec![Value::from("uint8")]).expect("gpuArray uint8");
719 let Value::GpuTensor(handle) = result else {
720 panic!("expected gpu tensor");
721 };
722 let gathered =
723 test_support::gather(Value::GpuTensor(handle.clone())).expect("gather uint8");
724 assert_eq!(gathered.data, vec![0.0, 13.0, 255.0, 255.0]);
725 });
726 }
727
728 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
729 #[test]
730 fn gpu_array_single_precision_rounds() {
731 test_support::with_test_provider(|_| {
732 let tensor = Tensor::new(vec![1.23456789, -9.87654321], vec![2, 1]).unwrap();
733 let result =
734 call(Value::Tensor(tensor), vec![Value::from("single")]).expect("gpuArray single");
735 let Value::GpuTensor(handle) = result else {
736 panic!("expected gpu tensor");
737 };
738 let gathered =
739 test_support::gather(Value::GpuTensor(handle.clone())).expect("gather single");
740 let expected = [1.234_567_9_f32 as f64, (-9.876_543_f32) as f64];
741 for (observed, expected) in gathered.data.iter().zip(expected.iter()) {
742 assert!((observed - expected).abs() < 1e-6);
743 }
744 });
745 }
746
747 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
748 #[test]
749 fn gpu_array_like_infers_logical() {
750 test_support::with_test_provider(|_| {
751 let tensor = Tensor::new(vec![0.0, 2.0, -3.0], vec![3, 1]).unwrap();
752 let logical_proto =
753 LogicalArray::new(vec![0, 1, 0], vec![3, 1]).expect("logical proto");
754 let result = call(
755 Value::Tensor(tensor),
756 vec![Value::from("like"), Value::LogicalArray(logical_proto)],
757 )
758 .expect("gpuArray like logical");
759 let Value::GpuTensor(handle) = result else {
760 panic!("expected gpu tensor");
761 };
762 assert!(runmat_accelerate_api::handle_is_logical(&handle));
763 let gathered = test_support::gather(Value::GpuTensor(handle.clone())).expect("gather");
764 assert_eq!(gathered.data, vec![0.0, 1.0, 1.0]);
765 });
766 }
767
768 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
769 #[test]
770 fn gpu_array_like_requires_argument() {
771 test_support::with_test_provider(|_| {
772 let tensor = Tensor::new(vec![1.0], vec![1, 1]).unwrap();
773 let err = call(Value::Tensor(tensor), vec![Value::from("like")])
774 .unwrap_err()
775 .to_string();
776 assert!(err.contains("expected a prototype value"));
777 });
778 }
779
780 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
781 #[test]
782 fn gpu_array_unknown_option_errors() {
783 test_support::with_test_provider(|_| {
784 let tensor = Tensor::new(vec![1.0], vec![1, 1]).unwrap();
785 let err = call(Value::Tensor(tensor), vec![Value::from("mystery")])
786 .unwrap_err()
787 .to_string();
788 assert!(err.contains("unrecognised option"));
789 });
790 }
791
792 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
793 #[test]
794 fn gpu_array_gpu_to_logical_reuploads() {
795 test_support::with_test_provider(|provider| {
796 let tensor = Tensor::new(vec![2.0, 0.0, -5.5], vec![3, 1]).unwrap();
797 let view = HostTensorView {
798 data: &tensor.data,
799 shape: &tensor.shape,
800 };
801 let handle = provider.upload(&view).expect("upload");
802 let result = call(
803 Value::GpuTensor(handle.clone()),
804 vec![Value::from("logical")],
805 )
806 .expect("gpuArray logical cast");
807 let Value::GpuTensor(new_handle) = result else {
808 panic!("expected gpu tensor");
809 };
810 assert!(runmat_accelerate_api::handle_is_logical(&new_handle));
811 let gathered =
812 test_support::gather(Value::GpuTensor(new_handle.clone())).expect("gather");
813 assert_eq!(gathered.data, vec![1.0, 0.0, 1.0]);
814 provider.free(&handle).ok();
815 provider.free(&new_handle).ok();
816 });
817 }
818
819 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
820 #[test]
821 fn gpu_array_gpu_logical_to_double_clears_flag() {
822 test_support::with_test_provider(|provider| {
823 let tensor = Tensor::new(vec![1.0, 0.0], vec![2, 1]).unwrap();
824 let view = HostTensorView {
825 data: &tensor.data,
826 shape: &tensor.shape,
827 };
828 let handle = provider.upload(&view).expect("upload");
829 runmat_accelerate_api::set_handle_logical(&handle, true);
830 let result = call(
831 Value::GpuTensor(handle.clone()),
832 vec![Value::from("double")],
833 )
834 .expect("gpuArray double cast");
835 let Value::GpuTensor(new_handle) = result else {
836 panic!("expected gpu tensor");
837 };
838 assert!(!runmat_accelerate_api::handle_is_logical(&new_handle));
839 provider.free(&handle).ok();
840 provider.free(&new_handle).ok();
841 });
842 }
843
844 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
845 #[test]
846 fn gpu_array_applies_size_arguments() {
847 test_support::with_test_provider(|_| {
848 let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![4, 1]).unwrap();
849 let result = call(
850 Value::Tensor(tensor),
851 vec![Value::from(2i32), Value::from(2i32)],
852 )
853 .expect("gpuArray reshape");
854 let Value::GpuTensor(handle) = result else {
855 panic!("expected gpu tensor");
856 };
857 assert_eq!(handle.shape, vec![2, 2]);
858 });
859 }
860
861 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
862 #[test]
863 fn gpu_array_gpu_size_arguments_update_shape() {
864 test_support::with_test_provider(|provider| {
865 let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![4, 1]).unwrap();
866 let view = HostTensorView {
867 data: &tensor.data,
868 shape: &tensor.shape,
869 };
870 let handle = provider.upload(&view).expect("upload");
871 let result = call(
872 Value::GpuTensor(handle.clone()),
873 vec![Value::from(2i32), Value::from(2i32)],
874 )
875 .expect("gpuArray gpu reshape");
876 let Value::GpuTensor(new_handle) = result else {
877 panic!("expected gpu tensor");
878 };
879 assert_eq!(new_handle.shape, vec![2, 2]);
880 provider.free(&handle).ok();
881 provider.free(&new_handle).ok();
882 });
883 }
884
885 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
886 #[test]
887 fn gpu_array_size_mismatch_errors() {
888 test_support::with_test_provider(|_| {
889 let tensor = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
890 let err = call(
891 Value::Tensor(tensor),
892 vec![Value::from(2i32), Value::from(2i32)],
893 )
894 .unwrap_err()
895 .to_string();
896 assert!(err.contains("cannot reshape"));
897 });
898 }
899
900 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
901 #[test]
902 #[cfg(feature = "wgpu")]
903 fn gpu_array_wgpu_roundtrip() {
904 use runmat_accelerate_api::AccelProvider;
905
906 match runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
907 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
908 ) {
909 Ok(provider) => {
910 let tensor = Tensor::new(vec![1.0, 2.5, 3.5], vec![3, 1]).unwrap();
911 let result = call(Value::Tensor(tensor.clone()), vec![Value::from("int32")])
912 .expect("wgpu upload");
913 let Value::GpuTensor(handle) = result else {
914 panic!("expected gpu tensor");
915 };
916 let gathered =
917 test_support::gather(Value::GpuTensor(handle.clone())).expect("wgpu gather");
918 assert_eq!(gathered.shape, vec![3, 1]);
919 assert_eq!(gathered.data, vec![1.0, 3.0, 4.0]);
920 provider.free(&handle).ok();
921 }
922 Err(err) => {
923 tracing::warn!("Skipping gpu_array_wgpu_roundtrip: {err}");
924 }
925 }
926 runmat_accelerate::simple_provider::register_inprocess_provider();
927 }
928
929 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
930 #[test]
931 fn gpu_array_accepts_int_scalars() {
932 test_support::with_test_provider(|_| {
933 let value = Value::Int(IntValue::I32(7));
934 let result = call(value, Vec::new()).expect("gpuArray int");
935 let Value::GpuTensor(handle) = result else {
936 panic!("expected gpu tensor");
937 };
938 let gathered =
939 test_support::gather(Value::GpuTensor(handle.clone())).expect("gather int");
940 assert_eq!(gathered.shape, vec![1, 1]);
941 assert_eq!(gathered.data, vec![7.0]);
942 });
943 }
944
945 #[test]
946 fn gpuarray_type_for_logical_is_logical() {
947 assert_eq!(
948 gpuarray_type(&[Type::logical()], &ResolveContext::new(Vec::new())),
949 Type::logical()
950 );
951 }
952}