extern crate byteorder;
extern crate num;
use pg_crate::types::*;
use pg_crate::error::Error;
use std::io::Cursor;
use std::fmt;
use std::result::*;
use std::error;
use self::byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
use self::num::{BigUint, One, Zero, ToPrimitive};
use self::num::bigint::ToBigUint;
#[cfg(test)]
use pg_crate::{Connection, TlsMode};
#[cfg(test)]
use std::str::FromStr;
use super::Decimal;
#[derive(Debug, Clone, Copy)]
pub struct InvalidDecimal;
impl fmt::Display for InvalidDecimal {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
fmt.write_str(error::Error::description(self))
}
}
impl error::Error for InvalidDecimal {
fn description(&self) -> &str {
"Invalid Decimal"
}
}
impl FromSql for Decimal {
fn from_sql(_: &Type, raw: &[u8])
-> Result<Decimal, Box<error::Error + 'static + Sync + Send>> {
let mut raw = Cursor::new(raw);
let num_groups = try!(raw.read_u16::<BigEndian>());
let weight = try!(raw.read_i16::<BigEndian>()); let sign = try!(raw.read_u16::<BigEndian>());
let fixed_scale = try!(raw.read_u16::<BigEndian>()) as i32;
let mut powers = Vec::new();
let mult = 10000.to_biguint().unwrap();
let mut val: BigUint = One::one();
powers.push(BigUint::one());
for _ in 1..num_groups {
val = &val * &mult;
powers.push(val.clone());
}
powers.reverse();
let mut result: BigUint = Zero::zero();
for i in 0..num_groups {
let group = try!(raw.read_u16::<BigEndian>());
let calculated = &powers[i as usize] * group.to_biguint().unwrap();
result = result + calculated;
}
let mut scale = (num_groups as i16 - weight - 1) as i32 * 4;
if scale < 0 {
result = result * 10i64.pow((scale * -1) as u32).to_biguint().unwrap();
scale = 0;
} else if scale > fixed_scale {
result = result / 10i64.pow((scale - fixed_scale) as u32).to_biguint().unwrap();
scale = fixed_scale;
}
let neg = sign == 0x4000;
let mut decimal = try!(match Decimal::from_biguint(result, scale as u32, neg) {
Ok(x) => Ok(x),
Err(_) => Err(Box::new(Error::Conversion(Box::new(InvalidDecimal)))),
});
if scale > 0 {
let str_rep = decimal.to_string();
let trailing_zeros = str_rep.chars().rev().take_while(|&x| x == '0').count();
decimal = decimal.rescale(scale as u32 - trailing_zeros as u32);
}
Ok(decimal)
}
fn accepts(ty: &Type) -> bool {
match *ty {
Type::Numeric => true,
_ => false,
}
}
}
impl ToSql for Decimal {
fn to_sql(&self,
_: &Type,
out: &mut Vec<u8>)
-> Result<IsNull, Box<error::Error + 'static + Sync + Send>> {
let uint = self.to_biguint();
let sign = if self.is_negative() { 0x4000 } else { 0x0000 };
let scale = self.scale() as u16;
let mut digits = uint.to_str_radix(10);
let split_point = if scale as usize > digits.len() {
let mut new_digits = vec!['0'; scale as usize - digits.len() as usize];
new_digits.extend(digits.chars());
digits = new_digits.into_iter().collect::<String>();
0
} else {
digits.len() as isize - scale as isize
};
let (whole_digits, decimal_digits) = digits.split_at(split_point as usize);
let whole_portion = whole_digits.chars().rev().collect::<Vec<char>>().chunks(4)
.map(|x| {
let mut x = x.to_owned();
while x.len() < 4 { x.push('0'); }
x.into_iter().rev().collect::<String>()
})
.rev().collect::<Vec<String>>();
let decimal_portion = decimal_digits.chars().collect::<Vec<char>>().chunks(4)
.map(|x| {
let mut x = x.to_owned();
while x.len() < 4 { x.push('0'); }
x.into_iter().collect::<String>()
})
.collect::<Vec<String>>();
let weight = if whole_portion.is_empty() {
-(decimal_portion.len() as i16)
} else {
whole_portion.len() as i16 - 1
};
let all_groups = whole_portion.into_iter().chain(decimal_portion.into_iter())
.skip_while(|ref x| *x == "0000").collect::<Vec<String>>();
let num_groups = all_groups.len() as u16;
try!(out.write_u16::<BigEndian>(num_groups));
try!(out.write_i16::<BigEndian>(weight));
try!(out.write_u16::<BigEndian>(sign));
try!(out.write_u16::<BigEndian>(scale));
for chunk in all_groups {
let calculated = chunk.parse::<u16>().unwrap();
try!(out.write_u16::<BigEndian>(calculated.to_u16().unwrap()));
}
Ok(IsNull::No)
}
fn accepts(ty: &Type) -> bool {
match *ty {
Type::Numeric => true,
_ => false,
}
}
to_sql_checked!();
}
#[cfg(test)]
fn test_read_type(sql_type: &str, checks: &[&'static str]) {
let conn = match Connection::connect("postgres://paulmason@localhost", TlsMode::None) {
Ok(x) => x,
Err(err) => panic!("{:#?}", err),
};
for &val in checks.iter() {
let stmt = match conn.prepare(&*format!("SELECT {}::{}", val, sql_type)) {
Ok(x) => x,
Err(err) => panic!("{:#?}", err),
};
let result: Decimal = match stmt.query(&[]) {
Ok(x) => x.iter().next().unwrap().get(0),
Err(err) => panic!("{:#?}", err),
};
assert_eq!(val, result.to_string());
}
}
#[cfg(test)]
fn test_write_type(sql_type: &str, checks: &[&'static str]) {
let conn = match Connection::connect("postgres://paulmason@localhost", TlsMode::None) {
Ok(x) => x,
Err(err) => panic!("{:#?}", err),
};
for &val in checks.iter() {
let stmt = match conn.prepare(&*format!("SELECT $1::{}", sql_type)) {
Ok(x) => x,
Err(err) => panic!("{:#?}", err)
};
let number = Decimal::from_str(val).unwrap();
let result: Decimal = match stmt.query(&[&number]) {
Ok(x) => x.iter().next().unwrap().get(0),
Err(err) => panic!("{:#?}", err)
};
assert_eq!(val, result.to_string());
}
}
#[test]
fn test_null() {
let conn = match Connection::connect("postgres://paulmason@localhost", TlsMode::None) {
Ok(x) => x,
Err(err) => panic!("{:#?}", err),
};
let stmt = match conn.prepare(&"SELECT NULL::numeric") {
Ok(x) => x,
Err(err) => panic!("{:#?}", err),
};
let result: Option<Decimal> = match stmt.query(&[]) {
Ok(x) => x.iter().next().unwrap().get(0),
Err(err) => panic!("{:#?}", err),
};
assert_eq!(None, result);
}
#[test]
fn it_can_read_numeric_type() {
test_read_type("NUMERIC(26,6)",
&["3950.123456", "3950", "0.1", "0.01", "0.001", "0.0001", "0.00001", "0.000001",
"1", "-100", "-123.456", "119996.25", "1000000", "9999999.99999"]);
}
#[test]
fn it_can_write_numeric_type() {
test_write_type("NUMERIC(26,6)",
&["3950.123456", "3950", "0.1", "0.01", "0.001", "0.0001", "0.00001", "0.000001",
"1", "-100", "-123.456", "119996.25", "1000000", "9999999.99999"]);
}