#[cfg(feature = "std")]
use core::cell::RefCell;
use crate::{
i128_div_mod_floor, i128_shifted_div_mod_floor, i256_div_mod_floor,
ten_pow,
};
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum RoundingMode {
Round05Up,
RoundCeiling,
RoundDown,
RoundFloor,
RoundHalfDown,
RoundHalfEven,
RoundHalfUp,
RoundUp,
}
#[cfg(feature = "std")]
thread_local!(
static DFLT_ROUNDING_MODE: RefCell<RoundingMode> =
const { RefCell::new(RoundingMode::RoundHalfEven) }
);
#[cfg(feature = "std")]
impl Default for RoundingMode {
fn default() -> Self {
DFLT_ROUNDING_MODE.with(|m| *m.borrow())
}
}
#[cfg(feature = "std")]
impl RoundingMode {
pub fn set_default(mode: Self) {
DFLT_ROUNDING_MODE.with(|m| *m.borrow_mut() = mode);
}
}
#[cfg(not(feature = "std"))]
static DFLT_ROUNDING_MODE: RoundingMode = RoundingMode::RoundHalfEven;
#[cfg(not(feature = "std"))]
impl Default for RoundingMode {
fn default() -> Self {
DFLT_ROUNDING_MODE
}
}
pub trait Round
where
Self: Sized,
{
fn round(self, n_frac_digits: i8) -> Self;
fn checked_round(self, n_frac_digits: i8) -> Option<Self>;
}
#[inline]
fn round_quot(
quot: i128,
rem: u128,
divisor: u128,
mode: Option<RoundingMode>,
) -> i128 {
if rem == 0 {
return quot;
}
let mode = mode.unwrap_or_default();
match mode {
RoundingMode::Round05Up => {
if quot >= 0 && quot % 5 == 0 || quot < 0 && (quot + 1) % 5 != 0 {
return quot + 1;
}
}
RoundingMode::RoundCeiling => {
return quot + 1;
}
RoundingMode::RoundDown => {
if quot < 0 {
return quot + 1;
}
}
RoundingMode::RoundFloor => {
return quot;
}
RoundingMode::RoundHalfDown => {
let rem_doubled = rem << 1;
if rem_doubled > divisor || rem_doubled == divisor && quot < 0 {
return quot + 1;
}
}
RoundingMode::RoundHalfEven => {
let rem_doubled = rem << 1;
if rem_doubled > divisor
|| rem_doubled == divisor && quot % 2 != 0
{
return quot + 1;
}
}
RoundingMode::RoundHalfUp => {
let rem_doubled = rem << 1;
if rem_doubled > divisor || rem_doubled == divisor && quot >= 0 {
return quot + 1;
}
}
RoundingMode::RoundUp => {
if quot >= 0 {
return quot + 1;
}
}
}
quot
}
#[doc(hidden)]
#[must_use]
#[allow(clippy::cast_sign_loss)]
pub fn i128_div_rounded(
mut divident: i128,
mut divisor: i128,
mode: Option<RoundingMode>,
) -> i128 {
if divisor < 0 {
divident = -divident;
divisor = -divisor;
}
let (quot, rem) = i128_div_mod_floor(divident, divisor);
round_quot(quot, rem as u128, divisor as u128, mode)
}
#[doc(hidden)]
#[must_use]
#[allow(clippy::cast_sign_loss)]
pub fn i128_shifted_div_rounded(
mut divident: i128,
p: u8,
mut divisor: i128,
mode: Option<RoundingMode>,
) -> Option<i128> {
if divisor < 0 {
divident = -divident;
divisor = -divisor;
}
let (quot, rem) = i128_shifted_div_mod_floor(divident, p, divisor)?;
Some(round_quot(quot, rem as u128, divisor as u128, mode))
}
#[doc(hidden)]
#[must_use]
#[allow(clippy::cast_sign_loss)]
pub fn i128_mul_div_ten_pow_rounded(
x: i128,
y: i128,
p: u8,
mode: Option<RoundingMode>,
) -> Option<i128> {
let divisor = ten_pow(p);
let (quot, rem) = i256_div_mod_floor(x, y, divisor)?;
Some(round_quot(quot, rem as u128, divisor as u128, mode))
}
#[cfg(feature = "std")]
#[cfg(test)]
mod rounding_mode_tests {
use super::*;
#[test]
fn test1() {
assert_eq!(RoundingMode::default(), RoundingMode::RoundHalfEven);
RoundingMode::set_default(RoundingMode::RoundUp);
assert_eq!(RoundingMode::default(), RoundingMode::RoundUp);
RoundingMode::set_default(RoundingMode::RoundHalfEven);
assert_eq!(RoundingMode::default(), RoundingMode::RoundHalfEven);
}
#[test]
fn test2() {
assert_eq!(RoundingMode::default(), RoundingMode::RoundHalfEven);
RoundingMode::set_default(RoundingMode::RoundHalfUp);
assert_eq!(RoundingMode::default(), RoundingMode::RoundHalfUp);
RoundingMode::set_default(RoundingMode::RoundHalfEven);
assert_eq!(RoundingMode::default(), RoundingMode::RoundHalfEven);
}
}
#[cfg(test)]
mod helper_tests {
use super::*;
const TESTDATA: [(i128, i128, RoundingMode, i128); 34] = [
(17, 5, RoundingMode::Round05Up, 3),
(27, 5, RoundingMode::Round05Up, 6),
(-17, 5, RoundingMode::Round05Up, -3),
(-27, 5, RoundingMode::Round05Up, -6),
(17, 5, RoundingMode::RoundCeiling, 4),
(15, 5, RoundingMode::RoundCeiling, 3),
(-17, 5, RoundingMode::RoundCeiling, -3),
(-15, 5, RoundingMode::RoundCeiling, -3),
(19, 5, RoundingMode::RoundDown, 3),
(15, 5, RoundingMode::RoundDown, 3),
(-18, 5, RoundingMode::RoundDown, -3),
(-15, 5, RoundingMode::RoundDown, -3),
(19, 5, RoundingMode::RoundFloor, 3),
(15, 5, RoundingMode::RoundFloor, 3),
(-18, 5, RoundingMode::RoundFloor, -4),
(-15, 5, RoundingMode::RoundFloor, -3),
(19, 2, RoundingMode::RoundHalfDown, 9),
(15, 4, RoundingMode::RoundHalfDown, 4),
(-19, 2, RoundingMode::RoundHalfDown, -9),
(-15, 4, RoundingMode::RoundHalfDown, -4),
(19, 2, RoundingMode::RoundHalfEven, 10),
(15, 4, RoundingMode::RoundHalfEven, 4),
(-225, 50, RoundingMode::RoundHalfEven, -4),
(-15, 4, RoundingMode::RoundHalfEven, -4),
(
u64::MAX as i128,
i64::MIN as i128 * 10,
RoundingMode::RoundHalfEven,
0,
),
(19, 2, RoundingMode::RoundHalfUp, 10),
(10802, 4321, RoundingMode::RoundHalfUp, 2),
(-19, 2, RoundingMode::RoundHalfUp, -10),
(-10802, 4321, RoundingMode::RoundHalfUp, -2),
(19, 2, RoundingMode::RoundUp, 10),
(10802, 4321, RoundingMode::RoundUp, 3),
(-19, 2, RoundingMode::RoundUp, -10),
(-10802, 4321, RoundingMode::RoundUp, -3),
(i32::MAX as i128, 1, RoundingMode::RoundUp, i32::MAX as i128),
];
#[test]
fn test_div_rounded() {
for (divident, divisor, rnd_mode, result) in TESTDATA {
let quot = i128_div_rounded(divident, divisor, Some(rnd_mode));
assert_eq!(quot, result);
}
}
}