1use log::trace;
4use runmat_accelerate_api::{self, AccelProvider, GpuTensorHandle, HostTensorView};
5use runmat_builtins::{
6 CharArray, ComplexTensor, LogicalArray, ResolveContext, StringArray, Tensor, Type, Value,
7};
8use runmat_macros::runtime_builtin;
9
10use crate::builtins::common::{
11 gpu_helpers,
12 shape::{canonical_scalar_shape, normalize_scalar_shape},
13 spec::{
14 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
15 ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
16 },
17 tensor,
18};
19use crate::builtins::logical::type_resolvers::logical_like;
20
21use crate::{build_runtime_error, BuiltinResult, RuntimeError};
22
23#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::logical::ops")]
24pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
25 name: "logical",
26 op_kind: GpuOpKind::Elementwise,
27 supported_precisions: &[ScalarType::F32, ScalarType::F64],
28 broadcast: BroadcastSemantics::Matlab,
29 provider_hooks: &[ProviderHook::Binary {
30 name: "elem_ne",
31 commutative: true,
32 }],
33 constant_strategy: ConstantStrategy::InlineLiteral,
34 residency: ResidencyPolicy::NewHandle,
35 nan_mode: ReductionNaN::Include,
36 two_pass_threshold: None,
37 workgroup_size: None,
38 accepts_nan_mode: false,
39 notes: "Preferred path issues elem_ne(X, 0) on the device; missing hooks trigger a gather → host cast → re-upload sequence flagged as logical.",
40};
41
42#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::logical::ops")]
43pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
44 name: "logical",
45 shape: ShapeRequirements::BroadcastCompatible,
46 constant_strategy: ConstantStrategy::InlineLiteral,
47 elementwise: None,
48 reduction: None,
49 emits_nan: false,
50 notes: "Fusion support will arrive alongside a dedicated WGSL template; today the builtin executes outside fusion plans.",
51};
52
53const BUILTIN_NAME: &str = "logical";
54
55fn logical_type(args: &[Type], _context: &ResolveContext) -> Type {
56 args.first().map(logical_like).unwrap_or(Type::logical())
57}
58
59fn logical_error(message: impl Into<String>) -> RuntimeError {
60 build_runtime_error(message)
61 .with_builtin(BUILTIN_NAME)
62 .build()
63}
64
65#[runtime_builtin(
66 name = "logical",
67 category = "logical",
68 summary = "Convert scalars, arrays, and gpuArray values to logical outputs.",
69 keywords = "logical,boolean,gpuArray,mask,conversion",
70 accel = "unary",
71 type_resolver(logical_type),
72 builtin_path = "crate::builtins::logical::ops"
73)]
74async fn logical_builtin(value: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
75 if !rest.is_empty() {
76 return Err(logical_error("logical: too many input arguments"));
77 }
78 convert_value_to_logical(value).await
79}
80
81async fn convert_value_to_logical(value: Value) -> BuiltinResult<Value> {
82 match value {
83 Value::Bool(_) | Value::LogicalArray(_) => Ok(value),
84 Value::Num(n) => Ok(Value::Bool(n != 0.0)),
85 Value::Int(i) => Ok(Value::Bool(!i.is_zero())),
86 Value::Complex(re, im) => Ok(Value::Bool(!complex_is_zero(re, im))),
87 Value::Tensor(tensor) => logical_from_tensor(tensor),
88 Value::ComplexTensor(tensor) => logical_from_complex_tensor(tensor),
89 Value::CharArray(chars) => logical_from_char_array(chars),
90 Value::StringArray(strings) => logical_from_string_array(strings),
91 Value::GpuTensor(handle) => logical_from_gpu(handle).await,
92 Value::String(_) => Err(conversion_error("string")),
93 Value::Cell(_) => Err(conversion_error("cell")),
94 Value::Struct(_) => Err(conversion_error("struct")),
95 Value::Object(obj) => Err(conversion_error(&obj.class_name)),
96 Value::HandleObject(handle) => Err(conversion_error(&handle.class_name)),
97 Value::Listener(_) => Err(conversion_error("event.listener")),
98 Value::FunctionHandle(_) | Value::Closure(_) => Err(conversion_error("function_handle")),
99 Value::ClassRef(_) => Err(conversion_error("meta.class")),
100 Value::MException(_) => Err(conversion_error("MException")),
101 Value::OutputList(_) => Err(conversion_error("OutputList")),
102 }
103}
104
105fn logical_from_tensor(tensor: Tensor) -> BuiltinResult<Value> {
106 let buffer = LogicalBuffer::from_real_tensor(&tensor);
107 logical_buffer_to_host(buffer)
108}
109
110fn logical_from_complex_tensor(tensor: ComplexTensor) -> BuiltinResult<Value> {
111 let buffer = LogicalBuffer::from_complex_tensor(&tensor);
112 logical_buffer_to_host(buffer)
113}
114
115fn logical_from_char_array(chars: CharArray) -> BuiltinResult<Value> {
116 let buffer = LogicalBuffer::from_char_array(&chars);
117 logical_buffer_to_host(buffer)
118}
119
120fn logical_from_string_array(strings: StringArray) -> BuiltinResult<Value> {
121 let bits: Vec<u8> = strings
122 .data
123 .iter()
124 .map(|s| if s.is_empty() { 0 } else { 1 })
125 .collect();
126 let shape = canonical_shape(&strings.shape, bits.len());
127 logical_buffer_to_host(LogicalBuffer { bits, shape })
128}
129
130async fn logical_from_gpu(handle: GpuTensorHandle) -> BuiltinResult<Value> {
131 if runmat_accelerate_api::handle_is_logical(&handle) {
132 return Ok(Value::GpuTensor(handle));
133 }
134
135 let provider = runmat_accelerate_api::provider();
136
137 if let Some(p) = provider {
138 match p.logical_islogical(&handle) {
139 Ok(true) => {
140 runmat_accelerate_api::set_handle_logical(&handle, true);
141 return Ok(Value::GpuTensor(handle));
142 }
143 Ok(false) => {}
144 Err(err) => {
145 trace!("logical: provider logical_islogical hook unavailable, falling back ({err})")
146 }
147 }
148 if let Some(result) = try_gpu_cast(p, &handle).await {
149 return Ok(gpu_helpers::logical_gpu_value(result));
150 } else {
151 trace!(
152 "logical: provider elem_ne/zeros_like unavailable for buffer {} – gathering",
153 handle.buffer_id
154 );
155 }
156 }
157
158 let tensor = gpu_helpers::gather_tensor_async(&handle)
159 .await
160 .map_err(|err| logical_error(format!("{BUILTIN_NAME}: {err}")))?;
161 let buffer = LogicalBuffer::from_real_tensor(&tensor);
162 logical_buffer_to_gpu(buffer, provider)
163}
164
165fn logical_buffer_to_host(buffer: LogicalBuffer) -> BuiltinResult<Value> {
166 let LogicalBuffer { bits, shape } = buffer;
167 if tensor::element_count(&shape) == 1 && bits.len() == 1 {
168 Ok(Value::Bool(bits[0] != 0))
169 } else {
170 LogicalArray::new(bits, shape)
171 .map(Value::LogicalArray)
172 .map_err(|e| logical_error(format!("logical: {e}")))
173 }
174}
175
176fn logical_buffer_to_gpu(
177 buffer: LogicalBuffer,
178 provider: Option<&'static dyn AccelProvider>,
179) -> BuiltinResult<Value> {
180 if let Some(p) = provider {
181 let floats: Vec<f64> = buffer
182 .bits
183 .iter()
184 .map(|&b| if b != 0 { 1.0 } else { 0.0 })
185 .collect();
186 let view = HostTensorView {
187 data: &floats,
188 shape: &buffer.shape,
189 };
190 match p.upload(&view) {
191 Ok(handle) => Ok(gpu_helpers::logical_gpu_value(handle)),
192 Err(err) => {
193 trace!("logical: upload failed during fallback path ({err})");
194 logical_buffer_to_host(buffer)
195 }
196 }
197 } else {
198 logical_buffer_to_host(buffer)
199 }
200}
201
202async fn try_gpu_cast(
203 provider: &'static dyn AccelProvider,
204 input: &GpuTensorHandle,
205) -> Option<GpuTensorHandle> {
206 let zeros = provider.zeros_like(input).ok()?;
207 let result = provider.elem_ne(input, &zeros).await.ok();
208 let _ = provider.free(&zeros);
209 result
210}
211
212fn complex_is_zero(re: f64, im: f64) -> bool {
213 re == 0.0 && im == 0.0
214}
215
216fn conversion_error(type_name: &str) -> RuntimeError {
217 logical_error(format!(
218 "logical: conversion to logical from {} is not possible",
219 type_name
220 ))
221}
222
223#[derive(Clone)]
224struct LogicalBuffer {
225 bits: Vec<u8>,
226 shape: Vec<usize>,
227}
228
229impl LogicalBuffer {
230 fn from_real_tensor(tensor: &Tensor) -> Self {
231 let bits: Vec<u8> = tensor
232 .data
233 .iter()
234 .map(|&v| if v != 0.0 { 1 } else { 0 })
235 .collect();
236 let shape = canonical_shape(&tensor.shape, bits.len());
237 Self { bits, shape }
238 }
239
240 fn from_complex_tensor(tensor: &ComplexTensor) -> Self {
241 let bits: Vec<u8> = tensor
242 .data
243 .iter()
244 .map(|&(re, im)| if !complex_is_zero(re, im) { 1 } else { 0 })
245 .collect();
246 let shape = canonical_shape(&tensor.shape, bits.len());
247 Self { bits, shape }
248 }
249
250 fn from_char_array(chars: &CharArray) -> Self {
251 let bits: Vec<u8> = chars
252 .data
253 .iter()
254 .map(|&ch| if (ch as u32) != 0 { 1 } else { 0 })
255 .collect();
256 let original_shape = vec![chars.rows, chars.cols];
257 let shape = canonical_shape(&original_shape, bits.len());
258 Self { bits, shape }
259 }
260}
261
262fn canonical_shape(shape: &[usize], len: usize) -> Vec<usize> {
263 if tensor::element_count(shape) == len {
264 return normalize_scalar_shape(shape);
265 }
266 if len == 0 {
267 if shape.len() > 1 {
268 return shape.to_vec();
269 }
270 return vec![0];
271 }
272 if len == 1 {
273 canonical_scalar_shape()
274 } else {
275 vec![len, 1]
276 }
277}
278
279#[cfg(test)]
280pub(crate) mod tests {
281 use super::*;
282 use crate::builtins::common::test_support;
283 use futures::executor::block_on;
284 use runmat_accelerate_api::HostTensorView;
285 use runmat_builtins::{CellArray, IntValue, MException, ObjectInstance, StructValue};
286
287 fn logical_builtin(value: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
288 block_on(super::logical_builtin(value, rest))
289 }
290
291 fn assert_error_message(err: crate::RuntimeError, expected: &str) {
292 assert_eq!(err.message(), expected);
293 }
294
295 fn assert_error_contains(err: crate::RuntimeError, expected: &str) {
296 assert!(
297 err.message().contains(expected),
298 "unexpected error: {}",
299 err.message()
300 );
301 }
302
303 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
304 #[test]
305 fn logical_scalar_num() {
306 let result = logical_builtin(Value::Num(5.0), Vec::new()).expect("logical");
307 assert_eq!(result, Value::Bool(true));
308
309 let zero_result = logical_builtin(Value::Num(0.0), Vec::new()).expect("logical");
310 assert_eq!(zero_result, Value::Bool(false));
311 }
312
313 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
314 #[test]
315 fn logical_nan_is_true() {
316 let tensor = Tensor::new(vec![0.0, f64::NAN, -0.0], vec![1, 3]).unwrap();
317 let result = logical_builtin(Value::Tensor(tensor), Vec::new()).expect("logical");
318 match result {
319 Value::LogicalArray(array) => assert_eq!(array.data, vec![0, 1, 0]),
320 other => panic!("expected logical array, got {:?}", other),
321 }
322 }
323
324 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
325 #[test]
326 fn logical_tensor_matrix() {
327 let tensor = Tensor::new(vec![0.0, 2.0, -3.0, 0.0], vec![2, 2]).unwrap();
328 let result = logical_builtin(Value::Tensor(tensor), Vec::new()).expect("logical");
329 match result {
330 Value::LogicalArray(array) => {
331 assert_eq!(array.shape, vec![2, 2]);
332 assert_eq!(array.data, vec![0, 1, 1, 0]);
333 }
334 other => panic!("expected logical array, got {:?}", other),
335 }
336 }
337
338 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
339 #[test]
340 fn logical_complex_conversion() {
341 let complex =
342 ComplexTensor::new(vec![(0.0, 0.0), (1.0, 0.0), (0.0, 2.0)], vec![3, 1]).unwrap();
343 let result = logical_builtin(Value::ComplexTensor(complex), Vec::new()).expect("logical");
344 match result {
345 Value::LogicalArray(array) => {
346 assert_eq!(array.data, vec![0, 1, 1]);
347 }
348 other => panic!("expected logical array, got {:?}", other),
349 }
350 }
351
352 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
353 #[test]
354 fn logical_char_array_conversion() {
355 let chars = CharArray::new(vec!['A', '\0', 'C'], 1, 3).unwrap();
356 let result = logical_builtin(Value::CharArray(chars), Vec::new()).expect("logical");
357 match result {
358 Value::LogicalArray(array) => assert_eq!(array.data, vec![1, 0, 1]),
359 other => panic!("expected logical array, got {:?}", other),
360 }
361 }
362
363 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
364 #[test]
365 fn logical_string_error() {
366 let err = logical_builtin(Value::String("runmat".to_string()), Vec::new()).unwrap_err();
367 assert_error_message(
368 err,
369 "logical: conversion to logical from string is not possible",
370 );
371 }
372
373 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
374 #[test]
375 fn logical_struct_error() {
376 let mut st = StructValue::new();
377 st.insert("field", Value::Num(1.0));
378 let err = logical_builtin(Value::Struct(st), Vec::new()).unwrap_err();
379 assert_error_contains(err, "struct");
380 }
381
382 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
383 #[test]
384 fn logical_cell_error() {
385 let cell = CellArray::new(vec![Value::Num(1.0)], 1, 1).expect("cell creation");
386 let err = logical_builtin(Value::Cell(cell), Vec::new()).unwrap_err();
387 assert_error_message(
388 err,
389 "logical: conversion to logical from cell is not possible",
390 );
391 }
392
393 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
394 #[test]
395 fn logical_function_handle_error() {
396 let err = logical_builtin(Value::FunctionHandle("foo".into()), Vec::new()).unwrap_err();
397 assert_error_message(
398 err,
399 "logical: conversion to logical from function_handle is not possible",
400 );
401 }
402
403 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
404 #[test]
405 fn logical_object_error() {
406 let obj = ObjectInstance::new("DemoClass".to_string());
407 let err = logical_builtin(Value::Object(obj), Vec::new()).unwrap_err();
408 assert_error_contains(err, "DemoClass");
409 }
410
411 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
412 #[test]
413 fn logical_mexception_error() {
414 let mex = MException::new("id:logical".into(), "message".into());
415 let err = logical_builtin(Value::MException(mex), Vec::new()).unwrap_err();
416 assert_error_message(
417 err,
418 "logical: conversion to logical from MException is not possible",
419 );
420 }
421
422 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
423 #[test]
424 fn logical_gpu_roundtrip() {
425 test_support::with_test_provider(|provider| {
426 let tensor = Tensor::new(vec![0.0, 1.0, -2.0], vec![3, 1]).unwrap();
427 let view = HostTensorView {
428 data: &tensor.data,
429 shape: &tensor.shape,
430 };
431 let handle = provider.upload(&view).expect("upload");
432 let result =
433 logical_builtin(Value::GpuTensor(handle.clone()), Vec::new()).expect("logical");
434 let gathered = test_support::gather(result.clone()).expect("gather");
435 assert_eq!(gathered.data, vec![0.0, 1.0, 1.0]);
436 if let Value::GpuTensor(out) = result {
437 assert!(runmat_accelerate_api::handle_is_logical(&out));
438 } else {
439 panic!("expected gpu tensor output");
440 }
441 });
442 }
443
444 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
445 #[test]
446 fn logical_gpu_passthrough_for_logical_handle() {
447 test_support::with_test_provider(|provider| {
448 let tensor = Tensor::new(vec![0.0, 1.0], vec![2, 1]).unwrap();
449 let view = HostTensorView {
450 data: &tensor.data,
451 shape: &tensor.shape,
452 };
453 let handle = provider.upload(&view).expect("upload");
454 runmat_accelerate_api::set_handle_logical(&handle, true);
455 let result =
456 logical_builtin(Value::GpuTensor(handle.clone()), Vec::new()).expect("logical");
457 match result {
458 Value::GpuTensor(out) => assert_eq!(out, handle),
459 other => panic!("expected gpu tensor, got {:?}", other),
460 }
461 });
462 }
463
464 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
465 #[test]
466 fn logical_bool_and_logical_inputs_passthrough() {
467 let res_bool = logical_builtin(Value::Bool(true), Vec::new()).expect("logical");
468 assert_eq!(res_bool, Value::Bool(true));
469
470 let logical = LogicalArray::new(vec![1, 0], vec![1, 2]).unwrap();
471 let res_array =
472 logical_builtin(Value::LogicalArray(logical.clone()), Vec::new()).expect("logical");
473 assert_eq!(res_array, Value::LogicalArray(logical));
474 }
475
476 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
477 #[test]
478 fn logical_empty_tensor_preserves_shape() {
479 let tensor = Tensor::new(Vec::new(), vec![0, 3]).unwrap();
480 let result = logical_builtin(Value::Tensor(tensor), Vec::new()).expect("logical");
481 match result {
482 Value::LogicalArray(array) => {
483 assert!(array.data.is_empty());
484 assert_eq!(array.shape, vec![0, 3]);
485 }
486 other => panic!("expected logical array, got {:?}", other),
487 }
488 }
489
490 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
491 #[test]
492 fn logical_integer_scalar() {
493 let res = logical_builtin(Value::Int(IntValue::I32(0)), Vec::new()).expect("logical");
494 assert_eq!(res, Value::Bool(false));
495
496 let res_nonzero =
497 logical_builtin(Value::Int(IntValue::I32(-5)), Vec::new()).expect("logical");
498 assert_eq!(res_nonzero, Value::Bool(true));
499 }
500
501 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
502 #[test]
503 #[cfg(feature = "wgpu")]
504 fn logical_wgpu_matches_cpu_conversion() {
505 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
506 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
507 );
508
509 let tensor = Tensor::new(vec![0.0, 2.0, -3.0, f64::NAN], vec![2, 2]).unwrap();
510 let cpu = logical_builtin(Value::Tensor(tensor.clone()), Vec::new()).unwrap();
511
512 let view = runmat_accelerate_api::HostTensorView {
513 data: &tensor.data,
514 shape: &tensor.shape,
515 };
516 let provider = runmat_accelerate_api::provider().expect("wgpu provider");
517 let handle = provider.upload(&view).expect("upload");
518
519 let gpu_value = logical_builtin(Value::GpuTensor(handle), Vec::new()).unwrap();
520 let out_handle = match gpu_value {
521 Value::GpuTensor(ref h) => {
522 assert!(runmat_accelerate_api::handle_is_logical(h));
523 h.clone()
524 }
525 other => panic!("expected gpu tensor, got {other:?}"),
526 };
527
528 let gathered = test_support::gather(Value::GpuTensor(out_handle)).expect("gather");
529
530 let (expected, expected_shape): (Vec<f64>, Vec<usize>) = match cpu {
531 Value::LogicalArray(arr) => (
532 arr.data
533 .iter()
534 .map(|&b| if b != 0 { 1.0 } else { 0.0 })
535 .collect(),
536 arr.shape.clone(),
537 ),
538 Value::Bool(flag) => (vec![if flag { 1.0 } else { 0.0 }], vec![1, 1]),
539 other => panic!("unexpected cpu result {other:?}"),
540 };
541
542 assert_eq!(gathered.shape, expected_shape);
543 assert_eq!(gathered.data, expected);
544 }
545}