use std::ops::{Add, Div, Mul, Sub};
use arrow::array::PrimitiveArray;
use arrow::bitmap::MutableBitmap;
use bytemuck::allocation::zeroed_vec;
use num_traits::{NumCast, Zero};
use polars_core::prelude::*;
use polars_utils::slice::SliceAble;
use super::linear_itp;
#[inline]
unsafe fn signed_interp_by_sorted<T, F>(y_start: T, y_end: T, x: &[F], out: &mut Vec<T>)
where
T: Sub<Output = T>
+ Mul<Output = T>
+ Add<Output = T>
+ Div<Output = T>
+ NumCast
+ Copy
+ Zero,
F: Sub<Output = F> + NumCast + Copy,
{
let range_y = y_end - y_start;
let x_start;
let range_x;
let iter;
unsafe {
x_start = x.get_unchecked(0);
range_x = NumCast::from(*x.get_unchecked(x.len() - 1) - *x_start).unwrap();
iter = x.slice_unchecked(1..x.len() - 1).iter();
}
let slope = range_y / range_x;
for x_i in iter {
let x_delta = NumCast::from(*x_i - *x_start).unwrap();
let v = linear_itp(y_start, x_delta, slope);
out.push(v)
}
}
#[inline]
unsafe fn signed_interp_by<T, F>(
y_start: T,
y_end: T,
x: &[F],
out: &mut [T],
sorting_indices: &[IdxSize],
) where
T: Sub<Output = T>
+ Mul<Output = T>
+ Add<Output = T>
+ Div<Output = T>
+ NumCast
+ Copy
+ Zero,
F: Sub<Output = F> + NumCast + Copy,
{
let range_y = y_end - y_start;
let x_start;
let range_x;
let iter;
unsafe {
x_start = x.get_unchecked(0);
range_x = NumCast::from(*x.get_unchecked(x.len() - 1) - *x_start).unwrap();
iter = x.slice_unchecked(1..x.len() - 1).iter();
}
let slope = range_y / range_x;
for (idx, x_i) in iter.enumerate() {
let x_delta = NumCast::from(*x_i - *x_start).unwrap();
let v = linear_itp(y_start, x_delta, slope);
unsafe {
let out_idx = sorting_indices.get_unchecked(idx + 1);
*out.get_unchecked_mut(*out_idx as usize) = v;
}
}
}
fn interpolate_impl_by_sorted<T, F, I>(
chunked_arr: &ChunkedArray<T>,
by: &ChunkedArray<F>,
interpolation_branch: I,
) -> PolarsResult<ChunkedArray<T>>
where
T: PolarsNumericType,
F: PolarsNumericType,
I: Fn(T::Native, T::Native, &[F::Native], &mut Vec<T::Native>),
{
if !chunked_arr.has_nulls() || chunked_arr.null_count() == chunked_arr.len() {
return Ok(chunked_arr.clone());
}
polars_ensure!(by.null_count() == 0, InvalidOperation: "null values in `by` column are not yet supported in 'interpolate_by' expression");
let by = by.rechunk();
let by_values = by.cont_slice().unwrap();
let first = chunked_arr.first_non_null().unwrap();
let last = chunked_arr.last_non_null().unwrap() + 1;
let mut out = Vec::with_capacity(chunked_arr.len());
let mut iter = chunked_arr.iter().enumerate().skip(first);
for _ in 0..first {
out.push(Zero::zero());
}
let (mut low_idx, opt_low) = iter.next().unwrap();
let mut low = opt_low.unwrap();
out.push(low);
while let Some((idx, next)) = iter.next() {
if let Some(v) = next {
out.push(v);
low = v;
low_idx = idx;
} else {
for (high_idx, next) in iter.by_ref() {
if let Some(high) = next {
unsafe {
let x = &by_values.slice_unchecked(low_idx..high_idx + 1);
interpolation_branch(low, high, x, &mut out);
}
out.push(high);
low = high;
low_idx = high_idx;
break;
}
}
}
}
if first != 0 || last != chunked_arr.len() {
let mut validity = MutableBitmap::with_capacity(chunked_arr.len());
validity.extend_constant(chunked_arr.len(), true);
for i in 0..first {
unsafe { validity.set_unchecked(i, false) };
}
for i in last..chunked_arr.len() {
unsafe { validity.set_unchecked(i, false) }
out.push(Zero::zero());
}
let array = PrimitiveArray::new(
T::get_static_dtype().to_arrow(CompatLevel::newest()),
out.into(),
Some(validity.into()),
);
Ok(ChunkedArray::with_chunk(chunked_arr.name().clone(), array))
} else {
Ok(ChunkedArray::from_vec(chunked_arr.name().clone(), out))
}
}
fn interpolate_impl_by<T, F, I>(
ca: &ChunkedArray<T>,
by: &ChunkedArray<F>,
interpolation_branch: I,
) -> PolarsResult<ChunkedArray<T>>
where
T: PolarsNumericType,
F: PolarsNumericType,
I: Fn(T::Native, T::Native, &[F::Native], &mut [T::Native], &[IdxSize]),
{
if !ca.has_nulls() || ca.null_count() == ca.len() {
return Ok(ca.clone());
}
polars_ensure!(by.null_count() == 0, InvalidOperation: "null values in `by` column are not yet supported in 'interpolate_by' expression");
let sorting_indices = by.arg_sort(Default::default());
let sorting_indices = sorting_indices
.cont_slice()
.expect("arg sort produces single chunk");
let by_sorted = unsafe { by.take_unchecked(sorting_indices) };
let ca_sorted = unsafe { ca.take_unchecked(sorting_indices) };
let by_sorted_values = by_sorted
.cont_slice()
.expect("We already checked for nulls, and `take_unchecked` produces single chunk");
let first = ca_sorted.first_non_null().unwrap();
let last = ca_sorted.last_non_null().unwrap() + 1;
let mut out = zeroed_vec(ca_sorted.len());
let mut iter = ca_sorted.iter().enumerate().skip(first);
let (mut low_idx, opt_low) = iter.next().unwrap();
let mut low = opt_low.unwrap();
unsafe {
let out_idx = sorting_indices.get_unchecked(low_idx);
*out.get_unchecked_mut(*out_idx as usize) = low;
}
while let Some((idx, next)) = iter.next() {
if let Some(v) = next {
unsafe {
let out_idx = sorting_indices.get_unchecked(idx);
*out.get_unchecked_mut(*out_idx as usize) = v;
}
low = v;
low_idx = idx;
} else {
for (high_idx, next) in iter.by_ref() {
if let Some(high) = next {
unsafe {
interpolation_branch(
low,
high,
by_sorted_values.slice_unchecked(low_idx..high_idx + 1),
&mut out,
sorting_indices.slice_unchecked(low_idx..high_idx + 1),
);
let out_idx = sorting_indices.get_unchecked(high_idx);
*out.get_unchecked_mut(*out_idx as usize) = high;
}
low = high;
low_idx = high_idx;
break;
}
}
}
}
if first != 0 || last != ca_sorted.len() {
let mut validity = MutableBitmap::with_capacity(ca_sorted.len());
validity.extend_constant(ca_sorted.len(), true);
for i in 0..first {
unsafe {
let out_idx = sorting_indices.get_unchecked(i);
validity.set_unchecked(*out_idx as usize, false);
}
}
for i in last..ca_sorted.len() {
unsafe {
let out_idx = sorting_indices.get_unchecked(i);
validity.set_unchecked(*out_idx as usize, false);
}
}
let array = PrimitiveArray::new(
T::get_static_dtype().to_arrow(CompatLevel::newest()),
out.into(),
Some(validity.into()),
);
Ok(ChunkedArray::with_chunk(ca_sorted.name().clone(), array))
} else {
Ok(ChunkedArray::from_vec(ca_sorted.name().clone(), out))
}
}
pub fn interpolate_by(s: &Column, by: &Column, by_is_sorted: bool) -> PolarsResult<Column> {
polars_ensure!(s.len() == by.len(), InvalidOperation: "`by` column must be the same length as Series ({}), got {}", s.len(), by.len());
fn func<T, F>(
ca: &ChunkedArray<T>,
by: &ChunkedArray<F>,
is_sorted: bool,
) -> PolarsResult<Column>
where
T: PolarsNumericType,
F: PolarsNumericType,
ChunkedArray<T>: IntoColumn,
{
if is_sorted {
interpolate_impl_by_sorted(ca, by, |y_start, y_end, x, out| unsafe {
signed_interp_by_sorted(y_start, y_end, x, out)
})
.map(|x| x.into_column())
} else {
interpolate_impl_by(ca, by, |y_start, y_end, x, out, sorting_indices| unsafe {
signed_interp_by(y_start, y_end, x, out, sorting_indices)
})
.map(|x| x.into_column())
}
}
match (s.dtype(), by.dtype()) {
(DataType::Float64, DataType::Float64) => {
func(s.f64().unwrap(), by.f64().unwrap(), by_is_sorted)
},
(DataType::Float64, DataType::Float32) => {
func(s.f64().unwrap(), by.f32().unwrap(), by_is_sorted)
},
(DataType::Float32, DataType::Float64) => {
func(s.f32().unwrap(), by.f64().unwrap(), by_is_sorted)
},
(DataType::Float32, DataType::Float32) => {
func(s.f32().unwrap(), by.f32().unwrap(), by_is_sorted)
},
(DataType::Float64, DataType::Int64) => {
func(s.f64().unwrap(), by.i64().unwrap(), by_is_sorted)
},
(DataType::Float64, DataType::Int32) => {
func(s.f64().unwrap(), by.i32().unwrap(), by_is_sorted)
},
(DataType::Float64, DataType::UInt64) => {
func(s.f64().unwrap(), by.u64().unwrap(), by_is_sorted)
},
(DataType::Float64, DataType::UInt32) => {
func(s.f64().unwrap(), by.u32().unwrap(), by_is_sorted)
},
(DataType::Float32, DataType::Int64) => {
func(s.f32().unwrap(), by.i64().unwrap(), by_is_sorted)
},
(DataType::Float32, DataType::Int32) => {
func(s.f32().unwrap(), by.i32().unwrap(), by_is_sorted)
},
(DataType::Float32, DataType::UInt64) => {
func(s.f32().unwrap(), by.u64().unwrap(), by_is_sorted)
},
(DataType::Float32, DataType::UInt32) => {
func(s.f32().unwrap(), by.u32().unwrap(), by_is_sorted)
},
#[cfg(feature = "dtype-date")]
(_, DataType::Date) => interpolate_by(s, &by.cast(&DataType::Int32).unwrap(), by_is_sorted),
#[cfg(feature = "dtype-datetime")]
(_, DataType::Datetime(_, _)) => {
interpolate_by(s, &by.cast(&DataType::Int64).unwrap(), by_is_sorted)
},
(DataType::UInt64 | DataType::UInt32 | DataType::Int64 | DataType::Int32, _) => {
interpolate_by(&s.cast(&DataType::Float64).unwrap(), by, by_is_sorted)
},
_ => {
polars_bail!(InvalidOperation: "expected series to be Float64, Float32, \
Int64, Int32, UInt64, UInt32, and `by` to be Date, Datetime, Int64, Int32, \
UInt64, UInt32, Float32 or Float64")
},
}
}