polars_compute/comparisons/
list.rs1use arrow::array::{
2 Array, BinaryArray, BinaryViewArray, BooleanArray, DictionaryArray, FixedSizeBinaryArray,
3 ListArray, NullArray, PrimitiveArray, StructArray, Utf8Array, Utf8ViewArray,
4};
5use arrow::bitmap::Bitmap;
6use arrow::legacy::utils::CustomIterTools;
7use arrow::types::{Offset, days_ms, i256, months_days_ns};
8use polars_utils::float16::pf16;
9
10use super::TotalEqKernel;
11
12macro_rules! compare {
13 (
14 $lhs:expr, $rhs:expr,
15 $op:path, $true_op:expr,
16 $ineq_len_rv:literal, $invalid_rv:literal
17 ) => {{
18 let lhs = $lhs;
19 let rhs = $rhs;
20
21 assert_eq!(lhs.len(), rhs.len());
22 assert_eq!(lhs.dtype(), rhs.dtype());
23
24 macro_rules! call_binary {
25 ($T:ty) => {{
26 let lhs_values: &$T = $lhs.values().as_any().downcast_ref().unwrap();
27 let rhs_values: &$T = $rhs.values().as_any().downcast_ref().unwrap();
28
29 (0..$lhs.len())
30 .map(|i| {
31 let lval = $lhs.validity().is_none_or(|v| v.get(i).unwrap());
32 let rval = $rhs.validity().is_none_or(|v| v.get(i).unwrap());
33
34 if !lval || !rval {
35 return $invalid_rv;
36 }
37
38 let (lstart, lend) = unsafe { $lhs.offsets().start_end_unchecked(i) };
40 let (rstart, rend) = unsafe { $rhs.offsets().start_end_unchecked(i) };
41
42 if lend - lstart != rend - rstart {
43 return $ineq_len_rv;
44 }
45
46 let mut lhs_values = lhs_values.clone();
47 lhs_values.slice(lstart, lend - lstart);
48 let mut rhs_values = rhs_values.clone();
49 rhs_values.slice(rstart, rend - rstart);
50
51 $true_op($op(&lhs_values, &rhs_values))
52 })
53 .collect_trusted()
54 }};
55 }
56
57 use arrow::datatypes::{IntegerType as I, PhysicalType as PH, PrimitiveType as PR};
58 match lhs.values().dtype().to_physical_type() {
59 PH::Boolean => call_binary!(BooleanArray),
60 PH::BinaryView => call_binary!(BinaryViewArray),
61 PH::Utf8View => call_binary!(Utf8ViewArray),
62 PH::Primitive(PR::Int8) => call_binary!(PrimitiveArray<i8>),
63 PH::Primitive(PR::Int16) => call_binary!(PrimitiveArray<i16>),
64 PH::Primitive(PR::Int32) => call_binary!(PrimitiveArray<i32>),
65 PH::Primitive(PR::Int64) => call_binary!(PrimitiveArray<i64>),
66 PH::Primitive(PR::Int128) => call_binary!(PrimitiveArray<i128>),
67 PH::Primitive(PR::UInt8) => call_binary!(PrimitiveArray<u8>),
68 PH::Primitive(PR::UInt16) => call_binary!(PrimitiveArray<u16>),
69 PH::Primitive(PR::UInt32) => call_binary!(PrimitiveArray<u32>),
70 PH::Primitive(PR::UInt64) => call_binary!(PrimitiveArray<u64>),
71 PH::Primitive(PR::UInt128) => call_binary!(PrimitiveArray<u128>),
72 PH::Primitive(PR::Float16) => call_binary!(PrimitiveArray<pf16>),
73 PH::Primitive(PR::Float32) => call_binary!(PrimitiveArray<f32>),
74 PH::Primitive(PR::Float64) => call_binary!(PrimitiveArray<f64>),
75 PH::Primitive(PR::Int256) => call_binary!(PrimitiveArray<i256>),
76 PH::Primitive(PR::DaysMs) => call_binary!(PrimitiveArray<days_ms>),
77 PH::Primitive(PR::MonthDayNano) => {
78 call_binary!(PrimitiveArray<months_days_ns>)
79 },
80 PH::Primitive(PR::MonthDayMillis) => unimplemented!(),
81
82 #[cfg(feature = "dtype-array")]
83 PH::FixedSizeList => call_binary!(arrow::array::FixedSizeListArray),
84 #[cfg(not(feature = "dtype-array"))]
85 PH::FixedSizeList => todo!(
86 "Comparison of FixedSizeListArray is not supported without dtype-array feature"
87 ),
88
89 PH::Null => call_binary!(NullArray),
90 PH::FixedSizeBinary => call_binary!(FixedSizeBinaryArray),
91 PH::Binary => call_binary!(BinaryArray<i32>),
92 PH::LargeBinary => call_binary!(BinaryArray<i64>),
93 PH::Utf8 => call_binary!(Utf8Array<i32>),
94 PH::LargeUtf8 => call_binary!(Utf8Array<i64>),
95 PH::List => call_binary!(ListArray<i32>),
96 PH::LargeList => call_binary!(ListArray<i64>),
97 PH::Struct => call_binary!(StructArray),
98 PH::Union => todo!("Comparison of UnionArrays is not yet supported"),
99 PH::Map => todo!("Comparison of MapArrays is not yet supported"),
100 PH::Dictionary(I::Int8) => call_binary!(DictionaryArray<i8>),
101 PH::Dictionary(I::Int16) => call_binary!(DictionaryArray<i16>),
102 PH::Dictionary(I::Int32) => call_binary!(DictionaryArray<i32>),
103 PH::Dictionary(I::Int64) => call_binary!(DictionaryArray<i64>),
104 PH::Dictionary(I::Int128) => call_binary!(DictionaryArray<i128>),
105 PH::Dictionary(I::UInt8) => call_binary!(DictionaryArray<u8>),
106 PH::Dictionary(I::UInt16) => call_binary!(DictionaryArray<u16>),
107 PH::Dictionary(I::UInt32) => call_binary!(DictionaryArray<u32>),
108 PH::Dictionary(I::UInt64) => call_binary!(DictionaryArray<u64>),
109 PH::Dictionary(I::UInt128) => call_binary!(DictionaryArray<u128>),
110 }
111 }};
112}
113
114macro_rules! compare_broadcast {
115 (
116 $lhs:expr, $rhs:expr,
117 $offsets:expr, $validity:expr,
118 $op:path, $true_op:expr,
119 $ineq_len_rv:literal, $invalid_rv:literal
120 ) => {{
121 let lhs = $lhs;
122 let rhs = $rhs;
123
124 macro_rules! call_binary {
125 ($T:ty) => {{
126 let values: &$T = $lhs.as_any().downcast_ref().unwrap();
127 let scalar: &$T = $rhs.as_any().downcast_ref().unwrap();
128
129 let length = $offsets.len_proxy();
130
131 (0..length)
132 .map(move |i| {
133 let v = $validity.is_none_or(|v| v.get(i).unwrap());
134
135 if !v {
136 return $invalid_rv;
137 }
138
139 let (start, end) = unsafe { $offsets.start_end_unchecked(i) };
140
141 if end - start != scalar.len() {
142 return $ineq_len_rv;
143 }
144
145 let mut values: $T = values.clone();
147 <$T>::slice(&mut values, start, end - start);
148
149 $true_op($op(&values, scalar))
150 })
151 .collect_trusted()
152 }};
153 }
154
155 assert_eq!(lhs.dtype(), rhs.dtype());
156
157 use arrow::datatypes::{IntegerType as I, PhysicalType as PH, PrimitiveType as PR};
158 match lhs.dtype().to_physical_type() {
159 PH::Boolean => call_binary!(BooleanArray),
160 PH::BinaryView => call_binary!(BinaryViewArray),
161 PH::Utf8View => call_binary!(Utf8ViewArray),
162 PH::Primitive(PR::Int8) => call_binary!(PrimitiveArray<i8>),
163 PH::Primitive(PR::Int16) => call_binary!(PrimitiveArray<i16>),
164 PH::Primitive(PR::Int32) => call_binary!(PrimitiveArray<i32>),
165 PH::Primitive(PR::Int64) => call_binary!(PrimitiveArray<i64>),
166 PH::Primitive(PR::Int128) => call_binary!(PrimitiveArray<i128>),
167 PH::Primitive(PR::UInt8) => call_binary!(PrimitiveArray<u8>),
168 PH::Primitive(PR::UInt16) => call_binary!(PrimitiveArray<u16>),
169 PH::Primitive(PR::UInt32) => call_binary!(PrimitiveArray<u32>),
170 PH::Primitive(PR::UInt64) => call_binary!(PrimitiveArray<u64>),
171 PH::Primitive(PR::UInt128) => call_binary!(PrimitiveArray<u128>),
172 PH::Primitive(PR::Float16) => call_binary!(PrimitiveArray<pf16>),
173 PH::Primitive(PR::Float32) => call_binary!(PrimitiveArray<f32>),
174 PH::Primitive(PR::Float64) => call_binary!(PrimitiveArray<f64>),
175 PH::Primitive(PR::Int256) => call_binary!(PrimitiveArray<i256>),
176 PH::Primitive(PR::DaysMs) => call_binary!(PrimitiveArray<days_ms>),
177 PH::Primitive(PR::MonthDayNano) => {
178 call_binary!(PrimitiveArray<months_days_ns>)
179 },
180 PH::Primitive(PR::MonthDayMillis) => unimplemented!(),
181
182 #[cfg(feature = "dtype-array")]
183 PH::FixedSizeList => call_binary!(arrow::array::FixedSizeListArray),
184 #[cfg(not(feature = "dtype-array"))]
185 PH::FixedSizeList => todo!(
186 "Comparison of FixedSizeListArray is not supported without dtype-array feature"
187 ),
188
189 PH::Null => call_binary!(NullArray),
190 PH::FixedSizeBinary => call_binary!(FixedSizeBinaryArray),
191 PH::Binary => call_binary!(BinaryArray<i32>),
192 PH::LargeBinary => call_binary!(BinaryArray<i64>),
193 PH::Utf8 => call_binary!(Utf8Array<i32>),
194 PH::LargeUtf8 => call_binary!(Utf8Array<i64>),
195 PH::List => call_binary!(ListArray<i32>),
196 PH::LargeList => call_binary!(ListArray<i64>),
197 PH::Struct => call_binary!(StructArray),
198 PH::Union => todo!("Comparison of UnionArrays is not yet supported"),
199 PH::Map => todo!("Comparison of MapArrays is not yet supported"),
200 PH::Dictionary(I::Int8) => call_binary!(DictionaryArray<i8>),
201 PH::Dictionary(I::Int16) => call_binary!(DictionaryArray<i16>),
202 PH::Dictionary(I::Int32) => call_binary!(DictionaryArray<i32>),
203 PH::Dictionary(I::Int64) => call_binary!(DictionaryArray<i64>),
204 PH::Dictionary(I::Int128) => call_binary!(DictionaryArray<i128>),
205 PH::Dictionary(I::UInt8) => call_binary!(DictionaryArray<u8>),
206 PH::Dictionary(I::UInt16) => call_binary!(DictionaryArray<u16>),
207 PH::Dictionary(I::UInt32) => call_binary!(DictionaryArray<u32>),
208 PH::Dictionary(I::UInt64) => call_binary!(DictionaryArray<u64>),
209 PH::Dictionary(I::UInt128) => call_binary!(DictionaryArray<u128>),
210 }
211 }};
212}
213
214impl<O: Offset> TotalEqKernel for ListArray<O> {
215 type Scalar = Box<dyn Array>;
216
217 fn tot_eq_kernel(&self, other: &Self) -> Bitmap {
218 compare!(
219 self,
220 other,
221 TotalEqKernel::tot_eq_missing_kernel,
222 |bm: Bitmap| bm.unset_bits() == 0,
223 false,
224 true
225 )
226 }
227
228 fn tot_ne_kernel(&self, other: &Self) -> Bitmap {
229 compare!(
230 self,
231 other,
232 TotalEqKernel::tot_ne_missing_kernel,
233 |bm: Bitmap| bm.set_bits() > 0,
234 true,
235 false
236 )
237 }
238
239 fn tot_eq_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {
240 compare_broadcast!(
241 self.values().as_ref(),
242 other.as_ref(),
243 self.offsets(),
244 self.validity(),
245 TotalEqKernel::tot_eq_missing_kernel,
246 |bm: Bitmap| bm.unset_bits() == 0,
247 false,
248 true
249 )
250 }
251
252 fn tot_ne_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {
253 compare_broadcast!(
254 self.values().as_ref(),
255 other.as_ref(),
256 self.offsets(),
257 self.validity(),
258 TotalEqKernel::tot_ne_missing_kernel,
259 |bm: Bitmap| bm.set_bits() > 0,
260 true,
261 false
262 )
263 }
264}