Skip to main content

runmat_runtime/builtins/acceleration/gpu/
gpuarray.rs

1//! MATLAB-compatible `gpuArray` builtin that uploads host data to the active accelerator.
2//!
3//! The implementation mirrors MathWorks MATLAB semantics, including optional
4//! size arguments, `'like'` prototypes, and explicit dtype toggles. When no
5//! acceleration provider is registered the builtin surfaces a MATLAB-style
6//! error, ensuring callers know residency could not be established.
7
8use crate::builtins::acceleration::gpu::type_resolvers::gpuarray_type;
9use crate::builtins::common::spec::{
10    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
11    ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
12};
13use crate::builtins::common::{gpu_helpers, tensor};
14use runmat_accelerate_api::{GpuTensorHandle, HostTensorView, ProviderPrecision};
15use runmat_builtins::{
16    BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
17    BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
18    CharArray, IntValue, Tensor, Value,
19};
20use runmat_macros::runtime_builtin;
21
22use crate::{build_runtime_error, BuiltinResult, RuntimeError};
23
24const BUILTIN_NAME: &str = "gpuArray";
25
26const GPUARRAY_OUTPUT: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
27    name: "G",
28    ty: BuiltinParamType::Any,
29    arity: BuiltinParamArity::Required,
30    default: None,
31    description: "GPU-resident handle containing uploaded/converted data.",
32}];
33
34const GPUARRAY_INPUTS_BASE: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
35    name: "X",
36    ty: BuiltinParamType::Any,
37    arity: BuiltinParamArity::Required,
38    default: None,
39    description: "Input value to upload or recast on GPU.",
40}];
41
42const GPUARRAY_INPUTS_DIMS: [BuiltinParamDescriptor; 2] = [
43    BuiltinParamDescriptor {
44        name: "X",
45        ty: BuiltinParamType::Any,
46        arity: BuiltinParamArity::Required,
47        default: None,
48        description: "Input value to upload or recast on GPU.",
49    },
50    BuiltinParamDescriptor {
51        name: "dim",
52        ty: BuiltinParamType::SizeArg,
53        arity: BuiltinParamArity::Variadic,
54        default: None,
55        description: "Reshape dimensions (scalar dims or a single size vector tensor).",
56    },
57];
58
59const GPUARRAY_INPUTS_DTYPE: [BuiltinParamDescriptor; 2] = [
60    BuiltinParamDescriptor {
61        name: "X",
62        ty: BuiltinParamType::Any,
63        arity: BuiltinParamArity::Required,
64        default: None,
65        description: "Input value to upload or recast on GPU.",
66    },
67    BuiltinParamDescriptor {
68        name: "dtype",
69        ty: BuiltinParamType::StringScalar,
70        arity: BuiltinParamArity::Required,
71        default: Some("\"double\""),
72        description: "Class tag such as `single`, `int32`, `uint8`, `logical`, or `double`.",
73    },
74];
75
76const GPUARRAY_INPUTS_LIKE: [BuiltinParamDescriptor; 3] = [
77    BuiltinParamDescriptor {
78        name: "X",
79        ty: BuiltinParamType::Any,
80        arity: BuiltinParamArity::Required,
81        default: None,
82        description: "Input value to upload or recast on GPU.",
83    },
84    BuiltinParamDescriptor {
85        name: "like",
86        ty: BuiltinParamType::StringScalar,
87        arity: BuiltinParamArity::Required,
88        default: None,
89        description: "Literal keyword `\"like\"`.",
90    },
91    BuiltinParamDescriptor {
92        name: "prototype",
93        ty: BuiltinParamType::LikePrototype,
94        arity: BuiltinParamArity::Required,
95        default: None,
96        description: "Prototype value whose class drives output conversion.",
97    },
98];
99
100const GPUARRAY_INPUTS_DIMS_OPTIONS: [BuiltinParamDescriptor; 3] = [
101    BuiltinParamDescriptor {
102        name: "X",
103        ty: BuiltinParamType::Any,
104        arity: BuiltinParamArity::Required,
105        default: None,
106        description: "Input value to upload or recast on GPU.",
107    },
108    BuiltinParamDescriptor {
109        name: "dim",
110        ty: BuiltinParamType::SizeArg,
111        arity: BuiltinParamArity::Variadic,
112        default: None,
113        description: "Reshape dimensions (scalar dims or a single size vector tensor).",
114    },
115    BuiltinParamDescriptor {
116        name: "option",
117        ty: BuiltinParamType::Any,
118        arity: BuiltinParamArity::Variadic,
119        default: None,
120        description: "Class tags and/or `\"like\", prototype` qualifiers.",
121    },
122];
123
124const GPUARRAY_SIGNATURES: [BuiltinSignatureDescriptor; 5] = [
125    BuiltinSignatureDescriptor {
126        label: "G = gpuArray(X)",
127        inputs: &GPUARRAY_INPUTS_BASE,
128        outputs: &GPUARRAY_OUTPUT,
129    },
130    BuiltinSignatureDescriptor {
131        label: "G = gpuArray(X, dim, ...)",
132        inputs: &GPUARRAY_INPUTS_DIMS,
133        outputs: &GPUARRAY_OUTPUT,
134    },
135    BuiltinSignatureDescriptor {
136        label: "G = gpuArray(X, dtype)",
137        inputs: &GPUARRAY_INPUTS_DTYPE,
138        outputs: &GPUARRAY_OUTPUT,
139    },
140    BuiltinSignatureDescriptor {
141        label: "G = gpuArray(X, \"like\", prototype)",
142        inputs: &GPUARRAY_INPUTS_LIKE,
143        outputs: &GPUARRAY_OUTPUT,
144    },
145    BuiltinSignatureDescriptor {
146        label: "G = gpuArray(X, dim, ..., option, ...)",
147        inputs: &GPUARRAY_INPUTS_DIMS_OPTIONS,
148        outputs: &GPUARRAY_OUTPUT,
149    },
150];
151
152const GPUARRAY_ERROR_NO_PROVIDER: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
153    code: "RM.GPUARRAY.NO_PROVIDER",
154    identifier: Some("RunMat:gpuArray:NoProvider"),
155    when: "No acceleration provider is registered for host/device transfers.",
156    message: "gpuArray: no acceleration provider registered",
157};
158
159const GPUARRAY_ERROR_OPTION_ARGUMENT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
160    code: "RM.GPUARRAY.OPTION_ARGUMENT",
161    identifier: Some("RunMat:gpuArray:OptionArgument"),
162    when: "Option tail contains non-text values where class tags/keywords are expected.",
163    message: "gpuArray: invalid option argument",
164};
165
166const GPUARRAY_ERROR_LIKE_MISSING: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
167    code: "RM.GPUARRAY.LIKE_MISSING",
168    identifier: Some("RunMat:gpuArray:LikeMissingPrototype"),
169    when: "Keyword `like` is supplied without a following prototype value.",
170    message: "gpuArray: expected a prototype value after 'like'",
171};
172
173const GPUARRAY_ERROR_LIKE_DUPLICATE: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
174    code: "RM.GPUARRAY.LIKE_DUPLICATE",
175    identifier: Some("RunMat:gpuArray:LikeDuplicate"),
176    when: "Keyword `like` appears more than once.",
177    message: "gpuArray: duplicate 'like' qualifier",
178};
179
180const GPUARRAY_ERROR_CODISTRIBUTED_UNSUPPORTED: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
181    code: "RM.GPUARRAY.CODISTRIBUTED_UNSUPPORTED",
182    identifier: Some("RunMat:gpuArray:CodistributedUnsupported"),
183    when: "Distributed/codistributed qualifiers are requested.",
184    message: "gpuArray: codistributed arrays are not supported yet",
185};
186
187const GPUARRAY_ERROR_CONFLICTING_TYPE: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
188    code: "RM.GPUARRAY.CONFLICTING_TYPE",
189    identifier: Some("RunMat:gpuArray:ConflictingTypeQualifiers"),
190    when: "Multiple incompatible class qualifiers are supplied.",
191    message: "gpuArray: conflicting type qualifiers supplied",
192};
193
194const GPUARRAY_ERROR_UNKNOWN_OPTION: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
195    code: "RM.GPUARRAY.UNKNOWN_OPTION",
196    identifier: Some("RunMat:gpuArray:UnknownOption"),
197    when: "Text option is not a recognized class/keyword token.",
198    message: "gpuArray: unrecognised option",
199};
200
201const GPUARRAY_ERROR_SIZE_ARGUMENT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
202    code: "RM.GPUARRAY.SIZE_ARGUMENT",
203    identifier: Some("RunMat:gpuArray:InvalidSizeArgument"),
204    when: "Size arguments are malformed (not finite integers, negative, or invalid combinations).",
205    message: "gpuArray: invalid size argument",
206};
207
208const GPUARRAY_ERROR_LIKE_PROTOTYPE: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
209    code: "RM.GPUARRAY.LIKE_PROTOTYPE",
210    identifier: Some("RunMat:gpuArray:InvalidLikePrototype"),
211    when: "`like` prototype is unsupported for type inference.",
212    message: "gpuArray: invalid 'like' prototype",
213};
214
215const GPUARRAY_ERROR_INPUT_TYPE: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
216    code: "RM.GPUARRAY.INPUT_TYPE",
217    identifier: Some("RunMat:gpuArray:UnsupportedInputType"),
218    when: "Input value type cannot be uploaded/coerced to supported gpuArray storage.",
219    message: "gpuArray: unsupported input type",
220};
221
222const GPUARRAY_ERROR_CONVERSION: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
223    code: "RM.GPUARRAY.CONVERSION",
224    identifier: Some("RunMat:gpuArray:ConversionFailed"),
225    when: "Requested dtype conversion cannot be performed (for example NaN->logical).",
226    message: "gpuArray: conversion failed",
227};
228
229const GPUARRAY_ERROR_RESHAPE: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
230    code: "RM.GPUARRAY.RESHAPE",
231    identifier: Some("RunMat:gpuArray:ReshapeMismatch"),
232    when: "Requested shape does not preserve the element count.",
233    message: "gpuArray: cannot reshape gpuArray into requested size",
234};
235
236const GPUARRAY_ERROR_PROVIDER_IO: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
237    code: "RM.GPUARRAY.PROVIDER_IO",
238    identifier: Some("RunMat:gpuArray:ProviderIO"),
239    when: "Provider upload/download interaction fails.",
240    message: "gpuArray: provider I/O failed",
241};
242
243const GPUARRAY_ERROR_INTERNAL: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
244    code: "RM.GPUARRAY.INTERNAL",
245    identifier: Some("RunMat:gpuArray:InternalError"),
246    when: "Internal tensor/container conversion fails.",
247    message: "gpuArray: internal error",
248};
249
250const GPUARRAY_ERRORS: [BuiltinErrorDescriptor; 14] = [
251    GPUARRAY_ERROR_NO_PROVIDER,
252    GPUARRAY_ERROR_OPTION_ARGUMENT,
253    GPUARRAY_ERROR_LIKE_MISSING,
254    GPUARRAY_ERROR_LIKE_DUPLICATE,
255    GPUARRAY_ERROR_CODISTRIBUTED_UNSUPPORTED,
256    GPUARRAY_ERROR_CONFLICTING_TYPE,
257    GPUARRAY_ERROR_UNKNOWN_OPTION,
258    GPUARRAY_ERROR_SIZE_ARGUMENT,
259    GPUARRAY_ERROR_LIKE_PROTOTYPE,
260    GPUARRAY_ERROR_INPUT_TYPE,
261    GPUARRAY_ERROR_CONVERSION,
262    GPUARRAY_ERROR_RESHAPE,
263    GPUARRAY_ERROR_PROVIDER_IO,
264    GPUARRAY_ERROR_INTERNAL,
265];
266
267pub const GPUARRAY_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
268    signatures: &GPUARRAY_SIGNATURES,
269    output_mode: BuiltinOutputMode::Fixed,
270    completion_policy: BuiltinCompletionPolicy::Public,
271    errors: &GPUARRAY_ERRORS,
272};
273
274fn gpu_array_error(error: &'static BuiltinErrorDescriptor) -> RuntimeError {
275    gpu_array_error_with_message(error.message, error)
276}
277
278fn gpu_array_error_with_message(
279    message: impl Into<String>,
280    error: &'static BuiltinErrorDescriptor,
281) -> RuntimeError {
282    let mut builder = build_runtime_error(message).with_builtin(BUILTIN_NAME);
283    if let Some(identifier) = error.identifier {
284        builder = builder.with_identifier(identifier);
285    }
286    builder.build()
287}
288
289fn gpu_array_error_with_detail(
290    error: &'static BuiltinErrorDescriptor,
291    detail: impl AsRef<str>,
292) -> RuntimeError {
293    gpu_array_error_with_message(format!("{}: {}", error.message, detail.as_ref()), error)
294}
295
296#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::acceleration::gpu::gpuarray")]
297pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
298    name: "gpuArray",
299    op_kind: GpuOpKind::Custom("upload"),
300    supported_precisions: &[ScalarType::F32, ScalarType::F64],
301    broadcast: BroadcastSemantics::None,
302    provider_hooks: &[ProviderHook::Custom("upload")],
303    constant_strategy: ConstantStrategy::InlineLiteral,
304    residency: ResidencyPolicy::NewHandle,
305    nan_mode: ReductionNaN::Include,
306    two_pass_threshold: None,
307    workgroup_size: None,
308    accepts_nan_mode: false,
309    notes: "Invokes the provider `upload` hook, reuploading gpuArray inputs when dtype conversion is requested. Handles class strings, size vectors, and `'like'` prototypes.",
310};
311
312#[runmat_macros::register_fusion_spec(
313    builtin_path = "crate::builtins::acceleration::gpu::gpuarray"
314)]
315pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
316    name: "gpuArray",
317    shape: ShapeRequirements::Any,
318    constant_strategy: ConstantStrategy::InlineLiteral,
319    elementwise: None,
320    reduction: None,
321    emits_nan: false,
322    notes:
323        "Acts as a residency boundary; fusion graphs never cross explicit host↔device transfers.",
324};
325
326#[runtime_builtin(
327    name = "gpuArray",
328    category = "acceleration/gpu",
329    summary = "Move data to the GPU as gpuArray values.",
330    keywords = "gpuArray,gpu,accelerate,upload,dtype,like",
331    examples = "G = gpuArray([1 2 3], 'single');",
332    accel = "array_construct",
333    type_resolver(gpuarray_type),
334    descriptor(crate::builtins::acceleration::gpu::gpuarray::GPUARRAY_DESCRIPTOR),
335    builtin_path = "crate::builtins::acceleration::gpu::gpuarray"
336)]
337async fn gpu_array_builtin(value: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
338    let options = parse_options(&rest)?;
339    let incoming_precision = match &value {
340        Value::GpuTensor(handle) => runmat_accelerate_api::handle_precision(handle),
341        _ => None,
342    };
343    let dtype = resolve_dtype(&value, &options)?;
344    let dims = options.dims.clone();
345
346    let prepared = match value {
347        Value::GpuTensor(handle) => convert_device_value(handle, dtype).await?,
348        other => upload_host_value(other, dtype)?,
349    };
350
351    let mut handle = prepared.handle;
352
353    if let Some(dims) = dims.as_ref() {
354        apply_dims(&mut handle, dims)?;
355    }
356
357    let provider_precision = runmat_accelerate_api::provider()
358        .map(|p| p.precision())
359        .unwrap_or(ProviderPrecision::F64);
360    let requested_precision = match dtype {
361        DataClass::Single => Some(ProviderPrecision::F32),
362        _ => None,
363    };
364    let final_precision = requested_precision
365        .or(incoming_precision)
366        .unwrap_or(provider_precision);
367    runmat_accelerate_api::set_handle_precision(&handle, final_precision);
368
369    runmat_accelerate_api::set_handle_logical(&handle, prepared.logical);
370
371    Ok(Value::GpuTensor(handle))
372}
373
374#[derive(Clone, Copy, Debug, PartialEq, Eq)]
375enum DataClass {
376    Double,
377    Single,
378    Logical,
379    Int8,
380    Int16,
381    Int32,
382    Int64,
383    UInt8,
384    UInt16,
385    UInt32,
386    UInt64,
387}
388
389impl DataClass {
390    fn from_tag(tag: &str) -> Option<Self> {
391        match tag {
392            "double" => Some(Self::Double),
393            "single" | "float32" => Some(Self::Single),
394            "logical" | "bool" | "boolean" => Some(Self::Logical),
395            "int8" => Some(Self::Int8),
396            "int16" => Some(Self::Int16),
397            "int32" | "int" => Some(Self::Int32),
398            "int64" => Some(Self::Int64),
399            "uint8" => Some(Self::UInt8),
400            "uint16" => Some(Self::UInt16),
401            "uint32" => Some(Self::UInt32),
402            "uint64" => Some(Self::UInt64),
403            "gpuarray" => None, // compatibility no-op
404            _ => None,
405        }
406    }
407}
408
409#[derive(Debug, Default)]
410struct ParsedOptions {
411    dims: Option<Vec<usize>>,
412    explicit_dtype: Option<DataClass>,
413    prototype: Option<Value>,
414}
415
416fn parse_options(rest: &[Value]) -> BuiltinResult<ParsedOptions> {
417    let (index_after_dims, dims) = parse_size_arguments(rest)?;
418    let mut options = ParsedOptions {
419        dims,
420        ..ParsedOptions::default()
421    };
422
423    let mut idx = index_after_dims;
424    while idx < rest.len() {
425        let tag = value_to_lower_string(&rest[idx]).ok_or_else(|| {
426            gpu_array_error_with_message(
427                format!(
428                "gpuArray: unexpected argument {:?}; expected a class string or the keyword 'like'",
429                rest[idx]
430                ),
431                &GPUARRAY_ERROR_OPTION_ARGUMENT,
432            )
433        })?;
434
435        match tag.as_str() {
436            "like" => {
437                idx += 1;
438                if idx >= rest.len() {
439                    return Err(gpu_array_error(&GPUARRAY_ERROR_LIKE_MISSING));
440                }
441                if options.prototype.is_some() {
442                    return Err(gpu_array_error(&GPUARRAY_ERROR_LIKE_DUPLICATE));
443                }
444                options.prototype = Some(rest[idx].clone());
445            }
446            "distributed" | "codistributed" => {
447                return Err(gpu_array_error(&GPUARRAY_ERROR_CODISTRIBUTED_UNSUPPORTED));
448            }
449            tag => {
450                if let Some(class) = DataClass::from_tag(tag) {
451                    if let Some(existing) = options.explicit_dtype {
452                        if existing != class {
453                            return Err(gpu_array_error(&GPUARRAY_ERROR_CONFLICTING_TYPE));
454                        }
455                    } else {
456                        options.explicit_dtype = Some(class);
457                    }
458                } else if tag != "gpuarray" {
459                    return Err(gpu_array_error_with_detail(
460                        &GPUARRAY_ERROR_UNKNOWN_OPTION,
461                        format!("unrecognised option '{tag}'"),
462                    ));
463                }
464            }
465        }
466
467        idx += 1;
468    }
469
470    Ok(options)
471}
472
473fn parse_size_arguments(rest: &[Value]) -> BuiltinResult<(usize, Option<Vec<usize>>)> {
474    let mut idx = 0;
475    let mut dims: Vec<usize> = Vec::new();
476    let mut vector_consumed = false;
477
478    while idx < rest.len() {
479        // Stop at textual qualifiers only; numeric values continue parsing as size args.
480        match &rest[idx] {
481            Value::String(_) | Value::StringArray(_) | Value::CharArray(_) => break,
482            _ => {}
483        }
484
485        match &rest[idx] {
486            Value::Int(i) => {
487                dims.push(int_to_dim(i)?);
488            }
489            Value::Num(n) => {
490                dims.push(float_to_dim(*n)?);
491            }
492            Value::Tensor(t) => {
493                if vector_consumed || !dims.is_empty() {
494                    return Err(gpu_array_error_with_message(
495                        "gpuArray: size vectors cannot be combined with scalar dimensions",
496                        &GPUARRAY_ERROR_SIZE_ARGUMENT,
497                    ));
498                }
499                dims = tensor_to_dims(t)?;
500                vector_consumed = true;
501            }
502            _ => break,
503        }
504        idx += 1;
505    }
506
507    let dims_option = if dims.is_empty() { None } else { Some(dims) };
508    Ok((idx, dims_option))
509}
510
511fn value_to_lower_string(value: &Value) -> Option<String> {
512    crate::builtins::common::tensor::value_to_string(value).map(|s| s.trim().to_ascii_lowercase())
513}
514
515fn int_to_dim(value: &IntValue) -> BuiltinResult<usize> {
516    let raw = value.to_i64();
517    if raw < 0 {
518        return Err(gpu_array_error_with_message(
519            "gpuArray: size arguments must be non-negative integers",
520            &GPUARRAY_ERROR_SIZE_ARGUMENT,
521        ));
522    }
523    Ok(raw as usize)
524}
525
526fn float_to_dim(value: f64) -> BuiltinResult<usize> {
527    if !value.is_finite() {
528        return Err(gpu_array_error_with_message(
529            "gpuArray: size arguments must be finite integers",
530            &GPUARRAY_ERROR_SIZE_ARGUMENT,
531        ));
532    }
533    let rounded = value.round();
534    if (rounded - value).abs() > f64::EPSILON {
535        return Err(gpu_array_error_with_message(
536            "gpuArray: size arguments must be integers",
537            &GPUARRAY_ERROR_SIZE_ARGUMENT,
538        ));
539    }
540    if rounded < 0.0 {
541        return Err(gpu_array_error_with_message(
542            "gpuArray: size arguments must be non-negative",
543            &GPUARRAY_ERROR_SIZE_ARGUMENT,
544        ));
545    }
546    Ok(rounded as usize)
547}
548
549fn tensor_to_dims(tensor: &Tensor) -> BuiltinResult<Vec<usize>> {
550    let mut dims = Vec::with_capacity(tensor.data.len());
551    for value in &tensor.data {
552        dims.push(float_to_dim(*value)?);
553    }
554    Ok(dims)
555}
556
557fn resolve_dtype(value: &Value, options: &ParsedOptions) -> BuiltinResult<DataClass> {
558    if let Some(explicit) = options.explicit_dtype {
559        return Ok(explicit);
560    }
561    if let Some(prototype) = options.prototype.as_ref() {
562        return infer_dtype_from_prototype(prototype);
563    }
564    if value_defaults_to_logical(value) {
565        return Ok(DataClass::Logical);
566    }
567    Ok(DataClass::Double)
568}
569
570fn infer_dtype_from_prototype(proto: &Value) -> BuiltinResult<DataClass> {
571    match proto {
572        Value::GpuTensor(handle) => {
573            if runmat_accelerate_api::handle_is_logical(handle) {
574                Ok(DataClass::Logical)
575            } else {
576                Ok(DataClass::Double)
577            }
578        }
579        Value::LogicalArray(_) | Value::Bool(_) => Ok(DataClass::Logical),
580        Value::Int(int) => Ok(match int {
581            IntValue::I8(_) => DataClass::Int8,
582            IntValue::I16(_) => DataClass::Int16,
583            IntValue::I32(_) => DataClass::Int32,
584            IntValue::I64(_) => DataClass::Int64,
585            IntValue::U8(_) => DataClass::UInt8,
586            IntValue::U16(_) => DataClass::UInt16,
587            IntValue::U32(_) => DataClass::UInt32,
588            IntValue::U64(_) => DataClass::UInt64,
589        }),
590        Value::Tensor(_) | Value::Num(_) => Ok(DataClass::Double),
591        Value::CharArray(_) => Ok(DataClass::Double),
592        Value::String(_) => Err(gpu_array_error_with_message(
593            "gpuArray: 'like' does not accept MATLAB string scalars; convert to char() first",
594            &GPUARRAY_ERROR_LIKE_PROTOTYPE,
595        )),
596        Value::StringArray(_) => Err(gpu_array_error_with_message(
597            "gpuArray: 'like' does not accept string arrays; convert to char arrays first",
598            &GPUARRAY_ERROR_LIKE_PROTOTYPE,
599        )),
600        Value::Complex(_, _) | Value::ComplexTensor(_) => Err(gpu_array_error_with_message(
601            "gpuArray: complex prototypes are not supported yet; provide real-valued inputs",
602            &GPUARRAY_ERROR_LIKE_PROTOTYPE,
603        )),
604        other => Err(gpu_array_error_with_message(
605            format!(
606                "gpuArray: unsupported 'like' prototype type {other:?}; expected numeric or logical values"
607            ),
608            &GPUARRAY_ERROR_LIKE_PROTOTYPE,
609        )),
610    }
611}
612
613fn value_defaults_to_logical(value: &Value) -> bool {
614    match value {
615        Value::LogicalArray(_) | Value::Bool(_) => true,
616        Value::GpuTensor(handle) => runmat_accelerate_api::handle_is_logical(handle),
617        _ => false,
618    }
619}
620
621struct PreparedHandle {
622    handle: GpuTensorHandle,
623    logical: bool,
624}
625
626fn upload_host_value(value: Value, dtype: DataClass) -> BuiltinResult<PreparedHandle> {
627    #[cfg(all(test, feature = "wgpu"))]
628    {
629        if runmat_accelerate_api::provider().is_none() {
630            let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
631                runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
632            );
633        }
634    }
635    let provider = runmat_accelerate_api::provider()
636        .ok_or_else(|| gpu_array_error(&GPUARRAY_ERROR_NO_PROVIDER))?;
637
638    let tensor = coerce_host_value(value)?;
639    let (mut tensor, logical) = cast_tensor(tensor, dtype)?;
640
641    let view = HostTensorView {
642        data: &tensor.data,
643        shape: &tensor.shape,
644    };
645    let new_handle = provider.upload(&view).map_err(|err| {
646        gpu_array_error_with_message(format!("gpuArray: {err}"), &GPUARRAY_ERROR_PROVIDER_IO)
647    })?;
648
649    tensor.data.clear();
650
651    Ok(PreparedHandle {
652        handle: new_handle,
653        logical,
654    })
655}
656
657async fn convert_device_value(
658    handle: GpuTensorHandle,
659    dtype: DataClass,
660) -> BuiltinResult<PreparedHandle> {
661    let was_logical = runmat_accelerate_api::handle_is_logical(&handle);
662    match dtype {
663        DataClass::Double => {
664            return Ok(PreparedHandle {
665                handle,
666                logical: false,
667            });
668        }
669        DataClass::Logical => {
670            if was_logical {
671                return Ok(PreparedHandle {
672                    handle,
673                    logical: true,
674                });
675            }
676        }
677        _ => {}
678    }
679
680    let provider = runmat_accelerate_api::provider()
681        .ok_or_else(|| gpu_array_error(&GPUARRAY_ERROR_NO_PROVIDER))?;
682    let tensor = gpu_helpers::gather_tensor_async(&handle)
683        .await
684        .map_err(|err| {
685            gpu_array_error_with_message(err.to_string(), &GPUARRAY_ERROR_PROVIDER_IO)
686        })?;
687    let (mut tensor, logical) = cast_tensor(tensor, dtype)?;
688
689    let view = HostTensorView {
690        data: &tensor.data,
691        shape: &tensor.shape,
692    };
693    let new_handle = provider.upload(&view).map_err(|err| {
694        gpu_array_error_with_message(format!("gpuArray: {err}"), &GPUARRAY_ERROR_PROVIDER_IO)
695    })?;
696
697    provider.free(&handle).ok();
698    tensor.data.clear();
699
700    Ok(PreparedHandle {
701        handle: new_handle,
702        logical,
703    })
704}
705
706fn coerce_host_value(value: Value) -> BuiltinResult<Tensor> {
707    match value {
708        Value::Tensor(t) => Ok(t),
709        Value::LogicalArray(logical) => tensor::logical_to_tensor(&logical)
710            .map_err(|err| {
711                gpu_array_error_with_message(format!("gpuArray: {err}"), &GPUARRAY_ERROR_INTERNAL)
712            }),
713        Value::Bool(flag) => Tensor::new(vec![if flag { 1.0 } else { 0.0 }], vec![1, 1])
714            .map_err(|err| {
715                gpu_array_error_with_message(format!("gpuArray: {err}"), &GPUARRAY_ERROR_INTERNAL)
716            }),
717        Value::Num(n) => Tensor::new(vec![n], vec![1, 1])
718            .map_err(|err| {
719                gpu_array_error_with_message(format!("gpuArray: {err}"), &GPUARRAY_ERROR_INTERNAL)
720            }),
721        Value::Int(i) => Tensor::new(vec![i.to_f64()], vec![1, 1])
722            .map_err(|err| {
723                gpu_array_error_with_message(format!("gpuArray: {err}"), &GPUARRAY_ERROR_INTERNAL)
724            }),
725        Value::CharArray(ca) => char_array_to_tensor(&ca),
726        Value::String(text) => {
727            let ca = CharArray::new_row(&text);
728            char_array_to_tensor(&ca)
729        }
730        Value::StringArray(_) => Err(gpu_array_error_with_message(
731            "gpuArray: string arrays are not supported yet; convert to char arrays with CHAR first",
732            &GPUARRAY_ERROR_INPUT_TYPE,
733        )),
734        Value::Complex(_, _) | Value::ComplexTensor(_) => Err(gpu_array_error_with_message(
735            "gpuArray: complex inputs are not supported yet; split real and imaginary parts before uploading",
736            &GPUARRAY_ERROR_INPUT_TYPE,
737        )),
738        other => Err(gpu_array_error_with_detail(
739            &GPUARRAY_ERROR_INPUT_TYPE,
740            format!("unsupported input type for GPU transfer: {other:?}"),
741        )),
742    }
743}
744
745fn cast_tensor(mut tensor: Tensor, dtype: DataClass) -> BuiltinResult<(Tensor, bool)> {
746    let logical = match dtype {
747        DataClass::Logical => {
748            convert_to_logical(&mut tensor.data)?;
749            true
750        }
751        DataClass::Single => {
752            convert_to_single(&mut tensor.data);
753            false
754        }
755        DataClass::Int8 => {
756            convert_to_int_range(&mut tensor.data, i8::MIN as f64, i8::MAX as f64);
757            false
758        }
759        DataClass::Int16 => {
760            convert_to_int_range(&mut tensor.data, i16::MIN as f64, i16::MAX as f64);
761            false
762        }
763        DataClass::Int32 => {
764            convert_to_int_range(&mut tensor.data, i32::MIN as f64, i32::MAX as f64);
765            false
766        }
767        DataClass::Int64 => {
768            convert_to_int_range(&mut tensor.data, i64::MIN as f64, i64::MAX as f64);
769            false
770        }
771        DataClass::UInt8 => {
772            convert_to_int_range(&mut tensor.data, 0.0, u8::MAX as f64);
773            false
774        }
775        DataClass::UInt16 => {
776            convert_to_int_range(&mut tensor.data, 0.0, u16::MAX as f64);
777            false
778        }
779        DataClass::UInt32 => {
780            convert_to_int_range(&mut tensor.data, 0.0, u32::MAX as f64);
781            false
782        }
783        DataClass::UInt64 => {
784            convert_to_int_range(&mut tensor.data, 0.0, u64::MAX as f64);
785            false
786        }
787        DataClass::Double => false,
788    };
789
790    Ok((tensor, logical))
791}
792
793fn convert_to_logical(data: &mut [f64]) -> BuiltinResult<()> {
794    for value in data.iter_mut() {
795        if value.is_nan() {
796            return Err(gpu_array_error_with_message(
797                "gpuArray: cannot convert NaN to logical",
798                &GPUARRAY_ERROR_CONVERSION,
799            ));
800        }
801        *value = if *value != 0.0 { 1.0 } else { 0.0 };
802    }
803    Ok(())
804}
805
806fn convert_to_single(data: &mut [f64]) {
807    for value in data.iter_mut() {
808        *value = (*value as f32) as f64;
809    }
810}
811
812fn convert_to_int_range(data: &mut [f64], min: f64, max: f64) {
813    for value in data.iter_mut() {
814        if value.is_nan() {
815            *value = min;
816            continue;
817        }
818        if value.is_infinite() {
819            *value = if value.is_sign_negative() { min } else { max };
820            continue;
821        }
822        let rounded = value.round();
823        *value = rounded.clamp(min, max);
824    }
825}
826
827fn apply_dims(handle: &mut GpuTensorHandle, dims: &[usize]) -> BuiltinResult<()> {
828    let new_elems: usize = dims.iter().product();
829    let current_elems: usize = if handle.shape.is_empty() {
830        new_elems
831    } else {
832        handle.shape.iter().product()
833    };
834    if new_elems != current_elems {
835        return Err(gpu_array_error_with_message(
836            format!(
837                "gpuArray: cannot reshape gpuArray of {current_elems} elements into size {:?}",
838                dims
839            ),
840            &GPUARRAY_ERROR_RESHAPE,
841        ));
842    }
843    handle.shape = dims.to_vec();
844    Ok(())
845}
846
847fn char_array_to_tensor(ca: &CharArray) -> BuiltinResult<Tensor> {
848    let rows = ca.rows;
849    let cols = ca.cols;
850    if rows == 0 || cols == 0 {
851        return Tensor::new(Vec::new(), vec![rows, cols]).map_err(|err| {
852            gpu_array_error_with_message(format!("gpuArray: {err}"), &GPUARRAY_ERROR_INTERNAL)
853        });
854    }
855    let mut data = vec![0.0; rows * cols];
856    // Store in row-major to preserve the original character order when interpreted with column-major indexing
857    for row in 0..rows {
858        for col in 0..cols {
859            let idx_char = row * cols + col;
860            let ch = ca.data[idx_char];
861            data[row * cols + col] = ch as u32 as f64;
862        }
863    }
864    Tensor::new(data, vec![rows, cols]).map_err(|err| {
865        gpu_array_error_with_message(format!("gpuArray: {err}"), &GPUARRAY_ERROR_INTERNAL)
866    })
867}
868
869#[cfg(test)]
870pub(crate) mod tests {
871    use super::*;
872    use crate::builtins::common::test_support;
873    use futures::executor::block_on;
874    use runmat_accelerate_api::HostTensorView;
875    use runmat_builtins::{IntValue, LogicalArray, ResolveContext, Type};
876
877    fn call(value: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
878        block_on(gpu_array_builtin(value, rest))
879    }
880
881    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
882    #[test]
883    fn gpu_array_transfers_numeric_tensor() {
884        test_support::with_test_provider(|_| {
885            let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
886            let result = call(Value::Tensor(tensor.clone()), Vec::new()).expect("gpuArray upload");
887            let Value::GpuTensor(handle) = result else {
888                panic!("expected gpu tensor");
889            };
890            assert_eq!(handle.shape, tensor.shape);
891            let gathered =
892                test_support::gather(Value::GpuTensor(handle.clone())).expect("gather values");
893            assert_eq!(gathered.shape, tensor.shape);
894            assert_eq!(gathered.data, tensor.data);
895        });
896    }
897
898    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
899    #[test]
900    fn gpu_array_marks_logical_inputs() {
901        test_support::with_test_provider(|_| {
902            let logical =
903                LogicalArray::new(vec![1, 0, 1, 1], vec![2, 2]).expect("logical construction");
904            let result =
905                call(Value::LogicalArray(logical.clone()), Vec::new()).expect("gpuArray logical");
906            let Value::GpuTensor(handle) = result else {
907                panic!("expected gpu tensor");
908            };
909            assert!(runmat_accelerate_api::handle_is_logical(&handle));
910            let gathered =
911                test_support::gather(Value::GpuTensor(handle.clone())).expect("gather logical");
912            assert_eq!(gathered.shape, logical.shape);
913            assert_eq!(gathered.data, vec![1.0, 0.0, 1.0, 1.0]);
914        });
915    }
916
917    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
918    #[test]
919    fn gpu_array_handles_scalar_bool() {
920        test_support::with_test_provider(|_| {
921            let result = call(Value::Bool(true), Vec::new()).expect("gpuArray bool");
922            let Value::GpuTensor(handle) = result else {
923                panic!("expected gpu tensor");
924            };
925            assert!(runmat_accelerate_api::handle_is_logical(&handle));
926            let gathered =
927                test_support::gather(Value::GpuTensor(handle.clone())).expect("gather bool");
928            assert_eq!(gathered.shape, vec![1, 1]);
929            assert_eq!(gathered.data, vec![1.0]);
930        });
931    }
932
933    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
934    #[test]
935    fn gpu_array_supports_char_arrays() {
936        test_support::with_test_provider(|_| {
937            let chars = CharArray::new("row1row2".chars().collect(), 2, 4).unwrap();
938            let original: Vec<char> = chars.data.clone();
939            let result =
940                call(Value::CharArray(chars), Vec::new()).expect("gpuArray char array upload");
941            let Value::GpuTensor(handle) = result else {
942                panic!("expected gpu tensor");
943            };
944            let gathered =
945                test_support::gather(Value::GpuTensor(handle.clone())).expect("gather chars");
946            assert_eq!(gathered.shape, vec![2, 4]);
947            let mut recovered = Vec::new();
948            for col in 0..4 {
949                for row in 0..2 {
950                    let idx = row + col * 2;
951                    let code = gathered.data[idx];
952                    let ch = char::from_u32(code as u32)
953                        .expect("valid unicode scalar from numeric code");
954                    recovered.push(ch);
955                }
956            }
957            assert_eq!(recovered, original);
958        });
959    }
960
961    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
962    #[test]
963    fn gpu_array_converts_strings() {
964        test_support::with_test_provider(|_| {
965            let result = call(Value::String("gpu".into()), Vec::new()).expect("gpuArray string");
966            let Value::GpuTensor(handle) = result else {
967                panic!("expected gpu tensor");
968            };
969            let gathered =
970                test_support::gather(Value::GpuTensor(handle.clone())).expect("gather string");
971            assert_eq!(gathered.shape, vec![1, 3]);
972            let expected: Vec<f64> = "gpu".chars().map(|ch| ch as u32 as f64).collect();
973            assert_eq!(gathered.data, expected);
974        });
975    }
976
977    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
978    #[test]
979    fn gpu_array_passthrough_existing_handle() {
980        test_support::with_test_provider(|provider| {
981            let tensor = Tensor::new(vec![5.0, 6.0], vec![2, 1]).unwrap();
982            let view = HostTensorView {
983                data: &tensor.data,
984                shape: &tensor.shape,
985            };
986            let handle = provider.upload(&view).expect("upload");
987            let cloned = handle.clone();
988            let result =
989                call(Value::GpuTensor(handle.clone()), Vec::new()).expect("gpuArray passthrough");
990            let Value::GpuTensor(returned) = result else {
991                panic!("expected gpu tensor");
992            };
993            assert_eq!(returned.buffer_id, cloned.buffer_id);
994            assert_eq!(returned.shape, cloned.shape);
995        });
996    }
997
998    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
999    #[test]
1000    fn gpu_array_casts_to_int32() {
1001        test_support::with_test_provider(|_| {
1002            let tensor = Tensor::new(vec![1.2, -3.7, 123456.0], vec![3, 1]).unwrap();
1003            let result =
1004                call(Value::Tensor(tensor), vec![Value::from("int32")]).expect("gpuArray int32");
1005            let Value::GpuTensor(handle) = result else {
1006                panic!("expected gpu tensor");
1007            };
1008            let gathered =
1009                test_support::gather(Value::GpuTensor(handle.clone())).expect("gather int32");
1010            assert_eq!(gathered.data, vec![1.0, -4.0, 123456.0]);
1011        });
1012    }
1013
1014    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1015    #[test]
1016    fn gpu_array_casts_to_uint8() {
1017        test_support::with_test_provider(|_| {
1018            let tensor = Tensor::new(vec![-12.0, 12.8, 300.4, f64::INFINITY], vec![4, 1]).unwrap();
1019            let result =
1020                call(Value::Tensor(tensor), vec![Value::from("uint8")]).expect("gpuArray uint8");
1021            let Value::GpuTensor(handle) = result else {
1022                panic!("expected gpu tensor");
1023            };
1024            let gathered =
1025                test_support::gather(Value::GpuTensor(handle.clone())).expect("gather uint8");
1026            assert_eq!(gathered.data, vec![0.0, 13.0, 255.0, 255.0]);
1027        });
1028    }
1029
1030    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1031    #[test]
1032    fn gpu_array_single_precision_rounds() {
1033        test_support::with_test_provider(|_| {
1034            let tensor = Tensor::new(vec![1.23456789, -9.87654321], vec![2, 1]).unwrap();
1035            let result =
1036                call(Value::Tensor(tensor), vec![Value::from("single")]).expect("gpuArray single");
1037            let Value::GpuTensor(handle) = result else {
1038                panic!("expected gpu tensor");
1039            };
1040            let gathered =
1041                test_support::gather(Value::GpuTensor(handle.clone())).expect("gather single");
1042            let expected = [1.234_567_9_f32 as f64, (-9.876_543_f32) as f64];
1043            for (observed, expected) in gathered.data.iter().zip(expected.iter()) {
1044                assert!((observed - expected).abs() < 1e-6);
1045            }
1046        });
1047    }
1048
1049    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1050    #[test]
1051    fn gpu_array_like_infers_logical() {
1052        test_support::with_test_provider(|_| {
1053            let tensor = Tensor::new(vec![0.0, 2.0, -3.0], vec![3, 1]).unwrap();
1054            let logical_proto =
1055                LogicalArray::new(vec![0, 1, 0], vec![3, 1]).expect("logical proto");
1056            let result = call(
1057                Value::Tensor(tensor),
1058                vec![Value::from("like"), Value::LogicalArray(logical_proto)],
1059            )
1060            .expect("gpuArray like logical");
1061            let Value::GpuTensor(handle) = result else {
1062                panic!("expected gpu tensor");
1063            };
1064            assert!(runmat_accelerate_api::handle_is_logical(&handle));
1065            let gathered = test_support::gather(Value::GpuTensor(handle.clone())).expect("gather");
1066            assert_eq!(gathered.data, vec![0.0, 1.0, 1.0]);
1067        });
1068    }
1069
1070    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1071    #[test]
1072    fn gpu_array_like_requires_argument() {
1073        test_support::with_test_provider(|_| {
1074            let tensor = Tensor::new(vec![1.0], vec![1, 1]).unwrap();
1075            let err = call(Value::Tensor(tensor), vec![Value::from("like")]).unwrap_err();
1076            assert_eq!(err.to_string(), GPUARRAY_ERROR_LIKE_MISSING.message);
1077            assert_eq!(err.identifier(), GPUARRAY_ERROR_LIKE_MISSING.identifier);
1078        });
1079    }
1080
1081    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1082    #[test]
1083    fn gpu_array_unknown_option_errors() {
1084        test_support::with_test_provider(|_| {
1085            let tensor = Tensor::new(vec![1.0], vec![1, 1]).unwrap();
1086            let err = call(Value::Tensor(tensor), vec![Value::from("mystery")]).unwrap_err();
1087            assert!(err
1088                .to_string()
1089                .contains(GPUARRAY_ERROR_UNKNOWN_OPTION.message));
1090            assert_eq!(err.identifier(), GPUARRAY_ERROR_UNKNOWN_OPTION.identifier);
1091        });
1092    }
1093
1094    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1095    #[test]
1096    fn gpu_array_gpu_to_logical_reuploads() {
1097        test_support::with_test_provider(|provider| {
1098            let tensor = Tensor::new(vec![2.0, 0.0, -5.5], vec![3, 1]).unwrap();
1099            let view = HostTensorView {
1100                data: &tensor.data,
1101                shape: &tensor.shape,
1102            };
1103            let handle = provider.upload(&view).expect("upload");
1104            let result = call(
1105                Value::GpuTensor(handle.clone()),
1106                vec![Value::from("logical")],
1107            )
1108            .expect("gpuArray logical cast");
1109            let Value::GpuTensor(new_handle) = result else {
1110                panic!("expected gpu tensor");
1111            };
1112            assert!(runmat_accelerate_api::handle_is_logical(&new_handle));
1113            let gathered =
1114                test_support::gather(Value::GpuTensor(new_handle.clone())).expect("gather");
1115            assert_eq!(gathered.data, vec![1.0, 0.0, 1.0]);
1116            provider.free(&handle).ok();
1117            provider.free(&new_handle).ok();
1118        });
1119    }
1120
1121    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1122    #[test]
1123    fn gpu_array_gpu_logical_to_double_clears_flag() {
1124        test_support::with_test_provider(|provider| {
1125            let tensor = Tensor::new(vec![1.0, 0.0], vec![2, 1]).unwrap();
1126            let view = HostTensorView {
1127                data: &tensor.data,
1128                shape: &tensor.shape,
1129            };
1130            let handle = provider.upload(&view).expect("upload");
1131            runmat_accelerate_api::set_handle_logical(&handle, true);
1132            let result = call(
1133                Value::GpuTensor(handle.clone()),
1134                vec![Value::from("double")],
1135            )
1136            .expect("gpuArray double cast");
1137            let Value::GpuTensor(new_handle) = result else {
1138                panic!("expected gpu tensor");
1139            };
1140            assert!(!runmat_accelerate_api::handle_is_logical(&new_handle));
1141            provider.free(&handle).ok();
1142            provider.free(&new_handle).ok();
1143        });
1144    }
1145
1146    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1147    #[test]
1148    fn gpu_array_applies_size_arguments() {
1149        test_support::with_test_provider(|_| {
1150            let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![4, 1]).unwrap();
1151            let result = call(
1152                Value::Tensor(tensor),
1153                vec![Value::from(2i32), Value::from(2i32)],
1154            )
1155            .expect("gpuArray reshape");
1156            let Value::GpuTensor(handle) = result else {
1157                panic!("expected gpu tensor");
1158            };
1159            assert_eq!(handle.shape, vec![2, 2]);
1160        });
1161    }
1162
1163    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1164    #[test]
1165    fn gpu_array_gpu_size_arguments_update_shape() {
1166        test_support::with_test_provider(|provider| {
1167            let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![4, 1]).unwrap();
1168            let view = HostTensorView {
1169                data: &tensor.data,
1170                shape: &tensor.shape,
1171            };
1172            let handle = provider.upload(&view).expect("upload");
1173            let result = call(
1174                Value::GpuTensor(handle.clone()),
1175                vec![Value::from(2i32), Value::from(2i32)],
1176            )
1177            .expect("gpuArray gpu reshape");
1178            let Value::GpuTensor(new_handle) = result else {
1179                panic!("expected gpu tensor");
1180            };
1181            assert_eq!(new_handle.shape, vec![2, 2]);
1182            provider.free(&handle).ok();
1183            provider.free(&new_handle).ok();
1184        });
1185    }
1186
1187    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1188    #[test]
1189    fn gpu_array_size_mismatch_errors() {
1190        test_support::with_test_provider(|_| {
1191            let tensor = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
1192            let err = call(
1193                Value::Tensor(tensor),
1194                vec![Value::from(2i32), Value::from(2i32)],
1195            )
1196            .unwrap_err();
1197            assert!(err.to_string().contains("cannot reshape"));
1198            assert_eq!(err.identifier(), GPUARRAY_ERROR_RESHAPE.identifier);
1199        });
1200    }
1201
1202    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1203    #[test]
1204    #[cfg(feature = "wgpu")]
1205    fn gpu_array_wgpu_roundtrip() {
1206        use runmat_accelerate_api::AccelProvider;
1207
1208        match runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1209            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1210        ) {
1211            Ok(provider) => {
1212                let tensor = Tensor::new(vec![1.0, 2.5, 3.5], vec![3, 1]).unwrap();
1213                let result = call(Value::Tensor(tensor.clone()), vec![Value::from("int32")])
1214                    .expect("wgpu upload");
1215                let Value::GpuTensor(handle) = result else {
1216                    panic!("expected gpu tensor");
1217                };
1218                let gathered =
1219                    test_support::gather(Value::GpuTensor(handle.clone())).expect("wgpu gather");
1220                assert_eq!(gathered.shape, vec![3, 1]);
1221                assert_eq!(gathered.data, vec![1.0, 3.0, 4.0]);
1222                provider.free(&handle).ok();
1223            }
1224            Err(err) => {
1225                tracing::warn!("Skipping gpu_array_wgpu_roundtrip: {err}");
1226            }
1227        }
1228        runmat_accelerate::simple_provider::register_inprocess_provider();
1229    }
1230
1231    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1232    #[test]
1233    fn gpu_array_accepts_int_scalars() {
1234        test_support::with_test_provider(|_| {
1235            let value = Value::Int(IntValue::I32(7));
1236            let result = call(value, Vec::new()).expect("gpuArray int");
1237            let Value::GpuTensor(handle) = result else {
1238                panic!("expected gpu tensor");
1239            };
1240            let gathered =
1241                test_support::gather(Value::GpuTensor(handle.clone())).expect("gather int");
1242            assert_eq!(gathered.shape, vec![1, 1]);
1243            assert_eq!(gathered.data, vec![7.0]);
1244        });
1245    }
1246
1247    #[test]
1248    fn gpuarray_type_for_logical_is_logical() {
1249        assert_eq!(
1250            gpuarray_type(&[Type::logical()], &ResolveContext::new(Vec::new())),
1251            Type::logical()
1252        );
1253    }
1254}