1use log::trace;
4use runmat_accelerate_api::{self, AccelProvider, GpuTensorHandle, HostTensorView};
5use runmat_builtins::{CharArray, ComplexTensor, LogicalArray, StringArray, Tensor, Value};
6use runmat_macros::runtime_builtin;
7
8use crate::builtins::common::{
9 gpu_helpers,
10 spec::{
11 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
12 ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
13 },
14 tensor,
15};
16#[cfg(feature = "doc_export")]
17use crate::register_builtin_doc_text;
18use crate::{register_builtin_fusion_spec, register_builtin_gpu_spec};
19
20#[cfg(feature = "doc_export")]
21pub const DOC_MD: &str = r#"---
22title: "logical"
23category: "logical"
24keywords: ["logical", "boolean conversion", "truth mask", "gpuArray", "mask array"]
25summary: "Convert scalars, arrays, and gpuArray values to MATLAB-compatible logical values."
26references: []
27gpu_support:
28 elementwise: true
29 reduction: false
30 precisions: ["f32", "f64"]
31 broadcasting: "matlab"
32 notes: "Prefers a device-side elem\\_ne(X, 0) cast when the provider supports elem_ne and zeros_like; otherwise gathers to the host, converts, and re-uploads the logical result."
33fusion:
34 elementwise: false
35 reduction: false
36 max_inputs: 1
37 constants: "inline"
38requires_feature: null
39tested:
40 unit: "builtins::logical::ops::tests"
41 integration: "builtins::logical::ops::tests::logical_gpu_roundtrip"
42---
43
44# What does the `logical` function do in MATLAB / RunMat?
45`logical(X)` converts numeric, logical, character, and gpuArray inputs into MATLAB logical values (booleans). Any non-zero (or `NaN`/`Inf`) element maps to `true`, while zero maps to `false`. Logical inputs are returned unchanged.
46
47## How does the `logical` function behave in MATLAB / RunMat?
48- `logical` accepts scalars, dense arrays, N-D tensors, and gpuArrays. Shapes are preserved bit-for-bit.
49- Non-zero numeric values, `NaN`, and `Inf` map to `true`; `0` and `-0` map to `false`.
50- Complex inputs are considered `true` when either the real or imaginary component is non-zero.
51- Character arrays are converted elementwise by interpreting code points (so `'A'` becomes `true`, `'\0'` becomes `false`).
52- Strings, structs, cells, objects, and other non-numeric types raise MATLAB-compatible errors (`"Conversion to logical from <type> is not possible"`).
53- Scalar results become logical scalars (`true`/`false`); higher-rank arrays produce dense logical arrays.
54
55## `logical` Function GPU Execution Behaviour
56- When a GPU provider implements `elem_ne` and `zeros_like`, RunMat performs the conversion in-place on the device by evaluating `elem_ne(X, 0)`, then marks the resulting handle as logical so predicates like `islogical` work without downloads.
57- If the provider cannot service the request (missing hooks, unsupported dtype, or allocation failure), the value is transparently gathered to the host, converted, and—when a provider is still available—re-uploaded as a logical gpuArray. The fallback is documented so users understand potential host/device transitions.
58- Handles that are already flagged as logical (`gpuArray.logical`) are returned without modification.
59- Scalars remain scalars: converting a `gpuArray` scalar preserves the residency and returns a logical gpuArray scalar.
60
61## Examples of using the `logical` function in MATLAB / RunMat
62
63### Creating a logical mask from numeric data
64```matlab
65values = [0 2 -3 0];
66mask = logical(values);
67```
68Expected output:
69```matlab
70mask =
71 1×4 logical array
72 0 1 1 0
73```
74
75### Building a logical mask from a matrix
76```matlab
77M = [-4 0 8; 0 1 0];
78mask = logical(M);
79```
80Expected output:
81```matlab
82mask =
83 2×3 logical array
84 1 0 1
85 0 1 0
86```
87
88### Treating NaN and Inf values as true
89```matlab
90flags = logical([NaN Inf 0]);
91```
92Expected output:
93```matlab
94flags =
95 1×3 logical array
96 1 1 0
97```
98
99### Converting complex numbers to logical scalars
100```matlab
101z = logical(3 + 4i);
102w = logical(0 + 0i);
103```
104Expected output:
105```matlab
106z =
107 1
108w =
109 0
110```
111
112### Converting character arrays to logical values
113```matlab
114chars = ['A' 0 'C'];
115mask = logical(chars);
116```
117Expected output:
118```matlab
119mask =
120 1×3 logical array
121 1 0 1
122```
123
124### Keeping gpuArray inputs on the device
125```matlab
126G = gpuArray([0 1 2]);
127maskGPU = logical(G);
128hostMask = gather(maskGPU);
129```
130Expected output:
131```matlab
132hostMask =
133 1×3 logical array
134 0 1 1
135```
136
137### Preserving empty shapes through logical conversion
138```matlab
139emptyVec = zeros(0, 3);
140logicalEmpty = logical(emptyVec);
141```
142Expected output:
143```matlab
144logicalEmpty =
145 0×3 logical array
146 []
147```
148
149## GPU residency in RunMat (Do I need `gpuArray`?)
150You rarely need to call `gpuArray` manually. When the acceleration provider is active, RunMat keeps logical conversions on the GPU by issuing `elem_ne(X, 0)` kernels (backed by `zeros_like` allocations) and flagging the handle as logical metadata. Explicit `gpuArray` calls are available for MATLAB compatibility or when you want to pin residency before interacting with external libraries. When the provider lacks the necessary hook, RunMat documents the fallback: it gathers the data, converts it on the host, and—if a provider is still available—re-uploads the logical mask so downstream GPU code continues to work without residency surprises.
151
152## FAQ
153
154### Which input types does `logical` support?
155Numeric, logical, complex, character, and gpuArray values are accepted. Strings, structs, cells, objects, and function handles are rejected with MATLAB-compatible error messages.
156
157### How are NaN or Inf values treated?
158They evaluate to `true`. MATLAB defines logical conversion as “non-zero”, and `NaN` / `Inf` both satisfy that rule.
159
160### How does `logical` handle complex numbers?
161The result is `true` when either the real or imaginary component is non-zero (or `NaN`/`Inf`). Only `0 + 0i` converts to `false`.
162
163### Does the builtin change array shapes?
164No. Shapes are preserved exactly, including empty dimensions and higher-rank tensors.
165
166### What happens to existing logical arrays?
167They are returned verbatim. Logical gpuArrays remain on the device without triggering new allocations.
168
169### Can I convert strings with `logical`?
170No. MATLAB rejects string inputs, and RunMat mirrors that behaviour: `"logical: conversion to logical from string is not possible"`.
171
172### What about structs, cells, or objects?
173They raise the same conversion error as MATLAB. Use functions like `~cellfun(@isempty, ...)` to derive masks instead.
174
175### Does the GPU path allocate new buffers?
176Only when the provider cannot operate in-place. The preferred path performs `elem_ne` against a zero tensor and reuses the resulting buffer. Fallback paths allocate a new gpuArray after gathering to the host.
177
178### Where can I learn more?
179See the references below and the RunMat source for implementation details.
180
181## See Also
182[`islogical`](./tests/islogical), [`gpuArray`](../../acceleration/gpu/gpuArray), [`gather`](../../acceleration/gpu/gather), [`find`](../../math/reduction/find)
183
184## Source & Feedback
185- Implementation: `crates/runmat-runtime/src/builtins/logical/ops.rs`
186- Issues & feature requests: [https://github.com/runmat-org/runmat/issues/new/choose](https://github.com/runmat-org/runmat/issues/new/choose)
187"#;
188
189pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
190 name: "logical",
191 op_kind: GpuOpKind::Elementwise,
192 supported_precisions: &[ScalarType::F32, ScalarType::F64],
193 broadcast: BroadcastSemantics::Matlab,
194 provider_hooks: &[ProviderHook::Binary {
195 name: "elem_ne",
196 commutative: true,
197 }],
198 constant_strategy: ConstantStrategy::InlineLiteral,
199 residency: ResidencyPolicy::NewHandle,
200 nan_mode: ReductionNaN::Include,
201 two_pass_threshold: None,
202 workgroup_size: None,
203 accepts_nan_mode: false,
204 notes: "Preferred path issues elem_ne(X, 0) on the device; missing hooks trigger a gather → host cast → re-upload sequence flagged as logical.",
205};
206
207register_builtin_gpu_spec!(GPU_SPEC);
208
209pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
210 name: "logical",
211 shape: ShapeRequirements::BroadcastCompatible,
212 constant_strategy: ConstantStrategy::InlineLiteral,
213 elementwise: None,
214 reduction: None,
215 emits_nan: false,
216 notes: "Fusion support will arrive alongside a dedicated WGSL template; today the builtin executes outside fusion plans.",
217};
218
219register_builtin_fusion_spec!(FUSION_SPEC);
220
221#[cfg(feature = "doc_export")]
222register_builtin_doc_text!("logical", DOC_MD);
223
224#[runtime_builtin(
225 name = "logical",
226 category = "logical",
227 summary = "Convert scalars, arrays, and gpuArray values to logical outputs.",
228 keywords = "logical,boolean,gpuArray,mask,conversion",
229 accel = "unary"
230)]
231fn logical_builtin(value: Value, rest: Vec<Value>) -> Result<Value, String> {
232 if !rest.is_empty() {
233 return Err("logical: too many input arguments".to_string());
234 }
235 convert_value_to_logical(value)
236}
237
238fn convert_value_to_logical(value: Value) -> Result<Value, String> {
239 match value {
240 Value::Bool(_) | Value::LogicalArray(_) => Ok(value),
241 Value::Num(n) => Ok(Value::Bool(n != 0.0)),
242 Value::Int(i) => Ok(Value::Bool(!i.is_zero())),
243 Value::Complex(re, im) => Ok(Value::Bool(!complex_is_zero(re, im))),
244 Value::Tensor(tensor) => logical_from_tensor(tensor),
245 Value::ComplexTensor(tensor) => logical_from_complex_tensor(tensor),
246 Value::CharArray(chars) => logical_from_char_array(chars),
247 Value::StringArray(strings) => logical_from_string_array(strings),
248 Value::GpuTensor(handle) => logical_from_gpu(handle),
249 Value::String(_) => Err(conversion_error("string")),
250 Value::Cell(_) => Err(conversion_error("cell")),
251 Value::Struct(_) => Err(conversion_error("struct")),
252 Value::Object(obj) => Err(conversion_error(&obj.class_name)),
253 Value::HandleObject(handle) => Err(conversion_error(&handle.class_name)),
254 Value::Listener(_) => Err(conversion_error("event.listener")),
255 Value::FunctionHandle(_) | Value::Closure(_) => Err(conversion_error("function_handle")),
256 Value::ClassRef(_) => Err(conversion_error("meta.class")),
257 Value::MException(_) => Err(conversion_error("MException")),
258 }
259}
260
261fn logical_from_tensor(tensor: Tensor) -> Result<Value, String> {
262 let buffer = LogicalBuffer::from_real_tensor(&tensor);
263 logical_buffer_to_host(buffer)
264}
265
266fn logical_from_complex_tensor(tensor: ComplexTensor) -> Result<Value, String> {
267 let buffer = LogicalBuffer::from_complex_tensor(&tensor);
268 logical_buffer_to_host(buffer)
269}
270
271fn logical_from_char_array(chars: CharArray) -> Result<Value, String> {
272 let buffer = LogicalBuffer::from_char_array(&chars);
273 logical_buffer_to_host(buffer)
274}
275
276fn logical_from_string_array(strings: StringArray) -> Result<Value, String> {
277 let bits: Vec<u8> = strings
278 .data
279 .iter()
280 .map(|s| if s.is_empty() { 0 } else { 1 })
281 .collect();
282 let shape = canonical_shape(&strings.shape, bits.len());
283 logical_buffer_to_host(LogicalBuffer { bits, shape })
284}
285
286fn logical_from_gpu(handle: GpuTensorHandle) -> Result<Value, String> {
287 if runmat_accelerate_api::handle_is_logical(&handle) {
288 return Ok(Value::GpuTensor(handle));
289 }
290
291 let provider = runmat_accelerate_api::provider();
292
293 if let Some(p) = provider {
294 match p.logical_islogical(&handle) {
295 Ok(true) => {
296 runmat_accelerate_api::set_handle_logical(&handle, true);
297 return Ok(Value::GpuTensor(handle));
298 }
299 Ok(false) => {}
300 Err(err) => {
301 trace!("logical: provider logical_islogical hook unavailable, falling back ({err})")
302 }
303 }
304 if let Some(result) = try_gpu_cast(p, &handle) {
305 return Ok(gpu_helpers::logical_gpu_value(result));
306 } else {
307 trace!(
308 "logical: provider elem_ne/zeros_like unavailable for buffer {} – gathering",
309 handle.buffer_id
310 );
311 }
312 }
313
314 let tensor = gpu_helpers::gather_tensor(&handle)?;
315 let buffer = LogicalBuffer::from_real_tensor(&tensor);
316 logical_buffer_to_gpu(buffer, provider)
317}
318
319fn logical_buffer_to_host(buffer: LogicalBuffer) -> Result<Value, String> {
320 let LogicalBuffer { bits, shape } = buffer;
321 if tensor::element_count(&shape) == 1 && bits.len() == 1 {
322 Ok(Value::Bool(bits[0] != 0))
323 } else {
324 LogicalArray::new(bits, shape)
325 .map(Value::LogicalArray)
326 .map_err(|e| format!("logical: {e}"))
327 }
328}
329
330fn logical_buffer_to_gpu(
331 buffer: LogicalBuffer,
332 provider: Option<&'static dyn AccelProvider>,
333) -> Result<Value, String> {
334 if let Some(p) = provider {
335 let floats: Vec<f64> = buffer
336 .bits
337 .iter()
338 .map(|&b| if b != 0 { 1.0 } else { 0.0 })
339 .collect();
340 let view = HostTensorView {
341 data: &floats,
342 shape: &buffer.shape,
343 };
344 match p.upload(&view) {
345 Ok(handle) => Ok(gpu_helpers::logical_gpu_value(handle)),
346 Err(err) => {
347 trace!("logical: upload failed during fallback path ({err})");
348 logical_buffer_to_host(buffer)
349 }
350 }
351 } else {
352 logical_buffer_to_host(buffer)
353 }
354}
355
356fn try_gpu_cast(
357 provider: &'static dyn AccelProvider,
358 input: &GpuTensorHandle,
359) -> Option<GpuTensorHandle> {
360 let zeros = provider.zeros_like(input).ok()?;
361 let result = provider.elem_ne(input, &zeros).ok();
362 let _ = provider.free(&zeros);
363 result
364}
365
366fn complex_is_zero(re: f64, im: f64) -> bool {
367 re == 0.0 && im == 0.0
368}
369
370fn conversion_error(type_name: &str) -> String {
371 format!(
372 "logical: conversion to logical from {} is not possible",
373 type_name
374 )
375}
376
377#[derive(Clone)]
378struct LogicalBuffer {
379 bits: Vec<u8>,
380 shape: Vec<usize>,
381}
382
383impl LogicalBuffer {
384 fn from_real_tensor(tensor: &Tensor) -> Self {
385 let bits: Vec<u8> = tensor
386 .data
387 .iter()
388 .map(|&v| if v != 0.0 { 1 } else { 0 })
389 .collect();
390 let shape = canonical_shape(&tensor.shape, bits.len());
391 Self { bits, shape }
392 }
393
394 fn from_complex_tensor(tensor: &ComplexTensor) -> Self {
395 let bits: Vec<u8> = tensor
396 .data
397 .iter()
398 .map(|&(re, im)| if !complex_is_zero(re, im) { 1 } else { 0 })
399 .collect();
400 let shape = canonical_shape(&tensor.shape, bits.len());
401 Self { bits, shape }
402 }
403
404 fn from_char_array(chars: &CharArray) -> Self {
405 let bits: Vec<u8> = chars
406 .data
407 .iter()
408 .map(|&ch| if (ch as u32) != 0 { 1 } else { 0 })
409 .collect();
410 let original_shape = vec![chars.rows, chars.cols];
411 let shape = canonical_shape(&original_shape, bits.len());
412 Self { bits, shape }
413 }
414}
415
416fn canonical_shape(shape: &[usize], len: usize) -> Vec<usize> {
417 if !shape.is_empty() && tensor::element_count(shape) == len {
418 return shape.to_vec();
419 }
420 if len == 0 {
421 if shape.len() > 1 {
422 return shape.to_vec();
423 }
424 return vec![0];
425 }
426 if len == 1 {
427 vec![1, 1]
428 } else {
429 vec![len, 1]
430 }
431}
432
433#[cfg(test)]
434mod tests {
435 use super::*;
436 use crate::builtins::common::test_support;
437 use runmat_accelerate_api::HostTensorView;
438 use runmat_builtins::{CellArray, IntValue, MException, ObjectInstance, StructValue};
439
440 #[test]
441 fn logical_scalar_num() {
442 let result = logical_builtin(Value::Num(5.0), Vec::new()).expect("logical");
443 assert_eq!(result, Value::Bool(true));
444
445 let zero_result = logical_builtin(Value::Num(0.0), Vec::new()).expect("logical");
446 assert_eq!(zero_result, Value::Bool(false));
447 }
448
449 #[test]
450 fn logical_nan_is_true() {
451 let tensor = Tensor::new(vec![0.0, f64::NAN, -0.0], vec![1, 3]).unwrap();
452 let result = logical_builtin(Value::Tensor(tensor), Vec::new()).expect("logical");
453 match result {
454 Value::LogicalArray(array) => assert_eq!(array.data, vec![0, 1, 0]),
455 other => panic!("expected logical array, got {:?}", other),
456 }
457 }
458
459 #[test]
460 fn logical_tensor_matrix() {
461 let tensor = Tensor::new(vec![0.0, 2.0, -3.0, 0.0], vec![2, 2]).unwrap();
462 let result = logical_builtin(Value::Tensor(tensor), Vec::new()).expect("logical");
463 match result {
464 Value::LogicalArray(array) => {
465 assert_eq!(array.shape, vec![2, 2]);
466 assert_eq!(array.data, vec![0, 1, 1, 0]);
467 }
468 other => panic!("expected logical array, got {:?}", other),
469 }
470 }
471
472 #[test]
473 fn logical_complex_conversion() {
474 let complex =
475 ComplexTensor::new(vec![(0.0, 0.0), (1.0, 0.0), (0.0, 2.0)], vec![3, 1]).unwrap();
476 let result = logical_builtin(Value::ComplexTensor(complex), Vec::new()).expect("logical");
477 match result {
478 Value::LogicalArray(array) => {
479 assert_eq!(array.data, vec![0, 1, 1]);
480 }
481 other => panic!("expected logical array, got {:?}", other),
482 }
483 }
484
485 #[test]
486 fn logical_char_array_conversion() {
487 let chars = CharArray::new(vec!['A', '\0', 'C'], 1, 3).unwrap();
488 let result = logical_builtin(Value::CharArray(chars), Vec::new()).expect("logical");
489 match result {
490 Value::LogicalArray(array) => assert_eq!(array.data, vec![1, 0, 1]),
491 other => panic!("expected logical array, got {:?}", other),
492 }
493 }
494
495 #[test]
496 fn logical_string_error() {
497 let err = logical_builtin(Value::String("runmat".to_string()), Vec::new()).unwrap_err();
498 assert_eq!(
499 err,
500 "logical: conversion to logical from string is not possible"
501 );
502 }
503
504 #[test]
505 fn logical_struct_error() {
506 let mut st = StructValue::new();
507 st.insert("field", Value::Num(1.0));
508 let err = logical_builtin(Value::Struct(st), Vec::new()).unwrap_err();
509 assert!(err.contains("struct"), "unexpected error message: {err}");
510 }
511
512 #[test]
513 fn logical_cell_error() {
514 let cell = CellArray::new(vec![Value::Num(1.0)], 1, 1).expect("cell creation");
515 let err = logical_builtin(Value::Cell(cell), Vec::new()).unwrap_err();
516 assert_eq!(
517 err,
518 "logical: conversion to logical from cell is not possible"
519 );
520 }
521
522 #[test]
523 fn logical_function_handle_error() {
524 let err = logical_builtin(Value::FunctionHandle("foo".into()), Vec::new()).unwrap_err();
525 assert_eq!(
526 err,
527 "logical: conversion to logical from function_handle is not possible"
528 );
529 }
530
531 #[test]
532 fn logical_object_error() {
533 let obj = ObjectInstance::new("DemoClass".to_string());
534 let err = logical_builtin(Value::Object(obj), Vec::new()).unwrap_err();
535 assert!(
536 err.contains("DemoClass"),
537 "expected class name in error, got {err}"
538 );
539 }
540
541 #[test]
542 fn logical_mexception_error() {
543 let mex = MException::new("id:logical".into(), "message".into());
544 let err = logical_builtin(Value::MException(mex), Vec::new()).unwrap_err();
545 assert_eq!(
546 err,
547 "logical: conversion to logical from MException is not possible"
548 );
549 }
550
551 #[test]
552 fn logical_gpu_roundtrip() {
553 test_support::with_test_provider(|provider| {
554 let tensor = Tensor::new(vec![0.0, 1.0, -2.0], vec![3, 1]).unwrap();
555 let view = HostTensorView {
556 data: &tensor.data,
557 shape: &tensor.shape,
558 };
559 let handle = provider.upload(&view).expect("upload");
560 let result =
561 logical_builtin(Value::GpuTensor(handle.clone()), Vec::new()).expect("logical");
562 let gathered = test_support::gather(result.clone()).expect("gather");
563 assert_eq!(gathered.data, vec![0.0, 1.0, 1.0]);
564 if let Value::GpuTensor(out) = result {
565 assert!(runmat_accelerate_api::handle_is_logical(&out));
566 } else {
567 panic!("expected gpu tensor output");
568 }
569 });
570 }
571
572 #[test]
573 fn logical_gpu_passthrough_for_logical_handle() {
574 test_support::with_test_provider(|provider| {
575 let tensor = Tensor::new(vec![0.0, 1.0], vec![2, 1]).unwrap();
576 let view = HostTensorView {
577 data: &tensor.data,
578 shape: &tensor.shape,
579 };
580 let handle = provider.upload(&view).expect("upload");
581 runmat_accelerate_api::set_handle_logical(&handle, true);
582 let result =
583 logical_builtin(Value::GpuTensor(handle.clone()), Vec::new()).expect("logical");
584 match result {
585 Value::GpuTensor(out) => assert_eq!(out, handle),
586 other => panic!("expected gpu tensor, got {:?}", other),
587 }
588 });
589 }
590
591 #[test]
592 fn logical_bool_and_logical_inputs_passthrough() {
593 let res_bool = logical_builtin(Value::Bool(true), Vec::new()).expect("logical");
594 assert_eq!(res_bool, Value::Bool(true));
595
596 let logical = LogicalArray::new(vec![1, 0], vec![1, 2]).unwrap();
597 let res_array =
598 logical_builtin(Value::LogicalArray(logical.clone()), Vec::new()).expect("logical");
599 assert_eq!(res_array, Value::LogicalArray(logical));
600 }
601
602 #[test]
603 fn logical_empty_tensor_preserves_shape() {
604 let tensor = Tensor::new(Vec::new(), vec![0, 3]).unwrap();
605 let result = logical_builtin(Value::Tensor(tensor), Vec::new()).expect("logical");
606 match result {
607 Value::LogicalArray(array) => {
608 assert!(array.data.is_empty());
609 assert_eq!(array.shape, vec![0, 3]);
610 }
611 other => panic!("expected logical array, got {:?}", other),
612 }
613 }
614
615 #[test]
616 fn logical_integer_scalar() {
617 let res = logical_builtin(Value::Int(IntValue::I32(0)), Vec::new()).expect("logical");
618 assert_eq!(res, Value::Bool(false));
619
620 let res_nonzero =
621 logical_builtin(Value::Int(IntValue::I32(-5)), Vec::new()).expect("logical");
622 assert_eq!(res_nonzero, Value::Bool(true));
623 }
624
625 #[test]
626 #[cfg(feature = "doc_export")]
627 fn doc_examples_present() {
628 let blocks = test_support::doc_examples(DOC_MD);
629 assert!(!blocks.is_empty());
630 }
631
632 #[test]
633 #[cfg(feature = "wgpu")]
634 fn logical_wgpu_matches_cpu_conversion() {
635 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
636 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
637 );
638
639 let tensor = Tensor::new(vec![0.0, 2.0, -3.0, f64::NAN], vec![2, 2]).unwrap();
640 let cpu = logical_builtin(Value::Tensor(tensor.clone()), Vec::new()).unwrap();
641
642 let view = runmat_accelerate_api::HostTensorView {
643 data: &tensor.data,
644 shape: &tensor.shape,
645 };
646 let provider = runmat_accelerate_api::provider().expect("wgpu provider");
647 let handle = provider.upload(&view).expect("upload");
648
649 let gpu_value = logical_builtin(Value::GpuTensor(handle), Vec::new()).unwrap();
650 let out_handle = match gpu_value {
651 Value::GpuTensor(ref h) => {
652 assert!(runmat_accelerate_api::handle_is_logical(h));
653 h.clone()
654 }
655 other => panic!("expected gpu tensor, got {other:?}"),
656 };
657
658 let gathered = test_support::gather(Value::GpuTensor(out_handle)).expect("gather");
659
660 let (expected, expected_shape): (Vec<f64>, Vec<usize>) = match cpu {
661 Value::LogicalArray(arr) => (
662 arr.data
663 .iter()
664 .map(|&b| if b != 0 { 1.0 } else { 0.0 })
665 .collect(),
666 arr.shape.clone(),
667 ),
668 Value::Bool(flag) => (vec![if flag { 1.0 } else { 0.0 }], vec![1, 1]),
669 other => panic!("unexpected cpu result {other:?}"),
670 };
671
672 assert_eq!(gathered.shape, expected_shape);
673 assert_eq!(gathered.data, expected);
674 }
675}