use smallvec::SmallVec;
use std::cmp;
use std::cmp::Ordering;
use std::fmt;
use std::ops::RangeInclusive;
use std::str::FromStr;
use static_assertions::const_assert;
use super::uint_iterator::IntIterator;
use crate::Error;
const_assert!(std::mem::size_of::<usize>() >= std::mem::size_of::<u16>());
#[derive(Debug, Clone, PartialEq)]
pub struct FixedDecimal {
digits: SmallVec<[u8; 8]>,
magnitude: i16,
upper_magnitude: i16,
lower_magnitude: i16,
pub is_negative: bool,
}
impl Default for FixedDecimal {
fn default() -> Self {
Self {
digits: SmallVec::new(),
magnitude: 0,
upper_magnitude: 0,
lower_magnitude: 0,
is_negative: false,
}
}
}
macro_rules! impl_from_signed_integer_type {
($itype:ident, $utype: ident) => {
impl From<$itype> for FixedDecimal {
fn from(value: $itype) -> Self {
let int_iterator: IntIterator<$utype> = value.into();
let is_negative = int_iterator.is_negative;
let mut result = Self::from_ascending(int_iterator)
.expect("All built-in integer types should fit");
result.is_negative = is_negative;
result
}
}
};
}
macro_rules! impl_from_unsigned_integer_type {
($utype: ident) => {
impl From<$utype> for FixedDecimal {
fn from(value: $utype) -> Self {
let int_iterator: IntIterator<$utype> = value.into();
Self::from_ascending(int_iterator).expect("All built-in integer types should fit")
}
}
};
}
impl_from_signed_integer_type!(isize, usize);
impl_from_signed_integer_type!(i128, u128);
impl_from_signed_integer_type!(i64, u64);
impl_from_signed_integer_type!(i32, u32);
impl_from_signed_integer_type!(i16, u16);
impl_from_signed_integer_type!(i8, u8);
impl_from_unsigned_integer_type!(usize);
impl_from_unsigned_integer_type!(u128);
impl_from_unsigned_integer_type!(u64);
impl_from_unsigned_integer_type!(u32);
impl_from_unsigned_integer_type!(u16);
impl_from_unsigned_integer_type!(u8);
impl FixedDecimal {
fn from_ascending<T>(digits_iter: T) -> Result<Self, Error>
where
T: Iterator<Item = u8>,
{
const X: usize = 39;
let mut mem: [u8; X] = [0u8; X];
let mut trailing_zeros: usize = 0;
let mut i: usize = 0;
for (x, d) in digits_iter.enumerate() {
if x > std::i16::MAX as usize {
return Err(Error::Limit);
}
if i != 0 || d != 0 {
i += 1;
match X.checked_sub(i) {
Some(v) => mem[v] = d,
None => return Err(Error::Limit),
}
} else {
trailing_zeros += 1;
}
}
let mut result: Self = Default::default();
if i != 0 {
let magnitude = trailing_zeros + i - 1;
debug_assert!(magnitude <= std::i16::MAX as usize);
result.magnitude = magnitude as i16;
result.upper_magnitude = result.magnitude;
debug_assert!(i <= X);
result.digits.extend_from_slice(&mem[(X - i)..]);
}
#[cfg(debug_assertions)]
result.check_invariants();
Ok(result)
}
pub fn digit_at(&self, magnitude: i16) -> u8 {
if magnitude > self.magnitude {
0
} else {
let j = (self.magnitude as i32 - magnitude as i32) as usize;
match self.digits.get(j) {
Some(v) => *v,
None => 0,
}
}
}
pub const fn magnitude_range(&self) -> RangeInclusive<i16> {
self.lower_magnitude..=self.upper_magnitude
}
pub fn multiply_pow10(&mut self, delta: i16) -> Result<(), Error> {
match delta.cmp(&0) {
Ordering::Greater => {
self.upper_magnitude = self
.upper_magnitude
.checked_add(delta)
.ok_or(Error::Limit)?;
let lower_magnitude = self.lower_magnitude + delta;
self.lower_magnitude = cmp::min(0, lower_magnitude);
}
Ordering::Less => {
self.lower_magnitude = self
.lower_magnitude
.checked_add(delta)
.ok_or(Error::Limit)?;
let upper_magnitude = self.upper_magnitude + delta;
self.upper_magnitude = cmp::max(0, upper_magnitude);
}
Ordering::Equal => {}
};
self.magnitude += delta;
#[cfg(debug_assertions)]
self.check_invariants();
Ok(())
}
pub fn multiplied_pow10(mut self, delta: i16) -> Result<Self, Error> {
match self.multiply_pow10(delta) {
Ok(()) => Ok(self),
Err(err) => Err(err),
}
}
pub fn write_to(&self, sink: &mut dyn fmt::Write) -> fmt::Result {
if self.is_negative {
sink.write_char('-')?;
}
for m in self.magnitude_range().rev() {
if m == -1 {
sink.write_char('.')?;
}
let d = self.digit_at(m);
sink.write_char((b'0' + d) as char)?;
}
Ok(())
}
pub const fn write_len(&self) -> usize {
let num_digits = 1 + (self.upper_magnitude as i32 - self.lower_magnitude as i32) as usize;
num_digits
+ (if self.is_negative { 1 } else { 0 })
+ (if self.lower_magnitude < 0 { 1 } else { 0 })
}
#[cfg(debug_assertions)]
fn check_invariants(&self) {
debug_assert!(self.upper_magnitude >= self.magnitude, "{:?}", self);
debug_assert!(self.lower_magnitude <= self.magnitude, "{:?}", self);
debug_assert!(self.upper_magnitude >= 0, "{:?}", self);
debug_assert!(self.lower_magnitude <= 0, "{:?}", self);
let max_len = (self.magnitude as i32 - self.lower_magnitude as i32 + 1) as usize;
debug_assert!(self.digits.len() <= max_len, "{:?}", self);
if !self.digits.is_empty() {
debug_assert_ne!(self.digits[0], 0, "{:?}", self);
debug_assert_ne!(self.digits[self.digits.len() - 1], 0, "{:?}", self);
}
}
}
impl fmt::Display for FixedDecimal {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
self.write_to(f)
}
}
impl FromStr for FixedDecimal {
type Err = Error;
fn from_str(input_str: &str) -> Result<Self, Self::Err> {
if input_str == "" || input_str == "-" {
return Err(Error::Syntax);
}
let input_str = input_str.as_bytes();
let is_negative = input_str[0] == b'-';
let no_sign_str = if is_negative {
&input_str[1..]
} else {
input_str
};
let no_sign_str_len = no_sign_str.len();
let mut has_dot = false;
let mut dot_index = no_sign_str_len;
for (i, c) in no_sign_str.iter().enumerate() {
if *c == b'.' {
match has_dot {
false => {
dot_index = i;
has_dot = true;
if i == 0 || i == no_sign_str.len() - 1 {
return Err(Error::Syntax);
}
}
true => {
return Err(Error::Syntax);
}
}
} else if *c < b'0' || *c > b'9' {
return Err(Error::Syntax);
}
}
let mut dec: Self = Self::default();
dec.is_negative = is_negative;
let mut no_dot_str_len = no_sign_str_len;
if has_dot {
no_dot_str_len -= 1;
}
let temp_upper_magnitude = dot_index - 1;
if temp_upper_magnitude > i16::MAX as usize {
return Err(Error::Limit);
}
dec.upper_magnitude = temp_upper_magnitude as i16;
let temp_lower_magnitude = no_dot_str_len - dot_index;
if temp_lower_magnitude > (i16::MIN as u16) as usize {
return Err(Error::Limit);
}
dec.lower_magnitude = (temp_lower_magnitude as i16).wrapping_neg();
let mut leftmost_digit = no_sign_str_len;
for (i, c) in no_sign_str.iter().enumerate() {
if *c == b'.' {
continue;
}
if *c != b'0' {
leftmost_digit = i;
break;
}
}
if leftmost_digit == no_sign_str_len {
return Ok(dec);
}
let mut temp_magnitude = ((dot_index as i32) - (leftmost_digit as i32) - 1i32) as i16;
if dot_index < leftmost_digit {
temp_magnitude += 1;
}
dec.magnitude = temp_magnitude;
let mut rightmost_digit = no_sign_str_len;
for (i, c) in no_sign_str.iter().rev().enumerate() {
if *c == b'.' {
continue;
}
if *c != b'0' {
rightmost_digit = no_sign_str_len - i;
break;
}
}
let mut digits_str_len = rightmost_digit - leftmost_digit;
if leftmost_digit < dot_index && dot_index < rightmost_digit {
digits_str_len -= 1;
}
let mut v: SmallVec<[u8; 8]> = SmallVec::with_capacity(digits_str_len);
for c in no_sign_str[leftmost_digit..rightmost_digit].iter() {
if *c == b'.' {
continue;
}
v.push(c - b'0');
}
let v_len = v.len();
debug_assert_eq!(v_len, digits_str_len);
dec.digits = v;
Ok(dec)
}
}
#[test]
fn test_basic() {
#[derive(Debug)]
struct TestCase {
pub input: isize,
pub delta: i16,
pub expected: &'static str,
};
let cases = [
TestCase {
input: 51423,
delta: 0,
expected: "51423",
},
TestCase {
input: 51423,
delta: -2,
expected: "514.23",
},
TestCase {
input: 51423,
delta: -5,
expected: "0.51423",
},
TestCase {
input: 51423,
delta: -8,
expected: "0.00051423",
},
TestCase {
input: 51423,
delta: 3,
expected: "51423000",
},
TestCase {
input: 0,
delta: 0,
expected: "0",
},
TestCase {
input: 0,
delta: -2,
expected: "0.00",
},
TestCase {
input: 0,
delta: 3,
expected: "0000",
},
TestCase {
input: 500,
delta: 0,
expected: "500",
},
TestCase {
input: 500,
delta: -1,
expected: "50.0",
},
TestCase {
input: 500,
delta: -2,
expected: "5.00",
},
TestCase {
input: 500,
delta: -3,
expected: "0.500",
},
TestCase {
input: 500,
delta: -4,
expected: "0.0500",
},
TestCase {
input: 500,
delta: 3,
expected: "500000",
},
TestCase {
input: -123,
delta: 0,
expected: "-123",
},
TestCase {
input: -123,
delta: -2,
expected: "-1.23",
},
TestCase {
input: -123,
delta: -5,
expected: "-0.00123",
},
TestCase {
input: -123,
delta: 3,
expected: "-123000",
},
];
for cas in &cases {
let mut dec: FixedDecimal = cas.input.into();
dec.multiply_pow10(cas.delta).unwrap();
let string = dec.to_string();
assert_eq!(cas.expected, string, "{:?}", cas);
assert_eq!(string.len(), dec.write_len(), "{:?}", cas);
}
}
#[test]
fn test_from_str() {
#[derive(Debug)]
struct TestCase {
pub input_str: &'static str,
};
let cases = [
TestCase {
input_str: "-00123400",
},
TestCase {
input_str: "0.0123400",
},
TestCase {
input_str: "-00.123400",
},
TestCase {
input_str: "0012.3400",
},
TestCase {
input_str: "-0012340.0",
},
TestCase { input_str: "1234" },
TestCase {
input_str: "0.000000001",
},
TestCase {
input_str: "0.0000000010",
},
TestCase {
input_str: "1000000",
},
TestCase {
input_str: "10000001",
},
TestCase { input_str: "123" },
TestCase {
input_str: "922337203685477580898230948203840239384.9823094820384023938423424",
},
TestCase {
input_str: "009223372000.003685477580898230948203840239384000",
},
TestCase {
input_str: "009223372000.003685477580898230948203840239384000",
},
TestCase { input_str: "0" },
TestCase { input_str: "-0" },
TestCase { input_str: "000" },
TestCase { input_str: "-00.0" },
];
for cas in &cases {
let input_str_roundtrip = FixedDecimal::from_str(cas.input_str).unwrap().to_string();
assert_eq!(cas.input_str, input_str_roundtrip);
}
}
#[test]
fn test_isize_limits() {
for num in &[std::isize::MAX, std::isize::MIN] {
let dec: FixedDecimal = (*num).into();
let dec_str = dec.to_string();
assert_eq!(num.to_string(), dec_str);
assert_eq!(dec, FixedDecimal::from_str(&dec_str).unwrap());
assert_eq!(dec.write_len(), dec_str.len());
}
}
#[test]
fn test_ui128_limits() {
for num in &[std::i128::MAX, std::i128::MIN] {
let dec: FixedDecimal = (*num).into();
let dec_str = dec.to_string();
assert_eq!(num.to_string(), dec_str);
assert_eq!(dec, FixedDecimal::from_str(&dec_str).unwrap());
assert_eq!(dec.write_len(), dec_str.len());
}
for num in &[std::u128::MAX, std::u128::MIN] {
let dec: FixedDecimal = (*num).into();
let dec_str = dec.to_string();
assert_eq!(num.to_string(), dec_str);
assert_eq!(dec, FixedDecimal::from_str(&dec_str).unwrap());
assert_eq!(dec.write_len(), dec_str.len());
}
}
#[test]
fn test_upper_magnitude_bounds() {
let mut dec: FixedDecimal = 98765.into();
assert_eq!(dec.upper_magnitude, 4);
dec.multiply_pow10(32763).unwrap();
assert_eq!(dec.upper_magnitude, std::i16::MAX);
let dec_backup = dec.clone();
assert_eq!(Error::Limit, dec.multiply_pow10(1).unwrap_err());
assert_eq!(dec, dec_backup, "Value should be unchanged on failure");
let dec_roundtrip = FixedDecimal::from_str(&dec.to_string()).unwrap();
assert_eq!(dec, dec_roundtrip);
}
#[test]
fn test_lower_magnitude_bounds() {
let mut dec: FixedDecimal = 98765.into();
assert_eq!(dec.lower_magnitude, 0);
dec.multiply_pow10(-32768).unwrap();
assert_eq!(dec.lower_magnitude, std::i16::MIN);
let dec_backup = dec.clone();
assert_eq!(Error::Limit, dec.multiply_pow10(-1).unwrap_err());
assert_eq!(dec, dec_backup, "Value should be unchanged on failure");
let dec_roundtrip = FixedDecimal::from_str(&dec.to_string()).unwrap();
assert_eq!(dec, dec_roundtrip);
}
#[test]
fn test_zero_str_bounds() {
#[derive(Debug)]
struct TestCase {
pub zeros_before_dot: usize,
pub zeros_after_dot: usize,
pub expected_err: Option<Error>,
};
let cases = [
TestCase {
zeros_before_dot: 32768,
zeros_after_dot: 0,
expected_err: None,
},
TestCase {
zeros_before_dot: 32767,
zeros_after_dot: 0,
expected_err: None,
},
TestCase {
zeros_before_dot: 32769,
zeros_after_dot: 0,
expected_err: Some(Error::Limit),
},
TestCase {
zeros_before_dot: 0,
zeros_after_dot: 32769,
expected_err: Some(Error::Limit),
},
TestCase {
zeros_before_dot: 32768,
zeros_after_dot: 32768,
expected_err: None,
},
TestCase {
zeros_before_dot: 32769,
zeros_after_dot: 32768,
expected_err: Some(Error::Limit),
},
TestCase {
zeros_before_dot: 32768,
zeros_after_dot: 32769,
expected_err: Some(Error::Limit),
},
TestCase {
zeros_before_dot: 32767,
zeros_after_dot: 32769,
expected_err: Some(Error::Limit),
},
TestCase {
zeros_before_dot: 32767,
zeros_after_dot: 32767,
expected_err: None,
},
TestCase {
zeros_before_dot: 32768,
zeros_after_dot: 32767,
expected_err: None,
},
];
for cas in &cases {
let mut input_str = format!("{:0fill$}", 0, fill = cas.zeros_before_dot);
if cas.zeros_after_dot > 0 {
input_str.push_str(".");
input_str.push_str(&format!("{:0fill$}", 0, fill = cas.zeros_after_dot));
}
match FixedDecimal::from_str(&input_str) {
Ok(dec) => {
assert_eq!(cas.expected_err, None, "{:?}", cas);
assert_eq!(input_str, dec.to_string(), "{:?}", cas);
}
Err(err) => {
assert_eq!(cas.expected_err, Some(err), "{:?}", cas);
}
}
}
}
#[test]
fn test_syntax_error() {
#[derive(Debug)]
struct TestCase {
pub input_str: &'static str,
pub expected_err: Option<Error>,
};
let cases = [
TestCase {
input_str: "-12a34",
expected_err: Some(Error::Syntax),
},
TestCase {
input_str: "0.0123√400",
expected_err: Some(Error::Syntax),
},
TestCase {
input_str: "0.012.3400",
expected_err: Some(Error::Syntax),
},
TestCase {
input_str: "-0-0123400",
expected_err: Some(Error::Syntax),
},
TestCase {
input_str: "0-0123400",
expected_err: Some(Error::Syntax),
},
TestCase {
input_str: "-.00123400",
expected_err: Some(Error::Syntax),
},
TestCase {
input_str: "-0.00123400",
expected_err: None,
},
TestCase {
input_str: ".00123400",
expected_err: Some(Error::Syntax),
},
TestCase {
input_str: "00123400.",
expected_err: Some(Error::Syntax),
},
TestCase {
input_str: "00123400.0",
expected_err: None,
},
TestCase {
input_str: "123_456",
expected_err: Some(Error::Syntax),
},
TestCase {
input_str: "",
expected_err: Some(Error::Syntax),
},
TestCase {
input_str: "-",
expected_err: Some(Error::Syntax),
},
TestCase {
input_str: "-1",
expected_err: None,
},
];
for cas in &cases {
match FixedDecimal::from_str(cas.input_str) {
Ok(dec) => {
assert_eq!(cas.expected_err, None, "{:?}", cas);
assert_eq!(cas.input_str, dec.to_string(), "{:?}", cas);
}
Err(err) => {
assert_eq!(cas.expected_err, Some(err), "{:?}", cas);
}
}
}
}