use std::fmt::{Display, Formatter};
use std::hash::Hash;
use arrow::array::{
Array, BinaryViewArray, ListArray, MutableArray, MutablePlBinary, MutablePrimitiveArray,
PrimitiveArray, Utf8ViewArray,
};
use arrow::bitmap::Bitmap;
use arrow::compute::utils::combine_validities_and;
use arrow::offset::OffsetsBuffer;
use arrow::types::NativeType;
use polars_core::prelude::*;
use polars_core::with_match_physical_numeric_type;
use polars_utils::total_ord::{ToTotalOrd, TotalEq, TotalHash, TotalOrdWrap};
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
trait MaterializeValues<K> {
fn extend_buf<I: Iterator<Item = K>>(&mut self, values: I) -> usize;
}
impl<T> MaterializeValues<Option<T>> for MutablePrimitiveArray<T>
where
T: NativeType,
{
fn extend_buf<I: Iterator<Item = Option<T>>>(&mut self, values: I) -> usize {
self.extend(values);
self.len()
}
}
impl<T> MaterializeValues<TotalOrdWrap<Option<T>>> for MutablePrimitiveArray<T>
where
T: NativeType,
{
fn extend_buf<I: Iterator<Item = TotalOrdWrap<Option<T>>>>(&mut self, values: I) -> usize {
self.extend(values.map(|x| x.0));
self.len()
}
}
impl<'a> MaterializeValues<Option<&'a [u8]>> for MutablePlBinary {
fn extend_buf<I: Iterator<Item = Option<&'a [u8]>>>(&mut self, values: I) -> usize {
self.extend(values);
self.len()
}
}
#[allow(clippy::too_many_arguments)]
fn set_operation<I, J, K, R>(
set: &mut PlIndexSet<K>,
set2: &mut PlIndexSet<K>,
a: &mut I,
b: &mut J,
out: &mut R,
set_op: SetOperation,
broadcast_rhs: bool,
) -> usize
where
K: Eq + Hash + Copy,
I: Iterator<Item = K>,
J: Iterator<Item = K>,
R: MaterializeValues<K>,
{
set.clear();
match set_op {
SetOperation::Intersection => {
set.extend(a);
if !broadcast_rhs {
set2.clear();
set2.extend(b);
}
out.extend_buf(set.intersection(set2).copied())
},
SetOperation::Union => {
set.extend(a);
set.extend(b);
out.extend_buf(set.drain(..))
},
SetOperation::Difference => {
set.extend(a);
for v in b {
set.swap_remove(&v);
}
out.extend_buf(set.drain(..))
},
SetOperation::SymmetricDifference => {
if !broadcast_rhs {
set2.clear();
set2.extend(b);
}
set.extend(a);
out.extend_buf(set.symmetric_difference(set2).copied())
},
}
}
fn copied_wrapper_opt<T: Copy + TotalEq + TotalHash>(
v: Option<&T>,
) -> <Option<T> as ToTotalOrd>::TotalOrdItem {
v.copied().to_total_ord()
}
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
pub enum SetOperation {
Intersection,
Union,
Difference,
SymmetricDifference,
}
impl Display for SetOperation {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
let s = match self {
SetOperation::Intersection => "intersection",
SetOperation::Union => "union",
SetOperation::Difference => "difference",
SetOperation::SymmetricDifference => "symmetric_difference",
};
write!(f, "{s}")
}
}
fn primitive<T>(
a: &PrimitiveArray<T>,
b: &PrimitiveArray<T>,
offsets_a: &[i64],
offsets_b: &[i64],
set_op: SetOperation,
validity: Option<Bitmap>,
) -> PolarsResult<ListArray<i64>>
where
T: NativeType + TotalHash + TotalEq + Copy + ToTotalOrd,
<Option<T> as ToTotalOrd>::TotalOrdItem: Hash + Eq + Copy,
{
let broadcast_lhs = offsets_a.len() == 2;
let broadcast_rhs = offsets_b.len() == 2;
let mut set = Default::default();
let mut set2: PlIndexSet<<Option<T> as ToTotalOrd>::TotalOrdItem> = Default::default();
let mut values_out = MutablePrimitiveArray::with_capacity(std::cmp::max(
*offsets_a.last().unwrap(),
*offsets_b.last().unwrap(),
) as usize);
let mut offsets = Vec::with_capacity(std::cmp::max(offsets_a.len(), offsets_b.len()));
offsets.push(0i64);
let offsets_slice = if offsets_a.len() > offsets_b.len() {
offsets_a
} else {
offsets_b
};
let first_a = offsets_a[0];
let second_a = offsets_a[1];
let first_b = offsets_b[0];
let second_b = offsets_b[1];
if broadcast_rhs {
set2.extend(
b.into_iter()
.skip(first_b as usize)
.take(second_b as usize - first_b as usize)
.map(copied_wrapper_opt),
);
}
let mut iter_a = a.into_iter().skip(first_a as usize);
let mut iter_b = b.into_iter().skip(first_b as usize);
for i in 1..offsets_slice.len() {
let start_a = *offsets_a.get(i - 1).unwrap_or(&first_a) as usize;
let end_a = *offsets_a.get(i).unwrap_or(&second_a) as usize;
let start_b = *offsets_b.get(i - 1).unwrap_or(&first_b) as usize;
let end_b = *offsets_b.get(i).unwrap_or(&second_b) as usize;
let mut iter_a_broadcast = iter_a.clone();
let mut iter_b_broadcast = iter_b.clone();
let mut iter_a = if broadcast_lhs {
iter_a_broadcast
.by_ref()
.take(second_a as usize - first_a as usize)
.map(copied_wrapper_opt)
} else {
iter_a
.by_ref()
.take(end_a - start_a)
.map(copied_wrapper_opt)
};
let mut iter_b = if broadcast_rhs {
iter_b_broadcast
.by_ref()
.take(second_b as usize - first_b as usize)
.map(copied_wrapper_opt)
} else {
iter_b
.by_ref()
.take(end_b - start_b)
.map(copied_wrapper_opt)
};
let offset = set_operation(
&mut set,
&mut set2,
&mut iter_a,
&mut iter_b,
&mut values_out,
set_op,
broadcast_rhs,
);
assert!(iter_a.next().is_none());
if !broadcast_rhs || matches!(set_op, SetOperation::Union | SetOperation::Difference) {
assert!(iter_b.next().is_none());
};
offsets.push(offset as i64);
}
let offsets = unsafe { OffsetsBuffer::new_unchecked(offsets.into()) };
let dtype = ListArray::<i64>::default_datatype(values_out.dtype().clone());
let values: PrimitiveArray<T> = values_out.into();
Ok(ListArray::new(dtype, offsets, values.boxed(), validity))
}
fn binary(
a: &BinaryViewArray,
b: &BinaryViewArray,
offsets_a: &[i64],
offsets_b: &[i64],
set_op: SetOperation,
validity: Option<Bitmap>,
as_utf8: bool,
) -> PolarsResult<ListArray<i64>> {
let broadcast_lhs = offsets_a.len() == 2;
let broadcast_rhs = offsets_b.len() == 2;
let mut set: PlIndexSet<Option<&[u8]>> = Default::default();
let mut set2: PlIndexSet<Option<&[u8]>> = Default::default();
let mut values_out = MutablePlBinary::with_capacity(std::cmp::max(
*offsets_a.last().unwrap(),
*offsets_b.last().unwrap(),
) as usize);
let mut offsets = Vec::with_capacity(std::cmp::max(offsets_a.len(), offsets_b.len()));
offsets.push(0i64);
let offsets_slice = if offsets_a.len() > offsets_b.len() {
offsets_a
} else {
offsets_b
};
let first_a = offsets_a[0];
let second_a = offsets_a[1];
let first_b = offsets_b[0];
let second_b = offsets_b[1];
if broadcast_rhs {
set2.extend(
b.into_iter()
.skip(first_b as usize)
.take(second_b as usize - first_b as usize),
);
}
let mut iter_a = a.into_iter().skip(first_a as usize);
let mut iter_b = b.into_iter().skip(first_b as usize);
for i in 1..offsets_slice.len() {
let start_a = *offsets_a.get(i - 1).unwrap_or(&first_a) as usize;
let end_a = *offsets_a.get(i).unwrap_or(&second_a) as usize;
let start_b = *offsets_b.get(i - 1).unwrap_or(&first_b) as usize;
let end_b = *offsets_b.get(i).unwrap_or(&second_b) as usize;
let mut iter_a_broadcast = iter_a.clone();
let mut iter_b_broadcast = iter_b.clone();
let mut iter_a = if broadcast_lhs {
iter_a_broadcast
.by_ref()
.take(second_a as usize - first_a as usize)
} else {
iter_a.by_ref().take(end_a - start_a)
};
let mut iter_b = if broadcast_rhs {
iter_b_broadcast
.by_ref()
.take(second_b as usize - first_b as usize)
} else {
iter_b.by_ref().take(end_b - start_b)
};
let offset = set_operation(
&mut set,
&mut set2,
&mut iter_a,
&mut iter_b,
&mut values_out,
set_op,
broadcast_rhs,
);
assert!(iter_a.next().is_none());
if !broadcast_rhs || matches!(set_op, SetOperation::Union | SetOperation::Difference) {
assert!(iter_b.next().is_none());
};
offsets.push(offset as i64);
}
let offsets = unsafe { OffsetsBuffer::new_unchecked(offsets.into()) };
let values = values_out.freeze();
if as_utf8 {
let values = unsafe { values.to_utf8view_unchecked() };
let dtype = ListArray::<i64>::default_datatype(values.dtype().clone());
Ok(ListArray::new(dtype, offsets, values.boxed(), validity))
} else {
let dtype = ListArray::<i64>::default_datatype(values.dtype().clone());
Ok(ListArray::new(dtype, offsets, values.boxed(), validity))
}
}
fn array_set_operation(
a: &ListArray<i64>,
b: &ListArray<i64>,
set_op: SetOperation,
) -> PolarsResult<ListArray<i64>> {
let offsets_a = a.offsets().as_slice();
let offsets_b = b.offsets().as_slice();
let values_a = a.values();
let values_b = b.values();
assert_eq!(values_a.dtype(), values_b.dtype());
let dtype = values_b.dtype();
let validity = combine_validities_and(a.validity(), b.validity());
match dtype {
ArrowDataType::Utf8View => {
let a = values_a
.as_any()
.downcast_ref::<Utf8ViewArray>()
.unwrap()
.to_binview();
let b = values_b
.as_any()
.downcast_ref::<Utf8ViewArray>()
.unwrap()
.to_binview();
binary(&a, &b, offsets_a, offsets_b, set_op, validity, true)
},
ArrowDataType::BinaryView => {
let a = values_a.as_any().downcast_ref::<BinaryViewArray>().unwrap();
let b = values_b.as_any().downcast_ref::<BinaryViewArray>().unwrap();
binary(a, b, offsets_a, offsets_b, set_op, validity, false)
},
ArrowDataType::Boolean => {
polars_bail!(InvalidOperation: "boolean type not yet supported in list 'set' operations")
},
_ => {
with_match_physical_numeric_type!(DataType::from_arrow_dtype(dtype), |$T| {
let a = values_a.as_any().downcast_ref::<PrimitiveArray<$T>>().unwrap();
let b = values_b.as_any().downcast_ref::<PrimitiveArray<$T>>().unwrap();
primitive(&a, &b, offsets_a, offsets_b, set_op, validity)
})
},
}
}
pub fn list_set_operation(
a: &ListChunked,
b: &ListChunked,
set_op: SetOperation,
) -> PolarsResult<ListChunked> {
polars_ensure!(a.len() == b.len() || b.len() == 1 || a.len() == 1, ShapeMismatch: "column lengths don't match");
polars_ensure!(a.dtype() == b.dtype(), InvalidOperation: "cannot do 'set' operation on dtypes: {} and {}", a.dtype(), b.dtype());
let mut a = a.clone();
let mut b = b.clone();
if a.len() != b.len() {
a.rechunk_mut();
b.rechunk_mut();
}
a.prune_empty_chunks();
b.prune_empty_chunks();
unsafe {
arity::try_binary_unchecked_same_type(
&a,
&b,
|a, b| array_set_operation(a, b, set_op).map(|arr| arr.boxed()),
false,
false,
)
}
}