rust_decimal 1.0.3

A Decimal Implementation written in pure Rust suitable for financial calculations.
Documentation
use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};

use num::Zero;

use crate::Decimal;

use postgres::{to_sql_checked, types::*};

use std::{error, fmt, io::Cursor, result::*};

use crate::decimal::{div_by_u32, is_all_zero, mul_by_u32};

const DECIMALS: [Decimal; 15] = [
    Decimal::from_parts(1, 0, 0, false, 28),
    Decimal::from_parts(1, 0, 0, false, 24),
    Decimal::from_parts(1, 0, 0, false, 20),
    Decimal::from_parts(1, 0, 0, false, 16),
    Decimal::from_parts(1, 0, 0, false, 12),
    Decimal::from_parts(1, 0, 0, false, 8),
    Decimal::from_parts(1, 0, 0, false, 4),
    Decimal::from_parts(1, 0, 0, false, 0),
    Decimal::from_parts(1_0000, 0, 0, false, 0),
    Decimal::from_parts(1_0000_0000, 0, 0, false, 0),
    Decimal::from_parts(
        1_0000_0000_0000u64 as u32,
        (1_0000_0000_0000u64 >> 32) as u32,
        0,
        false,
        0,
    ),
    Decimal::from_parts(
        1_0000_0000_0000_0000u64 as u32,
        (1_0000_0000_0000_0000u64 >> 32) as u32,
        0,
        false,
        0,
    ),
    Decimal::from_parts(1661992960, 1808227885, 5, false, 0),
    Decimal::from_parts(2701131776, 466537709, 54210, false, 0),
    Decimal::from_parts(268435456, 1042612833, 542101086, false, 0),
];

#[derive(Debug, Clone, Copy)]
pub struct InvalidDecimal;

impl fmt::Display for InvalidDecimal {
    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
        fmt.write_str(error::Error::description(self))
    }
}

impl error::Error for InvalidDecimal {
    fn description(&self) -> &str {
        "Invalid Decimal"
    }
}

impl FromSql for Decimal {
    // Decimals are represented as follows:
    // Header:
    //  u16 numGroups
    //  i16 weightFirstGroup (10000^weight)
    //  u16 sign (0x0000 = positive, 0x4000 = negative, 0xC000 = NaN)
    //  i16 dscale. Number of digits (in base 10) to print after decimal separator
    //
    //  Psuedo code :
    //  const Decimals [
    //          0.0000000000000000000000000001,
    //          0.000000000000000000000001,
    //          0.00000000000000000001,
    //          0.0000000000000001,
    //          0.000000000001,
    //          0.00000001,
    //          0.0001,
    //          1,
    //          10000,
    //          100000000,
    //          1000000000000,
    //          10000000000000000,
    //          100000000000000000000,
    //          1000000000000000000000000,
    //          10000000000000000000000000000
    //  ]
    //  overflow = false
    //  result = 0
    //  for i = 0, weight = weightFirstGroup + 7; i < numGroups; i++, weight--
    //    group = read.u16
    //    if weight < 0 or weight > MaxNum
    //       overflow = true
    //    else
    //       result += Decimals[weight] * group
    //  sign == 0x4000 ? -result : result

    // So if we were to take the number: 3950.123456
    //
    //  Stored on Disk:
    //    00 03 00 00 00 00 00 06 0F 6E 04 D2 15 E0
    //
    //  Number of groups: 00 03
    //  Weight of first group: 00 00
    //  Sign: 00 00
    //  DScale: 00 06
    //
    // 0F 6E = 3950
    //   result = result + 3950 * 1;
    // 04 D2 = 1234
    //   result = result + 1234 * 0.0001;
    // 15 E0 = 5600
    //   result = result + 5600 * 0.00000001;
    //

    fn from_sql(_: &Type, raw: &[u8]) -> Result<Decimal, Box<dyn error::Error + 'static + Sync + Send>> {
        let mut raw = Cursor::new(raw);
        let num_groups = raw.read_u16::<BigEndian>()?;
        let weight = raw.read_i16::<BigEndian>()?; // 10000^weight
                                                   // Sign: 0x0000 = positive, 0x4000 = negative, 0xC000 = NaN
        let sign = raw.read_u16::<BigEndian>()?;
        // Number of digits (in base 10) to print after decimal separator
        let fixed_scale = raw.read_u16::<BigEndian>()? as i32;

        // Read all of the groups
        let mut groups = Vec::new();
        for _ in 0..num_groups as usize {
            let group = raw.read_u16::<BigEndian>()?;
            groups.push(Decimal::new(group as i64, 0));
        }
        groups.reverse();

        // Now process the number
        let mut result = Decimal::zero();
        for (index, group) in groups.iter().enumerate() {
            result = result + (&DECIMALS[index + 7] * group);
        }

        // Finally, adjust for the scale
        let mut scale = (num_groups as i16 - weight - 1) as i32 * 4;
        // Scale could be negative
        if scale < 0 {
            result *= Decimal::new(10i64.pow((scale * -1) as u32), 0);
            scale = 0;
        } else if scale > fixed_scale {
            result /= Decimal::new(10i64.pow((scale - fixed_scale) as u32), 0);
            scale = fixed_scale;
        }

        // Create the decimal
        let neg = sign == 0x4000;
        if result.set_scale(scale as u32).is_err() {
            return Err(Box::new(InvalidDecimal));
        }
        result.set_sign(!neg);

        // Normalize by truncating any trailing 0's from the decimal representation
        Ok(result.normalize())
    }

    fn accepts(ty: &Type) -> bool {
        match *ty {
            NUMERIC => true,
            _ => false,
        }
    }
}

impl ToSql for Decimal {
    fn to_sql(&self, _: &Type, out: &mut Vec<u8>) -> Result<IsNull, Box<dyn error::Error + 'static + Sync + Send>> {
        // If it's zero we can short cut with a u64
        if self.is_zero() {
            out.write_u64::<BigEndian>(0)?;
            return Ok(IsNull::No);
        }
        let sign = if self.is_sign_negative() { 0x4000 } else { 0x0000 };
        let scale = self.scale() as u16;

        let groups_diff = scale & 0x3; // groups_diff = scale % 4
        let mut fractional_groups_count = (scale >> 2) as isize; // fractional_groups_count = scale / 4
        fractional_groups_count += if groups_diff > 0 { 1 } else { 0 };

        let mut mantissa = self.mantissa_array4();

        if groups_diff > 0 {
            let remainder = 4 - groups_diff;
            let power = 10u32.pow(remainder as u32);
            mul_by_u32(&mut mantissa, power);
        }

        // array to store max mantissa of Decimal in Postgres decimal format
        const MAX_GROUP_COUNT: usize = 8;
        let mut groups = [0u16; MAX_GROUP_COUNT];

        let mut num_groups = 0usize;
        while !is_all_zero(&mantissa) {
            let group_digits = div_by_u32(&mut mantissa, 10000) as u16;
            groups[num_groups] = group_digits;
            num_groups += 1;
        }

        let whole_portion_len = num_groups as isize - fractional_groups_count;
        let weight = if whole_portion_len < 0 {
            -(fractional_groups_count as i16)
        } else {
            whole_portion_len as i16 - 1
        };

        // Number of groups
        out.write_u16::<BigEndian>(num_groups as u16)?;
        // Weight of first group
        out.write_i16::<BigEndian>(weight)?;
        // Sign
        out.write_u16::<BigEndian>(sign)?;
        // DScale
        out.write_u16::<BigEndian>(scale)?;
        // Now process the number
        for group in groups[0..num_groups].iter().rev() {
            out.write_u16::<BigEndian>(*group)?;
        }

        Ok(IsNull::No)
    }

    fn accepts(ty: &Type) -> bool {
        match *ty {
            NUMERIC => true,
            _ => false,
        }
    }

    to_sql_checked!();
}

#[cfg(test)]
mod test {
    use super::*;

    use postgres::{Connection, TlsMode};

    use std::str::FromStr;

    pub static TEST_DECIMALS: &[(u32, u32, &str, &str)] = &[
        // precision, scale, sent, expected
        (35, 6, "3950.123456", "3950.123456"),
        (35, 2, "3950.123456", "3950.12"),
        (35, 2, "3950.1256", "3950.13"),
        (10, 2, "3950.123456", "3950.12"),
        (35, 6, "3950", "3950"),
        (4, 0, "3950", "3950"),
        (35, 6, "0.1", "0.1"),
        (35, 6, "0.01", "0.01"),
        (35, 6, "0.001", "0.001"),
        (35, 6, "0.0001", "0.0001"),
        (35, 6, "0.00001", "0.00001"),
        (35, 6, "0.000001", "0.000001"),
        (35, 6, "1", "1"),
        (35, 6, "-100", "-100"),
        (35, 6, "-123.456", "-123.456"),
        (35, 6, "119996.25", "119996.25"),
        (35, 6, "1000000", "1000000"),
        (35, 6, "9999999.99999", "9999999.99999"),
        (35, 6, "12340.56789", "12340.56789"),
        // 0xFFFF_FFFF_FFFF_FFFF_FFFF_FFFF (96 bit)
        (35, 6, "79228162514264337593543950335", "79228162514264337593543950335"),
        // 0x0FFF_FFFF_FFFF_FFFF_FFFF_FFFF (95 bit)
        (35, 6, "4951760157141521099596496895", "4951760157141521099596496895"),
        // 0x1000_0000_0000_0000_0000_0000
        (35, 6, "4951760157141521099596496896", "4951760157141521099596496896"),
        (35, 6, "18446744073709551615", "18446744073709551615"),
        (35, 6, "-18446744073709551615", "-18446744073709551615"),
        (35, 6, "0.10001", "0.10001"),
        (35, 6, "0.12345", "0.12345"),
    ];

    #[test]
    fn ensure_equivalent_decimal_constants() {
        let expected_decimals = [
            Decimal::new(1, 28),
            Decimal::new(1, 24),
            Decimal::new(1, 20),
            Decimal::new(1, 16),
            Decimal::new(1, 12),
            Decimal::new(1, 8),
            Decimal::new(1, 4),
            Decimal::new(1, 0),
            Decimal::new(10000, 0),
            Decimal::new(100000000, 0),
            Decimal::new(1000000000000, 0),
            Decimal::new(10000000000000000, 0),
            Decimal::from_parts(1661992960, 1808227885, 5, false, 0),
            Decimal::from_parts(2701131776, 466537709, 54210, false, 0),
            Decimal::from_parts(268435456, 1042612833, 542101086, false, 0),
        ];

        assert_eq!(&expected_decimals[..], &DECIMALS[..]);
    }

    #[test]
    fn test_null() {
        let conn = match Connection::connect("postgres://postgres@localhost", TlsMode::None) {
            Ok(x) => x,
            Err(err) => panic!("{:#?}", err),
        };

        // Test NULL
        let stmt = match conn.prepare(&"SELECT NULL::numeric") {
            Ok(x) => x,
            Err(err) => panic!("{:#?}", err),
        };
        let result: Option<Decimal> = match stmt.query(&[]) {
            Ok(x) => x.iter().next().unwrap().get(0),
            Err(err) => panic!("{:#?}", err),
        };
        assert_eq!(None, result);
    }

    #[test]
    fn read_numeric_type() {
        let conn = match Connection::connect("postgres://postgres@localhost", TlsMode::None) {
            Ok(x) => x,
            Err(err) => panic!("{:#?}", err),
        };
        for &(precision, scale, sent, expected) in TEST_DECIMALS.iter() {
            let stmt = match conn.prepare(&*format!("SELECT {}::NUMERIC({}, {})", sent, precision, scale)) {
                Ok(x) => x,
                Err(err) => panic!("{:#?}", err),
            };
            let result: Decimal = match stmt.query(&[]) {
                Ok(x) => x.iter().next().unwrap().get(0),
                Err(err) => panic!("{:#?}", err),
            };
            assert_eq!(expected, result.to_string(), "NUMERIC({}, {})", precision, scale);
        }
    }

    #[test]
    fn write_numeric_type() {
        let conn = match Connection::connect("postgres://postgres@localhost", TlsMode::None) {
            Ok(x) => x,
            Err(err) => panic!("{:#?}", err),
        };
        for &(precision, scale, sent, expected) in TEST_DECIMALS.iter() {
            let stmt = match conn.prepare(&*format!("SELECT $1::NUMERIC({}, {})", precision, scale)) {
                Ok(x) => x,
                Err(err) => panic!("{:#?}", err),
            };
            let number = Decimal::from_str(sent).unwrap();
            let result: Decimal = match stmt.query(&[&number]) {
                Ok(x) => x.iter().next().unwrap().get(0),
                Err(err) => panic!("{:#?}", err),
            };
            assert_eq!(expected, result.to_string(), "NUMERIC({}, {})", precision, scale);
        }
    }

    #[test]
    fn numeric_overflow() {
        let tests = [(4, 4, "3950.1234")];
        let conn = match Connection::connect("postgres://postgres@localhost", TlsMode::None) {
            Ok(x) => x,
            Err(err) => panic!("{:#?}", err),
        };
        for &(precision, scale, sent) in tests.iter() {
            let stmt = match conn.prepare(&*format!("SELECT {}::NUMERIC({}, {})", sent, precision, scale)) {
                Ok(x) => x,
                Err(err) => panic!("{:#?}", err),
            };
            match stmt.query(&[]) {
                Ok(_) => panic!(
                    "Expected numeric overflow for {}::NUMERIC({}, {})",
                    sent, precision, scale
                ),
                Err(err) => {
                    assert_eq!("22003", err.code().unwrap().code(), "Unexpected error code");
                }
            };
        }
    }
}