polars_plan/dsl/function_expr/
list.rs

1use arrow::legacy::utils::CustomIterTools;
2use polars_ops::chunked_array::list::*;
3
4use super::*;
5use crate::{map, map_as_slice, wrap};
6
7#[derive(Clone, Eq, PartialEq, Hash, Debug)]
8#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
9pub enum ListFunction {
10    Concat,
11    #[cfg(feature = "is_in")]
12    Contains {
13        nulls_equal: bool,
14    },
15    #[cfg(feature = "list_drop_nulls")]
16    DropNulls,
17    #[cfg(feature = "list_sample")]
18    Sample {
19        is_fraction: bool,
20        with_replacement: bool,
21        shuffle: bool,
22        seed: Option<u64>,
23    },
24    Slice,
25    Shift,
26    Get(bool),
27    #[cfg(feature = "list_gather")]
28    Gather(bool),
29    #[cfg(feature = "list_gather")]
30    GatherEvery,
31    #[cfg(feature = "list_count")]
32    CountMatches,
33    Sum,
34    Length,
35    Max,
36    Min,
37    Mean,
38    Median,
39    Std(u8),
40    Var(u8),
41    ArgMin,
42    ArgMax,
43    #[cfg(feature = "diff")]
44    Diff {
45        n: i64,
46        null_behavior: NullBehavior,
47    },
48    Sort(SortOptions),
49    Reverse,
50    Unique(bool),
51    NUnique,
52    #[cfg(feature = "list_sets")]
53    SetOperation(SetOperation),
54    #[cfg(feature = "list_any_all")]
55    Any,
56    #[cfg(feature = "list_any_all")]
57    All,
58    Join(bool),
59    #[cfg(feature = "dtype-array")]
60    ToArray(usize),
61    #[cfg(feature = "list_to_struct")]
62    ToStruct(ListToStructArgs),
63}
64
65impl ListFunction {
66    pub(super) fn get_field(&self, mapper: FieldsMapper) -> PolarsResult<Field> {
67        use ListFunction::*;
68        match self {
69            Concat => mapper.map_to_list_supertype(),
70            #[cfg(feature = "is_in")]
71            Contains { nulls_equal: _ } => mapper.with_dtype(DataType::Boolean),
72            #[cfg(feature = "list_drop_nulls")]
73            DropNulls => mapper.with_same_dtype(),
74            #[cfg(feature = "list_sample")]
75            Sample { .. } => mapper.with_same_dtype(),
76            Slice => mapper.with_same_dtype(),
77            Shift => mapper.with_same_dtype(),
78            Get(_) => mapper.map_to_list_and_array_inner_dtype(),
79            #[cfg(feature = "list_gather")]
80            Gather(_) => mapper.with_same_dtype(),
81            #[cfg(feature = "list_gather")]
82            GatherEvery => mapper.with_same_dtype(),
83            #[cfg(feature = "list_count")]
84            CountMatches => mapper.with_dtype(IDX_DTYPE),
85            Sum => mapper.nested_sum_type(),
86            Min => mapper.map_to_list_and_array_inner_dtype(),
87            Max => mapper.map_to_list_and_array_inner_dtype(),
88            Mean => mapper.nested_mean_median_type(),
89            Median => mapper.nested_mean_median_type(),
90            Std(_) => mapper.map_to_float_dtype(), // Need to also have this sometimes marked as float32 or duration..
91            Var(_) => mapper.map_to_float_dtype(),
92            ArgMin => mapper.with_dtype(IDX_DTYPE),
93            ArgMax => mapper.with_dtype(IDX_DTYPE),
94            #[cfg(feature = "diff")]
95            Diff { .. } => mapper.map_dtype(|dt| {
96                let inner_dt = match dt.inner_dtype().unwrap() {
97                    #[cfg(feature = "dtype-datetime")]
98                    DataType::Datetime(tu, _) => DataType::Duration(*tu),
99                    #[cfg(feature = "dtype-date")]
100                    DataType::Date => DataType::Duration(TimeUnit::Milliseconds),
101                    #[cfg(feature = "dtype-time")]
102                    DataType::Time => DataType::Duration(TimeUnit::Nanoseconds),
103                    DataType::UInt64 | DataType::UInt32 => DataType::Int64,
104                    DataType::UInt16 => DataType::Int32,
105                    DataType::UInt8 => DataType::Int16,
106                    inner_dt => inner_dt.clone(),
107                };
108                DataType::List(Box::new(inner_dt))
109            }),
110            Sort(_) => mapper.with_same_dtype(),
111            Reverse => mapper.with_same_dtype(),
112            Unique(_) => mapper.with_same_dtype(),
113            Length => mapper.with_dtype(IDX_DTYPE),
114            #[cfg(feature = "list_sets")]
115            SetOperation(_) => mapper.with_same_dtype(),
116            #[cfg(feature = "list_any_all")]
117            Any => mapper.with_dtype(DataType::Boolean),
118            #[cfg(feature = "list_any_all")]
119            All => mapper.with_dtype(DataType::Boolean),
120            Join(_) => mapper.with_dtype(DataType::String),
121            #[cfg(feature = "dtype-array")]
122            ToArray(width) => mapper.try_map_dtype(|dt| map_list_dtype_to_array_dtype(dt, *width)),
123            NUnique => mapper.with_dtype(IDX_DTYPE),
124            #[cfg(feature = "list_to_struct")]
125            ToStruct(args) => mapper.try_map_dtype(|x| args.get_output_dtype(x)),
126        }
127    }
128
129    pub fn function_options(&self) -> FunctionOptions {
130        use ListFunction as L;
131        match self {
132            L::Concat => FunctionOptions::elementwise(),
133            #[cfg(feature = "is_in")]
134            L::Contains { nulls_equal: _ } => FunctionOptions::elementwise(),
135            #[cfg(feature = "list_sample")]
136            L::Sample { .. } => FunctionOptions::elementwise(),
137            #[cfg(feature = "list_gather")]
138            L::Gather(_) => FunctionOptions::elementwise(),
139            #[cfg(feature = "list_gather")]
140            L::GatherEvery => FunctionOptions::elementwise(),
141            #[cfg(feature = "list_sets")]
142            L::SetOperation(_) => FunctionOptions::elementwise()
143                .with_casting_rules(CastingRules::Supertype(SuperTypeOptions {
144                    flags: SuperTypeFlags::default() | SuperTypeFlags::ALLOW_IMPLODE_LIST,
145                }))
146                .with_flags(|f| f & !FunctionFlags::RETURNS_SCALAR),
147            #[cfg(feature = "diff")]
148            L::Diff { .. } => FunctionOptions::elementwise(),
149            #[cfg(feature = "list_drop_nulls")]
150            L::DropNulls => FunctionOptions::elementwise(),
151            #[cfg(feature = "list_count")]
152            L::CountMatches => FunctionOptions::elementwise(),
153            L::Sum
154            | L::Slice
155            | L::Shift
156            | L::Get(_)
157            | L::Length
158            | L::Max
159            | L::Min
160            | L::Mean
161            | L::Median
162            | L::Std(_)
163            | L::Var(_)
164            | L::ArgMin
165            | L::ArgMax
166            | L::Sort(_)
167            | L::Reverse
168            | L::Unique(_)
169            | L::Join(_)
170            | L::NUnique => FunctionOptions::elementwise(),
171            #[cfg(feature = "list_any_all")]
172            L::Any | L::All => FunctionOptions::elementwise(),
173            #[cfg(feature = "dtype-array")]
174            L::ToArray(_) => FunctionOptions::elementwise(),
175            #[cfg(feature = "list_to_struct")]
176            L::ToStruct(ListToStructArgs::FixedWidth(_)) => FunctionOptions::elementwise(),
177            #[cfg(feature = "list_to_struct")]
178            L::ToStruct(ListToStructArgs::InferWidth { .. }) => FunctionOptions::groupwise(),
179        }
180    }
181}
182
183#[cfg(feature = "dtype-array")]
184fn map_list_dtype_to_array_dtype(datatype: &DataType, width: usize) -> PolarsResult<DataType> {
185    if let DataType::List(inner) = datatype {
186        Ok(DataType::Array(inner.clone(), width))
187    } else {
188        polars_bail!(ComputeError: "expected List dtype")
189    }
190}
191
192impl Display for ListFunction {
193    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
194        use ListFunction::*;
195
196        let name = match self {
197            Concat => "concat",
198            #[cfg(feature = "is_in")]
199            Contains { nulls_equal: _ } => "contains",
200            #[cfg(feature = "list_drop_nulls")]
201            DropNulls => "drop_nulls",
202            #[cfg(feature = "list_sample")]
203            Sample { is_fraction, .. } => {
204                if *is_fraction {
205                    "sample_fraction"
206                } else {
207                    "sample_n"
208                }
209            },
210            Slice => "slice",
211            Shift => "shift",
212            Get(_) => "get",
213            #[cfg(feature = "list_gather")]
214            Gather(_) => "gather",
215            #[cfg(feature = "list_gather")]
216            GatherEvery => "gather_every",
217            #[cfg(feature = "list_count")]
218            CountMatches => "count_matches",
219            Sum => "sum",
220            Min => "min",
221            Max => "max",
222            Mean => "mean",
223            Median => "median",
224            Std(_) => "std",
225            Var(_) => "var",
226            ArgMin => "arg_min",
227            ArgMax => "arg_max",
228            #[cfg(feature = "diff")]
229            Diff { .. } => "diff",
230            Length => "length",
231            Sort(_) => "sort",
232            Reverse => "reverse",
233            Unique(is_stable) => {
234                if *is_stable {
235                    "unique_stable"
236                } else {
237                    "unique"
238                }
239            },
240            NUnique => "n_unique",
241            #[cfg(feature = "list_sets")]
242            SetOperation(s) => return write!(f, "list.{s}"),
243            #[cfg(feature = "list_any_all")]
244            Any => "any",
245            #[cfg(feature = "list_any_all")]
246            All => "all",
247            Join(_) => "join",
248            #[cfg(feature = "dtype-array")]
249            ToArray(_) => "to_array",
250            #[cfg(feature = "list_to_struct")]
251            ToStruct(_) => "to_struct",
252        };
253        write!(f, "list.{name}")
254    }
255}
256
257impl From<ListFunction> for SpecialEq<Arc<dyn ColumnsUdf>> {
258    fn from(func: ListFunction) -> Self {
259        use ListFunction::*;
260        match func {
261            Concat => wrap!(concat),
262            #[cfg(feature = "is_in")]
263            Contains { nulls_equal } => map_as_slice!(contains, nulls_equal),
264            #[cfg(feature = "list_drop_nulls")]
265            DropNulls => map!(drop_nulls),
266            #[cfg(feature = "list_sample")]
267            Sample {
268                is_fraction,
269                with_replacement,
270                shuffle,
271                seed,
272            } => {
273                if is_fraction {
274                    map_as_slice!(sample_fraction, with_replacement, shuffle, seed)
275                } else {
276                    map_as_slice!(sample_n, with_replacement, shuffle, seed)
277                }
278            },
279            Slice => wrap!(slice),
280            Shift => map_as_slice!(shift),
281            Get(null_on_oob) => wrap!(get, null_on_oob),
282            #[cfg(feature = "list_gather")]
283            Gather(null_on_oob) => map_as_slice!(gather, null_on_oob),
284            #[cfg(feature = "list_gather")]
285            GatherEvery => map_as_slice!(gather_every),
286            #[cfg(feature = "list_count")]
287            CountMatches => map_as_slice!(count_matches),
288            Sum => map!(sum),
289            Length => map!(length),
290            Max => map!(max),
291            Min => map!(min),
292            Mean => map!(mean),
293            Median => map!(median),
294            Std(ddof) => map!(std, ddof),
295            Var(ddof) => map!(var, ddof),
296            ArgMin => map!(arg_min),
297            ArgMax => map!(arg_max),
298            #[cfg(feature = "diff")]
299            Diff { n, null_behavior } => map!(diff, n, null_behavior),
300            Sort(options) => map!(sort, options),
301            Reverse => map!(reverse),
302            Unique(is_stable) => map!(unique, is_stable),
303            #[cfg(feature = "list_sets")]
304            SetOperation(s) => map_as_slice!(set_operation, s),
305            #[cfg(feature = "list_any_all")]
306            Any => map!(lst_any),
307            #[cfg(feature = "list_any_all")]
308            All => map!(lst_all),
309            Join(ignore_nulls) => map_as_slice!(join, ignore_nulls),
310            #[cfg(feature = "dtype-array")]
311            ToArray(width) => map!(to_array, width),
312            NUnique => map!(n_unique),
313            #[cfg(feature = "list_to_struct")]
314            ToStruct(args) => map!(to_struct, &args),
315        }
316    }
317}
318
319#[cfg(feature = "is_in")]
320pub(super) fn contains(args: &mut [Column], nulls_equal: bool) -> PolarsResult<Column> {
321    let list = &args[0];
322    let item = &args[1];
323    polars_ensure!(matches!(list.dtype(), DataType::List(_)),
324        SchemaMismatch: "invalid series dtype: expected `List`, got `{}`", list.dtype(),
325    );
326    let mut ca = polars_ops::prelude::is_in(
327        item.as_materialized_series(),
328        list.as_materialized_series(),
329        nulls_equal,
330    )?;
331    ca.rename(list.name().clone());
332    Ok(ca.into_column())
333}
334
335#[cfg(feature = "list_drop_nulls")]
336pub(super) fn drop_nulls(s: &Column) -> PolarsResult<Column> {
337    let list = s.list()?;
338    Ok(list.lst_drop_nulls().into_column())
339}
340
341#[cfg(feature = "list_sample")]
342pub(super) fn sample_n(
343    s: &[Column],
344    with_replacement: bool,
345    shuffle: bool,
346    seed: Option<u64>,
347) -> PolarsResult<Column> {
348    let list = s[0].list()?;
349    let n = &s[1];
350    list.lst_sample_n(n.as_materialized_series(), with_replacement, shuffle, seed)
351        .map(|ok| ok.into_column())
352}
353
354#[cfg(feature = "list_sample")]
355pub(super) fn sample_fraction(
356    s: &[Column],
357    with_replacement: bool,
358    shuffle: bool,
359    seed: Option<u64>,
360) -> PolarsResult<Column> {
361    let list = s[0].list()?;
362    let fraction = &s[1];
363    list.lst_sample_fraction(
364        fraction.as_materialized_series(),
365        with_replacement,
366        shuffle,
367        seed,
368    )
369    .map(|ok| ok.into_column())
370}
371
372fn check_slice_arg_shape(slice_len: usize, ca_len: usize, name: &str) -> PolarsResult<()> {
373    polars_ensure!(
374        slice_len == ca_len,
375        ComputeError:
376        "shape of the slice '{}' argument: {} does not match that of the list column: {}",
377        name, slice_len, ca_len
378    );
379    Ok(())
380}
381
382pub(super) fn shift(s: &[Column]) -> PolarsResult<Column> {
383    let list = s[0].list()?;
384    let periods = &s[1];
385
386    list.lst_shift(periods).map(|ok| ok.into_column())
387}
388
389pub(super) fn slice(args: &mut [Column]) -> PolarsResult<Option<Column>> {
390    let s = &args[0];
391    let list_ca = s.list()?;
392    let offset_s = &args[1];
393    let length_s = &args[2];
394
395    let mut out: ListChunked = match (offset_s.len(), length_s.len()) {
396        (1, 1) => {
397            let offset = offset_s.get(0).unwrap().try_extract::<i64>()?;
398            let slice_len = length_s
399                .get(0)
400                .unwrap()
401                .extract::<usize>()
402                .unwrap_or(usize::MAX);
403            return Ok(Some(list_ca.lst_slice(offset, slice_len).into_column()));
404        },
405        (1, length_slice_len) => {
406            check_slice_arg_shape(length_slice_len, list_ca.len(), "length")?;
407            let offset = offset_s.get(0).unwrap().try_extract::<i64>()?;
408            // cast to i64 as it is more likely that it is that dtype
409            // instead of usize/u64 (we never need that max length)
410            let length_ca = length_s.cast(&DataType::Int64)?;
411            let length_ca = length_ca.i64().unwrap();
412
413            list_ca
414                .amortized_iter()
415                .zip(length_ca)
416                .map(|(opt_s, opt_length)| match (opt_s, opt_length) {
417                    (Some(s), Some(length)) => Some(s.as_ref().slice(offset, length as usize)),
418                    _ => None,
419                })
420                .collect_trusted()
421        },
422        (offset_len, 1) => {
423            check_slice_arg_shape(offset_len, list_ca.len(), "offset")?;
424            let length_slice = length_s
425                .get(0)
426                .unwrap()
427                .extract::<usize>()
428                .unwrap_or(usize::MAX);
429            let offset_ca = offset_s.cast(&DataType::Int64)?;
430            let offset_ca = offset_ca.i64().unwrap();
431            list_ca
432                .amortized_iter()
433                .zip(offset_ca)
434                .map(|(opt_s, opt_offset)| match (opt_s, opt_offset) {
435                    (Some(s), Some(offset)) => Some(s.as_ref().slice(offset, length_slice)),
436                    _ => None,
437                })
438                .collect_trusted()
439        },
440        _ => {
441            check_slice_arg_shape(offset_s.len(), list_ca.len(), "offset")?;
442            check_slice_arg_shape(length_s.len(), list_ca.len(), "length")?;
443            let offset_ca = offset_s.cast(&DataType::Int64)?;
444            let offset_ca = offset_ca.i64()?;
445            // cast to i64 as it is more likely that it is that dtype
446            // instead of usize/u64 (we never need that max length)
447            let length_ca = length_s.cast(&DataType::Int64)?;
448            let length_ca = length_ca.i64().unwrap();
449
450            list_ca
451                .amortized_iter()
452                .zip(offset_ca)
453                .zip(length_ca)
454                .map(
455                    |((opt_s, opt_offset), opt_length)| match (opt_s, opt_offset, opt_length) {
456                        (Some(s), Some(offset), Some(length)) => {
457                            Some(s.as_ref().slice(offset, length as usize))
458                        },
459                        _ => None,
460                    },
461                )
462                .collect_trusted()
463        },
464    };
465    out.rename(s.name().clone());
466    Ok(Some(out.into_column()))
467}
468
469pub(super) fn concat(s: &mut [Column]) -> PolarsResult<Option<Column>> {
470    let mut first = std::mem::take(&mut s[0]);
471    let other = &s[1..];
472
473    // TODO! don't auto cast here, but implode beforehand.
474    let mut first_ca = match first.try_list() {
475        Some(ca) => ca,
476        None => {
477            first = first
478                .reshape_list(&[ReshapeDimension::Infer, ReshapeDimension::new_dimension(1)])
479                .unwrap();
480            first.list().unwrap()
481        },
482    }
483    .clone();
484
485    if first_ca.len() == 1 && !other.is_empty() {
486        let max_len = other.iter().map(|s| s.len()).max().unwrap();
487        if max_len > 1 {
488            first_ca = first_ca.new_from_index(0, max_len)
489        }
490    }
491
492    first_ca.lst_concat(other).map(|ca| Some(ca.into_column()))
493}
494
495pub(super) fn get(s: &mut [Column], null_on_oob: bool) -> PolarsResult<Option<Column>> {
496    let ca = s[0].list()?;
497    let index = s[1].cast(&DataType::Int64)?;
498    let index = index.i64().unwrap();
499
500    match index.len() {
501        1 => {
502            let index = index.get(0);
503            if let Some(index) = index {
504                ca.lst_get(index, null_on_oob).map(Column::from).map(Some)
505            } else {
506                Ok(Some(Column::full_null(
507                    ca.name().clone(),
508                    ca.len(),
509                    ca.inner_dtype(),
510                )))
511            }
512        },
513        len if len == ca.len() => {
514            let tmp = ca.rechunk();
515            let arr = tmp.downcast_as_array();
516            let offsets = arr.offsets().as_slice();
517            let take_by = if ca.null_count() == 0 {
518                index
519                    .iter()
520                    .enumerate()
521                    .map(|(i, opt_idx)| match opt_idx {
522                        Some(idx) => {
523                            let (start, end) = unsafe {
524                                (*offsets.get_unchecked(i), *offsets.get_unchecked(i + 1))
525                            };
526                            let offset = if idx >= 0 { start + idx } else { end + idx };
527                            if offset >= end || offset < start || start == end {
528                                if null_on_oob {
529                                    Ok(None)
530                                } else {
531                                    polars_bail!(ComputeError: "get index is out of bounds");
532                                }
533                            } else {
534                                Ok(Some(offset as IdxSize))
535                            }
536                        },
537                        None => Ok(None),
538                    })
539                    .collect::<Result<IdxCa, _>>()?
540            } else {
541                index
542                    .iter()
543                    .zip(arr.validity().unwrap())
544                    .enumerate()
545                    .map(|(i, (opt_idx, valid))| match (valid, opt_idx) {
546                        (true, Some(idx)) => {
547                            let (start, end) = unsafe {
548                                (*offsets.get_unchecked(i), *offsets.get_unchecked(i + 1))
549                            };
550                            let offset = if idx >= 0 { start + idx } else { end + idx };
551                            if offset >= end || offset < start || start == end {
552                                if null_on_oob {
553                                    Ok(None)
554                                } else {
555                                    polars_bail!(ComputeError: "get index is out of bounds");
556                                }
557                            } else {
558                                Ok(Some(offset as IdxSize))
559                            }
560                        },
561                        _ => Ok(None),
562                    })
563                    .collect::<Result<IdxCa, _>>()?
564            };
565            let s = Series::try_from((ca.name().clone(), arr.values().clone())).unwrap();
566            unsafe { s.take_unchecked(&take_by) }
567                .cast(ca.inner_dtype())
568                .map(Column::from)
569                .map(Some)
570        },
571        _ if ca.len() == 1 => {
572            if ca.null_count() > 0 {
573                return Ok(Some(Column::full_null(
574                    ca.name().clone(),
575                    index.len(),
576                    ca.inner_dtype(),
577                )));
578            }
579            let tmp = ca.rechunk();
580            let arr = tmp.downcast_as_array();
581            let offsets = arr.offsets().as_slice();
582            let start = offsets[0];
583            let end = offsets[1];
584            let out_of_bounds = |offset| offset >= end || offset < start || start == end;
585            let take_by: IdxCa = index
586                .iter()
587                .map(|opt_idx| match opt_idx {
588                    Some(idx) => {
589                        let offset = if idx >= 0 { start + idx } else { end + idx };
590                        if out_of_bounds(offset) {
591                            if null_on_oob {
592                                Ok(None)
593                            } else {
594                                polars_bail!(ComputeError: "get index is out of bounds");
595                            }
596                        } else {
597                            let Ok(offset) = IdxSize::try_from(offset) else {
598                                polars_bail!(ComputeError: "get index is out of bounds");
599                            };
600                            Ok(Some(offset))
601                        }
602                    },
603                    None => Ok(None),
604                })
605                .collect::<Result<IdxCa, _>>()?;
606
607            let s = Series::try_from((ca.name().clone(), arr.values().clone())).unwrap();
608            unsafe { s.take_unchecked(&take_by) }
609                .cast(ca.inner_dtype())
610                .map(Column::from)
611                .map(Some)
612        },
613        len => polars_bail!(
614            ComputeError:
615            "`list.get` expression got an index array of length {} while the list has {} elements",
616            len, ca.len()
617        ),
618    }
619}
620
621#[cfg(feature = "list_gather")]
622pub(super) fn gather(args: &[Column], null_on_oob: bool) -> PolarsResult<Column> {
623    let ca = &args[0];
624    let idx = &args[1];
625    let ca = ca.list()?;
626
627    if idx.len() == 1 && idx.dtype().is_primitive_numeric() && null_on_oob {
628        // fast path
629        let idx = idx.get(0)?.try_extract::<i64>()?;
630        let out = ca.lst_get(idx, null_on_oob).map(Column::from)?;
631        // make sure we return a list
632        out.reshape_list(&[ReshapeDimension::Infer, ReshapeDimension::new_dimension(1)])
633    } else {
634        ca.lst_gather(idx.as_materialized_series(), null_on_oob)
635            .map(Column::from)
636    }
637}
638
639#[cfg(feature = "list_gather")]
640pub(super) fn gather_every(args: &[Column]) -> PolarsResult<Column> {
641    let ca = &args[0];
642    let n = &args[1].strict_cast(&IDX_DTYPE)?;
643    let offset = &args[2].strict_cast(&IDX_DTYPE)?;
644
645    ca.list()?
646        .lst_gather_every(n.idx()?, offset.idx()?)
647        .map(Column::from)
648}
649
650#[cfg(feature = "list_count")]
651pub(super) fn count_matches(args: &[Column]) -> PolarsResult<Column> {
652    let s = &args[0];
653    let element = &args[1];
654    polars_ensure!(
655        element.len() == 1,
656        ComputeError: "argument expression in `list.count_matches` must produce exactly one element, got {}",
657        element.len()
658    );
659    let ca = s.list()?;
660    list_count_matches(ca, element.get(0).unwrap()).map(Column::from)
661}
662
663pub(super) fn sum(s: &Column) -> PolarsResult<Column> {
664    s.list()?.lst_sum().map(Column::from)
665}
666
667pub(super) fn length(s: &Column) -> PolarsResult<Column> {
668    Ok(s.list()?.lst_lengths().into_column())
669}
670
671pub(super) fn max(s: &Column) -> PolarsResult<Column> {
672    s.list()?.lst_max().map(Column::from)
673}
674
675pub(super) fn min(s: &Column) -> PolarsResult<Column> {
676    s.list()?.lst_min().map(Column::from)
677}
678
679pub(super) fn mean(s: &Column) -> PolarsResult<Column> {
680    Ok(s.list()?.lst_mean().into())
681}
682
683pub(super) fn median(s: &Column) -> PolarsResult<Column> {
684    Ok(s.list()?.lst_median().into())
685}
686
687pub(super) fn std(s: &Column, ddof: u8) -> PolarsResult<Column> {
688    Ok(s.list()?.lst_std(ddof).into())
689}
690
691pub(super) fn var(s: &Column, ddof: u8) -> PolarsResult<Column> {
692    Ok(s.list()?.lst_var(ddof).into())
693}
694
695pub(super) fn arg_min(s: &Column) -> PolarsResult<Column> {
696    Ok(s.list()?.lst_arg_min().into_column())
697}
698
699pub(super) fn arg_max(s: &Column) -> PolarsResult<Column> {
700    Ok(s.list()?.lst_arg_max().into_column())
701}
702
703#[cfg(feature = "diff")]
704pub(super) fn diff(s: &Column, n: i64, null_behavior: NullBehavior) -> PolarsResult<Column> {
705    Ok(s.list()?.lst_diff(n, null_behavior)?.into_column())
706}
707
708pub(super) fn sort(s: &Column, options: SortOptions) -> PolarsResult<Column> {
709    Ok(s.list()?.lst_sort(options)?.into_column())
710}
711
712pub(super) fn reverse(s: &Column) -> PolarsResult<Column> {
713    Ok(s.list()?.lst_reverse().into_column())
714}
715
716pub(super) fn unique(s: &Column, is_stable: bool) -> PolarsResult<Column> {
717    if is_stable {
718        Ok(s.list()?.lst_unique_stable()?.into_column())
719    } else {
720        Ok(s.list()?.lst_unique()?.into_column())
721    }
722}
723
724#[cfg(feature = "list_sets")]
725pub(super) fn set_operation(s: &[Column], set_type: SetOperation) -> PolarsResult<Column> {
726    let s0 = &s[0];
727    let s1 = &s[1];
728
729    if s0.is_empty() || s1.is_empty() {
730        return match set_type {
731            SetOperation::Intersection => {
732                if s0.is_empty() {
733                    Ok(s0.clone())
734                } else {
735                    Ok(s1.clone().with_name(s0.name().clone()))
736                }
737            },
738            SetOperation::Difference => Ok(s0.clone()),
739            SetOperation::Union | SetOperation::SymmetricDifference => {
740                if s0.is_empty() {
741                    Ok(s1.clone().with_name(s0.name().clone()))
742                } else {
743                    Ok(s0.clone())
744                }
745            },
746        };
747    }
748
749    list_set_operation(s0.list()?, s1.list()?, set_type).map(|ca| ca.into_column())
750}
751
752#[cfg(feature = "list_any_all")]
753pub(super) fn lst_any(s: &Column) -> PolarsResult<Column> {
754    s.list()?.lst_any().map(Column::from)
755}
756
757#[cfg(feature = "list_any_all")]
758pub(super) fn lst_all(s: &Column) -> PolarsResult<Column> {
759    s.list()?.lst_all().map(Column::from)
760}
761
762pub(super) fn join(s: &[Column], ignore_nulls: bool) -> PolarsResult<Column> {
763    let ca = s[0].list()?;
764    let separator = s[1].str()?;
765    Ok(ca.lst_join(separator, ignore_nulls)?.into_column())
766}
767
768#[cfg(feature = "dtype-array")]
769pub(super) fn to_array(s: &Column, width: usize) -> PolarsResult<Column> {
770    let array_dtype = map_list_dtype_to_array_dtype(s.dtype(), width)?;
771    s.cast(&array_dtype)
772}
773
774#[cfg(feature = "list_to_struct")]
775pub(super) fn to_struct(s: &Column, args: &ListToStructArgs) -> PolarsResult<Column> {
776    Ok(s.list()?.to_struct(args)?.into_series().into())
777}
778
779pub(super) fn n_unique(s: &Column) -> PolarsResult<Column> {
780    Ok(s.list()?.lst_n_unique()?.into_column())
781}