polars_plan/dsl/function_expr/
array.rs

1use polars_ops::chunked_array::array::*;
2
3use super::*;
4use crate::{map, map_as_slice};
5
6#[derive(Clone, Copy, Eq, PartialEq, Hash, Debug)]
7#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
8pub enum ArrayFunction {
9    Length,
10    Min,
11    Max,
12    Sum,
13    ToList,
14    Unique(bool),
15    NUnique,
16    Std(u8),
17    Var(u8),
18    Median,
19    #[cfg(feature = "array_any_all")]
20    Any,
21    #[cfg(feature = "array_any_all")]
22    All,
23    Sort(SortOptions),
24    Reverse,
25    ArgMin,
26    ArgMax,
27    Get(bool),
28    Join(bool),
29    #[cfg(feature = "is_in")]
30    Contains,
31    #[cfg(feature = "array_count")]
32    CountMatches,
33    Shift,
34    Explode {
35        skip_empty: bool,
36    },
37    Concat,
38}
39
40impl ArrayFunction {
41    pub(super) fn get_field(&self, mapper: FieldsMapper) -> PolarsResult<Field> {
42        use ArrayFunction::*;
43        match self {
44            Concat => Ok(Field::new(
45                mapper
46                    .args()
47                    .first()
48                    .map_or(PlSmallStr::EMPTY, |x| x.name.clone()),
49                concat_arr_output_dtype(
50                    &mut mapper.args().iter().map(|x| (x.name.as_str(), &x.dtype)),
51                )?,
52            )),
53            Length => mapper.with_dtype(IDX_DTYPE),
54            Min | Max => mapper.map_to_list_and_array_inner_dtype(),
55            Sum => mapper.nested_sum_type(),
56            ToList => mapper.try_map_dtype(map_array_dtype_to_list_dtype),
57            Unique(_) => mapper.try_map_dtype(map_array_dtype_to_list_dtype),
58            NUnique => mapper.with_dtype(IDX_DTYPE),
59            Std(_) => mapper.map_to_float_dtype(),
60            Var(_) => mapper.map_to_float_dtype(),
61            Median => mapper.map_to_float_dtype(),
62            #[cfg(feature = "array_any_all")]
63            Any | All => mapper.with_dtype(DataType::Boolean),
64            Sort(_) => mapper.with_same_dtype(),
65            Reverse => mapper.with_same_dtype(),
66            ArgMin | ArgMax => mapper.with_dtype(IDX_DTYPE),
67            Get(_) => mapper.map_to_list_and_array_inner_dtype(),
68            Join(_) => mapper.with_dtype(DataType::String),
69            #[cfg(feature = "is_in")]
70            Contains => mapper.with_dtype(DataType::Boolean),
71            #[cfg(feature = "array_count")]
72            CountMatches => mapper.with_dtype(IDX_DTYPE),
73            Shift => mapper.with_same_dtype(),
74            Explode { .. } => mapper.try_map_to_array_inner_dtype(),
75        }
76    }
77
78    pub fn function_options(&self) -> FunctionOptions {
79        use ArrayFunction as A;
80        match self {
81            #[cfg(feature = "array_any_all")]
82            A::Any | A::All => FunctionOptions::elementwise(),
83            #[cfg(feature = "is_in")]
84            A::Contains => FunctionOptions::elementwise(),
85            #[cfg(feature = "array_count")]
86            A::CountMatches => FunctionOptions::elementwise(),
87            A::Length
88            | A::Min
89            | A::Max
90            | A::Sum
91            | A::ToList
92            | A::Unique(_)
93            | A::NUnique
94            | A::Std(_)
95            | A::Var(_)
96            | A::Median
97            | A::Sort(_)
98            | A::Reverse
99            | A::ArgMin
100            | A::ArgMax
101            | A::Concat
102            | A::Get(_)
103            | A::Join(_)
104            | A::Shift => FunctionOptions::elementwise(),
105            A::Explode { .. } => FunctionOptions::row_separable(),
106        }
107    }
108}
109
110fn map_array_dtype_to_list_dtype(datatype: &DataType) -> PolarsResult<DataType> {
111    if let DataType::Array(inner, _) = datatype {
112        Ok(DataType::List(inner.clone()))
113    } else {
114        polars_bail!(ComputeError: "expected array dtype")
115    }
116}
117
118impl Display for ArrayFunction {
119    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
120        use ArrayFunction::*;
121        let name = match self {
122            Concat => "concat",
123            Length => "length",
124            Min => "min",
125            Max => "max",
126            Sum => "sum",
127            ToList => "to_list",
128            Unique(_) => "unique",
129            NUnique => "n_unique",
130            Std(_) => "std",
131            Var(_) => "var",
132            Median => "median",
133            #[cfg(feature = "array_any_all")]
134            Any => "any",
135            #[cfg(feature = "array_any_all")]
136            All => "all",
137            Sort(_) => "sort",
138            Reverse => "reverse",
139            ArgMin => "arg_min",
140            ArgMax => "arg_max",
141            Get(_) => "get",
142            Join(_) => "join",
143            #[cfg(feature = "is_in")]
144            Contains => "contains",
145            #[cfg(feature = "array_count")]
146            CountMatches => "count_matches",
147            Shift => "shift",
148            Explode { .. } => "explode",
149        };
150        write!(f, "arr.{name}")
151    }
152}
153
154impl From<ArrayFunction> for SpecialEq<Arc<dyn ColumnsUdf>> {
155    fn from(func: ArrayFunction) -> Self {
156        use ArrayFunction::*;
157        match func {
158            Concat => map_as_slice!(concat_arr),
159            Length => map!(length),
160            Min => map!(min),
161            Max => map!(max),
162            Sum => map!(sum),
163            ToList => map!(to_list),
164            Unique(stable) => map!(unique, stable),
165            NUnique => map!(n_unique),
166            Std(ddof) => map!(std, ddof),
167            Var(ddof) => map!(var, ddof),
168            Median => map!(median),
169            #[cfg(feature = "array_any_all")]
170            Any => map!(any),
171            #[cfg(feature = "array_any_all")]
172            All => map!(all),
173            Sort(options) => map!(sort, options),
174            Reverse => map!(reverse),
175            ArgMin => map!(arg_min),
176            ArgMax => map!(arg_max),
177            Get(null_on_oob) => map_as_slice!(get, null_on_oob),
178            Join(ignore_nulls) => map_as_slice!(join, ignore_nulls),
179            #[cfg(feature = "is_in")]
180            Contains => map_as_slice!(contains),
181            #[cfg(feature = "array_count")]
182            CountMatches => map_as_slice!(count_matches),
183            Shift => map_as_slice!(shift),
184            Explode { skip_empty } => map_as_slice!(explode, skip_empty),
185        }
186    }
187}
188
189pub(super) fn length(s: &Column) -> PolarsResult<Column> {
190    let array = s.array()?;
191    let width = array.width();
192    let width = IdxSize::try_from(width)
193        .map_err(|_| polars_err!(bigidx, ctx = "array length", size = width))?;
194
195    let mut c = Column::new_scalar(array.name().clone(), width.into(), array.len());
196    if let Some(validity) = array.rechunk_validity() {
197        let mut series = c.into_materialized_series().clone();
198
199        // SAFETY: We keep datatypes intact and call compute_len afterwards.
200        let chunks = unsafe { series.chunks_mut() };
201        assert_eq!(chunks.len(), 1);
202
203        chunks[0] = chunks[0].with_validity(Some(validity));
204
205        series.compute_len();
206        c = series.into_column();
207    }
208
209    Ok(c)
210}
211
212pub(super) fn max(s: &Column) -> PolarsResult<Column> {
213    Ok(s.array()?.array_max().into())
214}
215
216pub(super) fn min(s: &Column) -> PolarsResult<Column> {
217    Ok(s.array()?.array_min().into())
218}
219
220pub(super) fn sum(s: &Column) -> PolarsResult<Column> {
221    s.array()?.array_sum().map(Column::from)
222}
223
224pub(super) fn std(s: &Column, ddof: u8) -> PolarsResult<Column> {
225    s.array()?.array_std(ddof).map(Column::from)
226}
227
228pub(super) fn var(s: &Column, ddof: u8) -> PolarsResult<Column> {
229    s.array()?.array_var(ddof).map(Column::from)
230}
231pub(super) fn median(s: &Column) -> PolarsResult<Column> {
232    s.array()?.array_median().map(Column::from)
233}
234
235pub(super) fn unique(s: &Column, stable: bool) -> PolarsResult<Column> {
236    let ca = s.array()?;
237    let out = if stable {
238        ca.array_unique_stable()
239    } else {
240        ca.array_unique()
241    };
242    out.map(|ca| ca.into_column())
243}
244
245pub(super) fn n_unique(s: &Column) -> PolarsResult<Column> {
246    Ok(s.array()?.array_n_unique()?.into_column())
247}
248
249pub(super) fn to_list(s: &Column) -> PolarsResult<Column> {
250    let list_dtype = map_array_dtype_to_list_dtype(s.dtype())?;
251    s.cast(&list_dtype)
252}
253
254#[cfg(feature = "array_any_all")]
255pub(super) fn any(s: &Column) -> PolarsResult<Column> {
256    s.array()?.array_any().map(Column::from)
257}
258
259#[cfg(feature = "array_any_all")]
260pub(super) fn all(s: &Column) -> PolarsResult<Column> {
261    s.array()?.array_all().map(Column::from)
262}
263
264pub(super) fn sort(s: &Column, options: SortOptions) -> PolarsResult<Column> {
265    Ok(s.array()?.array_sort(options)?.into_column())
266}
267
268pub(super) fn reverse(s: &Column) -> PolarsResult<Column> {
269    Ok(s.array()?.array_reverse().into_column())
270}
271
272pub(super) fn arg_min(s: &Column) -> PolarsResult<Column> {
273    Ok(s.array()?.array_arg_min().into_column())
274}
275
276pub(super) fn arg_max(s: &Column) -> PolarsResult<Column> {
277    Ok(s.array()?.array_arg_max().into_column())
278}
279
280pub(super) fn get(s: &[Column], null_on_oob: bool) -> PolarsResult<Column> {
281    let ca = s[0].array()?;
282    let index = s[1].cast(&DataType::Int64)?;
283    let index = index.i64().unwrap();
284    ca.array_get(index, null_on_oob).map(Column::from)
285}
286
287pub(super) fn join(s: &[Column], ignore_nulls: bool) -> PolarsResult<Column> {
288    let ca = s[0].array()?;
289    let separator = s[1].str()?;
290    ca.array_join(separator, ignore_nulls).map(Column::from)
291}
292
293#[cfg(feature = "is_in")]
294pub(super) fn contains(s: &[Column]) -> PolarsResult<Column> {
295    let array = &s[0];
296    let item = &s[1];
297    polars_ensure!(matches!(array.dtype(), DataType::Array(_, _)),
298        SchemaMismatch: "invalid series dtype: expected `Array`, got `{}`", array.dtype(),
299    );
300    Ok(is_in(
301        item.as_materialized_series(),
302        array.as_materialized_series(),
303        true,
304    )?
305    .with_name(array.name().clone())
306    .into_column())
307}
308
309#[cfg(feature = "array_count")]
310pub(super) fn count_matches(args: &[Column]) -> PolarsResult<Column> {
311    let s = &args[0];
312    let element = &args[1];
313    polars_ensure!(
314        element.len() == 1,
315        ComputeError: "argument expression in `arr.count_matches` must produce exactly one element, got {}",
316        element.len()
317    );
318    let ca = s.array()?;
319    ca.array_count_matches(element.get(0).unwrap())
320        .map(Column::from)
321}
322
323pub(super) fn shift(s: &[Column]) -> PolarsResult<Column> {
324    let ca = s[0].array()?;
325    let n = &s[1];
326
327    ca.array_shift(n.as_materialized_series()).map(Column::from)
328}
329
330fn explode(c: &[Column], skip_empty: bool) -> PolarsResult<Column> {
331    c[0].explode(skip_empty)
332}
333
334fn concat_arr(args: &[Column]) -> PolarsResult<Column> {
335    let dtype = concat_arr_output_dtype(&mut args.iter().map(|c| (c.name().as_str(), c.dtype())))?;
336
337    polars_ops::series::concat_arr::concat_arr(args, &dtype)
338}
339
340/// Determine the output dtype of a `concat_arr` operation. Also performs validation to ensure input
341/// dtypes are compatible.
342fn concat_arr_output_dtype(
343    inputs: &mut dyn ExactSizeIterator<Item = (&str, &DataType)>,
344) -> PolarsResult<DataType> {
345    #[allow(clippy::len_zero)]
346    if inputs.len() == 0 {
347        // should not be reachable - we did not set ALLOW_EMPTY_INPUTS
348        panic!();
349    }
350
351    let mut inputs = inputs.map(|(name, dtype)| {
352        let (inner_dtype, width) = match dtype {
353            DataType::Array(inner, width) => (inner.as_ref(), *width),
354            dt => (dt, 1),
355        };
356        (name, dtype, inner_dtype, width)
357    });
358    let (first_name, first_dtype, first_inner_dtype, mut out_width) = inputs.next().unwrap();
359
360    for (col_name, dtype, inner_dtype, width) in inputs {
361        out_width += width;
362
363        if inner_dtype != first_inner_dtype {
364            polars_bail!(
365                SchemaMismatch:
366                "concat_arr dtype mismatch: expected {} or array[{}] dtype to match dtype of first \
367                input column (name: {}, dtype: {}), got {} instead for column {}",
368                first_inner_dtype, first_inner_dtype, first_name, first_dtype, dtype, col_name,
369            )
370        }
371    }
372
373    Ok(DataType::Array(
374        Box::new(first_inner_dtype.clone()),
375        out_width,
376    ))
377}