1use polars_ops::chunked_array::array::*;
2
3use super::*;
4use crate::{map, map_as_slice};
5
6#[derive(Clone, Copy, Eq, PartialEq, Hash, Debug)]
7#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
8pub enum ArrayFunction {
9 Length,
10 Min,
11 Max,
12 Sum,
13 ToList,
14 Unique(bool),
15 NUnique,
16 Std(u8),
17 Var(u8),
18 Median,
19 #[cfg(feature = "array_any_all")]
20 Any,
21 #[cfg(feature = "array_any_all")]
22 All,
23 Sort(SortOptions),
24 Reverse,
25 ArgMin,
26 ArgMax,
27 Get(bool),
28 Join(bool),
29 #[cfg(feature = "is_in")]
30 Contains,
31 #[cfg(feature = "array_count")]
32 CountMatches,
33 Shift,
34 Explode {
35 skip_empty: bool,
36 },
37 Concat,
38}
39
40impl ArrayFunction {
41 pub(super) fn get_field(&self, mapper: FieldsMapper) -> PolarsResult<Field> {
42 use ArrayFunction::*;
43 match self {
44 Concat => Ok(Field::new(
45 mapper
46 .args()
47 .first()
48 .map_or(PlSmallStr::EMPTY, |x| x.name.clone()),
49 concat_arr_output_dtype(
50 &mut mapper.args().iter().map(|x| (x.name.as_str(), &x.dtype)),
51 )?,
52 )),
53 Length => mapper.with_dtype(IDX_DTYPE),
54 Min | Max => mapper.map_to_list_and_array_inner_dtype(),
55 Sum => mapper.nested_sum_type(),
56 ToList => mapper.try_map_dtype(map_array_dtype_to_list_dtype),
57 Unique(_) => mapper.try_map_dtype(map_array_dtype_to_list_dtype),
58 NUnique => mapper.with_dtype(IDX_DTYPE),
59 Std(_) => mapper.map_to_float_dtype(),
60 Var(_) => mapper.map_to_float_dtype(),
61 Median => mapper.map_to_float_dtype(),
62 #[cfg(feature = "array_any_all")]
63 Any | All => mapper.with_dtype(DataType::Boolean),
64 Sort(_) => mapper.with_same_dtype(),
65 Reverse => mapper.with_same_dtype(),
66 ArgMin | ArgMax => mapper.with_dtype(IDX_DTYPE),
67 Get(_) => mapper.map_to_list_and_array_inner_dtype(),
68 Join(_) => mapper.with_dtype(DataType::String),
69 #[cfg(feature = "is_in")]
70 Contains => mapper.with_dtype(DataType::Boolean),
71 #[cfg(feature = "array_count")]
72 CountMatches => mapper.with_dtype(IDX_DTYPE),
73 Shift => mapper.with_same_dtype(),
74 Explode { .. } => mapper.try_map_to_array_inner_dtype(),
75 }
76 }
77
78 pub fn function_options(&self) -> FunctionOptions {
79 use ArrayFunction as A;
80 match self {
81 #[cfg(feature = "array_any_all")]
82 A::Any | A::All => FunctionOptions::elementwise(),
83 #[cfg(feature = "is_in")]
84 A::Contains => FunctionOptions::elementwise(),
85 #[cfg(feature = "array_count")]
86 A::CountMatches => FunctionOptions::elementwise(),
87 A::Length
88 | A::Min
89 | A::Max
90 | A::Sum
91 | A::ToList
92 | A::Unique(_)
93 | A::NUnique
94 | A::Std(_)
95 | A::Var(_)
96 | A::Median
97 | A::Sort(_)
98 | A::Reverse
99 | A::ArgMin
100 | A::ArgMax
101 | A::Concat
102 | A::Get(_)
103 | A::Join(_)
104 | A::Shift => FunctionOptions::elementwise(),
105 A::Explode { .. } => FunctionOptions::row_separable(),
106 }
107 }
108}
109
110fn map_array_dtype_to_list_dtype(datatype: &DataType) -> PolarsResult<DataType> {
111 if let DataType::Array(inner, _) = datatype {
112 Ok(DataType::List(inner.clone()))
113 } else {
114 polars_bail!(ComputeError: "expected array dtype")
115 }
116}
117
118impl Display for ArrayFunction {
119 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
120 use ArrayFunction::*;
121 let name = match self {
122 Concat => "concat",
123 Length => "length",
124 Min => "min",
125 Max => "max",
126 Sum => "sum",
127 ToList => "to_list",
128 Unique(_) => "unique",
129 NUnique => "n_unique",
130 Std(_) => "std",
131 Var(_) => "var",
132 Median => "median",
133 #[cfg(feature = "array_any_all")]
134 Any => "any",
135 #[cfg(feature = "array_any_all")]
136 All => "all",
137 Sort(_) => "sort",
138 Reverse => "reverse",
139 ArgMin => "arg_min",
140 ArgMax => "arg_max",
141 Get(_) => "get",
142 Join(_) => "join",
143 #[cfg(feature = "is_in")]
144 Contains => "contains",
145 #[cfg(feature = "array_count")]
146 CountMatches => "count_matches",
147 Shift => "shift",
148 Explode { .. } => "explode",
149 };
150 write!(f, "arr.{name}")
151 }
152}
153
154impl From<ArrayFunction> for SpecialEq<Arc<dyn ColumnsUdf>> {
155 fn from(func: ArrayFunction) -> Self {
156 use ArrayFunction::*;
157 match func {
158 Concat => map_as_slice!(concat_arr),
159 Length => map!(length),
160 Min => map!(min),
161 Max => map!(max),
162 Sum => map!(sum),
163 ToList => map!(to_list),
164 Unique(stable) => map!(unique, stable),
165 NUnique => map!(n_unique),
166 Std(ddof) => map!(std, ddof),
167 Var(ddof) => map!(var, ddof),
168 Median => map!(median),
169 #[cfg(feature = "array_any_all")]
170 Any => map!(any),
171 #[cfg(feature = "array_any_all")]
172 All => map!(all),
173 Sort(options) => map!(sort, options),
174 Reverse => map!(reverse),
175 ArgMin => map!(arg_min),
176 ArgMax => map!(arg_max),
177 Get(null_on_oob) => map_as_slice!(get, null_on_oob),
178 Join(ignore_nulls) => map_as_slice!(join, ignore_nulls),
179 #[cfg(feature = "is_in")]
180 Contains => map_as_slice!(contains),
181 #[cfg(feature = "array_count")]
182 CountMatches => map_as_slice!(count_matches),
183 Shift => map_as_slice!(shift),
184 Explode { skip_empty } => map_as_slice!(explode, skip_empty),
185 }
186 }
187}
188
189pub(super) fn length(s: &Column) -> PolarsResult<Column> {
190 let array = s.array()?;
191 let width = array.width();
192 let width = IdxSize::try_from(width)
193 .map_err(|_| polars_err!(bigidx, ctx = "array length", size = width))?;
194
195 let mut c = Column::new_scalar(array.name().clone(), width.into(), array.len());
196 if let Some(validity) = array.rechunk_validity() {
197 let mut series = c.into_materialized_series().clone();
198
199 let chunks = unsafe { series.chunks_mut() };
201 assert_eq!(chunks.len(), 1);
202
203 chunks[0] = chunks[0].with_validity(Some(validity));
204
205 series.compute_len();
206 c = series.into_column();
207 }
208
209 Ok(c)
210}
211
212pub(super) fn max(s: &Column) -> PolarsResult<Column> {
213 Ok(s.array()?.array_max().into())
214}
215
216pub(super) fn min(s: &Column) -> PolarsResult<Column> {
217 Ok(s.array()?.array_min().into())
218}
219
220pub(super) fn sum(s: &Column) -> PolarsResult<Column> {
221 s.array()?.array_sum().map(Column::from)
222}
223
224pub(super) fn std(s: &Column, ddof: u8) -> PolarsResult<Column> {
225 s.array()?.array_std(ddof).map(Column::from)
226}
227
228pub(super) fn var(s: &Column, ddof: u8) -> PolarsResult<Column> {
229 s.array()?.array_var(ddof).map(Column::from)
230}
231pub(super) fn median(s: &Column) -> PolarsResult<Column> {
232 s.array()?.array_median().map(Column::from)
233}
234
235pub(super) fn unique(s: &Column, stable: bool) -> PolarsResult<Column> {
236 let ca = s.array()?;
237 let out = if stable {
238 ca.array_unique_stable()
239 } else {
240 ca.array_unique()
241 };
242 out.map(|ca| ca.into_column())
243}
244
245pub(super) fn n_unique(s: &Column) -> PolarsResult<Column> {
246 Ok(s.array()?.array_n_unique()?.into_column())
247}
248
249pub(super) fn to_list(s: &Column) -> PolarsResult<Column> {
250 let list_dtype = map_array_dtype_to_list_dtype(s.dtype())?;
251 s.cast(&list_dtype)
252}
253
254#[cfg(feature = "array_any_all")]
255pub(super) fn any(s: &Column) -> PolarsResult<Column> {
256 s.array()?.array_any().map(Column::from)
257}
258
259#[cfg(feature = "array_any_all")]
260pub(super) fn all(s: &Column) -> PolarsResult<Column> {
261 s.array()?.array_all().map(Column::from)
262}
263
264pub(super) fn sort(s: &Column, options: SortOptions) -> PolarsResult<Column> {
265 Ok(s.array()?.array_sort(options)?.into_column())
266}
267
268pub(super) fn reverse(s: &Column) -> PolarsResult<Column> {
269 Ok(s.array()?.array_reverse().into_column())
270}
271
272pub(super) fn arg_min(s: &Column) -> PolarsResult<Column> {
273 Ok(s.array()?.array_arg_min().into_column())
274}
275
276pub(super) fn arg_max(s: &Column) -> PolarsResult<Column> {
277 Ok(s.array()?.array_arg_max().into_column())
278}
279
280pub(super) fn get(s: &[Column], null_on_oob: bool) -> PolarsResult<Column> {
281 let ca = s[0].array()?;
282 let index = s[1].cast(&DataType::Int64)?;
283 let index = index.i64().unwrap();
284 ca.array_get(index, null_on_oob).map(Column::from)
285}
286
287pub(super) fn join(s: &[Column], ignore_nulls: bool) -> PolarsResult<Column> {
288 let ca = s[0].array()?;
289 let separator = s[1].str()?;
290 ca.array_join(separator, ignore_nulls).map(Column::from)
291}
292
293#[cfg(feature = "is_in")]
294pub(super) fn contains(s: &[Column]) -> PolarsResult<Column> {
295 let array = &s[0];
296 let item = &s[1];
297 polars_ensure!(matches!(array.dtype(), DataType::Array(_, _)),
298 SchemaMismatch: "invalid series dtype: expected `Array`, got `{}`", array.dtype(),
299 );
300 Ok(is_in(
301 item.as_materialized_series(),
302 array.as_materialized_series(),
303 true,
304 )?
305 .with_name(array.name().clone())
306 .into_column())
307}
308
309#[cfg(feature = "array_count")]
310pub(super) fn count_matches(args: &[Column]) -> PolarsResult<Column> {
311 let s = &args[0];
312 let element = &args[1];
313 polars_ensure!(
314 element.len() == 1,
315 ComputeError: "argument expression in `arr.count_matches` must produce exactly one element, got {}",
316 element.len()
317 );
318 let ca = s.array()?;
319 ca.array_count_matches(element.get(0).unwrap())
320 .map(Column::from)
321}
322
323pub(super) fn shift(s: &[Column]) -> PolarsResult<Column> {
324 let ca = s[0].array()?;
325 let n = &s[1];
326
327 ca.array_shift(n.as_materialized_series()).map(Column::from)
328}
329
330fn explode(c: &[Column], skip_empty: bool) -> PolarsResult<Column> {
331 c[0].explode(skip_empty)
332}
333
334fn concat_arr(args: &[Column]) -> PolarsResult<Column> {
335 let dtype = concat_arr_output_dtype(&mut args.iter().map(|c| (c.name().as_str(), c.dtype())))?;
336
337 polars_ops::series::concat_arr::concat_arr(args, &dtype)
338}
339
340fn concat_arr_output_dtype(
343 inputs: &mut dyn ExactSizeIterator<Item = (&str, &DataType)>,
344) -> PolarsResult<DataType> {
345 #[allow(clippy::len_zero)]
346 if inputs.len() == 0 {
347 panic!();
349 }
350
351 let mut inputs = inputs.map(|(name, dtype)| {
352 let (inner_dtype, width) = match dtype {
353 DataType::Array(inner, width) => (inner.as_ref(), *width),
354 dt => (dt, 1),
355 };
356 (name, dtype, inner_dtype, width)
357 });
358 let (first_name, first_dtype, first_inner_dtype, mut out_width) = inputs.next().unwrap();
359
360 for (col_name, dtype, inner_dtype, width) in inputs {
361 out_width += width;
362
363 if inner_dtype != first_inner_dtype {
364 polars_bail!(
365 SchemaMismatch:
366 "concat_arr dtype mismatch: expected {} or array[{}] dtype to match dtype of first \
367 input column (name: {}, dtype: {}), got {} instead for column {}",
368 first_inner_dtype, first_inner_dtype, first_name, first_dtype, dtype, col_name,
369 )
370 }
371 }
372
373 Ok(DataType::Array(
374 Box::new(first_inner_dtype.clone()),
375 out_width,
376 ))
377}