polars-ops 0.54.4

More operations on Polars data structures
Documentation
mod default;
mod groups;
use std::borrow::Cow;
use std::cmp::Ordering;

use default::*;
pub use groups::AsofJoinBy;
use polars_core::prelude::*;
use polars_utils::pl_str::PlSmallStr;
use polars_utils::total_ord::TotalOrd;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};

use super::{_finish_join, build_tables};
use crate::frame::IntoDf;
use crate::series::SeriesMethods;

#[inline]
fn ge_allow_eq<T: TotalOrd>(l: &T, r: &T, allow_eq: bool) -> bool {
    match l.tot_cmp(r) {
        Ordering::Equal => allow_eq,
        Ordering::Greater => true,
        Ordering::Less => false,
    }
}

#[inline]
fn lt_allow_eq<T: TotalOrd>(l: &T, r: &T, allow_eq: bool) -> bool {
    match l.tot_cmp(r) {
        Ordering::Equal => allow_eq,
        Ordering::Less => true,
        Ordering::Greater => false,
    }
}

trait AsofJoinState<T> {
    fn next<F: FnMut(IdxSize) -> Option<T>>(
        &mut self,
        left_val: &T,
        right: F,
        n_right: IdxSize,
    ) -> Option<IdxSize>;

    fn new(allow_eq: bool) -> Self;
}

struct AsofJoinForwardState {
    scan_offset: IdxSize,
    allow_eq: bool,
}

impl<T: TotalOrd> AsofJoinState<T> for AsofJoinForwardState {
    fn new(allow_eq: bool) -> Self {
        AsofJoinForwardState {
            scan_offset: Default::default(),
            allow_eq,
        }
    }
    #[inline]
    fn next<F: FnMut(IdxSize) -> Option<T>>(
        &mut self,
        left_val: &T,
        mut right: F,
        n_right: IdxSize,
    ) -> Option<IdxSize> {
        while (self.scan_offset) < n_right {
            if let Some(right_val) = right(self.scan_offset) {
                if ge_allow_eq(&right_val, left_val, self.allow_eq) {
                    return Some(self.scan_offset);
                }
            }
            self.scan_offset += 1;
        }
        None
    }
}

struct AsofJoinBackwardState {
    // best_bound is the greatest right index <= left_val.
    best_bound: Option<IdxSize>,
    scan_offset: IdxSize,
    allow_eq: bool,
}

impl<T: TotalOrd> AsofJoinState<T> for AsofJoinBackwardState {
    fn new(allow_eq: bool) -> Self {
        AsofJoinBackwardState {
            scan_offset: Default::default(),
            best_bound: Default::default(),
            allow_eq,
        }
    }
    #[inline]
    fn next<F: FnMut(IdxSize) -> Option<T>>(
        &mut self,
        left_val: &T,
        mut right: F,
        n_right: IdxSize,
    ) -> Option<IdxSize> {
        while self.scan_offset < n_right {
            if let Some(right_val) = right(self.scan_offset) {
                if lt_allow_eq(&right_val, left_val, self.allow_eq) {
                    self.best_bound = Some(self.scan_offset);
                } else {
                    break;
                }
            }
            self.scan_offset += 1;
        }
        self.best_bound
    }
}

#[derive(Default)]
struct AsofJoinNearestState {
    /// The last value that is strictly smaller than the current
    /// left value.
    strictly_smaller: Option<IdxSize>,
    /// If `allow_eq == false`: the first value strictly greater than the
    /// current left value.
    /// If `allow_eq == true`: the last value of the first chunk of equal
    /// values that are strictly greater than the current left value.
    upper_candidate: IdxSize,
    allow_eq: bool,
}

impl<T: NumericNative> AsofJoinState<T> for AsofJoinNearestState {
    fn new(allow_eq: bool) -> Self {
        AsofJoinNearestState {
            allow_eq,
            ..Default::default()
        }
    }
    #[inline]
    fn next<F: FnMut(IdxSize) -> Option<T>>(
        &mut self,
        left_val: &T,
        mut right: F,
        n_right: IdxSize,
    ) -> Option<IdxSize> {
        // Skipping ahead to the first value greater than left_val. This is
        // cheaper than computing differences.
        while self.upper_candidate < n_right {
            let Some(scan_right_val) = right(self.upper_candidate) else {
                self.upper_candidate += 1;
                continue;
            };
            if scan_right_val > *left_val {
                break;
            }
            self.upper_candidate += 1;
        }

        if self.allow_eq
            && self.upper_candidate > 0
            && right(self.upper_candidate - 1) == Some(*left_val)
        {
            return Some(self.upper_candidate - 1);
        }

        // It is possible there are later elements equal to our
        // scan, so keep going on.
        while self.upper_candidate + 1 < n_right
            && right(self.upper_candidate + 1) == right(self.upper_candidate)
        {
            self.upper_candidate += 1;
        }

        let mut cursor = self.strictly_smaller.unwrap_or(0);
        while cursor < self.upper_candidate {
            let Some(scan_right_val) = right(cursor) else {
                cursor += 1;
                continue;
            };
            if scan_right_val >= *left_val {
                break;
            }
            self.strictly_smaller = Some(cursor);
            cursor += 1;
        }

        let mut right_get = |idx: IdxSize| (idx < n_right).then(|| right(idx)).flatten();
        let lower = self.strictly_smaller.and_then(&mut right_get);
        let upper = right_get(self.upper_candidate);
        match (lower, upper) {
            (None, None) => None,
            (Some(_), None) => self.strictly_smaller,
            (None, Some(_)) => Some(self.upper_candidate),
            (Some(lo), Some(hi)) => {
                let lo_diff = left_val.abs_diff(lo);
                let hi_diff = left_val.abs_diff(hi);
                if hi_diff <= lo_diff {
                    Some(self.upper_candidate)
                } else {
                    self.strictly_smaller
                }
            },
        }
    }
}

#[derive(Clone, Debug, PartialEq, Default, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
pub struct AsOfOptions {
    pub strategy: AsofStrategy,
    /// A tolerance in the same unit as the asof column
    pub tolerance: Option<Scalar>,
    /// A time duration specified as a string, for example:
    /// - "5m"
    /// - "2h15m"
    /// - "1d6h"
    pub tolerance_str: Option<PlSmallStr>,
    pub left_by: Option<Vec<PlSmallStr>>,
    pub right_by: Option<Vec<PlSmallStr>>,
    /// Allow equal matches
    pub allow_eq: bool,
    pub check_sortedness: bool,
}

pub fn _check_asof_columns(
    a: &Series,
    b: &Series,
    has_tolerance: bool,
    check_sortedness: bool,
    by_groups_present: bool,
) -> PolarsResult<()> {
    let dtype_a = a.dtype();
    let dtype_b = b.dtype();
    if has_tolerance {
        polars_ensure!(
            dtype_a.to_physical().is_primitive_numeric() && dtype_b.to_physical().is_primitive_numeric(),
            InvalidOperation:
            "asof join with tolerance is only supported on numeric/temporal keys"
        );
    } else {
        polars_ensure!(
            dtype_a.to_physical().is_primitive() && dtype_b.to_physical().is_primitive(),
            InvalidOperation:
            "asof join is only supported on primitive key types"
        );
    }
    polars_ensure!(
        dtype_a == dtype_b,
        ComputeError: "mismatching key dtypes in asof-join: `{}` and `{}`",
        a.dtype(), b.dtype()
    );
    if check_sortedness {
        if by_groups_present {
            polars_warn!("Sortedness of columns cannot be checked when 'by' groups provided");
        } else {
            a.ensure_sorted_arg("asof_join")?;
            b.ensure_sorted_arg("asof_join")?;
        }
    }
    Ok(())
}

#[derive(Clone, Copy, Debug, PartialEq, Eq, Default, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
pub enum AsofStrategy {
    /// selects the last row in the right DataFrame whose ‘on’ key is less than or equal to the left’s key
    #[default]
    Backward,
    /// selects the first row in the right DataFrame whose ‘on’ key is greater than or equal to the left’s key.
    Forward,
    /// selects the right in the right DataFrame whose 'on' key is nearest to the left's key.
    Nearest,
}

pub trait AsofJoin: IntoDf {
    #[doc(hidden)]
    #[allow(clippy::too_many_arguments)]
    fn _join_asof(
        &self,
        other: &DataFrame,
        left_key: &Series,
        right_key: &Series,
        strategy: AsofStrategy,
        tolerance: Option<AnyValue<'static>>,
        suffix: Option<PlSmallStr>,
        slice: Option<(i64, usize)>,
        coalesce: bool,
        allow_eq: bool,
        check_sortedness: bool,
    ) -> PolarsResult<DataFrame> {
        let self_df = self.to_df();

        _check_asof_columns(
            left_key,
            right_key,
            tolerance.is_some(),
            check_sortedness,
            false,
        )?;
        let left_key = left_key.to_physical_repr();
        let right_key = right_key.to_physical_repr();

        let mut take_idx =
            _join_asof_dispatch(&left_key, &right_key, strategy, tolerance, allow_eq)?;

        try_raise_keyboard_interrupt();

        // Drop right join column.
        let other = if coalesce && left_key.name() == right_key.name() {
            Cow::Owned(other.drop(right_key.name())?)
        } else {
            Cow::Borrowed(other)
        };

        let mut left = self_df.clone();
        if let Some((offset, len)) = slice {
            left = left.slice(offset, len);
            take_idx = take_idx.slice(offset, len);
        }

        // SAFETY: join tuples are in bounds.
        let right_df = unsafe { other.take_unchecked(&take_idx) };

        _finish_join(left, right_df, suffix)
    }
}

pub fn _join_asof_dispatch(
    left_key: &Series,
    right_key: &Series,
    strategy: AsofStrategy,
    tolerance: Option<AnyValue<'static>>,
    allow_eq: bool,
) -> PolarsResult<IdxCa> {
    let take_idx = match left_key.dtype() {
        DataType::Int8 | DataType::UInt8 | DataType::Int16 | DataType::UInt16 => {
            let left_key = left_key.cast(&DataType::Int32).unwrap();
            let right_key = right_key.cast(&DataType::Int32).unwrap();
            let ca = left_key.i32().unwrap();
            join_asof_numeric(ca, &right_key, strategy, tolerance, allow_eq)
        },
        DataType::Int32 => {
            let ca = left_key.i32().unwrap();
            join_asof_numeric(ca, right_key, strategy, tolerance, allow_eq)
        },
        DataType::Int64 => {
            let ca = left_key.i64().unwrap();
            join_asof_numeric(ca, right_key, strategy, tolerance, allow_eq)
        },
        #[cfg(feature = "dtype-i128")]
        DataType::Int128 => {
            let ca = left_key.i128().unwrap();
            join_asof_numeric(ca, right_key, strategy, tolerance, allow_eq)
        },
        DataType::UInt32 => {
            let ca = left_key.u32().unwrap();
            join_asof_numeric(ca, right_key, strategy, tolerance, allow_eq)
        },
        DataType::UInt64 => {
            let ca = left_key.u64().unwrap();
            join_asof_numeric(ca, right_key, strategy, tolerance, allow_eq)
        },
        #[cfg(feature = "dtype-u128")]
        DataType::UInt128 => {
            let ca = left_key.u128().unwrap();
            join_asof_numeric(ca, right_key, strategy, tolerance, allow_eq)
        },
        #[cfg(feature = "dtype-f16")]
        DataType::Float16 => {
            let ca = left_key.f16().unwrap();
            join_asof_numeric(ca, right_key, strategy, tolerance, allow_eq)
        },
        DataType::Float32 => {
            let ca = left_key.f32().unwrap();
            join_asof_numeric(ca, right_key, strategy, tolerance, allow_eq)
        },
        DataType::Float64 => {
            let ca = left_key.f64().unwrap();
            join_asof_numeric(ca, right_key, strategy, tolerance, allow_eq)
        },
        DataType::Boolean => {
            let ca = left_key.bool().unwrap();
            join_asof::<BooleanType>(ca, right_key, strategy, allow_eq)
        },
        DataType::Binary => {
            let ca = left_key.binary().unwrap();
            join_asof::<BinaryType>(ca, right_key, strategy, allow_eq)
        },
        DataType::String => {
            let ca = left_key.str().unwrap();
            let right_binary = right_key.cast(&DataType::Binary).unwrap();
            join_asof::<BinaryType>(&ca.as_binary(), &right_binary, strategy, allow_eq)
        },
        dt => polars_bail!(opq = asof_join, dt),
    }?;
    Ok(take_idx)
}

impl AsofJoin for DataFrame {}