use super::{FixedBytes, Sign, Signed};
use bytes::{BufMut, BytesMut};
use derive_more::Display;
use postgres_types::{FromSql, IsNull, ToSql, Type, WrongType, accepts, to_sql_checked};
use std::{
error::Error,
iter,
str::{FromStr, from_utf8},
};
impl<const BITS: usize> ToSql for FixedBytes<BITS> {
fn to_sql(&self, _: &Type, out: &mut BytesMut) -> Result<IsNull, BoxedError> {
out.put_slice(&self[..]);
Ok(IsNull::No)
}
accepts!(BYTEA);
to_sql_checked!();
}
impl<'a, const BITS: usize> FromSql<'a> for FixedBytes<BITS> {
accepts!(BYTEA);
fn from_sql(_: &Type, raw: &'a [u8]) -> Result<Self, Box<dyn Error + Sync + Send>> {
Ok(Self::try_from(raw)?)
}
}
type BoxedError = Box<dyn Error + Sync + Send + 'static>;
const fn rem_up(a: usize, b: usize) -> usize {
let rem = a % b;
if rem > 0 { rem } else { b }
}
fn last_idx<T: PartialEq>(x: &[T], value: &T) -> usize {
x.iter().rposition(|b| b != value).map_or(0, |idx| idx + 1)
}
fn trim_end_vec<T: PartialEq>(vec: &mut Vec<T>, value: &T) {
vec.truncate(last_idx(vec, value));
}
#[derive(Clone, Debug, PartialEq, Eq, Display)]
pub enum ToSqlError {
#[display("Signed<{_0}> value too large to fit target type {_1}")]
Overflow(usize, Type),
}
impl core::error::Error for ToSqlError {}
impl<const BITS: usize, const LIMBS: usize> ToSql for Signed<BITS, LIMBS> {
fn to_sql(&self, ty: &Type, out: &mut BytesMut) -> Result<IsNull, BoxedError> {
match *ty {
Type::BOOL => out.put_u8(u8::from(bool::try_from(self.0)?)),
Type::INT2 => out.put_i16(self.0.try_into()?),
Type::INT4 => out.put_i32(self.0.try_into()?),
Type::OID => out.put_u32(self.0.try_into()?),
Type::INT8 => out.put_i64(self.0.try_into()?),
Type::MONEY => {
out.put_i64(
i64::try_from(self.0)?
.checked_mul(100)
.ok_or(ToSqlError::Overflow(BITS, ty.clone()))?,
);
}
Type::BYTEA => out.put_slice(&self.0.to_be_bytes_vec()),
Type::BIT | Type::VARBIT => {
if BITS == 0 {
if *ty == Type::BIT {
return Err(Box::new(WrongType::new::<Self>(ty.clone())));
}
out.put_i32(0);
} else {
let padding = 8 - rem_up(BITS, 8);
out.put_i32(Self::BITS.try_into()?);
let bytes = self.0.as_le_bytes();
let mut bytes = bytes.iter().rev();
let mut shifted = bytes.next().unwrap() << padding;
for byte in bytes {
shifted |= if padding > 0 { byte >> (8 - padding) } else { 0 };
out.put_u8(shifted);
shifted = byte << padding;
}
out.put_u8(shifted);
}
}
Type::CHAR | Type::TEXT | Type::VARCHAR => {
out.put_slice(format!("{self:#x}").as_bytes());
}
Type::JSON | Type::JSONB => {
if *ty == Type::JSONB {
out.put_u8(1);
}
out.put_slice(format!("\"{self:#x}\"").as_bytes());
}
Type::NUMERIC => {
const BASE: u64 = 10000;
let sign = match self.sign() {
Sign::Positive => 0x0000,
_ => 0x4000,
};
let mut digits: Vec<_> = self.abs().0.to_base_be(BASE).collect();
let exponent = digits.len().saturating_sub(1).try_into()?;
trim_end_vec(&mut digits, &0);
out.put_i16(digits.len().try_into()?); out.put_i16(exponent);
out.put_i16(sign);
out.put_i16(0); for digit in digits {
debug_assert!(digit < BASE);
#[allow(clippy::cast_possible_truncation)] out.put_i16(digit as i16);
}
}
_ => {
return Err(Box::new(WrongType::new::<Self>(ty.clone())));
}
};
Ok(IsNull::No)
}
fn accepts(ty: &Type) -> bool {
matches!(
*ty,
Type::BOOL
| Type::CHAR
| Type::INT2
| Type::INT4
| Type::INT8
| Type::OID
| Type::FLOAT4
| Type::FLOAT8
| Type::MONEY
| Type::NUMERIC
| Type::BYTEA
| Type::TEXT
| Type::VARCHAR
| Type::JSON
| Type::JSONB
| Type::BIT
| Type::VARBIT
)
}
to_sql_checked!();
}
#[derive(Clone, Debug, PartialEq, Eq, Display)]
pub enum FromSqlError {
#[display("the value is too large for the Signed type")]
Overflow,
#[display("unexpected data for type {_0}")]
ParseError(Type),
}
impl core::error::Error for FromSqlError {}
impl<'a, const BITS: usize, const LIMBS: usize> FromSql<'a> for Signed<BITS, LIMBS> {
fn accepts(ty: &Type) -> bool {
<Self as ToSql>::accepts(ty)
}
fn from_sql(ty: &Type, raw: &'a [u8]) -> Result<Self, Box<dyn Error + Sync + Send>> {
Ok(match *ty {
Type::BOOL => match raw {
[0] => Self::ZERO,
[1] => Self::try_from(1)?,
_ => return Err(Box::new(FromSqlError::ParseError(ty.clone()))),
},
Type::INT2 => i16::from_be_bytes(raw.try_into()?).try_into()?,
Type::INT4 => i32::from_be_bytes(raw.try_into()?).try_into()?,
Type::OID => u32::from_be_bytes(raw.try_into()?).try_into()?,
Type::INT8 => i64::from_be_bytes(raw.try_into()?).try_into()?,
Type::MONEY => (i64::from_be_bytes(raw.try_into()?) / 100).try_into()?,
Type::BYTEA => Self::try_from_be_slice(raw).ok_or(FromSqlError::Overflow)?,
Type::BIT | Type::VARBIT => {
if raw.len() < 4 {
return Err(Box::new(FromSqlError::ParseError(ty.clone())));
}
let len: usize = i32::from_be_bytes(raw[..4].try_into()?).try_into()?;
let raw = &raw[4..];
let padding = 8 - rem_up(len, 8);
let mut raw = raw.to_owned();
if padding > 0 {
for i in (1..raw.len()).rev() {
raw[i] = (raw[i] >> padding) | (raw[i - 1] << (8 - padding));
}
raw[0] >>= padding;
}
Self::try_from_be_slice(&raw).ok_or(FromSqlError::Overflow)?
}
Type::CHAR | Type::TEXT | Type::VARCHAR => Self::from_str(from_utf8(raw)?)?,
Type::JSON | Type::JSONB => {
let raw = if *ty == Type::JSONB {
if raw[0] == 1 {
&raw[1..]
} else {
return Err(Box::new(FromSqlError::ParseError(ty.clone())));
}
} else {
raw
};
let str = from_utf8(raw)?;
let str = if str.starts_with('"') && str.ends_with('"') {
&str[1..str.len() - 1]
} else {
str
};
Self::from_str(str)?
}
Type::NUMERIC => {
if raw.len() < 8 {
return Err(Box::new(FromSqlError::ParseError(ty.clone())));
}
let digits = i16::from_be_bytes(raw[0..2].try_into()?);
let exponent = i16::from_be_bytes(raw[2..4].try_into()?);
let sign = i16::from_be_bytes(raw[4..6].try_into()?);
let dscale = i16::from_be_bytes(raw[6..8].try_into()?);
let raw = &raw[8..];
#[allow(clippy::cast_sign_loss)] if digits < 0
|| exponent < 0
|| dscale != 0
|| digits > exponent + 1
|| raw.len() != digits as usize * 2
{
return Err(Box::new(FromSqlError::ParseError(ty.clone())));
}
let mut error = false;
let iter = raw.chunks_exact(2).filter_map(|raw| {
if error {
return None;
}
let digit = i16::from_be_bytes(raw.try_into().unwrap());
if !(0..10000).contains(&digit) {
error = true;
return None;
}
#[allow(clippy::cast_sign_loss)] Some(digit as u64)
});
#[allow(clippy::cast_sign_loss)]
let iter = iter.chain(iter::repeat_n(0, (exponent + 1 - digits) as usize));
let mut value = Self::from_base_be(10000, iter)?;
if sign == 0x4000 {
value = -value;
}
if error {
return Err(Box::new(FromSqlError::ParseError(ty.clone())));
}
value
}
_ => return Err(Box::new(WrongType::new::<Self>(ty.clone()))),
})
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::I256;
#[test]
fn positive_i256_from_sql() {
assert_eq!(
I256::from_sql(
&Type::NUMERIC,
&[
0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, ]
)
.unwrap(),
I256::ONE
);
}
#[test]
fn positive_i256_to_sql() {
let mut bytes = BytesMut::with_capacity(64);
I256::ONE.to_sql(&Type::NUMERIC, &mut bytes).unwrap();
assert_eq!(
*bytes.freeze(),
[
0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, ],
);
}
#[test]
fn negative_i256_from_sql() {
assert_eq!(
I256::from_sql(
&Type::NUMERIC,
&[
0x00, 0x01, 0x00, 0x00, 0x40, 0x00, 0x00, 0x00, 0x00, 0x01, ]
)
.unwrap(),
I256::MINUS_ONE
);
}
#[test]
fn negative_i256_to_sql() {
let mut bytes = BytesMut::with_capacity(64);
I256::MINUS_ONE.to_sql(&Type::NUMERIC, &mut bytes).unwrap();
assert_eq!(
*bytes.freeze(),
[
0x00, 0x01, 0x00, 0x00, 0x40, 0x00, 0x00, 0x00, 0x00, 0x01, ],
);
}
}