runmat_runtime/builtins/logical/tests/
isgpuarray.rs1use runmat_builtins::{ResolveContext, Type, Value};
4use runmat_macros::runtime_builtin;
5
6use crate::builtins::common::spec::{
7 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
8 ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
9};
10use crate::BuiltinResult;
11
12#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::logical::tests::isgpuarray")]
13pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
14 name: "isgpuarray",
15 op_kind: GpuOpKind::Custom("metadata"),
16 supported_precisions: &[ScalarType::F32, ScalarType::F64],
17 broadcast: BroadcastSemantics::None,
18 provider_hooks: &[],
19 constant_strategy: ConstantStrategy::InlineLiteral,
20 residency: ResidencyPolicy::GatherImmediately,
21 nan_mode: ReductionNaN::Include,
22 two_pass_threshold: None,
23 workgroup_size: None,
24 accepts_nan_mode: false,
25 notes: "Reports whether the value is a gpuArray without gathering device buffers.",
26};
27
28#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::logical::tests::isgpuarray")]
29pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
30 name: "isgpuarray",
31 shape: ShapeRequirements::Any,
32 constant_strategy: ConstantStrategy::InlineLiteral,
33 elementwise: None,
34 reduction: None,
35 emits_nan: false,
36 notes: "Metadata query that executes outside of fusion pipelines.",
37};
38
39#[runtime_builtin(
40 name = "isgpuarray",
41 category = "logical/tests",
42 summary = "Return true when a value is stored as a gpuArray handle.",
43 keywords = "isgpuarray,gpuarray,gpu,type,logical",
44 accel = "metadata",
45 type_resolver(bool_scalar_type),
46 builtin_path = "crate::builtins::logical::tests::isgpuarray"
47)]
48async fn isgpuarray_builtin(value: Value) -> BuiltinResult<Value> {
49 Ok(Value::Bool(matches!(value, Value::GpuTensor(_))))
50}
51
52fn bool_scalar_type(_: &[Type], _context: &ResolveContext) -> Type {
53 Type::Bool
54}
55
56#[cfg(test)]
57pub(crate) mod tests {
58 use super::*;
59 use crate::builtins::common::test_support;
60 use futures::executor::block_on;
61 use runmat_accelerate_api::HostTensorView;
62 use runmat_builtins::{Tensor, Value};
63
64 fn run_isgpuarray(value: Value) -> BuiltinResult<Value> {
65 block_on(super::isgpuarray_builtin(value))
66 }
67
68 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
69 #[test]
70 fn non_gpu_values_report_false() {
71 assert_eq!(run_isgpuarray(Value::Num(1.0)).unwrap(), Value::Bool(false));
72 }
73
74 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
75 #[test]
76 fn gpu_handles_report_true() {
77 test_support::with_test_provider(|provider| {
78 let tensor = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
79 let view = HostTensorView {
80 data: &tensor.data,
81 shape: &tensor.shape,
82 };
83 let handle = provider.upload(&view).expect("upload");
84 let result = run_isgpuarray(Value::GpuTensor(handle.clone())).expect("isgpuarray");
85 assert_eq!(result, Value::Bool(true));
86 provider.free(&handle).ok();
87 });
88 }
89}