Skip to main content

simd_kernels/kernels/
logical.rs

1// Copyright (c) 2025 SpaceCell Enterprises Ltd
2// SPDX-License-Identifier: AGPL-3.0-or-later
3// Commercial licensing available. See LICENSE and LICENSING.md.
4
5//! # **Logical Operations Kernels Module** - *Boolean Logic and Set Operations*
6//!
7//! Logical operation kernels providing efficient boolean algebra, set membership testing,
8//! and range operations with SIMD acceleration and null-aware semantics. Critical foundation
9//! for query execution, filtering predicates, and analytical data processing workflows.
10//!
11//! ## Core Operations
12//! - **Boolean algebra**: AND, OR, XOR, NOT operations on boolean arrays with bitmask optimisation
13//! - **Set membership**: IN and NOT IN operations with hash-based lookup optimisation
14//! - **Range operations**: BETWEEN predicates for numeric and string data types
15//! - **Pattern matching**: String pattern matching with optimised prefix/suffix detection
16//! - **Null-aware logic**: Three-valued logic implementation following SQL semantics
17//! - **Compound predicates**: Efficient evaluation of complex multi-condition expressions
18
19include!(concat!(env!("OUT_DIR"), "/simd_lanes.rs"));
20
21use std::collections::HashSet;
22use std::hash::Hash;
23use std::marker::PhantomData;
24#[cfg(feature = "simd")]
25use std::simd::{Mask, Simd, cmp::SimdPartialEq, cmp::SimdPartialOrd, num::SimdFloat};
26
27use minarrow::kernels::arithmetic::string::MAX_DICT_CHECK;
28use minarrow::traits::type_unions::Float;
29use minarrow::{
30    Array, Bitmask, BooleanAVT, BooleanArray, CategoricalAVT, Integer, MaskedArray, Numeric,
31    NumericArray, StringAVT, TextArray, Vec64,
32};
33
34#[cfg(not(feature = "simd"))]
35use crate::kernels::bitmask::dispatch::{and_masks, or_masks, xor_masks};
36use crate::operators::LogicalOperator;
37use minarrow::enums::error::KernelError;
38#[cfg(feature = "simd")]
39use minarrow::kernels::bitmask::simd::{and_masks_simd, or_masks_simd, xor_masks_simd};
40use minarrow::utils::confirm_mask_capacity;
41
42#[cfg(feature = "simd")]
43use minarrow::utils::is_simd_aligned;
44use std::any::TypeId;
45
46/// Builds the Boolean result buffer.
47/// `len` – number of rows that will be written.
48#[inline(always)]
49fn new_bool_buffer(len: usize) -> Bitmask {
50    Bitmask::new_set_all(len, false)
51}
52
53// Between
54
55macro_rules! impl_between_numeric {
56    ($name:ident, $name_to:ident, $ty:ty, $mask_elem:ty, $lanes:expr) => {
57        /// Zero-allocation variant: writes directly to caller's output buffer.
58        ///
59        /// Test if LHS values fall between RHS min/max bounds.
60        /// The output Bitmask must have capacity >= lhs.len().
61        #[inline(always)]
62        pub fn $name_to(
63            lhs: &[$ty],
64            rhs: &[$ty],
65            mask: Option<&Bitmask>,
66            has_nulls: bool,
67            output: &mut Bitmask,
68        ) -> Result<(), KernelError> {
69            let len = lhs.len();
70            if rhs.len() != 2 && rhs.len() != 2 * len {
71                return Err(KernelError::InvalidArguments(
72                    format!("between: RHS must have len 2 or 2×LHS (got lhs: {}, rhs: {})", len, rhs.len())
73                ));
74            }
75
76            if let Some(m) = mask {
77                if m.capacity() < len {
78                    return Err(KernelError::InvalidArguments(
79                        format!("between: mask (Bitmask) capacity must be ≥ len (got capacity: {}, len: {})", m.capacity(), len)
80                    ));
81                }
82            }
83            assert!(output.capacity() >= len, concat!(stringify!($name_to), ": output capacity too small"));
84
85            // SIMD fast-path
86            #[cfg(feature = "simd")]
87            {
88                // Check if both arrays are 64-byte aligned for SIMD
89                if is_simd_aligned(lhs) && is_simd_aligned(rhs) {
90                    const N: usize = $lanes;
91                    type V = Simd<$ty, N>;
92                    type M = Mask<$mask_elem, N>;
93
94                    if !has_nulls && rhs.len() == 2 {
95                        let min_v = V::splat(rhs[0]);
96                        let max_v = V::splat(rhs[1]);
97
98                        let mut i = 0usize;
99                        while i + N <= len {
100                            let x = V::from_slice(&lhs[i..i + N]);
101                            let m: M = x.simd_ge(min_v) & x.simd_le(max_v);
102                            let bm = m.to_bitmask();
103
104                            for l in 0..N {
105                                if ((bm >> l) & 1) == 1 {
106                                    output.set(i + l, true);
107                                }
108                            }
109                            i += N;
110                        }
111                        // fall back to scalar for tail
112                        for j in i..len {
113                            if lhs[j] >= rhs[0] && lhs[j] <= rhs[1] {
114                                output.set(j, true);
115                            }
116                        }
117
118                        return Ok(());
119                    }
120                }
121                // Fall through to scalar path if alignment check failed
122            }
123
124            // Scalar / null-aware path
125            if rhs.len() == 2 {
126                let (min, max) = (rhs[0], rhs[1]);
127                for i in 0..len {
128                    if (!has_nulls || mask.map_or(true, |m| unsafe { m.get_unchecked(i) }))
129                        && lhs[i] >= min
130                        && lhs[i] <= max
131                    {
132                        output.set(i, true);
133                    }
134                }
135            } else {
136                // per-row min / max
137                for i in 0..len {
138                    let min = rhs[i * 2];
139                    let max = rhs[i * 2 + 1];
140                    if (!has_nulls || mask.map_or(true, |m| unsafe { m.get_unchecked(i) }))
141                        && lhs[i] >= min
142                        && lhs[i] <= max
143                    {
144                        output.set(i, true);
145                    }
146                }
147            }
148
149            Ok(())
150        }
151
152        /// Test if LHS values fall between RHS min/max bounds, producing boolean result array.
153        #[inline(always)]
154        pub fn $name(
155            lhs: &[$ty],
156            rhs: &[$ty],
157            mask: Option<&Bitmask>,
158            has_nulls: bool
159        ) -> Result<BooleanArray<()>, KernelError> {
160            let len = lhs.len();
161            let mut out_data = new_bool_buffer(len);
162            $name_to(lhs, rhs, mask, has_nulls, &mut out_data)?;
163            Ok(BooleanArray {
164                data: out_data.into(),
165                null_mask: mask.cloned(),
166                len,
167                _phantom: PhantomData
168            })
169        }
170    };
171}
172
173// floats
174
175#[inline(always)]
176fn between_generic<T: Numeric + Copy + std::cmp::PartialOrd>(
177    lhs: &[T],
178    rhs: &[T],
179    mask: Option<&Bitmask>,
180    has_nulls: bool,
181) -> Result<BooleanArray<()>, KernelError> {
182    let len = lhs.len();
183    let mut out = new_bool_buffer(len);
184    let _ = confirm_mask_capacity(len, mask)?;
185    if rhs.len() == 2 {
186        let (min, max) = (rhs[0], rhs[1]);
187        for i in 0..len {
188            if (!has_nulls || mask.map_or(true, |m| unsafe { m.get_unchecked(i) }))
189                && lhs[i] >= min
190                && lhs[i] <= max
191            {
192                out.set(i, true);
193            }
194        }
195    } else {
196        for i in 0..len {
197            let min = rhs[i * 2];
198            let max = rhs[i * 2 + 1];
199            if (!has_nulls || mask.map_or(true, |m| unsafe { m.get_unchecked(i) }))
200                && lhs[i] >= min
201                && lhs[i] <= max
202            {
203                out.set(i, true);
204            }
205        }
206    }
207
208    Ok(BooleanArray {
209        data: out.into(),
210        null_mask: mask.cloned(),
211        len,
212        _phantom: PhantomData,
213    })
214}
215
216// In and Not In
217
218macro_rules! impl_in_int {
219    ($name:ident, $name_to:ident, $ty:ty, $lanes:expr, $mask_elem:ty) => {
220        /// Zero-allocation variant: writes directly to caller's output buffer.
221        ///
222        /// Test membership of LHS integer values in RHS set.
223        /// The output Bitmask must have capacity >= lhs.len().
224        #[inline(always)]
225        pub fn $name_to(
226            lhs: &[$ty],
227            rhs: &[$ty],
228            mask: Option<&Bitmask>,
229            has_nulls: bool,
230            output: &mut Bitmask,
231        ) -> Result<(), KernelError> {
232            let len = lhs.len();
233            let _ = confirm_mask_capacity(len, mask)?;
234            assert!(
235                output.capacity() >= len,
236                concat!(stringify!($name_to), ": output capacity too small")
237            );
238
239            #[cfg(feature = "simd")]
240            {
241                // Check if both arrays are 64-byte aligned for SIMD
242                if is_simd_aligned(lhs) && is_simd_aligned(rhs) {
243                    use crate::utils::bitmask_to_simd_mask;
244                    use core::simd::{Mask, Simd};
245
246                    if rhs.len() <= 16 {
247                        let mut i = 0;
248                        let rhs_simd = rhs;
249                        if !has_nulls {
250                            while i + $lanes <= len {
251                                let x = Simd::<$ty, $lanes>::from_slice(&lhs[i..i + $lanes]);
252                                let mut m = Mask::<$mask_elem, $lanes>::splat(false);
253                                for &v in rhs_simd {
254                                    m |= x.simd_eq(Simd::<$ty, $lanes>::splat(v));
255                                }
256                                let bm = m.to_bitmask();
257                                for l in 0..$lanes {
258                                    if ((bm >> l) & 1) == 1 {
259                                        output.set(i + l, true);
260                                    }
261                                }
262                                i += $lanes;
263                            }
264                            for j in i..len {
265                                if rhs_simd.contains(&lhs[j]) {
266                                    output.set(j, true);
267                                }
268                            }
269                            return Ok(());
270                        } else {
271                            // ---- SIMD + nulls: use bitmask_to_simd_mask
272                            let mb = mask.expect("Bitmask must be Some if has_nulls is set");
273                            let mask_bytes = mb.as_bytes();
274                            while i + $lanes <= len {
275                                let x = Simd::<$ty, $lanes>::from_slice(&lhs[i..i + $lanes]);
276                                // valid lanes
277                                let lane_mask =
278                                    bitmask_to_simd_mask::<$lanes, $mask_elem>(mask_bytes, i, len);
279                                let mut in_mask = Mask::<$mask_elem, $lanes>::splat(false);
280                                for &v in rhs_simd {
281                                    in_mask |= x.simd_eq(Simd::<$ty, $lanes>::splat(v));
282                                }
283                                // Only set bits for lanes that are both valid and match RHS
284                                let valid_in = lane_mask & in_mask;
285                                let bm = valid_in.to_bitmask();
286                                for l in 0..$lanes {
287                                    if ((bm >> l) & 1) == 1 {
288                                        output.set(i + l, true);
289                                    }
290                                }
291                                i += $lanes;
292                            }
293                            for j in i..len {
294                                if unsafe { mb.get_unchecked(j) } && rhs_simd.contains(&lhs[j]) {
295                                    output.set(j, true);
296                                }
297                            }
298                            return Ok(());
299                        }
300                    }
301                }
302                // Fall through to scalar path if alignment check failed
303            }
304
305            // Scalar fallback (null-aware and large-RHS)
306            let set: std::collections::HashSet<$ty> = rhs.iter().copied().collect();
307            for i in 0..len {
308                if (!has_nulls || mask.map_or(true, |m| unsafe { m.get_unchecked(i) }))
309                    && set.contains(&lhs[i])
310                {
311                    output.set(i, true);
312                }
313            }
314            Ok(())
315        }
316
317        /// Test membership of LHS integer values in RHS set, producing boolean result array.
318        #[inline(always)]
319        pub fn $name(
320            lhs: &[$ty],
321            rhs: &[$ty],
322            mask: Option<&Bitmask>,
323            has_nulls: bool,
324        ) -> Result<BooleanArray<()>, KernelError> {
325            let len = lhs.len();
326            let mut out = new_bool_buffer(len);
327            $name_to(lhs, rhs, mask, has_nulls, &mut out)?;
328            Ok(BooleanArray {
329                data: out.into(),
330                null_mask: mask.cloned(),
331                len,
332                _phantom: PhantomData,
333            })
334        }
335    };
336}
337
338/// Implements SIMD/Scalar IN kernel for floats, handling NaN semantics and optional null mask.
339macro_rules! impl_in_float {
340    (
341        $fn_name:ident, $fn_name_to:ident, $ty:ty, $lanes:expr, $mask_elem:ty
342    ) => {
343        /// Zero-allocation variant: writes directly to caller's output buffer.
344        ///
345        /// Test membership of LHS floating-point values in RHS set with NaN handling.
346        /// The output Bitmask must have capacity >= lhs.len().
347        #[inline(always)]
348        pub fn $fn_name_to(
349            lhs: &[$ty],
350            rhs: &[$ty],
351            mask: Option<&Bitmask>,
352            has_nulls: bool,
353            output: &mut Bitmask,
354        ) -> Result<(), KernelError> {
355            let len = lhs.len();
356            let _ = confirm_mask_capacity(len, mask)?;
357            assert!(
358                output.capacity() >= len,
359                concat!(stringify!($fn_name_to), ": output capacity too small")
360            );
361
362            #[cfg(feature = "simd")]
363            {
364                // Check if both arrays are 64-byte aligned for SIMD
365                if is_simd_aligned(lhs) && is_simd_aligned(rhs) {
366                    use crate::utils::bitmask_to_simd_mask;
367                    use core::simd::{Mask, Simd};
368                    if rhs.len() <= 16 {
369                        let mut i = 0;
370                        if !has_nulls {
371                            while i + $lanes <= len {
372                                let x = Simd::<$ty, $lanes>::from_slice(&lhs[i..i + $lanes]);
373                                let mut m = Mask::<$mask_elem, $lanes>::splat(false);
374                                for &v in rhs {
375                                    let vmask = x.simd_eq(Simd::<$ty, $lanes>::splat(v))
376                                        | (x.is_nan() & Simd::<$ty, $lanes>::splat(v).is_nan());
377                                    m |= vmask;
378                                }
379                                let bm = m.to_bitmask();
380                                for l in 0..$lanes {
381                                    if ((bm >> l) & 1) == 1 {
382                                        output.set(i + l, true);
383                                    }
384                                }
385                                i += $lanes;
386                            }
387                            for j in i..len {
388                                let x = lhs[j];
389                                if rhs.iter().any(|&v| x == v || (x.is_nan() && v.is_nan())) {
390                                    output.set(j, true);
391                                }
392                            }
393                            return Ok(());
394                        } else {
395                            let mb = mask.expect("Bitmask must be Some if nulls are present");
396                            let mask_bytes = mb.as_bytes();
397                            while i + $lanes <= len {
398                                let x = Simd::<$ty, $lanes>::from_slice(&lhs[i..i + $lanes]);
399                                let lane_mask =
400                                    bitmask_to_simd_mask::<$lanes, $mask_elem>(mask_bytes, i, len);
401                                let mut m = Mask::<$mask_elem, $lanes>::splat(false);
402                                for &v in rhs {
403                                    let vmask = x.simd_eq(Simd::<$ty, $lanes>::splat(v))
404                                        | (x.is_nan() & Simd::<$ty, $lanes>::splat(v).is_nan());
405                                    m |= vmask;
406                                }
407                                let m = m & lane_mask;
408                                let bm = m.to_bitmask();
409                                for l in 0..$lanes {
410                                    if ((bm >> l) & 1) == 1 {
411                                        output.set(i + l, true);
412                                    }
413                                }
414                                i += $lanes;
415                            }
416                            for j in i..len {
417                                if mask.map_or(true, |m| unsafe { m.get_unchecked(j) }) {
418                                    let x = lhs[j];
419                                    if rhs.iter().any(|&v| x == v || (x.is_nan() && v.is_nan())) {
420                                        output.set(j, true);
421                                    }
422                                }
423                            }
424                            return Ok(());
425                        }
426                    }
427                }
428                // Fall through to scalar path if alignment check failed
429            }
430
431            // Scalar fallback
432            for i in 0..len {
433                if has_nulls && !mask.map_or(true, |m| unsafe { m.get_unchecked(i) }) {
434                    continue;
435                }
436                let x = lhs[i];
437                if rhs.iter().any(|&v| x == v || (x.is_nan() && v.is_nan())) {
438                    output.set(i, true);
439                }
440            }
441            Ok(())
442        }
443
444        /// Test membership of LHS floating-point values in RHS set with NaN handling.
445        #[inline(always)]
446        pub fn $fn_name(
447            lhs: &[$ty],
448            rhs: &[$ty],
449            mask: Option<&Bitmask>,
450            has_nulls: bool,
451        ) -> Result<BooleanArray<()>, KernelError> {
452            let len = lhs.len();
453            let mut out = new_bool_buffer(len);
454            $fn_name_to(lhs, rhs, mask, has_nulls, &mut out)?;
455            Ok(BooleanArray {
456                data: out.into(),
457                null_mask: mask.cloned(),
458                len,
459                _phantom: PhantomData,
460            })
461        }
462    };
463}
464
465// Correct MaskElement types per std::simd
466#[cfg(feature = "extended_numeric_types")]
467impl_in_int!(in_i8, in_i8_to, i8, W8, i8);
468#[cfg(feature = "extended_numeric_types")]
469impl_in_int!(in_u8, in_u8_to, u8, W8, i8);
470#[cfg(feature = "extended_numeric_types")]
471impl_in_int!(in_i16, in_i16_to, i16, W16, i16);
472#[cfg(feature = "extended_numeric_types")]
473impl_in_int!(in_u16, in_u16_to, u16, W16, i16);
474impl_in_int!(in_i32, in_i32_to, i32, W32, i32);
475impl_in_int!(in_u32, in_u32_to, u32, W32, i32);
476impl_in_int!(in_i64, in_i64_to, i64, W64, i64);
477impl_in_int!(in_u64, in_u64_to, u64, W64, i64);
478impl_in_float!(in_f32, in_f32_to, f32, W32, i32);
479impl_in_float!(in_f64, in_f64_to, f64, W64, i64);
480
481#[cfg(feature = "extended_numeric_types")]
482impl_between_numeric!(between_i8, between_i8_to, i8, i8, W8);
483#[cfg(feature = "extended_numeric_types")]
484impl_between_numeric!(between_u8, between_u8_to, u8, i8, W8);
485#[cfg(feature = "extended_numeric_types")]
486impl_between_numeric!(between_i16, between_i16_to, i16, i16, W16);
487#[cfg(feature = "extended_numeric_types")]
488impl_between_numeric!(between_u16, between_u16_to, u16, i16, W16);
489
490impl_between_numeric!(between_i32, between_i32_to, i32, i32, W32);
491impl_between_numeric!(between_u32, between_u32_to, u32, i32, W32);
492impl_between_numeric!(between_i64, between_i64_to, i64, i64, W64);
493impl_between_numeric!(between_u64, between_u64_to, u64, i64, W64);
494impl_between_numeric!(between_f32, between_f32_to, f32, i32, W32);
495impl_between_numeric!(between_f64, between_f64_to, f64, i64, W64);
496
497// String and dictionary
498
499/// Test if LHS string values fall lexicographically between RHS min/max bounds.
500#[inline(always)]
501pub fn cmp_str_between<'a, T: Integer>(
502    lhs: StringAVT<'a, T>,
503    rhs: StringAVT<'a, T>,
504) -> Result<BooleanArray<()>, KernelError> {
505    let (larr, loff, llen) = lhs;
506    let (rarr, roff, rlen) = rhs;
507
508    if rlen < 2 {
509        return Err(KernelError::InvalidArguments(format!(
510            "str_between: RHS must contain at least two values (got {})",
511            rlen
512        )));
513    }
514    let min = rarr.get(roff).unwrap_or("");
515    let max = rarr.get(roff + 1).unwrap_or("");
516    let mask = larr.null_mask.as_ref().map(|m| m.slice_clone(loff, llen));
517    let _ = confirm_mask_capacity(llen, mask.as_ref())?;
518
519    let mut out = new_bool_buffer(llen);
520
521    for i in 0..llen {
522        if mask
523            .as_ref()
524            .map_or(true, |m| unsafe { m.get_unchecked(i) })
525        {
526            let s = unsafe { larr.get_str_unchecked(loff + i) };
527            if s >= min && s <= max {
528                unsafe { out.set_unchecked(i, true) };
529            }
530        }
531    }
532
533    Ok(BooleanArray {
534        data: out.into(),
535        null_mask: mask,
536        len: llen,
537        _phantom: PhantomData,
538    })
539}
540
541#[inline(always)]
542/// Test membership of LHS string values in RHS string set.
543pub fn cmp_str_in<'a, T: Integer>(
544    lhs: StringAVT<'a, T>,
545    rhs: StringAVT<'a, T>,
546) -> Result<BooleanArray<()>, KernelError> {
547    let (larr, loff, llen) = lhs;
548    let (rarr, roff, rlen) = rhs;
549
550    let set: HashSet<&str> = (0..rlen)
551        .map(|i| unsafe { rarr.get_str_unchecked(roff + i) })
552        .collect();
553
554    let mask = larr.null_mask.as_ref().map(|m| m.slice_clone(loff, llen));
555    let _ = confirm_mask_capacity(llen, mask.as_ref())?;
556
557    let mut out = new_bool_buffer(llen);
558
559    for i in 0..llen {
560        if mask
561            .as_ref()
562            .map_or(true, |m| unsafe { m.get_unchecked(i) })
563        {
564            let s = unsafe { larr.get_str_unchecked(loff + i) };
565            if set.contains(s) {
566                unsafe { out.set_unchecked(i, true) };
567            }
568        }
569    }
570    Ok(BooleanArray {
571        data: out.into(),
572        null_mask: mask,
573        len: llen,
574        _phantom: PhantomData,
575    })
576}
577
578// Public functions
579
580/// Test if values fall between min/max bounds for comparable numeric types.
581pub fn cmp_between<T: PartialOrd + Copy + Numeric>(
582    lhs: &[T],
583    rhs: &[T],
584) -> Result<BooleanArray<()>, KernelError> {
585    if std::any::TypeId::of::<T>() == std::any::TypeId::of::<i32>() {
586        return between_i32(
587            unsafe { std::mem::transmute(lhs) },
588            unsafe { std::mem::transmute(rhs) },
589            None,
590            false,
591        );
592    }
593    if std::any::TypeId::of::<T>() == std::any::TypeId::of::<u32>() {
594        return between_u32(
595            unsafe { std::mem::transmute(lhs) },
596            unsafe { std::mem::transmute(rhs) },
597            None,
598            false,
599        );
600    }
601    if std::any::TypeId::of::<T>() == std::any::TypeId::of::<i64>() {
602        return between_i64(
603            unsafe { std::mem::transmute(lhs) },
604            unsafe { std::mem::transmute(rhs) },
605            None,
606            false,
607        );
608    }
609    if std::any::TypeId::of::<T>() == std::any::TypeId::of::<u64>() {
610        return between_u64(
611            unsafe { std::mem::transmute(lhs) },
612            unsafe { std::mem::transmute(rhs) },
613            None,
614            false,
615        );
616    }
617    // Fallback – floats or any other PartialOrd type
618    between_generic(lhs, rhs, None, false)
619}
620
621/// Mask-aware variant
622#[inline(always)]
623pub fn cmp_between_mask<T: PartialOrd + Copy + Numeric>(
624    lhs: &[T],
625    rhs: &[T],
626    mask: Option<&Bitmask>,
627) -> Result<BooleanArray<()>, KernelError> {
628    let has_nulls = mask.is_some();
629    if std::any::TypeId::of::<T>() == std::any::TypeId::of::<i32>() {
630        return between_i32(
631            unsafe { std::mem::transmute(lhs) },
632            unsafe { std::mem::transmute(rhs) },
633            mask,
634            has_nulls,
635        );
636    }
637    if std::any::TypeId::of::<T>() == std::any::TypeId::of::<u32>() {
638        return between_u32(
639            unsafe { std::mem::transmute(lhs) },
640            unsafe { std::mem::transmute(rhs) },
641            mask,
642            has_nulls,
643        );
644    }
645    if std::any::TypeId::of::<T>() == std::any::TypeId::of::<i64>() {
646        return between_i64(
647            unsafe { std::mem::transmute(lhs) },
648            unsafe { std::mem::transmute(rhs) },
649            mask,
650            has_nulls,
651        );
652    }
653    if std::any::TypeId::of::<T>() == std::any::TypeId::of::<u64>() {
654        return between_u64(
655            unsafe { std::mem::transmute(lhs) },
656            unsafe { std::mem::transmute(rhs) },
657            mask,
658            has_nulls,
659        );
660    }
661    between_generic(lhs, rhs, mask, has_nulls)
662}
663
664// In and Not In
665
666/// Test membership in set for hashable types using hash-based lookup.
667pub fn cmp_in<T: Eq + Hash + Copy + 'static>(
668    lhs: &[T],
669    rhs: &[T],
670) -> Result<BooleanArray<()>, KernelError> {
671    // i32
672    if std::any::TypeId::of::<T>() == std::any::TypeId::of::<i32>() {
673        return in_i32(
674            unsafe { std::mem::transmute(lhs) },
675            unsafe { std::mem::transmute(rhs) },
676            None,
677            false,
678        );
679    }
680    // u32
681    if std::any::TypeId::of::<T>() == std::any::TypeId::of::<u32>() {
682        return in_u32(
683            unsafe { std::mem::transmute(lhs) },
684            unsafe { std::mem::transmute(rhs) },
685            None,
686            false,
687        );
688    }
689    // i64
690    if std::any::TypeId::of::<T>() == std::any::TypeId::of::<i64>() {
691        return in_i64(
692            unsafe { std::mem::transmute(lhs) },
693            unsafe { std::mem::transmute(rhs) },
694            None,
695            false,
696        );
697    }
698    // u64
699    if std::any::TypeId::of::<T>() == std::any::TypeId::of::<u64>() {
700        return in_u64(
701            unsafe { std::mem::transmute(lhs) },
702            unsafe { std::mem::transmute(rhs) },
703            None,
704            false,
705        );
706    }
707    // i16
708    #[cfg(feature = "extended_numeric_types")]
709    if std::any::TypeId::of::<T>() == std::any::TypeId::of::<i16>() {
710        return in_i16(
711            unsafe { std::mem::transmute(lhs) },
712            unsafe { std::mem::transmute(rhs) },
713            None,
714            false,
715        );
716    }
717    // u16
718    #[cfg(feature = "extended_numeric_types")]
719    if std::any::TypeId::of::<T>() == std::any::TypeId::of::<u16>() {
720        return in_u16(
721            unsafe { std::mem::transmute(lhs) },
722            unsafe { std::mem::transmute(rhs) },
723            None,
724            false,
725        );
726    }
727    // i8
728    #[cfg(feature = "extended_numeric_types")]
729    if std::any::TypeId::of::<T>() == std::any::TypeId::of::<i8>() {
730        return in_i8(
731            unsafe { std::mem::transmute(lhs) },
732            unsafe { std::mem::transmute(rhs) },
733            None,
734            false,
735        );
736    }
737    // u8
738    #[cfg(feature = "extended_numeric_types")]
739    if std::any::TypeId::of::<T>() == std::any::TypeId::of::<u8>() {
740        return in_u8(
741            unsafe { std::mem::transmute(lhs) },
742            unsafe { std::mem::transmute(rhs) },
743            None,
744            false,
745        );
746    }
747    return Err(KernelError::UnsupportedType(
748        "cmp_in: unsupported type for SIMD in".into(),
749    ));
750}
751
752/// Mask-aware variant
753#[inline(always)]
754pub fn cmp_in_mask<T: Eq + Hash + Copy + 'static>(
755    lhs: &[T],
756    rhs: &[T],
757    mask: Option<&Bitmask>,
758) -> Result<BooleanArray<()>, KernelError> {
759    if std::any::TypeId::of::<T>() == std::any::TypeId::of::<i32>() {
760        return in_i32(
761            unsafe { std::mem::transmute(lhs) },
762            unsafe { std::mem::transmute(rhs) },
763            mask,
764            mask.is_some(),
765        );
766    }
767    if std::any::TypeId::of::<T>() == std::any::TypeId::of::<u32>() {
768        return in_u32(
769            unsafe { std::mem::transmute(lhs) },
770            unsafe { std::mem::transmute(rhs) },
771            mask,
772            mask.is_some(),
773        );
774    }
775    if std::any::TypeId::of::<T>() == std::any::TypeId::of::<i64>() {
776        return in_i64(
777            unsafe { std::mem::transmute(lhs) },
778            unsafe { std::mem::transmute(rhs) },
779            mask,
780            mask.is_some(),
781        );
782    }
783    if std::any::TypeId::of::<T>() == std::any::TypeId::of::<u64>() {
784        return in_u64(
785            unsafe { std::mem::transmute(lhs) },
786            unsafe { std::mem::transmute(rhs) },
787            mask,
788            mask.is_some(),
789        );
790    }
791    return Err(KernelError::UnsupportedType(
792        "cmp_in_mask: unsupported type (expected integer type)".into(),
793    ));
794}
795
796/// SIMD-aware, type-specific dispatch for cmp_in_f_mask and cmp_in_f
797#[inline(always)]
798pub fn cmp_in_f_mask<T: Float + Copy>(
799    lhs: &[T],
800    rhs: &[T],
801    mask: Option<&Bitmask>,
802) -> Result<BooleanArray<()>, KernelError> {
803    if TypeId::of::<T>() == TypeId::of::<f32>() {
804        let lhs = unsafe { &*(lhs as *const [T] as *const [f32]) };
805        let rhs = unsafe { &*(rhs as *const [T] as *const [f32]) };
806        in_f32(lhs, rhs, mask, mask.is_some())
807    } else if TypeId::of::<T>() == TypeId::of::<f64>() {
808        let lhs = unsafe { &*(lhs as *const [T] as *const [f64]) };
809        let rhs = unsafe { &*(rhs as *const [T] as *const [f64]) };
810        in_f64(lhs, rhs, mask, mask.is_some())
811    } else {
812        unreachable!("cmp_in_f_mask: Only f32/f64 supported for Float kernels")
813    }
814}
815
816#[inline(always)]
817/// Test membership in set for floating-point types with NaN handling.
818pub fn cmp_in_f<T: Float + Copy>(lhs: &[T], rhs: &[T]) -> Result<BooleanArray<()>, KernelError> {
819    if TypeId::of::<T>() == TypeId::of::<f32>() {
820        let lhs = unsafe { &*(lhs as *const [T] as *const [f32]) };
821        let rhs = unsafe { &*(rhs as *const [T] as *const [f32]) };
822        in_f32(lhs, rhs, None, false)
823    } else if TypeId::of::<T>() == TypeId::of::<f64>() {
824        let lhs = unsafe { &*(lhs as *const [T] as *const [f64]) };
825        let rhs = unsafe { &*(rhs as *const [T] as *const [f64]) };
826        in_f64(lhs, rhs, None, false)
827    } else {
828        unreachable!("cmp_in_f: Only f32/f64 supported for Float kernels")
829    }
830}
831
832// String and dictionary
833
834/// Test if floating-point values fall between bounds with NaN handling.
835pub fn cmp_between_f<T: PartialOrd + Copy + Float + Numeric>(
836    lhs: &[T],
837    rhs: &[T],
838) -> Result<BooleanArray<()>, KernelError> {
839    between_generic(lhs, rhs, None, false)
840}
841
842/// Test if dictionary/categorical values fall between lexicographic bounds.
843pub fn cmp_dict_between<'a, T: Integer>(
844    lhs: CategoricalAVT<'a, T>,
845    rhs: CategoricalAVT<'a, T>,
846) -> Result<BooleanArray<()>, KernelError> {
847    let (larr, loff, llen) = lhs;
848    let (rarr, roff, _rlen) = rhs;
849
850    let min = rarr.get(roff).unwrap_or("");
851    let max = rarr.get(roff + 1).unwrap_or("");
852    let mask = larr.null_mask.as_ref();
853    let _ = confirm_mask_capacity(larr.data.len(), mask)?;
854    let has_nulls = mask.is_some();
855
856    let mut out = new_bool_buffer(llen);
857    for i in 0..llen {
858        let li = loff + i;
859        if !has_nulls || mask.map_or(true, |m| unsafe { m.get_unchecked(li) }) {
860            let s = unsafe { larr.get_str_unchecked(li) };
861            if s > min && s <= max {
862                unsafe { out.set_unchecked(i, true) };
863            }
864        }
865    }
866    Ok(BooleanArray {
867        data: out.into(),
868        null_mask: mask.cloned(),
869        len: llen,
870        _phantom: PhantomData,
871    })
872}
873
874/// Returns `true` for each row in `lhs` whose string value also appears
875/// anywhere in `rhs`, respecting null masks on both sides.
876/// Returns `true` for each row in `lhs` whose string value also appears
877/// anywhere in `rhs`, respecting null masks on both sides.
878pub fn cmp_dict_in<'a, T: Integer + Hash>(
879    lhs: CategoricalAVT<'a, T>,
880    rhs: CategoricalAVT<'a, T>,
881) -> Result<BooleanArray<()>, KernelError> {
882    let (larr, loff, llen) = lhs;
883    let (rarr, roff, rlen) = rhs;
884    let mask = larr.null_mask.as_ref().map(|m| m.slice_clone(loff, llen));
885    let _ = confirm_mask_capacity(llen, mask.as_ref())?;
886
887    let mut out = Bitmask::new_set_all(llen, false);
888
889    if (larr.unique_values.len() == rarr.unique_values.len())
890        && (larr.unique_values.len() <= MAX_DICT_CHECK)
891    {
892        let mut same_dict = true;
893        for (a, b) in larr.unique_values.iter().zip(rarr.unique_values.iter()) {
894            if a != b {
895                same_dict = false;
896                break;
897            }
898        }
899
900        if same_dict {
901            let rhs_codes: HashSet<T> = rarr.data[roff..roff + rlen].iter().copied().collect();
902            for i in 0..llen {
903                if mask
904                    .as_ref()
905                    .map_or(true, |m| unsafe { m.get_unchecked(i) })
906                {
907                    let code = larr.data[loff + i];
908                    if rhs_codes.contains(&code) {
909                        unsafe { out.set_unchecked(i, true) };
910                    }
911                }
912            }
913            return Ok(BooleanArray {
914                data: out.into(),
915                null_mask: mask,
916                len: llen,
917                _phantom: PhantomData,
918            });
919        }
920    }
921
922    let rhs_strings: HashSet<&str> = (0..rlen)
923        .filter(|&i| {
924            rarr.null_mask
925                .as_ref()
926                .map_or(true, |m| unsafe { m.get_unchecked(roff + i) })
927        })
928        .map(|i| unsafe { rarr.get_str_unchecked(roff + i) })
929        .collect();
930
931    for i in 0..llen {
932        if mask
933            .as_ref()
934            .map_or(true, |m| unsafe { m.get_unchecked(i) })
935        {
936            let s = unsafe { larr.get_str_unchecked(loff + i) };
937            if rhs_strings.contains(s) {
938                unsafe { out.set_unchecked(i, true) };
939            }
940        }
941    }
942
943    Ok(BooleanArray {
944        data: out.into(),
945        null_mask: mask,
946        len: llen,
947        _phantom: PhantomData,
948    })
949}
950
951// Is Null and Not null predicates
952
953/// Generate boolean mask indicating null elements in any array type.
954pub fn is_null_array(arr: &Array) -> Result<BooleanArray<()>, KernelError> {
955    let not_null = is_not_null_array(arr)?;
956    Ok(!not_null)
957}
958/// Generate boolean mask indicating non-null elements in any array type.
959pub fn is_not_null_array(arr: &Array) -> Result<BooleanArray<()>, KernelError> {
960    let len = arr.len();
961    let mut data = Bitmask::new_set_all(len, false);
962
963    if let Some(mask) = arr.null_mask() {
964        data = mask.clone();
965    } else {
966        data.fill(true);
967    }
968    Ok(BooleanArray {
969        data,
970        null_mask: None,
971        len,
972        _phantom: PhantomData,
973    })
974}
975
976// Array in, between , not in
977/// Test membership of array elements in values set, dispatching by array type.
978pub fn in_array(input: &Array, values: &Array) -> Result<BooleanArray<()>, KernelError> {
979    match (input, values) {
980        (
981            Array::NumericArray(NumericArray::Int32(a)),
982            Array::NumericArray(NumericArray::Int32(b)),
983        ) => cmp_in_mask(&a.data, &b.data, a.null_mask.as_ref()),
984        (
985            Array::NumericArray(NumericArray::Int64(a)),
986            Array::NumericArray(NumericArray::Int64(b)),
987        ) => cmp_in_mask(&a.data, &b.data, a.null_mask.as_ref()),
988        (
989            Array::NumericArray(NumericArray::UInt32(a)),
990            Array::NumericArray(NumericArray::UInt32(b)),
991        ) => cmp_in_mask(&a.data, &b.data, a.null_mask.as_ref()),
992        (
993            Array::NumericArray(NumericArray::UInt64(a)),
994            Array::NumericArray(NumericArray::UInt64(b)),
995        ) => cmp_in_mask(&a.data, &b.data, a.null_mask.as_ref()),
996        (
997            Array::NumericArray(NumericArray::Float32(a)),
998            Array::NumericArray(NumericArray::Float32(b)),
999        ) => cmp_in_f_mask(&a.data, &b.data, a.null_mask.as_ref()),
1000        (
1001            Array::NumericArray(NumericArray::Float64(a)),
1002            Array::NumericArray(NumericArray::Float64(b)),
1003        ) => cmp_in_f_mask(&a.data, &b.data, a.null_mask.as_ref()),
1004        (Array::TextArray(TextArray::String32(a)), Array::TextArray(TextArray::String32(b))) => {
1005            cmp_str_in((**a).tuple_ref(0, a.len()), (**b).tuple_ref(0, b.len()))
1006        }
1007        (Array::BooleanArray(a), Array::BooleanArray(b)) => {
1008            cmp_in_mask(&a.data, &b.data, a.null_mask.as_ref())
1009        }
1010        (
1011            Array::TextArray(TextArray::Categorical32(a)),
1012            Array::TextArray(TextArray::Categorical32(b)),
1013        ) => cmp_dict_in((**a).tuple_ref(0, a.len()), (**b).tuple_ref(0, b.len())),
1014        _ => unimplemented!(),
1015    }
1016}
1017
1018#[inline(always)]
1019/// Test non-membership of array elements in values set, dispatching by array type.
1020pub fn not_in_array(input: &Array, values: &Array) -> Result<BooleanArray<()>, KernelError> {
1021    let result = in_array(input, values)?;
1022    Ok(!result)
1023}
1024
1025/// Test if array elements fall between min/max bounds, dispatching by array type.
1026pub fn between_array(input: &Array, min: &Array, max: &Array) -> Result<Array, KernelError> {
1027    macro_rules! between_case {
1028        ($variant:ident, $cmp:ident) => {{
1029            let arr = match input {
1030                Array::NumericArray(NumericArray::$variant(arr)) => arr,
1031                _ => unreachable!(),
1032            };
1033            let mins = match min {
1034                Array::NumericArray(NumericArray::$variant(arr)) => arr,
1035                _ => unreachable!(),
1036            };
1037            let maxs = match max {
1038                Array::NumericArray(NumericArray::$variant(arr)) => arr,
1039                _ => unreachable!(),
1040            };
1041            let rhs: Vec64<_> = mins
1042                .data
1043                .iter()
1044                .zip(&maxs.data)
1045                .flat_map(|(&lo, &hi)| [lo, hi])
1046                .collect();
1047            Ok(Array::BooleanArray(
1048                $cmp(
1049                    &arr.data,
1050                    &rhs,
1051                    arr.null_mask.as_ref(),
1052                    arr.null_mask.is_some(),
1053                )?
1054                .into(),
1055            ))
1056        }};
1057    }
1058
1059    match (input, min, max) {
1060        (
1061            Array::NumericArray(NumericArray::Int32(..)),
1062            Array::NumericArray(NumericArray::Int32(..)),
1063            Array::NumericArray(NumericArray::Int32(..)),
1064        ) => between_case!(Int32, between_i32),
1065        (
1066            Array::NumericArray(NumericArray::Int64(..)),
1067            Array::NumericArray(NumericArray::Int64(..)),
1068            Array::NumericArray(NumericArray::Int64(..)),
1069        ) => between_case!(Int64, between_i64),
1070        (
1071            Array::NumericArray(NumericArray::UInt32(..)),
1072            Array::NumericArray(NumericArray::UInt32(..)),
1073            Array::NumericArray(NumericArray::UInt32(..)),
1074        ) => between_case!(UInt32, between_u32),
1075        (
1076            Array::NumericArray(NumericArray::UInt64(..)),
1077            Array::NumericArray(NumericArray::UInt64(..)),
1078            Array::NumericArray(NumericArray::UInt64(..)),
1079        ) => between_case!(UInt64, between_u64),
1080        (
1081            Array::NumericArray(NumericArray::Float32(..)),
1082            Array::NumericArray(NumericArray::Float32(..)),
1083            Array::NumericArray(NumericArray::Float32(..)),
1084        ) => between_case!(Float32, between_generic),
1085        (
1086            Array::NumericArray(NumericArray::Float64(..)),
1087            Array::NumericArray(NumericArray::Float64(..)),
1088            Array::NumericArray(NumericArray::Float64(..)),
1089        ) => between_case!(Float64, between_generic),
1090        _ => Err(KernelError::UnsupportedType(
1091            "Unsupported Type Error.".to_string(),
1092        )),
1093    }
1094}
1095
1096/// Bitwise NOT of a bit-packed boolean mask window.
1097/// Offset is a bit offset; len is in bits.
1098/// Requires offset % 64 == 0 for word-level SIMD processing.
1099#[inline]
1100pub fn not_bool<const LANES: usize>(
1101    src: BooleanAVT<'_, ()>,
1102) -> Result<BooleanArray<()>, KernelError>
1103where
1104{
1105    let (arr, offset, len) = src;
1106
1107    if offset % 64 != 0 {
1108        return Err(KernelError::InvalidArguments(format!(
1109            "not_bool: offset must be 64-bit aligned (got offset={})",
1110            offset
1111        )));
1112    }
1113
1114    let null_mask = arr.null_mask.as_ref().map(|nm| nm.slice_clone(offset, len));
1115
1116    let data = if null_mask.is_none() {
1117        #[cfg(feature = "simd")]
1118        {
1119            minarrow::kernels::bitmask::simd::not_mask_simd::<LANES>((&arr.data, offset, len))
1120        }
1121        #[cfg(not(feature = "simd"))]
1122        {
1123            minarrow::kernels::bitmask::std::not_mask((&arr.data, offset, len))
1124        }
1125    } else {
1126        // clone window – no modifications
1127        arr.data.slice_clone(offset, len)
1128    };
1129
1130    Ok(BooleanArray {
1131        data,
1132        null_mask,
1133        len,
1134        _phantom: core::marker::PhantomData,
1135    })
1136}
1137
1138/// Logical AND/OR/XOR of two bit-packed boolean masks over a window.
1139/// Offsets are bit offsets. Length is in bits.
1140/// Panics if offsets are not 64-bit aligned.
1141pub fn apply_logical_bool<const LANES: usize>(
1142    lhs: BooleanAVT<'_, ()>,
1143    rhs: BooleanAVT<'_, ()>,
1144    op: LogicalOperator,
1145) -> Result<BooleanArray<()>, KernelError>
1146where
1147{
1148    let (lhs_arr, lhs_off, len) = lhs;
1149    let (rhs_arr, rhs_off, rlen) = rhs;
1150
1151    if len != rlen {
1152        return Err(KernelError::LengthMismatch(format!(
1153            "logical_bool: window length mismatch (lhs: {}, rhs: {})",
1154            len, rlen
1155        )));
1156    }
1157    if lhs_off % 64 != 0 || rhs_off % 64 != 0 {
1158        return Err(KernelError::InvalidArguments(format!(
1159            "logical_bool: offsets must be 64-bit aligned (lhs: {}, rhs: {})",
1160            lhs_off, rhs_off
1161        )));
1162    }
1163
1164    // Apply bitmask kernel for the logical operation.
1165
1166    #[cfg(feature = "simd")]
1167    let data = match op {
1168        LogicalOperator::And => {
1169            and_masks_simd::<LANES>((&lhs_arr.data, lhs_off, len), (&rhs_arr.data, rhs_off, len))
1170        }
1171        LogicalOperator::Or => {
1172            or_masks_simd::<LANES>((&lhs_arr.data, lhs_off, len), (&rhs_arr.data, rhs_off, len))
1173        }
1174        LogicalOperator::Xor => {
1175            xor_masks_simd::<LANES>((&lhs_arr.data, lhs_off, len), (&rhs_arr.data, rhs_off, len))
1176        }
1177    };
1178
1179    // Merge validity (null) masks using AND
1180    #[cfg(feature = "simd")]
1181    let null_mask = match (lhs_arr.null_mask.as_ref(), rhs_arr.null_mask.as_ref()) {
1182        (None, None) => None,
1183        (Some(a), None) | (None, Some(a)) => Some(a.slice_clone(lhs_off, len)),
1184        (Some(a), Some(b)) => Some(and_masks_simd::<LANES>(
1185            (a, lhs_off, len),
1186            (b, rhs_off, len),
1187        )),
1188    };
1189
1190    #[cfg(not(feature = "simd"))]
1191    let data = match op {
1192        LogicalOperator::And => {
1193            and_masks((&lhs_arr.data, lhs_off, len), (&rhs_arr.data, rhs_off, len))
1194        }
1195        LogicalOperator::Or => {
1196            or_masks((&lhs_arr.data, lhs_off, len), (&rhs_arr.data, rhs_off, len))
1197        }
1198        LogicalOperator::Xor => {
1199            xor_masks((&lhs_arr.data, lhs_off, len), (&rhs_arr.data, rhs_off, len))
1200        }
1201    };
1202
1203    #[cfg(not(feature = "simd"))]
1204    let null_mask = match (lhs_arr.null_mask.as_ref(), rhs_arr.null_mask.as_ref()) {
1205        (None, None) => None,
1206        (Some(a), None) | (None, Some(a)) => Some(a.slice_clone(lhs_off, len)),
1207        (Some(a), Some(b)) => Some(and_masks((a, lhs_off, len), (b, rhs_off, len))),
1208    };
1209
1210    Ok(BooleanArray {
1211        data,
1212        null_mask,
1213        len,
1214        _phantom: PhantomData,
1215    })
1216}
1217
1218#[cfg(test)]
1219mod tests {
1220    use minarrow::structs::variants::categorical::CategoricalArray;
1221    use minarrow::structs::variants::float::FloatArray;
1222    use minarrow::structs::variants::integer::IntegerArray;
1223    use minarrow::structs::variants::string::StringArray;
1224    use minarrow::{Array, Bitmask, BooleanArray, vec64};
1225
1226    use super::*;
1227
1228    // --- helpers ---
1229
1230    fn bm(bits: &[bool]) -> Bitmask {
1231        let mut m = Bitmask::new_set_all(bits.len(), false);
1232        for (i, &b) in bits.iter().enumerate() {
1233            m.set(i, b);
1234        }
1235        m
1236    }
1237
1238    fn assert_bool(arr: &BooleanArray<()>, expect: &[bool], expect_mask: Option<&[bool]>) {
1239        assert_eq!(arr.len, expect.len(), "length mismatch");
1240        for i in 0..expect.len() {
1241            assert_eq!(arr.data.get(i), expect[i], "val @ {i}");
1242        }
1243        match (expect_mask, &arr.null_mask) {
1244            (None, None) => {}
1245            (Some(exp), Some(mask)) => {
1246                for (i, &b) in exp.iter().enumerate() {
1247                    assert_eq!(mask.get(i), b, "mask @ {i}");
1248                }
1249            }
1250            (None, Some(mask)) => {
1251                // all mask bits should be true
1252                for i in 0..arr.len {
1253                    assert!(mask.get(i), "unexpected false mask @ {i}");
1254                }
1255            }
1256            (Some(_), None) => panic!("expected null mask"),
1257        }
1258    }
1259
1260    fn i32_arr(data: &[i32]) -> IntegerArray<i32> {
1261        IntegerArray::from_slice(data)
1262    }
1263    fn f32_arr(data: &[f32]) -> FloatArray<f32> {
1264        FloatArray::from_slice(data)
1265    }
1266    fn str_arr<T: Integer>(vals: &[&str]) -> StringArray<T> {
1267        StringArray::<T>::from_slice(vals)
1268    }
1269    fn dict_arr<T: Integer>(vals: &[&str]) -> CategoricalArray<T> {
1270        let owned: Vec<&str> = vals.to_vec();
1271        CategoricalArray::<T>::from_values(owned)
1272    }
1273    //  BETWEEN
1274
1275    #[test]
1276    fn between_i32_scalar_rhs() {
1277        let lhs = vec64![1, 3, 5, 7];
1278        let rhs = vec64![2, 6];
1279        let out = between_i32(&lhs, &rhs, None, false).unwrap();
1280        assert_bool(&out, &[false, true, true, false], None);
1281    }
1282
1283    #[test]
1284    fn between_i32_per_row_rhs() {
1285        let lhs = vec64![5, 9, 2, 8];
1286        let rhs = vec64![0, 10, 0, 4, 2, 2, 8, 9]; // min/max for each row
1287        let out = between_i32(&lhs, &rhs, None, false).unwrap();
1288        assert_bool(&out, &[true, false, true, true], None);
1289    }
1290
1291    #[test]
1292    fn between_i32_nulls_propagate() {
1293        let lhs = vec64![5, 9, 2, 8];
1294        let rhs = vec64![0, 10, 0, 4, 2, 2, 8, 9];
1295        let mask = bm(&[true, false, true, true]);
1296        let out = between_i32(&lhs, &rhs, Some(&mask), true).unwrap();
1297        assert_bool(
1298            &out,
1299            &[true, false, true, true],
1300            Some(&[true, false, true, true]),
1301        );
1302    }
1303
1304    #[cfg(feature = "extended_numeric_types")]
1305    #[test]
1306    fn between_i16_works() {
1307        let lhs = vec64![10i16, 12, 99];
1308        let rhs = vec64![10i16, 12];
1309        let out = in_i16(&lhs, &rhs, None, false).unwrap();
1310        assert_bool(&out, &[true, true, false], None);
1311    }
1312
1313    #[test]
1314    fn between_f64_scalar_and_nulls() {
1315        let lhs = vec64![1.0, 5.0, 8.0, 20.0];
1316        let rhs = vec64![4.0, 10.0];
1317        let mask = bm(&[true, false, true, true]);
1318        let out = between_f64(&lhs, &rhs, Some(&mask), true).unwrap();
1319        assert_bool(
1320            &out,
1321            &[false, false, true, false],
1322            Some(&[true, false, true, true]),
1323        );
1324    }
1325
1326    #[test]
1327    fn between_f32_generic_dispatch() {
1328        let lhs = vec64![0.1f32, 0.5, 1.2, -1.0];
1329        let rhs = vec64![0.0, 1.0];
1330        let out = cmp_between(&lhs, &rhs).unwrap();
1331        assert_bool(&out, &[true, true, false, false], None);
1332    }
1333
1334    #[test]
1335    fn between_masked_dispatch() {
1336        let lhs = vec64![1i32, 2, 3];
1337        let rhs = vec64![0, 2];
1338        let mask = bm(&[true, false, true]);
1339        let out = cmp_between_mask(&lhs, &rhs, Some(&mask)).unwrap();
1340        assert_bool(&out, &[true, false, false], Some(&[true, false, true]));
1341    }
1342
1343    // IN
1344
1345    #[test]
1346    fn in_i32_small_rhs() {
1347        let lhs = vec64![1, 2, 3, 4, 5];
1348        let rhs = vec64![2, 4];
1349        let out = in_i32(&lhs, &rhs, None, false).unwrap();
1350        assert_bool(&out, &[false, true, false, true, false], None);
1351    }
1352
1353    #[test]
1354    fn in_i32_with_nulls() {
1355        let lhs = vec64![7, 8, 9];
1356        let rhs = vec64![8];
1357        let mask = bm(&[true, false, true]);
1358        let out = in_i32(&lhs, &rhs, Some(&mask), true).unwrap();
1359        assert_bool(&out, &[false, false, false], Some(&[true, false, true]));
1360    }
1361
1362    #[test]
1363    fn in_i64_large_rhs() {
1364        let lhs = vec64![1i64, 2, 3, 7, 8, 15];
1365        let rhs: Vec<i64> = (2..10).collect();
1366        let out = in_i64(&lhs, &rhs, None, false).unwrap();
1367        assert_bool(&out, &[false, true, true, true, true, false], None);
1368    }
1369
1370    #[cfg(feature = "extended_numeric_types")]
1371    #[test]
1372    fn in_u8_small_rhs() {
1373        let lhs = vec64![1u8, 2, 3, 4];
1374        let rhs = vec64![2u8, 3];
1375        let out = in_u8(&lhs, &rhs, None, false).unwrap();
1376        assert_bool(&out, &[false, true, true, false], None);
1377    }
1378
1379    #[test]
1380    fn in_float_nan_and_normal() {
1381        let lhs = vec64![1.0f32, f32::NAN, 7.0];
1382        let rhs = vec64![f32::NAN, 7.0];
1383        let out = in_f32(&lhs, &rhs, None, false).unwrap();
1384        assert_bool(&out, &[false, true, true], None);
1385    }
1386
1387    // BETWEEN / IN
1388
1389    #[test]
1390    fn string_between() {
1391        let lhs = str_arr::<u32>(&["aa", "bb", "zz"]);
1392        let rhs = str_arr::<u32>(&["b", "y"]);
1393        let lhs_slice = (&lhs, 0, lhs.len());
1394        let rhs_slice = (&rhs, 0, rhs.len());
1395        let out = cmp_str_between(lhs_slice, rhs_slice).unwrap();
1396        assert_bool(&out, &[false, true, false], None);
1397    }
1398
1399    #[test]
1400    fn string_between_chunk() {
1401        let lhs = str_arr::<u32>(&["0", "aa", "bb", "zz", "9"]);
1402        let rhs = str_arr::<u32>(&["a", "b", "y", "z"]);
1403        // Windowed: skip first/last for lhs; use a window for rhs
1404        let lhs_slice = (&lhs, 1, 3); // ["aa", "bb", "zz"]
1405        let rhs_slice = (&rhs, 1, 2); // ["b", "y"]
1406        let out = cmp_str_between(lhs_slice, rhs_slice).unwrap();
1407        assert_bool(&out, &[false, true, false], None);
1408    }
1409
1410    #[test]
1411    fn string_in_basic() {
1412        let lhs = str_arr::<u32>(&["x", "y", "z"]);
1413        let rhs = str_arr::<u32>(&["y", "a"]);
1414        let lhs_slice = (&lhs, 0, lhs.len());
1415        let rhs_slice = (&rhs, 0, rhs.len());
1416        let out = cmp_str_in(lhs_slice, rhs_slice).unwrap();
1417        assert_bool(&out, &[false, true, false], None);
1418    }
1419
1420    #[test]
1421    fn string_in_basic_chunk() {
1422        let lhs = str_arr::<u32>(&["0", "x", "y", "z", "9"]);
1423        let rhs = str_arr::<u32>(&["b", "y", "a", "c"]);
1424        let lhs_slice = (&lhs, 1, 3); // ["x", "y", "z"]
1425        let rhs_slice = (&rhs, 1, 2); // ["y", "a"]
1426        let out = cmp_str_in(lhs_slice, rhs_slice).unwrap();
1427        assert_bool(&out, &[false, true, false], None);
1428    }
1429
1430    #[test]
1431    fn dict_between() {
1432        let lhs = dict_arr::<u32>(&["cat", "dog", "emu"]);
1433        let rhs = dict_arr::<u32>(&["cobra", "dove"]);
1434        let lhs_slice = (&lhs, 0, lhs.len());
1435        let rhs_slice = (&rhs, 0, rhs.len());
1436        let out = cmp_dict_between(lhs_slice, rhs_slice).unwrap();
1437        assert_bool(&out, &[false, true, false], None);
1438    }
1439
1440    #[test]
1441    fn dict_between_chunk() {
1442        let lhs = dict_arr::<u32>(&["a", "cat", "dog", "emu", "z"]);
1443        let rhs = dict_arr::<u32>(&["a", "cobra", "dove", "zz"]);
1444        let lhs_slice = (&lhs, 1, 3); // ["cat", "dog", "emu"]
1445        let rhs_slice = (&rhs, 1, 2); // ["cobra", "dove"]
1446        let out = cmp_dict_between(lhs_slice, rhs_slice).unwrap();
1447        assert_bool(&out, &[false, true, false], None);
1448    }
1449
1450    #[test]
1451    fn dict_in_membership() {
1452        let lhs = dict_arr::<u32>(&["aa", "bb", "cc"]);
1453        let rhs = dict_arr::<u32>(&["bb", "dd"]);
1454        let lhs_slice = (&lhs, 0, lhs.len());
1455        let rhs_slice = (&rhs, 0, rhs.len());
1456        let out = cmp_dict_in(lhs_slice, rhs_slice).unwrap();
1457        assert_bool(&out, &[false, true, false], None);
1458    }
1459
1460    #[test]
1461    fn dict_in_membership_chunk() {
1462        let lhs = dict_arr::<u32>(&["0", "aa", "bb", "cc", "9"]);
1463        let rhs = dict_arr::<u32>(&["a", "bb", "dd", "zz"]);
1464        let lhs_slice = (&lhs, 1, 3); // ["aa", "bb", "cc"]
1465        let rhs_slice = (&rhs, 1, 2); // ["bb", "dd"]
1466        let out = cmp_dict_in(lhs_slice, rhs_slice).unwrap();
1467        assert_bool(&out, &[false, true, false], None);
1468    }
1469
1470    #[test]
1471    fn string_between_nulls() {
1472        let mut lhs = str_arr::<u32>(&["foo", "bar", "baz"]);
1473        lhs.null_mask = Some(bm(&[true, false, true]));
1474        let rhs = str_arr::<u32>(&["a", "zzz"]);
1475        let lhs_slice = (&lhs, 0, lhs.len());
1476        let rhs_slice = (&rhs, 0, rhs.len());
1477        let out = cmp_str_between(lhs_slice, rhs_slice).unwrap();
1478        assert_bool(&out, &[true, false, true], Some(&[true, false, true]));
1479    }
1480
1481    #[test]
1482    fn string_between_nulls_chunk() {
1483        let mut lhs = str_arr::<u32>(&["0", "foo", "bar", "baz", "z"]);
1484        lhs.null_mask = Some(bm(&[true, true, false, true, true]));
1485        let rhs = str_arr::<u32>(&["0", "a", "zzz", "9"]);
1486        let lhs_slice = (&lhs, 1, 3); // ["foo", "bar", "baz"]
1487        let rhs_slice = (&rhs, 1, 2); // ["a", "zzz"]
1488        let out = cmp_str_between(lhs_slice, rhs_slice).unwrap();
1489        assert_bool(&out, &[true, false, true], Some(&[true, false, true]));
1490    }
1491
1492    #[test]
1493    fn dict_in_nulls() {
1494        let mut lhs = dict_arr::<u32>(&["one", "two", "three"]);
1495        lhs.null_mask = Some(bm(&[false, true, true]));
1496        let rhs = dict_arr::<u32>(&["two", "four"]);
1497        let lhs_slice = (&lhs, 0, lhs.len());
1498        let rhs_slice = (&rhs, 0, rhs.len());
1499        let out = cmp_dict_in(lhs_slice, rhs_slice).unwrap();
1500        assert_bool(&out, &[false, true, false], Some(&[false, true, true]));
1501    }
1502
1503    #[test]
1504    fn dict_in_nulls_chunk() {
1505        let mut lhs = dict_arr::<u32>(&["x", "one", "two", "three", "z"]);
1506        lhs.null_mask = Some(bm(&[true, false, true, true, true]));
1507        let rhs = dict_arr::<u32>(&["a", "two", "four", "b"]);
1508        let lhs_slice = (&lhs, 1, 3); // ["one", "two", "three"]
1509        let rhs_slice = (&rhs, 1, 2); // ["two", "four"]
1510        let out = cmp_dict_in(lhs_slice, rhs_slice).unwrap();
1511        assert_bool(&out, &[false, true, false], Some(&[false, true, true]));
1512    }
1513
1514    // Boolean/Null
1515
1516    #[test]
1517    fn is_null_and_is_not_null() {
1518        let mut arr = i32_arr(&[1, 2, 0]);
1519        arr.null_mask = Some(bm(&[true, false, true]));
1520        let array = Array::from_int32(arr.clone());
1521
1522        let not_null = is_not_null_array(&array).unwrap();
1523        let is_null = is_null_array(&array).unwrap();
1524
1525        assert_bool(&not_null, &[true, false, true], None);
1526        assert_bool(&is_null, &[false, true, false], None);
1527    }
1528
1529    #[test]
1530    fn is_null_not_null_dense() {
1531        let arr = i32_arr(&[1, 2, 3]);
1532        let array = Array::from_int32(arr.clone());
1533        let is_null = is_null_array(&array).unwrap();
1534        assert_bool(&is_null, &[false, false, false], None);
1535        let not_null = is_not_null_array(&array).unwrap();
1536        assert_bool(&not_null, &[true, true, true], None);
1537    }
1538
1539    //  Array dispatch in_array, not_in_array, between_array ----
1540
1541    #[test]
1542    fn in_array_int32_dispatch() {
1543        let inp = Array::from_int32(i32_arr(&[10, 20, 30]));
1544        let vals = Array::from_int32(i32_arr(&[20, 40]));
1545        let out = in_array(&inp, &vals).unwrap();
1546        assert_bool(&out, &[false, true, false], None);
1547
1548        let out_not = not_in_array(&inp, &vals).unwrap();
1549        assert_bool(&out_not, &[true, false, true], None);
1550    }
1551
1552    #[test]
1553    fn in_array_f32_dispatch() {
1554        let inp = Array::from_float32(f32_arr(&[1.0, f32::NAN, 7.0]));
1555        let vals = Array::from_float32(f32_arr(&[f32::NAN, 7.0]));
1556        let out = in_array(&inp, &vals).unwrap();
1557        assert_bool(&out, &[false, true, true], None);
1558    }
1559
1560    #[test]
1561    fn in_array_string_dispatch() {
1562        let inp = Array::from_string32(str_arr::<u32>(&["a", "b", "c"]));
1563        let vals = Array::from_string32(str_arr::<u32>(&["b", "d"]));
1564        let out = in_array(&inp, &vals).unwrap();
1565        assert_bool(&out, &[false, true, false], None);
1566    }
1567
1568    #[test]
1569    fn in_array_dictionary_dispatch() {
1570        let inp = Array::from_categorical32(dict_arr::<u32>(&["aa", "bb", "cc"]));
1571        let vals = Array::from_categorical32(dict_arr::<u32>(&["bb", "cc"]));
1572        let out = in_array(&inp, &vals).unwrap();
1573        assert_bool(&out, &[false, true, true], None);
1574    }
1575
1576    #[test]
1577    fn between_array_int32_rows() {
1578        let inp = Array::from_int32(i32_arr(&[5, 15, 25]));
1579        let min = Array::from_int32(i32_arr(&[0, 10, 20]));
1580        let max = Array::from_int32(i32_arr(&[10, 20, 30]));
1581
1582        let out = between_array(&inp, &min, &max).unwrap();
1583        match out {
1584            Array::BooleanArray(b) => assert_bool(&b, &[true, true, true], None),
1585            _ => panic!("expected Bool array"),
1586        }
1587    }
1588
1589    #[test]
1590    fn between_array_float_generic() {
1591        let inp = Array::from_float32(f32_arr(&[0.5, 1.5, 2.5]));
1592        let min = Array::from_float32(f32_arr(&[0.0, 1.0, 2.0]));
1593        let max = Array::from_float32(f32_arr(&[1.0, 2.0, 3.0]));
1594
1595        let out = between_array(&inp, &min, &max).unwrap();
1596        match out {
1597            Array::BooleanArray(b) => assert_bool(&b, &[true, true, true], None),
1598            _ => panic!("expected Bool"),
1599        }
1600    }
1601
1602    #[test]
1603    fn between_array_type_mismatch() {
1604        let inp = Array::from_int32(i32_arr(&[1, 2, 3]));
1605        let min = Array::from_float32(f32_arr(&[0.0, 0.0, 0.0]));
1606        let max = Array::from_float32(f32_arr(&[5.0, 5.0, 5.0]));
1607        let err = between_array(&inp, &min, &max).unwrap_err();
1608        match err {
1609            KernelError::UnsupportedType(_) => {}
1610            _ => panic!("Expected UnsupportedType error"),
1611        }
1612    }
1613
1614    // all integer types, short and long
1615
1616    #[test]
1617    fn in_integers_various_types() {
1618        #[cfg(feature = "extended_numeric_types")]
1619        {
1620            let u8_lhs = vec64![1u8, 2, 3, 5];
1621            let u8_rhs = vec64![3u8, 5, 8];
1622            let out = in_u8(&u8_lhs, &u8_rhs, None, false).unwrap();
1623            assert_bool(&out, &[false, false, true, true], None);
1624
1625            let u16_lhs = vec64![100u16, 200, 300];
1626            let u16_rhs = vec64![200u16, 500];
1627            let out = in_u16(&u16_lhs, &u16_rhs, None, false).unwrap();
1628            assert_bool(&out, &[false, true, false], None);
1629
1630            let i16_lhs = vec64![10i16, 15, 42];
1631            let i16_rhs = vec64![15i16, 42, 77];
1632            let out = in_i16(&i16_lhs, &i16_rhs, None, false).unwrap();
1633            assert_bool(&out, &[false, true, true], None);
1634        }
1635
1636        let u32_lhs = vec64![0u32, 1, 2, 9];
1637        let u32_rhs = vec64![9u32, 1];
1638        let out = in_u32(&u32_lhs, &u32_rhs, None, false).unwrap();
1639        assert_bool(&out, &[false, true, false, true], None);
1640
1641        let i64_lhs = vec64![1i64, 9, 10];
1642        let i64_rhs = vec64![2i64, 10, 20];
1643        let out = in_i64(&i64_lhs, &i64_rhs, None, false).unwrap();
1644        assert_bool(&out, &[false, false, true], None);
1645
1646        let u64_lhs = vec64![1u64, 2, 3, 4];
1647        let u64_rhs = vec64![2u64, 4, 8];
1648        let out = in_u64(&u64_lhs, &u64_rhs, None, false).unwrap();
1649        assert_bool(&out, &[false, true, false, true], None);
1650    }
1651
1652    // empty input edge
1653
1654    #[test]
1655    fn between_and_in_empty_inputs() {
1656        // Between, scalar rhs (for numeric arrays, no slice tuple needed)
1657        let lhs: [i32; 0] = [];
1658        let rhs = vec64![0, 1];
1659        let out = between_i32(&lhs, &rhs, None, false).unwrap();
1660        assert_eq!(out.len, 0);
1661
1662        // In, any rhs (for numeric arrays, no slice tuple needed)
1663        let lhs: [i32; 0] = [];
1664        let rhs = vec64![1, 2, 3];
1665        let out = in_i32(&lhs, &rhs, None, false).unwrap();
1666        assert_eq!(out.len, 0);
1667
1668        // String, in (slice API)
1669        let lhs = str_arr::<u32>(&[]);
1670        let rhs = str_arr::<u32>(&["a", "b"]);
1671        let lhs_slice = (&lhs, 0, lhs.len());
1672        let rhs_slice = (&rhs, 0, rhs.len());
1673        let out = cmp_str_in(lhs_slice, rhs_slice).unwrap();
1674        assert_eq!(out.len, 0);
1675    }
1676
1677    #[test]
1678    fn between_and_in_empty_inputs_chunk() {
1679        // Only applies to the string in version
1680        let lhs = str_arr::<u32>(&["x", "y"]);
1681        let rhs = str_arr::<u32>(&["a", "b", "c"]);
1682        let lhs_slice = (&lhs, 1, 0); // zero-length window
1683        let rhs_slice = (&rhs, 1, 2); // window ["b", "c"]
1684        let out = cmp_str_in(lhs_slice, rhs_slice).unwrap();
1685        assert_eq!(out.len, 0);
1686    }
1687
1688    #[test]
1689    fn between_per_row_bounds_on_last_row() {
1690        // Coverage: last row per-row
1691        let lhs = vec64![0i32, 10, 20, 30];
1692        let rhs = vec64![0, 5, 5, 15, 15, 25, 25, 35];
1693        let out = between_i32(&lhs, &rhs, None, false).unwrap();
1694        assert_bool(&out, &[true, true, true, true], None);
1695    }
1696
1697    #[test]
1698    fn test_cmp_dict_in_force_fallback() {
1699        // lhs and rhs have different unique_values lengths
1700        let mut lhs = dict_arr::<u32>(&["a", "b", "c", "a"]);
1701        lhs.unique_values = vec64!["a".to_string(), "b".to_string(), "c".to_string()]; // len=3
1702        let mut rhs = dict_arr::<u32>(&["b", "x", "y", "z"]);
1703        rhs.unique_values = vec64![
1704            "b".to_string(),
1705            "x".to_string(),
1706            "y".to_string(),
1707            "z".to_string()
1708        ]; // len=4
1709        lhs.null_mask = Some(bm(&[true, true, true, true]));
1710        let lhs_slice = (&lhs, 0, lhs.len());
1711        let rhs_slice = (&rhs, 0, rhs.len());
1712        let out = cmp_dict_in(lhs_slice, rhs_slice).unwrap();
1713        // should fall back to string-matching: only "b" matches
1714        assert_bool(
1715            &out,
1716            &[false, true, false, false],
1717            Some(&[true, true, true, true]),
1718        );
1719    }
1720
1721    #[test]
1722    fn test_cmp_dict_in_force_fallback_chunk() {
1723        let mut lhs = dict_arr::<u32>(&["z", "a", "b", "c", "a", "q"]);
1724        lhs.unique_values = vec64![
1725            "z".to_string(),
1726            "a".to_string(),
1727            "b".to_string(),
1728            "c".to_string(),
1729            "q".to_string()
1730        ];
1731        let mut rhs = dict_arr::<u32>(&["x", "b", "x", "y", "z"]);
1732        rhs.unique_values = vec64![
1733            "x".to_string(),
1734            "b".to_string(),
1735            "y".to_string(),
1736            "z".to_string()
1737        ];
1738        lhs.null_mask = Some(bm(&[true, true, true, true, true, true]));
1739        // Window: pick ["a", "b", "c", "a"] and ["b", "x", "y", "z"]
1740        let lhs_slice = (&lhs, 1, 4);
1741        let rhs_slice = (&rhs, 1, 4);
1742        let out = cmp_dict_in(lhs_slice, rhs_slice).unwrap();
1743        // Only "b" matches (index 1 of window)
1744        assert_bool(
1745            &out,
1746            &[false, true, false, false],
1747            Some(&[true, true, true, true]),
1748        );
1749    }
1750
1751    #[test]
1752    fn test_in_array_empty_rhs() {
1753        let arr = Array::from_int32(i32_arr(&[1, 2, 3]));
1754        let empty = Array::from_int32(i32_arr(&[]));
1755        let out = in_array(&arr, &empty).unwrap();
1756        // must be all false, and mask preserved (no mask => all bits true)
1757        assert_bool(&out, &[false, false, false], None);
1758    }
1759}