Skip to main content

polars_compute/comparisons/
list.rs

1use arrow::array::{
2    Array, BinaryArray, BinaryViewArray, BooleanArray, DictionaryArray, FixedSizeBinaryArray,
3    ListArray, NullArray, PrimitiveArray, StructArray, Utf8Array, Utf8ViewArray,
4};
5use arrow::bitmap::Bitmap;
6use arrow::legacy::utils::CustomIterTools;
7use arrow::types::{Offset, days_ms, i256, months_days_ns};
8use polars_utils::float16::pf16;
9
10use super::TotalEqKernel;
11
12macro_rules! compare {
13    (
14        $lhs:expr, $rhs:expr,
15        $op:path, $true_op:expr,
16        $ineq_len_rv:literal, $invalid_rv:literal
17    ) => {{
18        let lhs = $lhs;
19        let rhs = $rhs;
20
21        assert_eq!(lhs.len(), rhs.len());
22        assert_eq!(lhs.dtype(), rhs.dtype());
23
24        macro_rules! call_binary {
25            ($T:ty) => {{
26                let lhs_values: &$T = $lhs.values().as_any().downcast_ref().unwrap();
27                let rhs_values: &$T = $rhs.values().as_any().downcast_ref().unwrap();
28
29                (0..$lhs.len())
30                    .map(|i| {
31                        let lval = $lhs.validity().is_none_or(|v| v.get(i).unwrap());
32                        let rval = $rhs.validity().is_none_or(|v| v.get(i).unwrap());
33
34                        if !lval || !rval {
35                            return $invalid_rv;
36                        }
37
38                        // SAFETY: ListArray's invariant offsets.len_proxy() == len
39                        let (lstart, lend) = unsafe { $lhs.offsets().start_end_unchecked(i) };
40                        let (rstart, rend) = unsafe { $rhs.offsets().start_end_unchecked(i) };
41
42                        if lend - lstart != rend - rstart {
43                            return $ineq_len_rv;
44                        }
45
46                        let mut lhs_values = lhs_values.clone();
47                        lhs_values.slice(lstart, lend - lstart);
48                        let mut rhs_values = rhs_values.clone();
49                        rhs_values.slice(rstart, rend - rstart);
50
51                        $true_op($op(&lhs_values, &rhs_values))
52                    })
53                    .collect_trusted()
54            }};
55        }
56
57        use arrow::datatypes::{IntegerType as I, PhysicalType as PH, PrimitiveType as PR};
58        match lhs.values().dtype().to_physical_type() {
59            PH::Boolean => call_binary!(BooleanArray),
60            PH::BinaryView => call_binary!(BinaryViewArray),
61            PH::Utf8View => call_binary!(Utf8ViewArray),
62            PH::Primitive(PR::Int8) => call_binary!(PrimitiveArray<i8>),
63            PH::Primitive(PR::Int16) => call_binary!(PrimitiveArray<i16>),
64            PH::Primitive(PR::Int32) => call_binary!(PrimitiveArray<i32>),
65            PH::Primitive(PR::Int64) => call_binary!(PrimitiveArray<i64>),
66            PH::Primitive(PR::Int128) => call_binary!(PrimitiveArray<i128>),
67            PH::Primitive(PR::UInt8) => call_binary!(PrimitiveArray<u8>),
68            PH::Primitive(PR::UInt16) => call_binary!(PrimitiveArray<u16>),
69            PH::Primitive(PR::UInt32) => call_binary!(PrimitiveArray<u32>),
70            PH::Primitive(PR::UInt64) => call_binary!(PrimitiveArray<u64>),
71            PH::Primitive(PR::UInt128) => call_binary!(PrimitiveArray<u128>),
72            PH::Primitive(PR::Float16) => call_binary!(PrimitiveArray<pf16>),
73            PH::Primitive(PR::Float32) => call_binary!(PrimitiveArray<f32>),
74            PH::Primitive(PR::Float64) => call_binary!(PrimitiveArray<f64>),
75            PH::Primitive(PR::Int256) => call_binary!(PrimitiveArray<i256>),
76            PH::Primitive(PR::DaysMs) => call_binary!(PrimitiveArray<days_ms>),
77            PH::Primitive(PR::MonthDayNano) => {
78                call_binary!(PrimitiveArray<months_days_ns>)
79            },
80            PH::Primitive(PR::MonthDayMillis) => unimplemented!(),
81
82            #[cfg(feature = "dtype-array")]
83            PH::FixedSizeList => call_binary!(arrow::array::FixedSizeListArray),
84            #[cfg(not(feature = "dtype-array"))]
85            PH::FixedSizeList => todo!(
86                "Comparison of FixedSizeListArray is not supported without dtype-array feature"
87            ),
88
89            PH::Null => call_binary!(NullArray),
90            PH::FixedSizeBinary => call_binary!(FixedSizeBinaryArray),
91            PH::Binary => call_binary!(BinaryArray<i32>),
92            PH::LargeBinary => call_binary!(BinaryArray<i64>),
93            PH::Utf8 => call_binary!(Utf8Array<i32>),
94            PH::LargeUtf8 => call_binary!(Utf8Array<i64>),
95            PH::List => call_binary!(ListArray<i32>),
96            PH::LargeList => call_binary!(ListArray<i64>),
97            PH::Struct => call_binary!(StructArray),
98            PH::Union => todo!("Comparison of UnionArrays is not yet supported"),
99            PH::Map => todo!("Comparison of MapArrays is not yet supported"),
100            PH::Dictionary(I::Int8) => call_binary!(DictionaryArray<i8>),
101            PH::Dictionary(I::Int16) => call_binary!(DictionaryArray<i16>),
102            PH::Dictionary(I::Int32) => call_binary!(DictionaryArray<i32>),
103            PH::Dictionary(I::Int64) => call_binary!(DictionaryArray<i64>),
104            PH::Dictionary(I::Int128) => call_binary!(DictionaryArray<i128>),
105            PH::Dictionary(I::UInt8) => call_binary!(DictionaryArray<u8>),
106            PH::Dictionary(I::UInt16) => call_binary!(DictionaryArray<u16>),
107            PH::Dictionary(I::UInt32) => call_binary!(DictionaryArray<u32>),
108            PH::Dictionary(I::UInt64) => call_binary!(DictionaryArray<u64>),
109            PH::Dictionary(I::UInt128) => call_binary!(DictionaryArray<u128>),
110        }
111    }};
112}
113
114macro_rules! compare_broadcast {
115    (
116        $lhs:expr, $rhs:expr,
117        $offsets:expr, $validity:expr,
118        $op:path, $true_op:expr,
119        $ineq_len_rv:literal, $invalid_rv:literal
120    ) => {{
121        let lhs = $lhs;
122        let rhs = $rhs;
123
124        macro_rules! call_binary {
125            ($T:ty) => {{
126                let values: &$T = $lhs.as_any().downcast_ref().unwrap();
127                let scalar: &$T = $rhs.as_any().downcast_ref().unwrap();
128
129                let length = $offsets.len_proxy();
130
131                (0..length)
132                    .map(move |i| {
133                        let v = $validity.is_none_or(|v| v.get(i).unwrap());
134
135                        if !v {
136                            return $invalid_rv;
137                        }
138
139                        let (start, end) = unsafe { $offsets.start_end_unchecked(i) };
140
141                        if end - start != scalar.len() {
142                            return $ineq_len_rv;
143                        }
144
145                        // @TODO: I feel like there is a better way to do this.
146                        let mut values: $T = values.clone();
147                        <$T>::slice(&mut values, start, end - start);
148
149                        $true_op($op(&values, scalar))
150                    })
151                    .collect_trusted()
152            }};
153        }
154
155        assert_eq!(lhs.dtype(), rhs.dtype());
156
157        use arrow::datatypes::{IntegerType as I, PhysicalType as PH, PrimitiveType as PR};
158        match lhs.dtype().to_physical_type() {
159            PH::Boolean => call_binary!(BooleanArray),
160            PH::BinaryView => call_binary!(BinaryViewArray),
161            PH::Utf8View => call_binary!(Utf8ViewArray),
162            PH::Primitive(PR::Int8) => call_binary!(PrimitiveArray<i8>),
163            PH::Primitive(PR::Int16) => call_binary!(PrimitiveArray<i16>),
164            PH::Primitive(PR::Int32) => call_binary!(PrimitiveArray<i32>),
165            PH::Primitive(PR::Int64) => call_binary!(PrimitiveArray<i64>),
166            PH::Primitive(PR::Int128) => call_binary!(PrimitiveArray<i128>),
167            PH::Primitive(PR::UInt8) => call_binary!(PrimitiveArray<u8>),
168            PH::Primitive(PR::UInt16) => call_binary!(PrimitiveArray<u16>),
169            PH::Primitive(PR::UInt32) => call_binary!(PrimitiveArray<u32>),
170            PH::Primitive(PR::UInt64) => call_binary!(PrimitiveArray<u64>),
171            PH::Primitive(PR::UInt128) => call_binary!(PrimitiveArray<u128>),
172            PH::Primitive(PR::Float16) => call_binary!(PrimitiveArray<pf16>),
173            PH::Primitive(PR::Float32) => call_binary!(PrimitiveArray<f32>),
174            PH::Primitive(PR::Float64) => call_binary!(PrimitiveArray<f64>),
175            PH::Primitive(PR::Int256) => call_binary!(PrimitiveArray<i256>),
176            PH::Primitive(PR::DaysMs) => call_binary!(PrimitiveArray<days_ms>),
177            PH::Primitive(PR::MonthDayNano) => {
178                call_binary!(PrimitiveArray<months_days_ns>)
179            },
180            PH::Primitive(PR::MonthDayMillis) => unimplemented!(),
181
182            #[cfg(feature = "dtype-array")]
183            PH::FixedSizeList => call_binary!(arrow::array::FixedSizeListArray),
184            #[cfg(not(feature = "dtype-array"))]
185            PH::FixedSizeList => todo!(
186                "Comparison of FixedSizeListArray is not supported without dtype-array feature"
187            ),
188
189            PH::Null => call_binary!(NullArray),
190            PH::FixedSizeBinary => call_binary!(FixedSizeBinaryArray),
191            PH::Binary => call_binary!(BinaryArray<i32>),
192            PH::LargeBinary => call_binary!(BinaryArray<i64>),
193            PH::Utf8 => call_binary!(Utf8Array<i32>),
194            PH::LargeUtf8 => call_binary!(Utf8Array<i64>),
195            PH::List => call_binary!(ListArray<i32>),
196            PH::LargeList => call_binary!(ListArray<i64>),
197            PH::Struct => call_binary!(StructArray),
198            PH::Union => todo!("Comparison of UnionArrays is not yet supported"),
199            PH::Map => todo!("Comparison of MapArrays is not yet supported"),
200            PH::Dictionary(I::Int8) => call_binary!(DictionaryArray<i8>),
201            PH::Dictionary(I::Int16) => call_binary!(DictionaryArray<i16>),
202            PH::Dictionary(I::Int32) => call_binary!(DictionaryArray<i32>),
203            PH::Dictionary(I::Int64) => call_binary!(DictionaryArray<i64>),
204            PH::Dictionary(I::Int128) => call_binary!(DictionaryArray<i128>),
205            PH::Dictionary(I::UInt8) => call_binary!(DictionaryArray<u8>),
206            PH::Dictionary(I::UInt16) => call_binary!(DictionaryArray<u16>),
207            PH::Dictionary(I::UInt32) => call_binary!(DictionaryArray<u32>),
208            PH::Dictionary(I::UInt64) => call_binary!(DictionaryArray<u64>),
209            PH::Dictionary(I::UInt128) => call_binary!(DictionaryArray<u128>),
210        }
211    }};
212}
213
214impl<O: Offset> TotalEqKernel for ListArray<O> {
215    type Scalar = Box<dyn Array>;
216
217    fn tot_eq_kernel(&self, other: &Self) -> Bitmap {
218        compare!(
219            self,
220            other,
221            TotalEqKernel::tot_eq_missing_kernel,
222            |bm: Bitmap| bm.unset_bits() == 0,
223            false,
224            true
225        )
226    }
227
228    fn tot_ne_kernel(&self, other: &Self) -> Bitmap {
229        compare!(
230            self,
231            other,
232            TotalEqKernel::tot_ne_missing_kernel,
233            |bm: Bitmap| bm.set_bits() > 0,
234            true,
235            false
236        )
237    }
238
239    fn tot_eq_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {
240        compare_broadcast!(
241            self.values().as_ref(),
242            other.as_ref(),
243            self.offsets(),
244            self.validity(),
245            TotalEqKernel::tot_eq_missing_kernel,
246            |bm: Bitmap| bm.unset_bits() == 0,
247            false,
248            true
249        )
250    }
251
252    fn tot_ne_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {
253        compare_broadcast!(
254            self.values().as_ref(),
255            other.as_ref(),
256            self.offsets(),
257            self.validity(),
258            TotalEqKernel::tot_ne_missing_kernel,
259            |bm: Bitmap| bm.set_bits() > 0,
260            true,
261            false
262        )
263    }
264}