polars_plan/dsl/function_expr/
list.rs

1use arrow::legacy::utils::CustomIterTools;
2use polars_ops::chunked_array::list::*;
3
4use super::*;
5use crate::{map, map_as_slice, wrap};
6
7#[derive(Clone, Eq, PartialEq, Hash, Debug)]
8#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
9pub enum ListFunction {
10    Concat,
11    #[cfg(feature = "is_in")]
12    Contains,
13    #[cfg(feature = "list_drop_nulls")]
14    DropNulls,
15    #[cfg(feature = "list_sample")]
16    Sample {
17        is_fraction: bool,
18        with_replacement: bool,
19        shuffle: bool,
20        seed: Option<u64>,
21    },
22    Slice,
23    Shift,
24    Get(bool),
25    #[cfg(feature = "list_gather")]
26    Gather(bool),
27    #[cfg(feature = "list_gather")]
28    GatherEvery,
29    #[cfg(feature = "list_count")]
30    CountMatches,
31    Sum,
32    Length,
33    Max,
34    Min,
35    Mean,
36    Median,
37    Std(u8),
38    Var(u8),
39    ArgMin,
40    ArgMax,
41    #[cfg(feature = "diff")]
42    Diff {
43        n: i64,
44        null_behavior: NullBehavior,
45    },
46    Sort(SortOptions),
47    Reverse,
48    Unique(bool),
49    NUnique,
50    #[cfg(feature = "list_sets")]
51    SetOperation(SetOperation),
52    #[cfg(feature = "list_any_all")]
53    Any,
54    #[cfg(feature = "list_any_all")]
55    All,
56    Join(bool),
57    #[cfg(feature = "dtype-array")]
58    ToArray(usize),
59    #[cfg(feature = "list_to_struct")]
60    ToStruct(ListToStructArgs),
61}
62
63impl ListFunction {
64    pub(super) fn get_field(&self, mapper: FieldsMapper) -> PolarsResult<Field> {
65        use ListFunction::*;
66        match self {
67            Concat => mapper.map_to_list_supertype(),
68            #[cfg(feature = "is_in")]
69            Contains => mapper.with_dtype(DataType::Boolean),
70            #[cfg(feature = "list_drop_nulls")]
71            DropNulls => mapper.with_same_dtype(),
72            #[cfg(feature = "list_sample")]
73            Sample { .. } => mapper.with_same_dtype(),
74            Slice => mapper.with_same_dtype(),
75            Shift => mapper.with_same_dtype(),
76            Get(_) => mapper.map_to_list_and_array_inner_dtype(),
77            #[cfg(feature = "list_gather")]
78            Gather(_) => mapper.with_same_dtype(),
79            #[cfg(feature = "list_gather")]
80            GatherEvery => mapper.with_same_dtype(),
81            #[cfg(feature = "list_count")]
82            CountMatches => mapper.with_dtype(IDX_DTYPE),
83            Sum => mapper.nested_sum_type(),
84            Min => mapper.map_to_list_and_array_inner_dtype(),
85            Max => mapper.map_to_list_and_array_inner_dtype(),
86            Mean => mapper.nested_mean_median_type(),
87            Median => mapper.nested_mean_median_type(),
88            Std(_) => mapper.map_to_float_dtype(), // Need to also have this sometimes marked as float32 or duration..
89            Var(_) => mapper.map_to_float_dtype(),
90            ArgMin => mapper.with_dtype(IDX_DTYPE),
91            ArgMax => mapper.with_dtype(IDX_DTYPE),
92            #[cfg(feature = "diff")]
93            Diff { .. } => mapper.map_dtype(|dt| {
94                let inner_dt = match dt.inner_dtype().unwrap() {
95                    #[cfg(feature = "dtype-datetime")]
96                    DataType::Datetime(tu, _) => DataType::Duration(*tu),
97                    #[cfg(feature = "dtype-date")]
98                    DataType::Date => DataType::Duration(TimeUnit::Milliseconds),
99                    #[cfg(feature = "dtype-time")]
100                    DataType::Time => DataType::Duration(TimeUnit::Nanoseconds),
101                    DataType::UInt64 | DataType::UInt32 => DataType::Int64,
102                    DataType::UInt16 => DataType::Int32,
103                    DataType::UInt8 => DataType::Int16,
104                    inner_dt => inner_dt.clone(),
105                };
106                DataType::List(Box::new(inner_dt))
107            }),
108            Sort(_) => mapper.with_same_dtype(),
109            Reverse => mapper.with_same_dtype(),
110            Unique(_) => mapper.with_same_dtype(),
111            Length => mapper.with_dtype(IDX_DTYPE),
112            #[cfg(feature = "list_sets")]
113            SetOperation(_) => mapper.with_same_dtype(),
114            #[cfg(feature = "list_any_all")]
115            Any => mapper.with_dtype(DataType::Boolean),
116            #[cfg(feature = "list_any_all")]
117            All => mapper.with_dtype(DataType::Boolean),
118            Join(_) => mapper.with_dtype(DataType::String),
119            #[cfg(feature = "dtype-array")]
120            ToArray(width) => mapper.try_map_dtype(|dt| map_list_dtype_to_array_dtype(dt, *width)),
121            NUnique => mapper.with_dtype(IDX_DTYPE),
122            #[cfg(feature = "list_to_struct")]
123            ToStruct(args) => mapper.try_map_dtype(|x| args.get_output_dtype(x)),
124        }
125    }
126
127    pub fn function_options(&self) -> FunctionOptions {
128        use ListFunction as L;
129        match self {
130            L::Concat => FunctionOptions::elementwise(),
131            #[cfg(feature = "is_in")]
132            L::Contains => FunctionOptions::elementwise(),
133            #[cfg(feature = "list_sample")]
134            L::Sample { .. } => FunctionOptions::elementwise(),
135            #[cfg(feature = "list_gather")]
136            L::Gather(_) => FunctionOptions::elementwise(),
137            #[cfg(feature = "list_gather")]
138            L::GatherEvery => FunctionOptions::elementwise(),
139            #[cfg(feature = "list_sets")]
140            L::SetOperation(_) => FunctionOptions::elementwise()
141                .with_casting_rules(CastingRules::Supertype(SuperTypeOptions {
142                    flags: SuperTypeFlags::default() | SuperTypeFlags::ALLOW_IMPLODE_LIST,
143                }))
144                .with_flags(|f| f & !FunctionFlags::RETURNS_SCALAR),
145            #[cfg(feature = "diff")]
146            L::Diff { .. } => FunctionOptions::elementwise(),
147            #[cfg(feature = "list_drop_nulls")]
148            L::DropNulls => FunctionOptions::elementwise(),
149            #[cfg(feature = "list_count")]
150            L::CountMatches => FunctionOptions::elementwise(),
151            L::Sum
152            | L::Slice
153            | L::Shift
154            | L::Get(_)
155            | L::Length
156            | L::Max
157            | L::Min
158            | L::Mean
159            | L::Median
160            | L::Std(_)
161            | L::Var(_)
162            | L::ArgMin
163            | L::ArgMax
164            | L::Sort(_)
165            | L::Reverse
166            | L::Unique(_)
167            | L::Join(_)
168            | L::NUnique => FunctionOptions::elementwise(),
169            #[cfg(feature = "list_any_all")]
170            L::Any | L::All => FunctionOptions::elementwise(),
171            #[cfg(feature = "dtype-array")]
172            L::ToArray(_) => FunctionOptions::elementwise(),
173            #[cfg(feature = "list_to_struct")]
174            L::ToStruct(ListToStructArgs::FixedWidth(_)) => FunctionOptions::elementwise(),
175            #[cfg(feature = "list_to_struct")]
176            L::ToStruct(ListToStructArgs::InferWidth { .. }) => FunctionOptions::groupwise(),
177        }
178    }
179}
180
181#[cfg(feature = "dtype-array")]
182fn map_list_dtype_to_array_dtype(datatype: &DataType, width: usize) -> PolarsResult<DataType> {
183    if let DataType::List(inner) = datatype {
184        Ok(DataType::Array(inner.clone(), width))
185    } else {
186        polars_bail!(ComputeError: "expected List dtype")
187    }
188}
189
190impl Display for ListFunction {
191    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
192        use ListFunction::*;
193
194        let name = match self {
195            Concat => "concat",
196            #[cfg(feature = "is_in")]
197            Contains => "contains",
198            #[cfg(feature = "list_drop_nulls")]
199            DropNulls => "drop_nulls",
200            #[cfg(feature = "list_sample")]
201            Sample { is_fraction, .. } => {
202                if *is_fraction {
203                    "sample_fraction"
204                } else {
205                    "sample_n"
206                }
207            },
208            Slice => "slice",
209            Shift => "shift",
210            Get(_) => "get",
211            #[cfg(feature = "list_gather")]
212            Gather(_) => "gather",
213            #[cfg(feature = "list_gather")]
214            GatherEvery => "gather_every",
215            #[cfg(feature = "list_count")]
216            CountMatches => "count_matches",
217            Sum => "sum",
218            Min => "min",
219            Max => "max",
220            Mean => "mean",
221            Median => "median",
222            Std(_) => "std",
223            Var(_) => "var",
224            ArgMin => "arg_min",
225            ArgMax => "arg_max",
226            #[cfg(feature = "diff")]
227            Diff { .. } => "diff",
228            Length => "length",
229            Sort(_) => "sort",
230            Reverse => "reverse",
231            Unique(is_stable) => {
232                if *is_stable {
233                    "unique_stable"
234                } else {
235                    "unique"
236                }
237            },
238            NUnique => "n_unique",
239            #[cfg(feature = "list_sets")]
240            SetOperation(s) => return write!(f, "list.{s}"),
241            #[cfg(feature = "list_any_all")]
242            Any => "any",
243            #[cfg(feature = "list_any_all")]
244            All => "all",
245            Join(_) => "join",
246            #[cfg(feature = "dtype-array")]
247            ToArray(_) => "to_array",
248            #[cfg(feature = "list_to_struct")]
249            ToStruct(_) => "to_struct",
250        };
251        write!(f, "list.{name}")
252    }
253}
254
255impl From<ListFunction> for SpecialEq<Arc<dyn ColumnsUdf>> {
256    fn from(func: ListFunction) -> Self {
257        use ListFunction::*;
258        match func {
259            Concat => wrap!(concat),
260            #[cfg(feature = "is_in")]
261            Contains => wrap!(contains),
262            #[cfg(feature = "list_drop_nulls")]
263            DropNulls => map!(drop_nulls),
264            #[cfg(feature = "list_sample")]
265            Sample {
266                is_fraction,
267                with_replacement,
268                shuffle,
269                seed,
270            } => {
271                if is_fraction {
272                    map_as_slice!(sample_fraction, with_replacement, shuffle, seed)
273                } else {
274                    map_as_slice!(sample_n, with_replacement, shuffle, seed)
275                }
276            },
277            Slice => wrap!(slice),
278            Shift => map_as_slice!(shift),
279            Get(null_on_oob) => wrap!(get, null_on_oob),
280            #[cfg(feature = "list_gather")]
281            Gather(null_on_oob) => map_as_slice!(gather, null_on_oob),
282            #[cfg(feature = "list_gather")]
283            GatherEvery => map_as_slice!(gather_every),
284            #[cfg(feature = "list_count")]
285            CountMatches => map_as_slice!(count_matches),
286            Sum => map!(sum),
287            Length => map!(length),
288            Max => map!(max),
289            Min => map!(min),
290            Mean => map!(mean),
291            Median => map!(median),
292            Std(ddof) => map!(std, ddof),
293            Var(ddof) => map!(var, ddof),
294            ArgMin => map!(arg_min),
295            ArgMax => map!(arg_max),
296            #[cfg(feature = "diff")]
297            Diff { n, null_behavior } => map!(diff, n, null_behavior),
298            Sort(options) => map!(sort, options),
299            Reverse => map!(reverse),
300            Unique(is_stable) => map!(unique, is_stable),
301            #[cfg(feature = "list_sets")]
302            SetOperation(s) => map_as_slice!(set_operation, s),
303            #[cfg(feature = "list_any_all")]
304            Any => map!(lst_any),
305            #[cfg(feature = "list_any_all")]
306            All => map!(lst_all),
307            Join(ignore_nulls) => map_as_slice!(join, ignore_nulls),
308            #[cfg(feature = "dtype-array")]
309            ToArray(width) => map!(to_array, width),
310            NUnique => map!(n_unique),
311            #[cfg(feature = "list_to_struct")]
312            ToStruct(args) => map!(to_struct, &args),
313        }
314    }
315}
316
317#[cfg(feature = "is_in")]
318pub(super) fn contains(args: &mut [Column]) -> PolarsResult<Option<Column>> {
319    let list = &args[0];
320    let item = &args[1];
321    polars_ensure!(matches!(list.dtype(), DataType::List(_)),
322        SchemaMismatch: "invalid series dtype: expected `List`, got `{}`", list.dtype(),
323    );
324    polars_ops::prelude::is_in(
325        item.as_materialized_series(),
326        list.as_materialized_series(),
327        true,
328    )
329    .map(|mut ca| {
330        ca.rename(list.name().clone());
331        Some(ca.into_column())
332    })
333}
334
335#[cfg(feature = "list_drop_nulls")]
336pub(super) fn drop_nulls(s: &Column) -> PolarsResult<Column> {
337    let list = s.list()?;
338    Ok(list.lst_drop_nulls().into_column())
339}
340
341#[cfg(feature = "list_sample")]
342pub(super) fn sample_n(
343    s: &[Column],
344    with_replacement: bool,
345    shuffle: bool,
346    seed: Option<u64>,
347) -> PolarsResult<Column> {
348    let list = s[0].list()?;
349    let n = &s[1];
350    list.lst_sample_n(n.as_materialized_series(), with_replacement, shuffle, seed)
351        .map(|ok| ok.into_column())
352}
353
354#[cfg(feature = "list_sample")]
355pub(super) fn sample_fraction(
356    s: &[Column],
357    with_replacement: bool,
358    shuffle: bool,
359    seed: Option<u64>,
360) -> PolarsResult<Column> {
361    let list = s[0].list()?;
362    let fraction = &s[1];
363    list.lst_sample_fraction(
364        fraction.as_materialized_series(),
365        with_replacement,
366        shuffle,
367        seed,
368    )
369    .map(|ok| ok.into_column())
370}
371
372fn check_slice_arg_shape(slice_len: usize, ca_len: usize, name: &str) -> PolarsResult<()> {
373    polars_ensure!(
374        slice_len == ca_len,
375        ComputeError:
376        "shape of the slice '{}' argument: {} does not match that of the list column: {}",
377        name, slice_len, ca_len
378    );
379    Ok(())
380}
381
382pub(super) fn shift(s: &[Column]) -> PolarsResult<Column> {
383    let list = s[0].list()?;
384    let periods = &s[1];
385
386    list.lst_shift(periods).map(|ok| ok.into_column())
387}
388
389pub(super) fn slice(args: &mut [Column]) -> PolarsResult<Option<Column>> {
390    let s = &args[0];
391    let list_ca = s.list()?;
392    let offset_s = &args[1];
393    let length_s = &args[2];
394
395    let mut out: ListChunked = match (offset_s.len(), length_s.len()) {
396        (1, 1) => {
397            let offset = offset_s.get(0).unwrap().try_extract::<i64>()?;
398            let slice_len = length_s
399                .get(0)
400                .unwrap()
401                .extract::<usize>()
402                .unwrap_or(usize::MAX);
403            return Ok(Some(list_ca.lst_slice(offset, slice_len).into_column()));
404        },
405        (1, length_slice_len) => {
406            check_slice_arg_shape(length_slice_len, list_ca.len(), "length")?;
407            let offset = offset_s.get(0).unwrap().try_extract::<i64>()?;
408            // cast to i64 as it is more likely that it is that dtype
409            // instead of usize/u64 (we never need that max length)
410            let length_ca = length_s.cast(&DataType::Int64)?;
411            let length_ca = length_ca.i64().unwrap();
412
413            list_ca
414                .amortized_iter()
415                .zip(length_ca)
416                .map(|(opt_s, opt_length)| match (opt_s, opt_length) {
417                    (Some(s), Some(length)) => Some(s.as_ref().slice(offset, length as usize)),
418                    _ => None,
419                })
420                .collect_trusted()
421        },
422        (offset_len, 1) => {
423            check_slice_arg_shape(offset_len, list_ca.len(), "offset")?;
424            let length_slice = length_s
425                .get(0)
426                .unwrap()
427                .extract::<usize>()
428                .unwrap_or(usize::MAX);
429            let offset_ca = offset_s.cast(&DataType::Int64)?;
430            let offset_ca = offset_ca.i64().unwrap();
431            list_ca
432                .amortized_iter()
433                .zip(offset_ca)
434                .map(|(opt_s, opt_offset)| match (opt_s, opt_offset) {
435                    (Some(s), Some(offset)) => Some(s.as_ref().slice(offset, length_slice)),
436                    _ => None,
437                })
438                .collect_trusted()
439        },
440        _ => {
441            check_slice_arg_shape(offset_s.len(), list_ca.len(), "offset")?;
442            check_slice_arg_shape(length_s.len(), list_ca.len(), "length")?;
443            let offset_ca = offset_s.cast(&DataType::Int64)?;
444            let offset_ca = offset_ca.i64()?;
445            // cast to i64 as it is more likely that it is that dtype
446            // instead of usize/u64 (we never need that max length)
447            let length_ca = length_s.cast(&DataType::Int64)?;
448            let length_ca = length_ca.i64().unwrap();
449
450            list_ca
451                .amortized_iter()
452                .zip(offset_ca)
453                .zip(length_ca)
454                .map(
455                    |((opt_s, opt_offset), opt_length)| match (opt_s, opt_offset, opt_length) {
456                        (Some(s), Some(offset), Some(length)) => {
457                            Some(s.as_ref().slice(offset, length as usize))
458                        },
459                        _ => None,
460                    },
461                )
462                .collect_trusted()
463        },
464    };
465    out.rename(s.name().clone());
466    Ok(Some(out.into_column()))
467}
468
469pub(super) fn concat(s: &mut [Column]) -> PolarsResult<Option<Column>> {
470    let mut first = std::mem::take(&mut s[0]);
471    let other = &s[1..];
472
473    // TODO! don't auto cast here, but implode beforehand.
474    let mut first_ca = match first.try_list() {
475        Some(ca) => ca,
476        None => {
477            first = first
478                .reshape_list(&[ReshapeDimension::Infer, ReshapeDimension::new_dimension(1)])
479                .unwrap();
480            first.list().unwrap()
481        },
482    }
483    .clone();
484
485    if first_ca.len() == 1 && !other.is_empty() {
486        let max_len = other.iter().map(|s| s.len()).max().unwrap();
487        if max_len > 1 {
488            first_ca = first_ca.new_from_index(0, max_len)
489        }
490    }
491
492    first_ca.lst_concat(other).map(|ca| Some(ca.into_column()))
493}
494
495pub(super) fn get(s: &mut [Column], null_on_oob: bool) -> PolarsResult<Option<Column>> {
496    let ca = s[0].list()?;
497    let index = s[1].cast(&DataType::Int64)?;
498    let index = index.i64().unwrap();
499
500    match index.len() {
501        1 => {
502            let index = index.get(0);
503            if let Some(index) = index {
504                ca.lst_get(index, null_on_oob).map(Column::from).map(Some)
505            } else {
506                Ok(Some(Column::full_null(
507                    ca.name().clone(),
508                    ca.len(),
509                    ca.inner_dtype(),
510                )))
511            }
512        },
513        len if len == ca.len() => {
514            let tmp = ca.rechunk();
515            let arr = tmp.downcast_as_array();
516            let offsets = arr.offsets().as_slice();
517            let take_by = if ca.null_count() == 0 {
518                index
519                    .iter()
520                    .enumerate()
521                    .map(|(i, opt_idx)| match opt_idx {
522                        Some(idx) => {
523                            let (start, end) = unsafe {
524                                (*offsets.get_unchecked(i), *offsets.get_unchecked(i + 1))
525                            };
526                            let offset = if idx >= 0 { start + idx } else { end + idx };
527                            if offset >= end || offset < start || start == end {
528                                if null_on_oob {
529                                    Ok(None)
530                                } else {
531                                    polars_bail!(ComputeError: "get index is out of bounds");
532                                }
533                            } else {
534                                Ok(Some(offset as IdxSize))
535                            }
536                        },
537                        None => Ok(None),
538                    })
539                    .collect::<Result<IdxCa, _>>()?
540            } else {
541                index
542                    .iter()
543                    .zip(arr.validity().unwrap())
544                    .enumerate()
545                    .map(|(i, (opt_idx, valid))| match (valid, opt_idx) {
546                        (true, Some(idx)) => {
547                            let (start, end) = unsafe {
548                                (*offsets.get_unchecked(i), *offsets.get_unchecked(i + 1))
549                            };
550                            let offset = if idx >= 0 { start + idx } else { end + idx };
551                            if offset >= end || offset < start || start == end {
552                                if null_on_oob {
553                                    Ok(None)
554                                } else {
555                                    polars_bail!(ComputeError: "get index is out of bounds");
556                                }
557                            } else {
558                                Ok(Some(offset as IdxSize))
559                            }
560                        },
561                        _ => Ok(None),
562                    })
563                    .collect::<Result<IdxCa, _>>()?
564            };
565            let s = Series::try_from((ca.name().clone(), arr.values().clone())).unwrap();
566            unsafe { s.take_unchecked(&take_by) }
567                .cast(ca.inner_dtype())
568                .map(Column::from)
569                .map(Some)
570        },
571        _ if ca.len() == 1 => {
572            if ca.null_count() > 0 {
573                return Ok(Some(Column::full_null(
574                    ca.name().clone(),
575                    index.len(),
576                    ca.inner_dtype(),
577                )));
578            }
579            let tmp = ca.rechunk();
580            let arr = tmp.downcast_as_array();
581            let offsets = arr.offsets().as_slice();
582            let start = offsets[0];
583            let end = offsets[1];
584            let out_of_bounds = |offset| offset >= end || offset < start || start == end;
585            let take_by: IdxCa = index
586                .iter()
587                .map(|opt_idx| match opt_idx {
588                    Some(idx) => {
589                        let offset = if idx >= 0 { start + idx } else { end + idx };
590                        if out_of_bounds(offset) {
591                            if null_on_oob {
592                                Ok(None)
593                            } else {
594                                polars_bail!(ComputeError: "get index is out of bounds");
595                            }
596                        } else {
597                            let Ok(offset) = IdxSize::try_from(offset) else {
598                                polars_bail!(ComputeError: "get index is out of bounds");
599                            };
600                            Ok(Some(offset))
601                        }
602                    },
603                    None => Ok(None),
604                })
605                .collect::<Result<IdxCa, _>>()?;
606
607            let s = Series::try_from((ca.name().clone(), arr.values().clone())).unwrap();
608            unsafe { s.take_unchecked(&take_by) }
609                .cast(ca.inner_dtype())
610                .map(Column::from)
611                .map(Some)
612        },
613        len => polars_bail!(
614            ComputeError:
615            "`list.get` expression got an index array of length {} while the list has {} elements",
616            len, ca.len()
617        ),
618    }
619}
620
621#[cfg(feature = "list_gather")]
622pub(super) fn gather(args: &[Column], null_on_oob: bool) -> PolarsResult<Column> {
623    let ca = &args[0];
624    let idx = &args[1];
625    let ca = ca.list()?;
626
627    if idx.len() == 1 && idx.dtype().is_primitive_numeric() && null_on_oob {
628        // fast path
629        let idx = idx.get(0)?.try_extract::<i64>()?;
630        let out = ca.lst_get(idx, null_on_oob).map(Column::from)?;
631        // make sure we return a list
632        out.reshape_list(&[ReshapeDimension::Infer, ReshapeDimension::new_dimension(1)])
633    } else {
634        ca.lst_gather(idx.as_materialized_series(), null_on_oob)
635            .map(Column::from)
636    }
637}
638
639#[cfg(feature = "list_gather")]
640pub(super) fn gather_every(args: &[Column]) -> PolarsResult<Column> {
641    let ca = &args[0];
642    let n = &args[1].strict_cast(&IDX_DTYPE)?;
643    let offset = &args[2].strict_cast(&IDX_DTYPE)?;
644
645    ca.list()?
646        .lst_gather_every(n.idx()?, offset.idx()?)
647        .map(Column::from)
648}
649
650#[cfg(feature = "list_count")]
651pub(super) fn count_matches(args: &[Column]) -> PolarsResult<Column> {
652    let s = &args[0];
653    let element = &args[1];
654    polars_ensure!(
655        element.len() == 1,
656        ComputeError: "argument expression in `list.count_matches` must produce exactly one element, got {}",
657        element.len()
658    );
659    let ca = s.list()?;
660    list_count_matches(ca, element.get(0).unwrap()).map(Column::from)
661}
662
663pub(super) fn sum(s: &Column) -> PolarsResult<Column> {
664    s.list()?.lst_sum().map(Column::from)
665}
666
667pub(super) fn length(s: &Column) -> PolarsResult<Column> {
668    Ok(s.list()?.lst_lengths().into_column())
669}
670
671pub(super) fn max(s: &Column) -> PolarsResult<Column> {
672    s.list()?.lst_max().map(Column::from)
673}
674
675pub(super) fn min(s: &Column) -> PolarsResult<Column> {
676    s.list()?.lst_min().map(Column::from)
677}
678
679pub(super) fn mean(s: &Column) -> PolarsResult<Column> {
680    Ok(s.list()?.lst_mean().into())
681}
682
683pub(super) fn median(s: &Column) -> PolarsResult<Column> {
684    Ok(s.list()?.lst_median().into())
685}
686
687pub(super) fn std(s: &Column, ddof: u8) -> PolarsResult<Column> {
688    Ok(s.list()?.lst_std(ddof).into())
689}
690
691pub(super) fn var(s: &Column, ddof: u8) -> PolarsResult<Column> {
692    Ok(s.list()?.lst_var(ddof).into())
693}
694
695pub(super) fn arg_min(s: &Column) -> PolarsResult<Column> {
696    Ok(s.list()?.lst_arg_min().into_column())
697}
698
699pub(super) fn arg_max(s: &Column) -> PolarsResult<Column> {
700    Ok(s.list()?.lst_arg_max().into_column())
701}
702
703#[cfg(feature = "diff")]
704pub(super) fn diff(s: &Column, n: i64, null_behavior: NullBehavior) -> PolarsResult<Column> {
705    Ok(s.list()?.lst_diff(n, null_behavior)?.into_column())
706}
707
708pub(super) fn sort(s: &Column, options: SortOptions) -> PolarsResult<Column> {
709    Ok(s.list()?.lst_sort(options)?.into_column())
710}
711
712pub(super) fn reverse(s: &Column) -> PolarsResult<Column> {
713    Ok(s.list()?.lst_reverse().into_column())
714}
715
716pub(super) fn unique(s: &Column, is_stable: bool) -> PolarsResult<Column> {
717    if is_stable {
718        Ok(s.list()?.lst_unique_stable()?.into_column())
719    } else {
720        Ok(s.list()?.lst_unique()?.into_column())
721    }
722}
723
724#[cfg(feature = "list_sets")]
725pub(super) fn set_operation(s: &[Column], set_type: SetOperation) -> PolarsResult<Column> {
726    let s0 = &s[0];
727    let s1 = &s[1];
728
729    if s0.is_empty() || s1.is_empty() {
730        return match set_type {
731            SetOperation::Intersection => {
732                if s0.is_empty() {
733                    Ok(s0.clone())
734                } else {
735                    Ok(s1.clone().with_name(s0.name().clone()))
736                }
737            },
738            SetOperation::Difference => Ok(s0.clone()),
739            SetOperation::Union | SetOperation::SymmetricDifference => {
740                if s0.is_empty() {
741                    Ok(s1.clone().with_name(s0.name().clone()))
742                } else {
743                    Ok(s0.clone())
744                }
745            },
746        };
747    }
748
749    list_set_operation(s0.list()?, s1.list()?, set_type).map(|ca| ca.into_column())
750}
751
752#[cfg(feature = "list_any_all")]
753pub(super) fn lst_any(s: &Column) -> PolarsResult<Column> {
754    s.list()?.lst_any().map(Column::from)
755}
756
757#[cfg(feature = "list_any_all")]
758pub(super) fn lst_all(s: &Column) -> PolarsResult<Column> {
759    s.list()?.lst_all().map(Column::from)
760}
761
762pub(super) fn join(s: &[Column], ignore_nulls: bool) -> PolarsResult<Column> {
763    let ca = s[0].list()?;
764    let separator = s[1].str()?;
765    Ok(ca.lst_join(separator, ignore_nulls)?.into_column())
766}
767
768#[cfg(feature = "dtype-array")]
769pub(super) fn to_array(s: &Column, width: usize) -> PolarsResult<Column> {
770    let array_dtype = map_list_dtype_to_array_dtype(s.dtype(), width)?;
771    s.cast(&array_dtype)
772}
773
774#[cfg(feature = "list_to_struct")]
775pub(super) fn to_struct(s: &Column, args: &ListToStructArgs) -> PolarsResult<Column> {
776    Ok(s.list()?.to_struct(args)?.into_series().into())
777}
778
779pub(super) fn n_unique(s: &Column) -> PolarsResult<Column> {
780    Ok(s.list()?.lst_n_unique()?.into_column())
781}