extern crate byteorder;
extern crate num;
use self::byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
use self::num::{Zero, ToPrimitive};
use super::Decimal;
use pg_crate::types::*;
use std::error;
use std::fmt;
use std::io::Cursor;
use std::result::*;
lazy_static! {
static ref DECIMALS: [Decimal; 15] = [
Decimal::new(1, 28),
Decimal::new(1, 24),
Decimal::new(1, 20),
Decimal::new(1, 16),
Decimal::new(1, 12),
Decimal::new(1, 8),
Decimal::new(1, 4),
Decimal::new(1, 0),
Decimal::new(10000, 0),
Decimal::new(100000000, 0),
Decimal::new(1000000000000, 0),
Decimal::new(10000000000000000, 0),
Decimal::from_parts(1661992960, 1808227885, 5, false, 0),
Decimal::from_parts(2701131776, 466537709, 54210, false, 0),
Decimal::from_parts(268435456, 1042612833, 542101086, false, 0),
];
}
#[derive(Debug, Clone, Copy)]
pub struct InvalidDecimal;
impl fmt::Display for InvalidDecimal {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
fmt.write_str(error::Error::description(self))
}
}
impl error::Error for InvalidDecimal {
fn description(&self) -> &str {
"Invalid Decimal"
}
}
impl FromSql for Decimal {
fn from_sql(_: &Type, raw: &[u8]) -> Result<Decimal, Box<error::Error + 'static + Sync + Send>> {
let mut raw = Cursor::new(raw);
let num_groups = raw.read_u16::<BigEndian>()?;
let weight = raw.read_i16::<BigEndian>()?; let sign = raw.read_u16::<BigEndian>()?;
let fixed_scale = raw.read_u16::<BigEndian>()? as i32;
let mut groups = Vec::new();
for _ in 0..num_groups as usize {
let group = raw.read_u16::<BigEndian>()?;
groups.push(Decimal::new(group as i64, 0));
}
groups.reverse();
let mut result = Decimal::zero();
for (index, group) in groups.iter().enumerate() {
result = result + (&DECIMALS[index + 7] * group);
}
let mut scale = (num_groups as i16 - weight - 1) as i32 * 4;
if scale < 0 {
result *= Decimal::new(10i64.pow((scale * -1) as u32), 0);
scale = 0;
} else if scale > fixed_scale {
result /= Decimal::new(10i64.pow((scale - fixed_scale) as u32), 0);
scale = fixed_scale;
}
let neg = sign == 0x4000;
if result.set_scale(scale as u32).is_err() {
return Err(Box::new(InvalidDecimal));
}
result.set_sign(!neg);
Ok(result.normalize())
}
fn accepts(ty: &Type) -> bool {
match *ty {
NUMERIC => true,
_ => false,
}
}
}
impl ToSql for Decimal {
fn to_sql(&self, _: &Type, out: &mut Vec<u8>) -> Result<IsNull, Box<error::Error + 'static + Sync + Send>> {
if self.is_zero() {
out.write_u64::<BigEndian>(0)?;
return Ok(IsNull::No);
}
let sign = if self.is_sign_negative() { 0x4000 } else { 0x0000 };
let scale = self.scale() as u16;
let mut whole = *self;
whole.set_sign(true);
whole.set_scale(0).ok();
let mut digits = whole.to_string();
let split_point = if scale as usize > digits.len() {
let mut new_digits = vec!['0'; scale as usize - digits.len() as usize];
new_digits.extend(digits.chars());
digits = new_digits.into_iter().collect::<String>();
0
} else {
digits.len() as isize - scale as isize
};
let (whole_digits, decimal_digits) = digits.split_at(split_point as usize);
let whole_portion = whole_digits
.chars()
.rev()
.collect::<Vec<char>>()
.chunks(4)
.map(|x| {
let mut x = x.to_owned();
while x.len() < 4 {
x.push('0');
}
x.into_iter().rev().collect::<String>()
})
.rev()
.collect::<Vec<String>>();
let decimal_portion = decimal_digits
.chars()
.collect::<Vec<char>>()
.chunks(4)
.map(|x| {
let mut x = x.to_owned();
while x.len() < 4 {
x.push('0');
}
x.into_iter().collect::<String>()
})
.collect::<Vec<String>>();
let weight = if whole_portion.is_empty() {
-(decimal_portion.len() as i16)
} else {
whole_portion.len() as i16 - 1
};
let all_groups = whole_portion
.into_iter()
.chain(decimal_portion.into_iter())
.skip_while(|ref x| *x == "0000")
.collect::<Vec<String>>();
let num_groups = all_groups.len() as u16;
out.write_u16::<BigEndian>(num_groups)?;
out.write_i16::<BigEndian>(weight)?;
out.write_u16::<BigEndian>(sign)?;
out.write_u16::<BigEndian>(scale)?;
for chunk in all_groups {
let calculated = chunk.parse::<u16>().unwrap();
out.write_u16::<BigEndian>(calculated.to_u16().unwrap())?;
}
Ok(IsNull::No)
}
fn accepts(ty: &Type) -> bool {
match *ty {
NUMERIC => true,
_ => false,
}
}
to_sql_checked!();
}
#[cfg(test)]
mod test {
use super::*;
use pg_crate::{Connection, TlsMode};
use std::str::FromStr;
fn read_type(sql_type: &str, checks: &[&'static str]) {
let conn = match Connection::connect("postgres://postgres@localhost", TlsMode::None) {
Ok(x) => x,
Err(err) => panic!("{:#?}", err),
};
for &val in checks.iter() {
let stmt = match conn.prepare(&*format!("SELECT {}::{}", val, sql_type)) {
Ok(x) => x,
Err(err) => panic!("{:#?}", err),
};
let result: Decimal = match stmt.query(&[]) {
Ok(x) => x.iter().next().unwrap().get(0),
Err(err) => panic!("{:#?}", err),
};
assert_eq!(val, result.to_string());
}
}
fn write_type(sql_type: &str, checks: &[&'static str]) {
let conn = match Connection::connect("postgres://postgres@localhost", TlsMode::None) {
Ok(x) => x,
Err(err) => panic!("{:#?}", err),
};
for &val in checks.iter() {
let stmt = match conn.prepare(&*format!("SELECT $1::{}", sql_type)) {
Ok(x) => x,
Err(err) => panic!("{:#?}", err),
};
let number = Decimal::from_str(val).unwrap();
let result: Decimal = match stmt.query(&[&number]) {
Ok(x) => x.iter().next().unwrap().get(0),
Err(err) => panic!("{:#?}", err),
};
assert_eq!(val, result.to_string());
}
}
#[test]
fn test_null() {
let conn = match Connection::connect("postgres://postgres@localhost", TlsMode::None) {
Ok(x) => x,
Err(err) => panic!("{:#?}", err),
};
let stmt = match conn.prepare(&"SELECT NULL::numeric") {
Ok(x) => x,
Err(err) => panic!("{:#?}", err),
};
let result: Option<Decimal> = match stmt.query(&[]) {
Ok(x) => x.iter().next().unwrap().get(0),
Err(err) => panic!("{:#?}", err),
};
assert_eq!(None, result);
}
#[test]
fn read_numeric_type() {
read_type(
"NUMERIC(26,6)",
&[
"3950.123456",
"3950",
"0.1",
"0.01",
"0.001",
"0.0001",
"0.00001",
"0.000001",
"1",
"-100",
"-123.456",
"119996.25",
"1000000",
"9999999.99999",
],
);
}
#[test]
fn write_numeric_type() {
write_type(
"NUMERIC(26,6)",
&[
"3950.123456",
"3950",
"0",
"0.1",
"0.01",
"0.001",
"0.0001",
"0.00001",
"0.000001",
"1",
"-100",
"-123.456",
"119996.25",
"1000000",
"9999999.99999",
],
);
}
}