use num_traits::AsPrimitive;
use polars_core::prelude::*;
use polars_core::with_match_physical_numeric_polars_type;
use polars_utils::float16::pf16;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use strum_macros::IntoStaticStr;
use crate::series::ops::SeriesSealed;
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, IntoStaticStr)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
#[strum(serialize_all = "snake_case")]
#[derive(Default)]
pub enum RoundMode {
#[default]
HalfToEven,
HalfAwayFromZero,
ToZero,
}
fn apply_float_rounding(
s: &Series,
decimals: u32,
f32_op: fn(f32) -> f32,
f64_op: fn(f64) -> f64,
) -> Option<PolarsResult<Series>> {
#[cfg(feature = "dtype-f16")]
if let Ok(ca) = s.f16() {
return Some(if decimals == 0 {
let s = ca
.apply_values(|val| f32_op(f32::from(val)).into())
.into_series();
Ok(s)
} else if decimals >= 11 {
Ok(s.clone())
} else {
let multiplier = 10.0_f32.powi(decimals as i32);
let s = ca
.apply_values(|val| {
let val_f32: f32 = val.into();
let ret: pf16 = (f32_op(val_f32 * multiplier) / multiplier).into();
if ret.is_finite() {
ret
} else {
val
}
})
.into_series();
Ok(s)
});
}
if let Ok(ca) = s.f32() {
return Some(if decimals == 0 {
let s = ca.apply_values(f32_op).into_series();
Ok(s)
} else if decimals >= 47 {
Ok(s.clone())
} else {
let multiplier = 10.0_f64.powi(decimals as i32);
let s = ca
.apply_values(|val| {
let ret = (f64_op(val as f64 * multiplier) / multiplier) as f32;
if ret.is_finite() {
ret
} else {
val
}
})
.into_series();
Ok(s)
});
}
if let Ok(ca) = s.f64() {
return Some(if decimals == 0 {
let s = ca.apply_values(f64_op).into_series();
Ok(s)
} else if decimals >= 326 {
Ok(s.clone())
} else if decimals >= 300 {
let mul2 = libm::scalbn(1.0, decimals as i32);
let invmul2 = 1.0 / mul2; let mul5 = 5.0_f64.powi(decimals as i32);
let s = ca
.apply_values(|val| {
let ret = f64_op(val * mul2 * mul5) / mul5 * invmul2;
if ret.is_finite() {
ret
} else {
val
}
})
.into_series();
Ok(s)
} else {
let multiplier = 10.0_f64.powi(decimals as i32);
let s = ca
.apply_values(|val| {
let ret = f64_op(val * multiplier) / multiplier;
if ret.is_finite() { ret } else { val }
})
.into_series();
Ok(s)
});
}
None
}
pub trait RoundSeries: SeriesSealed {
fn round(&self, decimals: u32, mode: RoundMode) -> PolarsResult<Series> {
let s = self.as_series();
#[allow(clippy::type_complexity)]
let (f32_op, f64_op): (fn(f32) -> f32, fn(f64) -> f64) = match mode {
RoundMode::HalfToEven => (f32::round_ties_even, f64::round_ties_even),
RoundMode::HalfAwayFromZero => (f32::round, f64::round),
RoundMode::ToZero => (f32::trunc, f64::trunc),
};
if let Some(result) = apply_float_rounding(s, decimals, f32_op, f64_op) {
return result;
}
#[cfg(feature = "dtype-decimal")]
if let Some(ca) = s.try_decimal() {
let scale = ca.scale() as u32;
if scale <= decimals {
return Ok(ca.clone().into_series());
}
let decimal_delta = scale - decimals;
let multiplier = 10i128.pow(decimal_delta);
let threshold = multiplier / 2;
let res = match mode {
RoundMode::HalfToEven => ca.physical().apply_values(|v| {
let rem_big = v % (2 * multiplier);
let is_v_floor_even = rem_big.abs() < multiplier;
let rem = if is_v_floor_even {
rem_big
} else if rem_big > 0 {
rem_big - multiplier
} else {
rem_big + multiplier
};
let threshold = threshold + i128::from(is_v_floor_even);
let round_offset = if rem.abs() >= threshold {
if v < 0 { -multiplier } else { multiplier }
} else {
0
};
v - rem + round_offset
}),
RoundMode::HalfAwayFromZero => ca.physical().apply_values(|v| {
let rem = v % multiplier;
let round_offset = if rem.abs() >= threshold {
if v < 0 { -multiplier } else { multiplier }
} else {
0
};
v - rem + round_offset
}),
RoundMode::ToZero => ca.physical().apply_values(|v| v - (v % multiplier)),
};
return Ok(res
.into_decimal_unchecked(ca.precision(), scale as usize)
.into_series());
}
let op = match mode {
RoundMode::ToZero => "truncation ('to_zero')",
RoundMode::HalfToEven => "rounding ('half_to_even')",
RoundMode::HalfAwayFromZero => "rounding ('half_away_from_zero')",
};
polars_ensure!(s.dtype().is_integer(), InvalidOperation: "{} can only be used on numeric types", op);
Ok(s.clone())
}
fn round_sig_figs(&self, digits: i32) -> PolarsResult<Series> {
let s = self.as_series();
polars_ensure!(digits >= 1, InvalidOperation: "digits must be an integer >= 1");
#[cfg(feature = "dtype-decimal")]
if let Some(ca) = s.try_decimal() {
let precision = ca.precision();
let scale = ca.scale() as u32;
let s = ca
.physical()
.apply_values(|v| {
if v == 0 {
return 0;
}
let mut magnitude = v.abs().ilog10();
let magnitude_mult = 10i128.pow(magnitude); if v.abs() > magnitude_mult {
magnitude += 1;
}
let decimals = magnitude.saturating_sub(digits as u32);
let multiplier = 10i128.pow(decimals); let threshold = multiplier / 2;
let rem = v % multiplier;
let is_v_floor_even = decimals <= scale && ((v - rem) / multiplier) % 2 == 0;
let threshold = threshold + i128::from(is_v_floor_even);
let round_offset = if rem.abs() >= threshold {
multiplier
} else {
0
};
let round_offset = if v < 0 { -round_offset } else { round_offset };
v - rem + round_offset
})
.into_decimal_unchecked(precision, scale as usize)
.into_series();
return Ok(s);
}
polars_ensure!(s.dtype().is_primitive_numeric(), InvalidOperation: "round_sig_figs can only be used on numeric types" );
with_match_physical_numeric_polars_type!(s.dtype(), |$T| {
let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref();
let s = ca.apply_values(|value| {
let value = AsPrimitive::<f64>::as_(value);
if value == 0.0 {
return AsPrimitive::<<$T as PolarsNumericType>::Native>::as_(value);
}
let exp = digits - 1 - value.abs().log10().floor() as i32;
let pow5 = 5.0_f64.powi(exp);
let scaled = libm::scalbn(value, exp) * pow5;
let descaled = libm::scalbn(scaled.round() / pow5, -exp);
AsPrimitive::<<$T as PolarsNumericType>::Native>::as_(
if descaled.is_finite() { descaled } else { value }
)
}).into_series();
return Ok(s);
});
}
fn truncate(&self, decimals: u32) -> PolarsResult<Series> {
self.round(decimals, RoundMode::ToZero)
}
fn floor(&self) -> PolarsResult<Series> {
let s = self.as_series();
if let Ok(ca) = s.f32() {
let s = ca.apply_values(|val| val.floor()).into_series();
return Ok(s);
}
if let Ok(ca) = s.f64() {
let s = ca.apply_values(|val| val.floor()).into_series();
return Ok(s);
}
#[cfg(feature = "dtype-decimal")]
if let Some(ca) = s.try_decimal() {
let precision = ca.precision();
let scale = ca.scale() as u32;
if scale == 0 {
return Ok(ca.clone().into_series());
}
let decimal_delta = scale;
let multiplier = 10i128.pow(decimal_delta);
let ca = ca
.physical()
.apply_values(|v| {
let rem = v % multiplier;
let round_offset = if v < 0 { multiplier + rem } else { rem };
let round_offset = if rem == 0 { 0 } else { round_offset };
v - round_offset
})
.into_decimal_unchecked(precision, scale as usize);
return Ok(ca.into_series());
}
polars_ensure!(s.dtype().is_primitive_numeric(), InvalidOperation: "floor can only be used on numeric types" );
Ok(s.clone())
}
fn ceil(&self) -> PolarsResult<Series> {
let s = self.as_series();
if let Ok(ca) = s.f32() {
let s = ca.apply_values(|val| val.ceil()).into_series();
return Ok(s);
}
if let Ok(ca) = s.f64() {
let s = ca.apply_values(|val| val.ceil()).into_series();
return Ok(s);
}
#[cfg(feature = "dtype-decimal")]
if let Some(ca) = s.try_decimal() {
let precision = ca.precision();
let scale = ca.scale() as u32;
if scale == 0 {
return Ok(ca.clone().into_series());
}
let decimal_delta = scale;
let multiplier = 10i128.pow(decimal_delta);
let ca = ca
.physical()
.apply_values(|v| {
let rem = v % multiplier;
let round_offset = if v < 0 { -rem } else { multiplier - rem };
let round_offset = if rem == 0 { 0 } else { round_offset };
v + round_offset
})
.into_decimal_unchecked(precision, scale as usize);
return Ok(ca.into_series());
}
polars_ensure!(s.dtype().is_primitive_numeric(), InvalidOperation: "ceil can only be used on numeric types" );
Ok(s.clone())
}
}
impl RoundSeries for Series {}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_round_series() {
let series = Series::new("a".into(), &[1.003, 2.23222, 3.4352]);
let out = series.round(2, RoundMode::default()).unwrap();
let ca = out.f64().unwrap();
assert_eq!(ca.get(0), Some(1.0));
}
}