mod default;
mod groups;
use std::borrow::Cow;
use std::cmp::Ordering;
use default::*;
pub use groups::AsofJoinBy;
use polars_core::prelude::*;
use polars_utils::pl_str::PlSmallStr;
use polars_utils::total_ord::TotalOrd;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use super::{_finish_join, build_tables};
use crate::frame::IntoDf;
use crate::series::SeriesMethods;
#[inline]
fn ge_allow_eq<T: TotalOrd>(l: &T, r: &T, allow_eq: bool) -> bool {
match l.tot_cmp(r) {
Ordering::Equal => allow_eq,
Ordering::Greater => true,
Ordering::Less => false,
}
}
#[inline]
fn lt_allow_eq<T: TotalOrd>(l: &T, r: &T, allow_eq: bool) -> bool {
match l.tot_cmp(r) {
Ordering::Equal => allow_eq,
Ordering::Less => true,
Ordering::Greater => false,
}
}
trait AsofJoinState<T> {
fn next<F: FnMut(IdxSize) -> Option<T>>(
&mut self,
left_val: &T,
right: F,
n_right: IdxSize,
) -> Option<IdxSize>;
fn new(allow_eq: bool) -> Self;
}
struct AsofJoinForwardState {
scan_offset: IdxSize,
allow_eq: bool,
}
impl<T: TotalOrd> AsofJoinState<T> for AsofJoinForwardState {
fn new(allow_eq: bool) -> Self {
AsofJoinForwardState {
scan_offset: Default::default(),
allow_eq,
}
}
#[inline]
fn next<F: FnMut(IdxSize) -> Option<T>>(
&mut self,
left_val: &T,
mut right: F,
n_right: IdxSize,
) -> Option<IdxSize> {
while (self.scan_offset) < n_right {
if let Some(right_val) = right(self.scan_offset) {
if ge_allow_eq(&right_val, left_val, self.allow_eq) {
return Some(self.scan_offset);
}
}
self.scan_offset += 1;
}
None
}
}
struct AsofJoinBackwardState {
best_bound: Option<IdxSize>,
scan_offset: IdxSize,
allow_eq: bool,
}
impl<T: TotalOrd> AsofJoinState<T> for AsofJoinBackwardState {
fn new(allow_eq: bool) -> Self {
AsofJoinBackwardState {
scan_offset: Default::default(),
best_bound: Default::default(),
allow_eq,
}
}
#[inline]
fn next<F: FnMut(IdxSize) -> Option<T>>(
&mut self,
left_val: &T,
mut right: F,
n_right: IdxSize,
) -> Option<IdxSize> {
while self.scan_offset < n_right {
if let Some(right_val) = right(self.scan_offset) {
if lt_allow_eq(&right_val, left_val, self.allow_eq) {
self.best_bound = Some(self.scan_offset);
} else {
break;
}
}
self.scan_offset += 1;
}
self.best_bound
}
}
#[derive(Default)]
struct AsofJoinNearestState {
strictly_smaller: Option<IdxSize>,
upper_candidate: IdxSize,
allow_eq: bool,
}
impl<T: NumericNative> AsofJoinState<T> for AsofJoinNearestState {
fn new(allow_eq: bool) -> Self {
AsofJoinNearestState {
allow_eq,
..Default::default()
}
}
#[inline]
fn next<F: FnMut(IdxSize) -> Option<T>>(
&mut self,
left_val: &T,
mut right: F,
n_right: IdxSize,
) -> Option<IdxSize> {
while self.upper_candidate < n_right {
let Some(scan_right_val) = right(self.upper_candidate) else {
self.upper_candidate += 1;
continue;
};
if scan_right_val > *left_val {
break;
}
self.upper_candidate += 1;
}
if self.allow_eq
&& self.upper_candidate > 0
&& right(self.upper_candidate - 1) == Some(*left_val)
{
return Some(self.upper_candidate - 1);
}
while self.upper_candidate + 1 < n_right
&& right(self.upper_candidate + 1) == right(self.upper_candidate)
{
self.upper_candidate += 1;
}
let mut cursor = self.strictly_smaller.unwrap_or(0);
while cursor < self.upper_candidate {
let Some(scan_right_val) = right(cursor) else {
cursor += 1;
continue;
};
if scan_right_val >= *left_val {
break;
}
self.strictly_smaller = Some(cursor);
cursor += 1;
}
let mut right_get = |idx: IdxSize| (idx < n_right).then(|| right(idx)).flatten();
let lower = self.strictly_smaller.and_then(&mut right_get);
let upper = right_get(self.upper_candidate);
match (lower, upper) {
(None, None) => None,
(Some(_), None) => self.strictly_smaller,
(None, Some(_)) => Some(self.upper_candidate),
(Some(lo), Some(hi)) => {
let lo_diff = left_val.abs_diff(lo);
let hi_diff = left_val.abs_diff(hi);
if hi_diff <= lo_diff {
Some(self.upper_candidate)
} else {
self.strictly_smaller
}
},
}
}
}
#[derive(Clone, Debug, PartialEq, Default, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
pub struct AsOfOptions {
pub strategy: AsofStrategy,
pub tolerance: Option<Scalar>,
pub tolerance_str: Option<PlSmallStr>,
pub left_by: Option<Vec<PlSmallStr>>,
pub right_by: Option<Vec<PlSmallStr>>,
pub allow_eq: bool,
pub check_sortedness: bool,
}
pub fn _check_asof_columns(
a: &Series,
b: &Series,
has_tolerance: bool,
check_sortedness: bool,
by_groups_present: bool,
) -> PolarsResult<()> {
let dtype_a = a.dtype();
let dtype_b = b.dtype();
if has_tolerance {
polars_ensure!(
dtype_a.to_physical().is_primitive_numeric() && dtype_b.to_physical().is_primitive_numeric(),
InvalidOperation:
"asof join with tolerance is only supported on numeric/temporal keys"
);
} else {
polars_ensure!(
dtype_a.to_physical().is_primitive() && dtype_b.to_physical().is_primitive(),
InvalidOperation:
"asof join is only supported on primitive key types"
);
}
polars_ensure!(
dtype_a == dtype_b,
ComputeError: "mismatching key dtypes in asof-join: `{}` and `{}`",
a.dtype(), b.dtype()
);
if check_sortedness {
if by_groups_present {
polars_warn!("Sortedness of columns cannot be checked when 'by' groups provided");
} else {
a.ensure_sorted_arg("asof_join")?;
b.ensure_sorted_arg("asof_join")?;
}
}
Ok(())
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Default, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
pub enum AsofStrategy {
#[default]
Backward,
Forward,
Nearest,
}
pub trait AsofJoin: IntoDf {
#[doc(hidden)]
#[allow(clippy::too_many_arguments)]
fn _join_asof(
&self,
other: &DataFrame,
left_key: &Series,
right_key: &Series,
strategy: AsofStrategy,
tolerance: Option<AnyValue<'static>>,
suffix: Option<PlSmallStr>,
slice: Option<(i64, usize)>,
coalesce: bool,
allow_eq: bool,
check_sortedness: bool,
) -> PolarsResult<DataFrame> {
let self_df = self.to_df();
_check_asof_columns(
left_key,
right_key,
tolerance.is_some(),
check_sortedness,
false,
)?;
let left_key = left_key.to_physical_repr();
let right_key = right_key.to_physical_repr();
let mut take_idx =
_join_asof_dispatch(&left_key, &right_key, strategy, tolerance, allow_eq)?;
try_raise_keyboard_interrupt();
let other = if coalesce && left_key.name() == right_key.name() {
Cow::Owned(other.drop(right_key.name())?)
} else {
Cow::Borrowed(other)
};
let mut left = self_df.clone();
if let Some((offset, len)) = slice {
left = left.slice(offset, len);
take_idx = take_idx.slice(offset, len);
}
let right_df = unsafe { other.take_unchecked(&take_idx) };
_finish_join(left, right_df, suffix)
}
}
pub fn _join_asof_dispatch(
left_key: &Series,
right_key: &Series,
strategy: AsofStrategy,
tolerance: Option<AnyValue<'static>>,
allow_eq: bool,
) -> PolarsResult<IdxCa> {
let take_idx = match left_key.dtype() {
DataType::Int8 | DataType::UInt8 | DataType::Int16 | DataType::UInt16 => {
let left_key = left_key.cast(&DataType::Int32).unwrap();
let right_key = right_key.cast(&DataType::Int32).unwrap();
let ca = left_key.i32().unwrap();
join_asof_numeric(ca, &right_key, strategy, tolerance, allow_eq)
},
DataType::Int32 => {
let ca = left_key.i32().unwrap();
join_asof_numeric(ca, right_key, strategy, tolerance, allow_eq)
},
DataType::Int64 => {
let ca = left_key.i64().unwrap();
join_asof_numeric(ca, right_key, strategy, tolerance, allow_eq)
},
#[cfg(feature = "dtype-i128")]
DataType::Int128 => {
let ca = left_key.i128().unwrap();
join_asof_numeric(ca, right_key, strategy, tolerance, allow_eq)
},
DataType::UInt32 => {
let ca = left_key.u32().unwrap();
join_asof_numeric(ca, right_key, strategy, tolerance, allow_eq)
},
DataType::UInt64 => {
let ca = left_key.u64().unwrap();
join_asof_numeric(ca, right_key, strategy, tolerance, allow_eq)
},
#[cfg(feature = "dtype-u128")]
DataType::UInt128 => {
let ca = left_key.u128().unwrap();
join_asof_numeric(ca, right_key, strategy, tolerance, allow_eq)
},
#[cfg(feature = "dtype-f16")]
DataType::Float16 => {
let ca = left_key.f16().unwrap();
join_asof_numeric(ca, right_key, strategy, tolerance, allow_eq)
},
DataType::Float32 => {
let ca = left_key.f32().unwrap();
join_asof_numeric(ca, right_key, strategy, tolerance, allow_eq)
},
DataType::Float64 => {
let ca = left_key.f64().unwrap();
join_asof_numeric(ca, right_key, strategy, tolerance, allow_eq)
},
DataType::Boolean => {
let ca = left_key.bool().unwrap();
join_asof::<BooleanType>(ca, right_key, strategy, allow_eq)
},
DataType::Binary => {
let ca = left_key.binary().unwrap();
join_asof::<BinaryType>(ca, right_key, strategy, allow_eq)
},
DataType::String => {
let ca = left_key.str().unwrap();
let right_binary = right_key.cast(&DataType::Binary).unwrap();
join_asof::<BinaryType>(&ca.as_binary(), &right_binary, strategy, allow_eq)
},
dt => polars_bail!(opq = asof_join, dt),
}?;
Ok(take_idx)
}
impl AsofJoin for DataFrame {}