Skip to main content

runmat_runtime/builtins/logical/tests/
isgpuarray.rs

1//! MATLAB-compatible `isgpuarray` builtin with GPU-aware semantics for RunMat.
2
3use runmat_builtins::{
4    BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
5    BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
6    ResolveContext, Type, Value,
7};
8use runmat_macros::runtime_builtin;
9
10use crate::builtins::common::spec::{
11    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
12    ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
13};
14use crate::BuiltinResult;
15
16#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::logical::tests::isgpuarray")]
17pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
18    name: "isgpuarray",
19    op_kind: GpuOpKind::Custom("metadata"),
20    supported_precisions: &[ScalarType::F32, ScalarType::F64],
21    broadcast: BroadcastSemantics::None,
22    provider_hooks: &[],
23    constant_strategy: ConstantStrategy::InlineLiteral,
24    residency: ResidencyPolicy::GatherImmediately,
25    nan_mode: ReductionNaN::Include,
26    two_pass_threshold: None,
27    workgroup_size: None,
28    accepts_nan_mode: false,
29    notes: "Reports whether the value is a gpuArray without gathering device buffers.",
30};
31
32#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::logical::tests::isgpuarray")]
33pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
34    name: "isgpuarray",
35    shape: ShapeRequirements::Any,
36    constant_strategy: ConstantStrategy::InlineLiteral,
37    elementwise: None,
38    reduction: None,
39    emits_nan: false,
40    notes: "Metadata query that executes outside of fusion pipelines.",
41};
42
43const ISGPUARRAY_OUTPUT: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
44    name: "tf",
45    ty: BuiltinParamType::LogicalArray,
46    arity: BuiltinParamArity::Required,
47    default: None,
48    description: "True when input is a gpuArray handle.",
49}];
50
51const ISGPUARRAY_INPUTS: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
52    name: "A",
53    ty: BuiltinParamType::Any,
54    arity: BuiltinParamArity::Required,
55    default: None,
56    description: "Input value to test.",
57}];
58
59const ISGPUARRAY_SIGNATURES: [BuiltinSignatureDescriptor; 1] = [BuiltinSignatureDescriptor {
60    label: "tf = isgpuarray(A)",
61    inputs: &ISGPUARRAY_INPUTS,
62    outputs: &ISGPUARRAY_OUTPUT,
63}];
64
65const ISGPUARRAY_ERRORS: [BuiltinErrorDescriptor; 0] = [];
66
67pub const ISGPUARRAY_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
68    signatures: &ISGPUARRAY_SIGNATURES,
69    output_mode: BuiltinOutputMode::Fixed,
70    completion_policy: BuiltinCompletionPolicy::Public,
71    errors: &ISGPUARRAY_ERRORS,
72};
73
74#[runtime_builtin(
75    name = "isgpuarray",
76    category = "logical/tests",
77    summary = "Return true when a value is stored as a gpuArray handle.",
78    keywords = "isgpuarray,gpuarray,gpu,type,logical",
79    accel = "metadata",
80    type_resolver(bool_scalar_type),
81    descriptor(crate::builtins::logical::tests::isgpuarray::ISGPUARRAY_DESCRIPTOR),
82    builtin_path = "crate::builtins::logical::tests::isgpuarray"
83)]
84async fn isgpuarray_builtin(value: Value) -> BuiltinResult<Value> {
85    Ok(Value::Bool(matches!(value, Value::GpuTensor(_))))
86}
87
88fn bool_scalar_type(_: &[Type], _context: &ResolveContext) -> Type {
89    Type::Bool
90}
91
92#[cfg(test)]
93pub(crate) mod tests {
94    use super::*;
95    use crate::builtins::common::test_support;
96    use futures::executor::block_on;
97    use runmat_accelerate_api::HostTensorView;
98    use runmat_builtins::{Tensor, Value};
99
100    fn run_isgpuarray(value: Value) -> BuiltinResult<Value> {
101        block_on(super::isgpuarray_builtin(value))
102    }
103
104    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
105    #[test]
106    fn non_gpu_values_report_false() {
107        assert_eq!(run_isgpuarray(Value::Num(1.0)).unwrap(), Value::Bool(false));
108    }
109
110    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
111    #[test]
112    fn gpu_handles_report_true() {
113        test_support::with_test_provider(|provider| {
114            let tensor = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
115            let view = HostTensorView {
116                data: &tensor.data,
117                shape: &tensor.shape,
118            };
119            let handle = provider.upload(&view).expect("upload");
120            let result = run_isgpuarray(Value::GpuTensor(handle.clone())).expect("isgpuarray");
121            assert_eq!(result, Value::Bool(true));
122            provider.free(&handle).ok();
123        });
124    }
125}