use bigdecimal::{
BigDecimal, Zero,
num_bigint::{BigInt, Sign},
};
use byteorder::{BigEndian, ReadBytesExt};
use bytes::{BufMut, BytesMut};
use postgres_types::{FromSql, IsNull, ToSql, Type, to_sql_checked};
use std::{cmp, convert::TryInto, error, fmt, io::Cursor};
#[derive(Debug, Clone)]
pub struct DecimalWrapper(pub BigDecimal);
#[derive(Debug, Clone)]
pub struct InvalidDecimal(&'static str);
impl fmt::Display for InvalidDecimal {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
fmt.write_fmt(format_args!("Invalid Decimal: {}", self.0))
}
}
impl error::Error for InvalidDecimal {}
struct PostgresDecimal<D> {
neg: bool,
weight: i16,
scale: u16,
digits: D,
}
fn from_postgres<D: ExactSizeIterator<Item = u16>>(dec: PostgresDecimal<D>) -> Result<BigDecimal, InvalidDecimal> {
let PostgresDecimal {
neg, digits, weight, ..
} = dec;
if digits.len() == 0 {
return Ok(0u64.into());
}
let sign = match neg {
false => Sign::Plus,
true => Sign::Minus,
};
let scale = (digits.len() as i64 - weight as i64 - 1) * 4;
let mut cents = Vec::with_capacity(digits.len() * 2);
for digit in digits {
cents.push((digit / 100) as u8);
cents.push((digit % 100) as u8);
}
let bigint = BigInt::from_radix_be(sign, ¢s, 100)
.ok_or(InvalidDecimal("PostgresDecimal contained an out-of-range digit"))?;
Ok(BigDecimal::new(bigint, scale))
}
fn to_postgres(decimal: &BigDecimal) -> crate::Result<PostgresDecimal<Vec<i16>>> {
if decimal.is_zero() {
return Ok(PostgresDecimal {
neg: false,
weight: 0,
scale: 0,
digits: vec![],
});
}
let base_10_to_10000 = |chunk: &[u8]| chunk.iter().fold(0i16, |a, &d| a * 10 + d as i16);
let (integer, exp) = decimal.as_bigint_and_exponent();
let (sign, base_10) = integer.to_radix_be(10);
let weight_10 = base_10.len() as i64 - exp;
let scale: u16 = cmp::max(0, exp).try_into()?;
let weight: i16 = if weight_10 <= 0 {
weight_10 / 4 - 1
} else {
(weight_10 - 1) / 4
}
.try_into()?;
let digits_len = if base_10.len() % 4 != 0 {
base_10.len() / 4 + 1
} else {
base_10.len() / 4
};
let offset = weight_10.rem_euclid(4) as usize;
let mut digits = Vec::with_capacity(digits_len);
if let Some(first) = base_10.get(..offset) {
if !first.is_empty() {
digits.push(base_10_to_10000(first));
}
} else if offset != 0 {
digits.push(base_10_to_10000(&base_10) * 10i16.pow((offset - base_10.len()) as u32));
}
if let Some(rest) = base_10.get(offset..) {
digits.extend(
rest.chunks(4)
.map(|chunk| base_10_to_10000(chunk) * 10i16.pow(4 - chunk.len() as u32)),
);
}
while let Some(&0) = digits.last() {
digits.pop();
}
let neg = match sign {
Sign::Plus | Sign::NoSign => false,
Sign::Minus => true,
};
Ok(PostgresDecimal {
neg,
weight,
scale,
digits,
})
}
impl FromSql<'_> for DecimalWrapper {
fn from_sql(_: &Type, raw: &[u8]) -> Result<DecimalWrapper, Box<dyn error::Error + 'static + Sync + Send>> {
let mut raw = Cursor::new(raw);
let num_groups = raw.read_u16::<BigEndian>()?;
let weight = raw.read_i16::<BigEndian>()?; let sign = raw.read_u16::<BigEndian>()?;
let scale = raw.read_u16::<BigEndian>()?;
let mut groups = Vec::new();
for _ in 0..num_groups as usize {
groups.push(raw.read_u16::<BigEndian>()?);
}
let dec = from_postgres(PostgresDecimal {
neg: sign == 0x4000,
weight,
scale,
digits: groups.into_iter(),
})
.map_err(Box::new)?;
Ok(DecimalWrapper(dec))
}
fn accepts(ty: &Type) -> bool {
matches!(*ty, Type::NUMERIC)
}
}
impl ToSql for DecimalWrapper {
fn to_sql(&self, _: &Type, out: &mut BytesMut) -> Result<IsNull, Box<dyn error::Error + 'static + Sync + Send>> {
let PostgresDecimal {
neg,
weight,
scale,
digits,
} = to_postgres(&self.0)?;
let num_digits = digits.len();
out.reserve(8 + num_digits * 2);
out.put_u16(num_digits.try_into()?);
out.put_i16(weight);
out.put_u16(if neg { 0x4000 } else { 0x0000 });
out.put_u16(scale);
for digit in digits[0..num_digits].iter() {
out.put_i16(*digit);
}
Ok(IsNull::No)
}
fn accepts(ty: &Type) -> bool {
matches!(*ty, Type::NUMERIC)
}
to_sql_checked!();
}