runmat_runtime/builtins/logical/
ops.rs

1//! MATLAB-compatible `logical` builtin with GPU-aware semantics for RunMat.
2
3use log::trace;
4use runmat_accelerate_api::{self, AccelProvider, GpuTensorHandle, HostTensorView};
5use runmat_builtins::{CharArray, ComplexTensor, LogicalArray, StringArray, Tensor, Value};
6use runmat_macros::runtime_builtin;
7
8use crate::builtins::common::{
9    gpu_helpers,
10    spec::{
11        BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
12        ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
13    },
14    tensor,
15};
16#[cfg(feature = "doc_export")]
17use crate::register_builtin_doc_text;
18use crate::{register_builtin_fusion_spec, register_builtin_gpu_spec};
19
20#[cfg(feature = "doc_export")]
21pub const DOC_MD: &str = r#"---
22title: "logical"
23category: "logical"
24keywords: ["logical", "boolean conversion", "truth mask", "gpuArray", "mask array"]
25summary: "Convert scalars, arrays, and gpuArray values to MATLAB-compatible logical values."
26references: []
27gpu_support:
28  elementwise: true
29  reduction: false
30  precisions: ["f32", "f64"]
31  broadcasting: "matlab"
32  notes: "Prefers a device-side elem\\_ne(X, 0) cast when the provider supports elem_ne and zeros_like; otherwise gathers to the host, converts, and re-uploads the logical result."
33fusion:
34  elementwise: false
35  reduction: false
36  max_inputs: 1
37  constants: "inline"
38requires_feature: null
39tested:
40  unit: "builtins::logical::ops::tests"
41  integration: "builtins::logical::ops::tests::logical_gpu_roundtrip"
42---
43
44# What does the `logical` function do in MATLAB / RunMat?
45`logical(X)` converts numeric, logical, character, and gpuArray inputs into MATLAB logical values (booleans). Any non-zero (or `NaN`/`Inf`) element maps to `true`, while zero maps to `false`. Logical inputs are returned unchanged.
46
47## How does the `logical` function behave in MATLAB / RunMat?
48- `logical` accepts scalars, dense arrays, N-D tensors, and gpuArrays. Shapes are preserved bit-for-bit.
49- Non-zero numeric values, `NaN`, and `Inf` map to `true`; `0` and `-0` map to `false`.
50- Complex inputs are considered `true` when either the real or imaginary component is non-zero.
51- Character arrays are converted elementwise by interpreting code points (so `'A'` becomes `true`, `'\0'` becomes `false`).
52- Strings, structs, cells, objects, and other non-numeric types raise MATLAB-compatible errors (`"Conversion to logical from <type> is not possible"`).
53- Scalar results become logical scalars (`true`/`false`); higher-rank arrays produce dense logical arrays.
54
55## `logical` Function GPU Execution Behaviour
56- When a GPU provider implements `elem_ne` and `zeros_like`, RunMat performs the conversion in-place on the device by evaluating `elem_ne(X, 0)`, then marks the resulting handle as logical so predicates like `islogical` work without downloads.
57- If the provider cannot service the request (missing hooks, unsupported dtype, or allocation failure), the value is transparently gathered to the host, converted, and—when a provider is still available—re-uploaded as a logical gpuArray. The fallback is documented so users understand potential host/device transitions.
58- Handles that are already flagged as logical (`gpuArray.logical`) are returned without modification.
59- Scalars remain scalars: converting a `gpuArray` scalar preserves the residency and returns a logical gpuArray scalar.
60
61## Examples of using the `logical` function in MATLAB / RunMat
62
63### Creating a logical mask from numeric data
64```matlab
65values = [0 2 -3 0];
66mask = logical(values);
67```
68Expected output:
69```matlab
70mask =
71  1×4 logical array
72     0     1     1     0
73```
74
75### Building a logical mask from a matrix
76```matlab
77M = [-4 0 8; 0 1 0];
78mask = logical(M);
79```
80Expected output:
81```matlab
82mask =
83  2×3 logical array
84     1     0     1
85     0     1     0
86```
87
88### Treating NaN and Inf values as true
89```matlab
90flags = logical([NaN Inf 0]);
91```
92Expected output:
93```matlab
94flags =
95  1×3 logical array
96     1     1     0
97```
98
99### Converting complex numbers to logical scalars
100```matlab
101z = logical(3 + 4i);
102w = logical(0 + 0i);
103```
104Expected output:
105```matlab
106z =
107     1
108w =
109     0
110```
111
112### Converting character arrays to logical values
113```matlab
114chars = ['A' 0 'C'];
115mask = logical(chars);
116```
117Expected output:
118```matlab
119mask =
120  1×3 logical array
121     1     0     1
122```
123
124### Keeping gpuArray inputs on the device
125```matlab
126G = gpuArray([0 1 2]);
127maskGPU = logical(G);
128hostMask = gather(maskGPU);
129```
130Expected output:
131```matlab
132hostMask =
133  1×3 logical array
134     0     1     1
135```
136
137### Preserving empty shapes through logical conversion
138```matlab
139emptyVec = zeros(0, 3);
140logicalEmpty = logical(emptyVec);
141```
142Expected output:
143```matlab
144logicalEmpty =
145  0×3 logical array
146     []
147```
148
149## GPU residency in RunMat (Do I need `gpuArray`?)
150You rarely need to call `gpuArray` manually. When the acceleration provider is active, RunMat keeps logical conversions on the GPU by issuing `elem_ne(X, 0)` kernels (backed by `zeros_like` allocations) and flagging the handle as logical metadata. Explicit `gpuArray` calls are available for MATLAB compatibility or when you want to pin residency before interacting with external libraries. When the provider lacks the necessary hook, RunMat documents the fallback: it gathers the data, converts it on the host, and—if a provider is still available—re-uploads the logical mask so downstream GPU code continues to work without residency surprises.
151
152## FAQ
153
154### Which input types does `logical` support?
155Numeric, logical, complex, character, and gpuArray values are accepted. Strings, structs, cells, objects, and function handles are rejected with MATLAB-compatible error messages.
156
157### How are NaN or Inf values treated?
158They evaluate to `true`. MATLAB defines logical conversion as “non-zero”, and `NaN` / `Inf` both satisfy that rule.
159
160### How does `logical` handle complex numbers?
161The result is `true` when either the real or imaginary component is non-zero (or `NaN`/`Inf`). Only `0 + 0i` converts to `false`.
162
163### Does the builtin change array shapes?
164No. Shapes are preserved exactly, including empty dimensions and higher-rank tensors.
165
166### What happens to existing logical arrays?
167They are returned verbatim. Logical gpuArrays remain on the device without triggering new allocations.
168
169### Can I convert strings with `logical`?
170No. MATLAB rejects string inputs, and RunMat mirrors that behaviour: `"logical: conversion to logical from string is not possible"`.
171
172### What about structs, cells, or objects?
173They raise the same conversion error as MATLAB. Use functions like `~cellfun(@isempty, ...)` to derive masks instead.
174
175### Does the GPU path allocate new buffers?
176Only when the provider cannot operate in-place. The preferred path performs `elem_ne` against a zero tensor and reuses the resulting buffer. Fallback paths allocate a new gpuArray after gathering to the host.
177
178### Where can I learn more?
179See the references below and the RunMat source for implementation details.
180
181## See Also
182[`islogical`](./tests/islogical), [`gpuArray`](../../acceleration/gpu/gpuArray), [`gather`](../../acceleration/gpu/gather), [`find`](../../math/reduction/find)
183
184## Source & Feedback
185- Implementation: `crates/runmat-runtime/src/builtins/logical/ops.rs`
186- Issues & feature requests: [https://github.com/runmat-org/runmat/issues/new/choose](https://github.com/runmat-org/runmat/issues/new/choose)
187"#;
188
189pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
190    name: "logical",
191    op_kind: GpuOpKind::Elementwise,
192    supported_precisions: &[ScalarType::F32, ScalarType::F64],
193    broadcast: BroadcastSemantics::Matlab,
194    provider_hooks: &[ProviderHook::Binary {
195        name: "elem_ne",
196        commutative: true,
197    }],
198    constant_strategy: ConstantStrategy::InlineLiteral,
199    residency: ResidencyPolicy::NewHandle,
200    nan_mode: ReductionNaN::Include,
201    two_pass_threshold: None,
202    workgroup_size: None,
203    accepts_nan_mode: false,
204    notes: "Preferred path issues elem_ne(X, 0) on the device; missing hooks trigger a gather → host cast → re-upload sequence flagged as logical.",
205};
206
207register_builtin_gpu_spec!(GPU_SPEC);
208
209pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
210    name: "logical",
211    shape: ShapeRequirements::BroadcastCompatible,
212    constant_strategy: ConstantStrategy::InlineLiteral,
213    elementwise: None,
214    reduction: None,
215    emits_nan: false,
216    notes: "Fusion support will arrive alongside a dedicated WGSL template; today the builtin executes outside fusion plans.",
217};
218
219register_builtin_fusion_spec!(FUSION_SPEC);
220
221#[cfg(feature = "doc_export")]
222register_builtin_doc_text!("logical", DOC_MD);
223
224#[runtime_builtin(
225    name = "logical",
226    category = "logical",
227    summary = "Convert scalars, arrays, and gpuArray values to logical outputs.",
228    keywords = "logical,boolean,gpuArray,mask,conversion",
229    accel = "unary"
230)]
231fn logical_builtin(value: Value, rest: Vec<Value>) -> Result<Value, String> {
232    if !rest.is_empty() {
233        return Err("logical: too many input arguments".to_string());
234    }
235    convert_value_to_logical(value)
236}
237
238fn convert_value_to_logical(value: Value) -> Result<Value, String> {
239    match value {
240        Value::Bool(_) | Value::LogicalArray(_) => Ok(value),
241        Value::Num(n) => Ok(Value::Bool(n != 0.0)),
242        Value::Int(i) => Ok(Value::Bool(!i.is_zero())),
243        Value::Complex(re, im) => Ok(Value::Bool(!complex_is_zero(re, im))),
244        Value::Tensor(tensor) => logical_from_tensor(tensor),
245        Value::ComplexTensor(tensor) => logical_from_complex_tensor(tensor),
246        Value::CharArray(chars) => logical_from_char_array(chars),
247        Value::StringArray(strings) => logical_from_string_array(strings),
248        Value::GpuTensor(handle) => logical_from_gpu(handle),
249        Value::String(_) => Err(conversion_error("string")),
250        Value::Cell(_) => Err(conversion_error("cell")),
251        Value::Struct(_) => Err(conversion_error("struct")),
252        Value::Object(obj) => Err(conversion_error(&obj.class_name)),
253        Value::HandleObject(handle) => Err(conversion_error(&handle.class_name)),
254        Value::Listener(_) => Err(conversion_error("event.listener")),
255        Value::FunctionHandle(_) | Value::Closure(_) => Err(conversion_error("function_handle")),
256        Value::ClassRef(_) => Err(conversion_error("meta.class")),
257        Value::MException(_) => Err(conversion_error("MException")),
258    }
259}
260
261fn logical_from_tensor(tensor: Tensor) -> Result<Value, String> {
262    let buffer = LogicalBuffer::from_real_tensor(&tensor);
263    logical_buffer_to_host(buffer)
264}
265
266fn logical_from_complex_tensor(tensor: ComplexTensor) -> Result<Value, String> {
267    let buffer = LogicalBuffer::from_complex_tensor(&tensor);
268    logical_buffer_to_host(buffer)
269}
270
271fn logical_from_char_array(chars: CharArray) -> Result<Value, String> {
272    let buffer = LogicalBuffer::from_char_array(&chars);
273    logical_buffer_to_host(buffer)
274}
275
276fn logical_from_string_array(strings: StringArray) -> Result<Value, String> {
277    let bits: Vec<u8> = strings
278        .data
279        .iter()
280        .map(|s| if s.is_empty() { 0 } else { 1 })
281        .collect();
282    let shape = canonical_shape(&strings.shape, bits.len());
283    logical_buffer_to_host(LogicalBuffer { bits, shape })
284}
285
286fn logical_from_gpu(handle: GpuTensorHandle) -> Result<Value, String> {
287    if runmat_accelerate_api::handle_is_logical(&handle) {
288        return Ok(Value::GpuTensor(handle));
289    }
290
291    let provider = runmat_accelerate_api::provider();
292
293    if let Some(p) = provider {
294        match p.logical_islogical(&handle) {
295            Ok(true) => {
296                runmat_accelerate_api::set_handle_logical(&handle, true);
297                return Ok(Value::GpuTensor(handle));
298            }
299            Ok(false) => {}
300            Err(err) => {
301                trace!("logical: provider logical_islogical hook unavailable, falling back ({err})")
302            }
303        }
304        if let Some(result) = try_gpu_cast(p, &handle) {
305            return Ok(gpu_helpers::logical_gpu_value(result));
306        } else {
307            trace!(
308                "logical: provider elem_ne/zeros_like unavailable for buffer {} – gathering",
309                handle.buffer_id
310            );
311        }
312    }
313
314    let tensor = gpu_helpers::gather_tensor(&handle)?;
315    let buffer = LogicalBuffer::from_real_tensor(&tensor);
316    logical_buffer_to_gpu(buffer, provider)
317}
318
319fn logical_buffer_to_host(buffer: LogicalBuffer) -> Result<Value, String> {
320    let LogicalBuffer { bits, shape } = buffer;
321    if tensor::element_count(&shape) == 1 && bits.len() == 1 {
322        Ok(Value::Bool(bits[0] != 0))
323    } else {
324        LogicalArray::new(bits, shape)
325            .map(Value::LogicalArray)
326            .map_err(|e| format!("logical: {e}"))
327    }
328}
329
330fn logical_buffer_to_gpu(
331    buffer: LogicalBuffer,
332    provider: Option<&'static dyn AccelProvider>,
333) -> Result<Value, String> {
334    if let Some(p) = provider {
335        let floats: Vec<f64> = buffer
336            .bits
337            .iter()
338            .map(|&b| if b != 0 { 1.0 } else { 0.0 })
339            .collect();
340        let view = HostTensorView {
341            data: &floats,
342            shape: &buffer.shape,
343        };
344        match p.upload(&view) {
345            Ok(handle) => Ok(gpu_helpers::logical_gpu_value(handle)),
346            Err(err) => {
347                trace!("logical: upload failed during fallback path ({err})");
348                logical_buffer_to_host(buffer)
349            }
350        }
351    } else {
352        logical_buffer_to_host(buffer)
353    }
354}
355
356fn try_gpu_cast(
357    provider: &'static dyn AccelProvider,
358    input: &GpuTensorHandle,
359) -> Option<GpuTensorHandle> {
360    let zeros = provider.zeros_like(input).ok()?;
361    let result = provider.elem_ne(input, &zeros).ok();
362    let _ = provider.free(&zeros);
363    result
364}
365
366fn complex_is_zero(re: f64, im: f64) -> bool {
367    re == 0.0 && im == 0.0
368}
369
370fn conversion_error(type_name: &str) -> String {
371    format!(
372        "logical: conversion to logical from {} is not possible",
373        type_name
374    )
375}
376
377#[derive(Clone)]
378struct LogicalBuffer {
379    bits: Vec<u8>,
380    shape: Vec<usize>,
381}
382
383impl LogicalBuffer {
384    fn from_real_tensor(tensor: &Tensor) -> Self {
385        let bits: Vec<u8> = tensor
386            .data
387            .iter()
388            .map(|&v| if v != 0.0 { 1 } else { 0 })
389            .collect();
390        let shape = canonical_shape(&tensor.shape, bits.len());
391        Self { bits, shape }
392    }
393
394    fn from_complex_tensor(tensor: &ComplexTensor) -> Self {
395        let bits: Vec<u8> = tensor
396            .data
397            .iter()
398            .map(|&(re, im)| if !complex_is_zero(re, im) { 1 } else { 0 })
399            .collect();
400        let shape = canonical_shape(&tensor.shape, bits.len());
401        Self { bits, shape }
402    }
403
404    fn from_char_array(chars: &CharArray) -> Self {
405        let bits: Vec<u8> = chars
406            .data
407            .iter()
408            .map(|&ch| if (ch as u32) != 0 { 1 } else { 0 })
409            .collect();
410        let original_shape = vec![chars.rows, chars.cols];
411        let shape = canonical_shape(&original_shape, bits.len());
412        Self { bits, shape }
413    }
414}
415
416fn canonical_shape(shape: &[usize], len: usize) -> Vec<usize> {
417    if !shape.is_empty() && tensor::element_count(shape) == len {
418        return shape.to_vec();
419    }
420    if len == 0 {
421        if shape.len() > 1 {
422            return shape.to_vec();
423        }
424        return vec![0];
425    }
426    if len == 1 {
427        vec![1, 1]
428    } else {
429        vec![len, 1]
430    }
431}
432
433#[cfg(test)]
434mod tests {
435    use super::*;
436    use crate::builtins::common::test_support;
437    use runmat_accelerate_api::HostTensorView;
438    use runmat_builtins::{CellArray, IntValue, MException, ObjectInstance, StructValue};
439
440    #[test]
441    fn logical_scalar_num() {
442        let result = logical_builtin(Value::Num(5.0), Vec::new()).expect("logical");
443        assert_eq!(result, Value::Bool(true));
444
445        let zero_result = logical_builtin(Value::Num(0.0), Vec::new()).expect("logical");
446        assert_eq!(zero_result, Value::Bool(false));
447    }
448
449    #[test]
450    fn logical_nan_is_true() {
451        let tensor = Tensor::new(vec![0.0, f64::NAN, -0.0], vec![1, 3]).unwrap();
452        let result = logical_builtin(Value::Tensor(tensor), Vec::new()).expect("logical");
453        match result {
454            Value::LogicalArray(array) => assert_eq!(array.data, vec![0, 1, 0]),
455            other => panic!("expected logical array, got {:?}", other),
456        }
457    }
458
459    #[test]
460    fn logical_tensor_matrix() {
461        let tensor = Tensor::new(vec![0.0, 2.0, -3.0, 0.0], vec![2, 2]).unwrap();
462        let result = logical_builtin(Value::Tensor(tensor), Vec::new()).expect("logical");
463        match result {
464            Value::LogicalArray(array) => {
465                assert_eq!(array.shape, vec![2, 2]);
466                assert_eq!(array.data, vec![0, 1, 1, 0]);
467            }
468            other => panic!("expected logical array, got {:?}", other),
469        }
470    }
471
472    #[test]
473    fn logical_complex_conversion() {
474        let complex =
475            ComplexTensor::new(vec![(0.0, 0.0), (1.0, 0.0), (0.0, 2.0)], vec![3, 1]).unwrap();
476        let result = logical_builtin(Value::ComplexTensor(complex), Vec::new()).expect("logical");
477        match result {
478            Value::LogicalArray(array) => {
479                assert_eq!(array.data, vec![0, 1, 1]);
480            }
481            other => panic!("expected logical array, got {:?}", other),
482        }
483    }
484
485    #[test]
486    fn logical_char_array_conversion() {
487        let chars = CharArray::new(vec!['A', '\0', 'C'], 1, 3).unwrap();
488        let result = logical_builtin(Value::CharArray(chars), Vec::new()).expect("logical");
489        match result {
490            Value::LogicalArray(array) => assert_eq!(array.data, vec![1, 0, 1]),
491            other => panic!("expected logical array, got {:?}", other),
492        }
493    }
494
495    #[test]
496    fn logical_string_error() {
497        let err = logical_builtin(Value::String("runmat".to_string()), Vec::new()).unwrap_err();
498        assert_eq!(
499            err,
500            "logical: conversion to logical from string is not possible"
501        );
502    }
503
504    #[test]
505    fn logical_struct_error() {
506        let mut st = StructValue::new();
507        st.insert("field", Value::Num(1.0));
508        let err = logical_builtin(Value::Struct(st), Vec::new()).unwrap_err();
509        assert!(err.contains("struct"), "unexpected error message: {err}");
510    }
511
512    #[test]
513    fn logical_cell_error() {
514        let cell = CellArray::new(vec![Value::Num(1.0)], 1, 1).expect("cell creation");
515        let err = logical_builtin(Value::Cell(cell), Vec::new()).unwrap_err();
516        assert_eq!(
517            err,
518            "logical: conversion to logical from cell is not possible"
519        );
520    }
521
522    #[test]
523    fn logical_function_handle_error() {
524        let err = logical_builtin(Value::FunctionHandle("foo".into()), Vec::new()).unwrap_err();
525        assert_eq!(
526            err,
527            "logical: conversion to logical from function_handle is not possible"
528        );
529    }
530
531    #[test]
532    fn logical_object_error() {
533        let obj = ObjectInstance::new("DemoClass".to_string());
534        let err = logical_builtin(Value::Object(obj), Vec::new()).unwrap_err();
535        assert!(
536            err.contains("DemoClass"),
537            "expected class name in error, got {err}"
538        );
539    }
540
541    #[test]
542    fn logical_mexception_error() {
543        let mex = MException::new("id:logical".into(), "message".into());
544        let err = logical_builtin(Value::MException(mex), Vec::new()).unwrap_err();
545        assert_eq!(
546            err,
547            "logical: conversion to logical from MException is not possible"
548        );
549    }
550
551    #[test]
552    fn logical_gpu_roundtrip() {
553        test_support::with_test_provider(|provider| {
554            let tensor = Tensor::new(vec![0.0, 1.0, -2.0], vec![3, 1]).unwrap();
555            let view = HostTensorView {
556                data: &tensor.data,
557                shape: &tensor.shape,
558            };
559            let handle = provider.upload(&view).expect("upload");
560            let result =
561                logical_builtin(Value::GpuTensor(handle.clone()), Vec::new()).expect("logical");
562            let gathered = test_support::gather(result.clone()).expect("gather");
563            assert_eq!(gathered.data, vec![0.0, 1.0, 1.0]);
564            if let Value::GpuTensor(out) = result {
565                assert!(runmat_accelerate_api::handle_is_logical(&out));
566            } else {
567                panic!("expected gpu tensor output");
568            }
569        });
570    }
571
572    #[test]
573    fn logical_gpu_passthrough_for_logical_handle() {
574        test_support::with_test_provider(|provider| {
575            let tensor = Tensor::new(vec![0.0, 1.0], vec![2, 1]).unwrap();
576            let view = HostTensorView {
577                data: &tensor.data,
578                shape: &tensor.shape,
579            };
580            let handle = provider.upload(&view).expect("upload");
581            runmat_accelerate_api::set_handle_logical(&handle, true);
582            let result =
583                logical_builtin(Value::GpuTensor(handle.clone()), Vec::new()).expect("logical");
584            match result {
585                Value::GpuTensor(out) => assert_eq!(out, handle),
586                other => panic!("expected gpu tensor, got {:?}", other),
587            }
588        });
589    }
590
591    #[test]
592    fn logical_bool_and_logical_inputs_passthrough() {
593        let res_bool = logical_builtin(Value::Bool(true), Vec::new()).expect("logical");
594        assert_eq!(res_bool, Value::Bool(true));
595
596        let logical = LogicalArray::new(vec![1, 0], vec![1, 2]).unwrap();
597        let res_array =
598            logical_builtin(Value::LogicalArray(logical.clone()), Vec::new()).expect("logical");
599        assert_eq!(res_array, Value::LogicalArray(logical));
600    }
601
602    #[test]
603    fn logical_empty_tensor_preserves_shape() {
604        let tensor = Tensor::new(Vec::new(), vec![0, 3]).unwrap();
605        let result = logical_builtin(Value::Tensor(tensor), Vec::new()).expect("logical");
606        match result {
607            Value::LogicalArray(array) => {
608                assert!(array.data.is_empty());
609                assert_eq!(array.shape, vec![0, 3]);
610            }
611            other => panic!("expected logical array, got {:?}", other),
612        }
613    }
614
615    #[test]
616    fn logical_integer_scalar() {
617        let res = logical_builtin(Value::Int(IntValue::I32(0)), Vec::new()).expect("logical");
618        assert_eq!(res, Value::Bool(false));
619
620        let res_nonzero =
621            logical_builtin(Value::Int(IntValue::I32(-5)), Vec::new()).expect("logical");
622        assert_eq!(res_nonzero, Value::Bool(true));
623    }
624
625    #[test]
626    #[cfg(feature = "doc_export")]
627    fn doc_examples_present() {
628        let blocks = test_support::doc_examples(DOC_MD);
629        assert!(!blocks.is_empty());
630    }
631
632    #[test]
633    #[cfg(feature = "wgpu")]
634    fn logical_wgpu_matches_cpu_conversion() {
635        let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
636            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
637        );
638
639        let tensor = Tensor::new(vec![0.0, 2.0, -3.0, f64::NAN], vec![2, 2]).unwrap();
640        let cpu = logical_builtin(Value::Tensor(tensor.clone()), Vec::new()).unwrap();
641
642        let view = runmat_accelerate_api::HostTensorView {
643            data: &tensor.data,
644            shape: &tensor.shape,
645        };
646        let provider = runmat_accelerate_api::provider().expect("wgpu provider");
647        let handle = provider.upload(&view).expect("upload");
648
649        let gpu_value = logical_builtin(Value::GpuTensor(handle), Vec::new()).unwrap();
650        let out_handle = match gpu_value {
651            Value::GpuTensor(ref h) => {
652                assert!(runmat_accelerate_api::handle_is_logical(h));
653                h.clone()
654            }
655            other => panic!("expected gpu tensor, got {other:?}"),
656        };
657
658        let gathered = test_support::gather(Value::GpuTensor(out_handle)).expect("gather");
659
660        let (expected, expected_shape): (Vec<f64>, Vec<usize>) = match cpu {
661            Value::LogicalArray(arr) => (
662                arr.data
663                    .iter()
664                    .map(|&b| if b != 0 { 1.0 } else { 0.0 })
665                    .collect(),
666                arr.shape.clone(),
667            ),
668            Value::Bool(flag) => (vec![if flag { 1.0 } else { 0.0 }], vec![1, 1]),
669            other => panic!("unexpected cpu result {other:?}"),
670        };
671
672        assert_eq!(gathered.shape, expected_shape);
673        assert_eq!(gathered.data, expected);
674    }
675}