ruint2/support/
postgres.rs

1//! Support for the [`postgres`](https://crates.io/crates/postgres) crate.
2#![cfg(feature = "postgres")]
3#![cfg_attr(has_doc_cfg, doc(cfg(feature = "postgres")))]
4
5use crate::{
6    utils::{rem_up, trim_end_vec},
7    Uint,
8};
9use bytes::{BufMut, BytesMut};
10use postgres_types::{to_sql_checked, FromSql, IsNull, ToSql, Type, WrongType};
11use std::{
12    error::Error,
13    iter,
14    str::{from_utf8, FromStr},
15};
16use thiserror::Error;
17
18type BoxedError = Box<dyn Error + Sync + Send + 'static>;
19
20#[derive(Clone, PartialEq, Eq, Debug, Error)]
21pub enum ToSqlError {
22    #[error("Uint<{0}> value too large to fit target type {1}")]
23    Overflow(usize, Type),
24}
25
26/// Convert to Postgres types.
27///
28/// Compatible [Postgres data types][dt] are:
29///
30/// * `BOOL`, `SMALLINT`, `INTEGER`, `BIGINT` which are 1, 16, 32 and 64 bit
31///   signed integers respectively.
32/// * `OID` which is a 32 bit unsigned integer.
33/// * `FLOAT`, `DOUBLE PRECISION` which are 32 and 64 bit floating point.
34/// * `DECIMAL` and `NUMERIC`, which are variable length.
35/// * `MONEY` which is a 64 bit integer with two decimals.
36/// * `BYTEA`, `BIT`, `VARBIT` interpreted as a big-endian binary number.
37/// * `CHAR`, `VARCHAR`, `TEXT` as `0x`-prefixed big-endian hex strings.
38/// * `JSON`, `JSONB` as a hex string compatible with the Serde serialization.
39///
40/// Note: [`Uint`]s are never null, use [`Option<Uint>`] instead.
41///
42/// # Errors
43///
44/// Returns an error when trying to convert to a value that is too small to fit
45/// the number. Note that this depends on the value, not the type, so a
46/// [`Uint<256>`] can be stored in a `SMALLINT` column, as long as the values
47/// are less than $2^{16}$.
48///
49/// # Implementation details
50///
51/// The Postgres binary formats are used in the wire-protocol and the
52/// the `COPY BINARY` command, but they have very little documentation. You are
53/// pointed to the source code, for example this is the implementation of the
54/// the `NUMERIC` type serializer: [`numeric.c`][numeric].
55///
56/// [dt]:https://www.postgresql.org/docs/9.5/datatype.html
57/// [numeric]: https://github.com/postgres/postgres/blob/05a5a1775c89f6beb326725282e7eea1373cbec8/src/backend/utils/adt/numeric.c#L1082
58impl<const BITS: usize, const LIMBS: usize> ToSql for Uint<BITS, LIMBS> {
59    fn accepts(ty: &Type) -> bool {
60        matches!(*ty, |Type::BOOL| Type::CHAR
61            | Type::INT2
62            | Type::INT4
63            | Type::INT8
64            | Type::OID
65            | Type::FLOAT4
66            | Type::FLOAT8
67            | Type::MONEY
68            | Type::NUMERIC
69            | Type::BYTEA
70            | Type::TEXT
71            | Type::VARCHAR
72            | Type::JSON
73            | Type::JSONB
74            | Type::BIT
75            | Type::VARBIT)
76    }
77
78    // See <https://github.com/sfackler/rust-postgres/blob/38da7fa8fe0067f47b60c147ccdaa214ab5f5211/postgres-protocol/src/types/mod.rs>
79    fn to_sql(&self, ty: &Type, out: &mut BytesMut) -> Result<IsNull, BoxedError> {
80        match *ty {
81            // Big-endian simple types
82            // Note `BufMut::put_*` methods write big-endian by default.
83            Type::BOOL => out.put_u8(u8::from(bool::try_from(*self)?)),
84            Type::INT2 => out.put_i16(self.try_into()?),
85            Type::INT4 => out.put_i32(self.try_into()?),
86            Type::OID => out.put_u32(self.try_into()?),
87            Type::INT8 => out.put_i64(self.try_into()?),
88            Type::FLOAT4 => out.put_f32(self.try_into()?),
89            Type::FLOAT8 => out.put_f64(self.try_into()?),
90            Type::MONEY => {
91                // Like i64, but with two decimals.
92                out.put_i64(
93                    i64::try_from(self)?
94                        .checked_mul(100)
95                        .ok_or_else(|| ToSqlError::Overflow(BITS, ty.clone()))?,
96                );
97            }
98
99            // Binary strings
100            Type::BYTEA => out.put_slice(&self.to_be_bytes_vec()),
101            Type::BIT | Type::VARBIT => {
102                // Bit in little-endian so the the first bit is the least significant.
103                // Length must be at least one bit.
104                if BITS == 0 {
105                    if *ty == Type::BIT {
106                        // `bit(0)` is not a valid type, but varbit can be empty.
107                        return Err(Box::new(WrongType::new::<Self>(ty.clone())));
108                    }
109                    out.put_i32(0);
110                } else {
111                    // Bits are output in big-endian order, but padded at the
112                    // least significant end.
113                    let padding = 8 - rem_up(BITS, 8);
114                    out.put_i32(Self::BITS.try_into()?);
115                    let bytes = self.as_le_bytes();
116                    let mut bytes = bytes.iter().rev();
117                    let mut shifted = bytes.next().unwrap() << padding;
118                    for byte in bytes {
119                        shifted |= if padding > 0 {
120                            byte >> (8 - padding)
121                        } else {
122                            0
123                        };
124                        out.put_u8(shifted);
125                        shifted = byte << padding;
126                    }
127                    out.put_u8(shifted);
128                }
129            }
130
131            // Hex strings
132            Type::CHAR | Type::TEXT | Type::VARCHAR => {
133                out.put_slice(format!("{self:#x}").as_bytes());
134            }
135            Type::JSON | Type::JSONB => {
136                if *ty == Type::JSONB {
137                    // Version 1 of JSONB is just plain text JSON.
138                    out.put_u8(1);
139                }
140                out.put_slice(format!("\"{self:#x}\"").as_bytes());
141            }
142
143            // Binary coded decimal types
144            // See <https://github.com/postgres/postgres/blob/05a5a1775c89f6beb326725282e7eea1373cbec8/src/backend/utils/adt/numeric.c#L253>
145            Type::NUMERIC => {
146                // Everything is done in big-endian base 1000 digits.
147                const BASE: u64 = 10000;
148                let mut digits: Vec<_> = self.to_base_be(BASE).collect();
149                let exponent = digits.len().saturating_sub(1).try_into()?;
150
151                // Trailing zeros are removed.
152                trim_end_vec(&mut digits, &0);
153
154                out.put_i16(digits.len().try_into()?); // Number of digits.
155                out.put_i16(exponent); // Exponent of first digit.
156                out.put_i16(0); // sign: 0x0000 = positive, 0x4000 = negative.
157                out.put_i16(0); // dscale: Number of digits to the right of the decimal point.
158                for digit in digits {
159                    debug_assert!(digit < BASE);
160                    #[allow(clippy::cast_possible_truncation)] // 10000 < i16::MAX
161                    out.put_i16(digit as i16);
162                }
163            }
164
165            // Unsupported types
166            _ => {
167                return Err(Box::new(WrongType::new::<Self>(ty.clone())));
168            }
169        };
170        Ok(IsNull::No)
171    }
172
173    to_sql_checked!();
174}
175
176#[derive(Clone, PartialEq, Eq, Debug, Error)]
177pub enum FromSqlError {
178    #[error("The value is too large for the Uint type")]
179    Overflow,
180
181    #[error("Unexpected data for type {0}")]
182    ParseError(Type),
183}
184
185/// Convert from Postgres types.
186///
187/// See [`ToSql`][Self::to_sql] for details.
188impl<'a, const BITS: usize, const LIMBS: usize> FromSql<'a> for Uint<BITS, LIMBS> {
189    fn accepts(ty: &Type) -> bool {
190        <Self as ToSql>::accepts(ty)
191    }
192
193    fn from_sql(ty: &Type, raw: &'a [u8]) -> Result<Self, Box<dyn Error + Sync + Send>> {
194        Ok(match *ty {
195            Type::BOOL => match raw {
196                [0] => Self::ZERO,
197                [1] => Self::try_from(1)?,
198                _ => return Err(Box::new(FromSqlError::ParseError(ty.clone()))),
199            },
200            Type::INT2 => i16::from_be_bytes(raw.try_into()?).try_into()?,
201            Type::INT4 => i32::from_be_bytes(raw.try_into()?).try_into()?,
202            Type::OID => u32::from_be_bytes(raw.try_into()?).try_into()?,
203            Type::INT8 => i64::from_be_bytes(raw.try_into()?).try_into()?,
204            Type::FLOAT4 => f32::from_be_bytes(raw.try_into()?).try_into()?,
205            Type::FLOAT8 => f64::from_be_bytes(raw.try_into()?).try_into()?,
206            Type::MONEY => (i64::from_be_bytes(raw.try_into()?) / 100).try_into()?,
207
208            // Binary strings
209            Type::BYTEA => Self::try_from_be_slice(raw).ok_or(FromSqlError::Overflow)?,
210            Type::BIT | Type::VARBIT => {
211                // Parse header
212                if raw.len() < 4 {
213                    return Err(Box::new(FromSqlError::ParseError(ty.clone())));
214                }
215                let len: usize = i32::from_be_bytes(raw[..4].try_into()?).try_into()?;
216                let raw = &raw[4..];
217
218                // Shift padding to the other end
219                let padding = 8 - rem_up(len, 8);
220                let mut raw = raw.to_owned();
221                if padding > 0 {
222                    for i in (1..raw.len()).rev() {
223                        raw[i] = raw[i] >> padding | raw[i - 1] << (8 - padding);
224                    }
225                    raw[0] >>= padding;
226                }
227                // Construct from bits
228                Self::try_from_be_slice(&raw).ok_or(FromSqlError::Overflow)?
229            }
230
231            // Hex strings
232            Type::CHAR | Type::TEXT | Type::VARCHAR => Self::from_str(from_utf8(raw)?)?,
233
234            // Hex strings
235            Type::JSON | Type::JSONB => {
236                let raw = if *ty == Type::JSONB {
237                    if raw[0] == 1 {
238                        &raw[1..]
239                    } else {
240                        // Unsupported version
241                        return Err(Box::new(FromSqlError::ParseError(ty.clone())));
242                    }
243                } else {
244                    raw
245                };
246                let str = from_utf8(raw)?;
247                let str = if str.starts_with('"') && str.ends_with('"') {
248                    // Stringified number
249                    &str[1..str.len() - 1]
250                } else {
251                    str
252                };
253                Self::from_str(str)?
254            }
255
256            // Numeric types
257            Type::NUMERIC => {
258                // Parse header
259                if raw.len() < 8 {
260                    return Err(Box::new(FromSqlError::ParseError(ty.clone())));
261                }
262                let digits = i16::from_be_bytes(raw[0..2].try_into()?);
263                let exponent = i16::from_be_bytes(raw[2..4].try_into()?);
264                let sign = i16::from_be_bytes(raw[4..6].try_into()?);
265                let dscale = i16::from_be_bytes(raw[6..8].try_into()?);
266                let raw = &raw[8..];
267                #[allow(clippy::cast_sign_loss)] // Signs are checked
268                if digits < 0
269                    || exponent < 0
270                    || sign != 0x0000
271                    || dscale != 0
272                    || digits > exponent + 1
273                    || raw.len() != digits as usize * 2
274                {
275                    return Err(Box::new(FromSqlError::ParseError(ty.clone())));
276                }
277                let mut error = false;
278                let iter = raw.chunks_exact(2).filter_map(|raw| {
279                    if error {
280                        return None;
281                    }
282                    let digit = i16::from_be_bytes(raw.try_into().unwrap());
283                    if !(0..10000).contains(&digit) {
284                        error = true;
285                        return None;
286                    }
287                    #[allow(clippy::cast_sign_loss)] // Signs are checked
288                    Some(digit as u64)
289                });
290                #[allow(clippy::cast_sign_loss)]
291                // Expression can not be negative due to checks above
292                let iter = iter.chain(iter::repeat(0).take((exponent + 1 - digits) as usize));
293
294                let value = Self::from_base_be(10000, iter)?;
295                if error {
296                    return Err(Box::new(FromSqlError::ParseError(ty.clone())));
297                }
298                value
299            }
300
301            // Unsupported types
302            _ => return Err(Box::new(WrongType::new::<Self>(ty.clone()))),
303        })
304    }
305}
306
307#[cfg(test)]
308mod tests {
309    use super::*;
310    use crate::{const_for, nbytes, nlimbs};
311    use approx::assert_ulps_eq;
312    use hex_literal::hex;
313    use postgres::{Client, NoTls};
314    use proptest::{proptest, test_runner::Config as ProptestConfig};
315    use std::{io::Read, sync::Mutex};
316
317    #[test]
318    fn test_basic() {
319        #[allow(clippy::unreadable_literal)]
320        const N: Uint<256, 4> = Uint::from_limbs([
321            0xa8ec92344438aaf4_u64,
322            0x9819ebdbd1faaab1_u64,
323            0x573b1a7064c19c1a_u64,
324            0xc85ef7d79691fe79_u64,
325        ]);
326        #[allow(clippy::needless_pass_by_value)]
327        fn bytes(ty: Type) -> Vec<u8> {
328            let mut out = BytesMut::new();
329            N.to_sql(&ty, &mut out).unwrap();
330            out.to_vec()
331        }
332        assert_eq!(bytes(Type::FLOAT4), hex!("7f800000")); // +inf
333        assert_eq!(bytes(Type::FLOAT8), hex!("4fe90bdefaf2d240"));
334        assert_eq!(bytes(Type::NUMERIC), hex!("0014001300000000000902760e3620f115a21c3b029709bc11e60b3e10d10d6900d123400def1c45091a147900f012f4"));
335        assert_eq!(
336            bytes(Type::BYTEA),
337            hex!("c85ef7d79691fe79573b1a7064c19c1a9819ebdbd1faaab1a8ec92344438aaf4")
338        );
339        assert_eq!(
340            bytes(Type::BIT),
341            hex!("00000100c85ef7d79691fe79573b1a7064c19c1a9819ebdbd1faaab1a8ec92344438aaf4")
342        );
343        assert_eq!(
344            bytes(Type::VARBIT),
345            hex!("00000100c85ef7d79691fe79573b1a7064c19c1a9819ebdbd1faaab1a8ec92344438aaf4")
346        );
347        assert_eq!(bytes(Type::CHAR), hex!("307863383565663764373936393166653739353733623161373036346331396331613938313965626462643166616161623161386563393233343434333861616634"));
348        assert_eq!(bytes(Type::TEXT), hex!("307863383565663764373936393166653739353733623161373036346331396331613938313965626462643166616161623161386563393233343434333861616634"));
349        assert_eq!(bytes(Type::VARCHAR), hex!("307863383565663764373936393166653739353733623161373036346331396331613938313965626462643166616161623161386563393233343434333861616634"));
350    }
351
352    #[test]
353    fn test_roundtrip() {
354        const_for!(BITS in SIZES {
355            const LIMBS: usize = nlimbs(BITS);
356            type U = Uint<BITS, LIMBS>;
357            proptest!(|(value: U)| {
358                let mut serialized = BytesMut::new();
359
360                if f32::from(value).is_finite() {
361                    serialized.clear();
362                    if value.to_sql(&Type::FLOAT4, &mut serialized).is_ok() {
363                        // println!("testing {:?} {}", value, Type::FLOAT4);
364                        let deserialized = U::from_sql(&Type::FLOAT4, &serialized).unwrap();
365                        assert_ulps_eq!(f32::from(value), f32::from(deserialized), max_ulps = 4);
366                    }
367                }
368                if f64::from(value).is_finite() {
369                    serialized.clear();
370                    if value.to_sql(&Type::FLOAT8, &mut serialized).is_ok() {
371                        // println!("testing {:?} {}", value, Type::FLOAT8);
372                        let deserialized = U::from_sql(&Type::FLOAT8, &serialized).unwrap();
373                        assert_ulps_eq!(f64::from(value), f64::from(deserialized), max_ulps = 4);
374                    }
375                }
376                for ty in &[/*Type::BOOL, Type::INT2, Type::INT4, Type::INT8, Type::OID, Type::MONEY, Type::BYTEA, Type::CHAR, Type::TEXT, Type::VARCHAR, Type::JSON, Type::JSONB, Type::NUMERIC,*/ Type::BIT, Type::VARBIT] {
377                    serialized.clear();
378                    if value.to_sql(ty, &mut serialized).is_ok() {
379                        // println!("testing {:?} {}", value, ty);
380                        let deserialized = U::from_sql(ty, &serialized).unwrap();
381                        assert_eq!(deserialized, value);
382                    }
383                }
384            });
385        });
386    }
387
388    // Query the binary encoding of an SQL expression
389    fn get_binary(client: &mut Client, expr: &str) -> Vec<u8> {
390        let query = format!("COPY (SELECT {expr}) TO STDOUT WITH BINARY;");
391
392        // See <https://www.postgresql.org/docs/current/sql-copy.html>
393        let mut reader = client.copy_out(&query).unwrap();
394        let mut buf = Vec::new();
395        reader.read_to_end(&mut buf).unwrap();
396
397        // Parse header
398        let buf = {
399            const HEADER: &[u8] = b"PGCOPY\n\xff\r\n\0";
400            assert_eq!(&buf[..11], HEADER);
401            &buf[11 + 4..]
402        };
403
404        // Skip extension headers (must be zero length)
405        assert_eq!(&buf[..4], &0_u32.to_be_bytes());
406        let buf = &buf[4..];
407
408        // Tuple field count must be one
409        assert_eq!(&buf[..2], &1_i16.to_be_bytes());
410        let buf = &buf[2..];
411
412        // Field length
413        let len = u32::from_be_bytes(buf[..4].try_into().unwrap()) as usize;
414        let buf = &buf[4..];
415
416        // Field data
417        let data = &buf[..len];
418        let buf = &buf[len..];
419
420        // Trailer must be -1_i16
421        assert_eq!(&buf, &(-1_i16).to_be_bytes());
422
423        data.to_owned()
424    }
425
426    fn test_to<const BITS: usize, const LIMBS: usize>(
427        client: &Mutex<Client>,
428        value: Uint<BITS, LIMBS>,
429        ty: &Type,
430    ) {
431        println!("testing {value:?} {ty}");
432
433        // Encode value locally
434        let mut serialized = BytesMut::new();
435        let result = value.to_sql(ty, &mut serialized);
436        if result.is_err() {
437            // Skip values that are out of range for the type
438            return;
439        }
440        // Skip floating point infinities
441        if ty == &Type::FLOAT4 && f32::from(value).is_infinite() {
442            return;
443        }
444        if ty == &Type::FLOAT8 && f64::from(value).is_infinite() {
445            return;
446        }
447        // dbg!(hex::encode(&serialized));
448
449        // Fetch ground truth value from Postgres
450        let expr = match *ty {
451            Type::BIT => format!(
452                "B'{value:b}'::bit({bits})",
453                value = value,
454                bits = if BITS == 0 { 1 } else { BITS },
455            ),
456            Type::VARBIT => format!("B'{value:b}'::varbit"),
457            Type::BYTEA => format!("'\\x{value:x}'::bytea"),
458            Type::CHAR => format!("'{value:#x}'::char({})", 2 + 2 * nbytes(BITS)),
459            Type::TEXT | Type::VARCHAR => format!("'{value:#x}'::{}", ty.name()),
460            Type::JSON | Type::JSONB => format!("'\"{value:#x}\"'::{}", ty.name()),
461            _ => format!("{value}::{}", ty.name()),
462        };
463        // dbg!(&expr);
464        let ground_truth = {
465            let mut client = client.lock().unwrap();
466            get_binary(&mut client, &expr)
467        };
468        // dbg!(hex::encode(&ground_truth));
469
470        // Compare with ground truth, for float we allow tiny rounding error
471        if ty == &Type::FLOAT4 {
472            let serialized = f32::from_be_bytes(serialized.as_ref().try_into().unwrap());
473            let ground_truth = f32::from_be_bytes(ground_truth.try_into().unwrap());
474            assert_ulps_eq!(serialized, ground_truth, max_ulps = 4);
475        } else if ty == &Type::FLOAT8 {
476            let serialized = f64::from_be_bytes(serialized.as_ref().try_into().unwrap());
477            let ground_truth = f64::from_be_bytes(ground_truth.try_into().unwrap());
478            assert_ulps_eq!(serialized, ground_truth, max_ulps = 4);
479        } else {
480            // Check that the value is exactly the same as the ground truth
481            assert_eq!(serialized, ground_truth);
482        }
483    }
484
485    // This test requires a live postgresql server.
486    // To start a server, run:
487    //
488    //     docker run -it --rm -e POSTGRES_PASSWORD=postgres -p 5432:5432 postgres
489    //
490    // Then run the test using:
491    //
492    //    PROPTEST_CASES=1000 cargo test --all-features -- --include-ignored
493    // --nocapture postgres
494    //
495    #[test]
496    #[ignore]
497    fn test_postgres() {
498        // docker run -it --rm -e POSTGRES_PASSWORD=postgres -p 5432:5432 postgres
499        let client = Client::connect("postgresql://postgres:postgres@localhost", NoTls).unwrap();
500        let client = Mutex::new(client);
501
502        const_for!(BITS in SIZES {
503            const LIMBS: usize = nlimbs(BITS);
504
505            // By default generates 256 random values per bit size. Configurable
506            // with the `PROPTEST_CASES` env variable.
507            let mut config = ProptestConfig::default();
508            // No point in running many values for small sizes
509            if BITS < 4 { config.cases = 16; };
510
511            proptest!(config, |(value: Uint<BITS, LIMBS>)| {
512
513                // Test based on which types value will fit
514                let bits = value.bit_len();
515                if bits <= 1 {
516                    test_to(&client, value, &Type::BOOL);
517                }
518                if bits <= 15 {
519                    test_to(&client, value, &Type::INT2);
520                }
521                if bits <= 31 {
522                    test_to(&client, value, &Type::INT4);
523                }
524                if bits <= 32 {
525                    test_to(&client, value, &Type::OID);
526                }
527                if bits <= 50 {
528                    test_to(&client, value, &Type::MONEY);
529                }
530                if bits <= 63 {
531                    test_to(&client, value, &Type::INT8);
532                }
533
534                // Floating points always work, except when the exponent
535                // overflows. We map that to +∞, mut SQL rejects it. This
536                // is handled in the `test_to` function.
537                test_to(&client, value, &Type::FLOAT4);
538                test_to(&client, value, &Type::FLOAT8);
539
540                // Types that work for any size
541                for ty in &[Type::NUMERIC, Type::BIT, Type::VARBIT, Type::BYTEA, Type::CHAR, Type::TEXT, Type::VARCHAR, Type::JSON, Type::JSONB] {
542                    test_to(&client, value, ty);
543                }
544
545            });
546        });
547    }
548}