use crate::arithmetic::pow10;
use crate::decimal::Decimal;
use crate::error::ArithmeticError;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
pub enum RoundingMode {
Down,
Up,
TowardZero,
AwayFromZero,
HalfUp,
HalfDown,
#[default]
HalfEven,
}
impl Decimal {
pub fn round(self, dp: u8, mode: RoundingMode) -> Result<Self, ArithmeticError> {
use crate::decimal::MAX_SCALE;
if dp > MAX_SCALE {
return Err(ArithmeticError::ScaleExceeded);
}
if dp >= self.scale {
return Ok(self);
}
let diff = self.scale - dp;
let factor = pow10(diff)?; let half = factor / 2;
let quotient = self.mantissa / factor;
let remainder = self.mantissa % factor;
let abs_rem = remainder.unsigned_abs() as i128;
let adjusted = match mode {
RoundingMode::TowardZero => quotient,
RoundingMode::AwayFromZero => {
if remainder != 0 {
quotient + quotient.signum().max(1) * remainder.signum()
} else {
quotient
}
}
RoundingMode::Down => {
if remainder < 0 {
quotient - 1
} else {
quotient
}
}
RoundingMode::Up => {
if remainder > 0 {
quotient + 1
} else {
quotient
}
}
RoundingMode::HalfUp => {
if abs_rem >= half {
if self.mantissa >= 0 {
quotient + 1
} else {
quotient - 1
}
} else {
quotient
}
}
RoundingMode::HalfDown => {
if abs_rem > half {
if self.mantissa >= 0 {
quotient + 1
} else {
quotient - 1
}
} else {
quotient
}
}
RoundingMode::HalfEven => {
if abs_rem > half {
if self.mantissa >= 0 {
quotient + 1
} else {
quotient - 1
}
} else if abs_rem == half {
if quotient % 2 != 0 {
if self.mantissa >= 0 {
quotient + 1
} else {
quotient - 1
}
} else {
quotient
}
} else {
quotient
}
}
};
Decimal::new(adjusted, dp)
}
pub fn rescale_up(self, new_scale: u8) -> Result<Self, ArithmeticError> {
if new_scale <= self.scale {
return Ok(self);
}
let diff = new_scale - self.scale;
let factor = pow10(diff)?;
let mantissa = self
.mantissa
.checked_mul(factor)
.ok_or(ArithmeticError::Overflow)?;
Decimal::new(mantissa, new_scale)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::decimal::Decimal;
fn d(mantissa: i128, scale: u8) -> Decimal {
Decimal::new(mantissa, scale).unwrap()
}
#[test]
fn round_toward_zero_positive() {
assert_eq!(d(19, 1).round(0, RoundingMode::TowardZero).unwrap(), d(1, 0));
}
#[test]
fn round_toward_zero_negative() {
assert_eq!(
d(-19, 1).round(0, RoundingMode::TowardZero).unwrap(),
d(-1, 0)
);
}
#[test]
fn round_away_from_zero_positive() {
assert_eq!(
d(11, 1).round(0, RoundingMode::AwayFromZero).unwrap(),
d(2, 0)
);
}
#[test]
fn round_away_from_zero_negative() {
assert_eq!(
d(-11, 1).round(0, RoundingMode::AwayFromZero).unwrap(),
d(-2, 0)
);
}
#[test]
fn round_away_from_zero_exact() {
assert_eq!(
d(10, 1).round(0, RoundingMode::AwayFromZero).unwrap(),
d(1, 0)
);
}
#[test]
fn round_down_positive() {
assert_eq!(d(19, 1).round(0, RoundingMode::Down).unwrap(), d(1, 0));
}
#[test]
fn round_down_negative() {
assert_eq!(d(-19, 1).round(0, RoundingMode::Down).unwrap(), d(-2, 0));
}
#[test]
fn round_up_positive() {
assert_eq!(d(11, 1).round(0, RoundingMode::Up).unwrap(), d(2, 0));
}
#[test]
fn round_up_negative() {
assert_eq!(d(-11, 1).round(0, RoundingMode::Up).unwrap(), d(-1, 0));
}
#[test]
fn round_half_up_at_midpoint() {
assert_eq!(d(5, 1).round(0, RoundingMode::HalfUp).unwrap(), d(1, 0));
}
#[test]
fn round_half_up_below_midpoint() {
assert_eq!(d(4, 1).round(0, RoundingMode::HalfUp).unwrap(), d(0, 0));
}
#[test]
fn round_half_up_negative_midpoint() {
assert_eq!(d(-5, 1).round(0, RoundingMode::HalfUp).unwrap(), d(-1, 0));
}
#[test]
fn round_half_down_at_midpoint() {
assert_eq!(d(5, 1).round(0, RoundingMode::HalfDown).unwrap(), d(0, 0));
}
#[test]
fn round_half_down_above_midpoint() {
assert_eq!(d(6, 1).round(0, RoundingMode::HalfDown).unwrap(), d(1, 0));
}
#[test]
fn round_half_even_round_to_even_up() {
assert_eq!(
d(15, 1).round(0, RoundingMode::HalfEven).unwrap(),
d(2, 0)
);
}
#[test]
fn round_half_even_round_to_even_down() {
assert_eq!(
d(25, 1).round(0, RoundingMode::HalfEven).unwrap(),
d(2, 0)
);
}
#[test]
fn round_half_even_past_midpoint() {
assert_eq!(
d(16, 1).round(0, RoundingMode::HalfEven).unwrap(),
d(2, 0)
);
}
#[test]
fn round_no_op_when_dp_equals_scale() {
let x = d(12345, 3);
assert_eq!(x.round(3, RoundingMode::HalfEven).unwrap(), x);
}
#[test]
fn round_no_op_when_dp_exceeds_scale() {
let x = d(12345, 3);
assert_eq!(x.round(5, RoundingMode::HalfEven).unwrap(), x);
}
#[test]
fn rescale_up_basic() {
let x = d(1, 0); let y = x.rescale_up(6).unwrap(); assert_eq!(y.mantissa(), 1_000_000);
assert_eq!(y.scale(), 6);
}
#[test]
fn rescale_up_noop() {
let x = d(1_000_000, 6);
assert_eq!(x.rescale_up(3).unwrap(), x); }
}