use crate::decoder::{decode_to_parts, DecodedDecimal, DecodedValue};
use crate::encoder::{encode_from_parts, encode_special_byte};
use crate::error::{DecodeError, EncodeError, EncodeResult};
use std::cmp::Ordering;
use std::fmt;
use std::fmt::Write as _;
use std::str::FromStr;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SpecialValue {
NegativeInfinity,
NegativeZero,
PositiveZero,
PositiveInfinity,
NaN,
}
#[derive(Debug, Clone)]
pub struct Decimal {
bytes: Vec<u8>,
}
impl Decimal {
#[must_use]
pub const fn from_bytes_unchecked(bytes: Vec<u8>) -> Self {
Self { bytes }
}
pub fn from_bytes(bytes: &[u8]) -> Result<Self, DecodeError> {
decode_to_parts(bytes)?;
Ok(Self {
bytes: bytes.to_vec(),
})
}
#[must_use]
pub fn as_bytes(&self) -> &[u8] {
&self.bytes
}
#[must_use]
pub fn into_bytes(self) -> Vec<u8> {
self.bytes
}
#[must_use]
pub fn decode(&self) -> Option<DecodedDecimal> {
match decode_to_parts(&self.bytes).ok()? {
DecodedValue::Regular(d) => Some(d),
DecodedValue::Special(_) => None,
}
}
#[must_use]
pub fn is_zero(&self) -> bool {
self.bytes.len() == 1 && (self.bytes[0] == 0x80 || self.bytes[0] == 0x40)
}
#[must_use]
pub fn is_nan(&self) -> bool {
self.bytes.len() == 1 && self.bytes[0] == 0b1110_0000
}
#[must_use]
pub fn is_infinity(&self) -> bool {
self.is_pos_infinity() || self.is_neg_infinity()
}
#[must_use]
pub fn is_pos_infinity(&self) -> bool {
self.bytes.len() == 1 && self.bytes[0] == 0xC0
}
#[must_use]
pub fn is_neg_infinity(&self) -> bool {
self.bytes.len() == 1 && self.bytes[0] == 0x00
}
#[must_use]
pub fn is_finite(&self) -> bool {
!self.is_infinity() && !self.is_nan()
}
#[must_use]
pub fn infinity() -> Self {
Self {
bytes: vec![encode_special_byte(SpecialValue::PositiveInfinity)],
}
}
#[must_use]
pub fn neg_infinity() -> Self {
Self {
bytes: vec![encode_special_byte(SpecialValue::NegativeInfinity)],
}
}
#[must_use]
pub fn nan() -> Self {
Self {
bytes: vec![encode_special_byte(SpecialValue::NaN)],
}
}
#[must_use]
pub fn zero() -> Self {
Self {
bytes: vec![encode_special_byte(SpecialValue::PositiveZero)],
}
}
fn parse_and_encode(s: &str) -> EncodeResult<Self> {
let s = s.trim();
if s.eq_ignore_ascii_case("inf")
|| s.eq_ignore_ascii_case("+inf")
|| s.eq_ignore_ascii_case("infinity")
|| s.eq_ignore_ascii_case("+infinity")
{
return Ok(Self::infinity());
}
if s.eq_ignore_ascii_case("-inf") || s.eq_ignore_ascii_case("-infinity") {
return Ok(Self::neg_infinity());
}
if s.eq_ignore_ascii_case("nan") {
return Ok(Self::nan());
}
#[allow(clippy::option_if_let_else)]
let (positive, s) = if let Some(stripped) = s.strip_prefix('-') {
(false, stripped)
} else if let Some(stripped) = s.strip_prefix('+') {
(true, stripped)
} else {
(true, s)
};
if s.is_empty() || !s.bytes().any(|b| b.is_ascii_digit()) {
return Err(EncodeError::InvalidFormat(
"input contains no digits".to_string(),
));
}
if s == "0" || s == "0.0" || s.bytes().all(|b| b == b'0' || b == b'.') {
return if positive {
Ok(Self::zero())
} else {
Ok(Self {
bytes: vec![encode_special_byte(SpecialValue::NegativeZero)],
})
};
}
let (integer_part, fractional_part) = match s.find('.') {
Some(pos) => {
let (int, rest) = s.split_at(pos);
if rest[1..].contains('.') {
return Err(EncodeError::InvalidFormat(
"multiple decimal points".to_string(),
));
}
(int, &rest[1..])
}
None => (s, ""),
};
for b in integer_part.bytes().chain(fractional_part.bytes()) {
if !b.is_ascii_digit() {
return Err(EncodeError::InvalidFormat(format!(
"invalid digit: {}",
b as char
)));
}
}
let leading_zeros = integer_part
.bytes()
.chain(fractional_part.bytes())
.take_while(|&b| b == b'0')
.count();
let total_len = integer_part.len() + fractional_part.len();
let trailing_zeros = fractional_part
.bytes()
.rev()
.chain(integer_part.bytes().rev())
.take_while(|&b| b == b'0')
.count();
let significant_len = total_len - leading_zeros - trailing_zeros;
if significant_len == 0 {
return if positive {
Ok(Self::zero())
} else {
Ok(Self {
bytes: vec![encode_special_byte(SpecialValue::NegativeZero)],
})
};
}
let decimal_point_position = integer_part.trim_start_matches('0').len();
#[allow(clippy::cast_possible_truncation)]
let (exponent, exponent_positive) = if decimal_point_position > 0 {
((decimal_point_position - 1) as u64, true)
} else {
let frac_leading_zeros =
fractional_part.len() - fractional_part.trim_start_matches('0').len();
((frac_leading_zeros + 1) as u64, false)
};
let mut significand = Vec::with_capacity(significant_len);
for b in integer_part
.bytes()
.chain(fractional_part.bytes())
.skip(leading_zeros)
{
significand.push(b - b'0');
}
while significand.last() == Some(&0) && significand.len() > 1 {
significand.pop();
}
let bytes = encode_from_parts(positive, exponent_positive, exponent, &significand);
Ok(Self { bytes })
}
}
impl FromStr for Decimal {
type Err = EncodeError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
Self::parse_and_encode(s)
}
}
impl fmt::Display for Decimal {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match decode_to_parts(&self.bytes) {
Ok(DecodedValue::Regular(d)) => {
if !d.positive {
f.write_str("-")?;
}
write!(f, "{}.", d.significand[0])?;
for &digit in &d.significand[1..] {
write!(f, "{digit}")?;
}
f.write_str(" × 10^")?;
if !d.exponent_positive {
f.write_str("-")?;
}
write!(f, "{}", d.exponent)
}
Ok(DecodedValue::Special(s)) => {
let name = match s {
SpecialValue::NegativeInfinity => "-∞",
SpecialValue::NegativeZero => "-0",
SpecialValue::PositiveZero => "0",
SpecialValue::PositiveInfinity => "∞",
SpecialValue::NaN => "NaN",
};
f.write_str(name)
}
Err(_) => f.write_str("<invalid>"),
}
}
}
impl PartialEq for Decimal {
fn eq(&self, other: &Self) -> bool {
if self.is_zero() && other.is_zero() {
return true;
}
self.bytes == other.bytes
}
}
impl Eq for Decimal {}
impl PartialOrd for Decimal {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for Decimal {
fn cmp(&self, other: &Self) -> Ordering {
if self.is_zero() && other.is_zero() {
return Ordering::Equal;
}
self.bytes.cmp(&other.bytes)
}
}
fn u128_to_digits(mut value: u128, buf: &mut [u8; 39]) -> usize {
if value == 0 {
buf[0] = 0;
return 1;
}
let mut pos = 39;
while value > 0 {
pos -= 1;
#[allow(clippy::cast_possible_truncation)]
{
buf[pos] = (value % 10) as u8;
}
value /= 10;
}
let len = 39 - pos;
buf.copy_within(pos..39, 0);
len
}
fn from_unsigned_with_sign(value: u128, positive: bool) -> Decimal {
if value == 0 {
return Decimal::zero();
}
let mut buf = [0u8; 39];
let len = u128_to_digits(value, &mut buf);
let digits = &buf[..len];
#[allow(clippy::cast_possible_truncation)]
let exponent = (len - 1) as u64;
let sig_end = digits.iter().rposition(|&d| d != 0).map_or(1, |p| p + 1);
let significand = &digits[..sig_end];
let bytes = encode_from_parts(positive, true, exponent, significand);
Decimal { bytes }
}
impl From<u64> for Decimal {
fn from(value: u64) -> Self {
from_unsigned_with_sign(u128::from(value), true)
}
}
impl From<i64> for Decimal {
fn from(value: i64) -> Self {
#[allow(clippy::cast_sign_loss)]
let (positive, magnitude) = if value >= 0 {
(true, value as u128)
} else {
(false, (-i128::from(value)) as u128)
};
from_unsigned_with_sign(magnitude, positive)
}
}
impl From<u128> for Decimal {
fn from(value: u128) -> Self {
from_unsigned_with_sign(value, true)
}
}
impl From<i128> for Decimal {
fn from(value: i128) -> Self {
#[allow(clippy::cast_sign_loss)]
let (positive, magnitude) = if value >= 0 {
(true, value as u128)
} else {
(false, value.unsigned_abs())
};
from_unsigned_with_sign(magnitude, positive)
}
}
impl From<u8> for Decimal {
fn from(value: u8) -> Self {
Self::from(u64::from(value))
}
}
impl From<u16> for Decimal {
fn from(value: u16) -> Self {
Self::from(u64::from(value))
}
}
impl From<u32> for Decimal {
fn from(value: u32) -> Self {
Self::from(u64::from(value))
}
}
impl From<i8> for Decimal {
fn from(value: i8) -> Self {
Self::from(i64::from(value))
}
}
impl From<i16> for Decimal {
fn from(value: i16) -> Self {
Self::from(i64::from(value))
}
}
impl From<i32> for Decimal {
fn from(value: i32) -> Self {
Self::from(i64::from(value))
}
}
struct StackBuf {
buf: [u8; 25],
len: usize,
}
impl StackBuf {
const fn new() -> Self {
Self {
buf: [0; 25],
len: 0,
}
}
fn as_str(&self) -> &str {
unsafe { std::str::from_utf8_unchecked(&self.buf[..self.len]) }
}
}
impl fmt::Write for StackBuf {
fn write_str(&mut self, s: &str) -> fmt::Result {
let bytes = s.as_bytes();
let new_len = self.len + bytes.len();
if new_len > self.buf.len() {
return Err(fmt::Error);
}
self.buf[self.len..new_len].copy_from_slice(bytes);
self.len = new_len;
Ok(())
}
}
impl From<f64> for Decimal {
fn from(value: f64) -> Self {
if value.is_nan() {
return Self::nan();
}
if value.is_infinite() {
return if value.is_sign_positive() {
Self::infinity()
} else {
Self::neg_infinity()
};
}
if value == 0.0 {
return if value.is_sign_positive() {
Self::zero()
} else {
Self {
bytes: vec![encode_special_byte(SpecialValue::NegativeZero)],
}
};
}
let mut buf = StackBuf::new();
write!(buf, "{value}").expect("f64 Display should fit in 25 bytes");
buf.as_str()
.parse()
.expect("f64 Display output should always be a valid decimal")
}
}
impl From<f32> for Decimal {
fn from(value: f32) -> Self {
Self::from(f64::from(value))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_positive() {
let d: Decimal = "123.456".parse().unwrap();
let decoded = d.decode().unwrap();
assert!(decoded.positive);
assert!(decoded.exponent_positive);
assert_eq!(decoded.exponent, 2);
assert!(decoded.significand.starts_with(&[1, 2, 3, 4, 5, 6]));
}
#[test]
fn test_parse_negative() {
let d: Decimal = "-103.2".parse().unwrap();
let decoded = d.decode().unwrap();
assert!(!decoded.positive);
assert!(decoded.exponent_positive);
assert_eq!(decoded.exponent, 2);
assert!(decoded.significand.starts_with(&[1, 0, 3, 2]));
}
#[test]
fn test_parse_small() {
let d: Decimal = "0.0405".parse().unwrap();
let decoded = d.decode().unwrap();
assert!(decoded.positive);
assert!(!decoded.exponent_positive);
assert_eq!(decoded.exponent, 2);
assert!(decoded.significand.starts_with(&[4, 0, 5]));
}
#[test]
fn test_parse_zero() {
let d: Decimal = "0".parse().unwrap();
assert!(d.is_zero());
}
#[test]
fn test_parse_infinity() {
let d: Decimal = "+inf".parse().unwrap();
assert!(d.is_pos_infinity());
}
#[test]
fn test_zero_equality() {
let pos_zero: Decimal = "0".parse().unwrap();
let neg_zero: Decimal = "-0".parse().unwrap();
assert_eq!(pos_zero, neg_zero, "+0 should equal -0");
}
#[test]
fn test_special_values_case_insensitive() {
assert!("INF".parse::<Decimal>().unwrap().is_pos_infinity());
assert!("Inf".parse::<Decimal>().unwrap().is_pos_infinity());
assert!("+INFINITY".parse::<Decimal>().unwrap().is_pos_infinity());
assert!("-inf".parse::<Decimal>().unwrap().is_neg_infinity());
assert!("-Infinity".parse::<Decimal>().unwrap().is_neg_infinity());
assert!("NaN".parse::<Decimal>().unwrap().is_nan());
assert!("nan".parse::<Decimal>().unwrap().is_nan());
assert!("NAN".parse::<Decimal>().unwrap().is_nan());
}
#[test]
fn test_multiple_decimal_points() {
assert!("123.456.789".parse::<Decimal>().is_err());
assert!("1.2.3".parse::<Decimal>().is_err());
}
#[test]
fn test_zero_ord_consistency() {
let pos_zero: Decimal = "0".parse().unwrap();
let neg_zero: Decimal = "-0".parse().unwrap();
assert_eq!(
pos_zero.cmp(&neg_zero),
Ordering::Equal,
"+0.cmp(-0) must be Equal to match PartialEq"
);
assert_eq!(
neg_zero.cmp(&pos_zero),
Ordering::Equal,
"-0.cmp(+0) must be Equal to match PartialEq"
);
}
#[test]
fn test_reject_empty_and_bare_inputs() {
assert!("".parse::<Decimal>().is_err(), "empty string should fail");
assert!("+".parse::<Decimal>().is_err(), "bare '+' should fail");
assert!("-".parse::<Decimal>().is_err(), "bare '-' should fail");
assert!(".".parse::<Decimal>().is_err(), "bare '.' should fail");
assert!("-.".parse::<Decimal>().is_err(), "'-.' should fail");
assert!("+.".parse::<Decimal>().is_err(), "'+.' should fail");
assert!(" ".parse::<Decimal>().is_err(), "whitespace should fail");
}
#[test]
fn test_from_u64_matches_parse() {
let cases: &[u64] = &[0, 1, 9, 10, 42, 100, 999, 1000, 123456789, u64::MAX];
for &n in cases {
let from_int = Decimal::from(n);
let from_str: Decimal = n.to_string().parse().unwrap();
assert_eq!(
from_int.as_bytes(),
from_str.as_bytes(),
"From<u64> mismatch for {n}"
);
}
}
#[test]
fn test_from_i64_matches_parse() {
let cases: &[i64] = &[
i64::MIN,
-123456789,
-1000,
-42,
-1,
0,
1,
42,
1000,
123456789,
i64::MAX,
];
for &n in cases {
let from_int = Decimal::from(n);
let from_str: Decimal = n.to_string().parse().unwrap();
assert_eq!(
from_int.as_bytes(),
from_str.as_bytes(),
"From<i64> mismatch for {n}"
);
}
}
#[test]
fn test_from_i128_extremes() {
let cases: &[i128] = &[i128::MIN, -1, 0, 1, i128::MAX];
for &n in cases {
let from_int = Decimal::from(n);
let from_str: Decimal = n.to_string().parse().unwrap();
assert_eq!(
from_int.as_bytes(),
from_str.as_bytes(),
"From<i128> mismatch for {n}"
);
}
}
#[test]
fn test_from_u128_max() {
let from_int = Decimal::from(u128::MAX);
let from_str: Decimal = u128::MAX.to_string().parse().unwrap();
assert_eq!(from_int.as_bytes(), from_str.as_bytes());
}
#[test]
fn test_from_small_types() {
assert_eq!(
Decimal::from(42u8).as_bytes(),
Decimal::from(42u64).as_bytes()
);
assert_eq!(
Decimal::from(42u16).as_bytes(),
Decimal::from(42u64).as_bytes()
);
assert_eq!(
Decimal::from(42u32).as_bytes(),
Decimal::from(42u64).as_bytes()
);
assert_eq!(
Decimal::from(-7i8).as_bytes(),
Decimal::from(-7i64).as_bytes()
);
assert_eq!(
Decimal::from(-7i16).as_bytes(),
Decimal::from(-7i64).as_bytes()
);
assert_eq!(
Decimal::from(-7i32).as_bytes(),
Decimal::from(-7i64).as_bytes()
);
}
#[test]
fn test_from_u64_order_preserved() {
let values: Vec<u64> = vec![0, 1, 2, 9, 10, 99, 100, 999, 1000, u64::MAX];
let decimals: Vec<Decimal> = values.iter().map(|&v| Decimal::from(v)).collect();
for i in 1..decimals.len() {
assert!(
decimals[i - 1] < decimals[i],
"Order not preserved: {} < {} failed",
values[i - 1],
values[i]
);
}
}
#[test]
fn test_from_zero_is_positive_zero() {
let d = Decimal::from(0u64);
assert!(d.is_zero());
assert_eq!(d.as_bytes(), Decimal::zero().as_bytes());
}
#[test]
fn test_from_f64_matches_parse() {
let cases: &[f64] = &[1.0, -1.0, 0.5, -0.5, 42.0, 123.456, 0.001, 1e10, 1e-10];
for &v in cases {
let from_float = Decimal::from(v);
let from_str: Decimal = v.to_string().parse().unwrap();
assert_eq!(
from_float.as_bytes(),
from_str.as_bytes(),
"From<f64> mismatch for {v}"
);
}
}
#[test]
fn test_from_f64_special_values() {
assert!(Decimal::from(f64::NAN).is_nan());
assert!(Decimal::from(f64::INFINITY).is_pos_infinity());
assert!(Decimal::from(f64::NEG_INFINITY).is_neg_infinity());
assert!(Decimal::from(0.0_f64).is_zero());
assert!(Decimal::from(-0.0_f64).is_zero());
}
#[test]
fn test_from_f64_negative_zero_preserved() {
let neg_zero = Decimal::from(-0.0_f64);
let pos_zero = Decimal::from(0.0_f64);
assert_eq!(neg_zero, pos_zero);
assert_ne!(neg_zero.as_bytes(), pos_zero.as_bytes());
}
#[test]
fn test_from_f64_order_preserved() {
let values: Vec<f64> = vec![-1000.0, -1.0, -0.001, 0.001, 1.0, 1000.0];
let decimals: Vec<Decimal> = values.iter().map(|&v| Decimal::from(v)).collect();
for i in 1..decimals.len() {
assert!(
decimals[i - 1] < decimals[i],
"Order not preserved: {} < {} failed",
values[i - 1],
values[i]
);
}
}
#[test]
fn test_from_f32_matches_f64_widening() {
let v = 2.72_f32;
let from_f32 = Decimal::from(v);
let from_f64 = Decimal::from(f64::from(v));
assert_eq!(from_f32.as_bytes(), from_f64.as_bytes());
}
}