polars_plan/plans/aexpr/function_expr/
array.rs

1use polars_core::utils::slice_offsets;
2use polars_ops::chunked_array::array::*;
3
4use super::*;
5use crate::{map, map_as_slice};
6
7#[derive(Clone, Eq, PartialEq, Hash, Debug)]
8#[cfg_attr(feature = "ir_serde", derive(serde::Serialize, serde::Deserialize))]
9pub enum IRArrayFunction {
10    Length,
11    Min,
12    Max,
13    Sum,
14    ToList,
15    Unique(bool),
16    NUnique,
17    Std(u8),
18    Var(u8),
19    Mean,
20    Median,
21    #[cfg(feature = "array_any_all")]
22    Any,
23    #[cfg(feature = "array_any_all")]
24    All,
25    Sort(SortOptions),
26    Reverse,
27    ArgMin,
28    ArgMax,
29    Get(bool),
30    Join(bool),
31    #[cfg(feature = "is_in")]
32    Contains {
33        nulls_equal: bool,
34    },
35    #[cfg(feature = "array_count")]
36    CountMatches,
37    Shift,
38    Explode {
39        skip_empty: bool,
40    },
41    Concat,
42    Slice(i64, i64),
43    #[cfg(feature = "array_to_struct")]
44    ToStruct(Option<DslNameGenerator>),
45}
46
47impl IRArrayFunction {
48    pub(super) fn get_field(&self, mapper: FieldsMapper) -> PolarsResult<Field> {
49        use IRArrayFunction::*;
50        match self {
51            Concat => Ok(Field::new(
52                mapper
53                    .args()
54                    .first()
55                    .map_or(PlSmallStr::EMPTY, |x| x.name.clone()),
56                concat_arr_output_dtype(
57                    &mut mapper.args().iter().map(|x| (x.name.as_str(), &x.dtype)),
58                )?,
59            )),
60            Length => mapper.with_dtype(IDX_DTYPE),
61            Min | Max => mapper.map_to_list_and_array_inner_dtype(),
62            Sum => mapper.nested_sum_type(),
63            ToList => mapper.try_map_dtype(map_array_dtype_to_list_dtype),
64            Unique(_) => mapper.try_map_dtype(map_array_dtype_to_list_dtype),
65            NUnique => mapper.with_dtype(IDX_DTYPE),
66            Std(_) => mapper.moment_dtype(),
67            Var(_) => mapper.var_dtype(),
68            Mean => mapper.moment_dtype(),
69            Median => mapper.moment_dtype(),
70            #[cfg(feature = "array_any_all")]
71            Any | All => mapper.with_dtype(DataType::Boolean),
72            Sort(_) => mapper.with_same_dtype(),
73            Reverse => mapper.with_same_dtype(),
74            ArgMin | ArgMax => mapper.with_dtype(IDX_DTYPE),
75            Get(_) => mapper.map_to_list_and_array_inner_dtype(),
76            Join(_) => mapper.with_dtype(DataType::String),
77            #[cfg(feature = "is_in")]
78            Contains { nulls_equal: _ } => mapper.with_dtype(DataType::Boolean),
79            #[cfg(feature = "array_count")]
80            CountMatches => mapper.with_dtype(IDX_DTYPE),
81            Shift => mapper.with_same_dtype(),
82            Explode { .. } => mapper.try_map_to_array_inner_dtype(),
83            Slice(offset, length) => {
84                mapper.try_map_dtype(map_to_array_fixed_length(offset, length))
85            },
86            #[cfg(feature = "array_to_struct")]
87            ToStruct(name_generator) => mapper.try_map_dtype(|dtype| {
88                let DataType::Array(inner, width) = dtype else {
89                    polars_bail!(InvalidOperation: "expected Array type, got: {dtype}")
90                };
91
92                (0..*width)
93                    .map(|i| {
94                        let name = match name_generator {
95                            None => arr_default_struct_name_gen(i),
96                            Some(ng) => PlSmallStr::from_string(ng.call(i)?),
97                        };
98                        Ok(Field::new(name, inner.as_ref().clone()))
99                    })
100                    .collect::<PolarsResult<Vec<Field>>>()
101                    .map(DataType::Struct)
102            }),
103        }
104    }
105
106    pub fn function_options(&self) -> FunctionOptions {
107        use IRArrayFunction as A;
108        match self {
109            #[cfg(feature = "array_any_all")]
110            A::Any | A::All => FunctionOptions::elementwise(),
111            #[cfg(feature = "is_in")]
112            A::Contains { nulls_equal: _ } => FunctionOptions::elementwise(),
113            #[cfg(feature = "array_count")]
114            A::CountMatches => FunctionOptions::elementwise(),
115            A::Concat => FunctionOptions::elementwise()
116                .with_flags(|f| f | FunctionFlags::INPUT_WILDCARD_EXPANSION),
117            A::Length
118            | A::Min
119            | A::Max
120            | A::Sum
121            | A::ToList
122            | A::Unique(_)
123            | A::NUnique
124            | A::Std(_)
125            | A::Var(_)
126            | A::Mean
127            | A::Median
128            | A::Sort(_)
129            | A::Reverse
130            | A::ArgMin
131            | A::ArgMax
132            | A::Get(_)
133            | A::Join(_)
134            | A::Shift
135            | A::Slice(_, _) => FunctionOptions::elementwise(),
136            A::Explode { .. } => FunctionOptions::row_separable(),
137            #[cfg(feature = "array_to_struct")]
138            A::ToStruct(_) => FunctionOptions::elementwise(),
139        }
140    }
141}
142
143fn map_array_dtype_to_list_dtype(datatype: &DataType) -> PolarsResult<DataType> {
144    if let DataType::Array(inner, _) = datatype {
145        Ok(DataType::List(inner.clone()))
146    } else {
147        polars_bail!(ComputeError: "expected array dtype")
148    }
149}
150
151fn map_to_array_fixed_length(
152    offset: &i64,
153    length: &i64,
154) -> impl FnOnce(&DataType) -> PolarsResult<DataType> {
155    move |datatype: &DataType| {
156        if let DataType::Array(inner, array_len) = datatype {
157            let length: usize = if *length < 0 {
158                (*array_len as i64 + *length).max(0)
159            } else {
160                *length
161            }.try_into().map_err(|_| {
162                polars_err!(OutOfBounds: "length must be a non-negative integer, got: {}", length)
163            })?;
164            let (_, slice_offset) = slice_offsets(*offset, length, *array_len);
165            Ok(DataType::Array(inner.clone(), slice_offset))
166        } else {
167            polars_bail!(ComputeError: "expected array dtype, got {}", datatype);
168        }
169    }
170}
171
172impl Display for IRArrayFunction {
173    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
174        use IRArrayFunction::*;
175        let name = match self {
176            Concat => "concat",
177            Length => "length",
178            Min => "min",
179            Max => "max",
180            Sum => "sum",
181            ToList => "to_list",
182            Unique(_) => "unique",
183            NUnique => "n_unique",
184            Std(_) => "std",
185            Var(_) => "var",
186            Mean => "mean",
187            Median => "median",
188            #[cfg(feature = "array_any_all")]
189            Any => "any",
190            #[cfg(feature = "array_any_all")]
191            All => "all",
192            Sort(_) => "sort",
193            Reverse => "reverse",
194            ArgMin => "arg_min",
195            ArgMax => "arg_max",
196            Get(_) => "get",
197            Join(_) => "join",
198            #[cfg(feature = "is_in")]
199            Contains { nulls_equal: _ } => "contains",
200            #[cfg(feature = "array_count")]
201            CountMatches => "count_matches",
202            Shift => "shift",
203            Slice(_, _) => "slice",
204            Explode { .. } => "explode",
205            #[cfg(feature = "array_to_struct")]
206            ToStruct(_) => "to_struct",
207        };
208        write!(f, "arr.{name}")
209    }
210}
211
212impl From<IRArrayFunction> for SpecialEq<Arc<dyn ColumnsUdf>> {
213    fn from(func: IRArrayFunction) -> Self {
214        use IRArrayFunction::*;
215        match func {
216            Concat => map_as_slice!(concat_arr),
217            Length => map!(length),
218            Min => map!(min),
219            Max => map!(max),
220            Sum => map!(sum),
221            ToList => map!(to_list),
222            Unique(stable) => map!(unique, stable),
223            NUnique => map!(n_unique),
224            Std(ddof) => map!(std, ddof),
225            Var(ddof) => map!(var, ddof),
226            Mean => map!(mean),
227            Median => map!(median),
228            #[cfg(feature = "array_any_all")]
229            Any => map!(any),
230            #[cfg(feature = "array_any_all")]
231            All => map!(all),
232            Sort(options) => map!(sort, options),
233            Reverse => map!(reverse),
234            ArgMin => map!(arg_min),
235            ArgMax => map!(arg_max),
236            Get(null_on_oob) => map_as_slice!(get, null_on_oob),
237            Join(ignore_nulls) => map_as_slice!(join, ignore_nulls),
238            #[cfg(feature = "is_in")]
239            Contains { nulls_equal } => map_as_slice!(contains, nulls_equal),
240            #[cfg(feature = "array_count")]
241            CountMatches => map_as_slice!(count_matches),
242            Shift => map_as_slice!(shift),
243            Explode { skip_empty } => map_as_slice!(explode, skip_empty),
244            Slice(offset, length) => map!(slice, offset, length),
245            #[cfg(feature = "array_to_struct")]
246            ToStruct(ng) => map!(arr_to_struct, ng.clone()),
247        }
248    }
249}
250
251pub(super) fn length(s: &Column) -> PolarsResult<Column> {
252    let array = s.array()?;
253    let width = array.width();
254    let width = IdxSize::try_from(width)
255        .map_err(|_| polars_err!(bigidx, ctx = "array length", size = width))?;
256
257    let mut c = Column::new_scalar(array.name().clone(), width.into(), array.len());
258    if let Some(validity) = array.rechunk_validity() {
259        let mut series = c.into_materialized_series().clone();
260
261        // SAFETY: We keep datatypes intact and call compute_len afterwards.
262        let chunks = unsafe { series.chunks_mut() };
263        assert_eq!(chunks.len(), 1);
264
265        chunks[0] = chunks[0].with_validity(Some(validity));
266
267        series.compute_len();
268        c = series.into_column();
269    }
270
271    Ok(c)
272}
273
274pub(super) fn max(s: &Column) -> PolarsResult<Column> {
275    Ok(s.array()?.array_max().into())
276}
277
278pub(super) fn min(s: &Column) -> PolarsResult<Column> {
279    Ok(s.array()?.array_min().into())
280}
281
282pub(super) fn sum(s: &Column) -> PolarsResult<Column> {
283    s.array()?.array_sum().map(Column::from)
284}
285
286pub(super) fn std(s: &Column, ddof: u8) -> PolarsResult<Column> {
287    s.array()?.array_std(ddof).map(Column::from)
288}
289
290pub(super) fn var(s: &Column, ddof: u8) -> PolarsResult<Column> {
291    s.array()?.array_var(ddof).map(Column::from)
292}
293
294pub(super) fn mean(s: &Column) -> PolarsResult<Column> {
295    s.array()?.array_mean().map(Column::from)
296}
297
298pub(super) fn median(s: &Column) -> PolarsResult<Column> {
299    s.array()?.array_median().map(Column::from)
300}
301
302pub(super) fn unique(s: &Column, stable: bool) -> PolarsResult<Column> {
303    let ca = s.array()?;
304    let out = if stable {
305        ca.array_unique_stable()
306    } else {
307        ca.array_unique()
308    };
309    out.map(|ca| ca.into_column())
310}
311
312pub(super) fn n_unique(s: &Column) -> PolarsResult<Column> {
313    Ok(s.array()?.array_n_unique()?.into_column())
314}
315
316pub(super) fn to_list(s: &Column) -> PolarsResult<Column> {
317    let list_dtype = map_array_dtype_to_list_dtype(s.dtype())?;
318    s.cast(&list_dtype)
319}
320
321#[cfg(feature = "array_any_all")]
322pub(super) fn any(s: &Column) -> PolarsResult<Column> {
323    s.array()?.array_any().map(Column::from)
324}
325
326#[cfg(feature = "array_any_all")]
327pub(super) fn all(s: &Column) -> PolarsResult<Column> {
328    s.array()?.array_all().map(Column::from)
329}
330
331pub(super) fn sort(s: &Column, options: SortOptions) -> PolarsResult<Column> {
332    Ok(s.array()?.array_sort(options)?.into_column())
333}
334
335pub(super) fn reverse(s: &Column) -> PolarsResult<Column> {
336    Ok(s.array()?.array_reverse().into_column())
337}
338
339pub(super) fn arg_min(s: &Column) -> PolarsResult<Column> {
340    Ok(s.array()?.array_arg_min().into_column())
341}
342
343pub(super) fn arg_max(s: &Column) -> PolarsResult<Column> {
344    Ok(s.array()?.array_arg_max().into_column())
345}
346
347pub(super) fn get(s: &[Column], null_on_oob: bool) -> PolarsResult<Column> {
348    let ca = s[0].array()?;
349    let index = s[1].cast(&DataType::Int64)?;
350    let index = index.i64().unwrap();
351    ca.array_get(index, null_on_oob).map(Column::from)
352}
353
354pub(super) fn join(s: &[Column], ignore_nulls: bool) -> PolarsResult<Column> {
355    let ca = s[0].array()?;
356    let separator = s[1].str()?;
357    ca.array_join(separator, ignore_nulls).map(Column::from)
358}
359
360#[cfg(feature = "is_in")]
361pub(super) fn contains(s: &[Column], nulls_equal: bool) -> PolarsResult<Column> {
362    let array = &s[0];
363    let item = &s[1];
364    polars_ensure!(matches!(array.dtype(), DataType::Array(_, _)),
365        SchemaMismatch: "invalid series dtype: expected `Array`, got `{}`", array.dtype(),
366    );
367    let mut ca = is_in(
368        item.as_materialized_series(),
369        array.as_materialized_series(),
370        nulls_equal,
371    )?;
372    ca.rename(array.name().clone());
373    Ok(ca.into_column())
374}
375
376#[cfg(feature = "array_count")]
377pub(super) fn count_matches(args: &[Column]) -> PolarsResult<Column> {
378    let s = &args[0];
379    let element = &args[1];
380    polars_ensure!(
381        element.len() == 1,
382        ComputeError: "argument expression in `arr.count_matches` must produce exactly one element, got {}",
383        element.len()
384    );
385    let ca = s.array()?;
386    ca.array_count_matches(element.get(0).unwrap())
387        .map(Column::from)
388}
389
390pub(super) fn shift(s: &[Column]) -> PolarsResult<Column> {
391    let ca = s[0].array()?;
392    let n = &s[1];
393
394    ca.array_shift(n.as_materialized_series()).map(Column::from)
395}
396
397pub(super) fn slice(s: &Column, offset: i64, length: i64) -> PolarsResult<Column> {
398    let ca = s.array()?;
399    ca.array_slice(offset, length).map(Column::from)
400}
401
402fn explode(c: &[Column], skip_empty: bool) -> PolarsResult<Column> {
403    c[0].explode(skip_empty)
404}
405
406fn concat_arr(args: &[Column]) -> PolarsResult<Column> {
407    let dtype = concat_arr_output_dtype(&mut args.iter().map(|c| (c.name().as_str(), c.dtype())))?;
408
409    polars_ops::series::concat_arr::concat_arr(args, &dtype)
410}
411
412/// Determine the output dtype of a `concat_arr` operation. Also performs validation to ensure input
413/// dtypes are compatible.
414fn concat_arr_output_dtype(
415    inputs: &mut dyn ExactSizeIterator<Item = (&str, &DataType)>,
416) -> PolarsResult<DataType> {
417    #[allow(clippy::len_zero)]
418    if inputs.len() == 0 {
419        // should not be reachable - we did not set ALLOW_EMPTY_INPUTS
420        panic!();
421    }
422
423    let mut inputs = inputs.map(|(name, dtype)| {
424        let (inner_dtype, width) = match dtype {
425            DataType::Array(inner, width) => (inner.as_ref(), *width),
426            dt => (dt, 1),
427        };
428        (name, dtype, inner_dtype, width)
429    });
430    let (first_name, first_dtype, first_inner_dtype, mut out_width) = inputs.next().unwrap();
431
432    for (col_name, dtype, inner_dtype, width) in inputs {
433        out_width += width;
434
435        if inner_dtype != first_inner_dtype {
436            polars_bail!(
437                SchemaMismatch:
438                "concat_arr dtype mismatch: expected {} or array[{}] dtype to match dtype of first \
439                input column (name: {}, dtype: {}), got {} instead for column {}",
440                first_inner_dtype, first_inner_dtype, first_name, first_dtype, dtype, col_name,
441            )
442        }
443    }
444
445    Ok(DataType::Array(
446        Box::new(first_inner_dtype.clone()),
447        out_width,
448    ))
449}
450
451#[cfg(feature = "array_to_struct")]
452fn arr_to_struct(s: &Column, name_generator: Option<DslNameGenerator>) -> PolarsResult<Column> {
453    let name_generator =
454        name_generator.map(|f| Arc::new(move |i| f.call(i).map(PlSmallStr::from)) as Arc<_>);
455    s.array()?
456        .to_struct(name_generator)
457        .map(IntoColumn::into_column)
458}