polars_plan/plans/aexpr/function_expr/
array.rs

1use polars_ops::chunked_array::array::*;
2
3use super::*;
4use crate::{map, map_as_slice};
5
6#[derive(Clone, Copy, Eq, PartialEq, Hash, Debug)]
7#[cfg_attr(feature = "ir_serde", derive(serde::Serialize, serde::Deserialize))]
8pub enum IRArrayFunction {
9    Length,
10    Min,
11    Max,
12    Sum,
13    ToList,
14    Unique(bool),
15    NUnique,
16    Std(u8),
17    Var(u8),
18    Median,
19    #[cfg(feature = "array_any_all")]
20    Any,
21    #[cfg(feature = "array_any_all")]
22    All,
23    Sort(SortOptions),
24    Reverse,
25    ArgMin,
26    ArgMax,
27    Get(bool),
28    Join(bool),
29    #[cfg(feature = "is_in")]
30    Contains {
31        nulls_equal: bool,
32    },
33    #[cfg(feature = "array_count")]
34    CountMatches,
35    Shift,
36    Explode {
37        skip_empty: bool,
38    },
39    Concat,
40}
41
42impl IRArrayFunction {
43    pub(super) fn get_field(&self, mapper: FieldsMapper) -> PolarsResult<Field> {
44        use IRArrayFunction::*;
45        match self {
46            Concat => Ok(Field::new(
47                mapper
48                    .args()
49                    .first()
50                    .map_or(PlSmallStr::EMPTY, |x| x.name.clone()),
51                concat_arr_output_dtype(
52                    &mut mapper.args().iter().map(|x| (x.name.as_str(), &x.dtype)),
53                )?,
54            )),
55            Length => mapper.with_dtype(IDX_DTYPE),
56            Min | Max => mapper.map_to_list_and_array_inner_dtype(),
57            Sum => mapper.nested_sum_type(),
58            ToList => mapper.try_map_dtype(map_array_dtype_to_list_dtype),
59            Unique(_) => mapper.try_map_dtype(map_array_dtype_to_list_dtype),
60            NUnique => mapper.with_dtype(IDX_DTYPE),
61            Std(_) => mapper.map_to_float_dtype(),
62            Var(_) => mapper.map_to_float_dtype(),
63            Median => mapper.map_to_float_dtype(),
64            #[cfg(feature = "array_any_all")]
65            Any | All => mapper.with_dtype(DataType::Boolean),
66            Sort(_) => mapper.with_same_dtype(),
67            Reverse => mapper.with_same_dtype(),
68            ArgMin | ArgMax => mapper.with_dtype(IDX_DTYPE),
69            Get(_) => mapper.map_to_list_and_array_inner_dtype(),
70            Join(_) => mapper.with_dtype(DataType::String),
71            #[cfg(feature = "is_in")]
72            Contains { nulls_equal: _ } => mapper.with_dtype(DataType::Boolean),
73            #[cfg(feature = "array_count")]
74            CountMatches => mapper.with_dtype(IDX_DTYPE),
75            Shift => mapper.with_same_dtype(),
76            Explode { .. } => mapper.try_map_to_array_inner_dtype(),
77        }
78    }
79
80    pub fn function_options(&self) -> FunctionOptions {
81        use IRArrayFunction as A;
82        match self {
83            #[cfg(feature = "array_any_all")]
84            A::Any | A::All => FunctionOptions::elementwise(),
85            #[cfg(feature = "is_in")]
86            A::Contains { nulls_equal: _ } => FunctionOptions::elementwise(),
87            #[cfg(feature = "array_count")]
88            A::CountMatches => FunctionOptions::elementwise(),
89            A::Length
90            | A::Min
91            | A::Max
92            | A::Sum
93            | A::ToList
94            | A::Unique(_)
95            | A::NUnique
96            | A::Std(_)
97            | A::Var(_)
98            | A::Median
99            | A::Sort(_)
100            | A::Reverse
101            | A::ArgMin
102            | A::ArgMax
103            | A::Concat
104            | A::Get(_)
105            | A::Join(_)
106            | A::Shift => FunctionOptions::elementwise(),
107            A::Explode { .. } => FunctionOptions::row_separable(),
108        }
109    }
110}
111
112fn map_array_dtype_to_list_dtype(datatype: &DataType) -> PolarsResult<DataType> {
113    if let DataType::Array(inner, _) = datatype {
114        Ok(DataType::List(inner.clone()))
115    } else {
116        polars_bail!(ComputeError: "expected array dtype")
117    }
118}
119
120impl Display for IRArrayFunction {
121    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
122        use IRArrayFunction::*;
123        let name = match self {
124            Concat => "concat",
125            Length => "length",
126            Min => "min",
127            Max => "max",
128            Sum => "sum",
129            ToList => "to_list",
130            Unique(_) => "unique",
131            NUnique => "n_unique",
132            Std(_) => "std",
133            Var(_) => "var",
134            Median => "median",
135            #[cfg(feature = "array_any_all")]
136            Any => "any",
137            #[cfg(feature = "array_any_all")]
138            All => "all",
139            Sort(_) => "sort",
140            Reverse => "reverse",
141            ArgMin => "arg_min",
142            ArgMax => "arg_max",
143            Get(_) => "get",
144            Join(_) => "join",
145            #[cfg(feature = "is_in")]
146            Contains { nulls_equal: _ } => "contains",
147            #[cfg(feature = "array_count")]
148            CountMatches => "count_matches",
149            Shift => "shift",
150            Explode { .. } => "explode",
151        };
152        write!(f, "arr.{name}")
153    }
154}
155
156impl From<IRArrayFunction> for SpecialEq<Arc<dyn ColumnsUdf>> {
157    fn from(func: IRArrayFunction) -> Self {
158        use IRArrayFunction::*;
159        match func {
160            Concat => map_as_slice!(concat_arr),
161            Length => map!(length),
162            Min => map!(min),
163            Max => map!(max),
164            Sum => map!(sum),
165            ToList => map!(to_list),
166            Unique(stable) => map!(unique, stable),
167            NUnique => map!(n_unique),
168            Std(ddof) => map!(std, ddof),
169            Var(ddof) => map!(var, ddof),
170            Median => map!(median),
171            #[cfg(feature = "array_any_all")]
172            Any => map!(any),
173            #[cfg(feature = "array_any_all")]
174            All => map!(all),
175            Sort(options) => map!(sort, options),
176            Reverse => map!(reverse),
177            ArgMin => map!(arg_min),
178            ArgMax => map!(arg_max),
179            Get(null_on_oob) => map_as_slice!(get, null_on_oob),
180            Join(ignore_nulls) => map_as_slice!(join, ignore_nulls),
181            #[cfg(feature = "is_in")]
182            Contains { nulls_equal } => map_as_slice!(contains, nulls_equal),
183            #[cfg(feature = "array_count")]
184            CountMatches => map_as_slice!(count_matches),
185            Shift => map_as_slice!(shift),
186            Explode { skip_empty } => map_as_slice!(explode, skip_empty),
187        }
188    }
189}
190
191pub(super) fn length(s: &Column) -> PolarsResult<Column> {
192    let array = s.array()?;
193    let width = array.width();
194    let width = IdxSize::try_from(width)
195        .map_err(|_| polars_err!(bigidx, ctx = "array length", size = width))?;
196
197    let mut c = Column::new_scalar(array.name().clone(), width.into(), array.len());
198    if let Some(validity) = array.rechunk_validity() {
199        let mut series = c.into_materialized_series().clone();
200
201        // SAFETY: We keep datatypes intact and call compute_len afterwards.
202        let chunks = unsafe { series.chunks_mut() };
203        assert_eq!(chunks.len(), 1);
204
205        chunks[0] = chunks[0].with_validity(Some(validity));
206
207        series.compute_len();
208        c = series.into_column();
209    }
210
211    Ok(c)
212}
213
214pub(super) fn max(s: &Column) -> PolarsResult<Column> {
215    Ok(s.array()?.array_max().into())
216}
217
218pub(super) fn min(s: &Column) -> PolarsResult<Column> {
219    Ok(s.array()?.array_min().into())
220}
221
222pub(super) fn sum(s: &Column) -> PolarsResult<Column> {
223    s.array()?.array_sum().map(Column::from)
224}
225
226pub(super) fn std(s: &Column, ddof: u8) -> PolarsResult<Column> {
227    s.array()?.array_std(ddof).map(Column::from)
228}
229
230pub(super) fn var(s: &Column, ddof: u8) -> PolarsResult<Column> {
231    s.array()?.array_var(ddof).map(Column::from)
232}
233pub(super) fn median(s: &Column) -> PolarsResult<Column> {
234    s.array()?.array_median().map(Column::from)
235}
236
237pub(super) fn unique(s: &Column, stable: bool) -> PolarsResult<Column> {
238    let ca = s.array()?;
239    let out = if stable {
240        ca.array_unique_stable()
241    } else {
242        ca.array_unique()
243    };
244    out.map(|ca| ca.into_column())
245}
246
247pub(super) fn n_unique(s: &Column) -> PolarsResult<Column> {
248    Ok(s.array()?.array_n_unique()?.into_column())
249}
250
251pub(super) fn to_list(s: &Column) -> PolarsResult<Column> {
252    let list_dtype = map_array_dtype_to_list_dtype(s.dtype())?;
253    s.cast(&list_dtype)
254}
255
256#[cfg(feature = "array_any_all")]
257pub(super) fn any(s: &Column) -> PolarsResult<Column> {
258    s.array()?.array_any().map(Column::from)
259}
260
261#[cfg(feature = "array_any_all")]
262pub(super) fn all(s: &Column) -> PolarsResult<Column> {
263    s.array()?.array_all().map(Column::from)
264}
265
266pub(super) fn sort(s: &Column, options: SortOptions) -> PolarsResult<Column> {
267    Ok(s.array()?.array_sort(options)?.into_column())
268}
269
270pub(super) fn reverse(s: &Column) -> PolarsResult<Column> {
271    Ok(s.array()?.array_reverse().into_column())
272}
273
274pub(super) fn arg_min(s: &Column) -> PolarsResult<Column> {
275    Ok(s.array()?.array_arg_min().into_column())
276}
277
278pub(super) fn arg_max(s: &Column) -> PolarsResult<Column> {
279    Ok(s.array()?.array_arg_max().into_column())
280}
281
282pub(super) fn get(s: &[Column], null_on_oob: bool) -> PolarsResult<Column> {
283    let ca = s[0].array()?;
284    let index = s[1].cast(&DataType::Int64)?;
285    let index = index.i64().unwrap();
286    ca.array_get(index, null_on_oob).map(Column::from)
287}
288
289pub(super) fn join(s: &[Column], ignore_nulls: bool) -> PolarsResult<Column> {
290    let ca = s[0].array()?;
291    let separator = s[1].str()?;
292    ca.array_join(separator, ignore_nulls).map(Column::from)
293}
294
295#[cfg(feature = "is_in")]
296pub(super) fn contains(s: &[Column], nulls_equal: bool) -> PolarsResult<Column> {
297    let array = &s[0];
298    let item = &s[1];
299    polars_ensure!(matches!(array.dtype(), DataType::Array(_, _)),
300        SchemaMismatch: "invalid series dtype: expected `Array`, got `{}`", array.dtype(),
301    );
302    let mut ca = is_in(
303        item.as_materialized_series(),
304        array.as_materialized_series(),
305        nulls_equal,
306    )?;
307    ca.rename(array.name().clone());
308    Ok(ca.into_column())
309}
310
311#[cfg(feature = "array_count")]
312pub(super) fn count_matches(args: &[Column]) -> PolarsResult<Column> {
313    let s = &args[0];
314    let element = &args[1];
315    polars_ensure!(
316        element.len() == 1,
317        ComputeError: "argument expression in `arr.count_matches` must produce exactly one element, got {}",
318        element.len()
319    );
320    let ca = s.array()?;
321    ca.array_count_matches(element.get(0).unwrap())
322        .map(Column::from)
323}
324
325pub(super) fn shift(s: &[Column]) -> PolarsResult<Column> {
326    let ca = s[0].array()?;
327    let n = &s[1];
328
329    ca.array_shift(n.as_materialized_series()).map(Column::from)
330}
331
332fn explode(c: &[Column], skip_empty: bool) -> PolarsResult<Column> {
333    c[0].explode(skip_empty)
334}
335
336fn concat_arr(args: &[Column]) -> PolarsResult<Column> {
337    let dtype = concat_arr_output_dtype(&mut args.iter().map(|c| (c.name().as_str(), c.dtype())))?;
338
339    polars_ops::series::concat_arr::concat_arr(args, &dtype)
340}
341
342/// Determine the output dtype of a `concat_arr` operation. Also performs validation to ensure input
343/// dtypes are compatible.
344fn concat_arr_output_dtype(
345    inputs: &mut dyn ExactSizeIterator<Item = (&str, &DataType)>,
346) -> PolarsResult<DataType> {
347    #[allow(clippy::len_zero)]
348    if inputs.len() == 0 {
349        // should not be reachable - we did not set ALLOW_EMPTY_INPUTS
350        panic!();
351    }
352
353    let mut inputs = inputs.map(|(name, dtype)| {
354        let (inner_dtype, width) = match dtype {
355            DataType::Array(inner, width) => (inner.as_ref(), *width),
356            dt => (dt, 1),
357        };
358        (name, dtype, inner_dtype, width)
359    });
360    let (first_name, first_dtype, first_inner_dtype, mut out_width) = inputs.next().unwrap();
361
362    for (col_name, dtype, inner_dtype, width) in inputs {
363        out_width += width;
364
365        if inner_dtype != first_inner_dtype {
366            polars_bail!(
367                SchemaMismatch:
368                "concat_arr dtype mismatch: expected {} or array[{}] dtype to match dtype of first \
369                input column (name: {}, dtype: {}), got {} instead for column {}",
370                first_inner_dtype, first_inner_dtype, first_name, first_dtype, dtype, col_name,
371            )
372        }
373    }
374
375    Ok(DataType::Array(
376        Box::new(first_inner_dtype.clone()),
377        out_width,
378    ))
379}