use crate::postgres::common::*;
use crate::Decimal;
use bytes::{BufMut, BytesMut};
use postgres_types::{to_sql_checked, FromSql, IsNull, ToSql, Type};
use std::io::{Cursor, Read};
fn read_two_bytes(cursor: &mut Cursor<&[u8]>) -> std::io::Result<[u8; 2]> {
let mut result = [0; 2];
cursor.read_exact(&mut result)?;
Ok(result)
}
impl<'a> FromSql<'a> for Decimal {
fn from_sql(_: &Type, raw: &[u8]) -> Result<Decimal, Box<dyn std::error::Error + 'static + Sync + Send>> {
let mut raw = Cursor::new(raw);
let num_groups = u16::from_be_bytes(read_two_bytes(&mut raw)?);
let weight = i16::from_be_bytes(read_two_bytes(&mut raw)?); let sign = u16::from_be_bytes(read_two_bytes(&mut raw)?);
let scale = u16::from_be_bytes(read_two_bytes(&mut raw)?);
let mut groups = Vec::new();
for _ in 0..num_groups as usize {
groups.push(u16::from_be_bytes(read_two_bytes(&mut raw)?));
}
let Some(result) = Self::checked_from_postgres(PostgresDecimal {
neg: sign == 0x4000,
weight,
scale,
digits: groups.into_iter(),
}) else {
return Err(Box::new(crate::error::Error::ExceedsMaximumPossibleValue));
};
Ok(result)
}
fn accepts(ty: &Type) -> bool {
matches!(*ty, Type::NUMERIC)
}
}
impl ToSql for Decimal {
fn to_sql(
&self,
_: &Type,
out: &mut BytesMut,
) -> Result<IsNull, Box<dyn std::error::Error + 'static + Sync + Send>> {
let PostgresDecimal {
neg,
weight,
scale,
digits,
} = self.to_postgres();
let num_digits = digits.len();
out.reserve(8 + num_digits * 2);
out.put_u16(num_digits.try_into().unwrap());
out.put_i16(weight);
out.put_u16(if neg { 0x4000 } else { 0x0000 });
out.put_u16(scale);
for digit in digits[0..num_digits].iter() {
out.put_i16(*digit);
}
Ok(IsNull::No)
}
fn accepts(ty: &Type) -> bool {
matches!(*ty, Type::NUMERIC)
}
to_sql_checked!();
}
#[cfg(test)]
mod test {
use super::*;
use ::postgres::{Client, NoTls};
use core::str::FromStr;
fn get_postgres_url() -> String {
if let Ok(url) = std::env::var("POSTGRES_URL") {
return url;
}
"postgres://postgres@localhost".to_string()
}
pub static TEST_DECIMALS: &[(u32, u32, &str, &str)] = &[
(35, 6, "3950.123456", "3950.123456"),
(35, 2, "3950.123456", "3950.12"),
(35, 2, "3950.1256", "3950.13"),
(10, 2, "3950.123456", "3950.12"),
(35, 6, "3950", "3950.000000"),
(4, 0, "3950", "3950"),
(35, 6, "0.1", "0.100000"),
(35, 6, "0.01", "0.010000"),
(35, 6, "0.001", "0.001000"),
(35, 6, "0.0001", "0.000100"),
(35, 6, "0.00001", "0.000010"),
(35, 6, "0.000001", "0.000001"),
(35, 6, "1", "1.000000"),
(35, 6, "-100", "-100.000000"),
(35, 6, "-123.456", "-123.456000"),
(35, 6, "119996.25", "119996.250000"),
(35, 6, "1000000", "1000000.000000"),
(35, 6, "9999999.99999", "9999999.999990"),
(35, 6, "12340.56789", "12340.567890"),
(65, 30, "1.2", "1.2000000000000000000000000000"),
(
65,
30,
"3.141592653589793238462643383279",
"3.1415926535897932384626433833",
),
(
65,
34,
"3.1415926535897932384626433832795028",
"3.1415926535897932384626433833",
),
(
65,
34,
"1.234567890123456789012345678950000",
"1.2345678901234567890123456790",
),
(
65,
34, "1.234567890123456789012345678949999",
"1.2345678901234567890123456789",
),
(35, 0, "79228162514264337593543950335", "79228162514264337593543950335"),
(35, 1, "4951760157141521099596496895", "4951760157141521099596496895.0"),
(35, 1, "4951760157141521099596496896", "4951760157141521099596496896.0"),
(35, 6, "18446744073709551615", "18446744073709551615.000000"),
(35, 6, "-18446744073709551615", "-18446744073709551615.000000"),
(35, 6, "0.10001", "0.100010"),
(35, 6, "0.12345", "0.123450"),
];
#[test]
fn test_null() {
let mut client = match Client::connect(&get_postgres_url(), NoTls) {
Ok(x) => x,
Err(err) => panic!("{:#?}", err),
};
let result: Option<Decimal> = match client.query("SELECT NULL::numeric", &[]) {
Ok(x) => x.first().unwrap().get(0),
Err(err) => panic!("{:#?}", err),
};
assert_eq!(None, result);
}
#[tokio::test]
#[cfg(feature = "tokio-pg")]
async fn async_test_null() {
use futures::future::FutureExt;
use tokio_postgres::connect;
let (client, connection) = connect(&get_postgres_url(), NoTls).await.unwrap();
let connection = connection.map(|e| e.unwrap());
tokio::spawn(connection);
let statement = client.prepare("SELECT NULL::numeric").await.unwrap();
let rows = client.query(&statement, &[]).await.unwrap();
let result: Option<Decimal> = rows.first().unwrap().get(0);
assert_eq!(None, result);
}
#[test]
fn read_very_small_numeric_type() {
let mut client = match Client::connect(&get_postgres_url(), NoTls) {
Ok(x) => x,
Err(err) => panic!("{:#?}", err),
};
let result: Decimal = match client.query("SELECT 1e-130::NUMERIC(130, 0)", &[]) {
Ok(x) => x.first().unwrap().get(0),
Err(err) => panic!("error - {:#?}", err),
};
assert_eq!(Decimal::ZERO, result);
}
#[test]
fn read_small_unconstrained_numeric_type() {
let mut client = match Client::connect(&get_postgres_url(), NoTls) {
Ok(x) => x,
Err(err) => panic!("{:#?}", err),
};
let result: Decimal = match client.query("SELECT 0.100000000000000000000000000001::NUMERIC", &[]) {
Ok(x) => x.first().unwrap().get(0),
Err(err) => panic!("error - {:#?}", err),
};
assert_eq!(result.to_string(), "0.1000000000000000000000000000");
assert_eq!(result.scale(), 28);
}
#[test]
fn read_small_unconstrained_numeric_type_addition() {
let mut client = match Client::connect(&get_postgres_url(), NoTls) {
Ok(x) => x,
Err(err) => panic!("{:#?}", err),
};
let (a, b): (Decimal, Decimal) = match client.query(
"SELECT 0.100000000000000000000000000001::NUMERIC, 0.00000000000014780214::NUMERIC",
&[],
) {
Ok(x) => {
let row = x.first().unwrap();
(row.get(0), row.get(1))
}
Err(err) => panic!("error - {:#?}", err),
};
assert_eq!(a + b, Decimal::from_str("0.1000000000001478021400000000").unwrap());
}
#[test]
fn read_numeric_type() {
let mut client = match Client::connect(&get_postgres_url(), NoTls) {
Ok(x) => x,
Err(err) => panic!("{:#?}", err),
};
for &(precision, scale, sent, expected) in TEST_DECIMALS.iter() {
let result: Decimal =
match client.query(&*format!("SELECT {}::NUMERIC({}, {})", sent, precision, scale), &[]) {
Ok(x) => x.first().unwrap().get(0),
Err(err) => panic!("SELECT {}::NUMERIC({}, {}), error - {:#?}", sent, precision, scale, err),
};
assert_eq!(
expected,
result.to_string(),
"NUMERIC({}, {}) sent: {}",
precision,
scale,
sent
);
}
}
#[tokio::test]
#[cfg(feature = "tokio-pg")]
async fn async_read_numeric_type() {
use futures::future::FutureExt;
use tokio_postgres::connect;
let (client, connection) = connect(&get_postgres_url(), NoTls).await.unwrap();
let connection = connection.map(|e| e.unwrap());
tokio::spawn(connection);
for &(precision, scale, sent, expected) in TEST_DECIMALS.iter() {
let statement = client
.prepare(&format!("SELECT {}::NUMERIC({}, {})", sent, precision, scale))
.await
.unwrap();
let rows = client.query(&statement, &[]).await.unwrap();
let result: Decimal = rows.first().unwrap().get(0);
assert_eq!(expected, result.to_string(), "NUMERIC({}, {})", precision, scale);
}
}
#[test]
fn write_numeric_type() {
let mut client = match Client::connect(&get_postgres_url(), NoTls) {
Ok(x) => x,
Err(err) => panic!("{:#?}", err),
};
for &(precision, scale, sent, expected) in TEST_DECIMALS.iter() {
let number = Decimal::from_str(sent).unwrap();
let result: Decimal =
match client.query(&*format!("SELECT $1::NUMERIC({}, {})", precision, scale), &[&number]) {
Ok(x) => x.first().unwrap().get(0),
Err(err) => panic!("{:#?}", err),
};
assert_eq!(expected, result.to_string(), "NUMERIC({}, {})", precision, scale);
}
}
#[tokio::test]
#[cfg(feature = "tokio-pg")]
async fn async_write_numeric_type() {
use futures::future::FutureExt;
use tokio_postgres::connect;
let (client, connection) = connect(&get_postgres_url(), NoTls).await.unwrap();
let connection = connection.map(|e| e.unwrap());
tokio::spawn(connection);
for &(precision, scale, sent, expected) in TEST_DECIMALS.iter() {
let statement = client
.prepare(&format!("SELECT $1::NUMERIC({}, {})", precision, scale))
.await
.unwrap();
let number = Decimal::from_str(sent).unwrap();
let rows = client.query(&statement, &[&number]).await.unwrap();
let result: Decimal = rows.first().unwrap().get(0);
assert_eq!(expected, result.to_string(), "NUMERIC({}, {})", precision, scale);
}
}
#[test]
fn numeric_overflow() {
let tests = [(4, 4, "3950.1234")];
let mut client = match Client::connect(&get_postgres_url(), NoTls) {
Ok(x) => x,
Err(err) => panic!("{:#?}", err),
};
for &(precision, scale, sent) in tests.iter() {
match client.query(&*format!("SELECT {}::NUMERIC({}, {})", sent, precision, scale), &[]) {
Ok(_) => panic!(
"Expected numeric overflow for {}::NUMERIC({}, {})",
sent, precision, scale
),
Err(err) => {
assert_eq!("22003", err.code().unwrap().code(), "Unexpected error code");
}
};
}
}
#[tokio::test]
#[cfg(feature = "tokio-pg")]
async fn async_numeric_overflow() {
use futures::future::FutureExt;
use tokio_postgres::connect;
let tests = [(4, 4, "3950.1234")];
let (client, connection) = connect(&get_postgres_url(), NoTls).await.unwrap();
let connection = connection.map(|e| e.unwrap());
tokio::spawn(connection);
for &(precision, scale, sent) in tests.iter() {
let statement = client
.prepare(&format!("SELECT {}::NUMERIC({}, {})", sent, precision, scale))
.await
.unwrap();
match client.query(&statement, &[]).await {
Ok(_) => panic!(
"Expected numeric overflow for {}::NUMERIC({}, {})",
sent, precision, scale
),
Err(err) => assert_eq!("22003", err.code().unwrap().code(), "Unexpected error code"),
}
}
}
#[test]
fn numeric_overflow_from_sql() {
let close_to_overflow = Decimal::from_sql(
&Type::NUMERIC,
&[0x00, 0x01, 0x00, 0x07, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01],
);
assert!(close_to_overflow.is_ok());
assert_eq!(close_to_overflow.unwrap().to_string(), "10000000000000000000000000000");
let overflow = Decimal::from_sql(
&Type::NUMERIC,
&[0x00, 0x01, 0x00, 0x07, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0a],
);
assert!(overflow.is_err());
assert_eq!(
overflow.unwrap_err().to_string(),
crate::error::Error::ExceedsMaximumPossibleValue.to_string()
);
}
}