1use polars_core::utils::slice_offsets;
2use polars_ops::chunked_array::array::*;
3
4use super::*;
5use crate::{map, map_as_slice};
6
7#[derive(Clone, Eq, PartialEq, Hash, Debug)]
8#[cfg_attr(feature = "ir_serde", derive(serde::Serialize, serde::Deserialize))]
9pub enum IRArrayFunction {
10 Length,
11 Min,
12 Max,
13 Sum,
14 ToList,
15 Unique(bool),
16 NUnique,
17 Std(u8),
18 Var(u8),
19 Mean,
20 Median,
21 #[cfg(feature = "array_any_all")]
22 Any,
23 #[cfg(feature = "array_any_all")]
24 All,
25 Sort(SortOptions),
26 Reverse,
27 ArgMin,
28 ArgMax,
29 Get(bool),
30 Join(bool),
31 #[cfg(feature = "is_in")]
32 Contains {
33 nulls_equal: bool,
34 },
35 #[cfg(feature = "array_count")]
36 CountMatches,
37 Shift,
38 Explode {
39 skip_empty: bool,
40 },
41 Concat,
42 Slice(i64, i64),
43 #[cfg(feature = "array_to_struct")]
44 ToStruct(Option<DslNameGenerator>),
45}
46
47impl IRArrayFunction {
48 pub(super) fn get_field(&self, mapper: FieldsMapper) -> PolarsResult<Field> {
49 use IRArrayFunction::*;
50 match self {
51 Concat => Ok(Field::new(
52 mapper
53 .args()
54 .first()
55 .map_or(PlSmallStr::EMPTY, |x| x.name.clone()),
56 concat_arr_output_dtype(
57 &mut mapper.args().iter().map(|x| (x.name.as_str(), &x.dtype)),
58 )?,
59 )),
60 Length => mapper.with_dtype(IDX_DTYPE),
61 Min | Max => mapper.map_to_list_and_array_inner_dtype(),
62 Sum => mapper.nested_sum_type(),
63 ToList => mapper.try_map_dtype(map_array_dtype_to_list_dtype),
64 Unique(_) => mapper.try_map_dtype(map_array_dtype_to_list_dtype),
65 NUnique => mapper.with_dtype(IDX_DTYPE),
66 Std(_) => mapper.moment_dtype(),
67 Var(_) => mapper.var_dtype(),
68 Mean => mapper.moment_dtype(),
69 Median => mapper.moment_dtype(),
70 #[cfg(feature = "array_any_all")]
71 Any | All => mapper.with_dtype(DataType::Boolean),
72 Sort(_) => mapper.with_same_dtype(),
73 Reverse => mapper.with_same_dtype(),
74 ArgMin | ArgMax => mapper.with_dtype(IDX_DTYPE),
75 Get(_) => mapper.map_to_list_and_array_inner_dtype(),
76 Join(_) => mapper.with_dtype(DataType::String),
77 #[cfg(feature = "is_in")]
78 Contains { nulls_equal: _ } => mapper.with_dtype(DataType::Boolean),
79 #[cfg(feature = "array_count")]
80 CountMatches => mapper.with_dtype(IDX_DTYPE),
81 Shift => mapper.with_same_dtype(),
82 Explode { .. } => mapper.try_map_to_array_inner_dtype(),
83 Slice(offset, length) => {
84 mapper.try_map_dtype(map_to_array_fixed_length(offset, length))
85 },
86 #[cfg(feature = "array_to_struct")]
87 ToStruct(name_generator) => mapper.try_map_dtype(|dtype| {
88 let DataType::Array(inner, width) = dtype else {
89 polars_bail!(InvalidOperation: "expected Array type, got: {dtype}")
90 };
91
92 (0..*width)
93 .map(|i| {
94 let name = match name_generator {
95 None => arr_default_struct_name_gen(i),
96 Some(ng) => PlSmallStr::from_string(ng.call(i)?),
97 };
98 Ok(Field::new(name, inner.as_ref().clone()))
99 })
100 .collect::<PolarsResult<Vec<Field>>>()
101 .map(DataType::Struct)
102 }),
103 }
104 }
105
106 pub fn function_options(&self) -> FunctionOptions {
107 use IRArrayFunction as A;
108 match self {
109 #[cfg(feature = "array_any_all")]
110 A::Any | A::All => FunctionOptions::elementwise(),
111 #[cfg(feature = "is_in")]
112 A::Contains { nulls_equal: _ } => FunctionOptions::elementwise(),
113 #[cfg(feature = "array_count")]
114 A::CountMatches => FunctionOptions::elementwise(),
115 A::Concat => FunctionOptions::elementwise()
116 .with_flags(|f| f | FunctionFlags::INPUT_WILDCARD_EXPANSION),
117 A::Length
118 | A::Min
119 | A::Max
120 | A::Sum
121 | A::ToList
122 | A::Unique(_)
123 | A::NUnique
124 | A::Std(_)
125 | A::Var(_)
126 | A::Mean
127 | A::Median
128 | A::Sort(_)
129 | A::Reverse
130 | A::ArgMin
131 | A::ArgMax
132 | A::Get(_)
133 | A::Join(_)
134 | A::Shift
135 | A::Slice(_, _) => FunctionOptions::elementwise(),
136 A::Explode { .. } => FunctionOptions::row_separable(),
137 #[cfg(feature = "array_to_struct")]
138 A::ToStruct(_) => FunctionOptions::elementwise(),
139 }
140 }
141}
142
143fn map_array_dtype_to_list_dtype(datatype: &DataType) -> PolarsResult<DataType> {
144 if let DataType::Array(inner, _) = datatype {
145 Ok(DataType::List(inner.clone()))
146 } else {
147 polars_bail!(ComputeError: "expected array dtype")
148 }
149}
150
151fn map_to_array_fixed_length(
152 offset: &i64,
153 length: &i64,
154) -> impl FnOnce(&DataType) -> PolarsResult<DataType> {
155 move |datatype: &DataType| {
156 if let DataType::Array(inner, array_len) = datatype {
157 let length: usize = if *length < 0 {
158 (*array_len as i64 + *length).max(0)
159 } else {
160 *length
161 }.try_into().map_err(|_| {
162 polars_err!(OutOfBounds: "length must be a non-negative integer, got: {}", length)
163 })?;
164 let (_, slice_offset) = slice_offsets(*offset, length, *array_len);
165 Ok(DataType::Array(inner.clone(), slice_offset))
166 } else {
167 polars_bail!(ComputeError: "expected array dtype, got {}", datatype);
168 }
169 }
170}
171
172impl Display for IRArrayFunction {
173 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
174 use IRArrayFunction::*;
175 let name = match self {
176 Concat => "concat",
177 Length => "length",
178 Min => "min",
179 Max => "max",
180 Sum => "sum",
181 ToList => "to_list",
182 Unique(_) => "unique",
183 NUnique => "n_unique",
184 Std(_) => "std",
185 Var(_) => "var",
186 Mean => "mean",
187 Median => "median",
188 #[cfg(feature = "array_any_all")]
189 Any => "any",
190 #[cfg(feature = "array_any_all")]
191 All => "all",
192 Sort(_) => "sort",
193 Reverse => "reverse",
194 ArgMin => "arg_min",
195 ArgMax => "arg_max",
196 Get(_) => "get",
197 Join(_) => "join",
198 #[cfg(feature = "is_in")]
199 Contains { nulls_equal: _ } => "contains",
200 #[cfg(feature = "array_count")]
201 CountMatches => "count_matches",
202 Shift => "shift",
203 Slice(_, _) => "slice",
204 Explode { .. } => "explode",
205 #[cfg(feature = "array_to_struct")]
206 ToStruct(_) => "to_struct",
207 };
208 write!(f, "arr.{name}")
209 }
210}
211
212impl From<IRArrayFunction> for SpecialEq<Arc<dyn ColumnsUdf>> {
213 fn from(func: IRArrayFunction) -> Self {
214 use IRArrayFunction::*;
215 match func {
216 Concat => map_as_slice!(concat_arr),
217 Length => map!(length),
218 Min => map!(min),
219 Max => map!(max),
220 Sum => map!(sum),
221 ToList => map!(to_list),
222 Unique(stable) => map!(unique, stable),
223 NUnique => map!(n_unique),
224 Std(ddof) => map!(std, ddof),
225 Var(ddof) => map!(var, ddof),
226 Mean => map!(mean),
227 Median => map!(median),
228 #[cfg(feature = "array_any_all")]
229 Any => map!(any),
230 #[cfg(feature = "array_any_all")]
231 All => map!(all),
232 Sort(options) => map!(sort, options),
233 Reverse => map!(reverse),
234 ArgMin => map!(arg_min),
235 ArgMax => map!(arg_max),
236 Get(null_on_oob) => map_as_slice!(get, null_on_oob),
237 Join(ignore_nulls) => map_as_slice!(join, ignore_nulls),
238 #[cfg(feature = "is_in")]
239 Contains { nulls_equal } => map_as_slice!(contains, nulls_equal),
240 #[cfg(feature = "array_count")]
241 CountMatches => map_as_slice!(count_matches),
242 Shift => map_as_slice!(shift),
243 Explode { skip_empty } => map_as_slice!(explode, skip_empty),
244 Slice(offset, length) => map!(slice, offset, length),
245 #[cfg(feature = "array_to_struct")]
246 ToStruct(ng) => map!(arr_to_struct, ng.clone()),
247 }
248 }
249}
250
251pub(super) fn length(s: &Column) -> PolarsResult<Column> {
252 let array = s.array()?;
253 let width = array.width();
254 let width = IdxSize::try_from(width)
255 .map_err(|_| polars_err!(bigidx, ctx = "array length", size = width))?;
256
257 let mut c = Column::new_scalar(array.name().clone(), width.into(), array.len());
258 if let Some(validity) = array.rechunk_validity() {
259 let mut series = c.into_materialized_series().clone();
260
261 let chunks = unsafe { series.chunks_mut() };
263 assert_eq!(chunks.len(), 1);
264
265 chunks[0] = chunks[0].with_validity(Some(validity));
266
267 series.compute_len();
268 c = series.into_column();
269 }
270
271 Ok(c)
272}
273
274pub(super) fn max(s: &Column) -> PolarsResult<Column> {
275 Ok(s.array()?.array_max().into())
276}
277
278pub(super) fn min(s: &Column) -> PolarsResult<Column> {
279 Ok(s.array()?.array_min().into())
280}
281
282pub(super) fn sum(s: &Column) -> PolarsResult<Column> {
283 s.array()?.array_sum().map(Column::from)
284}
285
286pub(super) fn std(s: &Column, ddof: u8) -> PolarsResult<Column> {
287 s.array()?.array_std(ddof).map(Column::from)
288}
289
290pub(super) fn var(s: &Column, ddof: u8) -> PolarsResult<Column> {
291 s.array()?.array_var(ddof).map(Column::from)
292}
293
294pub(super) fn mean(s: &Column) -> PolarsResult<Column> {
295 s.array()?.array_mean().map(Column::from)
296}
297
298pub(super) fn median(s: &Column) -> PolarsResult<Column> {
299 s.array()?.array_median().map(Column::from)
300}
301
302pub(super) fn unique(s: &Column, stable: bool) -> PolarsResult<Column> {
303 let ca = s.array()?;
304 let out = if stable {
305 ca.array_unique_stable()
306 } else {
307 ca.array_unique()
308 };
309 out.map(|ca| ca.into_column())
310}
311
312pub(super) fn n_unique(s: &Column) -> PolarsResult<Column> {
313 Ok(s.array()?.array_n_unique()?.into_column())
314}
315
316pub(super) fn to_list(s: &Column) -> PolarsResult<Column> {
317 let list_dtype = map_array_dtype_to_list_dtype(s.dtype())?;
318 s.cast(&list_dtype)
319}
320
321#[cfg(feature = "array_any_all")]
322pub(super) fn any(s: &Column) -> PolarsResult<Column> {
323 s.array()?.array_any().map(Column::from)
324}
325
326#[cfg(feature = "array_any_all")]
327pub(super) fn all(s: &Column) -> PolarsResult<Column> {
328 s.array()?.array_all().map(Column::from)
329}
330
331pub(super) fn sort(s: &Column, options: SortOptions) -> PolarsResult<Column> {
332 Ok(s.array()?.array_sort(options)?.into_column())
333}
334
335pub(super) fn reverse(s: &Column) -> PolarsResult<Column> {
336 Ok(s.array()?.array_reverse().into_column())
337}
338
339pub(super) fn arg_min(s: &Column) -> PolarsResult<Column> {
340 Ok(s.array()?.array_arg_min().into_column())
341}
342
343pub(super) fn arg_max(s: &Column) -> PolarsResult<Column> {
344 Ok(s.array()?.array_arg_max().into_column())
345}
346
347pub(super) fn get(s: &[Column], null_on_oob: bool) -> PolarsResult<Column> {
348 let ca = s[0].array()?;
349 let index = s[1].cast(&DataType::Int64)?;
350 let index = index.i64().unwrap();
351 ca.array_get(index, null_on_oob).map(Column::from)
352}
353
354pub(super) fn join(s: &[Column], ignore_nulls: bool) -> PolarsResult<Column> {
355 let ca = s[0].array()?;
356 let separator = s[1].str()?;
357 ca.array_join(separator, ignore_nulls).map(Column::from)
358}
359
360#[cfg(feature = "is_in")]
361pub(super) fn contains(s: &[Column], nulls_equal: bool) -> PolarsResult<Column> {
362 let array = &s[0];
363 let item = &s[1];
364 polars_ensure!(matches!(array.dtype(), DataType::Array(_, _)),
365 SchemaMismatch: "invalid series dtype: expected `Array`, got `{}`", array.dtype(),
366 );
367 let mut ca = is_in(
368 item.as_materialized_series(),
369 array.as_materialized_series(),
370 nulls_equal,
371 )?;
372 ca.rename(array.name().clone());
373 Ok(ca.into_column())
374}
375
376#[cfg(feature = "array_count")]
377pub(super) fn count_matches(args: &[Column]) -> PolarsResult<Column> {
378 let s = &args[0];
379 let element = &args[1];
380 polars_ensure!(
381 element.len() == 1,
382 ComputeError: "argument expression in `arr.count_matches` must produce exactly one element, got {}",
383 element.len()
384 );
385 let ca = s.array()?;
386 ca.array_count_matches(element.get(0).unwrap())
387 .map(Column::from)
388}
389
390pub(super) fn shift(s: &[Column]) -> PolarsResult<Column> {
391 let ca = s[0].array()?;
392 let n = &s[1];
393
394 ca.array_shift(n.as_materialized_series()).map(Column::from)
395}
396
397pub(super) fn slice(s: &Column, offset: i64, length: i64) -> PolarsResult<Column> {
398 let ca = s.array()?;
399 ca.array_slice(offset, length).map(Column::from)
400}
401
402fn explode(c: &[Column], skip_empty: bool) -> PolarsResult<Column> {
403 c[0].explode(skip_empty)
404}
405
406fn concat_arr(args: &[Column]) -> PolarsResult<Column> {
407 let dtype = concat_arr_output_dtype(&mut args.iter().map(|c| (c.name().as_str(), c.dtype())))?;
408
409 polars_ops::series::concat_arr::concat_arr(args, &dtype)
410}
411
412fn concat_arr_output_dtype(
415 inputs: &mut dyn ExactSizeIterator<Item = (&str, &DataType)>,
416) -> PolarsResult<DataType> {
417 #[allow(clippy::len_zero)]
418 if inputs.len() == 0 {
419 panic!();
421 }
422
423 let mut inputs = inputs.map(|(name, dtype)| {
424 let (inner_dtype, width) = match dtype {
425 DataType::Array(inner, width) => (inner.as_ref(), *width),
426 dt => (dt, 1),
427 };
428 (name, dtype, inner_dtype, width)
429 });
430 let (first_name, first_dtype, first_inner_dtype, mut out_width) = inputs.next().unwrap();
431
432 for (col_name, dtype, inner_dtype, width) in inputs {
433 out_width += width;
434
435 if inner_dtype != first_inner_dtype {
436 polars_bail!(
437 SchemaMismatch:
438 "concat_arr dtype mismatch: expected {} or array[{}] dtype to match dtype of first \
439 input column (name: {}, dtype: {}), got {} instead for column {}",
440 first_inner_dtype, first_inner_dtype, first_name, first_dtype, dtype, col_name,
441 )
442 }
443 }
444
445 Ok(DataType::Array(
446 Box::new(first_inner_dtype.clone()),
447 out_width,
448 ))
449}
450
451#[cfg(feature = "array_to_struct")]
452fn arr_to_struct(s: &Column, name_generator: Option<DslNameGenerator>) -> PolarsResult<Column> {
453 let name_generator =
454 name_generator.map(|f| Arc::new(move |i| f.call(i).map(PlSmallStr::from)) as Arc<_>);
455 s.array()?
456 .to_struct(name_generator)
457 .map(IntoColumn::into_column)
458}