Skip to main content

runmat_runtime/builtins/logical/
ops.rs

1//! MATLAB-compatible `logical` builtin with GPU-aware semantics for RunMat.
2
3use log::trace;
4use runmat_accelerate_api::{self, AccelProvider, GpuTensorHandle, HostTensorView};
5use runmat_builtins::{
6    CharArray, ComplexTensor, LogicalArray, ResolveContext, StringArray, Tensor, Type, Value,
7};
8use runmat_macros::runtime_builtin;
9
10use crate::builtins::common::{
11    gpu_helpers,
12    shape::{canonical_scalar_shape, normalize_scalar_shape},
13    spec::{
14        BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
15        ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
16    },
17    tensor,
18};
19use crate::builtins::logical::type_resolvers::logical_like;
20
21use crate::{build_runtime_error, BuiltinResult, RuntimeError};
22
23#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::logical::ops")]
24pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
25    name: "logical",
26    op_kind: GpuOpKind::Elementwise,
27    supported_precisions: &[ScalarType::F32, ScalarType::F64],
28    broadcast: BroadcastSemantics::Matlab,
29    provider_hooks: &[ProviderHook::Binary {
30        name: "elem_ne",
31        commutative: true,
32    }],
33    constant_strategy: ConstantStrategy::InlineLiteral,
34    residency: ResidencyPolicy::NewHandle,
35    nan_mode: ReductionNaN::Include,
36    two_pass_threshold: None,
37    workgroup_size: None,
38    accepts_nan_mode: false,
39    notes: "Preferred path issues elem_ne(X, 0) on the device; missing hooks trigger a gather → host cast → re-upload sequence flagged as logical.",
40};
41
42#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::logical::ops")]
43pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
44    name: "logical",
45    shape: ShapeRequirements::BroadcastCompatible,
46    constant_strategy: ConstantStrategy::InlineLiteral,
47    elementwise: None,
48    reduction: None,
49    emits_nan: false,
50    notes: "Fusion support will arrive alongside a dedicated WGSL template; today the builtin executes outside fusion plans.",
51};
52
53const BUILTIN_NAME: &str = "logical";
54
55fn logical_type(args: &[Type], _context: &ResolveContext) -> Type {
56    args.first().map(logical_like).unwrap_or(Type::logical())
57}
58
59fn logical_error(message: impl Into<String>) -> RuntimeError {
60    build_runtime_error(message)
61        .with_builtin(BUILTIN_NAME)
62        .build()
63}
64
65#[runtime_builtin(
66    name = "logical",
67    category = "logical",
68    summary = "Convert scalars, arrays, and gpuArray values to logical outputs.",
69    keywords = "logical,boolean,gpuArray,mask,conversion",
70    accel = "unary",
71    type_resolver(logical_type),
72    builtin_path = "crate::builtins::logical::ops"
73)]
74async fn logical_builtin(value: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
75    if !rest.is_empty() {
76        return Err(logical_error("logical: too many input arguments"));
77    }
78    convert_value_to_logical(value).await
79}
80
81async fn convert_value_to_logical(value: Value) -> BuiltinResult<Value> {
82    match value {
83        Value::Bool(_) | Value::LogicalArray(_) => Ok(value),
84        Value::Num(n) => Ok(Value::Bool(n != 0.0)),
85        Value::Int(i) => Ok(Value::Bool(!i.is_zero())),
86        Value::Complex(re, im) => Ok(Value::Bool(!complex_is_zero(re, im))),
87        Value::Tensor(tensor) => logical_from_tensor(tensor),
88        Value::ComplexTensor(tensor) => logical_from_complex_tensor(tensor),
89        Value::CharArray(chars) => logical_from_char_array(chars),
90        Value::StringArray(strings) => logical_from_string_array(strings),
91        Value::GpuTensor(handle) => logical_from_gpu(handle).await,
92        Value::String(_) => Err(conversion_error("string")),
93        Value::Cell(_) => Err(conversion_error("cell")),
94        Value::Struct(_) => Err(conversion_error("struct")),
95        Value::Object(obj) => Err(conversion_error(&obj.class_name)),
96        Value::HandleObject(handle) => Err(conversion_error(&handle.class_name)),
97        Value::Listener(_) => Err(conversion_error("event.listener")),
98        Value::FunctionHandle(_) | Value::Closure(_) => Err(conversion_error("function_handle")),
99        Value::ClassRef(_) => Err(conversion_error("meta.class")),
100        Value::MException(_) => Err(conversion_error("MException")),
101        Value::OutputList(_) => Err(conversion_error("OutputList")),
102    }
103}
104
105fn logical_from_tensor(tensor: Tensor) -> BuiltinResult<Value> {
106    let buffer = LogicalBuffer::from_real_tensor(&tensor);
107    logical_buffer_to_host(buffer)
108}
109
110fn logical_from_complex_tensor(tensor: ComplexTensor) -> BuiltinResult<Value> {
111    let buffer = LogicalBuffer::from_complex_tensor(&tensor);
112    logical_buffer_to_host(buffer)
113}
114
115fn logical_from_char_array(chars: CharArray) -> BuiltinResult<Value> {
116    let buffer = LogicalBuffer::from_char_array(&chars);
117    logical_buffer_to_host(buffer)
118}
119
120fn logical_from_string_array(strings: StringArray) -> BuiltinResult<Value> {
121    let bits: Vec<u8> = strings
122        .data
123        .iter()
124        .map(|s| if s.is_empty() { 0 } else { 1 })
125        .collect();
126    let shape = canonical_shape(&strings.shape, bits.len());
127    logical_buffer_to_host(LogicalBuffer { bits, shape })
128}
129
130async fn logical_from_gpu(handle: GpuTensorHandle) -> BuiltinResult<Value> {
131    if runmat_accelerate_api::handle_is_logical(&handle) {
132        return Ok(Value::GpuTensor(handle));
133    }
134
135    let provider = runmat_accelerate_api::provider();
136
137    if let Some(p) = provider {
138        match p.logical_islogical(&handle) {
139            Ok(true) => {
140                runmat_accelerate_api::set_handle_logical(&handle, true);
141                return Ok(Value::GpuTensor(handle));
142            }
143            Ok(false) => {}
144            Err(err) => {
145                trace!("logical: provider logical_islogical hook unavailable, falling back ({err})")
146            }
147        }
148        if let Some(result) = try_gpu_cast(p, &handle).await {
149            return Ok(gpu_helpers::logical_gpu_value(result));
150        } else {
151            trace!(
152                "logical: provider elem_ne/zeros_like unavailable for buffer {} – gathering",
153                handle.buffer_id
154            );
155        }
156    }
157
158    let tensor = gpu_helpers::gather_tensor_async(&handle)
159        .await
160        .map_err(|err| logical_error(format!("{BUILTIN_NAME}: {err}")))?;
161    let buffer = LogicalBuffer::from_real_tensor(&tensor);
162    logical_buffer_to_gpu(buffer, provider)
163}
164
165fn logical_buffer_to_host(buffer: LogicalBuffer) -> BuiltinResult<Value> {
166    let LogicalBuffer { bits, shape } = buffer;
167    if tensor::element_count(&shape) == 1 && bits.len() == 1 {
168        Ok(Value::Bool(bits[0] != 0))
169    } else {
170        LogicalArray::new(bits, shape)
171            .map(Value::LogicalArray)
172            .map_err(|e| logical_error(format!("logical: {e}")))
173    }
174}
175
176fn logical_buffer_to_gpu(
177    buffer: LogicalBuffer,
178    provider: Option<&'static dyn AccelProvider>,
179) -> BuiltinResult<Value> {
180    if let Some(p) = provider {
181        let floats: Vec<f64> = buffer
182            .bits
183            .iter()
184            .map(|&b| if b != 0 { 1.0 } else { 0.0 })
185            .collect();
186        let view = HostTensorView {
187            data: &floats,
188            shape: &buffer.shape,
189        };
190        match p.upload(&view) {
191            Ok(handle) => Ok(gpu_helpers::logical_gpu_value(handle)),
192            Err(err) => {
193                trace!("logical: upload failed during fallback path ({err})");
194                logical_buffer_to_host(buffer)
195            }
196        }
197    } else {
198        logical_buffer_to_host(buffer)
199    }
200}
201
202async fn try_gpu_cast(
203    provider: &'static dyn AccelProvider,
204    input: &GpuTensorHandle,
205) -> Option<GpuTensorHandle> {
206    let zeros = provider.zeros_like(input).ok()?;
207    let result = provider.elem_ne(input, &zeros).await.ok();
208    let _ = provider.free(&zeros);
209    result
210}
211
212fn complex_is_zero(re: f64, im: f64) -> bool {
213    re == 0.0 && im == 0.0
214}
215
216fn conversion_error(type_name: &str) -> RuntimeError {
217    logical_error(format!(
218        "logical: conversion to logical from {} is not possible",
219        type_name
220    ))
221}
222
223#[derive(Clone)]
224struct LogicalBuffer {
225    bits: Vec<u8>,
226    shape: Vec<usize>,
227}
228
229impl LogicalBuffer {
230    fn from_real_tensor(tensor: &Tensor) -> Self {
231        let bits: Vec<u8> = tensor
232            .data
233            .iter()
234            .map(|&v| if v != 0.0 { 1 } else { 0 })
235            .collect();
236        let shape = canonical_shape(&tensor.shape, bits.len());
237        Self { bits, shape }
238    }
239
240    fn from_complex_tensor(tensor: &ComplexTensor) -> Self {
241        let bits: Vec<u8> = tensor
242            .data
243            .iter()
244            .map(|&(re, im)| if !complex_is_zero(re, im) { 1 } else { 0 })
245            .collect();
246        let shape = canonical_shape(&tensor.shape, bits.len());
247        Self { bits, shape }
248    }
249
250    fn from_char_array(chars: &CharArray) -> Self {
251        let bits: Vec<u8> = chars
252            .data
253            .iter()
254            .map(|&ch| if (ch as u32) != 0 { 1 } else { 0 })
255            .collect();
256        let original_shape = vec![chars.rows, chars.cols];
257        let shape = canonical_shape(&original_shape, bits.len());
258        Self { bits, shape }
259    }
260}
261
262fn canonical_shape(shape: &[usize], len: usize) -> Vec<usize> {
263    if tensor::element_count(shape) == len {
264        return normalize_scalar_shape(shape);
265    }
266    if len == 0 {
267        if shape.len() > 1 {
268            return shape.to_vec();
269        }
270        return vec![0];
271    }
272    if len == 1 {
273        canonical_scalar_shape()
274    } else {
275        vec![len, 1]
276    }
277}
278
279#[cfg(test)]
280pub(crate) mod tests {
281    use super::*;
282    use crate::builtins::common::test_support;
283    use futures::executor::block_on;
284    use runmat_accelerate_api::HostTensorView;
285    use runmat_builtins::{CellArray, IntValue, MException, ObjectInstance, StructValue};
286
287    fn logical_builtin(value: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
288        block_on(super::logical_builtin(value, rest))
289    }
290
291    fn assert_error_message(err: crate::RuntimeError, expected: &str) {
292        assert_eq!(err.message(), expected);
293    }
294
295    fn assert_error_contains(err: crate::RuntimeError, expected: &str) {
296        assert!(
297            err.message().contains(expected),
298            "unexpected error: {}",
299            err.message()
300        );
301    }
302
303    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
304    #[test]
305    fn logical_scalar_num() {
306        let result = logical_builtin(Value::Num(5.0), Vec::new()).expect("logical");
307        assert_eq!(result, Value::Bool(true));
308
309        let zero_result = logical_builtin(Value::Num(0.0), Vec::new()).expect("logical");
310        assert_eq!(zero_result, Value::Bool(false));
311    }
312
313    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
314    #[test]
315    fn logical_nan_is_true() {
316        let tensor = Tensor::new(vec![0.0, f64::NAN, -0.0], vec![1, 3]).unwrap();
317        let result = logical_builtin(Value::Tensor(tensor), Vec::new()).expect("logical");
318        match result {
319            Value::LogicalArray(array) => assert_eq!(array.data, vec![0, 1, 0]),
320            other => panic!("expected logical array, got {:?}", other),
321        }
322    }
323
324    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
325    #[test]
326    fn logical_tensor_matrix() {
327        let tensor = Tensor::new(vec![0.0, 2.0, -3.0, 0.0], vec![2, 2]).unwrap();
328        let result = logical_builtin(Value::Tensor(tensor), Vec::new()).expect("logical");
329        match result {
330            Value::LogicalArray(array) => {
331                assert_eq!(array.shape, vec![2, 2]);
332                assert_eq!(array.data, vec![0, 1, 1, 0]);
333            }
334            other => panic!("expected logical array, got {:?}", other),
335        }
336    }
337
338    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
339    #[test]
340    fn logical_complex_conversion() {
341        let complex =
342            ComplexTensor::new(vec![(0.0, 0.0), (1.0, 0.0), (0.0, 2.0)], vec![3, 1]).unwrap();
343        let result = logical_builtin(Value::ComplexTensor(complex), Vec::new()).expect("logical");
344        match result {
345            Value::LogicalArray(array) => {
346                assert_eq!(array.data, vec![0, 1, 1]);
347            }
348            other => panic!("expected logical array, got {:?}", other),
349        }
350    }
351
352    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
353    #[test]
354    fn logical_char_array_conversion() {
355        let chars = CharArray::new(vec!['A', '\0', 'C'], 1, 3).unwrap();
356        let result = logical_builtin(Value::CharArray(chars), Vec::new()).expect("logical");
357        match result {
358            Value::LogicalArray(array) => assert_eq!(array.data, vec![1, 0, 1]),
359            other => panic!("expected logical array, got {:?}", other),
360        }
361    }
362
363    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
364    #[test]
365    fn logical_string_error() {
366        let err = logical_builtin(Value::String("runmat".to_string()), Vec::new()).unwrap_err();
367        assert_error_message(
368            err,
369            "logical: conversion to logical from string is not possible",
370        );
371    }
372
373    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
374    #[test]
375    fn logical_struct_error() {
376        let mut st = StructValue::new();
377        st.insert("field", Value::Num(1.0));
378        let err = logical_builtin(Value::Struct(st), Vec::new()).unwrap_err();
379        assert_error_contains(err, "struct");
380    }
381
382    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
383    #[test]
384    fn logical_cell_error() {
385        let cell = CellArray::new(vec![Value::Num(1.0)], 1, 1).expect("cell creation");
386        let err = logical_builtin(Value::Cell(cell), Vec::new()).unwrap_err();
387        assert_error_message(
388            err,
389            "logical: conversion to logical from cell is not possible",
390        );
391    }
392
393    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
394    #[test]
395    fn logical_function_handle_error() {
396        let err = logical_builtin(Value::FunctionHandle("foo".into()), Vec::new()).unwrap_err();
397        assert_error_message(
398            err,
399            "logical: conversion to logical from function_handle is not possible",
400        );
401    }
402
403    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
404    #[test]
405    fn logical_object_error() {
406        let obj = ObjectInstance::new("DemoClass".to_string());
407        let err = logical_builtin(Value::Object(obj), Vec::new()).unwrap_err();
408        assert_error_contains(err, "DemoClass");
409    }
410
411    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
412    #[test]
413    fn logical_mexception_error() {
414        let mex = MException::new("id:logical".into(), "message".into());
415        let err = logical_builtin(Value::MException(mex), Vec::new()).unwrap_err();
416        assert_error_message(
417            err,
418            "logical: conversion to logical from MException is not possible",
419        );
420    }
421
422    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
423    #[test]
424    fn logical_gpu_roundtrip() {
425        test_support::with_test_provider(|provider| {
426            let tensor = Tensor::new(vec![0.0, 1.0, -2.0], vec![3, 1]).unwrap();
427            let view = HostTensorView {
428                data: &tensor.data,
429                shape: &tensor.shape,
430            };
431            let handle = provider.upload(&view).expect("upload");
432            let result =
433                logical_builtin(Value::GpuTensor(handle.clone()), Vec::new()).expect("logical");
434            let gathered = test_support::gather(result.clone()).expect("gather");
435            assert_eq!(gathered.data, vec![0.0, 1.0, 1.0]);
436            if let Value::GpuTensor(out) = result {
437                assert!(runmat_accelerate_api::handle_is_logical(&out));
438            } else {
439                panic!("expected gpu tensor output");
440            }
441        });
442    }
443
444    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
445    #[test]
446    fn logical_gpu_passthrough_for_logical_handle() {
447        test_support::with_test_provider(|provider| {
448            let tensor = Tensor::new(vec![0.0, 1.0], vec![2, 1]).unwrap();
449            let view = HostTensorView {
450                data: &tensor.data,
451                shape: &tensor.shape,
452            };
453            let handle = provider.upload(&view).expect("upload");
454            runmat_accelerate_api::set_handle_logical(&handle, true);
455            let result =
456                logical_builtin(Value::GpuTensor(handle.clone()), Vec::new()).expect("logical");
457            match result {
458                Value::GpuTensor(out) => assert_eq!(out, handle),
459                other => panic!("expected gpu tensor, got {:?}", other),
460            }
461        });
462    }
463
464    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
465    #[test]
466    fn logical_bool_and_logical_inputs_passthrough() {
467        let res_bool = logical_builtin(Value::Bool(true), Vec::new()).expect("logical");
468        assert_eq!(res_bool, Value::Bool(true));
469
470        let logical = LogicalArray::new(vec![1, 0], vec![1, 2]).unwrap();
471        let res_array =
472            logical_builtin(Value::LogicalArray(logical.clone()), Vec::new()).expect("logical");
473        assert_eq!(res_array, Value::LogicalArray(logical));
474    }
475
476    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
477    #[test]
478    fn logical_empty_tensor_preserves_shape() {
479        let tensor = Tensor::new(Vec::new(), vec![0, 3]).unwrap();
480        let result = logical_builtin(Value::Tensor(tensor), Vec::new()).expect("logical");
481        match result {
482            Value::LogicalArray(array) => {
483                assert!(array.data.is_empty());
484                assert_eq!(array.shape, vec![0, 3]);
485            }
486            other => panic!("expected logical array, got {:?}", other),
487        }
488    }
489
490    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
491    #[test]
492    fn logical_integer_scalar() {
493        let res = logical_builtin(Value::Int(IntValue::I32(0)), Vec::new()).expect("logical");
494        assert_eq!(res, Value::Bool(false));
495
496        let res_nonzero =
497            logical_builtin(Value::Int(IntValue::I32(-5)), Vec::new()).expect("logical");
498        assert_eq!(res_nonzero, Value::Bool(true));
499    }
500
501    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
502    #[test]
503    #[cfg(feature = "wgpu")]
504    fn logical_wgpu_matches_cpu_conversion() {
505        let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
506            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
507        );
508
509        let tensor = Tensor::new(vec![0.0, 2.0, -3.0, f64::NAN], vec![2, 2]).unwrap();
510        let cpu = logical_builtin(Value::Tensor(tensor.clone()), Vec::new()).unwrap();
511
512        let view = runmat_accelerate_api::HostTensorView {
513            data: &tensor.data,
514            shape: &tensor.shape,
515        };
516        let provider = runmat_accelerate_api::provider().expect("wgpu provider");
517        let handle = provider.upload(&view).expect("upload");
518
519        let gpu_value = logical_builtin(Value::GpuTensor(handle), Vec::new()).unwrap();
520        let out_handle = match gpu_value {
521            Value::GpuTensor(ref h) => {
522                assert!(runmat_accelerate_api::handle_is_logical(h));
523                h.clone()
524            }
525            other => panic!("expected gpu tensor, got {other:?}"),
526        };
527
528        let gathered = test_support::gather(Value::GpuTensor(out_handle)).expect("gather");
529
530        let (expected, expected_shape): (Vec<f64>, Vec<usize>) = match cpu {
531            Value::LogicalArray(arr) => (
532                arr.data
533                    .iter()
534                    .map(|&b| if b != 0 { 1.0 } else { 0.0 })
535                    .collect(),
536                arr.shape.clone(),
537            ),
538            Value::Bool(flag) => (vec![if flag { 1.0 } else { 0.0 }], vec![1, 1]),
539            other => panic!("unexpected cpu result {other:?}"),
540        };
541
542        assert_eq!(gathered.shape, expected_shape);
543        assert_eq!(gathered.data, expected);
544    }
545}