Skip to main content

runmat_runtime/builtins/acceleration/gpu/
gpuarray.rs

1//! MATLAB-compatible `gpuArray` builtin that uploads host data to the active accelerator.
2//!
3//! The implementation mirrors MathWorks MATLAB semantics, including optional
4//! size arguments, `'like'` prototypes, and explicit dtype toggles. When no
5//! acceleration provider is registered the builtin surfaces a MATLAB-style
6//! error, ensuring callers know residency could not be established.
7
8use crate::builtins::acceleration::gpu::type_resolvers::gpuarray_type;
9use crate::builtins::common::spec::{
10    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
11    ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
12};
13use crate::builtins::common::{gpu_helpers, tensor};
14use runmat_accelerate_api::{GpuTensorHandle, HostTensorView, ProviderPrecision};
15use runmat_builtins::{
16    BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
17    BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
18    CharArray, ComplexTensor, IntValue, Tensor, Value,
19};
20use runmat_macros::runtime_builtin;
21
22use crate::{build_runtime_error, BuiltinResult, RuntimeError};
23
24const BUILTIN_NAME: &str = "gpuArray";
25
26const GPUARRAY_OUTPUT: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
27    name: "G",
28    ty: BuiltinParamType::Any,
29    arity: BuiltinParamArity::Required,
30    default: None,
31    description: "GPU-resident handle containing uploaded/converted data.",
32}];
33
34const GPUARRAY_INPUTS_BASE: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
35    name: "X",
36    ty: BuiltinParamType::Any,
37    arity: BuiltinParamArity::Required,
38    default: None,
39    description: "Input value to upload or recast on GPU.",
40}];
41
42const GPUARRAY_INPUTS_DIMS: [BuiltinParamDescriptor; 2] = [
43    BuiltinParamDescriptor {
44        name: "X",
45        ty: BuiltinParamType::Any,
46        arity: BuiltinParamArity::Required,
47        default: None,
48        description: "Input value to upload or recast on GPU.",
49    },
50    BuiltinParamDescriptor {
51        name: "dim",
52        ty: BuiltinParamType::SizeArg,
53        arity: BuiltinParamArity::Variadic,
54        default: None,
55        description: "Reshape dimensions (scalar dims or a single size vector tensor).",
56    },
57];
58
59const GPUARRAY_INPUTS_DTYPE: [BuiltinParamDescriptor; 2] = [
60    BuiltinParamDescriptor {
61        name: "X",
62        ty: BuiltinParamType::Any,
63        arity: BuiltinParamArity::Required,
64        default: None,
65        description: "Input value to upload or recast on GPU.",
66    },
67    BuiltinParamDescriptor {
68        name: "dtype",
69        ty: BuiltinParamType::StringScalar,
70        arity: BuiltinParamArity::Required,
71        default: Some("\"double\""),
72        description: "Class tag such as `single`, `int32`, `uint8`, `logical`, or `double`.",
73    },
74];
75
76const GPUARRAY_INPUTS_LIKE: [BuiltinParamDescriptor; 3] = [
77    BuiltinParamDescriptor {
78        name: "X",
79        ty: BuiltinParamType::Any,
80        arity: BuiltinParamArity::Required,
81        default: None,
82        description: "Input value to upload or recast on GPU.",
83    },
84    BuiltinParamDescriptor {
85        name: "like",
86        ty: BuiltinParamType::StringScalar,
87        arity: BuiltinParamArity::Required,
88        default: None,
89        description: "Literal keyword `\"like\"`.",
90    },
91    BuiltinParamDescriptor {
92        name: "prototype",
93        ty: BuiltinParamType::LikePrototype,
94        arity: BuiltinParamArity::Required,
95        default: None,
96        description: "Prototype value whose class drives output conversion.",
97    },
98];
99
100const GPUARRAY_INPUTS_DIMS_OPTIONS: [BuiltinParamDescriptor; 3] = [
101    BuiltinParamDescriptor {
102        name: "X",
103        ty: BuiltinParamType::Any,
104        arity: BuiltinParamArity::Required,
105        default: None,
106        description: "Input value to upload or recast on GPU.",
107    },
108    BuiltinParamDescriptor {
109        name: "dim",
110        ty: BuiltinParamType::SizeArg,
111        arity: BuiltinParamArity::Variadic,
112        default: None,
113        description: "Reshape dimensions (scalar dims or a single size vector tensor).",
114    },
115    BuiltinParamDescriptor {
116        name: "option",
117        ty: BuiltinParamType::Any,
118        arity: BuiltinParamArity::Variadic,
119        default: None,
120        description: "Class tags and/or `\"like\", prototype` qualifiers.",
121    },
122];
123
124const GPUARRAY_SIGNATURES: [BuiltinSignatureDescriptor; 5] = [
125    BuiltinSignatureDescriptor {
126        label: "G = gpuArray(X)",
127        inputs: &GPUARRAY_INPUTS_BASE,
128        outputs: &GPUARRAY_OUTPUT,
129    },
130    BuiltinSignatureDescriptor {
131        label: "G = gpuArray(X, dim, ...)",
132        inputs: &GPUARRAY_INPUTS_DIMS,
133        outputs: &GPUARRAY_OUTPUT,
134    },
135    BuiltinSignatureDescriptor {
136        label: "G = gpuArray(X, dtype)",
137        inputs: &GPUARRAY_INPUTS_DTYPE,
138        outputs: &GPUARRAY_OUTPUT,
139    },
140    BuiltinSignatureDescriptor {
141        label: "G = gpuArray(X, \"like\", prototype)",
142        inputs: &GPUARRAY_INPUTS_LIKE,
143        outputs: &GPUARRAY_OUTPUT,
144    },
145    BuiltinSignatureDescriptor {
146        label: "G = gpuArray(X, dim, ..., option, ...)",
147        inputs: &GPUARRAY_INPUTS_DIMS_OPTIONS,
148        outputs: &GPUARRAY_OUTPUT,
149    },
150];
151
152const GPUARRAY_ERROR_NO_PROVIDER: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
153    code: "RM.GPUARRAY.NO_PROVIDER",
154    identifier: Some("RunMat:gpuArray:NoProvider"),
155    when: "No acceleration provider is registered for host/device transfers.",
156    message: "gpuArray: no acceleration provider registered",
157};
158
159const GPUARRAY_ERROR_OPTION_ARGUMENT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
160    code: "RM.GPUARRAY.OPTION_ARGUMENT",
161    identifier: Some("RunMat:gpuArray:OptionArgument"),
162    when: "Option tail contains non-text values where class tags/keywords are expected.",
163    message: "gpuArray: invalid option argument",
164};
165
166const GPUARRAY_ERROR_LIKE_MISSING: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
167    code: "RM.GPUARRAY.LIKE_MISSING",
168    identifier: Some("RunMat:gpuArray:LikeMissingPrototype"),
169    when: "Keyword `like` is supplied without a following prototype value.",
170    message: "gpuArray: expected a prototype value after 'like'",
171};
172
173const GPUARRAY_ERROR_LIKE_DUPLICATE: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
174    code: "RM.GPUARRAY.LIKE_DUPLICATE",
175    identifier: Some("RunMat:gpuArray:LikeDuplicate"),
176    when: "Keyword `like` appears more than once.",
177    message: "gpuArray: duplicate 'like' qualifier",
178};
179
180const GPUARRAY_ERROR_CODISTRIBUTED_UNSUPPORTED: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
181    code: "RM.GPUARRAY.CODISTRIBUTED_UNSUPPORTED",
182    identifier: Some("RunMat:gpuArray:CodistributedUnsupported"),
183    when: "Distributed/codistributed qualifiers are requested.",
184    message: "gpuArray: codistributed arrays are not supported yet",
185};
186
187const GPUARRAY_ERROR_CONFLICTING_TYPE: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
188    code: "RM.GPUARRAY.CONFLICTING_TYPE",
189    identifier: Some("RunMat:gpuArray:ConflictingTypeQualifiers"),
190    when: "Multiple incompatible class qualifiers are supplied.",
191    message: "gpuArray: conflicting type qualifiers supplied",
192};
193
194const GPUARRAY_ERROR_UNKNOWN_OPTION: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
195    code: "RM.GPUARRAY.UNKNOWN_OPTION",
196    identifier: Some("RunMat:gpuArray:UnknownOption"),
197    when: "Text option is not a recognized class/keyword token.",
198    message: "gpuArray: unrecognised option",
199};
200
201const GPUARRAY_ERROR_SIZE_ARGUMENT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
202    code: "RM.GPUARRAY.SIZE_ARGUMENT",
203    identifier: Some("RunMat:gpuArray:InvalidSizeArgument"),
204    when: "Size arguments are malformed (not finite integers, negative, or invalid combinations).",
205    message: "gpuArray: invalid size argument",
206};
207
208const GPUARRAY_ERROR_LIKE_PROTOTYPE: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
209    code: "RM.GPUARRAY.LIKE_PROTOTYPE",
210    identifier: Some("RunMat:gpuArray:InvalidLikePrototype"),
211    when: "`like` prototype is unsupported for type inference.",
212    message: "gpuArray: invalid 'like' prototype",
213};
214
215const GPUARRAY_ERROR_INPUT_TYPE: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
216    code: "RM.GPUARRAY.INPUT_TYPE",
217    identifier: Some("RunMat:gpuArray:UnsupportedInputType"),
218    when: "Input value type cannot be uploaded/coerced to supported gpuArray storage.",
219    message: "gpuArray: unsupported input type",
220};
221
222const GPUARRAY_ERROR_CONVERSION: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
223    code: "RM.GPUARRAY.CONVERSION",
224    identifier: Some("RunMat:gpuArray:ConversionFailed"),
225    when: "Requested dtype conversion cannot be performed (for example NaN->logical).",
226    message: "gpuArray: conversion failed",
227};
228
229const GPUARRAY_ERROR_RESHAPE: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
230    code: "RM.GPUARRAY.RESHAPE",
231    identifier: Some("RunMat:gpuArray:ReshapeMismatch"),
232    when: "Requested shape does not preserve the element count.",
233    message: "gpuArray: cannot reshape gpuArray into requested size",
234};
235
236const GPUARRAY_ERROR_PROVIDER_IO: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
237    code: "RM.GPUARRAY.PROVIDER_IO",
238    identifier: Some("RunMat:gpuArray:ProviderIO"),
239    when: "Provider upload/download interaction fails.",
240    message: "gpuArray: provider I/O failed",
241};
242
243const GPUARRAY_ERROR_INTERNAL: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
244    code: "RM.GPUARRAY.INTERNAL",
245    identifier: Some("RunMat:gpuArray:InternalError"),
246    when: "Internal tensor/container conversion fails.",
247    message: "gpuArray: internal error",
248};
249
250const GPUARRAY_ERRORS: [BuiltinErrorDescriptor; 14] = [
251    GPUARRAY_ERROR_NO_PROVIDER,
252    GPUARRAY_ERROR_OPTION_ARGUMENT,
253    GPUARRAY_ERROR_LIKE_MISSING,
254    GPUARRAY_ERROR_LIKE_DUPLICATE,
255    GPUARRAY_ERROR_CODISTRIBUTED_UNSUPPORTED,
256    GPUARRAY_ERROR_CONFLICTING_TYPE,
257    GPUARRAY_ERROR_UNKNOWN_OPTION,
258    GPUARRAY_ERROR_SIZE_ARGUMENT,
259    GPUARRAY_ERROR_LIKE_PROTOTYPE,
260    GPUARRAY_ERROR_INPUT_TYPE,
261    GPUARRAY_ERROR_CONVERSION,
262    GPUARRAY_ERROR_RESHAPE,
263    GPUARRAY_ERROR_PROVIDER_IO,
264    GPUARRAY_ERROR_INTERNAL,
265];
266
267pub const GPUARRAY_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
268    signatures: &GPUARRAY_SIGNATURES,
269    output_mode: BuiltinOutputMode::Fixed,
270    completion_policy: BuiltinCompletionPolicy::Public,
271    errors: &GPUARRAY_ERRORS,
272};
273
274fn gpu_array_error(error: &'static BuiltinErrorDescriptor) -> RuntimeError {
275    gpu_array_error_with_message(error.message, error)
276}
277
278fn gpu_array_error_with_message(
279    message: impl Into<String>,
280    error: &'static BuiltinErrorDescriptor,
281) -> RuntimeError {
282    let mut builder = build_runtime_error(message).with_builtin(BUILTIN_NAME);
283    if let Some(identifier) = error.identifier {
284        builder = builder.with_identifier(identifier);
285    }
286    builder.build()
287}
288
289fn gpu_array_error_with_detail(
290    error: &'static BuiltinErrorDescriptor,
291    detail: impl AsRef<str>,
292) -> RuntimeError {
293    gpu_array_error_with_message(format!("{}: {}", error.message, detail.as_ref()), error)
294}
295
296#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::acceleration::gpu::gpuarray")]
297pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
298    name: "gpuArray",
299    op_kind: GpuOpKind::Custom("upload"),
300    supported_precisions: &[ScalarType::F32, ScalarType::F64],
301    broadcast: BroadcastSemantics::None,
302    provider_hooks: &[ProviderHook::Custom("upload")],
303    constant_strategy: ConstantStrategy::InlineLiteral,
304    residency: ResidencyPolicy::NewHandle,
305    nan_mode: ReductionNaN::Include,
306    two_pass_threshold: None,
307    workgroup_size: None,
308    accepts_nan_mode: false,
309    notes: "Invokes the provider `upload` hook, including complex interleaved uploads, and reuploads gpuArray inputs when dtype conversion is requested. Handles class strings, size vectors, and `'like'` prototypes.",
310};
311
312#[runmat_macros::register_fusion_spec(
313    builtin_path = "crate::builtins::acceleration::gpu::gpuarray"
314)]
315pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
316    name: "gpuArray",
317    shape: ShapeRequirements::Any,
318    constant_strategy: ConstantStrategy::InlineLiteral,
319    elementwise: None,
320    reduction: None,
321    emits_nan: false,
322    notes:
323        "Acts as a residency boundary; fusion graphs never cross explicit host↔device transfers.",
324};
325
326#[runtime_builtin(
327    name = "gpuArray",
328    category = "acceleration/gpu",
329    summary = "Move data to the GPU as gpuArray values.",
330    keywords = "gpuArray,gpu,accelerate,upload,dtype,like",
331    examples = "G = gpuArray([1 2 3], 'single');",
332    accel = "array_construct",
333    type_resolver(gpuarray_type),
334    descriptor(crate::builtins::acceleration::gpu::gpuarray::GPUARRAY_DESCRIPTOR),
335    builtin_path = "crate::builtins::acceleration::gpu::gpuarray"
336)]
337async fn gpu_array_builtin(value: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
338    let options = parse_options(&rest)?;
339    let incoming_precision = match &value {
340        Value::GpuTensor(handle) => runmat_accelerate_api::handle_precision(handle),
341        _ => None,
342    };
343    let dtype = resolve_dtype(&value, &options)?;
344    let dims = options.dims.clone();
345
346    let prepared = match value {
347        Value::GpuTensor(handle) => convert_device_value(handle, dtype).await?,
348        other => upload_host_value(other, dtype)?,
349    };
350
351    let mut handle = prepared.handle;
352
353    if let Some(dims) = dims.as_ref() {
354        apply_dims(&mut handle, dims)?;
355    }
356
357    let provider_precision = runmat_accelerate_api::provider()
358        .map(|p| p.precision())
359        .unwrap_or(ProviderPrecision::F64);
360    let requested_precision = match dtype {
361        DataClass::Single => Some(ProviderPrecision::F32),
362        _ => None,
363    };
364    let final_precision = requested_precision
365        .or(incoming_precision)
366        .unwrap_or(provider_precision);
367    runmat_accelerate_api::set_handle_precision(&handle, final_precision);
368
369    runmat_accelerate_api::set_handle_logical(&handle, prepared.logical);
370
371    Ok(Value::GpuTensor(handle))
372}
373
374#[derive(Clone, Copy, Debug, PartialEq, Eq)]
375enum DataClass {
376    Double,
377    Single,
378    Logical,
379    Int8,
380    Int16,
381    Int32,
382    Int64,
383    UInt8,
384    UInt16,
385    UInt32,
386    UInt64,
387}
388
389impl DataClass {
390    fn from_tag(tag: &str) -> Option<Self> {
391        match tag {
392            "double" => Some(Self::Double),
393            "single" | "float32" => Some(Self::Single),
394            "logical" | "bool" | "boolean" => Some(Self::Logical),
395            "int8" => Some(Self::Int8),
396            "int16" => Some(Self::Int16),
397            "int32" | "int" => Some(Self::Int32),
398            "int64" => Some(Self::Int64),
399            "uint8" => Some(Self::UInt8),
400            "uint16" => Some(Self::UInt16),
401            "uint32" => Some(Self::UInt32),
402            "uint64" => Some(Self::UInt64),
403            "gpuarray" => None, // compatibility no-op
404            _ => None,
405        }
406    }
407}
408
409#[derive(Debug, Default)]
410struct ParsedOptions {
411    dims: Option<Vec<usize>>,
412    explicit_dtype: Option<DataClass>,
413    prototype: Option<Value>,
414}
415
416fn parse_options(rest: &[Value]) -> BuiltinResult<ParsedOptions> {
417    let (index_after_dims, dims) = parse_size_arguments(rest)?;
418    let mut options = ParsedOptions {
419        dims,
420        ..ParsedOptions::default()
421    };
422
423    let mut idx = index_after_dims;
424    while idx < rest.len() {
425        let tag = value_to_lower_string(&rest[idx]).ok_or_else(|| {
426            gpu_array_error_with_message(
427                format!(
428                "gpuArray: unexpected argument {:?}; expected a class string or the keyword 'like'",
429                rest[idx]
430                ),
431                &GPUARRAY_ERROR_OPTION_ARGUMENT,
432            )
433        })?;
434
435        match tag.as_str() {
436            "like" => {
437                idx += 1;
438                if idx >= rest.len() {
439                    return Err(gpu_array_error(&GPUARRAY_ERROR_LIKE_MISSING));
440                }
441                if options.prototype.is_some() {
442                    return Err(gpu_array_error(&GPUARRAY_ERROR_LIKE_DUPLICATE));
443                }
444                options.prototype = Some(rest[idx].clone());
445            }
446            "distributed" | "codistributed" => {
447                return Err(gpu_array_error(&GPUARRAY_ERROR_CODISTRIBUTED_UNSUPPORTED));
448            }
449            tag => {
450                if let Some(class) = DataClass::from_tag(tag) {
451                    if let Some(existing) = options.explicit_dtype {
452                        if existing != class {
453                            return Err(gpu_array_error(&GPUARRAY_ERROR_CONFLICTING_TYPE));
454                        }
455                    } else {
456                        options.explicit_dtype = Some(class);
457                    }
458                } else if tag != "gpuarray" {
459                    return Err(gpu_array_error_with_detail(
460                        &GPUARRAY_ERROR_UNKNOWN_OPTION,
461                        format!("unrecognised option '{tag}'"),
462                    ));
463                }
464            }
465        }
466
467        idx += 1;
468    }
469
470    Ok(options)
471}
472
473fn parse_size_arguments(rest: &[Value]) -> BuiltinResult<(usize, Option<Vec<usize>>)> {
474    let mut idx = 0;
475    let mut dims: Vec<usize> = Vec::new();
476    let mut vector_consumed = false;
477
478    while idx < rest.len() {
479        // Stop at textual qualifiers only; numeric values continue parsing as size args.
480        match &rest[idx] {
481            Value::String(_) | Value::StringArray(_) | Value::CharArray(_) => break,
482            _ => {}
483        }
484
485        match &rest[idx] {
486            Value::Int(i) => {
487                dims.push(int_to_dim(i)?);
488            }
489            Value::Num(n) => {
490                dims.push(float_to_dim(*n)?);
491            }
492            Value::Tensor(t) => {
493                if vector_consumed || !dims.is_empty() {
494                    return Err(gpu_array_error_with_message(
495                        "gpuArray: size vectors cannot be combined with scalar dimensions",
496                        &GPUARRAY_ERROR_SIZE_ARGUMENT,
497                    ));
498                }
499                dims = tensor_to_dims(t)?;
500                vector_consumed = true;
501            }
502            _ => break,
503        }
504        idx += 1;
505    }
506
507    let dims_option = if dims.is_empty() { None } else { Some(dims) };
508    Ok((idx, dims_option))
509}
510
511fn value_to_lower_string(value: &Value) -> Option<String> {
512    crate::builtins::common::tensor::value_to_string(value).map(|s| s.trim().to_ascii_lowercase())
513}
514
515fn int_to_dim(value: &IntValue) -> BuiltinResult<usize> {
516    let raw = value.to_i64();
517    if raw < 0 {
518        return Err(gpu_array_error_with_message(
519            "gpuArray: size arguments must be non-negative integers",
520            &GPUARRAY_ERROR_SIZE_ARGUMENT,
521        ));
522    }
523    Ok(raw as usize)
524}
525
526fn float_to_dim(value: f64) -> BuiltinResult<usize> {
527    if !value.is_finite() {
528        return Err(gpu_array_error_with_message(
529            "gpuArray: size arguments must be finite integers",
530            &GPUARRAY_ERROR_SIZE_ARGUMENT,
531        ));
532    }
533    let rounded = value.round();
534    if (rounded - value).abs() > f64::EPSILON {
535        return Err(gpu_array_error_with_message(
536            "gpuArray: size arguments must be integers",
537            &GPUARRAY_ERROR_SIZE_ARGUMENT,
538        ));
539    }
540    if rounded < 0.0 {
541        return Err(gpu_array_error_with_message(
542            "gpuArray: size arguments must be non-negative",
543            &GPUARRAY_ERROR_SIZE_ARGUMENT,
544        ));
545    }
546    Ok(rounded as usize)
547}
548
549fn tensor_to_dims(tensor: &Tensor) -> BuiltinResult<Vec<usize>> {
550    let mut dims = Vec::with_capacity(tensor.data.len());
551    for value in &tensor.data {
552        dims.push(float_to_dim(*value)?);
553    }
554    Ok(dims)
555}
556
557fn resolve_dtype(value: &Value, options: &ParsedOptions) -> BuiltinResult<DataClass> {
558    if let Some(explicit) = options.explicit_dtype {
559        return Ok(explicit);
560    }
561    if let Some(prototype) = options.prototype.as_ref() {
562        return infer_dtype_from_prototype(prototype);
563    }
564    if value_defaults_to_logical(value) {
565        return Ok(DataClass::Logical);
566    }
567    Ok(DataClass::Double)
568}
569
570fn infer_dtype_from_prototype(proto: &Value) -> BuiltinResult<DataClass> {
571    match proto {
572        Value::GpuTensor(handle) => {
573            if runmat_accelerate_api::handle_is_logical(handle) {
574                Ok(DataClass::Logical)
575            } else {
576                Ok(DataClass::Double)
577            }
578        }
579        Value::LogicalArray(_) | Value::Bool(_) => Ok(DataClass::Logical),
580        Value::Int(int) => Ok(match int {
581            IntValue::I8(_) => DataClass::Int8,
582            IntValue::I16(_) => DataClass::Int16,
583            IntValue::I32(_) => DataClass::Int32,
584            IntValue::I64(_) => DataClass::Int64,
585            IntValue::U8(_) => DataClass::UInt8,
586            IntValue::U16(_) => DataClass::UInt16,
587            IntValue::U32(_) => DataClass::UInt32,
588            IntValue::U64(_) => DataClass::UInt64,
589        }),
590        Value::Tensor(_) | Value::Num(_) => Ok(DataClass::Double),
591        Value::CharArray(_) => Ok(DataClass::Double),
592        Value::String(_) => Err(gpu_array_error_with_message(
593            "gpuArray: 'like' does not accept MATLAB string scalars; convert to char() first",
594            &GPUARRAY_ERROR_LIKE_PROTOTYPE,
595        )),
596        Value::StringArray(_) => Err(gpu_array_error_with_message(
597            "gpuArray: 'like' does not accept string arrays; convert to char arrays first",
598            &GPUARRAY_ERROR_LIKE_PROTOTYPE,
599        )),
600        Value::Complex(_, _) | Value::ComplexTensor(_) => Ok(DataClass::Double),
601        other => Err(gpu_array_error_with_message(
602            format!(
603                "gpuArray: unsupported 'like' prototype type {other:?}; expected numeric or logical values"
604            ),
605            &GPUARRAY_ERROR_LIKE_PROTOTYPE,
606        )),
607    }
608}
609
610fn value_defaults_to_logical(value: &Value) -> bool {
611    match value {
612        Value::LogicalArray(_) | Value::Bool(_) => true,
613        Value::GpuTensor(handle) => runmat_accelerate_api::handle_is_logical(handle),
614        _ => false,
615    }
616}
617
618struct PreparedHandle {
619    handle: GpuTensorHandle,
620    logical: bool,
621}
622
623fn upload_host_value(value: Value, dtype: DataClass) -> BuiltinResult<PreparedHandle> {
624    #[cfg(all(test, feature = "wgpu"))]
625    {
626        if runmat_accelerate_api::provider().is_none() {
627            let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
628                runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
629            );
630        }
631    }
632    let provider = runmat_accelerate_api::provider()
633        .ok_or_else(|| gpu_array_error(&GPUARRAY_ERROR_NO_PROVIDER))?;
634
635    match value {
636        Value::Complex(re, im) => {
637            let tensor = ComplexTensor::new(vec![(re, im)], vec![1, 1]).map_err(|err| {
638                gpu_array_error_with_message(format!("gpuArray: {err}"), &GPUARRAY_ERROR_INTERNAL)
639            })?;
640            upload_complex_host_value(provider, tensor, dtype)
641        }
642        Value::ComplexTensor(tensor) => upload_complex_host_value(provider, tensor, dtype),
643        value => upload_real_host_value(provider, value, dtype),
644    }
645}
646
647fn upload_real_host_value(
648    provider: &dyn runmat_accelerate_api::AccelProvider,
649    value: Value,
650    dtype: DataClass,
651) -> BuiltinResult<PreparedHandle> {
652    let tensor = coerce_host_value(value)?;
653    let (mut tensor, logical) = cast_tensor(tensor, dtype)?;
654
655    let view = HostTensorView {
656        data: &tensor.data,
657        shape: &tensor.shape,
658    };
659    let new_handle = provider.upload(&view).map_err(|err| {
660        gpu_array_error_with_message(format!("gpuArray: {err}"), &GPUARRAY_ERROR_PROVIDER_IO)
661    })?;
662
663    tensor.data.clear();
664
665    Ok(PreparedHandle {
666        handle: new_handle,
667        logical,
668    })
669}
670
671fn upload_complex_host_value(
672    provider: &dyn runmat_accelerate_api::AccelProvider,
673    mut tensor: ComplexTensor,
674    dtype: DataClass,
675) -> BuiltinResult<PreparedHandle> {
676    match dtype {
677        DataClass::Double => {}
678        DataClass::Single => {
679            for (re, im) in &mut tensor.data {
680                *re = (*re as f32) as f64;
681                *im = (*im as f32) as f64;
682            }
683        }
684        _ => {
685            return Err(gpu_array_error_with_message(
686                "gpuArray: complex inputs can only be uploaded as double or single precision",
687                &GPUARRAY_ERROR_INPUT_TYPE,
688            ));
689        }
690    }
691
692    let handle = gpu_helpers::upload_complex_tensor(provider, &tensor).map_err(|err| {
693        gpu_array_error_with_message(err.to_string(), &GPUARRAY_ERROR_PROVIDER_IO)
694    })?;
695    let precision = match dtype {
696        DataClass::Double => runmat_accelerate_api::ProviderPrecision::F64,
697        DataClass::Single => runmat_accelerate_api::ProviderPrecision::F32,
698        _ => unreachable!("complex dtype was validated above"),
699    };
700    runmat_accelerate_api::set_handle_precision(&handle, precision);
701    Ok(PreparedHandle {
702        handle,
703        logical: false,
704    })
705}
706
707async fn convert_device_value(
708    handle: GpuTensorHandle,
709    dtype: DataClass,
710) -> BuiltinResult<PreparedHandle> {
711    let was_logical = runmat_accelerate_api::handle_is_logical(&handle);
712    let was_complex = runmat_accelerate_api::handle_storage(&handle)
713        == runmat_accelerate_api::GpuTensorStorage::ComplexInterleaved;
714    let current_precision = runmat_accelerate_api::handle_precision(&handle);
715    match dtype {
716        DataClass::Double => {
717            if !(was_complex
718                && current_precision == Some(runmat_accelerate_api::ProviderPrecision::F32))
719            {
720                return Ok(PreparedHandle {
721                    handle,
722                    logical: false,
723                });
724            }
725        }
726        DataClass::Logical => {
727            if was_logical {
728                return Ok(PreparedHandle {
729                    handle,
730                    logical: true,
731                });
732            }
733        }
734        _ => {}
735    }
736
737    let provider = runmat_accelerate_api::provider_for_handle(&handle)
738        .or_else(runmat_accelerate_api::provider)
739        .ok_or_else(|| gpu_array_error(&GPUARRAY_ERROR_NO_PROVIDER))?;
740    if was_complex {
741        let gathered = gpu_helpers::gather_value_async(&Value::GpuTensor(handle.clone()))
742            .await
743            .map_err(|err| {
744                gpu_array_error_with_message(err.to_string(), &GPUARRAY_ERROR_PROVIDER_IO)
745            })?;
746        let Value::ComplexTensor(tensor) = gathered else {
747            return Err(gpu_array_error_with_message(
748                "gpuArray: expected complex gpuArray data during conversion",
749                &GPUARRAY_ERROR_PROVIDER_IO,
750            ));
751        };
752        let prepared = upload_complex_host_value(provider, tensor, dtype)?;
753        provider.free(&handle).ok();
754        return Ok(prepared);
755    }
756
757    let tensor = gpu_helpers::gather_tensor_async(&handle)
758        .await
759        .map_err(|err| {
760            gpu_array_error_with_message(err.to_string(), &GPUARRAY_ERROR_PROVIDER_IO)
761        })?;
762    let (mut tensor, logical) = cast_tensor(tensor, dtype)?;
763
764    let view = HostTensorView {
765        data: &tensor.data,
766        shape: &tensor.shape,
767    };
768    let new_handle = provider.upload(&view).map_err(|err| {
769        gpu_array_error_with_message(format!("gpuArray: {err}"), &GPUARRAY_ERROR_PROVIDER_IO)
770    })?;
771
772    provider.free(&handle).ok();
773    tensor.data.clear();
774
775    Ok(PreparedHandle {
776        handle: new_handle,
777        logical,
778    })
779}
780
781fn coerce_host_value(value: Value) -> BuiltinResult<Tensor> {
782    match value {
783        Value::Tensor(t) => Ok(t),
784        Value::LogicalArray(logical) => tensor::logical_to_tensor(&logical).map_err(|err| {
785            gpu_array_error_with_message(format!("gpuArray: {err}"), &GPUARRAY_ERROR_INTERNAL)
786        }),
787        Value::Bool(flag) => {
788            Tensor::new(vec![if flag { 1.0 } else { 0.0 }], vec![1, 1]).map_err(|err| {
789                gpu_array_error_with_message(format!("gpuArray: {err}"), &GPUARRAY_ERROR_INTERNAL)
790            })
791        }
792        Value::Num(n) => Tensor::new(vec![n], vec![1, 1]).map_err(|err| {
793            gpu_array_error_with_message(format!("gpuArray: {err}"), &GPUARRAY_ERROR_INTERNAL)
794        }),
795        Value::Int(i) => Tensor::new(vec![i.to_f64()], vec![1, 1]).map_err(|err| {
796            gpu_array_error_with_message(format!("gpuArray: {err}"), &GPUARRAY_ERROR_INTERNAL)
797        }),
798        Value::CharArray(ca) => char_array_to_tensor(&ca),
799        Value::String(text) => {
800            let ca = CharArray::new_row(&text);
801            char_array_to_tensor(&ca)
802        }
803        Value::StringArray(_) => Err(gpu_array_error_with_message(
804            "gpuArray: string arrays are not supported yet; convert to char arrays with CHAR first",
805            &GPUARRAY_ERROR_INPUT_TYPE,
806        )),
807        Value::Complex(_, _) | Value::ComplexTensor(_) => Err(gpu_array_error_with_message(
808            "gpuArray: internal complex upload routing failed",
809            &GPUARRAY_ERROR_INTERNAL,
810        )),
811        other => Err(gpu_array_error_with_detail(
812            &GPUARRAY_ERROR_INPUT_TYPE,
813            format!("unsupported input type for GPU transfer: {other:?}"),
814        )),
815    }
816}
817
818fn cast_tensor(mut tensor: Tensor, dtype: DataClass) -> BuiltinResult<(Tensor, bool)> {
819    let logical = match dtype {
820        DataClass::Logical => {
821            convert_to_logical(&mut tensor.data)?;
822            true
823        }
824        DataClass::Single => {
825            convert_to_single(&mut tensor.data);
826            false
827        }
828        DataClass::Int8 => {
829            convert_to_int_range(&mut tensor.data, i8::MIN as f64, i8::MAX as f64);
830            false
831        }
832        DataClass::Int16 => {
833            convert_to_int_range(&mut tensor.data, i16::MIN as f64, i16::MAX as f64);
834            false
835        }
836        DataClass::Int32 => {
837            convert_to_int_range(&mut tensor.data, i32::MIN as f64, i32::MAX as f64);
838            false
839        }
840        DataClass::Int64 => {
841            convert_to_int_range(&mut tensor.data, i64::MIN as f64, i64::MAX as f64);
842            false
843        }
844        DataClass::UInt8 => {
845            convert_to_int_range(&mut tensor.data, 0.0, u8::MAX as f64);
846            false
847        }
848        DataClass::UInt16 => {
849            convert_to_int_range(&mut tensor.data, 0.0, u16::MAX as f64);
850            false
851        }
852        DataClass::UInt32 => {
853            convert_to_int_range(&mut tensor.data, 0.0, u32::MAX as f64);
854            false
855        }
856        DataClass::UInt64 => {
857            convert_to_int_range(&mut tensor.data, 0.0, u64::MAX as f64);
858            false
859        }
860        DataClass::Double => false,
861    };
862
863    Ok((tensor, logical))
864}
865
866fn convert_to_logical(data: &mut [f64]) -> BuiltinResult<()> {
867    for value in data.iter_mut() {
868        if value.is_nan() {
869            return Err(gpu_array_error_with_message(
870                "gpuArray: cannot convert NaN to logical",
871                &GPUARRAY_ERROR_CONVERSION,
872            ));
873        }
874        *value = if *value != 0.0 { 1.0 } else { 0.0 };
875    }
876    Ok(())
877}
878
879fn convert_to_single(data: &mut [f64]) {
880    for value in data.iter_mut() {
881        *value = (*value as f32) as f64;
882    }
883}
884
885fn convert_to_int_range(data: &mut [f64], min: f64, max: f64) {
886    for value in data.iter_mut() {
887        if value.is_nan() {
888            *value = min;
889            continue;
890        }
891        if value.is_infinite() {
892            *value = if value.is_sign_negative() { min } else { max };
893            continue;
894        }
895        let rounded = value.round();
896        *value = rounded.clamp(min, max);
897    }
898}
899
900fn apply_dims(handle: &mut GpuTensorHandle, dims: &[usize]) -> BuiltinResult<()> {
901    let new_elems: usize = dims.iter().product();
902    let current_elems: usize = if handle.shape.is_empty() {
903        new_elems
904    } else {
905        handle.shape.iter().product()
906    };
907    if new_elems != current_elems {
908        return Err(gpu_array_error_with_message(
909            format!(
910                "gpuArray: cannot reshape gpuArray of {current_elems} elements into size {:?}",
911                dims
912            ),
913            &GPUARRAY_ERROR_RESHAPE,
914        ));
915    }
916    handle.shape = dims.to_vec();
917    Ok(())
918}
919
920fn char_array_to_tensor(ca: &CharArray) -> BuiltinResult<Tensor> {
921    let rows = ca.rows;
922    let cols = ca.cols;
923    if rows == 0 || cols == 0 {
924        return Tensor::new(Vec::new(), vec![rows, cols]).map_err(|err| {
925            gpu_array_error_with_message(format!("gpuArray: {err}"), &GPUARRAY_ERROR_INTERNAL)
926        });
927    }
928    let mut data = vec![0.0; rows * cols];
929    // Store in row-major to preserve the original character order when interpreted with column-major indexing
930    for row in 0..rows {
931        for col in 0..cols {
932            let idx_char = row * cols + col;
933            let ch = ca.data[idx_char];
934            data[row * cols + col] = ch as u32 as f64;
935        }
936    }
937    Tensor::new(data, vec![rows, cols]).map_err(|err| {
938        gpu_array_error_with_message(format!("gpuArray: {err}"), &GPUARRAY_ERROR_INTERNAL)
939    })
940}
941
942#[cfg(test)]
943pub(crate) mod tests {
944    use super::*;
945    use crate::builtins::common::test_support;
946    use futures::executor::block_on;
947    use runmat_accelerate_api::{GpuTensorStorage, HostTensorView};
948    use runmat_builtins::{ComplexTensor, IntValue, LogicalArray, ResolveContext, Type};
949
950    fn call(value: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
951        block_on(gpu_array_builtin(value, rest))
952    }
953
954    fn gather_complex(value: Value) -> ComplexTensor {
955        match block_on(crate::dispatcher::gather_if_needed_async(&value)).expect("gather complex") {
956            Value::ComplexTensor(tensor) => tensor,
957            other => panic!("expected ComplexTensor, got {other:?}"),
958        }
959    }
960
961    fn assert_complex_close(actual: &[(f64, f64)], expected: &[(f64, f64)]) {
962        assert_eq!(actual.len(), expected.len());
963        for (idx, ((ar, ai), (er, ei))) in actual.iter().zip(expected.iter()).enumerate() {
964            assert!(
965                (ar - er).abs() < 1e-12 && (ai - ei).abs() < 1e-12,
966                "at {idx}: expected ({er}, {ei}), got ({ar}, {ai})"
967            );
968        }
969    }
970
971    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
972    #[test]
973    fn gpu_array_transfers_numeric_tensor() {
974        test_support::with_test_provider(|_| {
975            let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
976            let result = call(Value::Tensor(tensor.clone()), Vec::new()).expect("gpuArray upload");
977            let Value::GpuTensor(handle) = result else {
978                panic!("expected gpu tensor");
979            };
980            assert_eq!(handle.shape, tensor.shape);
981            let gathered =
982                test_support::gather(Value::GpuTensor(handle.clone())).expect("gather values");
983            assert_eq!(gathered.shape, tensor.shape);
984            assert_eq!(gathered.data, tensor.data);
985        });
986    }
987
988    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
989    #[test]
990    fn gpu_array_marks_logical_inputs() {
991        test_support::with_test_provider(|_| {
992            let logical =
993                LogicalArray::new(vec![1, 0, 1, 1], vec![2, 2]).expect("logical construction");
994            let result =
995                call(Value::LogicalArray(logical.clone()), Vec::new()).expect("gpuArray logical");
996            let Value::GpuTensor(handle) = result else {
997                panic!("expected gpu tensor");
998            };
999            assert!(runmat_accelerate_api::handle_is_logical(&handle));
1000            let gathered =
1001                test_support::gather(Value::GpuTensor(handle.clone())).expect("gather logical");
1002            assert_eq!(gathered.shape, logical.shape);
1003            assert_eq!(gathered.data, vec![1.0, 0.0, 1.0, 1.0]);
1004        });
1005    }
1006
1007    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1008    #[test]
1009    fn gpu_array_uploads_complex_tensor() {
1010        test_support::with_test_provider(|_| {
1011            let complex = ComplexTensor::new(vec![(1.0, -2.0), (3.5, 4.25)], vec![1, 2]).unwrap();
1012            let result =
1013                call(Value::ComplexTensor(complex.clone()), Vec::new()).expect("gpuArray complex");
1014            let Value::GpuTensor(handle) = result else {
1015                panic!("expected gpu tensor");
1016            };
1017            assert_eq!(
1018                runmat_accelerate_api::handle_storage(&handle),
1019                GpuTensorStorage::ComplexInterleaved
1020            );
1021            let gathered = gather_complex(Value::GpuTensor(handle.clone()));
1022            assert_eq!(gathered.shape, complex.shape);
1023            assert_complex_close(&gathered.data, &complex.data);
1024        });
1025    }
1026
1027    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1028    #[test]
1029    fn gpu_array_handles_scalar_bool() {
1030        test_support::with_test_provider(|_| {
1031            let result = call(Value::Bool(true), Vec::new()).expect("gpuArray bool");
1032            let Value::GpuTensor(handle) = result else {
1033                panic!("expected gpu tensor");
1034            };
1035            assert!(runmat_accelerate_api::handle_is_logical(&handle));
1036            let gathered =
1037                test_support::gather(Value::GpuTensor(handle.clone())).expect("gather bool");
1038            assert_eq!(gathered.shape, vec![1, 1]);
1039            assert_eq!(gathered.data, vec![1.0]);
1040        });
1041    }
1042
1043    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1044    #[test]
1045    fn gpu_array_supports_char_arrays() {
1046        test_support::with_test_provider(|_| {
1047            let chars = CharArray::new("row1row2".chars().collect(), 2, 4).unwrap();
1048            let original: Vec<char> = chars.data.clone();
1049            let result =
1050                call(Value::CharArray(chars), Vec::new()).expect("gpuArray char array upload");
1051            let Value::GpuTensor(handle) = result else {
1052                panic!("expected gpu tensor");
1053            };
1054            let gathered =
1055                test_support::gather(Value::GpuTensor(handle.clone())).expect("gather chars");
1056            assert_eq!(gathered.shape, vec![2, 4]);
1057            let mut recovered = Vec::new();
1058            for col in 0..4 {
1059                for row in 0..2 {
1060                    let idx = row + col * 2;
1061                    let code = gathered.data[idx];
1062                    let ch = char::from_u32(code as u32)
1063                        .expect("valid unicode scalar from numeric code");
1064                    recovered.push(ch);
1065                }
1066            }
1067            assert_eq!(recovered, original);
1068        });
1069    }
1070
1071    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1072    #[test]
1073    fn gpu_array_converts_strings() {
1074        test_support::with_test_provider(|_| {
1075            let result = call(Value::String("gpu".into()), Vec::new()).expect("gpuArray string");
1076            let Value::GpuTensor(handle) = result else {
1077                panic!("expected gpu tensor");
1078            };
1079            let gathered =
1080                test_support::gather(Value::GpuTensor(handle.clone())).expect("gather string");
1081            assert_eq!(gathered.shape, vec![1, 3]);
1082            let expected: Vec<f64> = "gpu".chars().map(|ch| ch as u32 as f64).collect();
1083            assert_eq!(gathered.data, expected);
1084        });
1085    }
1086
1087    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1088    #[test]
1089    fn gpu_array_passthrough_existing_handle() {
1090        test_support::with_test_provider(|provider| {
1091            let tensor = Tensor::new(vec![5.0, 6.0], vec![2, 1]).unwrap();
1092            let view = HostTensorView {
1093                data: &tensor.data,
1094                shape: &tensor.shape,
1095            };
1096            let handle = provider.upload(&view).expect("upload");
1097            let cloned = handle.clone();
1098            let result =
1099                call(Value::GpuTensor(handle.clone()), Vec::new()).expect("gpuArray passthrough");
1100            let Value::GpuTensor(returned) = result else {
1101                panic!("expected gpu tensor");
1102            };
1103            assert_eq!(returned.buffer_id, cloned.buffer_id);
1104            assert_eq!(returned.shape, cloned.shape);
1105        });
1106    }
1107
1108    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1109    #[test]
1110    fn gpu_array_passthrough_existing_complex_handle() {
1111        test_support::with_test_provider(|provider| {
1112            let complex = ComplexTensor::new(vec![(2.0, 3.0), (-4.0, 5.5)], vec![2, 1]).unwrap();
1113            let handle = gpu_helpers::upload_complex_tensor(provider, &complex).unwrap();
1114            let result =
1115                call(Value::GpuTensor(handle.clone()), Vec::new()).expect("gpuArray passthrough");
1116            let Value::GpuTensor(returned) = result else {
1117                panic!("expected gpu tensor");
1118            };
1119            assert_eq!(returned.buffer_id, handle.buffer_id);
1120            assert_eq!(
1121                runmat_accelerate_api::handle_storage(&returned),
1122                GpuTensorStorage::ComplexInterleaved
1123            );
1124            let gathered = gather_complex(Value::GpuTensor(returned.clone()));
1125            assert_complex_close(&gathered.data, &complex.data);
1126        });
1127    }
1128
1129    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1130    #[test]
1131    fn gpu_array_complex_gpu_to_single_reuploads_complex_handle() {
1132        test_support::with_test_provider(|provider| {
1133            let complex = ComplexTensor::new(
1134                vec![(1.234_567_89, -2.345_678_91), (3.456_789_12, 4.567_891_23)],
1135                vec![1, 2],
1136            )
1137            .unwrap();
1138            let handle = gpu_helpers::upload_complex_tensor(provider, &complex).unwrap();
1139            let result = call(
1140                Value::GpuTensor(handle.clone()),
1141                vec![Value::from("single")],
1142            )
1143            .expect("gpuArray complex single");
1144            let Value::GpuTensor(returned) = result else {
1145                panic!("expected gpu tensor");
1146            };
1147            assert_ne!(returned.buffer_id, handle.buffer_id);
1148            assert_eq!(
1149                runmat_accelerate_api::handle_storage(&returned),
1150                GpuTensorStorage::ComplexInterleaved
1151            );
1152            assert_eq!(
1153                runmat_accelerate_api::handle_precision(&returned),
1154                Some(runmat_accelerate_api::ProviderPrecision::F32)
1155            );
1156            let gathered = gather_complex(Value::GpuTensor(returned.clone()));
1157            let expected = complex
1158                .data
1159                .iter()
1160                .map(|(re, im)| ((*re as f32) as f64, (*im as f32) as f64))
1161                .collect::<Vec<_>>();
1162            assert_complex_close(&gathered.data, &expected);
1163        });
1164    }
1165
1166    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1167    #[test]
1168    fn gpu_array_casts_to_int32() {
1169        test_support::with_test_provider(|_| {
1170            let tensor = Tensor::new(vec![1.2, -3.7, 123456.0], vec![3, 1]).unwrap();
1171            let result =
1172                call(Value::Tensor(tensor), vec![Value::from("int32")]).expect("gpuArray int32");
1173            let Value::GpuTensor(handle) = result else {
1174                panic!("expected gpu tensor");
1175            };
1176            let gathered =
1177                test_support::gather(Value::GpuTensor(handle.clone())).expect("gather int32");
1178            assert_eq!(gathered.data, vec![1.0, -4.0, 123456.0]);
1179        });
1180    }
1181
1182    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1183    #[test]
1184    fn gpu_array_casts_to_uint8() {
1185        test_support::with_test_provider(|_| {
1186            let tensor = Tensor::new(vec![-12.0, 12.8, 300.4, f64::INFINITY], vec![4, 1]).unwrap();
1187            let result =
1188                call(Value::Tensor(tensor), vec![Value::from("uint8")]).expect("gpuArray uint8");
1189            let Value::GpuTensor(handle) = result else {
1190                panic!("expected gpu tensor");
1191            };
1192            let gathered =
1193                test_support::gather(Value::GpuTensor(handle.clone())).expect("gather uint8");
1194            assert_eq!(gathered.data, vec![0.0, 13.0, 255.0, 255.0]);
1195        });
1196    }
1197
1198    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1199    #[test]
1200    fn gpu_array_single_precision_rounds() {
1201        test_support::with_test_provider(|_| {
1202            let tensor = Tensor::new(vec![1.23456789, -9.87654321], vec![2, 1]).unwrap();
1203            let result =
1204                call(Value::Tensor(tensor), vec![Value::from("single")]).expect("gpuArray single");
1205            let Value::GpuTensor(handle) = result else {
1206                panic!("expected gpu tensor");
1207            };
1208            let gathered =
1209                test_support::gather(Value::GpuTensor(handle.clone())).expect("gather single");
1210            let expected = [1.234_567_9_f32 as f64, (-9.876_543_f32) as f64];
1211            for (observed, expected) in gathered.data.iter().zip(expected.iter()) {
1212                assert!((observed - expected).abs() < 1e-6);
1213            }
1214        });
1215    }
1216
1217    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1218    #[test]
1219    fn gpu_array_like_infers_logical() {
1220        test_support::with_test_provider(|_| {
1221            let tensor = Tensor::new(vec![0.0, 2.0, -3.0], vec![3, 1]).unwrap();
1222            let logical_proto =
1223                LogicalArray::new(vec![0, 1, 0], vec![3, 1]).expect("logical proto");
1224            let result = call(
1225                Value::Tensor(tensor),
1226                vec![Value::from("like"), Value::LogicalArray(logical_proto)],
1227            )
1228            .expect("gpuArray like logical");
1229            let Value::GpuTensor(handle) = result else {
1230                panic!("expected gpu tensor");
1231            };
1232            assert!(runmat_accelerate_api::handle_is_logical(&handle));
1233            let gathered = test_support::gather(Value::GpuTensor(handle.clone())).expect("gather");
1234            assert_eq!(gathered.data, vec![0.0, 1.0, 1.0]);
1235        });
1236    }
1237
1238    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1239    #[test]
1240    fn gpu_array_like_requires_argument() {
1241        test_support::with_test_provider(|_| {
1242            let tensor = Tensor::new(vec![1.0], vec![1, 1]).unwrap();
1243            let err = call(Value::Tensor(tensor), vec![Value::from("like")]).unwrap_err();
1244            assert_eq!(err.to_string(), GPUARRAY_ERROR_LIKE_MISSING.message);
1245            assert_eq!(err.identifier(), GPUARRAY_ERROR_LIKE_MISSING.identifier);
1246        });
1247    }
1248
1249    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1250    #[test]
1251    fn gpu_array_unknown_option_errors() {
1252        test_support::with_test_provider(|_| {
1253            let tensor = Tensor::new(vec![1.0], vec![1, 1]).unwrap();
1254            let err = call(Value::Tensor(tensor), vec![Value::from("mystery")]).unwrap_err();
1255            assert!(err
1256                .to_string()
1257                .contains(GPUARRAY_ERROR_UNKNOWN_OPTION.message));
1258            assert_eq!(err.identifier(), GPUARRAY_ERROR_UNKNOWN_OPTION.identifier);
1259        });
1260    }
1261
1262    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1263    #[test]
1264    fn gpu_array_gpu_to_logical_reuploads() {
1265        test_support::with_test_provider(|provider| {
1266            let tensor = Tensor::new(vec![2.0, 0.0, -5.5], vec![3, 1]).unwrap();
1267            let view = HostTensorView {
1268                data: &tensor.data,
1269                shape: &tensor.shape,
1270            };
1271            let handle = provider.upload(&view).expect("upload");
1272            let result = call(
1273                Value::GpuTensor(handle.clone()),
1274                vec![Value::from("logical")],
1275            )
1276            .expect("gpuArray logical cast");
1277            let Value::GpuTensor(new_handle) = result else {
1278                panic!("expected gpu tensor");
1279            };
1280            assert!(runmat_accelerate_api::handle_is_logical(&new_handle));
1281            let gathered =
1282                test_support::gather(Value::GpuTensor(new_handle.clone())).expect("gather");
1283            assert_eq!(gathered.data, vec![1.0, 0.0, 1.0]);
1284            provider.free(&handle).ok();
1285            provider.free(&new_handle).ok();
1286        });
1287    }
1288
1289    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1290    #[test]
1291    fn gpu_array_gpu_logical_to_double_clears_flag() {
1292        test_support::with_test_provider(|provider| {
1293            let tensor = Tensor::new(vec![1.0, 0.0], vec![2, 1]).unwrap();
1294            let view = HostTensorView {
1295                data: &tensor.data,
1296                shape: &tensor.shape,
1297            };
1298            let handle = provider.upload(&view).expect("upload");
1299            runmat_accelerate_api::set_handle_logical(&handle, true);
1300            let result = call(
1301                Value::GpuTensor(handle.clone()),
1302                vec![Value::from("double")],
1303            )
1304            .expect("gpuArray double cast");
1305            let Value::GpuTensor(new_handle) = result else {
1306                panic!("expected gpu tensor");
1307            };
1308            assert!(!runmat_accelerate_api::handle_is_logical(&new_handle));
1309            provider.free(&handle).ok();
1310            provider.free(&new_handle).ok();
1311        });
1312    }
1313
1314    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1315    #[test]
1316    fn gpu_array_applies_size_arguments() {
1317        test_support::with_test_provider(|_| {
1318            let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![4, 1]).unwrap();
1319            let result = call(
1320                Value::Tensor(tensor),
1321                vec![Value::from(2i32), Value::from(2i32)],
1322            )
1323            .expect("gpuArray reshape");
1324            let Value::GpuTensor(handle) = result else {
1325                panic!("expected gpu tensor");
1326            };
1327            assert_eq!(handle.shape, vec![2, 2]);
1328        });
1329    }
1330
1331    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1332    #[test]
1333    fn gpu_array_gpu_size_arguments_update_shape() {
1334        test_support::with_test_provider(|provider| {
1335            let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![4, 1]).unwrap();
1336            let view = HostTensorView {
1337                data: &tensor.data,
1338                shape: &tensor.shape,
1339            };
1340            let handle = provider.upload(&view).expect("upload");
1341            let result = call(
1342                Value::GpuTensor(handle.clone()),
1343                vec![Value::from(2i32), Value::from(2i32)],
1344            )
1345            .expect("gpuArray gpu reshape");
1346            let Value::GpuTensor(new_handle) = result else {
1347                panic!("expected gpu tensor");
1348            };
1349            assert_eq!(new_handle.shape, vec![2, 2]);
1350            provider.free(&handle).ok();
1351            provider.free(&new_handle).ok();
1352        });
1353    }
1354
1355    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1356    #[test]
1357    fn gpu_array_size_mismatch_errors() {
1358        test_support::with_test_provider(|_| {
1359            let tensor = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
1360            let err = call(
1361                Value::Tensor(tensor),
1362                vec![Value::from(2i32), Value::from(2i32)],
1363            )
1364            .unwrap_err();
1365            assert!(err.to_string().contains("cannot reshape"));
1366            assert_eq!(err.identifier(), GPUARRAY_ERROR_RESHAPE.identifier);
1367        });
1368    }
1369
1370    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1371    #[test]
1372    #[cfg(feature = "wgpu")]
1373    fn gpu_array_wgpu_roundtrip() {
1374        use runmat_accelerate_api::AccelProvider;
1375
1376        match runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1377            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1378        ) {
1379            Ok(provider) => {
1380                let tensor = Tensor::new(vec![1.0, 2.5, 3.5], vec![3, 1]).unwrap();
1381                let result = call(Value::Tensor(tensor.clone()), vec![Value::from("int32")])
1382                    .expect("wgpu upload");
1383                let Value::GpuTensor(handle) = result else {
1384                    panic!("expected gpu tensor");
1385                };
1386                let gathered =
1387                    test_support::gather(Value::GpuTensor(handle.clone())).expect("wgpu gather");
1388                assert_eq!(gathered.shape, vec![3, 1]);
1389                assert_eq!(gathered.data, vec![1.0, 3.0, 4.0]);
1390                provider.free(&handle).ok();
1391            }
1392            Err(err) => {
1393                tracing::warn!("Skipping gpu_array_wgpu_roundtrip: {err}");
1394            }
1395        }
1396        runmat_accelerate::simple_provider::register_inprocess_provider();
1397    }
1398
1399    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1400    #[test]
1401    #[cfg(feature = "wgpu")]
1402    fn gpu_array_wgpu_complex_roundtrip() {
1403        use runmat_accelerate_api::AccelProvider;
1404
1405        match runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1406            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1407        ) {
1408            Ok(provider) => {
1409                let complex =
1410                    ComplexTensor::new(vec![(1.25, -0.5), (-3.0, 4.0)], vec![1, 2]).unwrap();
1411                let result =
1412                    call(Value::ComplexTensor(complex.clone()), Vec::new()).expect("wgpu upload");
1413                let Value::GpuTensor(handle) = result else {
1414                    panic!("expected gpu tensor");
1415                };
1416                assert_eq!(
1417                    runmat_accelerate_api::handle_storage(&handle),
1418                    GpuTensorStorage::ComplexInterleaved
1419                );
1420                let gathered = gather_complex(Value::GpuTensor(handle.clone()));
1421                assert_eq!(gathered.shape, vec![1, 2]);
1422                assert_complex_close(&gathered.data, &complex.data);
1423                provider.free(&handle).ok();
1424            }
1425            Err(err) => {
1426                tracing::warn!("Skipping gpu_array_wgpu_complex_roundtrip: {err}");
1427            }
1428        }
1429        runmat_accelerate::simple_provider::register_inprocess_provider();
1430    }
1431
1432    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1433    #[test]
1434    fn gpu_array_accepts_int_scalars() {
1435        test_support::with_test_provider(|_| {
1436            let value = Value::Int(IntValue::I32(7));
1437            let result = call(value, Vec::new()).expect("gpuArray int");
1438            let Value::GpuTensor(handle) = result else {
1439                panic!("expected gpu tensor");
1440            };
1441            let gathered =
1442                test_support::gather(Value::GpuTensor(handle.clone())).expect("gather int");
1443            assert_eq!(gathered.shape, vec![1, 1]);
1444            assert_eq!(gathered.data, vec![7.0]);
1445        });
1446    }
1447
1448    #[test]
1449    fn gpuarray_type_for_logical_is_logical() {
1450        assert_eq!(
1451            gpuarray_type(&[Type::logical()], &ResolveContext::new(Vec::new())),
1452            Type::logical()
1453        );
1454    }
1455}