Skip to main content

runmat_runtime/builtins/logical/
ops.rs

1//! MATLAB-compatible `logical` builtin with GPU-aware semantics for RunMat.
2
3use log::trace;
4use runmat_accelerate_api::{self, AccelProvider, GpuTensorHandle, HostTensorView};
5use runmat_builtins::{
6    BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
7    BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
8    CharArray, ComplexTensor, LogicalArray, ResolveContext, StringArray, Tensor, Type, Value,
9};
10use runmat_macros::runtime_builtin;
11
12use crate::builtins::common::{
13    gpu_helpers,
14    shape::{canonical_scalar_shape, normalize_scalar_shape},
15    spec::{
16        BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
17        ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
18    },
19    tensor,
20};
21use crate::builtins::logical::type_resolvers::logical_like;
22
23use crate::{build_runtime_error, BuiltinResult, RuntimeError};
24
25#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::logical::ops")]
26pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
27    name: "logical",
28    op_kind: GpuOpKind::Elementwise,
29    supported_precisions: &[ScalarType::F32, ScalarType::F64],
30    broadcast: BroadcastSemantics::Matlab,
31    provider_hooks: &[ProviderHook::Binary {
32        name: "elem_ne",
33        commutative: true,
34    }],
35    constant_strategy: ConstantStrategy::InlineLiteral,
36    residency: ResidencyPolicy::NewHandle,
37    nan_mode: ReductionNaN::Include,
38    two_pass_threshold: None,
39    workgroup_size: None,
40    accepts_nan_mode: false,
41    notes: "Preferred path issues elem_ne(X, 0) on the device; missing hooks trigger a gather → host cast → re-upload sequence flagged as logical.",
42};
43
44#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::logical::ops")]
45pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
46    name: "logical",
47    shape: ShapeRequirements::BroadcastCompatible,
48    constant_strategy: ConstantStrategy::InlineLiteral,
49    elementwise: None,
50    reduction: None,
51    emits_nan: false,
52    notes: "Fusion support will arrive alongside a dedicated WGSL template; today the builtin executes outside fusion plans.",
53};
54
55const BUILTIN_NAME: &str = "logical";
56
57const LOGICAL_OUTPUT: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
58    name: "tf",
59    ty: BuiltinParamType::LogicalArray,
60    arity: BuiltinParamArity::Required,
61    default: None,
62    description: "Logical-converted result.",
63}];
64
65const LOGICAL_INPUTS: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
66    name: "A",
67    ty: BuiltinParamType::Any,
68    arity: BuiltinParamArity::Required,
69    default: None,
70    description: "Input value to convert.",
71}];
72
73const LOGICAL_SIGNATURES: [BuiltinSignatureDescriptor; 1] = [BuiltinSignatureDescriptor {
74    label: "tf = logical(A)",
75    inputs: &LOGICAL_INPUTS,
76    outputs: &LOGICAL_OUTPUT,
77}];
78
79const LOGICAL_ERROR_TOO_MANY_INPUTS: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
80    code: "RM.LOGICAL.TOO_MANY_INPUTS",
81    identifier: Some("RunMat:logical:TooManyInputs"),
82    when: "More than one input argument is provided.",
83    message: "logical: too many input arguments",
84};
85
86const LOGICAL_ERROR_CONVERSION_NOT_POSSIBLE: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
87    code: "RM.LOGICAL.CONVERSION_NOT_POSSIBLE",
88    identifier: Some("RunMat:logical:ConversionNotPossible"),
89    when: "Input type cannot be converted to logical.",
90    message: "logical: conversion to logical is not possible for this input type",
91};
92
93const LOGICAL_ERROR_GPU_GATHER_FAILED: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
94    code: "RM.LOGICAL.GPU_GATHER_FAILED",
95    identifier: Some("RunMat:logical:GpuGatherFailed"),
96    when: "GPU input gather fails during host fallback.",
97    message: "logical: failed to gather gpuArray input",
98};
99
100const LOGICAL_ERROR_INTERNAL: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
101    code: "RM.LOGICAL.INTERNAL",
102    identifier: Some("RunMat:logical:InternalError"),
103    when: "Internal logical buffer materialization fails.",
104    message: "logical: internal conversion error",
105};
106
107const LOGICAL_ERRORS: [BuiltinErrorDescriptor; 4] = [
108    LOGICAL_ERROR_TOO_MANY_INPUTS,
109    LOGICAL_ERROR_CONVERSION_NOT_POSSIBLE,
110    LOGICAL_ERROR_GPU_GATHER_FAILED,
111    LOGICAL_ERROR_INTERNAL,
112];
113
114pub const LOGICAL_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
115    signatures: &LOGICAL_SIGNATURES,
116    output_mode: BuiltinOutputMode::Fixed,
117    completion_policy: BuiltinCompletionPolicy::Public,
118    errors: &LOGICAL_ERRORS,
119};
120
121fn logical_type(args: &[Type], _context: &ResolveContext) -> Type {
122    args.first().map(logical_like).unwrap_or(Type::logical())
123}
124
125fn logical_error_with_message(
126    message: impl Into<String>,
127    error: &'static BuiltinErrorDescriptor,
128) -> RuntimeError {
129    let mut builder = build_runtime_error(message).with_builtin(BUILTIN_NAME);
130    if let Some(identifier) = error.identifier {
131        builder = builder.with_identifier(identifier);
132    }
133    builder.build()
134}
135
136#[runtime_builtin(
137    name = "logical",
138    category = "logical",
139    summary = "Convert scalars, arrays, and gpuArray values to logical outputs.",
140    keywords = "logical,boolean,gpuArray,mask,conversion",
141    accel = "unary",
142    type_resolver(logical_type),
143    descriptor(crate::builtins::logical::ops::LOGICAL_DESCRIPTOR),
144    builtin_path = "crate::builtins::logical::ops"
145)]
146async fn logical_builtin(value: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
147    if !rest.is_empty() {
148        return Err(logical_error_with_message(
149            LOGICAL_ERROR_TOO_MANY_INPUTS.message,
150            &LOGICAL_ERROR_TOO_MANY_INPUTS,
151        ));
152    }
153    convert_value_to_logical(value).await
154}
155
156async fn convert_value_to_logical(value: Value) -> BuiltinResult<Value> {
157    match value {
158        Value::Bool(_) | Value::LogicalArray(_) => Ok(value),
159        Value::Num(n) => Ok(Value::Bool(n != 0.0)),
160        Value::Int(i) => Ok(Value::Bool(!i.is_zero())),
161        Value::Complex(re, im) => Ok(Value::Bool(!complex_is_zero(re, im))),
162        Value::Tensor(tensor) => logical_from_tensor(tensor),
163        Value::SparseTensor(sparse) => logical_from_sparse_tensor(sparse),
164        Value::ComplexTensor(tensor) => logical_from_complex_tensor(tensor),
165        Value::CharArray(chars) => logical_from_char_array(chars),
166        Value::StringArray(strings) => logical_from_string_array(strings),
167        Value::GpuTensor(handle) => logical_from_gpu(handle).await,
168        Value::String(_) => Err(conversion_error("string")),
169        Value::Cell(_) => Err(conversion_error("cell")),
170        Value::Struct(_) => Err(conversion_error("struct")),
171        Value::Object(obj) => Err(conversion_error(&obj.class_name)),
172        Value::HandleObject(handle) => Err(conversion_error(&handle.class_name)),
173        Value::Listener(_) => Err(conversion_error("event.listener")),
174        Value::FunctionHandle(_)
175        | Value::ExternalFunctionHandle(_)
176        | Value::MethodFunctionHandle(_)
177        | Value::BoundFunctionHandle { .. }
178        | Value::Closure(_) => Err(conversion_error("function_handle")),
179        Value::ClassRef(_) => Err(conversion_error("meta.class")),
180        Value::MException(_) => Err(conversion_error("MException")),
181        Value::OutputList(_) => Err(conversion_error("OutputList")),
182    }
183}
184
185fn logical_from_tensor(tensor: Tensor) -> BuiltinResult<Value> {
186    let buffer = LogicalBuffer::from_real_tensor(&tensor);
187    logical_buffer_to_host(buffer)
188}
189
190fn logical_from_sparse_tensor(sparse: runmat_builtins::SparseTensor) -> BuiltinResult<Value> {
191    let tensor = sparse.to_dense().map_err(|err| {
192        logical_error_with_message(
193            format!("logical: failed to densify sparse input: {err}"),
194            &LOGICAL_ERROR_INTERNAL,
195        )
196    })?;
197    logical_from_tensor(tensor)
198}
199
200fn logical_from_complex_tensor(tensor: ComplexTensor) -> BuiltinResult<Value> {
201    let buffer = LogicalBuffer::from_complex_tensor(&tensor);
202    logical_buffer_to_host(buffer)
203}
204
205fn logical_from_char_array(chars: CharArray) -> BuiltinResult<Value> {
206    let buffer = LogicalBuffer::from_char_array(&chars);
207    logical_buffer_to_host(buffer)
208}
209
210fn logical_from_string_array(strings: StringArray) -> BuiltinResult<Value> {
211    let bits: Vec<u8> = strings
212        .data
213        .iter()
214        .map(|s| if s.is_empty() { 0 } else { 1 })
215        .collect();
216    let shape = canonical_shape(&strings.shape, bits.len());
217    logical_buffer_to_host(LogicalBuffer { bits, shape })
218}
219
220async fn logical_from_gpu(handle: GpuTensorHandle) -> BuiltinResult<Value> {
221    if runmat_accelerate_api::handle_is_logical(&handle) {
222        return Ok(Value::GpuTensor(handle));
223    }
224
225    let provider = runmat_accelerate_api::provider();
226
227    if let Some(p) = provider {
228        match p.logical_islogical(&handle) {
229            Ok(true) => {
230                runmat_accelerate_api::set_handle_logical(&handle, true);
231                return Ok(Value::GpuTensor(handle));
232            }
233            Ok(false) => {}
234            Err(err) => {
235                trace!("logical: provider logical_islogical hook unavailable, falling back ({err})")
236            }
237        }
238        if let Some(result) = try_gpu_cast(p, &handle).await {
239            return Ok(gpu_helpers::logical_gpu_value(result));
240        } else {
241            trace!(
242                "logical: provider elem_ne/zeros_like unavailable for buffer {} – gathering",
243                handle.buffer_id
244            );
245        }
246    }
247
248    let tensor = gpu_helpers::gather_tensor_async(&handle)
249        .await
250        .map_err(|err| {
251            logical_error_with_message(
252                format!("{BUILTIN_NAME}: {err}"),
253                &LOGICAL_ERROR_GPU_GATHER_FAILED,
254            )
255        })?;
256    let buffer = LogicalBuffer::from_real_tensor(&tensor);
257    logical_buffer_to_gpu(buffer, provider)
258}
259
260fn logical_buffer_to_host(buffer: LogicalBuffer) -> BuiltinResult<Value> {
261    let LogicalBuffer { bits, shape } = buffer;
262    if tensor::element_count(&shape) == 1 && bits.len() == 1 {
263        Ok(Value::Bool(bits[0] != 0))
264    } else {
265        LogicalArray::new(bits, shape)
266            .map(Value::LogicalArray)
267            .map_err(|e| {
268                logical_error_with_message(format!("logical: {e}"), &LOGICAL_ERROR_INTERNAL)
269            })
270    }
271}
272
273fn logical_buffer_to_gpu(
274    buffer: LogicalBuffer,
275    provider: Option<&'static dyn AccelProvider>,
276) -> BuiltinResult<Value> {
277    if let Some(p) = provider {
278        let floats: Vec<f64> = buffer
279            .bits
280            .iter()
281            .map(|&b| if b != 0 { 1.0 } else { 0.0 })
282            .collect();
283        let view = HostTensorView {
284            data: &floats,
285            shape: &buffer.shape,
286        };
287        match p.upload(&view) {
288            Ok(handle) => Ok(gpu_helpers::logical_gpu_value(handle)),
289            Err(err) => {
290                trace!("logical: upload failed during fallback path ({err})");
291                logical_buffer_to_host(buffer)
292            }
293        }
294    } else {
295        logical_buffer_to_host(buffer)
296    }
297}
298
299async fn try_gpu_cast(
300    provider: &'static dyn AccelProvider,
301    input: &GpuTensorHandle,
302) -> Option<GpuTensorHandle> {
303    let zeros = provider.zeros_like(input).ok()?;
304    let result = provider.elem_ne(input, &zeros).await.ok();
305    let _ = provider.free(&zeros);
306    result
307}
308
309fn complex_is_zero(re: f64, im: f64) -> bool {
310    re == 0.0 && im == 0.0
311}
312
313fn conversion_error(type_name: &str) -> RuntimeError {
314    logical_error_with_message(
315        format!(
316            "logical: conversion to logical from {} is not possible",
317            type_name
318        ),
319        &LOGICAL_ERROR_CONVERSION_NOT_POSSIBLE,
320    )
321}
322
323#[derive(Clone)]
324struct LogicalBuffer {
325    bits: Vec<u8>,
326    shape: Vec<usize>,
327}
328
329impl LogicalBuffer {
330    fn from_real_tensor(tensor: &Tensor) -> Self {
331        let bits: Vec<u8> = tensor
332            .data
333            .iter()
334            .map(|&v| if v != 0.0 { 1 } else { 0 })
335            .collect();
336        let shape = canonical_shape(&tensor.shape, bits.len());
337        Self { bits, shape }
338    }
339
340    fn from_complex_tensor(tensor: &ComplexTensor) -> Self {
341        let bits: Vec<u8> = tensor
342            .data
343            .iter()
344            .map(|&(re, im)| if !complex_is_zero(re, im) { 1 } else { 0 })
345            .collect();
346        let shape = canonical_shape(&tensor.shape, bits.len());
347        Self { bits, shape }
348    }
349
350    fn from_char_array(chars: &CharArray) -> Self {
351        let bits: Vec<u8> = chars
352            .data
353            .iter()
354            .map(|&ch| if (ch as u32) != 0 { 1 } else { 0 })
355            .collect();
356        let original_shape = vec![chars.rows, chars.cols];
357        let shape = canonical_shape(&original_shape, bits.len());
358        Self { bits, shape }
359    }
360}
361
362fn canonical_shape(shape: &[usize], len: usize) -> Vec<usize> {
363    if tensor::element_count(shape) == len {
364        return normalize_scalar_shape(shape);
365    }
366    if len == 0 {
367        if shape.len() > 1 {
368            return shape.to_vec();
369        }
370        return vec![0];
371    }
372    if len == 1 {
373        canonical_scalar_shape()
374    } else {
375        vec![len, 1]
376    }
377}
378
379#[cfg(test)]
380pub(crate) mod tests {
381    use super::*;
382    use crate::builtins::common::test_support;
383    use futures::executor::block_on;
384    use runmat_accelerate_api::HostTensorView;
385    use runmat_builtins::{
386        CellArray, IntValue, MException, ObjectInstance, SparseTensor, StructValue,
387    };
388
389    fn logical_builtin(value: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
390        block_on(super::logical_builtin(value, rest))
391    }
392
393    fn assert_error_message(err: &crate::RuntimeError, expected: &str) {
394        assert_eq!(err.message(), expected);
395    }
396
397    fn assert_error_contains(err: &crate::RuntimeError, expected: &str) {
398        assert!(
399            err.message().contains(expected),
400            "unexpected error: {}",
401            err.message()
402        );
403    }
404
405    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
406    #[test]
407    fn logical_scalar_num() {
408        let result = logical_builtin(Value::Num(5.0), Vec::new()).expect("logical");
409        assert_eq!(result, Value::Bool(true));
410
411        let zero_result = logical_builtin(Value::Num(0.0), Vec::new()).expect("logical");
412        assert_eq!(zero_result, Value::Bool(false));
413    }
414
415    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
416    #[test]
417    fn logical_nan_is_true() {
418        let tensor = Tensor::new(vec![0.0, f64::NAN, -0.0], vec![1, 3]).unwrap();
419        let result = logical_builtin(Value::Tensor(tensor), Vec::new()).expect("logical");
420        match result {
421            Value::LogicalArray(array) => assert_eq!(array.data, vec![0, 1, 0]),
422            other => panic!("expected logical array, got {:?}", other),
423        }
424    }
425
426    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
427    #[test]
428    fn logical_tensor_matrix() {
429        let tensor = Tensor::new(vec![0.0, 2.0, -3.0, 0.0], vec![2, 2]).unwrap();
430        let result = logical_builtin(Value::Tensor(tensor), Vec::new()).expect("logical");
431        match result {
432            Value::LogicalArray(array) => {
433                assert_eq!(array.shape, vec![2, 2]);
434                assert_eq!(array.data, vec![0, 1, 1, 0]);
435            }
436            other => panic!("expected logical array, got {:?}", other),
437        }
438    }
439
440    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
441    #[test]
442    fn logical_sparse_tensor_densifies() {
443        let sparse = SparseTensor::new(3, 2, vec![0, 1, 2], vec![1, 2], vec![4.0, -1.0]).unwrap();
444        let result = logical_builtin(Value::SparseTensor(sparse), Vec::new()).expect("logical");
445        match result {
446            Value::LogicalArray(array) => {
447                assert_eq!(array.shape, vec![3, 2]);
448                assert_eq!(array.data, vec![0, 1, 0, 0, 0, 1]);
449            }
450            other => panic!("expected logical array, got {:?}", other),
451        }
452    }
453
454    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
455    #[test]
456    fn logical_complex_conversion() {
457        let complex =
458            ComplexTensor::new(vec![(0.0, 0.0), (1.0, 0.0), (0.0, 2.0)], vec![3, 1]).unwrap();
459        let result = logical_builtin(Value::ComplexTensor(complex), Vec::new()).expect("logical");
460        match result {
461            Value::LogicalArray(array) => {
462                assert_eq!(array.data, vec![0, 1, 1]);
463            }
464            other => panic!("expected logical array, got {:?}", other),
465        }
466    }
467
468    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
469    #[test]
470    fn logical_char_array_conversion() {
471        let chars = CharArray::new(vec!['A', '\0', 'C'], 1, 3).unwrap();
472        let result = logical_builtin(Value::CharArray(chars), Vec::new()).expect("logical");
473        match result {
474            Value::LogicalArray(array) => assert_eq!(array.data, vec![1, 0, 1]),
475            other => panic!("expected logical array, got {:?}", other),
476        }
477    }
478
479    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
480    #[test]
481    fn logical_string_error() {
482        let err = logical_builtin(Value::String("runmat".to_string()), Vec::new()).unwrap_err();
483        assert_error_message(
484            &err,
485            "logical: conversion to logical from string is not possible",
486        );
487        assert_eq!(
488            err.identifier(),
489            LOGICAL_ERROR_CONVERSION_NOT_POSSIBLE.identifier
490        );
491    }
492
493    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
494    #[test]
495    fn logical_struct_error() {
496        let mut st = StructValue::new();
497        st.insert("field", Value::Num(1.0));
498        let err = logical_builtin(Value::Struct(st), Vec::new()).unwrap_err();
499        assert_error_contains(&err, "struct");
500        assert_eq!(
501            err.identifier(),
502            LOGICAL_ERROR_CONVERSION_NOT_POSSIBLE.identifier
503        );
504    }
505
506    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
507    #[test]
508    fn logical_cell_error() {
509        let cell = CellArray::new(vec![Value::Num(1.0)], 1, 1).expect("cell creation");
510        let err = logical_builtin(Value::Cell(cell), Vec::new()).unwrap_err();
511        assert_error_message(
512            &err,
513            "logical: conversion to logical from cell is not possible",
514        );
515        assert_eq!(
516            err.identifier(),
517            LOGICAL_ERROR_CONVERSION_NOT_POSSIBLE.identifier
518        );
519    }
520
521    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
522    #[test]
523    fn logical_function_handle_error() {
524        let err = logical_builtin(Value::FunctionHandle("foo".into()), Vec::new()).unwrap_err();
525        assert_error_message(
526            &err,
527            "logical: conversion to logical from function_handle is not possible",
528        );
529        assert_eq!(
530            err.identifier(),
531            LOGICAL_ERROR_CONVERSION_NOT_POSSIBLE.identifier
532        );
533    }
534
535    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
536    #[test]
537    fn logical_object_error() {
538        let obj = ObjectInstance::new("DemoClass".to_string());
539        let err = logical_builtin(Value::Object(obj), Vec::new()).unwrap_err();
540        assert_error_contains(&err, "DemoClass");
541        assert_eq!(
542            err.identifier(),
543            LOGICAL_ERROR_CONVERSION_NOT_POSSIBLE.identifier
544        );
545    }
546
547    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
548    #[test]
549    fn logical_mexception_error() {
550        let mex = MException::new("id:logical".into(), "message".into());
551        let err = logical_builtin(Value::MException(mex), Vec::new()).unwrap_err();
552        assert_error_message(
553            &err,
554            "logical: conversion to logical from MException is not possible",
555        );
556        assert_eq!(
557            err.identifier(),
558            LOGICAL_ERROR_CONVERSION_NOT_POSSIBLE.identifier
559        );
560    }
561
562    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
563    #[test]
564    fn logical_too_many_inputs_error() {
565        let err = logical_builtin(Value::Bool(true), vec![Value::Bool(false)]).unwrap_err();
566        assert_error_message(&err, LOGICAL_ERROR_TOO_MANY_INPUTS.message);
567        assert_eq!(err.identifier(), LOGICAL_ERROR_TOO_MANY_INPUTS.identifier);
568    }
569
570    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
571    #[test]
572    fn logical_gpu_roundtrip() {
573        test_support::with_test_provider(|provider| {
574            let tensor = Tensor::new(vec![0.0, 1.0, -2.0], vec![3, 1]).unwrap();
575            let view = HostTensorView {
576                data: &tensor.data,
577                shape: &tensor.shape,
578            };
579            let handle = provider.upload(&view).expect("upload");
580            let result =
581                logical_builtin(Value::GpuTensor(handle.clone()), Vec::new()).expect("logical");
582            let gathered = test_support::gather(result.clone()).expect("gather");
583            assert_eq!(gathered.data, vec![0.0, 1.0, 1.0]);
584            if let Value::GpuTensor(out) = result {
585                assert!(runmat_accelerate_api::handle_is_logical(&out));
586            } else {
587                panic!("expected gpu tensor output");
588            }
589        });
590    }
591
592    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
593    #[test]
594    fn logical_gpu_passthrough_for_logical_handle() {
595        test_support::with_test_provider(|provider| {
596            let tensor = Tensor::new(vec![0.0, 1.0], vec![2, 1]).unwrap();
597            let view = HostTensorView {
598                data: &tensor.data,
599                shape: &tensor.shape,
600            };
601            let handle = provider.upload(&view).expect("upload");
602            runmat_accelerate_api::set_handle_logical(&handle, true);
603            let result =
604                logical_builtin(Value::GpuTensor(handle.clone()), Vec::new()).expect("logical");
605            match result {
606                Value::GpuTensor(out) => assert_eq!(out, handle),
607                other => panic!("expected gpu tensor, got {:?}", other),
608            }
609        });
610    }
611
612    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
613    #[test]
614    fn logical_bool_and_logical_inputs_passthrough() {
615        let res_bool = logical_builtin(Value::Bool(true), Vec::new()).expect("logical");
616        assert_eq!(res_bool, Value::Bool(true));
617
618        let logical = LogicalArray::new(vec![1, 0], vec![1, 2]).unwrap();
619        let res_array =
620            logical_builtin(Value::LogicalArray(logical.clone()), Vec::new()).expect("logical");
621        assert_eq!(res_array, Value::LogicalArray(logical));
622    }
623
624    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
625    #[test]
626    fn logical_empty_tensor_preserves_shape() {
627        let tensor = Tensor::new(Vec::new(), vec![0, 3]).unwrap();
628        let result = logical_builtin(Value::Tensor(tensor), Vec::new()).expect("logical");
629        match result {
630            Value::LogicalArray(array) => {
631                assert!(array.data.is_empty());
632                assert_eq!(array.shape, vec![0, 3]);
633            }
634            other => panic!("expected logical array, got {:?}", other),
635        }
636    }
637
638    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
639    #[test]
640    fn logical_integer_scalar() {
641        let res = logical_builtin(Value::Int(IntValue::I32(0)), Vec::new()).expect("logical");
642        assert_eq!(res, Value::Bool(false));
643
644        let res_nonzero =
645            logical_builtin(Value::Int(IntValue::I32(-5)), Vec::new()).expect("logical");
646        assert_eq!(res_nonzero, Value::Bool(true));
647    }
648
649    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
650    #[test]
651    #[cfg(feature = "wgpu")]
652    fn logical_wgpu_matches_cpu_conversion() {
653        let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
654            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
655        );
656
657        let tensor = Tensor::new(vec![0.0, 2.0, -3.0, f64::NAN], vec![2, 2]).unwrap();
658        let cpu = logical_builtin(Value::Tensor(tensor.clone()), Vec::new()).unwrap();
659
660        let view = runmat_accelerate_api::HostTensorView {
661            data: &tensor.data,
662            shape: &tensor.shape,
663        };
664        let provider = runmat_accelerate_api::provider().expect("wgpu provider");
665        let handle = provider.upload(&view).expect("upload");
666
667        let gpu_value = logical_builtin(Value::GpuTensor(handle), Vec::new()).unwrap();
668        let out_handle = match gpu_value {
669            Value::GpuTensor(ref h) => {
670                assert!(runmat_accelerate_api::handle_is_logical(h));
671                h.clone()
672            }
673            other => panic!("expected gpu tensor, got {other:?}"),
674        };
675
676        let gathered = test_support::gather(Value::GpuTensor(out_handle)).expect("gather");
677
678        let (expected, expected_shape): (Vec<f64>, Vec<usize>) = match cpu {
679            Value::LogicalArray(arr) => (
680                arr.data
681                    .iter()
682                    .map(|&b| if b != 0 { 1.0 } else { 0.0 })
683                    .collect(),
684                arr.shape.clone(),
685            ),
686            Value::Bool(flag) => (vec![if flag { 1.0 } else { 0.0 }], vec![1, 1]),
687            other => panic!("unexpected cpu result {other:?}"),
688        };
689
690        assert_eq!(gathered.shape, expected_shape);
691        assert_eq!(gathered.data, expected);
692    }
693}