runmat_runtime/builtins/logical/tests/
isgpuarray.rs1use runmat_builtins::{
4 BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
5 BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
6 ResolveContext, Type, Value,
7};
8use runmat_macros::runtime_builtin;
9
10use crate::builtins::common::spec::{
11 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
12 ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
13};
14use crate::BuiltinResult;
15
16#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::logical::tests::isgpuarray")]
17pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
18 name: "isgpuarray",
19 op_kind: GpuOpKind::Custom("metadata"),
20 supported_precisions: &[ScalarType::F32, ScalarType::F64],
21 broadcast: BroadcastSemantics::None,
22 provider_hooks: &[],
23 constant_strategy: ConstantStrategy::InlineLiteral,
24 residency: ResidencyPolicy::GatherImmediately,
25 nan_mode: ReductionNaN::Include,
26 two_pass_threshold: None,
27 workgroup_size: None,
28 accepts_nan_mode: false,
29 notes: "Reports whether the value is a gpuArray without gathering device buffers.",
30};
31
32#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::logical::tests::isgpuarray")]
33pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
34 name: "isgpuarray",
35 shape: ShapeRequirements::Any,
36 constant_strategy: ConstantStrategy::InlineLiteral,
37 elementwise: None,
38 reduction: None,
39 emits_nan: false,
40 notes: "Metadata query that executes outside of fusion pipelines.",
41};
42
43const ISGPUARRAY_OUTPUT: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
44 name: "tf",
45 ty: BuiltinParamType::LogicalArray,
46 arity: BuiltinParamArity::Required,
47 default: None,
48 description: "True when input is a gpuArray handle.",
49}];
50
51const ISGPUARRAY_INPUTS: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
52 name: "A",
53 ty: BuiltinParamType::Any,
54 arity: BuiltinParamArity::Required,
55 default: None,
56 description: "Input value to test.",
57}];
58
59const ISGPUARRAY_SIGNATURES: [BuiltinSignatureDescriptor; 1] = [BuiltinSignatureDescriptor {
60 label: "tf = isgpuarray(A)",
61 inputs: &ISGPUARRAY_INPUTS,
62 outputs: &ISGPUARRAY_OUTPUT,
63}];
64
65const ISGPUARRAY_ERRORS: [BuiltinErrorDescriptor; 0] = [];
66
67pub const ISGPUARRAY_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
68 signatures: &ISGPUARRAY_SIGNATURES,
69 output_mode: BuiltinOutputMode::Fixed,
70 completion_policy: BuiltinCompletionPolicy::Public,
71 errors: &ISGPUARRAY_ERRORS,
72};
73
74#[runtime_builtin(
75 name = "isgpuarray",
76 category = "logical/tests",
77 summary = "Return true when a value is stored as a gpuArray handle.",
78 keywords = "isgpuarray,gpuarray,gpu,type,logical",
79 accel = "metadata",
80 type_resolver(bool_scalar_type),
81 descriptor(crate::builtins::logical::tests::isgpuarray::ISGPUARRAY_DESCRIPTOR),
82 builtin_path = "crate::builtins::logical::tests::isgpuarray"
83)]
84async fn isgpuarray_builtin(value: Value) -> BuiltinResult<Value> {
85 Ok(Value::Bool(matches!(value, Value::GpuTensor(_))))
86}
87
88fn bool_scalar_type(_: &[Type], _context: &ResolveContext) -> Type {
89 Type::Bool
90}
91
92#[cfg(test)]
93pub(crate) mod tests {
94 use super::*;
95 use crate::builtins::common::test_support;
96 use futures::executor::block_on;
97 use runmat_accelerate_api::HostTensorView;
98 use runmat_builtins::{Tensor, Value};
99
100 fn run_isgpuarray(value: Value) -> BuiltinResult<Value> {
101 block_on(super::isgpuarray_builtin(value))
102 }
103
104 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
105 #[test]
106 fn non_gpu_values_report_false() {
107 assert_eq!(run_isgpuarray(Value::Num(1.0)).unwrap(), Value::Bool(false));
108 }
109
110 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
111 #[test]
112 fn gpu_handles_report_true() {
113 test_support::with_test_provider(|provider| {
114 let tensor = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
115 let view = HostTensorView {
116 data: &tensor.data,
117 shape: &tensor.shape,
118 };
119 let handle = provider.upload(&view).expect("upload");
120 let result = run_isgpuarray(Value::GpuTensor(handle.clone())).expect("isgpuarray");
121 assert_eq!(result, Value::Bool(true));
122 provider.free(&handle).ok();
123 });
124 }
125}