use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
use num::Zero;
use crate::Decimal;
use postgres::{to_sql_checked, types::*};
#[cfg(not(feature = "const_fn"))]
use lazy_static::lazy_static;
use std::{error, fmt, io::Cursor, result::*};
use crate::decimal::{div_by_u32, is_all_zero, mul_by_u32};
#[cfg(feature = "const_fn")]
const DECIMALS: [Decimal; 15] = [
Decimal::from_parts(1, 0, 0, false, 28),
Decimal::from_parts(1, 0, 0, false, 24),
Decimal::from_parts(1, 0, 0, false, 20),
Decimal::from_parts(1, 0, 0, false, 16),
Decimal::from_parts(1, 0, 0, false, 12),
Decimal::from_parts(1, 0, 0, false, 8),
Decimal::from_parts(1, 0, 0, false, 4),
Decimal::from_parts(1, 0, 0, false, 0),
Decimal::from_parts(1_0000, 0, 0, false, 0),
Decimal::from_parts(1_0000_0000, 0, 0, false, 0),
Decimal::from_parts(
1_0000_0000_0000u64 as u32,
(1_0000_0000_0000u64 >> 32) as u32,
0,
false,
0,
),
Decimal::from_parts(
1_0000_0000_0000_0000u64 as u32,
(1_0000_0000_0000_0000u64 >> 32) as u32,
0,
false,
0,
),
Decimal::from_parts(1661992960, 1808227885, 5, false, 0),
Decimal::from_parts(2701131776, 466537709, 54210, false, 0),
Decimal::from_parts(268435456, 1042612833, 542101086, false, 0),
];
#[cfg(not(feature = "const_fn"))]
lazy_static! {
static ref DECIMALS: [Decimal; 15] = [
Decimal::new(1, 28),
Decimal::new(1, 24),
Decimal::new(1, 20),
Decimal::new(1, 16),
Decimal::new(1, 12),
Decimal::new(1, 8),
Decimal::new(1, 4),
Decimal::new(1, 0),
Decimal::new(10000, 0),
Decimal::new(100000000, 0),
Decimal::new(1000000000000, 0),
Decimal::new(10000000000000000, 0),
Decimal::from_parts(1661992960, 1808227885, 5, false, 0),
Decimal::from_parts(2701131776, 466537709, 54210, false, 0),
Decimal::from_parts(268435456, 1042612833, 542101086, false, 0),
];
}
#[derive(Debug, Clone, Copy)]
pub struct InvalidDecimal;
impl fmt::Display for InvalidDecimal {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
fmt.write_str(error::Error::description(self))
}
}
impl error::Error for InvalidDecimal {
fn description(&self) -> &str {
"Invalid Decimal"
}
}
impl FromSql for Decimal {
fn from_sql(_: &Type, raw: &[u8]) -> Result<Decimal, Box<error::Error + 'static + Sync + Send>> {
let mut raw = Cursor::new(raw);
let num_groups = raw.read_u16::<BigEndian>()?;
let weight = raw.read_i16::<BigEndian>()?; let sign = raw.read_u16::<BigEndian>()?;
let fixed_scale = raw.read_u16::<BigEndian>()? as i32;
let mut groups = Vec::new();
for _ in 0..num_groups as usize {
let group = raw.read_u16::<BigEndian>()?;
groups.push(Decimal::new(group as i64, 0));
}
groups.reverse();
let mut result = Decimal::zero();
for (index, group) in groups.iter().enumerate() {
result = result + (&DECIMALS[index + 7] * group);
}
let mut scale = (num_groups as i16 - weight - 1) as i32 * 4;
if scale < 0 {
result *= Decimal::new(10i64.pow((scale * -1) as u32), 0);
scale = 0;
} else if scale > fixed_scale {
result /= Decimal::new(10i64.pow((scale - fixed_scale) as u32), 0);
scale = fixed_scale;
}
let neg = sign == 0x4000;
if result.set_scale(scale as u32).is_err() {
return Err(Box::new(InvalidDecimal));
}
result.set_sign(!neg);
Ok(result.normalize())
}
fn accepts(ty: &Type) -> bool {
match *ty {
NUMERIC => true,
_ => false,
}
}
}
impl ToSql for Decimal {
fn to_sql(&self, _: &Type, out: &mut Vec<u8>) -> Result<IsNull, Box<error::Error + 'static + Sync + Send>> {
if self.is_zero() {
out.write_u64::<BigEndian>(0)?;
return Ok(IsNull::No);
}
let sign = if self.is_sign_negative() { 0x4000 } else { 0x0000 };
let scale = self.scale() as u16;
let groups_diff = scale & 0x3; let mut fractional_groups_count = (scale >> 2) as isize; fractional_groups_count += if groups_diff > 0 { 1 } else { 0 };
let mut mantissa = self.mantissa_array4();
if groups_diff > 0 {
let remainder = 4 - groups_diff;
let power = 10u32.pow(remainder as u32);
mul_by_u32(&mut mantissa, power);
}
const MAX_GROUP_COUNT: usize = 8;
let mut groups = [0u16; MAX_GROUP_COUNT];
let mut num_groups = 0usize;
while !is_all_zero(&mantissa) {
let group_digits = div_by_u32(&mut mantissa, 10000) as u16;
groups[num_groups] = group_digits;
num_groups += 1;
}
let whole_portion_len = num_groups as isize - fractional_groups_count;
let weight = if whole_portion_len < 0 {
-(fractional_groups_count as i16)
} else {
whole_portion_len as i16 - 1
};
out.write_u16::<BigEndian>(num_groups as u16)?;
out.write_i16::<BigEndian>(weight)?;
out.write_u16::<BigEndian>(sign)?;
out.write_u16::<BigEndian>(scale)?;
for group in groups[0..num_groups].iter().rev() {
out.write_u16::<BigEndian>(*group)?;
}
Ok(IsNull::No)
}
fn accepts(ty: &Type) -> bool {
match *ty {
NUMERIC => true,
_ => false,
}
}
to_sql_checked!();
}
#[cfg(test)]
mod test {
use super::*;
use postgres::{Connection, TlsMode};
use std::str::FromStr;
pub static TEST_DECIMALS: &[(u32, u32, &str, &str)] = &[
(35, 6, "3950.123456", "3950.123456"),
(35, 2, "3950.123456", "3950.12"),
(35, 2, "3950.1256", "3950.13"),
(10, 2, "3950.123456", "3950.12"),
(35, 6, "3950", "3950"),
(4, 0, "3950", "3950"),
(35, 6, "0.1", "0.1"),
(35, 6, "0.01", "0.01"),
(35, 6, "0.001", "0.001"),
(35, 6, "0.0001", "0.0001"),
(35, 6, "0.00001", "0.00001"),
(35, 6, "0.000001", "0.000001"),
(35, 6, "1", "1"),
(35, 6, "-100", "-100"),
(35, 6, "-123.456", "-123.456"),
(35, 6, "119996.25", "119996.25"),
(35, 6, "1000000", "1000000"),
(35, 6, "9999999.99999", "9999999.99999"),
(35, 6, "12340.56789", "12340.56789"),
(35, 6, "79228162514264337593543950335", "79228162514264337593543950335"),
(35, 6, "4951760157141521099596496895", "4951760157141521099596496895"),
(35, 6, "4951760157141521099596496896", "4951760157141521099596496896"),
(35, 6, "18446744073709551615", "18446744073709551615"),
(35, 6, "-18446744073709551615", "-18446744073709551615"),
(35, 6, "0.10001", "0.10001"),
(35, 6, "0.12345", "0.12345"),
];
#[test]
fn ensure_equivalent_decimal_constants() {
let expected_decimals = [
Decimal::new(1, 28),
Decimal::new(1, 24),
Decimal::new(1, 20),
Decimal::new(1, 16),
Decimal::new(1, 12),
Decimal::new(1, 8),
Decimal::new(1, 4),
Decimal::new(1, 0),
Decimal::new(10000, 0),
Decimal::new(100000000, 0),
Decimal::new(1000000000000, 0),
Decimal::new(10000000000000000, 0),
Decimal::from_parts(1661992960, 1808227885, 5, false, 0),
Decimal::from_parts(2701131776, 466537709, 54210, false, 0),
Decimal::from_parts(268435456, 1042612833, 542101086, false, 0),
];
assert_eq!(&expected_decimals[..], &DECIMALS[..]);
}
#[test]
fn test_null() {
let conn = match Connection::connect("postgres://postgres@localhost", TlsMode::None) {
Ok(x) => x,
Err(err) => panic!("{:#?}", err),
};
let stmt = match conn.prepare(&"SELECT NULL::numeric") {
Ok(x) => x,
Err(err) => panic!("{:#?}", err),
};
let result: Option<Decimal> = match stmt.query(&[]) {
Ok(x) => x.iter().next().unwrap().get(0),
Err(err) => panic!("{:#?}", err),
};
assert_eq!(None, result);
}
#[test]
fn read_numeric_type() {
let conn = match Connection::connect("postgres://postgres@localhost", TlsMode::None) {
Ok(x) => x,
Err(err) => panic!("{:#?}", err),
};
for &(precision, scale, sent, expected) in TEST_DECIMALS.iter() {
let stmt = match conn.prepare(&*format!("SELECT {}::NUMERIC({}, {})", sent, precision, scale)) {
Ok(x) => x,
Err(err) => panic!("{:#?}", err),
};
let result: Decimal = match stmt.query(&[]) {
Ok(x) => x.iter().next().unwrap().get(0),
Err(err) => panic!("{:#?}", err),
};
assert_eq!(expected, result.to_string(), "NUMERIC({}, {})", precision, scale);
}
}
#[test]
fn write_numeric_type() {
let conn = match Connection::connect("postgres://postgres@localhost", TlsMode::None) {
Ok(x) => x,
Err(err) => panic!("{:#?}", err),
};
for &(precision, scale, sent, expected) in TEST_DECIMALS.iter() {
let stmt = match conn.prepare(&*format!("SELECT $1::NUMERIC({}, {})", precision, scale)) {
Ok(x) => x,
Err(err) => panic!("{:#?}", err),
};
let number = Decimal::from_str(sent).unwrap();
let result: Decimal = match stmt.query(&[&number]) {
Ok(x) => x.iter().next().unwrap().get(0),
Err(err) => panic!("{:#?}", err),
};
assert_eq!(expected, result.to_string(), "NUMERIC({}, {})", precision, scale);
}
}
#[test]
fn numeric_overflow() {
let tests = [
(4, 4, "3950.1234"),
];
let conn = match Connection::connect("postgres://postgres@localhost", TlsMode::None) {
Ok(x) => x,
Err(err) => panic!("{:#?}", err),
};
for &(precision, scale, sent) in tests.iter() {
let stmt = match conn.prepare(&*format!("SELECT {}::NUMERIC({}, {})", sent, precision, scale)) {
Ok(x) => x,
Err(err) => panic!("{:#?}", err),
};
match stmt.query(&[]) {
Ok(_) => panic!("Expected numeric overflow for {}::NUMERIC({}, {})", sent, precision, scale),
Err(err) => {
assert_eq!("22003", err.code().unwrap().code(), "Unexpected error code");
},
};
}
}
}