use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
use num::Zero;
use crate::Decimal;
use postgres::{to_sql_checked, types::*};
use lazy_static::lazy_static;
use std::{error, fmt, io::Cursor, result::*};
use crate::decimal::{div_by_u32, is_all_zero, mul_by_u32};
#[cfg(feature = "const_fn")]
const DECIMALS: [Decimal; 15] = [
Decimal::from_parts(1, 0, 0, false, 28),
Decimal::from_parts(1, 0, 0, false, 24),
Decimal::from_parts(1, 0, 0, false, 20),
Decimal::from_parts(1, 0, 0, false, 16),
Decimal::from_parts(1, 0, 0, false, 12),
Decimal::from_parts(1, 0, 0, false, 8),
Decimal::from_parts(1, 0, 0, false, 4),
Decimal::from_parts(1, 0, 0, false, 0),
Decimal::from_parts(1_0000, 0, 0, false, 0),
Decimal::from_parts(1_0000_0000, 0, 0, false, 0),
Decimal::from_parts(
1_0000_0000_0000u64 as u32,
(1_0000_0000_0000u64 >> 32) as u32,
0,
false,
0,
),
Decimal::from_parts(
1_0000_0000_0000_0000u64 as u32,
(1_0000_0000_0000_0000u64 >> 32) as u32,
0,
false,
0,
),
Decimal::from_parts(1661992960, 1808227885, 5, false, 0),
Decimal::from_parts(2701131776, 466537709, 54210, false, 0),
Decimal::from_parts(268435456, 1042612833, 542101086, false, 0),
];
#[cfg(not(feature = "const_fn"))]
lazy_static! {
static ref DECIMALS: [Decimal; 15] = [
Decimal::new(1, 28),
Decimal::new(1, 24),
Decimal::new(1, 20),
Decimal::new(1, 16),
Decimal::new(1, 12),
Decimal::new(1, 8),
Decimal::new(1, 4),
Decimal::new(1, 0),
Decimal::new(10000, 0),
Decimal::new(100000000, 0),
Decimal::new(1000000000000, 0),
Decimal::new(10000000000000000, 0),
Decimal::from_parts(1661992960, 1808227885, 5, false, 0),
Decimal::from_parts(2701131776, 466537709, 54210, false, 0),
Decimal::from_parts(268435456, 1042612833, 542101086, false, 0),
];
}
#[derive(Debug, Clone, Copy)]
pub struct InvalidDecimal;
impl fmt::Display for InvalidDecimal {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
fmt.write_str(error::Error::description(self))
}
}
impl error::Error for InvalidDecimal {
fn description(&self) -> &str {
"Invalid Decimal"
}
}
impl FromSql for Decimal {
fn from_sql(_: &Type, raw: &[u8]) -> Result<Decimal, Box<error::Error + 'static + Sync + Send>> {
let mut raw = Cursor::new(raw);
let num_groups = raw.read_u16::<BigEndian>()?;
let weight = raw.read_i16::<BigEndian>()?; let sign = raw.read_u16::<BigEndian>()?;
let fixed_scale = raw.read_u16::<BigEndian>()? as i32;
let mut groups = Vec::new();
for _ in 0..num_groups as usize {
let group = raw.read_u16::<BigEndian>()?;
groups.push(Decimal::new(group as i64, 0));
}
groups.reverse();
let mut result = Decimal::zero();
for (index, group) in groups.iter().enumerate() {
result = result + (&DECIMALS[index + 7] * group);
}
let mut scale = (num_groups as i16 - weight - 1) as i32 * 4;
if scale < 0 {
result *= Decimal::new(10i64.pow((scale * -1) as u32), 0);
scale = 0;
} else if scale > fixed_scale {
result /= Decimal::new(10i64.pow((scale - fixed_scale) as u32), 0);
scale = fixed_scale;
}
let neg = sign == 0x4000;
if result.set_scale(scale as u32).is_err() {
return Err(Box::new(InvalidDecimal));
}
result.set_sign(!neg);
Ok(result.normalize())
}
fn accepts(ty: &Type) -> bool {
match *ty {
NUMERIC => true,
_ => false,
}
}
}
impl ToSql for Decimal {
fn to_sql(&self, _: &Type, out: &mut Vec<u8>) -> Result<IsNull, Box<error::Error + 'static + Sync + Send>> {
if self.is_zero() {
out.write_u64::<BigEndian>(0)?;
return Ok(IsNull::No);
}
let sign = if self.is_sign_negative() { 0x4000 } else { 0x0000 };
let scale = self.scale() as u16;
let groups_diff = scale & 0x3; let mut fractional_groups_count = (scale >> 2) as isize; fractional_groups_count += if groups_diff > 0 { 1 } else { 0 };
let mut mantissa = self.mantissa_array4();
if groups_diff > 0 {
let remainder = 4 - groups_diff;
let power = 10u32.pow(remainder as u32);
mul_by_u32(&mut mantissa, power);
}
const MAX_GROUP_COUNT: usize = 8;
let mut groups = [0u16; MAX_GROUP_COUNT];
let mut num_groups = 0usize;
while !is_all_zero(&mantissa) {
let group_digits = div_by_u32(&mut mantissa, 10000) as u16;
groups[num_groups] = group_digits;
num_groups += 1;
}
let whole_portion_len = num_groups as isize - fractional_groups_count;
let weight = if whole_portion_len <= 0 {
-(fractional_groups_count as i16)
} else {
whole_portion_len as i16 - 1
};
out.write_u16::<BigEndian>(num_groups as u16)?;
out.write_i16::<BigEndian>(weight)?;
out.write_u16::<BigEndian>(sign)?;
out.write_u16::<BigEndian>(scale)?;
for group in groups[0..num_groups].iter().rev() {
out.write_u16::<BigEndian>(*group)?;
}
Ok(IsNull::No)
}
fn accepts(ty: &Type) -> bool {
match *ty {
NUMERIC => true,
_ => false,
}
}
to_sql_checked!();
}
#[cfg(test)]
mod test {
use super::*;
use postgres::{Connection, TlsMode};
use std::str::FromStr;
#[test]
fn ensure_equivalent_decimal_constants() {
let expected_decimals = [
Decimal::new(1, 28),
Decimal::new(1, 24),
Decimal::new(1, 20),
Decimal::new(1, 16),
Decimal::new(1, 12),
Decimal::new(1, 8),
Decimal::new(1, 4),
Decimal::new(1, 0),
Decimal::new(10000, 0),
Decimal::new(100000000, 0),
Decimal::new(1000000000000, 0),
Decimal::new(10000000000000000, 0),
Decimal::from_parts(1661992960, 1808227885, 5, false, 0),
Decimal::from_parts(2701131776, 466537709, 54210, false, 0),
Decimal::from_parts(268435456, 1042612833, 542101086, false, 0),
];
assert_eq!(&expected_decimals[..], &DECIMALS[..]);
}
fn read_type(sql_type: &str, checks: &[&'static str]) {
let conn = match Connection::connect("postgres://postgres@localhost", TlsMode::None) {
Ok(x) => x,
Err(err) => panic!("{:#?}", err),
};
for &val in checks.iter() {
let stmt = match conn.prepare(&*format!("SELECT {}::{}", val, sql_type)) {
Ok(x) => x,
Err(err) => panic!("{:#?}", err),
};
let result: Decimal = match stmt.query(&[]) {
Ok(x) => x.iter().next().unwrap().get(0),
Err(err) => panic!("{:#?}", err),
};
assert_eq!(val, result.to_string());
}
}
fn write_type(sql_type: &str, checks: &[&'static str]) {
let conn = match Connection::connect("postgres://postgres@localhost", TlsMode::None) {
Ok(x) => x,
Err(err) => panic!("{:#?}", err),
};
for &val in checks.iter() {
let stmt = match conn.prepare(&*format!("SELECT $1::{}", sql_type)) {
Ok(x) => x,
Err(err) => panic!("{:#?}", err),
};
let number = Decimal::from_str(val).unwrap();
let result: Decimal = match stmt.query(&[&number]) {
Ok(x) => x.iter().next().unwrap().get(0),
Err(err) => panic!("{:#?}", err),
};
assert_eq!(val, result.to_string());
}
}
pub static TEST_DECIMALS: &[&str; 20] = &[
"3950.123456",
"3950",
"0.1",
"0.01",
"0.001",
"0.0001",
"0.00001",
"0.000001",
"1",
"-100",
"-123.456",
"119996.25",
"1000000",
"9999999.99999",
"12340.56789",
"79228162514264337593543950335", "4951760157141521099596496895", "4951760157141521099596496896", "18446744073709551615",
"-18446744073709551615",
];
#[test]
fn test_null() {
let conn = match Connection::connect("postgres://postgres@localhost", TlsMode::None) {
Ok(x) => x,
Err(err) => panic!("{:#?}", err),
};
let stmt = match conn.prepare(&"SELECT NULL::numeric") {
Ok(x) => x,
Err(err) => panic!("{:#?}", err),
};
let result: Option<Decimal> = match stmt.query(&[]) {
Ok(x) => x.iter().next().unwrap().get(0),
Err(err) => panic!("{:#?}", err),
};
assert_eq!(None, result);
}
#[test]
fn read_numeric_type() {
read_type("NUMERIC(35, 6)", TEST_DECIMALS);
}
#[test]
fn write_numeric_type() {
write_type("NUMERIC(35, 6)", TEST_DECIMALS);
}
}