1use log::trace;
4use runmat_accelerate_api::{self, AccelProvider, GpuTensorHandle, HostTensorView};
5use runmat_builtins::{
6 BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
7 BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
8 CharArray, ComplexTensor, LogicalArray, ResolveContext, StringArray, Tensor, Type, Value,
9};
10use runmat_macros::runtime_builtin;
11
12use crate::builtins::common::{
13 gpu_helpers,
14 shape::{canonical_scalar_shape, normalize_scalar_shape},
15 spec::{
16 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
17 ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
18 },
19 tensor,
20};
21use crate::builtins::logical::type_resolvers::logical_like;
22
23use crate::{build_runtime_error, BuiltinResult, RuntimeError};
24
25#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::logical::ops")]
26pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
27 name: "logical",
28 op_kind: GpuOpKind::Elementwise,
29 supported_precisions: &[ScalarType::F32, ScalarType::F64],
30 broadcast: BroadcastSemantics::Matlab,
31 provider_hooks: &[ProviderHook::Binary {
32 name: "elem_ne",
33 commutative: true,
34 }],
35 constant_strategy: ConstantStrategy::InlineLiteral,
36 residency: ResidencyPolicy::NewHandle,
37 nan_mode: ReductionNaN::Include,
38 two_pass_threshold: None,
39 workgroup_size: None,
40 accepts_nan_mode: false,
41 notes: "Preferred path issues elem_ne(X, 0) on the device; missing hooks trigger a gather → host cast → re-upload sequence flagged as logical.",
42};
43
44#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::logical::ops")]
45pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
46 name: "logical",
47 shape: ShapeRequirements::BroadcastCompatible,
48 constant_strategy: ConstantStrategy::InlineLiteral,
49 elementwise: None,
50 reduction: None,
51 emits_nan: false,
52 notes: "Fusion support will arrive alongside a dedicated WGSL template; today the builtin executes outside fusion plans.",
53};
54
55const BUILTIN_NAME: &str = "logical";
56
57const LOGICAL_OUTPUT: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
58 name: "tf",
59 ty: BuiltinParamType::LogicalArray,
60 arity: BuiltinParamArity::Required,
61 default: None,
62 description: "Logical-converted result.",
63}];
64
65const LOGICAL_INPUTS: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
66 name: "A",
67 ty: BuiltinParamType::Any,
68 arity: BuiltinParamArity::Required,
69 default: None,
70 description: "Input value to convert.",
71}];
72
73const LOGICAL_SIGNATURES: [BuiltinSignatureDescriptor; 1] = [BuiltinSignatureDescriptor {
74 label: "tf = logical(A)",
75 inputs: &LOGICAL_INPUTS,
76 outputs: &LOGICAL_OUTPUT,
77}];
78
79const LOGICAL_ERROR_TOO_MANY_INPUTS: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
80 code: "RM.LOGICAL.TOO_MANY_INPUTS",
81 identifier: Some("RunMat:logical:TooManyInputs"),
82 when: "More than one input argument is provided.",
83 message: "logical: too many input arguments",
84};
85
86const LOGICAL_ERROR_CONVERSION_NOT_POSSIBLE: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
87 code: "RM.LOGICAL.CONVERSION_NOT_POSSIBLE",
88 identifier: Some("RunMat:logical:ConversionNotPossible"),
89 when: "Input type cannot be converted to logical.",
90 message: "logical: conversion to logical is not possible for this input type",
91};
92
93const LOGICAL_ERROR_GPU_GATHER_FAILED: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
94 code: "RM.LOGICAL.GPU_GATHER_FAILED",
95 identifier: Some("RunMat:logical:GpuGatherFailed"),
96 when: "GPU input gather fails during host fallback.",
97 message: "logical: failed to gather gpuArray input",
98};
99
100const LOGICAL_ERROR_INTERNAL: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
101 code: "RM.LOGICAL.INTERNAL",
102 identifier: Some("RunMat:logical:InternalError"),
103 when: "Internal logical buffer materialization fails.",
104 message: "logical: internal conversion error",
105};
106
107const LOGICAL_ERRORS: [BuiltinErrorDescriptor; 4] = [
108 LOGICAL_ERROR_TOO_MANY_INPUTS,
109 LOGICAL_ERROR_CONVERSION_NOT_POSSIBLE,
110 LOGICAL_ERROR_GPU_GATHER_FAILED,
111 LOGICAL_ERROR_INTERNAL,
112];
113
114pub const LOGICAL_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
115 signatures: &LOGICAL_SIGNATURES,
116 output_mode: BuiltinOutputMode::Fixed,
117 completion_policy: BuiltinCompletionPolicy::Public,
118 errors: &LOGICAL_ERRORS,
119};
120
121fn logical_type(args: &[Type], _context: &ResolveContext) -> Type {
122 args.first().map(logical_like).unwrap_or(Type::logical())
123}
124
125fn logical_error_with_message(
126 message: impl Into<String>,
127 error: &'static BuiltinErrorDescriptor,
128) -> RuntimeError {
129 let mut builder = build_runtime_error(message).with_builtin(BUILTIN_NAME);
130 if let Some(identifier) = error.identifier {
131 builder = builder.with_identifier(identifier);
132 }
133 builder.build()
134}
135
136#[runtime_builtin(
137 name = "logical",
138 category = "logical",
139 summary = "Convert scalars, arrays, and gpuArray values to logical outputs.",
140 keywords = "logical,boolean,gpuArray,mask,conversion",
141 accel = "unary",
142 type_resolver(logical_type),
143 descriptor(crate::builtins::logical::ops::LOGICAL_DESCRIPTOR),
144 builtin_path = "crate::builtins::logical::ops"
145)]
146async fn logical_builtin(value: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
147 if !rest.is_empty() {
148 return Err(logical_error_with_message(
149 LOGICAL_ERROR_TOO_MANY_INPUTS.message,
150 &LOGICAL_ERROR_TOO_MANY_INPUTS,
151 ));
152 }
153 convert_value_to_logical(value).await
154}
155
156async fn convert_value_to_logical(value: Value) -> BuiltinResult<Value> {
157 match value {
158 Value::Bool(_) | Value::LogicalArray(_) => Ok(value),
159 Value::Num(n) => Ok(Value::Bool(n != 0.0)),
160 Value::Int(i) => Ok(Value::Bool(!i.is_zero())),
161 Value::Complex(re, im) => Ok(Value::Bool(!complex_is_zero(re, im))),
162 Value::Tensor(tensor) => logical_from_tensor(tensor),
163 Value::SparseTensor(sparse) => logical_from_sparse_tensor(sparse),
164 Value::ComplexTensor(tensor) => logical_from_complex_tensor(tensor),
165 Value::CharArray(chars) => logical_from_char_array(chars),
166 Value::StringArray(strings) => logical_from_string_array(strings),
167 Value::GpuTensor(handle) => logical_from_gpu(handle).await,
168 Value::String(_) => Err(conversion_error("string")),
169 Value::Cell(_) => Err(conversion_error("cell")),
170 Value::Struct(_) => Err(conversion_error("struct")),
171 Value::Object(obj) => Err(conversion_error(&obj.class_name)),
172 Value::HandleObject(handle) => Err(conversion_error(&handle.class_name)),
173 Value::Listener(_) => Err(conversion_error("event.listener")),
174 Value::FunctionHandle(_)
175 | Value::ExternalFunctionHandle(_)
176 | Value::MethodFunctionHandle(_)
177 | Value::BoundFunctionHandle { .. }
178 | Value::Closure(_) => Err(conversion_error("function_handle")),
179 Value::ClassRef(_) => Err(conversion_error("meta.class")),
180 Value::MException(_) => Err(conversion_error("MException")),
181 Value::OutputList(_) => Err(conversion_error("OutputList")),
182 }
183}
184
185fn logical_from_tensor(tensor: Tensor) -> BuiltinResult<Value> {
186 let buffer = LogicalBuffer::from_real_tensor(&tensor);
187 logical_buffer_to_host(buffer)
188}
189
190fn logical_from_sparse_tensor(sparse: runmat_builtins::SparseTensor) -> BuiltinResult<Value> {
191 let tensor = sparse.to_dense().map_err(|err| {
192 logical_error_with_message(
193 format!("logical: failed to densify sparse input: {err}"),
194 &LOGICAL_ERROR_INTERNAL,
195 )
196 })?;
197 logical_from_tensor(tensor)
198}
199
200fn logical_from_complex_tensor(tensor: ComplexTensor) -> BuiltinResult<Value> {
201 let buffer = LogicalBuffer::from_complex_tensor(&tensor);
202 logical_buffer_to_host(buffer)
203}
204
205fn logical_from_char_array(chars: CharArray) -> BuiltinResult<Value> {
206 let buffer = LogicalBuffer::from_char_array(&chars);
207 logical_buffer_to_host(buffer)
208}
209
210fn logical_from_string_array(strings: StringArray) -> BuiltinResult<Value> {
211 let bits: Vec<u8> = strings
212 .data
213 .iter()
214 .map(|s| if s.is_empty() { 0 } else { 1 })
215 .collect();
216 let shape = canonical_shape(&strings.shape, bits.len());
217 logical_buffer_to_host(LogicalBuffer { bits, shape })
218}
219
220async fn logical_from_gpu(handle: GpuTensorHandle) -> BuiltinResult<Value> {
221 if runmat_accelerate_api::handle_is_logical(&handle) {
222 return Ok(Value::GpuTensor(handle));
223 }
224
225 let provider = runmat_accelerate_api::provider();
226
227 if let Some(p) = provider {
228 match p.logical_islogical(&handle) {
229 Ok(true) => {
230 runmat_accelerate_api::set_handle_logical(&handle, true);
231 return Ok(Value::GpuTensor(handle));
232 }
233 Ok(false) => {}
234 Err(err) => {
235 trace!("logical: provider logical_islogical hook unavailable, falling back ({err})")
236 }
237 }
238 if let Some(result) = try_gpu_cast(p, &handle).await {
239 return Ok(gpu_helpers::logical_gpu_value(result));
240 } else {
241 trace!(
242 "logical: provider elem_ne/zeros_like unavailable for buffer {} – gathering",
243 handle.buffer_id
244 );
245 }
246 }
247
248 let tensor = gpu_helpers::gather_tensor_async(&handle)
249 .await
250 .map_err(|err| {
251 logical_error_with_message(
252 format!("{BUILTIN_NAME}: {err}"),
253 &LOGICAL_ERROR_GPU_GATHER_FAILED,
254 )
255 })?;
256 let buffer = LogicalBuffer::from_real_tensor(&tensor);
257 logical_buffer_to_gpu(buffer, provider)
258}
259
260fn logical_buffer_to_host(buffer: LogicalBuffer) -> BuiltinResult<Value> {
261 let LogicalBuffer { bits, shape } = buffer;
262 if tensor::element_count(&shape) == 1 && bits.len() == 1 {
263 Ok(Value::Bool(bits[0] != 0))
264 } else {
265 LogicalArray::new(bits, shape)
266 .map(Value::LogicalArray)
267 .map_err(|e| {
268 logical_error_with_message(format!("logical: {e}"), &LOGICAL_ERROR_INTERNAL)
269 })
270 }
271}
272
273fn logical_buffer_to_gpu(
274 buffer: LogicalBuffer,
275 provider: Option<&'static dyn AccelProvider>,
276) -> BuiltinResult<Value> {
277 if let Some(p) = provider {
278 let floats: Vec<f64> = buffer
279 .bits
280 .iter()
281 .map(|&b| if b != 0 { 1.0 } else { 0.0 })
282 .collect();
283 let view = HostTensorView {
284 data: &floats,
285 shape: &buffer.shape,
286 };
287 match p.upload(&view) {
288 Ok(handle) => Ok(gpu_helpers::logical_gpu_value(handle)),
289 Err(err) => {
290 trace!("logical: upload failed during fallback path ({err})");
291 logical_buffer_to_host(buffer)
292 }
293 }
294 } else {
295 logical_buffer_to_host(buffer)
296 }
297}
298
299async fn try_gpu_cast(
300 provider: &'static dyn AccelProvider,
301 input: &GpuTensorHandle,
302) -> Option<GpuTensorHandle> {
303 let zeros = provider.zeros_like(input).ok()?;
304 let result = provider.elem_ne(input, &zeros).await.ok();
305 let _ = provider.free(&zeros);
306 result
307}
308
309fn complex_is_zero(re: f64, im: f64) -> bool {
310 re == 0.0 && im == 0.0
311}
312
313fn conversion_error(type_name: &str) -> RuntimeError {
314 logical_error_with_message(
315 format!(
316 "logical: conversion to logical from {} is not possible",
317 type_name
318 ),
319 &LOGICAL_ERROR_CONVERSION_NOT_POSSIBLE,
320 )
321}
322
323#[derive(Clone)]
324struct LogicalBuffer {
325 bits: Vec<u8>,
326 shape: Vec<usize>,
327}
328
329impl LogicalBuffer {
330 fn from_real_tensor(tensor: &Tensor) -> Self {
331 let bits: Vec<u8> = tensor
332 .data
333 .iter()
334 .map(|&v| if v != 0.0 { 1 } else { 0 })
335 .collect();
336 let shape = canonical_shape(&tensor.shape, bits.len());
337 Self { bits, shape }
338 }
339
340 fn from_complex_tensor(tensor: &ComplexTensor) -> Self {
341 let bits: Vec<u8> = tensor
342 .data
343 .iter()
344 .map(|&(re, im)| if !complex_is_zero(re, im) { 1 } else { 0 })
345 .collect();
346 let shape = canonical_shape(&tensor.shape, bits.len());
347 Self { bits, shape }
348 }
349
350 fn from_char_array(chars: &CharArray) -> Self {
351 let bits: Vec<u8> = chars
352 .data
353 .iter()
354 .map(|&ch| if (ch as u32) != 0 { 1 } else { 0 })
355 .collect();
356 let original_shape = vec![chars.rows, chars.cols];
357 let shape = canonical_shape(&original_shape, bits.len());
358 Self { bits, shape }
359 }
360}
361
362fn canonical_shape(shape: &[usize], len: usize) -> Vec<usize> {
363 if tensor::element_count(shape) == len {
364 return normalize_scalar_shape(shape);
365 }
366 if len == 0 {
367 if shape.len() > 1 {
368 return shape.to_vec();
369 }
370 return vec![0];
371 }
372 if len == 1 {
373 canonical_scalar_shape()
374 } else {
375 vec![len, 1]
376 }
377}
378
379#[cfg(test)]
380pub(crate) mod tests {
381 use super::*;
382 use crate::builtins::common::test_support;
383 use futures::executor::block_on;
384 use runmat_accelerate_api::HostTensorView;
385 use runmat_builtins::{
386 CellArray, IntValue, MException, ObjectInstance, SparseTensor, StructValue,
387 };
388
389 fn logical_builtin(value: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
390 block_on(super::logical_builtin(value, rest))
391 }
392
393 fn assert_error_message(err: &crate::RuntimeError, expected: &str) {
394 assert_eq!(err.message(), expected);
395 }
396
397 fn assert_error_contains(err: &crate::RuntimeError, expected: &str) {
398 assert!(
399 err.message().contains(expected),
400 "unexpected error: {}",
401 err.message()
402 );
403 }
404
405 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
406 #[test]
407 fn logical_scalar_num() {
408 let result = logical_builtin(Value::Num(5.0), Vec::new()).expect("logical");
409 assert_eq!(result, Value::Bool(true));
410
411 let zero_result = logical_builtin(Value::Num(0.0), Vec::new()).expect("logical");
412 assert_eq!(zero_result, Value::Bool(false));
413 }
414
415 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
416 #[test]
417 fn logical_nan_is_true() {
418 let tensor = Tensor::new(vec![0.0, f64::NAN, -0.0], vec![1, 3]).unwrap();
419 let result = logical_builtin(Value::Tensor(tensor), Vec::new()).expect("logical");
420 match result {
421 Value::LogicalArray(array) => assert_eq!(array.data, vec![0, 1, 0]),
422 other => panic!("expected logical array, got {:?}", other),
423 }
424 }
425
426 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
427 #[test]
428 fn logical_tensor_matrix() {
429 let tensor = Tensor::new(vec![0.0, 2.0, -3.0, 0.0], vec![2, 2]).unwrap();
430 let result = logical_builtin(Value::Tensor(tensor), Vec::new()).expect("logical");
431 match result {
432 Value::LogicalArray(array) => {
433 assert_eq!(array.shape, vec![2, 2]);
434 assert_eq!(array.data, vec![0, 1, 1, 0]);
435 }
436 other => panic!("expected logical array, got {:?}", other),
437 }
438 }
439
440 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
441 #[test]
442 fn logical_sparse_tensor_densifies() {
443 let sparse = SparseTensor::new(3, 2, vec![0, 1, 2], vec![1, 2], vec![4.0, -1.0]).unwrap();
444 let result = logical_builtin(Value::SparseTensor(sparse), Vec::new()).expect("logical");
445 match result {
446 Value::LogicalArray(array) => {
447 assert_eq!(array.shape, vec![3, 2]);
448 assert_eq!(array.data, vec![0, 1, 0, 0, 0, 1]);
449 }
450 other => panic!("expected logical array, got {:?}", other),
451 }
452 }
453
454 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
455 #[test]
456 fn logical_complex_conversion() {
457 let complex =
458 ComplexTensor::new(vec![(0.0, 0.0), (1.0, 0.0), (0.0, 2.0)], vec![3, 1]).unwrap();
459 let result = logical_builtin(Value::ComplexTensor(complex), Vec::new()).expect("logical");
460 match result {
461 Value::LogicalArray(array) => {
462 assert_eq!(array.data, vec![0, 1, 1]);
463 }
464 other => panic!("expected logical array, got {:?}", other),
465 }
466 }
467
468 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
469 #[test]
470 fn logical_char_array_conversion() {
471 let chars = CharArray::new(vec!['A', '\0', 'C'], 1, 3).unwrap();
472 let result = logical_builtin(Value::CharArray(chars), Vec::new()).expect("logical");
473 match result {
474 Value::LogicalArray(array) => assert_eq!(array.data, vec![1, 0, 1]),
475 other => panic!("expected logical array, got {:?}", other),
476 }
477 }
478
479 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
480 #[test]
481 fn logical_string_error() {
482 let err = logical_builtin(Value::String("runmat".to_string()), Vec::new()).unwrap_err();
483 assert_error_message(
484 &err,
485 "logical: conversion to logical from string is not possible",
486 );
487 assert_eq!(
488 err.identifier(),
489 LOGICAL_ERROR_CONVERSION_NOT_POSSIBLE.identifier
490 );
491 }
492
493 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
494 #[test]
495 fn logical_struct_error() {
496 let mut st = StructValue::new();
497 st.insert("field", Value::Num(1.0));
498 let err = logical_builtin(Value::Struct(st), Vec::new()).unwrap_err();
499 assert_error_contains(&err, "struct");
500 assert_eq!(
501 err.identifier(),
502 LOGICAL_ERROR_CONVERSION_NOT_POSSIBLE.identifier
503 );
504 }
505
506 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
507 #[test]
508 fn logical_cell_error() {
509 let cell = CellArray::new(vec![Value::Num(1.0)], 1, 1).expect("cell creation");
510 let err = logical_builtin(Value::Cell(cell), Vec::new()).unwrap_err();
511 assert_error_message(
512 &err,
513 "logical: conversion to logical from cell is not possible",
514 );
515 assert_eq!(
516 err.identifier(),
517 LOGICAL_ERROR_CONVERSION_NOT_POSSIBLE.identifier
518 );
519 }
520
521 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
522 #[test]
523 fn logical_function_handle_error() {
524 let err = logical_builtin(Value::FunctionHandle("foo".into()), Vec::new()).unwrap_err();
525 assert_error_message(
526 &err,
527 "logical: conversion to logical from function_handle is not possible",
528 );
529 assert_eq!(
530 err.identifier(),
531 LOGICAL_ERROR_CONVERSION_NOT_POSSIBLE.identifier
532 );
533 }
534
535 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
536 #[test]
537 fn logical_object_error() {
538 let obj = ObjectInstance::new("DemoClass".to_string());
539 let err = logical_builtin(Value::Object(obj), Vec::new()).unwrap_err();
540 assert_error_contains(&err, "DemoClass");
541 assert_eq!(
542 err.identifier(),
543 LOGICAL_ERROR_CONVERSION_NOT_POSSIBLE.identifier
544 );
545 }
546
547 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
548 #[test]
549 fn logical_mexception_error() {
550 let mex = MException::new("id:logical".into(), "message".into());
551 let err = logical_builtin(Value::MException(mex), Vec::new()).unwrap_err();
552 assert_error_message(
553 &err,
554 "logical: conversion to logical from MException is not possible",
555 );
556 assert_eq!(
557 err.identifier(),
558 LOGICAL_ERROR_CONVERSION_NOT_POSSIBLE.identifier
559 );
560 }
561
562 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
563 #[test]
564 fn logical_too_many_inputs_error() {
565 let err = logical_builtin(Value::Bool(true), vec![Value::Bool(false)]).unwrap_err();
566 assert_error_message(&err, LOGICAL_ERROR_TOO_MANY_INPUTS.message);
567 assert_eq!(err.identifier(), LOGICAL_ERROR_TOO_MANY_INPUTS.identifier);
568 }
569
570 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
571 #[test]
572 fn logical_gpu_roundtrip() {
573 test_support::with_test_provider(|provider| {
574 let tensor = Tensor::new(vec![0.0, 1.0, -2.0], vec![3, 1]).unwrap();
575 let view = HostTensorView {
576 data: &tensor.data,
577 shape: &tensor.shape,
578 };
579 let handle = provider.upload(&view).expect("upload");
580 let result =
581 logical_builtin(Value::GpuTensor(handle.clone()), Vec::new()).expect("logical");
582 let gathered = test_support::gather(result.clone()).expect("gather");
583 assert_eq!(gathered.data, vec![0.0, 1.0, 1.0]);
584 if let Value::GpuTensor(out) = result {
585 assert!(runmat_accelerate_api::handle_is_logical(&out));
586 } else {
587 panic!("expected gpu tensor output");
588 }
589 });
590 }
591
592 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
593 #[test]
594 fn logical_gpu_passthrough_for_logical_handle() {
595 test_support::with_test_provider(|provider| {
596 let tensor = Tensor::new(vec![0.0, 1.0], vec![2, 1]).unwrap();
597 let view = HostTensorView {
598 data: &tensor.data,
599 shape: &tensor.shape,
600 };
601 let handle = provider.upload(&view).expect("upload");
602 runmat_accelerate_api::set_handle_logical(&handle, true);
603 let result =
604 logical_builtin(Value::GpuTensor(handle.clone()), Vec::new()).expect("logical");
605 match result {
606 Value::GpuTensor(out) => assert_eq!(out, handle),
607 other => panic!("expected gpu tensor, got {:?}", other),
608 }
609 });
610 }
611
612 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
613 #[test]
614 fn logical_bool_and_logical_inputs_passthrough() {
615 let res_bool = logical_builtin(Value::Bool(true), Vec::new()).expect("logical");
616 assert_eq!(res_bool, Value::Bool(true));
617
618 let logical = LogicalArray::new(vec![1, 0], vec![1, 2]).unwrap();
619 let res_array =
620 logical_builtin(Value::LogicalArray(logical.clone()), Vec::new()).expect("logical");
621 assert_eq!(res_array, Value::LogicalArray(logical));
622 }
623
624 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
625 #[test]
626 fn logical_empty_tensor_preserves_shape() {
627 let tensor = Tensor::new(Vec::new(), vec![0, 3]).unwrap();
628 let result = logical_builtin(Value::Tensor(tensor), Vec::new()).expect("logical");
629 match result {
630 Value::LogicalArray(array) => {
631 assert!(array.data.is_empty());
632 assert_eq!(array.shape, vec![0, 3]);
633 }
634 other => panic!("expected logical array, got {:?}", other),
635 }
636 }
637
638 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
639 #[test]
640 fn logical_integer_scalar() {
641 let res = logical_builtin(Value::Int(IntValue::I32(0)), Vec::new()).expect("logical");
642 assert_eq!(res, Value::Bool(false));
643
644 let res_nonzero =
645 logical_builtin(Value::Int(IntValue::I32(-5)), Vec::new()).expect("logical");
646 assert_eq!(res_nonzero, Value::Bool(true));
647 }
648
649 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
650 #[test]
651 #[cfg(feature = "wgpu")]
652 fn logical_wgpu_matches_cpu_conversion() {
653 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
654 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
655 );
656
657 let tensor = Tensor::new(vec![0.0, 2.0, -3.0, f64::NAN], vec![2, 2]).unwrap();
658 let cpu = logical_builtin(Value::Tensor(tensor.clone()), Vec::new()).unwrap();
659
660 let view = runmat_accelerate_api::HostTensorView {
661 data: &tensor.data,
662 shape: &tensor.shape,
663 };
664 let provider = runmat_accelerate_api::provider().expect("wgpu provider");
665 let handle = provider.upload(&view).expect("upload");
666
667 let gpu_value = logical_builtin(Value::GpuTensor(handle), Vec::new()).unwrap();
668 let out_handle = match gpu_value {
669 Value::GpuTensor(ref h) => {
670 assert!(runmat_accelerate_api::handle_is_logical(h));
671 h.clone()
672 }
673 other => panic!("expected gpu tensor, got {other:?}"),
674 };
675
676 let gathered = test_support::gather(Value::GpuTensor(out_handle)).expect("gather");
677
678 let (expected, expected_shape): (Vec<f64>, Vec<usize>) = match cpu {
679 Value::LogicalArray(arr) => (
680 arr.data
681 .iter()
682 .map(|&b| if b != 0 { 1.0 } else { 0.0 })
683 .collect(),
684 arr.shape.clone(),
685 ),
686 Value::Bool(flag) => (vec![if flag { 1.0 } else { 0.0 }], vec![1, 1]),
687 other => panic!("unexpected cpu result {other:?}"),
688 };
689
690 assert_eq!(gathered.shape, expected_shape);
691 assert_eq!(gathered.data, expected);
692 }
693}