1include!(concat!(env!("OUT_DIR"), "/simd_lanes.rs"));
20
21use std::marker::PhantomData;
23#[cfg(feature = "simd")]
24use std::simd::{Mask, Simd};
25
26use minarrow::{Bitmask, BooleanArray, Integer, Numeric};
27
28#[cfg(not(feature = "simd"))]
29use crate::kernels::bitmask::std::{and_masks, in_mask, not_in_mask, not_mask};
30use crate::operators::ComparisonOperator;
31use minarrow::enums::error::KernelError;
32#[cfg(feature = "simd")]
33use minarrow::kernels::bitmask::simd::{
34 and_masks_simd, in_mask_simd, not_in_mask_simd, not_mask_simd,
35};
36use minarrow::utils::confirm_equal_len;
37#[cfg(feature = "simd")]
38use minarrow::utils::is_simd_aligned;
39use minarrow::{BitmaskVT, BooleanAVT, CategoricalAVT, StringAVT};
40
41#[inline(always)]
43fn new_bool_bitmask(len: usize) -> Bitmask {
44 Bitmask::new_set_all(len, false)
45}
46
47fn merge_bitmasks_to_new(a: Option<&Bitmask>, b: Option<&Bitmask>, len: usize) -> Option<Bitmask> {
49 match (a, b) {
50 (None, None) => None,
51 (Some(x), None) | (None, Some(x)) => Some(x.slice_clone(0, len)),
52 (Some(x), Some(y)) => {
53 let mut out = Bitmask::new_set_all(len, true);
54 for i in 0..len {
55 unsafe { out.set_unchecked(i, x.get_unchecked(i) && y.get_unchecked(i)) };
56 }
57 Some(out)
58 }
59 }
60}
61
62macro_rules! impl_cmp_numeric {
65 ($fn_name:ident, $fn_name_to:ident, $ty:ty, $lanes:expr, $mask_elem:ty) => {
66 #[inline(always)]
71 pub fn $fn_name_to(
72 lhs: &[$ty],
73 rhs: &[$ty],
74 mask: Option<&Bitmask>,
75 op: ComparisonOperator,
76 output: &mut Bitmask,
77 ) -> Result<(), KernelError> {
78 let len = lhs.len();
79 confirm_equal_len("compare numeric length mismatch", len, rhs.len())?;
80 assert!(
81 output.capacity() >= len,
82 concat!(stringify!($fn_name_to), ": output capacity too small")
83 );
84 let has_nulls = mask.is_some();
85
86 #[cfg(feature = "simd")]
87 {
88 if is_simd_aligned(lhs) && is_simd_aligned(rhs) {
90 use std::simd::cmp::{SimdPartialEq, SimdPartialOrd};
91 const N: usize = $lanes;
92 if !has_nulls {
93 let mut i = 0;
94 while i + N <= len {
95 let a = Simd::<$ty, N>::from_slice(&lhs[i..i + N]);
96 let b = Simd::<$ty, N>::from_slice(&rhs[i..i + N]);
97 let m: Mask<$mask_elem, N> = match op {
98 ComparisonOperator::Equals => a.simd_eq(b),
99 ComparisonOperator::NotEquals => a.simd_ne(b),
100 ComparisonOperator::LessThan => a.simd_lt(b),
101 ComparisonOperator::LessThanOrEqualTo => a.simd_le(b),
102 ComparisonOperator::GreaterThan => a.simd_gt(b),
103 ComparisonOperator::GreaterThanOrEqualTo => a.simd_ge(b),
104 _ => Mask::splat(false),
105 };
106 let bits = m.to_bitmask();
107 for l in 0..N {
108 if ((bits >> l) & 1) == 1 {
109 unsafe { output.set_unchecked(i + l, true) };
110 }
111 }
112 i += N;
113 }
114 for j in i..len {
116 let res = match op {
117 ComparisonOperator::Equals => lhs[j] == rhs[j],
118 ComparisonOperator::NotEquals => lhs[j] != rhs[j],
119 ComparisonOperator::LessThan => lhs[j] < rhs[j],
120 ComparisonOperator::LessThanOrEqualTo => lhs[j] <= rhs[j],
121 ComparisonOperator::GreaterThan => lhs[j] > rhs[j],
122 ComparisonOperator::GreaterThanOrEqualTo => lhs[j] >= rhs[j],
123 _ => false,
124 };
125 if res {
126 unsafe { output.set_unchecked(j, true) };
127 }
128 }
129
130 return Ok(());
131 }
132 }
133 }
135
136 for i in 0..len {
138 if has_nulls && !mask.map_or(true, |m| unsafe { m.get_unchecked(i) }) {
139 continue;
140 }
141 let res = match op {
142 ComparisonOperator::Equals => lhs[i] == rhs[i],
143 ComparisonOperator::NotEquals => lhs[i] != rhs[i],
144 ComparisonOperator::LessThan => lhs[i] < rhs[i],
145 ComparisonOperator::LessThanOrEqualTo => lhs[i] <= rhs[i],
146 ComparisonOperator::GreaterThan => lhs[i] > rhs[i],
147 ComparisonOperator::GreaterThanOrEqualTo => lhs[i] >= rhs[i],
148 _ => false,
149 };
150 if res {
151 unsafe { output.set_unchecked(i, true) };
152 }
153 }
154 Ok(())
155 }
156
157 #[inline(always)]
177 pub fn $fn_name(
178 lhs: &[$ty],
179 rhs: &[$ty],
180 mask: Option<&Bitmask>,
181 op: ComparisonOperator,
182 ) -> Result<BooleanArray<()>, KernelError> {
183 let len = lhs.len();
184 let mut out = new_bool_bitmask(len);
185 $fn_name_to(lhs, rhs, mask, op, &mut out)?;
186 Ok(BooleanArray {
187 data: out.into(),
188 null_mask: mask.cloned(),
189 len,
190 _phantom: PhantomData,
191 })
192 }
193 };
194}
195
196#[inline(always)]
201pub fn cmp_numeric_to<T: Numeric + Copy + 'static>(
202 lhs: &[T],
203 rhs: &[T],
204 mask: Option<&Bitmask>,
205 op: ComparisonOperator,
206 output: &mut Bitmask,
207) -> Result<(), KernelError> {
208 macro_rules! dispatch {
209 ($t:ty, $f:ident) => {
210 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<$t>() {
211 return $f(
212 unsafe { std::mem::transmute(lhs) },
213 unsafe { std::mem::transmute(rhs) },
214 mask,
215 op,
216 output,
217 );
218 }
219 };
220 }
221 dispatch!(i32, cmp_i32_to);
222 dispatch!(i64, cmp_i64_to);
223 dispatch!(u32, cmp_u32_to);
224 dispatch!(u64, cmp_u64_to);
225 dispatch!(f32, cmp_f32_to);
226 dispatch!(f64, cmp_f64_to);
227
228 unreachable!("Unsupported numeric type for compare_numeric");
229}
230
231#[inline(always)]
274pub fn cmp_numeric<T: Numeric + Copy + 'static>(
275 lhs: &[T],
276 rhs: &[T],
277 mask: Option<&Bitmask>,
278 op: ComparisonOperator,
279) -> Result<BooleanArray<()>, KernelError> {
280 let len = lhs.len();
281 let mut out = new_bool_bitmask(len);
282 cmp_numeric_to(lhs, rhs, mask, op, &mut out)?;
283 Ok(BooleanArray {
284 data: out.into(),
285 null_mask: mask.cloned(),
286 len,
287 _phantom: PhantomData,
288 })
289}
290
291#[cfg(feature = "simd")]
301pub fn cmp_bitmask_simd<const LANES: usize>(
302 lhs: BitmaskVT<'_>,
303 rhs: BitmaskVT<'_>,
304 mask: Option<BitmaskVT<'_>>,
305 op: ComparisonOperator,
306) -> Result<Bitmask, KernelError>
307where
308{
309 confirm_equal_len("compare bool length mismatch", lhs.2, rhs.2)?;
315 let (lhs_mask, lhs_offset, len) = lhs;
316 let (rhs_mask, rhs_offset, _) = rhs;
317
318 if matches!(op, ComparisonOperator::In | ComparisonOperator::NotIn) {
321 let mut out = match op {
322 ComparisonOperator::In => in_mask_simd::<LANES>(lhs, rhs),
323 ComparisonOperator::NotIn => not_in_mask_simd::<LANES>(lhs, rhs),
324 _ => unreachable!(),
325 };
326 if let Some(mask_slice) = mask {
327 out = and_masks_simd::<LANES>((&out, 0, out.len), mask_slice);
328 }
329 return Ok(out);
330 }
331
332 if lhs_offset % 64 != 0
334 || rhs_offset % 64 != 0
335 || mask.as_ref().map_or(false, |(_, mo, _)| mo % 64 != 0)
336 {
337 return Err(KernelError::InvalidArguments(format!(
338 "cmp_bitmask: all offsets must be 64-bit aligned (lhs: {}, rhs: {}, mask offset: {:?})",
339 lhs_offset,
340 rhs_offset,
341 mask.as_ref().map(|(_, mo, _)| mo)
342 )));
343 }
344
345 let lhs_word_start = lhs_offset / 64;
347 let rhs_word_start = rhs_offset / 64;
348 let n_words = (len + 63) / 64;
349
350 let mut out = Bitmask::new_set_all(len, false);
352
353 type Word = u64;
354 let lane_words = LANES;
355 let simd_chunks = n_words / lane_words;
356
357 let tail_words = n_words % lane_words;
358 let mut word_idx = 0;
359
360 for chunk in 0..simd_chunks {
362 let base_lhs = lhs_word_start + chunk * lane_words;
363 let base_rhs = rhs_word_start + chunk * lane_words;
364 let base_mask = mask
365 .as_ref()
366 .map(|(m, mask_word_start, _)| (m, mask_word_start + chunk * lane_words));
367
368 let mut lhs_arr = [0u64; LANES];
369 let mut rhs_arr = [0u64; LANES];
370 let mut mask_arr = [!0u64; LANES];
371
372 for lane in 0..LANES {
373 lhs_arr[lane] = unsafe { lhs_mask.word_unchecked(base_lhs + lane) };
374 rhs_arr[lane] = unsafe { rhs_mask.word_unchecked(base_rhs + lane) };
375 if let Some((m, mask_word_start)) = base_mask {
376 mask_arr[lane] = unsafe { m.word_unchecked(mask_word_start + lane) };
377 }
378 }
379 let lhs_v = Simd::<Word, LANES>::from_array(lhs_arr);
380 let rhs_v = Simd::<Word, LANES>::from_array(rhs_arr);
381 let mask_v = Simd::<Word, LANES>::from_array(mask_arr);
382
383 let cmp_v = match op {
384 ComparisonOperator::Equals => !(lhs_v ^ rhs_v),
385 ComparisonOperator::NotEquals => lhs_v ^ rhs_v,
386 ComparisonOperator::GreaterThan => lhs_v & (!rhs_v),
387 ComparisonOperator::LessThan => (!lhs_v) & rhs_v,
388 ComparisonOperator::GreaterThanOrEqualTo => lhs_v | (!rhs_v),
389 ComparisonOperator::LessThanOrEqualTo => (!lhs_v) | rhs_v,
390 _ => Simd::splat(0),
391 };
392 let result_v = cmp_v & mask_v;
393
394 for lane in 0..LANES {
395 unsafe {
396 out.set_word_unchecked(word_idx, result_v[lane]);
397 }
398 word_idx += 1;
399 }
400 }
401
402 let base_lhs = lhs_word_start + simd_chunks * lane_words;
404 let base_rhs = rhs_word_start + simd_chunks * lane_words;
405 let base_mask: Option<(&Bitmask, usize)> = mask
406 .as_ref()
407 .map(|(m, mo, _)| (*m, mo + simd_chunks * lane_words));
408
409 for tail in 0..tail_words {
410 let a = unsafe { lhs_mask.word_unchecked(base_lhs + tail) };
411 let b = unsafe { rhs_mask.word_unchecked(base_rhs + tail) };
412 let m = if let Some((m, mask_word_start)) = base_mask {
413 unsafe { m.word_unchecked(mask_word_start + tail) }
414 } else {
415 !0u64
416 };
417 let cmp = match op {
418 ComparisonOperator::Equals => !(a ^ b),
419 ComparisonOperator::NotEquals => a ^ b,
420 ComparisonOperator::GreaterThan => a & (!b),
421 ComparisonOperator::LessThan => (!a) & b,
422 ComparisonOperator::GreaterThanOrEqualTo => a | (!b),
423 ComparisonOperator::LessThanOrEqualTo => (!a) | b,
424 _ => 0,
425 } & m;
426 unsafe {
427 out.set_word_unchecked(word_idx, cmp);
428 }
429 word_idx += 1;
430 }
431
432 out.mask_trailing_bits();
433 Ok(out)
434}
435
436pub fn cmp_bool<const LANES: usize>(
453 lhs: BooleanAVT<'_, ()>,
454 rhs: BooleanAVT<'_, ()>,
455 op: ComparisonOperator,
456) -> Result<BooleanArray<()>, KernelError>
457where
458{
459 let (lhs_arr, lhs_off, len) = lhs;
460 let (rhs_arr, rhs_off, rlen) = rhs;
461 debug_assert_eq!(len, rlen, "cmp_bool: window length mismatch");
462
463 #[cfg(feature = "simd")]
464 let merged_null_mask: Option<Bitmask> =
465 match (lhs_arr.null_mask.as_ref(), rhs_arr.null_mask.as_ref()) {
466 (None, None) => None,
467 (Some(m), None) | (None, Some(m)) => Some(m.slice_clone(lhs_off, len)),
468 (Some(a), Some(b)) => {
469 let am = (a, lhs_off, len);
470 let bm = (b, rhs_off, len);
471 Some(and_masks_simd::<LANES>(am, bm))
472 }
473 };
474
475 #[cfg(not(feature = "simd"))]
476 let merged_null_mask: Option<Bitmask> =
477 match (lhs_arr.null_mask.as_ref(), rhs_arr.null_mask.as_ref()) {
478 (None, None) => None,
479 (Some(m), None) | (None, Some(m)) => Some(m.slice_clone(lhs_off, len)),
480 (Some(a), Some(b)) => {
481 let am = (a, lhs_off, len);
482 let bm = (b, rhs_off, len);
483 Some(and_masks(am, bm))
484 }
485 };
486
487 let mask_slice = merged_null_mask.as_ref().map(|m| (m, 0, len));
488
489 let data = match op {
490 ComparisonOperator::Equals
491 | ComparisonOperator::NotEquals
492 | ComparisonOperator::LessThan
493 | ComparisonOperator::LessThanOrEqualTo
494 | ComparisonOperator::GreaterThan
495 | ComparisonOperator::GreaterThanOrEqualTo
496 | ComparisonOperator::In
497 | ComparisonOperator::NotIn => {
498 #[cfg(feature = "simd")]
499 let res = cmp_bitmask_simd::<LANES>(
500 (&lhs_arr.data, lhs_off, len),
501 (&rhs_arr.data, rhs_off, len),
502 mask_slice,
503 op,
504 )?;
505 #[cfg(not(feature = "simd"))]
506 let res = cmp_bitmask_std(
507 (&lhs_arr.data, lhs_off, len),
508 (&rhs_arr.data, rhs_off, len),
509 mask_slice,
510 op,
511 )?;
512 res
513 }
514 ComparisonOperator::IsNull => {
515 #[cfg(feature = "simd")]
516 let data = match merged_null_mask.as_ref() {
517 Some(mask) => not_mask_simd::<LANES>((mask, 0, len)),
518 None => Bitmask::new_set_all(len, false),
519 };
520 #[cfg(not(feature = "simd"))]
521 let data = match merged_null_mask.as_ref() {
522 Some(mask) => not_mask((mask, 0, len)),
523 None => Bitmask::new_set_all(len, false),
524 };
525 return Ok(BooleanArray {
526 data,
527 null_mask: None,
528 len,
529 _phantom: PhantomData,
530 });
531 }
532 ComparisonOperator::IsNotNull => {
533 let data = match merged_null_mask.as_ref() {
534 Some(mask) => mask.slice_clone(0, len),
535 None => Bitmask::new_set_all(len, true),
536 };
537 return Ok(BooleanArray {
538 data,
539 null_mask: None,
540 len,
541 _phantom: PhantomData,
542 });
543 }
544 ComparisonOperator::Between => {
545 return Err(KernelError::InvalidArguments(
546 "Set operations are not defined for Bool arrays".to_owned(),
547 ));
548 }
549 };
550
551 Ok(BooleanArray {
552 data,
553 null_mask: merged_null_mask,
554 len,
555 _phantom: PhantomData,
556 })
557}
558
559#[cfg(not(feature = "simd"))]
567pub fn cmp_bitmask_std(
568 lhs: BitmaskVT<'_>,
569 rhs: BitmaskVT<'_>,
570 mask: Option<BitmaskVT<'_>>,
571 op: ComparisonOperator,
572) -> Result<Bitmask, KernelError> {
573 confirm_equal_len("compare bool length mismatch", lhs.2, rhs.2)?;
579 let (lhs_mask, lhs_offset, len) = lhs;
580 let (rhs_mask, rhs_offset, _) = rhs;
581
582 if matches!(op, ComparisonOperator::In | ComparisonOperator::NotIn) {
585 let mut out = match op {
586 ComparisonOperator::In => in_mask(lhs, rhs),
587 ComparisonOperator::NotIn => not_in_mask(lhs, rhs),
588 _ => unreachable!(),
589 };
590 if let Some(mask_slice) = mask {
591 out = and_masks((&out, 0, out.len), mask_slice);
592 }
593 return Ok(out);
594 }
595
596 if lhs_offset % 64 != 0
598 || rhs_offset % 64 != 0
599 || mask.as_ref().map_or(false, |(_, mo, _)| mo % 64 != 0)
600 {
601 return Err(KernelError::InvalidArguments(format!(
602 "cmp_bitmask: all offsets must be 64-bit aligned (lhs: {}, rhs: {}, mask offset: {:?})",
603 lhs_offset,
604 rhs_offset,
605 mask.as_ref().map(|(_, mo, _)| mo)
606 )));
607 }
608
609 let lhs_word_start = lhs_offset / 64;
611 let rhs_word_start = rhs_offset / 64;
612 let n_words = (len + 63) / 64;
613
614 let mut out = Bitmask::new_set_all(len, false);
616
617 let words = n_words;
618 let tail = len % 64;
619 let mask_mask_opt = mask;
620
621 for w in 0..words {
623 let a = unsafe { lhs_mask.word_unchecked(lhs_word_start + w) };
624 let b = unsafe { rhs_mask.word_unchecked(rhs_word_start + w) };
625 let valid_bits =
626 mask_mask_opt
627 .as_ref()
628 .map_or(!0u64, |(mask_mask, mask_word_start, _)| unsafe {
629 mask_mask.word_unchecked(mask_word_start + w)
630 });
631 let word_cmp = match op {
632 ComparisonOperator::Equals => !(a ^ b),
633 ComparisonOperator::NotEquals => a ^ b,
634 ComparisonOperator::GreaterThan => a & (!b),
635 ComparisonOperator::LessThan => (!a) & b,
636 ComparisonOperator::GreaterThanOrEqualTo => a | (!b),
637 ComparisonOperator::LessThanOrEqualTo => (!a) | b,
638 _ => 0,
639 };
640 let final_bits = word_cmp & valid_bits;
641 unsafe {
642 out.set_word_unchecked(w, final_bits);
643 }
644 }
645
646 let base = words * 64;
649 for i in 0..tail {
650 let idx_lhs = lhs_offset + base + i;
651 let idx_rhs = rhs_offset + base + i;
652 let mask_valid =
653 mask_mask_opt
654 .as_ref()
655 .map_or(true, |(mask_mask, mask_word_start, mask_len)| unsafe {
656 let mask_idx = mask_word_start * 64 + base + i;
657 if mask_idx < *mask_len {
658 mask_mask.get_unchecked(mask_idx)
659 } else {
660 false
661 }
662 });
663 if !mask_valid {
664 continue;
665 }
666 if idx_lhs >= lhs_mask.len() || idx_rhs >= rhs_mask.len() {
667 continue;
668 }
669 let a = unsafe { lhs_mask.get_unchecked(idx_lhs) };
670 let b = unsafe { rhs_mask.get_unchecked(idx_rhs) };
671 let res = match op {
672 ComparisonOperator::Equals => a == b,
673 ComparisonOperator::NotEquals => a != b,
674 ComparisonOperator::GreaterThan => a & !b,
675 ComparisonOperator::LessThan => !a & b,
676 ComparisonOperator::GreaterThanOrEqualTo => a | !b,
677 ComparisonOperator::LessThanOrEqualTo => !a | b,
678 _ => false,
679 };
680 if res {
681 out.set(base + i, true)
682 }
683 }
684 out.mask_trailing_bits();
685 Ok(out)
686}
687
688macro_rules! impl_cmp_utf8_slice {
691 ($fn_name:ident, $fn_name_to:ident, $lhs_slice:ty, $rhs_slice:ty, [$($gen:tt)+]) => {
692 #[inline(always)]
697 pub fn $fn_name_to<$($gen)+>(
698 lhs: $lhs_slice,
699 rhs: $rhs_slice,
700 op: ComparisonOperator,
701 output: &mut Bitmask,
702 ) -> Result<(), KernelError> {
703 let (larr, loff, llen) = lhs;
704 let (rarr, roff, rlen) = rhs;
705 confirm_equal_len("compare string/dict length mismatch (slice contract)", llen, rlen)?;
706 assert!(output.capacity() >= llen, concat!(stringify!($fn_name_to), ": output capacity too small"));
707
708 let lhs_mask = larr.null_mask.as_ref().map(|m| m.slice_clone(loff, llen));
709 let rhs_mask = rarr.null_mask.as_ref().map(|m| m.slice_clone(roff, rlen));
710
711 if let Some(m) = larr.null_mask.as_ref() {
712 if m.capacity() < loff + llen {
713 return Err(KernelError::InvalidArguments(
714 format!(
715 "lhs mask capacity too small (expected ≥ {}, got {})",
716 loff + llen,
717 m.capacity()
718 ),
719 ));
720 }
721 }
722 if let Some(m) = rarr.null_mask.as_ref() {
723 if m.capacity() < roff + rlen {
724 return Err(KernelError::InvalidArguments(
725 format!(
726 "rhs mask capacity too small (expected ≥ {}, got {})",
727 roff + rlen,
728 m.capacity()
729 ),
730 ));
731 }
732 }
733
734 let has_nulls = lhs_mask.is_some() || rhs_mask.is_some();
735 for i in 0..llen {
736 if has_nulls
737 && !(lhs_mask.as_ref().map_or(true, |m| unsafe { m.get_unchecked(i) })
738 && rhs_mask.as_ref().map_or(true, |m| unsafe { m.get_unchecked(i) }))
739 {
740 continue;
741 }
742 let l = unsafe { larr.get_str_unchecked(loff + i) };
743 let r = unsafe { rarr.get_str_unchecked(roff + i) };
744 let res = match op {
745 ComparisonOperator::Equals => l == r,
746 ComparisonOperator::NotEquals => l != r,
747 ComparisonOperator::GreaterThan => l > r,
748 ComparisonOperator::LessThan => l < r,
749 ComparisonOperator::GreaterThanOrEqualTo => l >= r,
750 ComparisonOperator::LessThanOrEqualTo => l <= r,
751 _ => false,
752 };
753 if res {
754 output.set(i, true);
755 }
756 }
757 Ok(())
758 }
759
760 #[inline(always)]
762 pub fn $fn_name<$($gen)+>(
763 lhs: $lhs_slice,
764 rhs: $rhs_slice,
765 op: ComparisonOperator,
766 ) -> Result<BooleanArray<()>, KernelError> {
767 let (larr, loff, llen) = lhs;
768 let (rarr, roff, _) = rhs;
769 let lhs_mask = larr.null_mask.as_ref().map(|m| m.slice_clone(loff, llen));
770 let rhs_mask = rarr.null_mask.as_ref().map(|m| m.slice_clone(roff, llen));
771 let mut out = new_bool_bitmask(llen);
772 $fn_name_to((larr, loff, llen), (rarr, roff, llen), op, &mut out)?;
773 let null_mask = merge_bitmasks_to_new(lhs_mask.as_ref(), rhs_mask.as_ref(), llen);
774 Ok(BooleanArray { data: out.into(), null_mask, len: llen, _phantom: PhantomData })
775 }
776 };
777}
778
779impl_cmp_numeric!(cmp_i32, cmp_i32_to, i32, W32, i32);
780impl_cmp_numeric!(cmp_u32, cmp_u32_to, u32, W32, i32);
781impl_cmp_numeric!(cmp_i64, cmp_i64_to, i64, W64, i64);
782impl_cmp_numeric!(cmp_u64, cmp_u64_to, u64, W64, i64);
783impl_cmp_numeric!(cmp_f32, cmp_f32_to, f32, W32, i32);
784impl_cmp_numeric!(cmp_f64, cmp_f64_to, f64, W64, i64);
785impl_cmp_utf8_slice!(cmp_str_str, cmp_str_str_to, StringAVT<'a, T>, StringAVT<'a, T>, [ 'a, T: Integer ]);
786impl_cmp_utf8_slice!(cmp_str_dict, cmp_str_dict_to, StringAVT<'a, T>, CategoricalAVT<'a, U>, [ 'a, T: Integer, U: Integer ]);
787impl_cmp_utf8_slice!(cmp_dict_str, cmp_dict_str_to, CategoricalAVT<'a, T>, StringAVT<'a, U>, [ 'a, T: Integer, U: Integer ]);
788impl_cmp_utf8_slice!(cmp_dict_dict, cmp_dict_dict_to, CategoricalAVT<'a, T>, CategoricalAVT<'a, T>, [ 'a, T: Integer ]);
789
790#[cfg(test)]
791mod tests {
792 use minarrow::{Bitmask, BooleanArray, CategoricalArray, Integer, StringArray, vec64};
793
794 use crate::kernels::comparison::{
795 cmp_dict_dict, cmp_dict_str, cmp_i32, cmp_numeric, cmp_str_dict,
796 };
797
798 #[cfg(feature = "simd")]
799 use crate::kernels::comparison::{W64, cmp_bitmask_simd};
800
801 use crate::operators::ComparisonOperator;
802
803 fn bm(bits: &[bool]) -> Bitmask {
806 let mut m = Bitmask::new_set_all(bits.len(), false);
807 for (i, &b) in bits.iter().enumerate() {
808 m.set(i, b);
809 }
810 m
811 }
812
813 fn assert_bool(arr: &BooleanArray<()>, expect: &[bool], expect_mask: Option<&[bool]>) {
815 assert_eq!(arr.len, expect.len());
816 for i in 0..expect.len() {
817 assert_eq!(arr.data.get(i), expect[i], "value bit {i}");
818 }
819 match (arr.null_mask.as_ref(), expect_mask) {
820 (None, None) => {}
821 (Some(m), Some(exp)) => {
822 for (i, &b) in exp.iter().enumerate() {
823 assert_eq!(m.get(i), b, "null-bit {i}");
824 }
825 }
826 _ => panic!("mask mismatch"),
827 }
828 }
829
830 fn str_arr<T: Integer>(v: &[&str]) -> StringArray<T> {
832 StringArray::<T>::from_slice(v)
833 }
834
835 fn dict_arr<T: Integer>(vals: &[&str]) -> CategoricalArray<T> {
836 let owned: Vec<&str> = vals.to_vec();
837 CategoricalArray::<T>::from_values(owned)
838 }
839
840 #[test]
843 fn numeric_compare_no_nulls() {
844 let a = vec64![1i32, 2, 3, 4];
845 let b = vec64![1i32, 1, 4, 4];
846
847 let eq = cmp_i32(&a, &b, None, ComparisonOperator::Equals).unwrap();
848 let neq = cmp_i32(&a, &b, None, ComparisonOperator::NotEquals).unwrap();
849 let lt = cmp_i32(&a, &b, None, ComparisonOperator::LessThan).unwrap();
850 let le = cmp_i32(&a, &b, None, ComparisonOperator::LessThanOrEqualTo).unwrap();
851 let gt = cmp_i32(&a, &b, None, ComparisonOperator::GreaterThan).unwrap();
852 let ge = cmp_i32(&a, &b, None, ComparisonOperator::GreaterThanOrEqualTo).unwrap();
853
854 assert_bool(&eq, &[true, false, false, true], None);
855 assert_bool(&neq, &[false, true, true, false], None);
856 assert_bool(<, &[false, false, true, false], None);
857 assert_bool(&le, &[true, false, true, true], None);
858 assert_bool(>, &[false, true, false, false], None);
859 assert_bool(&ge, &[true, true, false, true], None);
860 }
861
862 #[test]
863 fn numeric_compare_with_nulls_generic_dispatch() {
864 let a = vec64![1u64, 5, 9, 10];
866 let b = vec64![0u64, 5, 8, 11];
867 let mask = bm(&[true, true, true, false]);
868
869 let out = cmp_numeric(&a, &b, Some(&mask), ComparisonOperator::GreaterThan).unwrap();
870 assert_bool(
872 &out,
873 &[true, false, true, false],
874 Some(&[true, true, true, false]),
875 );
876 }
877
878 #[cfg(feature = "simd")]
881 #[test]
882 fn bool_compare_all_ops() {
883 let a = bm(&[true, false, true, false]);
884 let b = bm(&[true, true, false, false]);
885 let eq = cmp_bitmask_simd::<W64>(
886 (&a, 0, a.len()),
887 (&b, 0, b.len()),
888 None,
889 ComparisonOperator::Equals,
890 )
891 .unwrap();
892 let lt = cmp_bitmask_simd::<W64>(
893 (&a, 0, a.len()),
894 (&b, 0, b.len()),
895 None,
896 ComparisonOperator::LessThan,
897 )
898 .unwrap();
899 let gt = cmp_bitmask_simd::<W64>(
900 (&a, 0, a.len()),
901 (&b, 0, b.len()),
902 None,
903 ComparisonOperator::GreaterThan,
904 )
905 .unwrap();
906
907 assert_bool(
908 &BooleanArray::from_bitmask(eq, None),
909 &[true, false, false, true],
910 None,
911 );
912 assert_bool(
913 &BooleanArray::from_bitmask(lt, None),
914 &[false, true, false, false],
915 None,
916 );
917 assert_bool(
918 &BooleanArray::from_bitmask(gt, None),
919 &[false, false, true, false],
920 None,
921 );
922 }
923
924 #[test]
927 fn string_vs_dict_compare_with_nulls() {
928 let mut lhs = str_arr::<u32>(&["x", "y", "z"]);
929 lhs.null_mask = Some(bm(&[true, false, true]));
930 let rhs = dict_arr::<u32>(&["x", "w", "zz"]);
931 let lhs_slice = (&lhs, 0, lhs.len());
932 let rhs_slice = (&rhs, 0, rhs.data.len());
933 let res = cmp_str_dict(lhs_slice, rhs_slice, ComparisonOperator::Equals).unwrap();
934 assert_bool(&res, &[true, false, false], Some(&[true, false, true]));
935 }
936
937 #[test]
938 fn string_vs_dict_compare_with_nulls_chunk() {
939 let mut lhs = str_arr::<u32>(&["pad", "x", "y", "z", "pad"]);
940 lhs.null_mask = Some(bm(&[true, true, false, true, true]));
941 let rhs = dict_arr::<u32>(&["pad", "x", "w", "zz", "pad"]);
942 let lhs_slice = (&lhs, 1, 3);
943 let rhs_slice = (&rhs, 1, 3);
944 let res = cmp_str_dict(lhs_slice, rhs_slice, ComparisonOperator::Equals).unwrap();
945 assert_bool(&res, &[true, false, false], Some(&[true, false, true]));
946 }
947
948 #[test]
949 fn dict_vs_dict_compare_gt() {
950 let lhs = dict_arr::<u32>(&["apple", "pear", "banana"]);
951 let rhs = dict_arr::<u32>(&["ant", "pear", "apricot"]);
952 let lhs_slice = (&lhs, 0, lhs.data.len());
953 let rhs_slice = (&rhs, 0, rhs.data.len());
954 let res = cmp_dict_dict(lhs_slice, rhs_slice, ComparisonOperator::GreaterThan).unwrap();
955 assert_bool(&res, &[true, false, true], None);
956 }
957
958 #[test]
959 fn dict_vs_dict_compare_gt_chunk() {
960 let lhs = dict_arr::<u32>(&["pad", "apple", "pear", "banana", "pad"]);
961 let rhs = dict_arr::<u32>(&["pad", "ant", "pear", "apricot", "pad"]);
962 let lhs_slice = (&lhs, 1, 3);
963 let rhs_slice = (&rhs, 1, 3);
964 let res = cmp_dict_dict(lhs_slice, rhs_slice, ComparisonOperator::GreaterThan).unwrap();
965 assert_bool(&res, &[true, false, true], None);
966 }
967
968 #[test]
969 fn dict_vs_string_compare_le() {
970 let lhs = dict_arr::<u32>(&["a", "b", "c"]);
971 let rhs = str_arr::<u32>(&["b", "b", "d"]);
972 let lhs_slice = (&lhs, 0, lhs.data.len());
973 let rhs_slice = (&rhs, 0, rhs.len());
974 let res =
975 cmp_dict_str(lhs_slice, rhs_slice, ComparisonOperator::LessThanOrEqualTo).unwrap();
976 assert_bool(&res, &[true, true, true], None);
977 }
978
979 #[test]
980 fn dict_vs_string_compare_le_chunk() {
981 let lhs = dict_arr::<u32>(&["pad", "a", "b", "c", "pad"]);
982 let rhs = str_arr::<u32>(&["pad", "b", "b", "d", "pad"]);
983 let lhs_slice = (&lhs, 1, 3);
984 let rhs_slice = (&rhs, 1, 3);
985 let res =
986 cmp_dict_str(lhs_slice, rhs_slice, ComparisonOperator::LessThanOrEqualTo).unwrap();
987 assert_bool(&res, &[true, true, true], None);
988 }
989}