1#![cfg(feature = "postgres")]
4#![cfg_attr(docsrs, doc(cfg(feature = "postgres")))]
5
6use crate::{
7 utils::{rem_up, trim_end_vec},
8 Uint,
9};
10use bytes::{BufMut, BytesMut};
11use core::{
12 error::Error,
13 iter,
14 str::{from_utf8, FromStr},
15};
16use postgres_types::{to_sql_checked, FromSql, IsNull, ToSql, Type, WrongType};
17use thiserror::Error;
18
19type BoxedError = Box<dyn Error + Sync + Send + 'static>;
20
21#[derive(Clone, PartialEq, Eq, Debug, Error)]
22pub enum ToSqlError {
23 #[error("Uint<{0}> value too large to fit target type {1}")]
24 Overflow(usize, Type),
25}
26
27impl<const BITS: usize, const LIMBS: usize> ToSql for Uint<BITS, LIMBS> {
60 fn accepts(ty: &Type) -> bool {
61 matches!(*ty, |Type::BOOL| Type::CHAR
62 | Type::INT2
63 | Type::INT4
64 | Type::INT8
65 | Type::OID
66 | Type::FLOAT4
67 | Type::FLOAT8
68 | Type::MONEY
69 | Type::NUMERIC
70 | Type::BYTEA
71 | Type::TEXT
72 | Type::VARCHAR
73 | Type::JSON
74 | Type::JSONB
75 | Type::BIT
76 | Type::VARBIT)
77 }
78
79 fn to_sql(&self, ty: &Type, out: &mut BytesMut) -> Result<IsNull, BoxedError> {
81 match *ty {
82 Type::BOOL => out.put_u8(u8::from(bool::try_from(*self)?)),
85 Type::INT2 => out.put_i16(self.try_into()?),
86 Type::INT4 => out.put_i32(self.try_into()?),
87 Type::OID => out.put_u32(self.try_into()?),
88 Type::INT8 => out.put_i64(self.try_into()?),
89 Type::FLOAT4 => out.put_f32(self.into()),
90 Type::FLOAT8 => out.put_f64(self.into()),
91 Type::MONEY => {
92 out.put_i64(
94 i64::try_from(self)?
95 .checked_mul(100)
96 .ok_or_else(|| ToSqlError::Overflow(BITS, ty.clone()))?,
97 );
98 }
99
100 Type::BYTEA => out.put_slice(&self.to_be_bytes_vec()),
102 Type::BIT | Type::VARBIT => {
103 if BITS == 0 {
106 if *ty == Type::BIT {
107 return Err(Box::new(WrongType::new::<Self>(ty.clone())));
109 }
110 out.put_i32(0);
111 } else {
112 let padding = 8 - rem_up(BITS, 8);
115 out.put_i32(Self::BITS.try_into()?);
116 let bytes = self.as_le_bytes();
117 let mut bytes = bytes.iter().rev();
118 let mut shifted = bytes.next().unwrap() << padding;
119 for byte in bytes {
120 shifted |= if padding > 0 {
121 byte >> (8 - padding)
122 } else {
123 0
124 };
125 out.put_u8(shifted);
126 shifted = byte << padding;
127 }
128 out.put_u8(shifted);
129 }
130 }
131
132 Type::CHAR | Type::TEXT | Type::VARCHAR => {
134 out.put_slice(format!("{self:#x}").as_bytes());
135 }
136 Type::JSON | Type::JSONB => {
137 if *ty == Type::JSONB {
138 out.put_u8(1);
140 }
141 out.put_slice(format!("\"{self:#x}\"").as_bytes());
142 }
143
144 Type::NUMERIC => {
147 const BASE: u64 = 10000;
149 let mut digits: Vec<_> = self.to_base_be(BASE).collect();
150 let exponent = digits.len().saturating_sub(1).try_into()?;
151
152 trim_end_vec(&mut digits, &0);
154
155 out.put_i16(digits.len().try_into()?); out.put_i16(exponent); out.put_i16(0); out.put_i16(0); for digit in digits {
160 debug_assert!(digit < BASE);
161 #[allow(clippy::cast_possible_truncation)] out.put_i16(digit as i16);
163 }
164 }
165
166 _ => {
168 return Err(Box::new(WrongType::new::<Self>(ty.clone())));
169 }
170 }
171 Ok(IsNull::No)
172 }
173
174 to_sql_checked!();
175}
176
177#[derive(Clone, PartialEq, Eq, Debug, Error)]
178pub enum FromSqlError {
179 #[error("The value is too large for the Uint type")]
180 Overflow,
181
182 #[error("Unexpected data for type {0}")]
183 ParseError(Type),
184}
185
186impl<'a, const BITS: usize, const LIMBS: usize> FromSql<'a> for Uint<BITS, LIMBS> {
190 fn accepts(ty: &Type) -> bool {
191 <Self as ToSql>::accepts(ty)
192 }
193
194 fn from_sql(ty: &Type, raw: &'a [u8]) -> Result<Self, Box<dyn Error + Sync + Send>> {
195 Ok(match *ty {
196 Type::BOOL => match raw {
197 [0] => Self::ZERO,
198 [1] => Self::try_from(1)?,
199 _ => return Err(Box::new(FromSqlError::ParseError(ty.clone()))),
200 },
201 Type::INT2 => i16::from_be_bytes(raw.try_into()?).try_into()?,
202 Type::INT4 => i32::from_be_bytes(raw.try_into()?).try_into()?,
203 Type::OID => u32::from_be_bytes(raw.try_into()?).try_into()?,
204 Type::INT8 => i64::from_be_bytes(raw.try_into()?).try_into()?,
205 Type::FLOAT4 => f32::from_be_bytes(raw.try_into()?).try_into()?,
206 Type::FLOAT8 => f64::from_be_bytes(raw.try_into()?).try_into()?,
207 Type::MONEY => (i64::from_be_bytes(raw.try_into()?) / 100).try_into()?,
208
209 Type::BYTEA => Self::try_from_be_slice(raw).ok_or(FromSqlError::Overflow)?,
211 Type::BIT | Type::VARBIT => {
212 if raw.len() < 4 {
214 return Err(Box::new(FromSqlError::ParseError(ty.clone())));
215 }
216 let len: usize = i32::from_be_bytes(raw[..4].try_into()?).try_into()?;
217 let raw = &raw[4..];
218
219 let padding = 8 - rem_up(len, 8);
221 let mut raw = raw.to_owned();
222 if padding > 0 {
223 for i in (1..raw.len()).rev() {
224 raw[i] = (raw[i] >> padding) | (raw[i - 1] << (8 - padding));
225 }
226 raw[0] >>= padding;
227 }
228 Self::try_from_be_slice(&raw).ok_or(FromSqlError::Overflow)?
230 }
231
232 Type::CHAR | Type::TEXT | Type::VARCHAR => Self::from_str(from_utf8(raw)?)?,
234
235 Type::JSON | Type::JSONB => {
237 let raw = if *ty == Type::JSONB {
238 if raw[0] == 1 {
239 &raw[1..]
240 } else {
241 return Err(Box::new(FromSqlError::ParseError(ty.clone())));
243 }
244 } else {
245 raw
246 };
247 let str = from_utf8(raw)?;
248 let str = if str.starts_with('"') && str.ends_with('"') {
249 &str[1..str.len() - 1]
251 } else {
252 str
253 };
254 Self::from_str(str)?
255 }
256
257 Type::NUMERIC => {
259 if raw.len() < 8 {
261 return Err(Box::new(FromSqlError::ParseError(ty.clone())));
262 }
263 let digits = i16::from_be_bytes(raw[0..2].try_into()?);
264 let exponent = i16::from_be_bytes(raw[2..4].try_into()?);
265 let sign = i16::from_be_bytes(raw[4..6].try_into()?);
266 let dscale = i16::from_be_bytes(raw[6..8].try_into()?);
267 let raw = &raw[8..];
268 #[allow(clippy::cast_sign_loss)] if digits < 0
270 || exponent < 0
271 || sign != 0x0000
272 || dscale != 0
273 || digits > exponent + 1
274 || raw.len() != digits as usize * 2
275 {
276 return Err(Box::new(FromSqlError::ParseError(ty.clone())));
277 }
278 let mut error = false;
279 let iter = raw.chunks_exact(2).filter_map(|raw| {
280 if error {
281 return None;
282 }
283 let digit = i16::from_be_bytes(raw.try_into().unwrap());
284 if !(0..10000).contains(&digit) {
285 error = true;
286 return None;
287 }
288 #[allow(clippy::cast_sign_loss)] Some(digit as u64)
290 });
291 #[allow(clippy::cast_sign_loss)]
292 let iter = iter.chain(iter::repeat(0).take((exponent + 1 - digits) as usize));
294
295 let value = Self::from_base_be(10000, iter)?;
296 if error {
297 return Err(Box::new(FromSqlError::ParseError(ty.clone())));
298 }
299 value
300 }
301
302 _ => return Err(Box::new(WrongType::new::<Self>(ty.clone()))),
304 })
305 }
306}
307
308#[cfg(test)]
309mod tests {
310 use super::*;
311 use crate::{const_for, nbytes, nlimbs};
312 use approx::assert_ulps_eq;
313 use hex_literal::hex;
314 use postgres::{Client, NoTls};
315 use proptest::{proptest, test_runner::Config as ProptestConfig};
316 use std::{io::Read, sync::Mutex};
317
318 #[test]
319 fn test_basic() {
320 #[allow(clippy::unreadable_literal)]
321 const N: Uint<256, 4> = Uint::from_limbs([
322 0xa8ec92344438aaf4_u64,
323 0x9819ebdbd1faaab1_u64,
324 0x573b1a7064c19c1a_u64,
325 0xc85ef7d79691fe79_u64,
326 ]);
327 #[allow(clippy::needless_pass_by_value)]
328 fn bytes(ty: Type) -> Vec<u8> {
329 let mut out = BytesMut::new();
330 N.to_sql(&ty, &mut out).unwrap();
331 out.to_vec()
332 }
333 assert_eq!(bytes(Type::FLOAT4), hex!("7f800000")); assert_eq!(bytes(Type::FLOAT8), hex!("4fe90bdefaf2d240"));
335 assert_eq!(bytes(Type::NUMERIC), hex!("0014001300000000000902760e3620f115a21c3b029709bc11e60b3e10d10d6900d123400def1c45091a147900f012f4"));
336 assert_eq!(
337 bytes(Type::BYTEA),
338 hex!("c85ef7d79691fe79573b1a7064c19c1a9819ebdbd1faaab1a8ec92344438aaf4")
339 );
340 assert_eq!(
341 bytes(Type::BIT),
342 hex!("00000100c85ef7d79691fe79573b1a7064c19c1a9819ebdbd1faaab1a8ec92344438aaf4")
343 );
344 assert_eq!(
345 bytes(Type::VARBIT),
346 hex!("00000100c85ef7d79691fe79573b1a7064c19c1a9819ebdbd1faaab1a8ec92344438aaf4")
347 );
348 assert_eq!(bytes(Type::CHAR), hex!("307863383565663764373936393166653739353733623161373036346331396331613938313965626462643166616161623161386563393233343434333861616634"));
349 assert_eq!(bytes(Type::TEXT), hex!("307863383565663764373936393166653739353733623161373036346331396331613938313965626462643166616161623161386563393233343434333861616634"));
350 assert_eq!(bytes(Type::VARCHAR), hex!("307863383565663764373936393166653739353733623161373036346331396331613938313965626462643166616161623161386563393233343434333861616634"));
351 }
352
353 #[test]
354 fn test_roundtrip() {
355 const_for!(BITS in SIZES {
356 const LIMBS: usize = nlimbs(BITS);
357 type U = Uint<BITS, LIMBS>;
358 proptest!(|(value: U)| {
359 let mut serialized = BytesMut::new();
360
361 if f32::from(value).is_finite() {
362 serialized.clear();
363 if value.to_sql(&Type::FLOAT4, &mut serialized).is_ok() {
364 let deserialized = U::from_sql(&Type::FLOAT4, &serialized).unwrap();
366 assert_ulps_eq!(f32::from(value), f32::from(deserialized), max_ulps = 4);
367 }
368 }
369 if f64::from(value).is_finite() {
370 serialized.clear();
371 if value.to_sql(&Type::FLOAT8, &mut serialized).is_ok() {
372 let deserialized = U::from_sql(&Type::FLOAT8, &serialized).unwrap();
374 assert_ulps_eq!(f64::from(value), f64::from(deserialized), max_ulps = 4);
375 }
376 }
377 for ty in &[Type::BIT, Type::VARBIT] {
378 serialized.clear();
379 if value.to_sql(ty, &mut serialized).is_ok() {
380 let deserialized = U::from_sql(ty, &serialized).unwrap();
382 assert_eq!(deserialized, value);
383 }
384 }
385 });
386 });
387 }
388
389 fn get_binary(client: &mut Client, expr: &str) -> Vec<u8> {
391 let query = format!("COPY (SELECT {expr}) TO STDOUT WITH BINARY;");
392
393 let mut reader = client.copy_out(&query).unwrap();
395 let mut buf = Vec::new();
396 reader.read_to_end(&mut buf).unwrap();
397
398 let buf = {
400 const HEADER: &[u8] = b"PGCOPY\n\xff\r\n\0";
401 assert_eq!(&buf[..11], HEADER);
402 &buf[11 + 4..]
403 };
404
405 assert_eq!(&buf[..4], &0_u32.to_be_bytes());
407 let buf = &buf[4..];
408
409 assert_eq!(&buf[..2], &1_i16.to_be_bytes());
411 let buf = &buf[2..];
412
413 let len = u32::from_be_bytes(buf[..4].try_into().unwrap()) as usize;
415 let buf = &buf[4..];
416
417 let data = &buf[..len];
419 let buf = &buf[len..];
420
421 assert_eq!(&buf, &(-1_i16).to_be_bytes());
423
424 data.to_owned()
425 }
426
427 fn test_to<const BITS: usize, const LIMBS: usize>(
428 client: &Mutex<Client>,
429 value: Uint<BITS, LIMBS>,
430 ty: &Type,
431 ) {
432 println!("testing {value:?} {ty}");
433
434 let mut serialized = BytesMut::new();
436 let result = value.to_sql(ty, &mut serialized);
437 if result.is_err() {
438 return;
440 }
441 if ty == &Type::FLOAT4 && f32::from(value).is_infinite() {
443 return;
444 }
445 if ty == &Type::FLOAT8 && f64::from(value).is_infinite() {
446 return;
447 }
448
449 let expr = match *ty {
451 Type::BIT => format!(
452 "B'{value:b}'::bit({bits})",
453 value = value,
454 bits = if BITS == 0 { 1 } else { BITS },
455 ),
456 Type::VARBIT => format!("B'{value:b}'::varbit"),
457 Type::BYTEA => format!("'\\x{value:x}'::bytea"),
458 Type::CHAR => format!("'{value:#x}'::char({})", 2 + 2 * nbytes(BITS)),
459 Type::TEXT | Type::VARCHAR => format!("'{value:#x}'::{}", ty.name()),
460 Type::JSON | Type::JSONB => format!("'\"{value:#x}\"'::{}", ty.name()),
461 _ => format!("{value}::{}", ty.name()),
462 };
463 let ground_truth = {
464 let mut client = client.lock().unwrap();
465 get_binary(&mut client, &expr)
466 };
467
468 if ty == &Type::FLOAT4 {
470 let serialized = f32::from_be_bytes(serialized.as_ref().try_into().unwrap());
471 let ground_truth = f32::from_be_bytes(ground_truth.try_into().unwrap());
472 assert_ulps_eq!(serialized, ground_truth, max_ulps = 4);
473 } else if ty == &Type::FLOAT8 {
474 let serialized = f64::from_be_bytes(serialized.as_ref().try_into().unwrap());
475 let ground_truth = f64::from_be_bytes(ground_truth.try_into().unwrap());
476 assert_ulps_eq!(serialized, ground_truth, max_ulps = 4);
477 } else {
478 assert_eq!(serialized, ground_truth);
480 }
481 }
482
483 #[test]
494 #[ignore = "requires a live postgresql server"]
495 fn test_postgres() {
496 let client = Client::connect("postgresql://postgres:postgres@localhost", NoTls).unwrap();
498 let client = Mutex::new(client);
499
500 const_for!(BITS in SIZES {
501 const LIMBS: usize = nlimbs(BITS);
502
503 let mut config = ProptestConfig::default();
506 if BITS < 4 { config.cases = 16; }
508
509 proptest!(config, |(value: Uint<BITS, LIMBS>)| {
510
511 let bits = value.bit_len();
513 if bits <= 1 {
514 test_to(&client, value, &Type::BOOL);
515 }
516 if bits <= 15 {
517 test_to(&client, value, &Type::INT2);
518 }
519 if bits <= 31 {
520 test_to(&client, value, &Type::INT4);
521 }
522 if bits <= 32 {
523 test_to(&client, value, &Type::OID);
524 }
525 if bits <= 50 {
526 test_to(&client, value, &Type::MONEY);
527 }
528 if bits <= 63 {
529 test_to(&client, value, &Type::INT8);
530 }
531
532 test_to(&client, value, &Type::FLOAT4);
536 test_to(&client, value, &Type::FLOAT8);
537
538 for ty in &[Type::NUMERIC, Type::BIT, Type::VARBIT, Type::BYTEA, Type::CHAR, Type::TEXT, Type::VARCHAR, Type::JSON, Type::JSONB] {
540 test_to(&client, value, ty);
541 }
542
543 });
544 });
545 }
546}