alloy_primitives/
postgres.rs

1//! Support for the [`postgres_types`] crate.
2//!
3//! **WARNING**: this module depends entirely on [`postgres_types`, which is not yet stable,
4//! therefore this module is exempt from the semver guarantees of this crate.
5
6use super::{FixedBytes, Sign, Signed};
7use bytes::{BufMut, BytesMut};
8use derive_more::Display;
9use postgres_types::{FromSql, IsNull, ToSql, Type, WrongType, accepts, to_sql_checked};
10use std::{
11    error::Error,
12    iter,
13    str::{FromStr, from_utf8},
14};
15
16/// Converts `FixedBytes` to Postgres Bytea Type.
17impl<const BITS: usize> ToSql for FixedBytes<BITS> {
18    fn to_sql(&self, _: &Type, out: &mut BytesMut) -> Result<IsNull, BoxedError> {
19        out.put_slice(&self[..]);
20        Ok(IsNull::No)
21    }
22
23    accepts!(BYTEA);
24
25    to_sql_checked!();
26}
27
28/// Converts `FixedBytes` From Postgres Bytea Type.
29impl<'a, const BITS: usize> FromSql<'a> for FixedBytes<BITS> {
30    accepts!(BYTEA);
31
32    fn from_sql(_: &Type, raw: &'a [u8]) -> Result<Self, Box<dyn Error + Sync + Send>> {
33        Ok(Self::try_from(raw)?)
34    }
35}
36
37// https://github.com/recmo/uint/blob/6c755ad7cd54a0706d20f11f3f63b0d977af0226/src/support/postgres.rs#L22
38
39type BoxedError = Box<dyn Error + Sync + Send + 'static>;
40
41const fn rem_up(a: usize, b: usize) -> usize {
42    let rem = a % b;
43    if rem > 0 { rem } else { b }
44}
45
46fn last_idx<T: PartialEq>(x: &[T], value: &T) -> usize {
47    x.iter().rposition(|b| b != value).map_or(0, |idx| idx + 1)
48}
49
50fn trim_end_vec<T: PartialEq>(vec: &mut Vec<T>, value: &T) {
51    vec.truncate(last_idx(vec, value));
52}
53
54/// Error when converting to Postgres types.
55#[derive(Clone, Debug, PartialEq, Eq, Display)]
56pub enum ToSqlError {
57    /// The value is too large for the type.
58    #[display("Signed<{_0}> value too large to fit target type {_1}")]
59    Overflow(usize, Type),
60}
61
62impl core::error::Error for ToSqlError {}
63
64/// Convert to Postgres types.
65///
66/// Compatible [Postgres data types][dt] are:
67///
68/// * `BOOL`, `SMALLINT`, `INTEGER`, `BIGINT` which are 1, 16, 32 and 64 bit signed integers
69///   respectively.
70/// * `OID` which is a 32 bit unsigned integer.
71/// * `DECIMAL` and `NUMERIC`, which are variable length.
72/// * `MONEY` which is a 64 bit integer with two decimals.
73/// * `BYTEA`, `BIT`, `VARBIT` interpreted as a big-endian binary number.
74/// * `CHAR`, `VARCHAR`, `TEXT` as `0x`-prefixed big-endian hex strings.
75/// * `JSON`, `JSONB` as a hex string compatible with the Serde serialization.
76///
77/// # Errors
78///
79/// Returns an error when trying to convert to a value that is too small to fit
80/// the number. Note that this depends on the value, not the type, so a
81/// [`Signed<256>`] can be stored in a `SMALLINT` column, as long as the values
82/// are less than $2^{16}$.
83///
84/// # Implementation details
85///
86/// The Postgres binary formats are used in the wire-protocol and the
87/// the `COPY BINARY` command, but they have very little documentation. You are
88/// pointed to the source code, for example this is the implementation of the
89/// the `NUMERIC` type serializer: [`numeric.c`][numeric].
90///
91/// [dt]:https://www.postgresql.org/docs/9.5/datatype.html
92/// [numeric]: https://github.com/postgres/postgres/blob/05a5a1775c89f6beb326725282e7eea1373cbec8/src/backend/utils/adt/numeric.c#L1082
93impl<const BITS: usize, const LIMBS: usize> ToSql for Signed<BITS, LIMBS> {
94    fn to_sql(&self, ty: &Type, out: &mut BytesMut) -> Result<IsNull, BoxedError> {
95        match *ty {
96            // Big-endian simple types
97            // Note `BufMut::put_*` methods write big-endian by default.
98            Type::BOOL => out.put_u8(u8::from(bool::try_from(self.0)?)),
99            Type::INT2 => out.put_i16(self.0.try_into()?),
100            Type::INT4 => out.put_i32(self.0.try_into()?),
101            Type::OID => out.put_u32(self.0.try_into()?),
102            Type::INT8 => out.put_i64(self.0.try_into()?),
103
104            Type::MONEY => {
105                // Like i64, but with two decimals.
106                out.put_i64(
107                    i64::try_from(self.0)?
108                        .checked_mul(100)
109                        .ok_or(ToSqlError::Overflow(BITS, ty.clone()))?,
110                );
111            }
112
113            // Binary strings
114            Type::BYTEA => out.put_slice(&self.0.to_be_bytes_vec()),
115            Type::BIT | Type::VARBIT => {
116                // Bit in little-endian so the first bit is the least significant.
117                // Length must be at least one bit.
118                if BITS == 0 {
119                    if *ty == Type::BIT {
120                        // `bit(0)` is not a valid type, but varbit can be empty.
121                        return Err(Box::new(WrongType::new::<Self>(ty.clone())));
122                    }
123                    out.put_i32(0);
124                } else {
125                    // Bits are output in big-endian order, but padded at the
126                    // least significant end.
127                    let padding = 8 - rem_up(BITS, 8);
128                    out.put_i32(Self::BITS.try_into()?);
129                    let bytes = self.0.as_le_bytes();
130                    let mut bytes = bytes.iter().rev();
131                    let mut shifted = bytes.next().unwrap() << padding;
132                    for byte in bytes {
133                        shifted |= if padding > 0 { byte >> (8 - padding) } else { 0 };
134                        out.put_u8(shifted);
135                        shifted = byte << padding;
136                    }
137                    out.put_u8(shifted);
138                }
139            }
140
141            // Hex strings
142            Type::CHAR | Type::TEXT | Type::VARCHAR => {
143                out.put_slice(format!("{self:#x}").as_bytes());
144            }
145            Type::JSON | Type::JSONB => {
146                if *ty == Type::JSONB {
147                    // Version 1 of JSONB is just plain text JSON.
148                    out.put_u8(1);
149                }
150                out.put_slice(format!("\"{self:#x}\"").as_bytes());
151            }
152
153            // Binary coded decimal types
154            // See <https://github.com/postgres/postgres/blob/05a5a1775c89f6beb326725282e7eea1373cbec8/src/backend/utils/adt/numeric.c#L253>
155            Type::NUMERIC => {
156                // Everything is done in big-endian base 1000 digits.
157                const BASE: u64 = 10000;
158
159                let sign = match self.sign() {
160                    Sign::Positive => 0x0000,
161                    _ => 0x4000,
162                };
163
164                let mut digits: Vec<_> = self.abs().0.to_base_be(BASE).collect();
165                let exponent = digits.len().saturating_sub(1).try_into()?;
166
167                // Trailing zeros are removed.
168                trim_end_vec(&mut digits, &0);
169
170                out.put_i16(digits.len().try_into()?); // Number of digits.
171                out.put_i16(exponent); // Exponent of first digit.
172
173                out.put_i16(sign);
174                out.put_i16(0); // dscale: Number of digits to the right of the decimal point.
175                for digit in digits {
176                    debug_assert!(digit < BASE);
177                    #[allow(clippy::cast_possible_truncation)] // 10000 < i16::MAX
178                    out.put_i16(digit as i16);
179                }
180            }
181
182            // Unsupported types
183            _ => {
184                return Err(Box::new(WrongType::new::<Self>(ty.clone())));
185            }
186        };
187        Ok(IsNull::No)
188    }
189
190    fn accepts(ty: &Type) -> bool {
191        matches!(*ty, |Type::BOOL| Type::CHAR
192            | Type::INT2
193            | Type::INT4
194            | Type::INT8
195            | Type::OID
196            | Type::FLOAT4
197            | Type::FLOAT8
198            | Type::MONEY
199            | Type::NUMERIC
200            | Type::BYTEA
201            | Type::TEXT
202            | Type::VARCHAR
203            | Type::JSON
204            | Type::JSONB
205            | Type::BIT
206            | Type::VARBIT)
207    }
208
209    to_sql_checked!();
210}
211
212/// Error when converting from Postgres types.
213#[derive(Clone, Debug, PartialEq, Eq, Display)]
214pub enum FromSqlError {
215    /// The value is too large for the type.
216    #[display("the value is too large for the Signed type")]
217    Overflow,
218
219    /// The value is not valid for the type.
220    #[display("unexpected data for type {_0}")]
221    ParseError(Type),
222}
223
224impl core::error::Error for FromSqlError {}
225
226impl<'a, const BITS: usize, const LIMBS: usize> FromSql<'a> for Signed<BITS, LIMBS> {
227    fn accepts(ty: &Type) -> bool {
228        <Self as ToSql>::accepts(ty)
229    }
230
231    fn from_sql(ty: &Type, raw: &'a [u8]) -> Result<Self, Box<dyn Error + Sync + Send>> {
232        Ok(match *ty {
233            Type::BOOL => match raw {
234                [0] => Self::ZERO,
235                [1] => Self::try_from(1)?,
236                _ => return Err(Box::new(FromSqlError::ParseError(ty.clone()))),
237            },
238            Type::INT2 => i16::from_be_bytes(raw.try_into()?).try_into()?,
239            Type::INT4 => i32::from_be_bytes(raw.try_into()?).try_into()?,
240            Type::OID => u32::from_be_bytes(raw.try_into()?).try_into()?,
241            Type::INT8 => i64::from_be_bytes(raw.try_into()?).try_into()?,
242            Type::MONEY => (i64::from_be_bytes(raw.try_into()?) / 100).try_into()?,
243
244            // Binary strings
245            Type::BYTEA => Self::try_from_be_slice(raw).ok_or(FromSqlError::Overflow)?,
246            Type::BIT | Type::VARBIT => {
247                // Parse header
248                if raw.len() < 4 {
249                    return Err(Box::new(FromSqlError::ParseError(ty.clone())));
250                }
251                let len: usize = i32::from_be_bytes(raw[..4].try_into()?).try_into()?;
252                let raw = &raw[4..];
253
254                // Shift padding to the other end
255                let padding = 8 - rem_up(len, 8);
256                let mut raw = raw.to_owned();
257                if padding > 0 {
258                    for i in (1..raw.len()).rev() {
259                        raw[i] = (raw[i] >> padding) | (raw[i - 1] << (8 - padding));
260                    }
261                    raw[0] >>= padding;
262                }
263                // Construct from bits
264                Self::try_from_be_slice(&raw).ok_or(FromSqlError::Overflow)?
265            }
266
267            // Hex strings
268            Type::CHAR | Type::TEXT | Type::VARCHAR => Self::from_str(from_utf8(raw)?)?,
269
270            // Hex strings
271            Type::JSON | Type::JSONB => {
272                let raw = if *ty == Type::JSONB {
273                    if raw[0] == 1 {
274                        &raw[1..]
275                    } else {
276                        // Unsupported version
277                        return Err(Box::new(FromSqlError::ParseError(ty.clone())));
278                    }
279                } else {
280                    raw
281                };
282                let str = from_utf8(raw)?;
283                let str = if str.starts_with('"') && str.ends_with('"') {
284                    // Stringified number
285                    &str[1..str.len() - 1]
286                } else {
287                    str
288                };
289                Self::from_str(str)?
290            }
291
292            // Numeric types
293            Type::NUMERIC => {
294                // Parse header
295                if raw.len() < 8 {
296                    return Err(Box::new(FromSqlError::ParseError(ty.clone())));
297                }
298                let digits = i16::from_be_bytes(raw[0..2].try_into()?);
299                let exponent = i16::from_be_bytes(raw[2..4].try_into()?);
300                let sign = i16::from_be_bytes(raw[4..6].try_into()?);
301                let dscale = i16::from_be_bytes(raw[6..8].try_into()?);
302                let raw = &raw[8..];
303                #[allow(clippy::cast_sign_loss)] // Signs are checked
304                if digits < 0
305                    || exponent < 0
306                    || dscale != 0
307                    || digits > exponent + 1
308                    || raw.len() != digits as usize * 2
309                {
310                    return Err(Box::new(FromSqlError::ParseError(ty.clone())));
311                }
312                let mut error = false;
313                let iter = raw.chunks_exact(2).filter_map(|raw| {
314                    if error {
315                        return None;
316                    }
317                    let digit = i16::from_be_bytes(raw.try_into().unwrap());
318                    if !(0..10000).contains(&digit) {
319                        error = true;
320                        return None;
321                    }
322                    #[allow(clippy::cast_sign_loss)] // Signs are checked
323                    Some(digit as u64)
324                });
325                #[allow(clippy::cast_sign_loss)]
326                // Expression can not be negative due to checks above
327                let iter = iter.chain(iter::repeat_n(0, (exponent + 1 - digits) as usize));
328
329                let mut value = Self::from_base_be(10000, iter)?;
330                if sign == 0x4000 {
331                    value = -value;
332                }
333                if error {
334                    return Err(Box::new(FromSqlError::ParseError(ty.clone())));
335                }
336
337                value
338            }
339
340            // Unsupported types
341            _ => return Err(Box::new(WrongType::new::<Self>(ty.clone()))),
342        })
343    }
344}
345
346#[cfg(test)]
347mod test {
348    use super::*;
349
350    use crate::I256;
351
352    #[test]
353    fn positive_i256_from_sql() {
354        assert_eq!(
355            I256::from_sql(
356                &Type::NUMERIC,
357                &[
358                    0x00, 0x01, // ndigits: 1
359                    0x00, 0x00, // weight: 0
360                    0x00, 0x00, // sign: 0x0000 (positive)
361                    0x00, 0x00, // scale: 0
362                    0x00, 0x01, // digit: 1
363                ]
364            )
365            .unwrap(),
366            I256::ONE
367        );
368    }
369
370    #[test]
371    fn positive_i256_to_sql() {
372        let mut bytes = BytesMut::with_capacity(64);
373        I256::ONE.to_sql(&Type::NUMERIC, &mut bytes).unwrap();
374        assert_eq!(
375            *bytes.freeze(),
376            [
377                0x00, 0x01, // ndigits: 1
378                0x00, 0x00, // weight: 0
379                0x00, 0x00, // sign: 0x0000 (positive)
380                0x00, 0x00, // scale: 0
381                0x00, 0x01, // digit: 1
382            ],
383        );
384    }
385
386    #[test]
387    fn negative_i256_from_sql() {
388        assert_eq!(
389            I256::from_sql(
390                &Type::NUMERIC,
391                &[
392                    0x00, 0x01, // ndigits: 1
393                    0x00, 0x00, // weight: 0
394                    0x40, 0x00, // sign: 0x4000 (negative)
395                    0x00, 0x00, // scale: 0
396                    0x00, 0x01, // digit: 1
397                ]
398            )
399            .unwrap(),
400            I256::MINUS_ONE
401        );
402    }
403
404    #[test]
405    fn negative_i256_to_sql() {
406        let mut bytes = BytesMut::with_capacity(64);
407        I256::MINUS_ONE.to_sql(&Type::NUMERIC, &mut bytes).unwrap();
408        assert_eq!(
409            *bytes.freeze(),
410            [
411                0x00, 0x01, // ndigits: 1
412                0x00, 0x00, // weight: 0
413                0x40, 0x00, // sign: 0x4000 (negative)
414                0x00, 0x00, // scale: 0
415                0x00, 0x01, // digit: 1
416            ],
417        );
418    }
419}