use rust_decimal::Decimal;
use crate::error::{Error, Result};
use crate::protocol::types::{Oid, oid};
use super::{FromWireValue, ToWireValue};
const NUMERIC_NEG: u16 = 0x4000;
const NUMERIC_NAN: u16 = 0xC000;
const NBASE: i128 = 10000;
impl FromWireValue<'_> for Decimal {
fn from_text(oid: Oid, bytes: &[u8]) -> Result<Self> {
if oid != oid::NUMERIC {
return Err(Error::Decode(format!(
"cannot decode oid {} as Decimal",
oid
)));
}
let s = simdutf8::compat::from_utf8(bytes)
.map_err(|e| Error::Decode(format!("invalid UTF-8: {}", e)))?;
if s == "NaN" {
return Err(Error::Decode("NaN cannot be represented as Decimal".into()));
}
Decimal::from_str_exact(s).map_err(|e| Error::Decode(format!("invalid decimal: {}", e)))
}
fn from_binary(oid: Oid, bytes: &[u8]) -> Result<Self> {
if oid != oid::NUMERIC {
return Err(Error::Decode(format!(
"cannot decode oid {} as Decimal",
oid
)));
}
let (header, digit_bytes) = bytes
.split_first_chunk::<8>()
.ok_or_else(|| Error::Decode(format!("invalid NUMERIC length: {}", bytes.len())))?;
let ndigits = i16::from_be_bytes([header[0], header[1]]) as usize;
let weight = i16::from_be_bytes([header[2], header[3]]);
let sign = u16::from_be_bytes([header[4], header[5]]);
let dscale = u16::from_be_bytes([header[6], header[7]]);
if sign == NUMERIC_NAN {
return Err(Error::Decode("NaN cannot be represented as Decimal".into()));
}
if ndigits == 0 {
return Ok(Decimal::ZERO);
}
if digit_bytes.len() < ndigits * 2 {
return Err(Error::Decode(format!(
"invalid NUMERIC length: {} (expected {})",
bytes.len(),
8 + ndigits * 2
)));
}
let mut digits = Vec::with_capacity(ndigits);
let mut remaining = digit_bytes;
for _ in 0..ndigits {
let (pair, rest) = remaining
.split_first_chunk::<2>()
.ok_or_else(|| Error::Decode("truncated NUMERIC digit".into()))?;
remaining = rest;
digits.push(u16::from_be_bytes(*pair));
}
let mut value: i128 = 0;
for &digit in &digits {
value = value * NBASE + (digit as i128);
}
let exponent = (weight as i32 - ndigits as i32 + 1) * 4;
if sign == NUMERIC_NEG {
value = -value;
}
let mut decimal = Decimal::from_i128_with_scale(value, 0);
if exponent > 0 {
for _ in 0..exponent {
decimal = decimal
.checked_mul(Decimal::TEN)
.ok_or_else(|| Error::Decode("decimal overflow".into()))?;
}
} else if exponent < 0 {
decimal
.set_scale((-exponent) as u32)
.map_err(|e| Error::Decode(format!("decimal scale error: {}", e)))?;
}
if dscale > 0 {
decimal = decimal.round_dp(dscale as u32);
}
Ok(decimal)
}
}
impl ToWireValue for Decimal {
fn natural_oid(&self) -> Oid {
oid::NUMERIC
}
fn encode(&self, target_oid: Oid, buf: &mut Vec<u8>) -> Result<()> {
match target_oid {
oid::NUMERIC => {
let text = self.to_string();
buf.extend_from_slice(&(text.len() as i32).to_be_bytes());
buf.extend_from_slice(text.as_bytes());
Ok(())
}
_ => Err(Error::type_mismatch(self.natural_oid(), target_oid)),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::str::FromStr;
#[test]
fn decimal_text_decode() {
let dec = Decimal::from_text(oid::NUMERIC, b"123.45").unwrap();
assert_eq!(dec, Decimal::from_str("123.45").unwrap());
}
#[test]
fn decimal_text_negative() {
let dec = Decimal::from_text(oid::NUMERIC, b"-999.999").unwrap();
assert_eq!(dec, Decimal::from_str("-999.999").unwrap());
}
#[test]
fn decimal_zero() {
let dec = Decimal::from_text(oid::NUMERIC, b"0").unwrap();
assert_eq!(dec, Decimal::ZERO);
}
#[test]
fn decimal_encode_text_format() {
let original = Decimal::from_str("12345.6789").unwrap();
let mut buf = Vec::new();
original.encode(original.natural_oid(), &mut buf).unwrap();
let text = std::str::from_utf8(&buf[4..]).unwrap();
assert_eq!(text, "12345.6789");
}
#[test]
fn decimal_encode_zero() {
let original = Decimal::ZERO;
let mut buf = Vec::new();
original.encode(original.natural_oid(), &mut buf).unwrap();
let text = std::str::from_utf8(&buf[4..]).unwrap();
assert_eq!(text, "0");
}
#[test]
fn decimal_encode_negative() {
let original = Decimal::from_str("-123.456").unwrap();
let mut buf = Vec::new();
original.encode(original.natural_oid(), &mut buf).unwrap();
let text = std::str::from_utf8(&buf[4..]).unwrap();
assert_eq!(text, "-123.456");
}
#[test]
fn decimal_nan_text() {
let result = Decimal::from_text(oid::NUMERIC, b"NaN");
result.unwrap_err();
}
}