use thiserror::Error;
pub(crate) const SIGN_NEGATIVE: u8 = 0x00;
pub(crate) const SIGN_ZERO: u8 = 0x80;
pub(crate) const SIGN_POSITIVE: u8 = 0xFF;
pub const ENCODING_NEG_INFINITY: [u8; 3] = [0x00, 0x00, 0x00];
pub const ENCODING_POS_INFINITY: [u8; 3] = [0xFF, 0xFF, 0xFE];
pub const ENCODING_NAN: [u8; 3] = [0xFF, 0xFF, 0xFF];
const RESERVED_NEG_INFINITY_EXP: u16 = 0x0000; const RESERVED_POS_INFINITY_EXP: u16 = 0xFFFE; const RESERVED_NAN_EXP: u16 = 0xFFFF;
const EXPONENT_BIAS: i32 = 16384;
const MAX_EXPONENT: i32 = 32767 - EXPONENT_BIAS - 2; const MIN_EXPONENT: i32 = -EXPONENT_BIAS + 1;
#[derive(Error, Debug, Clone, PartialEq)]
pub enum DecimalError {
#[error("Invalid format: {0}")]
InvalidFormat(String),
#[error("Precision overflow: exponent out of range")]
PrecisionOverflow,
#[error("Invalid encoding")]
InvalidEncoding,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SpecialValue {
Infinity,
NegInfinity,
NaN,
}
pub fn encode_decimal(value: &str) -> Result<Vec<u8>, DecimalError> {
if let Some(special) = parse_special_value(value) {
return Ok(encode_special_value(special));
}
let (is_negative, digits, exponent) = parse_decimal(value)?;
if digits.is_empty() {
return Ok(vec![SIGN_ZERO]);
}
let mut result = Vec::with_capacity(1 + 2 + digits.len().div_ceil(2));
result.push(if is_negative {
SIGN_NEGATIVE
} else {
SIGN_POSITIVE
});
encode_exponent(&mut result, exponent, is_negative);
encode_mantissa(&mut result, &digits, is_negative);
Ok(result)
}
fn parse_special_value(value: &str) -> Option<SpecialValue> {
let trimmed = value.trim();
let lower = trimmed.to_lowercase();
match lower.as_str() {
"infinity" | "inf" | "+infinity" | "+inf" => Some(SpecialValue::Infinity),
"-infinity" | "-inf" => Some(SpecialValue::NegInfinity),
"nan" | "-nan" | "+nan" => Some(SpecialValue::NaN), _ => None,
}
}
pub fn encode_special_value(special: SpecialValue) -> Vec<u8> {
match special {
SpecialValue::NegInfinity => ENCODING_NEG_INFINITY.to_vec(),
SpecialValue::Infinity => ENCODING_POS_INFINITY.to_vec(),
SpecialValue::NaN => ENCODING_NAN.to_vec(),
}
}
pub fn decode_special_value(bytes: &[u8]) -> Option<SpecialValue> {
if bytes.len() == 3 {
if bytes == ENCODING_NEG_INFINITY {
return Some(SpecialValue::NegInfinity);
}
if bytes == ENCODING_POS_INFINITY {
return Some(SpecialValue::Infinity);
}
if bytes == ENCODING_NAN {
return Some(SpecialValue::NaN);
}
}
None
}
pub fn encode_decimal_with_constraints(
value: &str,
precision: Option<u32>,
scale: Option<i32>,
) -> Result<Vec<u8>, DecimalError> {
if parse_special_value(value).is_some() {
return encode_decimal(value);
}
let truncated = truncate_decimal(value, precision, scale)?;
encode_decimal(&truncated)
}
pub fn decode_to_string(bytes: &[u8]) -> Result<String, DecimalError> {
if bytes.is_empty() {
return Err(DecimalError::InvalidEncoding);
}
if let Some(special) = decode_special_value(bytes) {
return Ok(match special {
SpecialValue::NegInfinity => "-Infinity".to_string(),
SpecialValue::Infinity => "Infinity".to_string(),
SpecialValue::NaN => "NaN".to_string(),
});
}
let sign_byte = bytes[0];
if sign_byte == SIGN_ZERO {
return Ok("0".to_string());
}
let is_negative = sign_byte == SIGN_NEGATIVE;
if sign_byte != SIGN_NEGATIVE && sign_byte != SIGN_POSITIVE {
return Err(DecimalError::InvalidEncoding);
}
let (exponent, mantissa_start) = decode_exponent(&bytes[1..], is_negative)?;
let mantissa_bytes = &bytes[1 + mantissa_start..];
let digits = decode_mantissa(mantissa_bytes, is_negative)?;
format_decimal(is_negative, &digits, exponent)
}
pub fn decode_to_string_with_scale(bytes: &[u8], scale: i32) -> Result<String, DecimalError> {
let normalized = decode_to_string(bytes)?;
if normalized == "NaN" || normalized == "Infinity" || normalized == "-Infinity" {
return Ok(normalized);
}
if scale <= 0 {
return Ok(normalized);
}
let scale = scale as usize;
if let Some(dot_pos) = normalized.find('.') {
let current_decimals = normalized.len() - dot_pos - 1;
if current_decimals >= scale {
Ok(normalized)
} else {
let zeros_needed = scale - current_decimals;
Ok(format!("{}{}", normalized, "0".repeat(zeros_needed)))
}
} else {
Ok(format!("{}.{}", normalized, "0".repeat(scale)))
}
}
fn parse_decimal(value: &str) -> Result<(bool, Vec<u8>, i32), DecimalError> {
let value = value.trim();
let mut chars = value.chars().peekable();
let is_negative = if chars.peek() == Some(&'-') {
chars.next();
true
} else if chars.peek() == Some(&'+') {
chars.next();
false
} else {
false
};
let mut integer_part = String::new();
let mut fractional_part = String::new();
let mut seen_decimal = false;
let mut seen_exponent_marker = false;
while let Some(&c) = chars.peek() {
if c == '.' {
if seen_decimal {
return Err(DecimalError::InvalidFormat(
"Multiple decimal points".to_string(),
));
}
seen_decimal = true;
chars.next();
} else if c.is_ascii_digit() {
if seen_decimal {
fractional_part.push(c);
} else {
integer_part.push(c);
}
chars.next();
} else if c == 'e' || c == 'E' {
seen_exponent_marker = true;
chars.next();
break;
} else {
return Err(DecimalError::InvalidFormat(format!(
"Invalid character: {}",
c
)));
}
}
let mut exp_offset: i32 = 0;
if seen_exponent_marker {
if chars.peek().is_none() {
return Err(DecimalError::InvalidFormat(
"Missing exponent after 'e'".to_string(),
));
}
let exp_str: String = chars.collect();
exp_offset = exp_str
.parse()
.map_err(|_| DecimalError::InvalidFormat(format!("Invalid exponent: {}", exp_str)))?;
}
if integer_part.is_empty() && fractional_part.is_empty() {
return Ok((false, vec![], 0));
}
if integer_part.is_empty() {
integer_part.push('0');
}
let decimal_position = integer_part.len();
integer_part.push_str(&fractional_part);
let all_digits = integer_part;
let first_nonzero = all_digits.chars().position(|c| c != '0');
let last_nonzero = all_digits.chars().rev().position(|c| c != '0');
if first_nonzero.is_none() {
return Ok((false, vec![], 0));
}
let first_nonzero = first_nonzero.unwrap();
let last_nonzero = all_digits.len() - 1 - last_nonzero.unwrap();
let significant = &all_digits[first_nonzero..=last_nonzero];
let exponent = (decimal_position as i32) - (first_nonzero as i32) + exp_offset;
let digits: Vec<u8> = significant
.chars()
.map(|c| c.to_digit(10).unwrap() as u8)
.collect();
if !(MIN_EXPONENT..=MAX_EXPONENT).contains(&exponent) {
return Err(DecimalError::PrecisionOverflow);
}
Ok((is_negative, digits, exponent))
}
fn encode_exponent(result: &mut Vec<u8>, exponent: i32, is_negative: bool) {
let biased = (exponent + EXPONENT_BIAS) as u16;
let encoded = if is_negative { !biased } else { biased };
result.push((encoded >> 8) as u8);
result.push((encoded & 0xFF) as u8);
}
fn decode_exponent(bytes: &[u8], is_negative: bool) -> Result<(i32, usize), DecimalError> {
if bytes.len() < 2 {
return Err(DecimalError::InvalidEncoding);
}
let encoded = ((bytes[0] as u16) << 8) | (bytes[1] as u16);
if is_negative && encoded == RESERVED_NEG_INFINITY_EXP {
return Err(DecimalError::InvalidEncoding);
}
if !is_negative && (encoded == RESERVED_POS_INFINITY_EXP || encoded == RESERVED_NAN_EXP) {
return Err(DecimalError::InvalidEncoding);
}
let biased = if is_negative { !encoded } else { encoded };
let exponent = (biased as i32) - EXPONENT_BIAS;
Ok((exponent, 2))
}
fn encode_mantissa(result: &mut Vec<u8>, digits: &[u8], is_negative: bool) {
let mut i = 0;
while i < digits.len() {
let high = digits[i];
let low = if i + 1 < digits.len() {
digits[i + 1]
} else {
0 };
let byte = (high << 4) | low;
result.push(if is_negative { !byte } else { byte });
i += 2;
}
}
fn decode_mantissa(bytes: &[u8], is_negative: bool) -> Result<Vec<u8>, DecimalError> {
let mut digits = Vec::with_capacity(bytes.len() * 2);
for &byte in bytes {
let byte = if is_negative { !byte } else { byte };
let high = (byte >> 4) & 0x0F;
let low = byte & 0x0F;
if high > 9 || low > 9 {
return Err(DecimalError::InvalidEncoding);
}
digits.push(high);
digits.push(low);
}
while digits.last() == Some(&0) && digits.len() > 1 {
digits.pop();
}
Ok(digits)
}
fn format_decimal(is_negative: bool, digits: &[u8], exponent: i32) -> Result<String, DecimalError> {
if digits.is_empty() {
return Ok("0".to_string());
}
let mut result = String::new();
if is_negative {
result.push('-');
}
let num_digits = digits.len() as i32;
if exponent >= num_digits {
for d in digits {
result.push(char::from_digit(*d as u32, 10).unwrap());
}
for _ in 0..(exponent - num_digits) {
result.push('0');
}
} else if exponent <= 0 {
result.push('0');
result.push('.');
for _ in 0..(-exponent) {
result.push('0');
}
for d in digits {
result.push(char::from_digit(*d as u32, 10).unwrap());
}
} else {
let decimal_pos = exponent as usize;
for (i, d) in digits.iter().enumerate() {
if i == decimal_pos {
result.push('.');
}
result.push(char::from_digit(*d as u32, 10).unwrap());
}
}
Ok(result)
}
fn truncate_decimal(
value: &str,
precision: Option<u32>,
scale: Option<i32>,
) -> Result<String, DecimalError> {
let value = value.trim();
let is_negative = value.starts_with('-');
let value = value.trim_start_matches(['-', '+']);
let (integer_part, fractional_part) = if let Some(dot_pos) = value.find('.') {
(&value[..dot_pos], &value[dot_pos + 1..])
} else {
(value, "")
};
let integer_part = integer_part.trim_start_matches('0');
let integer_part = if integer_part.is_empty() {
"0"
} else {
integer_part
};
let scale_val = scale.unwrap_or(0);
if scale_val < 0 {
let round_digits = (-scale_val) as usize;
let mut int_str = integer_part.to_string();
if int_str.len() <= round_digits {
let num_val: u64 = int_str.parse().unwrap_or(0);
let rounding_unit = 10u64.pow(round_digits as u32);
let half_unit = rounding_unit / 2;
let result = if num_val >= half_unit {
rounding_unit.to_string()
} else {
"0".to_string()
};
return if is_negative && result != "0" {
Ok(format!("-{}", result))
} else {
Ok(result)
};
}
let keep_len = int_str.len() - round_digits;
let keep_part = &int_str[..keep_len];
let round_part = &int_str[keep_len..];
let first_rounded_digit = round_part.chars().next().unwrap_or('0');
let mut result_int = keep_part.to_string();
if first_rounded_digit >= '5' {
result_int = add_one_to_integer(&result_int);
}
int_str = format!("{}{}", result_int, "0".repeat(round_digits));
if let Some(p) = precision {
let max_significant = p as usize;
let significant_len = result_int.trim_start_matches('0').len();
if significant_len > max_significant && max_significant > 0 {
let trimmed = &result_int[result_int.len().saturating_sub(max_significant)..];
int_str = format!("{}{}", trimmed, "0".repeat(round_digits));
}
}
return if is_negative && int_str != "0" {
Ok(format!("-{}", int_str))
} else {
Ok(int_str)
};
}
let scale_usize = scale_val as usize;
let (mut integer_part, fractional_part) = if fractional_part.len() > scale_usize {
let truncated = &fractional_part[..scale_usize];
let next_digit = fractional_part.chars().nth(scale_usize).unwrap_or('0');
if next_digit >= '5' {
if scale_usize == 0 {
(add_one_to_integer(integer_part), String::new())
} else {
let rounded = round_up(truncated);
if rounded.len() > scale_usize {
let new_int = add_one_to_integer(integer_part);
(new_int, "0".repeat(scale_usize))
} else {
(integer_part.to_string(), rounded)
}
}
} else {
(integer_part.to_string(), truncated.to_string())
}
} else {
(integer_part.to_string(), fractional_part.to_string())
};
if let Some(p) = precision {
let max_integer_digits = if (p as i32) > scale_val {
(p as i32 - scale_val) as usize
} else {
0
};
if integer_part.len() > max_integer_digits && max_integer_digits > 0 {
integer_part = integer_part[integer_part.len() - max_integer_digits..].to_string();
} else if max_integer_digits == 0 {
integer_part = "0".to_string();
}
}
let result = if fractional_part.is_empty() || fractional_part.chars().all(|c| c == '0') {
integer_part
} else {
format!("{}.{}", integer_part, fractional_part.trim_end_matches('0'))
};
if is_negative && result != "0" {
Ok(format!("-{}", result))
} else {
Ok(result)
}
}
fn add_one_to_integer(s: &str) -> String {
let mut chars: Vec<char> = s.chars().collect();
let mut carry = true;
for c in chars.iter_mut().rev() {
if carry {
if *c == '9' {
*c = '0';
} else {
*c = char::from_digit(c.to_digit(10).unwrap() + 1, 10).unwrap();
carry = false;
}
}
}
if carry {
format!("1{}", chars.iter().collect::<String>())
} else {
chars.iter().collect()
}
}
fn round_up(s: &str) -> String {
let mut chars: Vec<char> = s.chars().collect();
let mut carry = true;
for c in chars.iter_mut().rev() {
if carry {
if *c == '9' {
*c = '0';
} else {
*c = char::from_digit(c.to_digit(10).unwrap() + 1, 10).unwrap();
carry = false;
}
}
}
if carry {
format!("1{}", chars.iter().collect::<String>())
} else {
chars.iter().collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_encode_decode_roundtrip() {
let values = vec![
"0",
"1",
"-1",
"123.456",
"-123.456",
"0.001",
"0.1",
"10",
"100",
"1000",
"-0.001",
"999999999999999999",
];
for s in values {
let encoded = encode_decimal(s).unwrap();
let decoded = decode_to_string(&encoded).unwrap();
let re_encoded = encode_decimal(&decoded).unwrap();
assert_eq!(encoded, re_encoded, "Roundtrip failed for {}", s);
}
}
#[test]
fn test_lexicographic_ordering() {
let values = vec![
"-1000", "-100", "-10", "-1", "-0.1", "-0.01", "0", "0.01", "0.1", "1", "10", "100",
"1000",
];
let encoded: Vec<Vec<u8>> = values.iter().map(|s| encode_decimal(s).unwrap()).collect();
for i in 0..encoded.len() - 1 {
assert!(
encoded[i] < encoded[i + 1],
"Ordering failed: {} should be < {}",
values[i],
values[i + 1]
);
}
}
#[test]
fn test_zero_encoding() {
let encoded = encode_decimal("0").unwrap();
assert_eq!(encoded, vec![SIGN_ZERO]);
let encoded = encode_decimal("0.0").unwrap();
assert_eq!(encoded, vec![SIGN_ZERO]);
let encoded = encode_decimal("-0").unwrap();
assert_eq!(encoded, vec![SIGN_ZERO]);
}
#[test]
fn test_truncate_scale() {
assert_eq!(
truncate_decimal("123.456", None, Some(2)).unwrap(),
"123.46"
);
assert_eq!(
truncate_decimal("123.454", None, Some(2)).unwrap(),
"123.45"
);
assert_eq!(truncate_decimal("123.995", None, Some(2)).unwrap(), "124");
assert_eq!(truncate_decimal("9.999", None, Some(2)).unwrap(), "10");
}
#[test]
fn test_storage_efficiency() {
let encoded = encode_decimal("123456789").unwrap();
assert!(
encoded.len() <= 8,
"Expected <= 8 bytes, got {}",
encoded.len()
);
let encoded = encode_decimal("0.1").unwrap();
assert!(
encoded.len() <= 4,
"Expected <= 4 bytes, got {}",
encoded.len()
);
}
#[test]
fn test_special_value_encoding() {
let pos_inf = encode_decimal("Infinity").unwrap();
assert_eq!(pos_inf, ENCODING_POS_INFINITY.to_vec());
let neg_inf = encode_decimal("-Infinity").unwrap();
assert_eq!(neg_inf, ENCODING_NEG_INFINITY.to_vec());
let nan = encode_decimal("NaN").unwrap();
assert_eq!(nan, ENCODING_NAN.to_vec());
}
#[test]
fn test_special_value_decoding() {
assert_eq!(
decode_to_string(&ENCODING_POS_INFINITY).unwrap(),
"Infinity"
);
assert_eq!(
decode_to_string(&ENCODING_NEG_INFINITY).unwrap(),
"-Infinity"
);
assert_eq!(decode_to_string(&ENCODING_NAN).unwrap(), "NaN");
}
#[test]
fn test_special_value_parsing_variants() {
let variants = vec![
("infinity", "Infinity"),
("Infinity", "Infinity"),
("INFINITY", "Infinity"),
("inf", "Infinity"),
("Inf", "Infinity"),
("+infinity", "Infinity"),
("+inf", "Infinity"),
("-infinity", "-Infinity"),
("-inf", "-Infinity"),
("-Infinity", "-Infinity"),
("nan", "NaN"),
("NaN", "NaN"),
("NAN", "NaN"),
("-nan", "NaN"), ("+nan", "NaN"),
];
for (input, expected) in variants {
let encoded = encode_decimal(input).unwrap();
let decoded = decode_to_string(&encoded).unwrap();
assert_eq!(decoded, expected, "Failed for input: {}", input);
}
}
#[test]
fn test_special_value_ordering() {
let values = vec![
"-Infinity",
"-1000000",
"-1",
"-0.001",
"0",
"0.001",
"1",
"1000000",
"Infinity",
"NaN",
];
let encoded: Vec<Vec<u8>> = values.iter().map(|s| encode_decimal(s).unwrap()).collect();
for i in 0..encoded.len() - 1 {
assert!(
encoded[i] < encoded[i + 1],
"Special value ordering failed: {} should be < {} (bytes: {:?} < {:?})",
values[i],
values[i + 1],
encoded[i],
encoded[i + 1]
);
}
}
#[test]
fn test_special_value_roundtrip() {
let values = vec!["Infinity", "-Infinity", "NaN"];
for s in values {
let encoded = encode_decimal(s).unwrap();
let decoded = decode_to_string(&encoded).unwrap();
let re_encoded = encode_decimal(&decoded).unwrap();
assert_eq!(
encoded, re_encoded,
"Special value roundtrip failed for {}",
s
);
}
}
#[test]
fn test_decode_special_value_helper() {
assert_eq!(
decode_special_value(&ENCODING_POS_INFINITY),
Some(SpecialValue::Infinity)
);
assert_eq!(
decode_special_value(&ENCODING_NEG_INFINITY),
Some(SpecialValue::NegInfinity)
);
assert_eq!(decode_special_value(&ENCODING_NAN), Some(SpecialValue::NaN));
let regular = encode_decimal("123.456").unwrap();
assert_eq!(decode_special_value(®ular), None);
let zero = encode_decimal("0").unwrap();
assert_eq!(decode_special_value(&zero), None);
}
#[test]
fn test_negative_scale_basic() {
assert_eq!(truncate_decimal("123", None, Some(-1)).unwrap(), "120");
assert_eq!(truncate_decimal("125", None, Some(-1)).unwrap(), "130");
assert_eq!(truncate_decimal("124", None, Some(-1)).unwrap(), "120");
assert_eq!(truncate_decimal("1234", None, Some(-2)).unwrap(), "1200");
assert_eq!(truncate_decimal("1250", None, Some(-2)).unwrap(), "1300");
assert_eq!(truncate_decimal("1249", None, Some(-2)).unwrap(), "1200");
assert_eq!(truncate_decimal("12345", None, Some(-3)).unwrap(), "12000");
assert_eq!(truncate_decimal("12500", None, Some(-3)).unwrap(), "13000");
}
#[test]
fn test_negative_scale_small_numbers() {
assert_eq!(truncate_decimal("499", None, Some(-3)).unwrap(), "0");
assert_eq!(truncate_decimal("500", None, Some(-3)).unwrap(), "1000");
assert_eq!(truncate_decimal("999", None, Some(-3)).unwrap(), "1000");
assert_eq!(truncate_decimal("49", None, Some(-2)).unwrap(), "0");
assert_eq!(truncate_decimal("50", None, Some(-2)).unwrap(), "100");
}
#[test]
fn test_negative_scale_with_precision() {
assert_eq!(
truncate_decimal("12345", Some(2), Some(-3)).unwrap(),
"12000"
);
assert_eq!(
truncate_decimal("99999", Some(2), Some(-3)).unwrap(),
"00000"
);
}
#[test]
fn test_negative_scale_negative_numbers() {
assert_eq!(truncate_decimal("-123", None, Some(-1)).unwrap(), "-120");
assert_eq!(truncate_decimal("-125", None, Some(-1)).unwrap(), "-130");
assert_eq!(truncate_decimal("-1234", None, Some(-2)).unwrap(), "-1200");
}
#[test]
fn test_negative_scale_with_decimal_input() {
assert_eq!(truncate_decimal("123.456", None, Some(-1)).unwrap(), "120");
assert_eq!(
truncate_decimal("1234.999", None, Some(-2)).unwrap(),
"1200"
);
}
#[test]
fn test_negative_scale_encoding_ordering() {
let values = vec!["-1000", "-100", "0", "100", "1000"];
let encoded: Vec<Vec<u8>> = values
.iter()
.map(|s| encode_decimal_with_constraints(s, None, Some(-2)).unwrap())
.collect();
for i in 0..encoded.len() - 1 {
assert!(
encoded[i] < encoded[i + 1],
"Negative scale ordering failed: {} should be < {}",
values[i],
values[i + 1]
);
}
}
#[test]
fn test_special_values_ignore_precision_scale() {
let inf = encode_decimal_with_constraints("Infinity", Some(5), Some(2)).unwrap();
assert_eq!(inf, ENCODING_POS_INFINITY.to_vec());
let neg_inf = encode_decimal_with_constraints("-Infinity", Some(5), Some(2)).unwrap();
assert_eq!(neg_inf, ENCODING_NEG_INFINITY.to_vec());
let nan = encode_decimal_with_constraints("NaN", Some(5), Some(2)).unwrap();
assert_eq!(nan, ENCODING_NAN.to_vec());
}
}