1use polars_error::{PolarsResult, feature_gated};
5
6use super::list_utils::NumericOp;
7use super::{IntoSeries, ListChunked, ListType, NumOpsDispatchInner, Series};
8
9impl NumOpsDispatchInner for ListType {
10 fn add_to(lhs: &ListChunked, rhs: &Series) -> PolarsResult<Series> {
11 NumericListOp::add().execute(&lhs.clone().into_series(), rhs)
12 }
13
14 fn subtract(lhs: &ListChunked, rhs: &Series) -> PolarsResult<Series> {
15 NumericListOp::sub().execute(&lhs.clone().into_series(), rhs)
16 }
17
18 fn multiply(lhs: &ListChunked, rhs: &Series) -> PolarsResult<Series> {
19 NumericListOp::mul().execute(&lhs.clone().into_series(), rhs)
20 }
21
22 fn divide(lhs: &ListChunked, rhs: &Series) -> PolarsResult<Series> {
23 NumericListOp::div().execute(&lhs.clone().into_series(), rhs)
24 }
25
26 fn remainder(lhs: &ListChunked, rhs: &Series) -> PolarsResult<Series> {
27 NumericListOp::rem().execute(&lhs.clone().into_series(), rhs)
28 }
29}
30
31#[derive(Clone)]
32pub struct NumericListOp(NumericOp);
33
34impl NumericListOp {
35 pub fn add() -> Self {
36 Self(NumericOp::Add)
37 }
38
39 pub fn sub() -> Self {
40 Self(NumericOp::Sub)
41 }
42
43 pub fn mul() -> Self {
44 Self(NumericOp::Mul)
45 }
46
47 pub fn div() -> Self {
48 Self(NumericOp::Div)
49 }
50
51 pub fn rem() -> Self {
52 Self(NumericOp::Rem)
53 }
54
55 pub fn floor_div() -> Self {
56 Self(NumericOp::FloorDiv)
57 }
58}
59
60impl NumericListOp {
61 #[cfg_attr(not(feature = "list_arithmetic"), allow(unused))]
62 pub fn execute(&self, lhs: &Series, rhs: &Series) -> PolarsResult<Series> {
63 feature_gated!("list_arithmetic", {
64 use std::borrow::Cow;
65
66 use either::Either;
67
68 let lhs = lhs
72 .trim_lists_to_normalized_offsets()
73 .map_or(Cow::Borrowed(lhs), Cow::Owned);
74 let rhs = rhs
75 .trim_lists_to_normalized_offsets()
76 .map_or(Cow::Borrowed(rhs), Cow::Owned);
77
78 let lhs = lhs.rechunk();
79 let rhs = rhs.rechunk();
80
81 let binary_op_exec = match ListNumericOpHelper::try_new(
82 self.clone(),
83 lhs.name().clone(),
84 lhs.dtype(),
85 rhs.dtype(),
86 lhs.len(),
87 rhs.len(),
88 {
89 let (a, b) = lhs.list_offsets_and_validities_recursive();
90 debug_assert!(a.iter().all(|x| *x.first() as usize == 0));
91 (a, b, lhs.clone())
92 },
93 {
94 let (a, b) = rhs.list_offsets_and_validities_recursive();
95 debug_assert!(a.iter().all(|x| *x.first() as usize == 0));
96 (a, b, rhs.clone())
97 },
98 lhs.rechunk_validity(),
99 rhs.rechunk_validity(),
100 )? {
101 Either::Left(v) => v,
102 Either::Right(ca) => return Ok(ca.into_series()),
103 };
104
105 Ok(binary_op_exec.finish()?.into_series())
106 })
107 }
108}
109
110#[cfg(feature = "list_arithmetic")]
111use inner::ListNumericOpHelper;
112
113#[cfg(feature = "list_arithmetic")]
114mod inner {
115 use arrow::bitmap::Bitmap;
116 use arrow::compute::utils::combine_validities_and;
117 use arrow::offset::OffsetsBuffer;
118 use either::Either;
119 use list_utils::with_match_pl_num_arith;
120 use num_traits::Zero;
121 use polars_compute::arithmetic::pl_num::PlNumArithmetic;
122 use polars_utils::float::IsFloat;
123
124 use super::super::list_utils::{BinaryOpApplyType, Broadcast, NumericOp};
125 use super::super::*;
126
127 pub(super) struct ListNumericOpHelper {
130 op: NumericListOp,
131 output_name: PlSmallStr,
132 op_apply_type: BinaryOpApplyType,
133 broadcast: Broadcast,
134 output_dtype: DataType,
135 output_primitive_dtype: DataType,
136 output_len: usize,
137 outer_validity: Bitmap,
140 data_lhs: (Vec<OffsetsBuffer<i64>>, Vec<Option<Bitmap>>, Series),
142 data_rhs: (Vec<OffsetsBuffer<i64>>, Vec<Option<Bitmap>>, Series),
143 list_to_prim_lhs: Option<(Box<dyn Array>, usize)>,
144 swapped: bool,
145 }
146
147 impl ListNumericOpHelper {
150 #[allow(clippy::too_many_arguments)]
166 pub(super) fn try_new(
167 op: NumericListOp,
168 output_name: PlSmallStr,
169 dtype_lhs: &DataType,
170 dtype_rhs: &DataType,
171 len_lhs: usize,
172 len_rhs: usize,
173 data_lhs: (Vec<OffsetsBuffer<i64>>, Vec<Option<Bitmap>>, Series),
174 data_rhs: (Vec<OffsetsBuffer<i64>>, Vec<Option<Bitmap>>, Series),
175 validity_lhs: Option<Bitmap>,
176 validity_rhs: Option<Bitmap>,
177 ) -> PolarsResult<Either<Self, ListChunked>> {
178 let prim_dtype_lhs = dtype_lhs.leaf_dtype();
179 let prim_dtype_rhs = dtype_rhs.leaf_dtype();
180
181 let output_primitive_dtype =
182 op.0.try_get_leaf_supertype(prim_dtype_lhs, prim_dtype_rhs)?;
183
184 fn is_list_type_at_all_levels(dtype: &DataType) -> bool {
185 match dtype {
186 DataType::List(inner) => is_list_type_at_all_levels(inner),
187 dt if dt.is_supported_list_arithmetic_input() => true,
188 _ => false,
189 }
190 }
191
192 let op_err_msg = |err_reason: &str| {
193 polars_err!(
194 InvalidOperation:
195 "cannot {} columns: {}: (left: {}, right: {})",
196 op.0.name(), err_reason, dtype_lhs, dtype_rhs,
197 )
198 };
199
200 let ensure_list_type_at_all_levels = |dtype: &DataType| {
201 if !is_list_type_at_all_levels(dtype) {
202 Err(op_err_msg("dtype was not list on all nesting levels"))
203 } else {
204 Ok(())
205 }
206 };
207
208 let (op_apply_type, output_dtype) = match (dtype_lhs, dtype_rhs) {
209 (l @ DataType::List(a), r @ DataType::List(b)) => {
210 if ![a, b]
215 .into_iter()
216 .all(|x| x.is_supported_list_arithmetic_input())
217 {
218 polars_bail!(
219 InvalidOperation:
220 "cannot {} two list columns with non-numeric inner types: (left: {}, right: {})",
221 op.0.name(), l, r,
222 );
223 }
224 (BinaryOpApplyType::ListToList, l)
225 },
226 (list_dtype @ DataType::List(_), x) if x.is_supported_list_arithmetic_input() => {
227 ensure_list_type_at_all_levels(list_dtype)?;
228 (BinaryOpApplyType::ListToPrimitive, list_dtype)
229 },
230 (x, list_dtype @ DataType::List(_)) if x.is_supported_list_arithmetic_input() => {
231 ensure_list_type_at_all_levels(list_dtype)?;
232 (BinaryOpApplyType::PrimitiveToList, list_dtype)
233 },
234 (l, r) => polars_bail!(
235 InvalidOperation:
236 "{} operation not supported for dtypes: {} != {}",
237 op.0.name(), l, r,
238 ),
239 };
240
241 let output_dtype = output_dtype.cast_leaf(output_primitive_dtype.clone());
242
243 let (broadcast, output_len) = match (len_lhs, len_rhs) {
244 (l, r) if l == r => (Broadcast::NoBroadcast, l),
245 (1, v) => (Broadcast::Left, v),
246 (v, 1) => (Broadcast::Right, v),
247 (l, r) => polars_bail!(
248 ShapeMismatch:
249 "cannot {} two columns of differing lengths: {} != {}",
250 op.0.name(), l, r
251 ),
252 };
253
254 let DataType::List(output_inner_dtype) = &output_dtype else {
255 unreachable!()
256 };
257
258 if output_len == 0
265 || (matches!(
266 &op_apply_type,
267 BinaryOpApplyType::ListToList | BinaryOpApplyType::ListToPrimitive
268 ) && validity_lhs.as_ref().is_some_and(|x| x.set_bits() == 0))
269 || (matches!(
270 &op_apply_type,
271 BinaryOpApplyType::ListToList | BinaryOpApplyType::PrimitiveToList
272 ) && validity_rhs.as_ref().is_some_and(|x| x.set_bits() == 0))
273 {
274 return Ok(Either::Right(ListChunked::full_null_with_dtype(
275 output_name.clone(),
276 output_len,
277 output_inner_dtype.as_ref(),
278 )));
279 }
280
281 let outer_validity = match (&op_apply_type, &broadcast, validity_lhs, validity_rhs) {
286 (BinaryOpApplyType::ListToList, Broadcast::NoBroadcast, l, r) => {
288 combine_validities_and(l.as_ref(), r.as_ref())
289 },
290 (
292 BinaryOpApplyType::ListToList | BinaryOpApplyType::ListToPrimitive,
293 Broadcast::NoBroadcast | Broadcast::Right,
294 v,
295 _,
296 )
297 | (
298 BinaryOpApplyType::ListToList | BinaryOpApplyType::PrimitiveToList,
299 Broadcast::NoBroadcast | Broadcast::Left,
300 _,
301 v,
302 ) => v,
303 _ => None,
304 }
305 .unwrap_or_else(|| Bitmap::new_with_value(true, output_len));
306
307 Ok(Either::Left(Self {
308 op,
309 output_name,
310 op_apply_type,
311 broadcast,
312 output_dtype: output_dtype.clone(),
313 output_primitive_dtype,
314 output_len,
315 outer_validity,
316 data_lhs,
317 data_rhs,
318 list_to_prim_lhs: None,
319 swapped: false,
320 }))
321 }
322
323 pub(super) fn finish(mut self) -> PolarsResult<ListChunked> {
324 self.swapped = true;
339
340 match (&self.op_apply_type, &self.broadcast) {
341 (BinaryOpApplyType::ListToList, Broadcast::NoBroadcast)
342 | (BinaryOpApplyType::ListToList, Broadcast::Right)
343 | (BinaryOpApplyType::ListToPrimitive, Broadcast::NoBroadcast)
344 | (BinaryOpApplyType::ListToPrimitive, Broadcast::Right) => {
345 self.swapped = false;
346 self._finish_impl_dispatch()
347 },
348 (BinaryOpApplyType::ListToList, Broadcast::Left) => {
349 self.broadcast = Broadcast::Right;
350
351 std::mem::swap(&mut self.data_lhs, &mut self.data_rhs);
352 self._finish_impl_dispatch()
353 },
354 (BinaryOpApplyType::ListToPrimitive, Broadcast::Left) => {
355 self.list_to_prim_lhs
356 .replace(Self::materialize_broadcasted_list(
357 &mut self.data_lhs,
358 self.output_len,
359 &self.output_primitive_dtype,
360 ));
361
362 self.broadcast = Broadcast::NoBroadcast;
363
364 self.swapped = false;
367 self._finish_impl_dispatch()
368 },
369 (BinaryOpApplyType::PrimitiveToList, Broadcast::NoBroadcast) => {
370 self.op_apply_type = BinaryOpApplyType::ListToPrimitive;
371
372 std::mem::swap(&mut self.data_lhs, &mut self.data_rhs);
373 self._finish_impl_dispatch()
374 },
375 (BinaryOpApplyType::PrimitiveToList, Broadcast::Right) => {
376 self.list_to_prim_lhs
381 .replace(Self::materialize_broadcasted_list(
382 &mut self.data_rhs,
383 self.output_len,
384 &self.output_primitive_dtype,
385 ));
386
387 self.op_apply_type = BinaryOpApplyType::ListToPrimitive;
388 self.broadcast = Broadcast::NoBroadcast;
389
390 std::mem::swap(&mut self.data_lhs, &mut self.data_rhs);
391 self._finish_impl_dispatch()
392 },
393 (BinaryOpApplyType::PrimitiveToList, Broadcast::Left) => {
394 self.op_apply_type = BinaryOpApplyType::ListToPrimitive;
395 self.broadcast = Broadcast::Right;
396
397 std::mem::swap(&mut self.data_lhs, &mut self.data_rhs);
398 self._finish_impl_dispatch()
399 },
400 }
401 }
402
403 fn _finish_impl_dispatch(&mut self) -> PolarsResult<ListChunked> {
404 let output_dtype = self.output_dtype.clone();
405 let output_len = self.output_len;
406
407 let prim_lhs = self
408 .data_lhs
409 .2
410 .get_leaf_array()
411 .cast(&self.output_primitive_dtype)?
412 .rechunk();
413 let prim_rhs = self
414 .data_rhs
415 .2
416 .get_leaf_array()
417 .cast(&self.output_primitive_dtype)?
418 .rechunk();
419
420 debug_assert_eq!(prim_lhs.dtype(), prim_rhs.dtype());
421 let prim_dtype = prim_lhs.dtype();
422 debug_assert_eq!(prim_dtype, &self.output_primitive_dtype);
423
424 let out = with_match_physical_numeric_polars_type!(&prim_dtype, |$T| {
426 self._finish_impl::<$T>(prim_lhs, prim_rhs)
427 })?;
428
429 debug_assert_eq!(out.dtype(), &output_dtype);
430 assert_eq!(out.len(), output_len);
431
432 Ok(out)
433 }
434
435 fn _finish_impl<T: PolarsNumericType>(
437 &mut self,
438 prim_s_lhs: Series,
439 prim_s_rhs: Series,
440 ) -> PolarsResult<ListChunked>
441 where
442 T::Native: PlNumArithmetic,
443 PrimitiveArray<T::Native>:
444 polars_compute::comparisons::TotalEqKernel<Scalar = T::Native>,
445 T::Native: Zero + IsFloat,
446 {
447 #[inline(never)]
448 fn check_mismatch_pos(
449 mismatch_pos: usize,
450 offsets_lhs: &OffsetsBuffer<i64>,
451 offsets_rhs: &OffsetsBuffer<i64>,
452 ) -> PolarsResult<()> {
453 if mismatch_pos < offsets_lhs.len_proxy() {
454 let len_r = offsets_rhs.length_at(if offsets_rhs.len_proxy() == 1 {
456 0
457 } else {
458 mismatch_pos
459 });
460 polars_bail!(
461 ShapeMismatch:
462 "list lengths differed at index {}: {} != {}",
463 mismatch_pos,
464 offsets_lhs.length_at(mismatch_pos), len_r
465 )
466 }
467 Ok(())
468 }
469
470 let mut arr_lhs = {
471 let ca: &ChunkedArray<T> = prim_s_lhs.as_ref().as_ref();
472 assert_eq!(ca.chunks().len(), 1);
473 ca.downcast_get(0).unwrap().clone()
474 };
475
476 let mut arr_rhs = {
477 let ca: &ChunkedArray<T> = prim_s_rhs.as_ref().as_ref();
478 assert_eq!(ca.chunks().len(), 1);
479 ca.downcast_get(0).unwrap().clone()
480 };
481
482 match (&self.op_apply_type, &self.broadcast) {
483 (BinaryOpApplyType::ListToPrimitive, Broadcast::Right) => {},
486 _ if self.list_to_prim_lhs.is_none() => {
487 self.op.0.prepare_numeric_op_side_validities::<T>(
488 &mut arr_lhs,
489 &mut arr_rhs,
490 self.swapped,
491 )
492 },
493 (BinaryOpApplyType::ListToPrimitive, Broadcast::NoBroadcast) => {
494 },
496 _ => unreachable!(),
497 }
498
499 let out = match (&self.op_apply_type, &self.broadcast) {
508 (BinaryOpApplyType::ListToList, Broadcast::NoBroadcast) => {
509 let offsets_lhs = &self.data_lhs.0[0];
510 let offsets_rhs = &self.data_rhs.0[0];
511
512 assert_eq!(offsets_lhs.len_proxy(), offsets_rhs.len_proxy());
513
514 let n_values = arr_lhs.len();
516 let mut out_vec: Vec<T::Native> = Vec::with_capacity(n_values);
517 let out_ptr: *mut T::Native = out_vec.as_mut_ptr();
518
519 let mut mismatch_pos = 0;
522
523 with_match_pl_num_arith!(&self.op.0, self.swapped, |$OP| {
524 for (i, ((lhs_start, lhs_len), (rhs_start, rhs_len))) in offsets_lhs
525 .offset_and_length_iter()
526 .zip(offsets_rhs.offset_and_length_iter())
527 .enumerate()
528 {
529 if
530 (mismatch_pos == i)
531 & (
532 (lhs_len == rhs_len)
533 | unsafe { !self.outer_validity.get_bit_unchecked(i) }
534 )
535 {
536 mismatch_pos += 1;
537 }
538
539 let len: usize = lhs_len.min(rhs_len);
542
543 for i in 0..len {
544 let l_idx = i + lhs_start;
545 let r_idx = i + rhs_start;
546
547 let l = unsafe { arr_lhs.value_unchecked(l_idx) };
548 let r = unsafe { arr_rhs.value_unchecked(r_idx) };
549 let v = $OP(l, r);
550
551 unsafe { out_ptr.add(l_idx).write(v) };
552 }
553 }
554 });
555
556 check_mismatch_pos(mismatch_pos, offsets_lhs, offsets_rhs)?;
557
558 unsafe { out_vec.set_len(n_values) };
559
560 #[inline(never)]
562 fn combine_validities_list_to_list_no_broadcast(
563 offsets_lhs: &OffsetsBuffer<i64>,
564 offsets_rhs: &OffsetsBuffer<i64>,
565 validity_lhs: Option<&Bitmap>,
566 validity_rhs: Option<&Bitmap>,
567 len_lhs: usize,
568 ) -> Option<Bitmap> {
569 match (validity_lhs, validity_rhs) {
570 (Some(l), Some(r)) => Some((l.clone().make_mut(), r)),
571 (Some(v), None) => return Some(v.clone()),
572 (None, Some(v)) => {
573 Some((Bitmap::new_with_value(true, len_lhs).make_mut(), v))
574 },
575 (None, None) => None,
576 }
577 .map(|(mut validity_out, validity_rhs)| {
578 for ((lhs_start, lhs_len), (rhs_start, rhs_len)) in offsets_lhs
579 .offset_and_length_iter()
580 .zip(offsets_rhs.offset_and_length_iter())
581 {
582 let len: usize = lhs_len.min(rhs_len);
583
584 for i in 0..len {
585 let l_idx = i + lhs_start;
586 let r_idx = i + rhs_start;
587
588 let l_valid = unsafe { validity_out.get_unchecked(l_idx) };
589 let r_valid = unsafe { validity_rhs.get_bit_unchecked(r_idx) };
590 let is_valid = l_valid & r_valid;
591
592 unsafe { validity_out.set_unchecked(l_idx, is_valid) };
594 }
595 }
596
597 validity_out.freeze()
598 })
599 }
600
601 let leaf_validity = combine_validities_list_to_list_no_broadcast(
602 offsets_lhs,
603 offsets_rhs,
604 arr_lhs.validity(),
605 arr_rhs.validity(),
606 arr_lhs.len(),
607 );
608
609 let arr =
610 PrimitiveArray::<T::Native>::from_vec(out_vec).with_validity(leaf_validity);
611
612 let (offsets, validities, _) = std::mem::take(&mut self.data_lhs);
613 assert_eq!(offsets.len(), 1);
614
615 self.finish_offsets_and_validities(Box::new(arr), offsets, validities)
616 },
617 (BinaryOpApplyType::ListToList, Broadcast::Right) => {
618 let offsets_lhs = &self.data_lhs.0[0];
619 let offsets_rhs = &self.data_rhs.0[0];
620
621 let n_values = arr_lhs.len();
623 let mut out_vec: Vec<T::Native> = Vec::with_capacity(n_values);
624 let out_ptr: *mut T::Native = out_vec.as_mut_ptr();
625
626 assert_eq!(offsets_rhs.len_proxy(), 1);
627 let rhs_start = *offsets_rhs.first() as usize;
628 let width = offsets_rhs.range() as usize;
629
630 let mut mismatch_pos = 0;
631
632 with_match_pl_num_arith!(&self.op.0, self.swapped, |$OP| {
633 for (i, (lhs_start, lhs_len)) in offsets_lhs.offset_and_length_iter().enumerate() {
634 if ((lhs_len == width) & (mismatch_pos == i))
635 | unsafe { !self.outer_validity.get_bit_unchecked(i) }
636 {
637 mismatch_pos += 1;
638 }
639
640 let len: usize = lhs_len.min(width);
641
642 for i in 0..len {
643 let l_idx = i + lhs_start;
644 let r_idx = i + rhs_start;
645
646 let l = unsafe { arr_lhs.value_unchecked(l_idx) };
647 let r = unsafe { arr_rhs.value_unchecked(r_idx) };
648 let v = $OP(l, r);
649
650 unsafe {
651 out_ptr.add(l_idx).write(v);
652 }
653 }
654 }
655 });
656
657 check_mismatch_pos(mismatch_pos, offsets_lhs, offsets_rhs)?;
658
659 unsafe { out_vec.set_len(n_values) };
660
661 #[inline(never)]
662 fn combine_validities_list_to_list_broadcast_right(
663 offsets_lhs: &OffsetsBuffer<i64>,
664 validity_lhs: Option<&Bitmap>,
665 validity_rhs: Option<&Bitmap>,
666 len_lhs: usize,
667 width: usize,
668 rhs_start: usize,
669 ) -> Option<Bitmap> {
670 match (validity_lhs, validity_rhs) {
671 (Some(l), Some(r)) => Some((l.clone().make_mut(), r)),
672 (Some(v), None) => return Some(v.clone()),
673 (None, Some(v)) => {
674 Some((Bitmap::new_with_value(true, len_lhs).make_mut(), v))
675 },
676 (None, None) => None,
677 }
678 .map(|(mut validity_out, validity_rhs)| {
679 for (lhs_start, lhs_len) in offsets_lhs.offset_and_length_iter() {
680 let len: usize = lhs_len.min(width);
681
682 for i in 0..len {
683 let l_idx = i + lhs_start;
684 let r_idx = i + rhs_start;
685
686 let l_valid = unsafe { validity_out.get_unchecked(l_idx) };
687 let r_valid = unsafe { validity_rhs.get_bit_unchecked(r_idx) };
688 let is_valid = l_valid & r_valid;
689
690 unsafe { validity_out.set_unchecked(l_idx, is_valid) };
692 }
693 }
694
695 validity_out.freeze()
696 })
697 }
698
699 let leaf_validity = combine_validities_list_to_list_broadcast_right(
700 offsets_lhs,
701 arr_lhs.validity(),
702 arr_rhs.validity(),
703 arr_lhs.len(),
704 width,
705 rhs_start,
706 );
707
708 let arr =
709 PrimitiveArray::<T::Native>::from_vec(out_vec).with_validity(leaf_validity);
710
711 let (offsets, validities, _) = std::mem::take(&mut self.data_lhs);
712 assert_eq!(offsets.len(), 1);
713
714 self.finish_offsets_and_validities(Box::new(arr), offsets, validities)
715 },
716 (BinaryOpApplyType::ListToPrimitive, Broadcast::NoBroadcast)
717 if self.list_to_prim_lhs.is_none() =>
718 {
719 let offsets_lhs = self.data_lhs.0.as_slice();
720
721 let n_values = arr_lhs.len();
726 let mut out_vec = Vec::<T::Native>::with_capacity(n_values);
727 let out_ptr = out_vec.as_mut_ptr();
728
729 with_match_pl_num_arith!(&self.op.0, self.swapped, |$OP| {
730 for (i, l_range) in OffsetsBuffer::<i64>::leaf_ranges_iter(offsets_lhs).enumerate()
731 {
732 let r = unsafe { arr_rhs.value_unchecked(i) };
733 for l_idx in l_range {
734 unsafe {
735 let l = arr_lhs.value_unchecked(l_idx);
736 let v = $OP(l, r);
737 out_ptr.add(l_idx).write(v);
738 }
739 }
740 }
741 });
742
743 unsafe { out_vec.set_len(n_values) }
744
745 let leaf_validity = combine_validities_list_to_primitive_no_broadcast(
746 offsets_lhs,
747 arr_lhs.validity(),
748 arr_rhs.validity(),
749 arr_lhs.len(),
750 );
751
752 let arr =
753 PrimitiveArray::<T::Native>::from_vec(out_vec).with_validity(leaf_validity);
754
755 let (offsets, validities, _) = std::mem::take(&mut self.data_lhs);
756 self.finish_offsets_and_validities(Box::new(arr), offsets, validities)
757 },
758 (BinaryOpApplyType::ListToPrimitive, Broadcast::NoBroadcast) => {
762 let offsets_lhs = self.data_lhs.0.as_slice();
763
764 let (mut arr, n_values) = Option::take(&mut self.list_to_prim_lhs).unwrap();
765 let arr = arr
766 .as_any_mut()
767 .downcast_mut::<PrimitiveArray<T::Native>>()
768 .unwrap();
769 let mut arr_lhs = std::mem::take(arr);
770
771 self.op.0.prepare_numeric_op_side_validities::<T>(
772 &mut arr_lhs,
773 &mut arr_rhs,
774 self.swapped,
775 );
776
777 let arr_lhs_mut_slice = arr_lhs.get_mut_values().unwrap();
778 assert_eq!(arr_lhs_mut_slice.len(), n_values);
779
780 with_match_pl_num_arith!(&self.op.0, self.swapped, |$OP| {
781 for (i, l_range) in OffsetsBuffer::<i64>::leaf_ranges_iter(offsets_lhs).enumerate()
782 {
783 let r = unsafe { arr_rhs.value_unchecked(i) };
784 for l_idx in l_range {
785 unsafe {
786 let l = arr_lhs_mut_slice.get_unchecked_mut(l_idx);
787 *l = $OP(*l, r);
788 }
789 }
790 }
791 });
792
793 let leaf_validity = combine_validities_list_to_primitive_no_broadcast(
794 offsets_lhs,
795 arr_lhs.validity(),
796 arr_rhs.validity(),
797 arr_lhs.len(),
798 );
799
800 let arr = arr_lhs.with_validity(leaf_validity);
801
802 let (offsets, validities, _) = std::mem::take(&mut self.data_lhs);
803 self.finish_offsets_and_validities(Box::new(arr), offsets, validities)
804 },
805 (BinaryOpApplyType::ListToPrimitive, Broadcast::Right) => {
806 assert_eq!(arr_rhs.len(), 1);
807
808 let Some(r) = (unsafe { arr_rhs.get_unchecked(0) }) else {
809 let (offsets, validities, _) = std::mem::take(&mut self.data_lhs);
811 return Ok(self.finish_offsets_and_validities(
812 Box::new(
813 arr_lhs.clone().with_validity(Some(Bitmap::new_with_value(
814 false,
815 arr_lhs.len(),
816 ))),
817 ),
818 offsets,
819 validities,
820 ));
821 };
822
823 let arr = self
824 .op
825 .0
826 .apply_array_to_scalar::<T>(arr_lhs, r, self.swapped);
827 let (offsets, validities, _) = std::mem::take(&mut self.data_lhs);
828
829 self.finish_offsets_and_validities(Box::new(arr), offsets, validities)
830 },
831 v @ (BinaryOpApplyType::PrimitiveToList, Broadcast::Right)
832 | v @ (BinaryOpApplyType::ListToList, Broadcast::Left)
833 | v @ (BinaryOpApplyType::ListToPrimitive, Broadcast::Left)
834 | v @ (BinaryOpApplyType::PrimitiveToList, Broadcast::Left)
835 | v @ (BinaryOpApplyType::PrimitiveToList, Broadcast::NoBroadcast) => {
836 if cfg!(debug_assertions) {
837 panic!("operation was not re-written: {:?}", v)
838 } else {
839 unreachable!()
840 }
841 },
842 };
843
844 Ok(out)
845 }
846
847 fn finish_offsets_and_validities(
850 &mut self,
851 leaf_array: Box<dyn Array>,
852 offsets: Vec<OffsetsBuffer<i64>>,
853 validities: Vec<Option<Bitmap>>,
854 ) -> ListChunked {
855 assert!(!offsets.is_empty());
856 assert_eq!(offsets.len(), validities.len());
857 let mut results = leaf_array;
858
859 let mut iter = offsets.into_iter().zip(validities).rev();
860
861 while iter.len() > 1 {
862 let (offsets, validity) = iter.next().unwrap();
863 let dtype = LargeListArray::default_datatype(results.dtype().clone());
864 results = Box::new(LargeListArray::new(dtype, offsets, results, validity));
865 }
866
867 let (offsets, _) = iter.next().unwrap();
869 let validity = std::mem::take(&mut self.outer_validity);
870 let dtype = LargeListArray::default_datatype(results.dtype().clone());
871 let results = LargeListArray::new(dtype, offsets, results, Some(validity));
872
873 ListChunked::with_chunk(std::mem::take(&mut self.output_name), results)
874 }
875
876 fn materialize_broadcasted_list(
877 side_data: &mut (Vec<OffsetsBuffer<i64>>, Vec<Option<Bitmap>>, Series),
878 output_len: usize,
879 output_primitive_dtype: &DataType,
880 ) -> (Box<dyn Array>, usize) {
881 let s = &side_data.2;
882 assert_eq!(s.len(), 1);
883
884 let expected_n_values = {
885 let offsets = s.list_offsets_and_validities_recursive().0;
886 output_len * OffsetsBuffer::<i64>::leaf_full_start_end(&offsets).len()
887 };
888
889 let ca = s.list().unwrap();
890 let ca = ca
892 .cast(&ca.dtype().cast_leaf(output_primitive_dtype.clone()))
893 .unwrap();
894 assert!(output_len > 1); let ca = ca.new_from_index(0, output_len).rechunk();
896
897 let s = ca.into_series();
898
899 *side_data = {
900 let (a, b) = s.list_offsets_and_validities_recursive();
901 (a, b, Series::default())
903 };
904
905 let n_values = OffsetsBuffer::<i64>::leaf_full_start_end(&side_data.0).len();
906 assert_eq!(n_values, expected_n_values);
907
908 let mut s = s.get_leaf_array();
909 let v = unsafe { s.chunks_mut() };
910
911 assert_eq!(v.len(), 1);
912 (v.swap_remove(0), n_values)
913 }
914 }
915
916 #[inline(never)]
918 fn combine_validities_list_to_primitive_no_broadcast(
919 offsets_lhs: &[OffsetsBuffer<i64>],
920 validity_lhs: Option<&Bitmap>,
921 validity_rhs: Option<&Bitmap>,
922 len_lhs: usize,
923 ) -> Option<Bitmap> {
924 match (validity_lhs, validity_rhs) {
925 (Some(l), Some(r)) => Some((l.clone().make_mut(), r)),
926 (Some(v), None) => return Some(v.clone()),
927 (None, Some(v)) => Some((Bitmap::new_with_value(true, len_lhs).make_mut(), v)),
930 (None, None) => None,
931 }
932 .map(|(mut validity_out, validity_rhs)| {
933 for (i, l_range) in OffsetsBuffer::<i64>::leaf_ranges_iter(offsets_lhs).enumerate() {
934 let r_valid = unsafe { validity_rhs.get_bit_unchecked(i) };
935 for l_idx in l_range {
936 let l_valid = unsafe { validity_out.get_unchecked(l_idx) };
937 let is_valid = l_valid & r_valid;
938
939 unsafe { validity_out.set_unchecked(l_idx, is_valid) };
941 }
942 }
943
944 validity_out.freeze()
945 })
946 }
947}