use super::{frexp, log10u};
use std::collections::VecDeque;
pub const DIGIT_WIDTH: usize = 9;
pub const DIGIT_BASE: u32 = 1_000_000_000;
pub const BITS_PER_DIGIT: usize = 29;
#[inline]
pub fn divmod_floor(n: i32, d: i32) -> (i32, i32) {
(n.div_euclid(d), n.rem_euclid(d))
}
#[derive(Debug, Clone, Copy)]
pub enum DigitLimit {
Total(usize),
Fractional(usize),
}
#[derive(Debug)]
pub struct Decimal {
pub digits: VecDeque<u32>,
pub radix: i32,
pub negative: bool,
}
impl Decimal {
pub fn new(y: f64, limit: DigitLimit) -> Self {
debug_assert!(y.is_finite());
let negative = y.is_sign_negative();
let (mut y, mut e2) = frexp(y.abs());
if y != 0.0 {
y *= (1 << BITS_PER_DIGIT) as f64;
e2 -= BITS_PER_DIGIT as i32;
}
let mut digits = Vec::new();
while y != 0.0 {
debug_assert!(y >= 0.0 && y < DIGIT_BASE as f64);
let digit = y as u32;
digits.push(digit);
y = (DIGIT_BASE as f64) * (y - digit as f64);
}
let mut decimal = Decimal {
digits: digits.into(),
radix: 0,
negative,
};
if e2 >= 0 {
decimal.shift_left(e2 as usize);
} else {
decimal.shift_right(-e2 as usize, limit);
}
decimal
}
pub fn push_front(&mut self, digit: u32) {
self.digits.push_front(digit);
self.radix += 1;
}
pub fn push_back(&mut self, digit: u32) {
self.digits.push_back(digit);
}
pub fn last(&self) -> Option<u32> {
self.digits.back().copied()
}
pub fn first(&self) -> Option<u32> {
self.digits.front().copied()
}
pub fn shift_left(&mut self, mut amt: usize) {
while amt > 0 {
let sh = amt.min(BITS_PER_DIGIT);
let mut carry: u32 = 0;
for digit in self.digits.iter_mut().rev() {
let nd = ((*digit as u64) << sh) + carry as u64;
*digit = (nd % DIGIT_BASE as u64) as u32;
carry = (nd / DIGIT_BASE as u64) as u32;
}
if carry != 0 {
self.push_front(carry);
}
self.trim_trailing_zeros();
amt -= sh;
}
}
pub fn shift_right(&mut self, mut amt: usize, limit: DigitLimit) {
while amt > 0 {
let sh = amt.min(DIGIT_WIDTH);
let mut carry: u32 = 0;
let (s1, s2) = self.digits.as_mut_slices();
for digit in s1.iter_mut() {
let remainder = *digit & ((1 << sh) - 1); *digit = (*digit >> sh) + carry;
carry = (DIGIT_BASE >> sh) * remainder;
}
for digit in s2.iter_mut() {
let remainder = *digit & ((1 << sh) - 1); *digit = (*digit >> sh) + carry;
carry = (DIGIT_BASE >> sh) * remainder;
}
self.trim_leading_zeros();
if carry != 0 {
self.push_back(carry);
}
amt -= sh;
match limit {
DigitLimit::Total(n) => {
self.digits.truncate(n);
}
DigitLimit::Fractional(n) => {
let current = (self.digits.len() as i32 - self.radix - 1).max(0) as usize;
let to_trunc = current.saturating_sub(n);
self.digits
.truncate(self.digits.len().saturating_sub(to_trunc));
}
}
}
}
pub fn len_i32(&self) -> i32 {
self.digits.len() as i32
}
pub fn exponent(&self) -> i32 {
let Some(first_digit) = self.first() else {
return 0;
};
self.radix * (DIGIT_WIDTH as i32) + log10u(first_digit)
}
pub fn fractional_digit_count(&self) -> i32 {
(DIGIT_WIDTH as i32) * (self.digits.len() as i32 - self.radix - 1)
}
fn trim_leading_zeros(&mut self) {
while self.digits.front() == Some(&0) {
self.digits.pop_front();
self.radix -= 1;
}
}
fn trim_trailing_zeros(&mut self) {
while self.digits.iter().last() == Some(&0) {
self.digits.pop_back();
}
}
pub fn round_to_fractional_digits(&mut self, desired_frac_digits: i32) {
let frac_digit_count = self.fractional_digit_count();
if desired_frac_digits >= frac_digit_count {
return;
}
let (quot, rem) = divmod_floor(desired_frac_digits, DIGIT_WIDTH as i32);
let mut last_digit_idx = self.radix + 1 + quot;
while last_digit_idx < 0 {
self.push_front(0);
last_digit_idx += 1;
}
debug_assert!(DIGIT_WIDTH as i32 > rem);
let mod_base = 10u32.pow((DIGIT_WIDTH as i32 - rem) as u32);
debug_assert!(mod_base <= DIGIT_BASE);
let remainder_to_round = self[last_digit_idx] % mod_base;
self[last_digit_idx] -= remainder_to_round;
if self.should_round_up(last_digit_idx, remainder_to_round, mod_base) {
self[last_digit_idx] += mod_base;
while self[last_digit_idx] >= DIGIT_BASE {
self[last_digit_idx] = 0;
last_digit_idx -= 1;
if last_digit_idx < 0 {
self.push_front(0);
last_digit_idx = 0;
}
self[last_digit_idx] += 1;
}
}
self.digits.truncate(last_digit_idx as usize + 1);
self.trim_trailing_zeros();
}
#[inline]
fn should_round_up(&self, digit_idx: i32, remainder: u32, mod_base: u32) -> bool {
if remainder == 0 && digit_idx + 1 == self.len_i32() {
return false;
}
let mut round = 2.0_f64.powi(f64::MANTISSA_DIGITS as i32);
let rounding_digit = if mod_base < DIGIT_BASE {
self[digit_idx] / mod_base
} else if digit_idx > 0 {
self[digit_idx - 1]
} else {
0
};
if rounding_digit & 1 != 0 {
round += 2.0;
debug_assert!(round.to_bits() & 1 != 0);
}
let mut small = if remainder < mod_base / 2 {
0.5
} else if remainder == mod_base / 2 && digit_idx + 1 == self.len_i32() {
1.0
} else {
1.5
};
if self.negative {
round = -round;
small = -small;
}
round + small != round
}
}
impl std::ops::Index<i32> for Decimal {
type Output = u32;
fn index(&self, index: i32) -> &Self::Output {
assert!(index >= 0);
&self.digits[index as usize]
}
}
impl std::ops::IndexMut<i32> for Decimal {
fn index_mut(&mut self, index: i32) -> &mut Self::Output {
assert!(index >= 0);
&mut self.digits[index as usize]
}
}