1use polars_ops::chunked_array::array::*;
2
3use super::*;
4use crate::{map, map_as_slice};
5
6#[derive(Clone, Copy, Eq, PartialEq, Hash, Debug)]
7#[cfg_attr(feature = "ir_serde", derive(serde::Serialize, serde::Deserialize))]
8pub enum IRArrayFunction {
9 Length,
10 Min,
11 Max,
12 Sum,
13 ToList,
14 Unique(bool),
15 NUnique,
16 Std(u8),
17 Var(u8),
18 Median,
19 #[cfg(feature = "array_any_all")]
20 Any,
21 #[cfg(feature = "array_any_all")]
22 All,
23 Sort(SortOptions),
24 Reverse,
25 ArgMin,
26 ArgMax,
27 Get(bool),
28 Join(bool),
29 #[cfg(feature = "is_in")]
30 Contains {
31 nulls_equal: bool,
32 },
33 #[cfg(feature = "array_count")]
34 CountMatches,
35 Shift,
36 Explode {
37 skip_empty: bool,
38 },
39 Concat,
40}
41
42impl IRArrayFunction {
43 pub(super) fn get_field(&self, mapper: FieldsMapper) -> PolarsResult<Field> {
44 use IRArrayFunction::*;
45 match self {
46 Concat => Ok(Field::new(
47 mapper
48 .args()
49 .first()
50 .map_or(PlSmallStr::EMPTY, |x| x.name.clone()),
51 concat_arr_output_dtype(
52 &mut mapper.args().iter().map(|x| (x.name.as_str(), &x.dtype)),
53 )?,
54 )),
55 Length => mapper.with_dtype(IDX_DTYPE),
56 Min | Max => mapper.map_to_list_and_array_inner_dtype(),
57 Sum => mapper.nested_sum_type(),
58 ToList => mapper.try_map_dtype(map_array_dtype_to_list_dtype),
59 Unique(_) => mapper.try_map_dtype(map_array_dtype_to_list_dtype),
60 NUnique => mapper.with_dtype(IDX_DTYPE),
61 Std(_) => mapper.map_to_float_dtype(),
62 Var(_) => mapper.map_to_float_dtype(),
63 Median => mapper.map_to_float_dtype(),
64 #[cfg(feature = "array_any_all")]
65 Any | All => mapper.with_dtype(DataType::Boolean),
66 Sort(_) => mapper.with_same_dtype(),
67 Reverse => mapper.with_same_dtype(),
68 ArgMin | ArgMax => mapper.with_dtype(IDX_DTYPE),
69 Get(_) => mapper.map_to_list_and_array_inner_dtype(),
70 Join(_) => mapper.with_dtype(DataType::String),
71 #[cfg(feature = "is_in")]
72 Contains { nulls_equal: _ } => mapper.with_dtype(DataType::Boolean),
73 #[cfg(feature = "array_count")]
74 CountMatches => mapper.with_dtype(IDX_DTYPE),
75 Shift => mapper.with_same_dtype(),
76 Explode { .. } => mapper.try_map_to_array_inner_dtype(),
77 }
78 }
79
80 pub fn function_options(&self) -> FunctionOptions {
81 use IRArrayFunction as A;
82 match self {
83 #[cfg(feature = "array_any_all")]
84 A::Any | A::All => FunctionOptions::elementwise(),
85 #[cfg(feature = "is_in")]
86 A::Contains { nulls_equal: _ } => FunctionOptions::elementwise(),
87 #[cfg(feature = "array_count")]
88 A::CountMatches => FunctionOptions::elementwise(),
89 A::Length
90 | A::Min
91 | A::Max
92 | A::Sum
93 | A::ToList
94 | A::Unique(_)
95 | A::NUnique
96 | A::Std(_)
97 | A::Var(_)
98 | A::Median
99 | A::Sort(_)
100 | A::Reverse
101 | A::ArgMin
102 | A::ArgMax
103 | A::Concat
104 | A::Get(_)
105 | A::Join(_)
106 | A::Shift => FunctionOptions::elementwise(),
107 A::Explode { .. } => FunctionOptions::row_separable(),
108 }
109 }
110}
111
112fn map_array_dtype_to_list_dtype(datatype: &DataType) -> PolarsResult<DataType> {
113 if let DataType::Array(inner, _) = datatype {
114 Ok(DataType::List(inner.clone()))
115 } else {
116 polars_bail!(ComputeError: "expected array dtype")
117 }
118}
119
120impl Display for IRArrayFunction {
121 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
122 use IRArrayFunction::*;
123 let name = match self {
124 Concat => "concat",
125 Length => "length",
126 Min => "min",
127 Max => "max",
128 Sum => "sum",
129 ToList => "to_list",
130 Unique(_) => "unique",
131 NUnique => "n_unique",
132 Std(_) => "std",
133 Var(_) => "var",
134 Median => "median",
135 #[cfg(feature = "array_any_all")]
136 Any => "any",
137 #[cfg(feature = "array_any_all")]
138 All => "all",
139 Sort(_) => "sort",
140 Reverse => "reverse",
141 ArgMin => "arg_min",
142 ArgMax => "arg_max",
143 Get(_) => "get",
144 Join(_) => "join",
145 #[cfg(feature = "is_in")]
146 Contains { nulls_equal: _ } => "contains",
147 #[cfg(feature = "array_count")]
148 CountMatches => "count_matches",
149 Shift => "shift",
150 Explode { .. } => "explode",
151 };
152 write!(f, "arr.{name}")
153 }
154}
155
156impl From<IRArrayFunction> for SpecialEq<Arc<dyn ColumnsUdf>> {
157 fn from(func: IRArrayFunction) -> Self {
158 use IRArrayFunction::*;
159 match func {
160 Concat => map_as_slice!(concat_arr),
161 Length => map!(length),
162 Min => map!(min),
163 Max => map!(max),
164 Sum => map!(sum),
165 ToList => map!(to_list),
166 Unique(stable) => map!(unique, stable),
167 NUnique => map!(n_unique),
168 Std(ddof) => map!(std, ddof),
169 Var(ddof) => map!(var, ddof),
170 Median => map!(median),
171 #[cfg(feature = "array_any_all")]
172 Any => map!(any),
173 #[cfg(feature = "array_any_all")]
174 All => map!(all),
175 Sort(options) => map!(sort, options),
176 Reverse => map!(reverse),
177 ArgMin => map!(arg_min),
178 ArgMax => map!(arg_max),
179 Get(null_on_oob) => map_as_slice!(get, null_on_oob),
180 Join(ignore_nulls) => map_as_slice!(join, ignore_nulls),
181 #[cfg(feature = "is_in")]
182 Contains { nulls_equal } => map_as_slice!(contains, nulls_equal),
183 #[cfg(feature = "array_count")]
184 CountMatches => map_as_slice!(count_matches),
185 Shift => map_as_slice!(shift),
186 Explode { skip_empty } => map_as_slice!(explode, skip_empty),
187 }
188 }
189}
190
191pub(super) fn length(s: &Column) -> PolarsResult<Column> {
192 let array = s.array()?;
193 let width = array.width();
194 let width = IdxSize::try_from(width)
195 .map_err(|_| polars_err!(bigidx, ctx = "array length", size = width))?;
196
197 let mut c = Column::new_scalar(array.name().clone(), width.into(), array.len());
198 if let Some(validity) = array.rechunk_validity() {
199 let mut series = c.into_materialized_series().clone();
200
201 let chunks = unsafe { series.chunks_mut() };
203 assert_eq!(chunks.len(), 1);
204
205 chunks[0] = chunks[0].with_validity(Some(validity));
206
207 series.compute_len();
208 c = series.into_column();
209 }
210
211 Ok(c)
212}
213
214pub(super) fn max(s: &Column) -> PolarsResult<Column> {
215 Ok(s.array()?.array_max().into())
216}
217
218pub(super) fn min(s: &Column) -> PolarsResult<Column> {
219 Ok(s.array()?.array_min().into())
220}
221
222pub(super) fn sum(s: &Column) -> PolarsResult<Column> {
223 s.array()?.array_sum().map(Column::from)
224}
225
226pub(super) fn std(s: &Column, ddof: u8) -> PolarsResult<Column> {
227 s.array()?.array_std(ddof).map(Column::from)
228}
229
230pub(super) fn var(s: &Column, ddof: u8) -> PolarsResult<Column> {
231 s.array()?.array_var(ddof).map(Column::from)
232}
233pub(super) fn median(s: &Column) -> PolarsResult<Column> {
234 s.array()?.array_median().map(Column::from)
235}
236
237pub(super) fn unique(s: &Column, stable: bool) -> PolarsResult<Column> {
238 let ca = s.array()?;
239 let out = if stable {
240 ca.array_unique_stable()
241 } else {
242 ca.array_unique()
243 };
244 out.map(|ca| ca.into_column())
245}
246
247pub(super) fn n_unique(s: &Column) -> PolarsResult<Column> {
248 Ok(s.array()?.array_n_unique()?.into_column())
249}
250
251pub(super) fn to_list(s: &Column) -> PolarsResult<Column> {
252 let list_dtype = map_array_dtype_to_list_dtype(s.dtype())?;
253 s.cast(&list_dtype)
254}
255
256#[cfg(feature = "array_any_all")]
257pub(super) fn any(s: &Column) -> PolarsResult<Column> {
258 s.array()?.array_any().map(Column::from)
259}
260
261#[cfg(feature = "array_any_all")]
262pub(super) fn all(s: &Column) -> PolarsResult<Column> {
263 s.array()?.array_all().map(Column::from)
264}
265
266pub(super) fn sort(s: &Column, options: SortOptions) -> PolarsResult<Column> {
267 Ok(s.array()?.array_sort(options)?.into_column())
268}
269
270pub(super) fn reverse(s: &Column) -> PolarsResult<Column> {
271 Ok(s.array()?.array_reverse().into_column())
272}
273
274pub(super) fn arg_min(s: &Column) -> PolarsResult<Column> {
275 Ok(s.array()?.array_arg_min().into_column())
276}
277
278pub(super) fn arg_max(s: &Column) -> PolarsResult<Column> {
279 Ok(s.array()?.array_arg_max().into_column())
280}
281
282pub(super) fn get(s: &[Column], null_on_oob: bool) -> PolarsResult<Column> {
283 let ca = s[0].array()?;
284 let index = s[1].cast(&DataType::Int64)?;
285 let index = index.i64().unwrap();
286 ca.array_get(index, null_on_oob).map(Column::from)
287}
288
289pub(super) fn join(s: &[Column], ignore_nulls: bool) -> PolarsResult<Column> {
290 let ca = s[0].array()?;
291 let separator = s[1].str()?;
292 ca.array_join(separator, ignore_nulls).map(Column::from)
293}
294
295#[cfg(feature = "is_in")]
296pub(super) fn contains(s: &[Column], nulls_equal: bool) -> PolarsResult<Column> {
297 let array = &s[0];
298 let item = &s[1];
299 polars_ensure!(matches!(array.dtype(), DataType::Array(_, _)),
300 SchemaMismatch: "invalid series dtype: expected `Array`, got `{}`", array.dtype(),
301 );
302 let mut ca = is_in(
303 item.as_materialized_series(),
304 array.as_materialized_series(),
305 nulls_equal,
306 )?;
307 ca.rename(array.name().clone());
308 Ok(ca.into_column())
309}
310
311#[cfg(feature = "array_count")]
312pub(super) fn count_matches(args: &[Column]) -> PolarsResult<Column> {
313 let s = &args[0];
314 let element = &args[1];
315 polars_ensure!(
316 element.len() == 1,
317 ComputeError: "argument expression in `arr.count_matches` must produce exactly one element, got {}",
318 element.len()
319 );
320 let ca = s.array()?;
321 ca.array_count_matches(element.get(0).unwrap())
322 .map(Column::from)
323}
324
325pub(super) fn shift(s: &[Column]) -> PolarsResult<Column> {
326 let ca = s[0].array()?;
327 let n = &s[1];
328
329 ca.array_shift(n.as_materialized_series()).map(Column::from)
330}
331
332fn explode(c: &[Column], skip_empty: bool) -> PolarsResult<Column> {
333 c[0].explode(skip_empty)
334}
335
336fn concat_arr(args: &[Column]) -> PolarsResult<Column> {
337 let dtype = concat_arr_output_dtype(&mut args.iter().map(|c| (c.name().as_str(), c.dtype())))?;
338
339 polars_ops::series::concat_arr::concat_arr(args, &dtype)
340}
341
342fn concat_arr_output_dtype(
345 inputs: &mut dyn ExactSizeIterator<Item = (&str, &DataType)>,
346) -> PolarsResult<DataType> {
347 #[allow(clippy::len_zero)]
348 if inputs.len() == 0 {
349 panic!();
351 }
352
353 let mut inputs = inputs.map(|(name, dtype)| {
354 let (inner_dtype, width) = match dtype {
355 DataType::Array(inner, width) => (inner.as_ref(), *width),
356 dt => (dt, 1),
357 };
358 (name, dtype, inner_dtype, width)
359 });
360 let (first_name, first_dtype, first_inner_dtype, mut out_width) = inputs.next().unwrap();
361
362 for (col_name, dtype, inner_dtype, width) in inputs {
363 out_width += width;
364
365 if inner_dtype != first_inner_dtype {
366 polars_bail!(
367 SchemaMismatch:
368 "concat_arr dtype mismatch: expected {} or array[{}] dtype to match dtype of first \
369 input column (name: {}, dtype: {}), got {} instead for column {}",
370 first_inner_dtype, first_inner_dtype, first_name, first_dtype, dtype, col_name,
371 )
372 }
373 }
374
375 Ok(DataType::Array(
376 Box::new(first_inner_dtype.clone()),
377 out_width,
378 ))
379}