use arrow::array::Array;
use arrow::bitmap::Bitmap;
use num_traits::Zero;
use polars_core::prelude::*;
use polars_utils::abs_diff::AbsDiff;
use polars_utils::total_ord::TotalOrd;
use super::{
AsofJoinBackwardState, AsofJoinForwardState, AsofJoinNearestState, AsofJoinState, AsofStrategy,
};
fn join_asof_impl<'a, T, S, F>(
left: &'a T::Array,
right: &'a T::Array,
mut filter: F,
allow_eq: bool,
) -> IdxCa
where
T: PolarsDataType,
S: AsofJoinState<T::Physical<'a>>,
F: FnMut(T::Physical<'a>, T::Physical<'a>) -> bool,
{
if left.len() == left.null_count() || right.len() == right.null_count() {
return IdxCa::full_null(PlSmallStr::EMPTY, left.len());
}
let mut out = vec![0; left.len()];
let mut mask = vec![0; left.len().div_ceil(8)];
let mut state = S::new(allow_eq);
if left.null_count() == 0 && right.null_count() == 0 {
for (i, val_l) in left.values_iter().enumerate() {
if let Some(r_idx) = state.next(
&val_l,
|j| Some(unsafe { right.value_unchecked(j as usize) }),
right.len() as IdxSize,
) {
unsafe {
let val_r = right.value_unchecked(r_idx as usize);
*out.get_unchecked_mut(i) = r_idx;
*mask.get_unchecked_mut(i / 8) |= (filter(val_l, val_r) as u8) << (i % 8);
}
}
}
} else {
for (i, opt_val_l) in left.iter().enumerate() {
if let Some(val_l) = opt_val_l {
if let Some(r_idx) = state.next(
&val_l,
|j| unsafe { right.get_unchecked(j as usize) },
right.len() as IdxSize,
) {
unsafe {
let val_r = right.value_unchecked(r_idx as usize);
*out.get_unchecked_mut(i) = r_idx;
*mask.get_unchecked_mut(i / 8) |= (filter(val_l, val_r) as u8) << (i % 8);
}
}
}
}
}
let bitmap = Bitmap::try_new(mask, out.len()).unwrap();
IdxCa::from_vec_validity(PlSmallStr::EMPTY, out, Some(bitmap))
}
fn join_asof_forward<'a, T, F>(
left: &'a T::Array,
right: &'a T::Array,
filter: F,
allow_eq: bool,
) -> IdxCa
where
T: PolarsDataType,
T::Physical<'a>: TotalOrd,
F: FnMut(T::Physical<'a>, T::Physical<'a>) -> bool,
{
join_asof_impl::<'a, T, AsofJoinForwardState, _>(left, right, filter, allow_eq)
}
fn join_asof_backward<'a, T, F>(
left: &'a T::Array,
right: &'a T::Array,
filter: F,
allow_eq: bool,
) -> IdxCa
where
T: PolarsDataType,
T::Physical<'a>: TotalOrd,
F: FnMut(T::Physical<'a>, T::Physical<'a>) -> bool,
{
join_asof_impl::<'a, T, AsofJoinBackwardState, _>(left, right, filter, allow_eq)
}
fn join_asof_nearest<'a, T, F>(
left: &'a T::Array,
right: &'a T::Array,
filter: F,
allow_eq: bool,
) -> IdxCa
where
T: PolarsDataType,
T::Physical<'a>: NumericNative,
F: FnMut(T::Physical<'a>, T::Physical<'a>) -> bool,
{
join_asof_impl::<'a, T, AsofJoinNearestState, _>(left, right, filter, allow_eq)
}
pub(crate) fn join_asof_numeric<T: PolarsNumericType>(
input_ca: &ChunkedArray<T>,
other: &Series,
strategy: AsofStrategy,
tolerance: Option<AnyValue<'static>>,
allow_eq: bool,
) -> PolarsResult<IdxCa> {
let other = input_ca.unpack_series_matching_type(other)?;
let ca = input_ca.rechunk();
let other = other.rechunk();
let left = ca.downcast_as_array();
let right = other.downcast_as_array();
let out = if let Some(t) = tolerance {
let native_tolerance = t.try_extract::<T::Native>()?;
let abs_tolerance = native_tolerance.abs_diff(T::Native::zero());
let filter = |l: T::Native, r: T::Native| l.abs_diff(r) <= abs_tolerance;
match strategy {
AsofStrategy::Forward => join_asof_forward::<T, _>(left, right, filter, allow_eq),
AsofStrategy::Backward => join_asof_backward::<T, _>(left, right, filter, allow_eq),
AsofStrategy::Nearest => join_asof_nearest::<T, _>(left, right, filter, allow_eq),
}
} else {
let filter = |_l: T::Native, _r: T::Native| true;
match strategy {
AsofStrategy::Forward => join_asof_forward::<T, _>(left, right, filter, allow_eq),
AsofStrategy::Backward => join_asof_backward::<T, _>(left, right, filter, allow_eq),
AsofStrategy::Nearest => join_asof_nearest::<T, _>(left, right, filter, allow_eq),
}
};
Ok(out)
}
pub(crate) fn join_asof<T>(
input_ca: &ChunkedArray<T>,
other: &Series,
strategy: AsofStrategy,
allow_eq: bool,
) -> PolarsResult<IdxCa>
where
T: PolarsDataType,
for<'a> T::Physical<'a>: TotalOrd,
{
let other = input_ca.unpack_series_matching_type(other)?;
let ca = input_ca.rechunk();
let other = other.rechunk();
let left = ca.downcast_iter().next().unwrap();
let right = other.downcast_iter().next().unwrap();
let filter = |_l: T::Physical<'_>, _r: T::Physical<'_>| true;
Ok(match strategy {
AsofStrategy::Forward => {
join_asof_impl::<T, AsofJoinForwardState, _>(left, right, filter, allow_eq)
},
AsofStrategy::Backward => {
join_asof_impl::<T, AsofJoinBackwardState, _>(left, right, filter, allow_eq)
},
AsofStrategy::Nearest => polars_bail!(InvalidOperation:
"AsOf strategy \"nearest\" is not supported for {} data type",
T::get_static_dtype()
),
})
}
#[cfg(test)]
mod test {
use arrow::array::PrimitiveArray;
use super::*;
#[test]
fn test_asof_backward() {
let a = PrimitiveArray::from_slice([-1, 2, 3, 3, 3, 4]);
let b = PrimitiveArray::from_slice([1, 2, 3, 3]);
let tuples = join_asof_backward::<Int32Type, _>(&a, &b, |_, _| true, true);
assert_eq!(tuples.len(), a.len());
assert_eq!(
tuples.to_vec(),
&[None, Some(1), Some(3), Some(3), Some(3), Some(3)]
);
let b = PrimitiveArray::from_slice([1, 2, 4, 5]);
let tuples = join_asof_backward::<Int32Type, _>(&a, &b, |_, _| true, true);
assert_eq!(
tuples.to_vec(),
&[None, Some(1), Some(1), Some(1), Some(1), Some(2)]
);
let a = PrimitiveArray::from_slice([2, 4, 4, 4]);
let b = PrimitiveArray::from_slice([1, 2, 3, 3]);
let tuples = join_asof_backward::<Int32Type, _>(&a, &b, |_, _| true, true);
assert_eq!(tuples.to_vec(), &[Some(1), Some(3), Some(3), Some(3)]);
}
#[test]
fn test_asof_backward_tolerance() {
let a = PrimitiveArray::from_slice([-1, 20, 25, 30, 30, 40]);
let b = PrimitiveArray::from_slice([10, 20, 30, 30]);
let tuples = join_asof_backward::<Int32Type, _>(&a, &b, |l, r| l.abs_diff(r) <= 4u32, true);
assert_eq!(
tuples.to_vec(),
&[None, Some(1), None, Some(3), Some(3), None]
);
}
#[test]
fn test_asof_forward_tolerance() {
let a = PrimitiveArray::from_slice([-1, 20, 25, 30, 30, 40, 52]);
let b = PrimitiveArray::from_slice([10, 20, 33, 55]);
let tuples = join_asof_forward::<Int32Type, _>(&a, &b, |l, r| l.abs_diff(r) <= 4u32, true);
assert_eq!(
tuples.to_vec(),
&[None, Some(1), None, Some(2), Some(2), None, Some(3)]
);
}
#[test]
fn test_asof_forward() {
let a = PrimitiveArray::from_slice([-1, 1, 2, 4, 6]);
let b = PrimitiveArray::from_slice([1, 2, 4, 5]);
let tuples = join_asof_forward::<Int32Type, _>(&a, &b, |_, _| true, true);
assert_eq!(tuples.len(), a.len());
assert_eq!(tuples.to_vec(), &[Some(0), Some(0), Some(1), Some(2), None]);
}
}