polars_core/series/arithmetic/
list.rs

1//! Allow arithmetic operations for ListChunked.
2//! use polars_error::{feature_gated, PolarsResult};
3
4use polars_error::{PolarsResult, feature_gated};
5
6use super::list_utils::NumericOp;
7use super::{IntoSeries, ListChunked, ListType, NumOpsDispatchInner, Series};
8
9impl NumOpsDispatchInner for ListType {
10    fn add_to(lhs: &ListChunked, rhs: &Series) -> PolarsResult<Series> {
11        NumericListOp::add().execute(&lhs.clone().into_series(), rhs)
12    }
13
14    fn subtract(lhs: &ListChunked, rhs: &Series) -> PolarsResult<Series> {
15        NumericListOp::sub().execute(&lhs.clone().into_series(), rhs)
16    }
17
18    fn multiply(lhs: &ListChunked, rhs: &Series) -> PolarsResult<Series> {
19        NumericListOp::mul().execute(&lhs.clone().into_series(), rhs)
20    }
21
22    fn divide(lhs: &ListChunked, rhs: &Series) -> PolarsResult<Series> {
23        NumericListOp::div().execute(&lhs.clone().into_series(), rhs)
24    }
25
26    fn remainder(lhs: &ListChunked, rhs: &Series) -> PolarsResult<Series> {
27        NumericListOp::rem().execute(&lhs.clone().into_series(), rhs)
28    }
29}
30
31#[derive(Clone)]
32pub struct NumericListOp(NumericOp);
33
34impl NumericListOp {
35    pub fn add() -> Self {
36        Self(NumericOp::Add)
37    }
38
39    pub fn sub() -> Self {
40        Self(NumericOp::Sub)
41    }
42
43    pub fn mul() -> Self {
44        Self(NumericOp::Mul)
45    }
46
47    pub fn div() -> Self {
48        Self(NumericOp::Div)
49    }
50
51    pub fn rem() -> Self {
52        Self(NumericOp::Rem)
53    }
54
55    pub fn floor_div() -> Self {
56        Self(NumericOp::FloorDiv)
57    }
58}
59
60impl NumericListOp {
61    #[cfg_attr(not(feature = "list_arithmetic"), allow(unused))]
62    pub fn execute(&self, lhs: &Series, rhs: &Series) -> PolarsResult<Series> {
63        feature_gated!("list_arithmetic", {
64            use std::borrow::Cow;
65
66            use either::Either;
67
68            // `trim_to_normalized_offsets` ensures we don't perform excessive
69            // memory allocation / compute on memory regions that have been
70            // sliced out.
71            let lhs = lhs
72                .trim_lists_to_normalized_offsets()
73                .map_or(Cow::Borrowed(lhs), Cow::Owned);
74            let rhs = rhs
75                .trim_lists_to_normalized_offsets()
76                .map_or(Cow::Borrowed(rhs), Cow::Owned);
77
78            let lhs = lhs.rechunk();
79            let rhs = rhs.rechunk();
80
81            let binary_op_exec = match ListNumericOpHelper::try_new(
82                self.clone(),
83                lhs.name().clone(),
84                lhs.dtype(),
85                rhs.dtype(),
86                lhs.len(),
87                rhs.len(),
88                {
89                    let (a, b) = lhs.list_offsets_and_validities_recursive();
90                    debug_assert!(a.iter().all(|x| *x.first() as usize == 0));
91                    (a, b, lhs.clone())
92                },
93                {
94                    let (a, b) = rhs.list_offsets_and_validities_recursive();
95                    debug_assert!(a.iter().all(|x| *x.first() as usize == 0));
96                    (a, b, rhs.clone())
97                },
98                lhs.rechunk_validity(),
99                rhs.rechunk_validity(),
100            )? {
101                Either::Left(v) => v,
102                Either::Right(ca) => return Ok(ca.into_series()),
103            };
104
105            Ok(binary_op_exec.finish()?.into_series())
106        })
107    }
108}
109
110#[cfg(feature = "list_arithmetic")]
111use inner::ListNumericOpHelper;
112
113#[cfg(feature = "list_arithmetic")]
114mod inner {
115    use arrow::bitmap::Bitmap;
116    use arrow::compute::utils::combine_validities_and;
117    use arrow::offset::OffsetsBuffer;
118    use either::Either;
119    use list_utils::with_match_pl_num_arith;
120    use num_traits::Zero;
121    use polars_compute::arithmetic::pl_num::PlNumArithmetic;
122    use polars_utils::float::IsFloat;
123
124    use super::super::list_utils::{BinaryOpApplyType, Broadcast, NumericOp};
125    use super::super::*;
126
127    /// Utility to perform a binary operation between the primitive values of
128    /// 2 columns, where at least one of the columns is a `ListChunked` type.
129    pub(super) struct ListNumericOpHelper {
130        op: NumericListOp,
131        output_name: PlSmallStr,
132        op_apply_type: BinaryOpApplyType,
133        broadcast: Broadcast,
134        output_dtype: DataType,
135        output_primitive_dtype: DataType,
136        output_len: usize,
137        /// Outer validity of the result, we always materialize this to reduce the
138        /// amount of code paths we need.
139        outer_validity: Bitmap,
140        // The series are stored as they are used for list broadcasting.
141        data_lhs: (Vec<OffsetsBuffer<i64>>, Vec<Option<Bitmap>>, Series),
142        data_rhs: (Vec<OffsetsBuffer<i64>>, Vec<Option<Bitmap>>, Series),
143        list_to_prim_lhs: Option<(Box<dyn Array>, usize)>,
144        swapped: bool,
145    }
146
147    /// This lets us separate some logic into `new()` to reduce the amount of
148    /// monomorphized code.
149    impl ListNumericOpHelper {
150        /// Checks that:
151        /// * Dtypes are compatible:
152        ///   * list<->primitive | primitive<->list
153        ///   * list<->list both contain primitives (e.g. List<Int8>)
154        /// * Primitive dtypes match
155        /// * Lengths are compatible:
156        ///   * 1<->n | n<->1
157        ///   * n<->n
158        /// * Both sides have at least 1 non-NULL outer row.
159        ///
160        /// Does not check:
161        /// * Whether the offsets are aligned for list<->list, this will be checked during execution.
162        ///
163        /// This returns an `Either` which may contain the final result to simplify
164        /// the implementation.
165        #[allow(clippy::too_many_arguments)]
166        pub(super) fn try_new(
167            op: NumericListOp,
168            output_name: PlSmallStr,
169            dtype_lhs: &DataType,
170            dtype_rhs: &DataType,
171            len_lhs: usize,
172            len_rhs: usize,
173            data_lhs: (Vec<OffsetsBuffer<i64>>, Vec<Option<Bitmap>>, Series),
174            data_rhs: (Vec<OffsetsBuffer<i64>>, Vec<Option<Bitmap>>, Series),
175            validity_lhs: Option<Bitmap>,
176            validity_rhs: Option<Bitmap>,
177        ) -> PolarsResult<Either<Self, ListChunked>> {
178            let prim_dtype_lhs = dtype_lhs.leaf_dtype();
179            let prim_dtype_rhs = dtype_rhs.leaf_dtype();
180
181            let output_primitive_dtype =
182                op.0.try_get_leaf_supertype(prim_dtype_lhs, prim_dtype_rhs)?;
183
184            fn is_list_type_at_all_levels(dtype: &DataType) -> bool {
185                match dtype {
186                    DataType::List(inner) => is_list_type_at_all_levels(inner),
187                    dt if dt.is_supported_list_arithmetic_input() => true,
188                    _ => false,
189                }
190            }
191
192            let op_err_msg = |err_reason: &str| {
193                polars_err!(
194                    InvalidOperation:
195                    "cannot {} columns: {}: (left: {}, right: {})",
196                    op.0.name(), err_reason, dtype_lhs, dtype_rhs,
197                )
198            };
199
200            let ensure_list_type_at_all_levels = |dtype: &DataType| {
201                if !is_list_type_at_all_levels(dtype) {
202                    Err(op_err_msg("dtype was not list on all nesting levels"))
203                } else {
204                    Ok(())
205                }
206            };
207
208            let (op_apply_type, output_dtype) = match (dtype_lhs, dtype_rhs) {
209                (l @ DataType::List(a), r @ DataType::List(b)) => {
210                    // `get_arithmetic_field()` in the DSL checks this, but we also have to check here because if a user
211                    // directly adds 2 series together it bypasses the DSL.
212                    // This is currently duplicated code and should be replaced one day with an assert after Series ops get
213                    // checked properly.
214                    if ![a, b]
215                        .into_iter()
216                        .all(|x| x.is_supported_list_arithmetic_input())
217                    {
218                        polars_bail!(
219                            InvalidOperation:
220                            "cannot {} two list columns with non-numeric inner types: (left: {}, right: {})",
221                            op.0.name(), l, r,
222                        );
223                    }
224                    (BinaryOpApplyType::ListToList, l)
225                },
226                (list_dtype @ DataType::List(_), x) if x.is_supported_list_arithmetic_input() => {
227                    ensure_list_type_at_all_levels(list_dtype)?;
228                    (BinaryOpApplyType::ListToPrimitive, list_dtype)
229                },
230                (x, list_dtype @ DataType::List(_)) if x.is_supported_list_arithmetic_input() => {
231                    ensure_list_type_at_all_levels(list_dtype)?;
232                    (BinaryOpApplyType::PrimitiveToList, list_dtype)
233                },
234                (l, r) => polars_bail!(
235                    InvalidOperation:
236                    "{} operation not supported for dtypes: {} != {}",
237                    op.0.name(), l, r,
238                ),
239            };
240
241            let output_dtype = output_dtype.cast_leaf(output_primitive_dtype.clone());
242
243            let (broadcast, output_len) = match (len_lhs, len_rhs) {
244                (l, r) if l == r => (Broadcast::NoBroadcast, l),
245                (1, v) => (Broadcast::Left, v),
246                (v, 1) => (Broadcast::Right, v),
247                (l, r) => polars_bail!(
248                    ShapeMismatch:
249                    "cannot {} two columns of differing lengths: {} != {}",
250                    op.0.name(), l, r
251                ),
252            };
253
254            let DataType::List(output_inner_dtype) = &output_dtype else {
255                unreachable!()
256            };
257
258            // # NULL semantics
259            // * [[1, 2]] (List[List[Int64]]) + NULL (Int64) => [[NULL, NULL]]
260            //   * Essentially as if the NULL primitive was added to every primitive in the row of the list column.
261            // * NULL (List[Int64]) + 1   (Int64)       => NULL
262            // * NULL (List[Int64]) + [1] (List[Int64]) => NULL
263
264            if output_len == 0
265                || (matches!(
266                    &op_apply_type,
267                    BinaryOpApplyType::ListToList | BinaryOpApplyType::ListToPrimitive
268                ) && validity_lhs.as_ref().is_some_and(|x| x.set_bits() == 0))
269                || (matches!(
270                    &op_apply_type,
271                    BinaryOpApplyType::ListToList | BinaryOpApplyType::PrimitiveToList
272                ) && validity_rhs.as_ref().is_some_and(|x| x.set_bits() == 0))
273            {
274                return Ok(Either::Right(ListChunked::full_null_with_dtype(
275                    output_name.clone(),
276                    output_len,
277                    output_inner_dtype.as_ref(),
278                )));
279            }
280
281            // At this point:
282            // * All unit length list columns have a valid outer value.
283
284            // The outer validity is just the validity of any non-broadcasting lists.
285            let outer_validity = match (&op_apply_type, &broadcast, validity_lhs, validity_rhs) {
286                // Both lists with same length, we combine the validity.
287                (BinaryOpApplyType::ListToList, Broadcast::NoBroadcast, l, r) => {
288                    combine_validities_and(l.as_ref(), r.as_ref())
289                },
290                // Match all other combinations that have non-broadcasting lists.
291                (
292                    BinaryOpApplyType::ListToList | BinaryOpApplyType::ListToPrimitive,
293                    Broadcast::NoBroadcast | Broadcast::Right,
294                    v,
295                    _,
296                )
297                | (
298                    BinaryOpApplyType::ListToList | BinaryOpApplyType::PrimitiveToList,
299                    Broadcast::NoBroadcast | Broadcast::Left,
300                    _,
301                    v,
302                ) => v,
303                _ => None,
304            }
305            .unwrap_or_else(|| Bitmap::new_with_value(true, output_len));
306
307            Ok(Either::Left(Self {
308                op,
309                output_name,
310                op_apply_type,
311                broadcast,
312                output_dtype: output_dtype.clone(),
313                output_primitive_dtype,
314                output_len,
315                outer_validity,
316                data_lhs,
317                data_rhs,
318                list_to_prim_lhs: None,
319                swapped: false,
320            }))
321        }
322
323        pub(super) fn finish(mut self) -> PolarsResult<ListChunked> {
324            // We have physical codepaths for a subset of the possible combinations of broadcasting and
325            // column types. The remaining combinations are handled by dispatching to the physical
326            // codepaths after operand swapping and/or materialized broadcasting.
327            //
328            // # Physical impl table
329            // Legend
330            // * |  N  | // impl "N"
331            // * | [N] | // dispatches to impl "N"
332            //
333            //                  |  L  |  N  |  R  | // Broadcast (L)eft, (N)oBroadcast, (R)ight
334            // ListToList       | [1] |  0  |  1  |
335            // ListToPrimitive  | [2] |  2  |  3  | // list broadcasting just materializes and dispatches to NoBroadcast
336            // PrimitiveToList  | [3] | [2] | [2] |
337
338            self.swapped = true;
339
340            match (&self.op_apply_type, &self.broadcast) {
341                (BinaryOpApplyType::ListToList, Broadcast::NoBroadcast)
342                | (BinaryOpApplyType::ListToList, Broadcast::Right)
343                | (BinaryOpApplyType::ListToPrimitive, Broadcast::NoBroadcast)
344                | (BinaryOpApplyType::ListToPrimitive, Broadcast::Right) => {
345                    self.swapped = false;
346                    self._finish_impl_dispatch()
347                },
348                (BinaryOpApplyType::ListToList, Broadcast::Left) => {
349                    self.broadcast = Broadcast::Right;
350
351                    std::mem::swap(&mut self.data_lhs, &mut self.data_rhs);
352                    self._finish_impl_dispatch()
353                },
354                (BinaryOpApplyType::ListToPrimitive, Broadcast::Left) => {
355                    self.list_to_prim_lhs
356                        .replace(Self::materialize_broadcasted_list(
357                            &mut self.data_lhs,
358                            self.output_len,
359                            &self.output_primitive_dtype,
360                        ));
361
362                    self.broadcast = Broadcast::NoBroadcast;
363
364                    // This does not swap! We are just dispatching to `NoBroadcast`
365                    // after materializing the broadcasted list array.
366                    self.swapped = false;
367                    self._finish_impl_dispatch()
368                },
369                (BinaryOpApplyType::PrimitiveToList, Broadcast::NoBroadcast) => {
370                    self.op_apply_type = BinaryOpApplyType::ListToPrimitive;
371
372                    std::mem::swap(&mut self.data_lhs, &mut self.data_rhs);
373                    self._finish_impl_dispatch()
374                },
375                (BinaryOpApplyType::PrimitiveToList, Broadcast::Right) => {
376                    // We materialize the list columns with `new_from_index`, as otherwise we'd have to
377                    // implement logic that broadcasts the offsets and validities across multiple levels
378                    // of nesting. But we will re-use the materialized memory to store the result.
379
380                    self.list_to_prim_lhs
381                        .replace(Self::materialize_broadcasted_list(
382                            &mut self.data_rhs,
383                            self.output_len,
384                            &self.output_primitive_dtype,
385                        ));
386
387                    self.op_apply_type = BinaryOpApplyType::ListToPrimitive;
388                    self.broadcast = Broadcast::NoBroadcast;
389
390                    std::mem::swap(&mut self.data_lhs, &mut self.data_rhs);
391                    self._finish_impl_dispatch()
392                },
393                (BinaryOpApplyType::PrimitiveToList, Broadcast::Left) => {
394                    self.op_apply_type = BinaryOpApplyType::ListToPrimitive;
395                    self.broadcast = Broadcast::Right;
396
397                    std::mem::swap(&mut self.data_lhs, &mut self.data_rhs);
398                    self._finish_impl_dispatch()
399                },
400            }
401        }
402
403        fn _finish_impl_dispatch(&mut self) -> PolarsResult<ListChunked> {
404            let output_dtype = self.output_dtype.clone();
405            let output_len = self.output_len;
406
407            let prim_lhs = self
408                .data_lhs
409                .2
410                .get_leaf_array()
411                .cast(&self.output_primitive_dtype)?
412                .rechunk();
413            let prim_rhs = self
414                .data_rhs
415                .2
416                .get_leaf_array()
417                .cast(&self.output_primitive_dtype)?
418                .rechunk();
419
420            debug_assert_eq!(prim_lhs.dtype(), prim_rhs.dtype());
421            let prim_dtype = prim_lhs.dtype();
422            debug_assert_eq!(prim_dtype, &self.output_primitive_dtype);
423
424            // Safety: Leaf dtypes have been checked to be numeric by `try_new()`
425            let out = with_match_physical_numeric_polars_type!(&prim_dtype, |$T| {
426                self._finish_impl::<$T>(prim_lhs, prim_rhs)
427            })?;
428
429            debug_assert_eq!(out.dtype(), &output_dtype);
430            assert_eq!(out.len(), output_len);
431
432            Ok(out)
433        }
434
435        /// Internal use only - contains physical impls.
436        fn _finish_impl<T: PolarsNumericType>(
437            &mut self,
438            prim_s_lhs: Series,
439            prim_s_rhs: Series,
440        ) -> PolarsResult<ListChunked>
441        where
442            T::Native: PlNumArithmetic,
443            PrimitiveArray<T::Native>:
444                polars_compute::comparisons::TotalEqKernel<Scalar = T::Native>,
445            T::Native: Zero + IsFloat,
446        {
447            #[inline(never)]
448            fn check_mismatch_pos(
449                mismatch_pos: usize,
450                offsets_lhs: &OffsetsBuffer<i64>,
451                offsets_rhs: &OffsetsBuffer<i64>,
452            ) -> PolarsResult<()> {
453                if mismatch_pos < offsets_lhs.len_proxy() {
454                    // RHS could be broadcasted
455                    let len_r = offsets_rhs.length_at(if offsets_rhs.len_proxy() == 1 {
456                        0
457                    } else {
458                        mismatch_pos
459                    });
460                    polars_bail!(
461                        ShapeMismatch:
462                        "list lengths differed at index {}: {} != {}",
463                        mismatch_pos,
464                        offsets_lhs.length_at(mismatch_pos), len_r
465                    )
466                }
467                Ok(())
468            }
469
470            let mut arr_lhs = {
471                let ca: &ChunkedArray<T> = prim_s_lhs.as_ref().as_ref();
472                assert_eq!(ca.chunks().len(), 1);
473                ca.downcast_get(0).unwrap().clone()
474            };
475
476            let mut arr_rhs = {
477                let ca: &ChunkedArray<T> = prim_s_rhs.as_ref().as_ref();
478                assert_eq!(ca.chunks().len(), 1);
479                ca.downcast_get(0).unwrap().clone()
480            };
481
482            match (&self.op_apply_type, &self.broadcast) {
483                // We skip for this because it dispatches to `ArithmeticKernel`, which handles the
484                // validities for us.
485                (BinaryOpApplyType::ListToPrimitive, Broadcast::Right) => {},
486                _ if self.list_to_prim_lhs.is_none() => {
487                    self.op.0.prepare_numeric_op_side_validities::<T>(
488                        &mut arr_lhs,
489                        &mut arr_rhs,
490                        self.swapped,
491                    )
492                },
493                (BinaryOpApplyType::ListToPrimitive, Broadcast::NoBroadcast) => {
494                    // `self.list_to_prim_lhs` is `Some(_)`, this is handled later.
495                },
496                _ => unreachable!(),
497            }
498
499            //
500            // General notes
501            // * Lists can be:
502            //   * Sliced, in which case the primitive/leaf array needs to be indexed starting from an
503            //     offset instead of 0.
504            //   * Masked, in which case the masked rows are permitted to have non-matching widths.
505            //
506
507            let out = match (&self.op_apply_type, &self.broadcast) {
508                (BinaryOpApplyType::ListToList, Broadcast::NoBroadcast) => {
509                    let offsets_lhs = &self.data_lhs.0[0];
510                    let offsets_rhs = &self.data_rhs.0[0];
511
512                    assert_eq!(offsets_lhs.len_proxy(), offsets_rhs.len_proxy());
513
514                    // Output primitive (and optional validity) are aligned to the LHS input.
515                    let n_values = arr_lhs.len();
516                    let mut out_vec: Vec<T::Native> = Vec::with_capacity(n_values);
517                    let out_ptr: *mut T::Native = out_vec.as_mut_ptr();
518
519                    // Counter that stops being incremented at the first row position with mismatching
520                    // list lengths.
521                    let mut mismatch_pos = 0;
522
523                    with_match_pl_num_arith!(&self.op.0, self.swapped, |$OP| {
524                        for (i, ((lhs_start, lhs_len), (rhs_start, rhs_len))) in offsets_lhs
525                            .offset_and_length_iter()
526                            .zip(offsets_rhs.offset_and_length_iter())
527                            .enumerate()
528                        {
529                            if
530                                (mismatch_pos == i)
531                                & (
532                                    (lhs_len == rhs_len)
533                                    | unsafe { !self.outer_validity.get_bit_unchecked(i) }
534                                )
535                            {
536                                mismatch_pos += 1;
537                            }
538
539                            // Both sides are lists, we restrict the index to the min length to avoid
540                            // OOB memory access.
541                            let len: usize = lhs_len.min(rhs_len);
542
543                            for i in 0..len {
544                                let l_idx = i + lhs_start;
545                                let r_idx = i + rhs_start;
546
547                                let l = unsafe { arr_lhs.value_unchecked(l_idx) };
548                                let r = unsafe { arr_rhs.value_unchecked(r_idx) };
549                                let v = $OP(l, r);
550
551                                unsafe { out_ptr.add(l_idx).write(v) };
552                            }
553                        }
554                    });
555
556                    check_mismatch_pos(mismatch_pos, offsets_lhs, offsets_rhs)?;
557
558                    unsafe { out_vec.set_len(n_values) };
559
560                    /// Reduce monomorphization
561                    #[inline(never)]
562                    fn combine_validities_list_to_list_no_broadcast(
563                        offsets_lhs: &OffsetsBuffer<i64>,
564                        offsets_rhs: &OffsetsBuffer<i64>,
565                        validity_lhs: Option<&Bitmap>,
566                        validity_rhs: Option<&Bitmap>,
567                        len_lhs: usize,
568                    ) -> Option<Bitmap> {
569                        match (validity_lhs, validity_rhs) {
570                            (Some(l), Some(r)) => Some((l.clone().make_mut(), r)),
571                            (Some(v), None) => return Some(v.clone()),
572                            (None, Some(v)) => {
573                                Some((Bitmap::new_with_value(true, len_lhs).make_mut(), v))
574                            },
575                            (None, None) => None,
576                        }
577                        .map(|(mut validity_out, validity_rhs)| {
578                            for ((lhs_start, lhs_len), (rhs_start, rhs_len)) in offsets_lhs
579                                .offset_and_length_iter()
580                                .zip(offsets_rhs.offset_and_length_iter())
581                            {
582                                let len: usize = lhs_len.min(rhs_len);
583
584                                for i in 0..len {
585                                    let l_idx = i + lhs_start;
586                                    let r_idx = i + rhs_start;
587
588                                    let l_valid = unsafe { validity_out.get_unchecked(l_idx) };
589                                    let r_valid = unsafe { validity_rhs.get_bit_unchecked(r_idx) };
590                                    let is_valid = l_valid & r_valid;
591
592                                    // Size and alignment of validity vec are based on LHS.
593                                    unsafe { validity_out.set_unchecked(l_idx, is_valid) };
594                                }
595                            }
596
597                            validity_out.freeze()
598                        })
599                    }
600
601                    let leaf_validity = combine_validities_list_to_list_no_broadcast(
602                        offsets_lhs,
603                        offsets_rhs,
604                        arr_lhs.validity(),
605                        arr_rhs.validity(),
606                        arr_lhs.len(),
607                    );
608
609                    let arr =
610                        PrimitiveArray::<T::Native>::from_vec(out_vec).with_validity(leaf_validity);
611
612                    let (offsets, validities, _) = std::mem::take(&mut self.data_lhs);
613                    assert_eq!(offsets.len(), 1);
614
615                    self.finish_offsets_and_validities(Box::new(arr), offsets, validities)
616                },
617                (BinaryOpApplyType::ListToList, Broadcast::Right) => {
618                    let offsets_lhs = &self.data_lhs.0[0];
619                    let offsets_rhs = &self.data_rhs.0[0];
620
621                    // Output primitive (and optional validity) are aligned to the LHS input.
622                    let n_values = arr_lhs.len();
623                    let mut out_vec: Vec<T::Native> = Vec::with_capacity(n_values);
624                    let out_ptr: *mut T::Native = out_vec.as_mut_ptr();
625
626                    assert_eq!(offsets_rhs.len_proxy(), 1);
627                    let rhs_start = *offsets_rhs.first() as usize;
628                    let width = offsets_rhs.range() as usize;
629
630                    let mut mismatch_pos = 0;
631
632                    with_match_pl_num_arith!(&self.op.0, self.swapped, |$OP| {
633                        for (i, (lhs_start, lhs_len)) in offsets_lhs.offset_and_length_iter().enumerate() {
634                            if ((lhs_len == width) & (mismatch_pos == i))
635                                | unsafe { !self.outer_validity.get_bit_unchecked(i) }
636                            {
637                                mismatch_pos += 1;
638                            }
639
640                            let len: usize = lhs_len.min(width);
641
642                            for i in 0..len {
643                                let l_idx = i + lhs_start;
644                                let r_idx = i + rhs_start;
645
646                                let l = unsafe { arr_lhs.value_unchecked(l_idx) };
647                                let r = unsafe { arr_rhs.value_unchecked(r_idx) };
648                                let v = $OP(l, r);
649
650                                unsafe {
651                                    out_ptr.add(l_idx).write(v);
652                                }
653                            }
654                        }
655                    });
656
657                    check_mismatch_pos(mismatch_pos, offsets_lhs, offsets_rhs)?;
658
659                    unsafe { out_vec.set_len(n_values) };
660
661                    #[inline(never)]
662                    fn combine_validities_list_to_list_broadcast_right(
663                        offsets_lhs: &OffsetsBuffer<i64>,
664                        validity_lhs: Option<&Bitmap>,
665                        validity_rhs: Option<&Bitmap>,
666                        len_lhs: usize,
667                        width: usize,
668                        rhs_start: usize,
669                    ) -> Option<Bitmap> {
670                        match (validity_lhs, validity_rhs) {
671                            (Some(l), Some(r)) => Some((l.clone().make_mut(), r)),
672                            (Some(v), None) => return Some(v.clone()),
673                            (None, Some(v)) => {
674                                Some((Bitmap::new_with_value(true, len_lhs).make_mut(), v))
675                            },
676                            (None, None) => None,
677                        }
678                        .map(|(mut validity_out, validity_rhs)| {
679                            for (lhs_start, lhs_len) in offsets_lhs.offset_and_length_iter() {
680                                let len: usize = lhs_len.min(width);
681
682                                for i in 0..len {
683                                    let l_idx = i + lhs_start;
684                                    let r_idx = i + rhs_start;
685
686                                    let l_valid = unsafe { validity_out.get_unchecked(l_idx) };
687                                    let r_valid = unsafe { validity_rhs.get_bit_unchecked(r_idx) };
688                                    let is_valid = l_valid & r_valid;
689
690                                    // Size and alignment of validity vec are based on LHS.
691                                    unsafe { validity_out.set_unchecked(l_idx, is_valid) };
692                                }
693                            }
694
695                            validity_out.freeze()
696                        })
697                    }
698
699                    let leaf_validity = combine_validities_list_to_list_broadcast_right(
700                        offsets_lhs,
701                        arr_lhs.validity(),
702                        arr_rhs.validity(),
703                        arr_lhs.len(),
704                        width,
705                        rhs_start,
706                    );
707
708                    let arr =
709                        PrimitiveArray::<T::Native>::from_vec(out_vec).with_validity(leaf_validity);
710
711                    let (offsets, validities, _) = std::mem::take(&mut self.data_lhs);
712                    assert_eq!(offsets.len(), 1);
713
714                    self.finish_offsets_and_validities(Box::new(arr), offsets, validities)
715                },
716                (BinaryOpApplyType::ListToPrimitive, Broadcast::NoBroadcast)
717                    if self.list_to_prim_lhs.is_none() =>
718                {
719                    let offsets_lhs = self.data_lhs.0.as_slice();
720
721                    // Notes
722                    // * Primitive indexing starts from 0
723                    // * Output is aligned to LHS array
724
725                    let n_values = arr_lhs.len();
726                    let mut out_vec = Vec::<T::Native>::with_capacity(n_values);
727                    let out_ptr = out_vec.as_mut_ptr();
728
729                    with_match_pl_num_arith!(&self.op.0, self.swapped, |$OP| {
730                        for (i, l_range) in OffsetsBuffer::<i64>::leaf_ranges_iter(offsets_lhs).enumerate()
731                        {
732                            let r = unsafe { arr_rhs.value_unchecked(i) };
733                            for l_idx in l_range {
734                                unsafe {
735                                    let l = arr_lhs.value_unchecked(l_idx);
736                                    let v = $OP(l, r);
737                                    out_ptr.add(l_idx).write(v);
738                                }
739                            }
740                        }
741                    });
742
743                    unsafe { out_vec.set_len(n_values) }
744
745                    let leaf_validity = combine_validities_list_to_primitive_no_broadcast(
746                        offsets_lhs,
747                        arr_lhs.validity(),
748                        arr_rhs.validity(),
749                        arr_lhs.len(),
750                    );
751
752                    let arr =
753                        PrimitiveArray::<T::Native>::from_vec(out_vec).with_validity(leaf_validity);
754
755                    let (offsets, validities, _) = std::mem::take(&mut self.data_lhs);
756                    self.finish_offsets_and_validities(Box::new(arr), offsets, validities)
757                },
758                // If we are dispatched here, it means that the LHS array is a unique allocation created
759                // after a unit-length list column was broadcasted, so this codepath mutably stores the
760                // results back into the LHS array to save memory.
761                (BinaryOpApplyType::ListToPrimitive, Broadcast::NoBroadcast) => {
762                    let offsets_lhs = self.data_lhs.0.as_slice();
763
764                    let (mut arr, n_values) = Option::take(&mut self.list_to_prim_lhs).unwrap();
765                    let arr = arr
766                        .as_any_mut()
767                        .downcast_mut::<PrimitiveArray<T::Native>>()
768                        .unwrap();
769                    let mut arr_lhs = std::mem::take(arr);
770
771                    self.op.0.prepare_numeric_op_side_validities::<T>(
772                        &mut arr_lhs,
773                        &mut arr_rhs,
774                        self.swapped,
775                    );
776
777                    let arr_lhs_mut_slice = arr_lhs.get_mut_values().unwrap();
778                    assert_eq!(arr_lhs_mut_slice.len(), n_values);
779
780                    with_match_pl_num_arith!(&self.op.0, self.swapped, |$OP| {
781                        for (i, l_range) in OffsetsBuffer::<i64>::leaf_ranges_iter(offsets_lhs).enumerate()
782                        {
783                            let r = unsafe { arr_rhs.value_unchecked(i) };
784                            for l_idx in l_range {
785                                unsafe {
786                                    let l = arr_lhs_mut_slice.get_unchecked_mut(l_idx);
787                                    *l = $OP(*l, r);
788                                }
789                            }
790                        }
791                    });
792
793                    let leaf_validity = combine_validities_list_to_primitive_no_broadcast(
794                        offsets_lhs,
795                        arr_lhs.validity(),
796                        arr_rhs.validity(),
797                        arr_lhs.len(),
798                    );
799
800                    let arr = arr_lhs.with_validity(leaf_validity);
801
802                    let (offsets, validities, _) = std::mem::take(&mut self.data_lhs);
803                    self.finish_offsets_and_validities(Box::new(arr), offsets, validities)
804                },
805                (BinaryOpApplyType::ListToPrimitive, Broadcast::Right) => {
806                    assert_eq!(arr_rhs.len(), 1);
807
808                    let Some(r) = (unsafe { arr_rhs.get_unchecked(0) }) else {
809                        // RHS is single primitive NULL, create the result by setting the leaf validity to all-NULL.
810                        let (offsets, validities, _) = std::mem::take(&mut self.data_lhs);
811                        return Ok(self.finish_offsets_and_validities(
812                            Box::new(
813                                arr_lhs.clone().with_validity(Some(Bitmap::new_with_value(
814                                    false,
815                                    arr_lhs.len(),
816                                ))),
817                            ),
818                            offsets,
819                            validities,
820                        ));
821                    };
822
823                    let arr = self
824                        .op
825                        .0
826                        .apply_array_to_scalar::<T>(arr_lhs, r, self.swapped);
827                    let (offsets, validities, _) = std::mem::take(&mut self.data_lhs);
828
829                    self.finish_offsets_and_validities(Box::new(arr), offsets, validities)
830                },
831                v @ (BinaryOpApplyType::PrimitiveToList, Broadcast::Right)
832                | v @ (BinaryOpApplyType::ListToList, Broadcast::Left)
833                | v @ (BinaryOpApplyType::ListToPrimitive, Broadcast::Left)
834                | v @ (BinaryOpApplyType::PrimitiveToList, Broadcast::Left)
835                | v @ (BinaryOpApplyType::PrimitiveToList, Broadcast::NoBroadcast) => {
836                    if cfg!(debug_assertions) {
837                        panic!("operation was not re-written: {:?}", v)
838                    } else {
839                        unreachable!()
840                    }
841                },
842            };
843
844            Ok(out)
845        }
846
847        /// Construct the result `ListChunked` from the leaf array and the offsets/validities of every
848        /// level.
849        fn finish_offsets_and_validities(
850            &mut self,
851            leaf_array: Box<dyn Array>,
852            offsets: Vec<OffsetsBuffer<i64>>,
853            validities: Vec<Option<Bitmap>>,
854        ) -> ListChunked {
855            assert!(!offsets.is_empty());
856            assert_eq!(offsets.len(), validities.len());
857            let mut results = leaf_array;
858
859            let mut iter = offsets.into_iter().zip(validities).rev();
860
861            while iter.len() > 1 {
862                let (offsets, validity) = iter.next().unwrap();
863                let dtype = LargeListArray::default_datatype(results.dtype().clone());
864                results = Box::new(LargeListArray::new(dtype, offsets, results, validity));
865            }
866
867            // The combined outer validity is pre-computed during `try_new()`
868            let (offsets, _) = iter.next().unwrap();
869            let validity = std::mem::take(&mut self.outer_validity);
870            let dtype = LargeListArray::default_datatype(results.dtype().clone());
871            let results = LargeListArray::new(dtype, offsets, results, Some(validity));
872
873            ListChunked::with_chunk(std::mem::take(&mut self.output_name), results)
874        }
875
876        fn materialize_broadcasted_list(
877            side_data: &mut (Vec<OffsetsBuffer<i64>>, Vec<Option<Bitmap>>, Series),
878            output_len: usize,
879            output_primitive_dtype: &DataType,
880        ) -> (Box<dyn Array>, usize) {
881            let s = &side_data.2;
882            assert_eq!(s.len(), 1);
883
884            let expected_n_values = {
885                let offsets = s.list_offsets_and_validities_recursive().0;
886                output_len * OffsetsBuffer::<i64>::leaf_full_start_end(&offsets).len()
887            };
888
889            let ca = s.list().unwrap();
890            // Remember to cast the leaf primitives to the supertype.
891            let ca = ca
892                .cast(&ca.dtype().cast_leaf(output_primitive_dtype.clone()))
893                .unwrap();
894            assert!(output_len > 1); // In case there is a fast-path that doesn't give us owned data.
895            let ca = ca.new_from_index(0, output_len).rechunk();
896
897            let s = ca.into_series();
898
899            *side_data = {
900                let (a, b) = s.list_offsets_and_validities_recursive();
901                // `Series::default()`: This field in the tuple is no longer used.
902                (a, b, Series::default())
903            };
904
905            let n_values = OffsetsBuffer::<i64>::leaf_full_start_end(&side_data.0).len();
906            assert_eq!(n_values, expected_n_values);
907
908            let mut s = s.get_leaf_array();
909            let v = unsafe { s.chunks_mut() };
910
911            assert_eq!(v.len(), 1);
912            (v.swap_remove(0), n_values)
913        }
914    }
915
916    /// Used in 2 places, so it's outside here.
917    #[inline(never)]
918    fn combine_validities_list_to_primitive_no_broadcast(
919        offsets_lhs: &[OffsetsBuffer<i64>],
920        validity_lhs: Option<&Bitmap>,
921        validity_rhs: Option<&Bitmap>,
922        len_lhs: usize,
923    ) -> Option<Bitmap> {
924        match (validity_lhs, validity_rhs) {
925            (Some(l), Some(r)) => Some((l.clone().make_mut(), r)),
926            (Some(v), None) => return Some(v.clone()),
927            // Materialize a full-true validity to re-use the codepath, as we still
928            // need to spread the bits from the RHS to the correct positions.
929            (None, Some(v)) => Some((Bitmap::new_with_value(true, len_lhs).make_mut(), v)),
930            (None, None) => None,
931        }
932        .map(|(mut validity_out, validity_rhs)| {
933            for (i, l_range) in OffsetsBuffer::<i64>::leaf_ranges_iter(offsets_lhs).enumerate() {
934                let r_valid = unsafe { validity_rhs.get_bit_unchecked(i) };
935                for l_idx in l_range {
936                    let l_valid = unsafe { validity_out.get_unchecked(l_idx) };
937                    let is_valid = l_valid & r_valid;
938
939                    // Size and alignment of validity vec are based on LHS.
940                    unsafe { validity_out.set_unchecked(l_idx, is_valid) };
941                }
942            }
943
944            validity_out.freeze()
945        })
946    }
947}