use bytes::{BufMut, BytesMut};
use rust_decimal::Decimal;
use crate::error::{Error, Result};
use crate::types::{FromSql, Oid, ToSql};
const NUMERIC_POS: u16 = 0x0000;
const NUMERIC_NEG: u16 = 0x4000;
const NUMERIC_NAN: u16 = 0xC000;
impl ToSql for Decimal {
fn oid(&self) -> Oid {
Oid::NUMERIC
}
fn to_sql(&self, buf: &mut BytesMut) -> Result<()> {
if self.is_zero() {
buf.put_i16(0); buf.put_i16(0); buf.put_u16(NUMERIC_POS); buf.put_u16(0); return Ok(());
}
let sign = if self.is_sign_negative() {
NUMERIC_NEG
} else {
NUMERIC_POS
};
let scale = self.scale() as u16;
let abs = if self.is_sign_negative() {
-*self
} else {
*self
};
let mantissa = abs.mantissa() as u128;
let mantissa_str = mantissa.to_string();
let total_len = mantissa_str.len();
let point_pos = total_len.saturating_sub(scale as usize);
let int_part_len = point_pos;
let int_pad = (4 - (int_part_len % 4)) % 4;
let frac_part_len = total_len - point_pos + (scale as usize).saturating_sub(total_len);
let frac_pad = (4 - (frac_part_len % 4)) % 4;
let mut padded = String::new();
if total_len <= scale as usize {
let extra_zeros = (scale as usize).saturating_sub(total_len);
for _ in 0..extra_zeros {
padded.push('0');
}
padded.push_str(&mantissa_str);
let trail = (4 - (padded.len() % 4)) % 4;
for _ in 0..trail {
padded.push('0');
}
} else {
for _ in 0..int_pad {
padded.push('0');
}
padded.push_str(&mantissa_str);
for _ in 0..frac_pad {
padded.push('0');
}
}
while padded.len() % 4 != 0 {
padded.push('0');
}
let mut aligned_digits = Vec::new();
for chunk in padded.as_bytes().chunks(4) {
let s = std::str::from_utf8(chunk).unwrap_or("0000");
aligned_digits.push(s.parse::<u16>().unwrap_or(0));
}
let weight: i16 = if total_len <= scale as usize {
-1
} else {
let groups_before_point = (int_part_len + int_pad) / 4;
groups_before_point as i16 - 1
};
while aligned_digits.last() == Some(&0) && aligned_digits.len() > 1 {
aligned_digits.pop();
}
let mut final_weight = weight;
while aligned_digits.first() == Some(&0) && aligned_digits.len() > 1 {
aligned_digits.remove(0);
final_weight -= 1;
}
let ndigits = aligned_digits.len() as i16;
buf.put_i16(ndigits);
buf.put_i16(final_weight);
buf.put_u16(sign);
buf.put_u16(scale);
for d in &aligned_digits {
buf.put_u16(*d);
}
Ok(())
}
}
impl FromSql for Decimal {
fn oid() -> Oid {
Oid::NUMERIC
}
fn from_sql(buf: &[u8]) -> Result<Self> {
if buf.len() < 8 {
return Err(Error::Decode(format!(
"numeric: expected at least 8 bytes, got {}",
buf.len()
)));
}
let ndigits = i16::from_be_bytes([buf[0], buf[1]]);
let weight = i16::from_be_bytes([buf[2], buf[3]]);
let sign = u16::from_be_bytes([buf[4], buf[5]]);
let dscale = u16::from_be_bytes([buf[6], buf[7]]);
if sign == NUMERIC_NAN {
return Err(Error::Decode(
"numeric: NaN cannot be represented as Decimal".into(),
));
}
if ndigits == 0 {
return Ok(Decimal::ZERO);
}
let expected_len = 8 + (ndigits as usize * 2);
if buf.len() < expected_len {
return Err(Error::Decode(format!(
"numeric: expected {} bytes for {} digits, got {}",
expected_len,
ndigits,
buf.len()
)));
}
let mut digits = Vec::with_capacity(ndigits as usize);
for i in 0..ndigits as usize {
let d = u16::from_be_bytes([buf[8 + i * 2], buf[9 + i * 2]]);
digits.push(d);
}
let mut result = Decimal::ZERO;
let base = Decimal::new(10000, 0);
for (i, &d) in digits.iter().enumerate() {
let power = weight as i32 - i as i32;
let digit_val = Decimal::new(d as i64, 0);
if power >= 0 {
let mut multiplier = Decimal::ONE;
for _ in 0..power {
multiplier *= base;
}
result += digit_val * multiplier;
} else {
let mut divisor = Decimal::ONE;
for _ in 0..(-power) {
divisor *= base;
}
result += digit_val / divisor;
}
}
result.rescale(dscale as u32);
if sign == NUMERIC_NEG {
result.set_sign_negative(true);
}
Ok(result)
}
}