1use super::{FixedBytes, Sign, Signed};
7use bytes::{BufMut, BytesMut};
8use derive_more::Display;
9use postgres_types::{FromSql, IsNull, ToSql, Type, WrongType, accepts, to_sql_checked};
10use std::{
11 error::Error,
12 iter,
13 str::{FromStr, from_utf8},
14};
15
16impl<const BITS: usize> ToSql for FixedBytes<BITS> {
18 fn to_sql(&self, _: &Type, out: &mut BytesMut) -> Result<IsNull, BoxedError> {
19 out.put_slice(&self[..]);
20 Ok(IsNull::No)
21 }
22
23 accepts!(BYTEA);
24
25 to_sql_checked!();
26}
27
28impl<'a, const BITS: usize> FromSql<'a> for FixedBytes<BITS> {
30 accepts!(BYTEA);
31
32 fn from_sql(_: &Type, raw: &'a [u8]) -> Result<Self, Box<dyn Error + Sync + Send>> {
33 Ok(Self::try_from(raw)?)
34 }
35}
36
37type BoxedError = Box<dyn Error + Sync + Send + 'static>;
40
41const fn rem_up(a: usize, b: usize) -> usize {
42 let rem = a % b;
43 if rem > 0 { rem } else { b }
44}
45
46fn last_idx<T: PartialEq>(x: &[T], value: &T) -> usize {
47 x.iter().rposition(|b| b != value).map_or(0, |idx| idx + 1)
48}
49
50fn trim_end_vec<T: PartialEq>(vec: &mut Vec<T>, value: &T) {
51 vec.truncate(last_idx(vec, value));
52}
53
54#[derive(Clone, Debug, PartialEq, Eq, Display)]
56pub enum ToSqlError {
57 #[display("Signed<{_0}> value too large to fit target type {_1}")]
59 Overflow(usize, Type),
60}
61
62impl core::error::Error for ToSqlError {}
63
64impl<const BITS: usize, const LIMBS: usize> ToSql for Signed<BITS, LIMBS> {
94 fn to_sql(&self, ty: &Type, out: &mut BytesMut) -> Result<IsNull, BoxedError> {
95 match *ty {
96 Type::BOOL => out.put_u8(u8::from(bool::try_from(self.0)?)),
99 Type::INT2 => out.put_i16(self.0.try_into()?),
100 Type::INT4 => out.put_i32(self.0.try_into()?),
101 Type::OID => out.put_u32(self.0.try_into()?),
102 Type::INT8 => out.put_i64(self.0.try_into()?),
103
104 Type::MONEY => {
105 out.put_i64(
107 i64::try_from(self.0)?
108 .checked_mul(100)
109 .ok_or(ToSqlError::Overflow(BITS, ty.clone()))?,
110 );
111 }
112
113 Type::BYTEA => out.put_slice(&self.0.to_be_bytes_vec()),
115 Type::BIT | Type::VARBIT => {
116 if BITS == 0 {
119 if *ty == Type::BIT {
120 return Err(Box::new(WrongType::new::<Self>(ty.clone())));
122 }
123 out.put_i32(0);
124 } else {
125 let padding = 8 - rem_up(BITS, 8);
128 out.put_i32(Self::BITS.try_into()?);
129 let bytes = self.0.as_le_bytes();
130 let mut bytes = bytes.iter().rev();
131 let mut shifted = bytes.next().unwrap() << padding;
132 for byte in bytes {
133 shifted |= if padding > 0 { byte >> (8 - padding) } else { 0 };
134 out.put_u8(shifted);
135 shifted = byte << padding;
136 }
137 out.put_u8(shifted);
138 }
139 }
140
141 Type::CHAR | Type::TEXT | Type::VARCHAR => {
143 out.put_slice(format!("{self:#x}").as_bytes());
144 }
145 Type::JSON | Type::JSONB => {
146 if *ty == Type::JSONB {
147 out.put_u8(1);
149 }
150 out.put_slice(format!("\"{self:#x}\"").as_bytes());
151 }
152
153 Type::NUMERIC => {
156 const BASE: u64 = 10000;
158
159 let sign = match self.sign() {
160 Sign::Positive => 0x0000,
161 _ => 0x4000,
162 };
163
164 let mut digits: Vec<_> = self.abs().0.to_base_be(BASE).collect();
165 let exponent = digits.len().saturating_sub(1).try_into()?;
166
167 trim_end_vec(&mut digits, &0);
169
170 out.put_i16(digits.len().try_into()?); out.put_i16(exponent); out.put_i16(sign);
174 out.put_i16(0); for digit in digits {
176 debug_assert!(digit < BASE);
177 #[allow(clippy::cast_possible_truncation)] out.put_i16(digit as i16);
179 }
180 }
181
182 _ => {
184 return Err(Box::new(WrongType::new::<Self>(ty.clone())));
185 }
186 };
187 Ok(IsNull::No)
188 }
189
190 fn accepts(ty: &Type) -> bool {
191 matches!(*ty, |Type::BOOL| Type::CHAR
192 | Type::INT2
193 | Type::INT4
194 | Type::INT8
195 | Type::OID
196 | Type::FLOAT4
197 | Type::FLOAT8
198 | Type::MONEY
199 | Type::NUMERIC
200 | Type::BYTEA
201 | Type::TEXT
202 | Type::VARCHAR
203 | Type::JSON
204 | Type::JSONB
205 | Type::BIT
206 | Type::VARBIT)
207 }
208
209 to_sql_checked!();
210}
211
212#[derive(Clone, Debug, PartialEq, Eq, Display)]
214pub enum FromSqlError {
215 #[display("the value is too large for the Signed type")]
217 Overflow,
218
219 #[display("unexpected data for type {_0}")]
221 ParseError(Type),
222}
223
224impl core::error::Error for FromSqlError {}
225
226impl<'a, const BITS: usize, const LIMBS: usize> FromSql<'a> for Signed<BITS, LIMBS> {
227 fn accepts(ty: &Type) -> bool {
228 <Self as ToSql>::accepts(ty)
229 }
230
231 fn from_sql(ty: &Type, raw: &'a [u8]) -> Result<Self, Box<dyn Error + Sync + Send>> {
232 Ok(match *ty {
233 Type::BOOL => match raw {
234 [0] => Self::ZERO,
235 [1] => Self::try_from(1)?,
236 _ => return Err(Box::new(FromSqlError::ParseError(ty.clone()))),
237 },
238 Type::INT2 => i16::from_be_bytes(raw.try_into()?).try_into()?,
239 Type::INT4 => i32::from_be_bytes(raw.try_into()?).try_into()?,
240 Type::OID => u32::from_be_bytes(raw.try_into()?).try_into()?,
241 Type::INT8 => i64::from_be_bytes(raw.try_into()?).try_into()?,
242 Type::MONEY => (i64::from_be_bytes(raw.try_into()?) / 100).try_into()?,
243
244 Type::BYTEA => Self::try_from_be_slice(raw).ok_or(FromSqlError::Overflow)?,
246 Type::BIT | Type::VARBIT => {
247 if raw.len() < 4 {
249 return Err(Box::new(FromSqlError::ParseError(ty.clone())));
250 }
251 let len: usize = i32::from_be_bytes(raw[..4].try_into()?).try_into()?;
252 let raw = &raw[4..];
253
254 let padding = 8 - rem_up(len, 8);
256 let mut raw = raw.to_owned();
257 if padding > 0 {
258 for i in (1..raw.len()).rev() {
259 raw[i] = (raw[i] >> padding) | (raw[i - 1] << (8 - padding));
260 }
261 raw[0] >>= padding;
262 }
263 Self::try_from_be_slice(&raw).ok_or(FromSqlError::Overflow)?
265 }
266
267 Type::CHAR | Type::TEXT | Type::VARCHAR => Self::from_str(from_utf8(raw)?)?,
269
270 Type::JSON | Type::JSONB => {
272 let raw = if *ty == Type::JSONB {
273 if raw[0] == 1 {
274 &raw[1..]
275 } else {
276 return Err(Box::new(FromSqlError::ParseError(ty.clone())));
278 }
279 } else {
280 raw
281 };
282 let str = from_utf8(raw)?;
283 let str = if str.starts_with('"') && str.ends_with('"') {
284 &str[1..str.len() - 1]
286 } else {
287 str
288 };
289 Self::from_str(str)?
290 }
291
292 Type::NUMERIC => {
294 if raw.len() < 8 {
296 return Err(Box::new(FromSqlError::ParseError(ty.clone())));
297 }
298 let digits = i16::from_be_bytes(raw[0..2].try_into()?);
299 let exponent = i16::from_be_bytes(raw[2..4].try_into()?);
300 let sign = i16::from_be_bytes(raw[4..6].try_into()?);
301 let dscale = i16::from_be_bytes(raw[6..8].try_into()?);
302 let raw = &raw[8..];
303 #[allow(clippy::cast_sign_loss)] if digits < 0
305 || exponent < 0
306 || dscale != 0
307 || digits > exponent + 1
308 || raw.len() != digits as usize * 2
309 {
310 return Err(Box::new(FromSqlError::ParseError(ty.clone())));
311 }
312 let mut error = false;
313 let iter = raw.chunks_exact(2).filter_map(|raw| {
314 if error {
315 return None;
316 }
317 let digit = i16::from_be_bytes(raw.try_into().unwrap());
318 if !(0..10000).contains(&digit) {
319 error = true;
320 return None;
321 }
322 #[allow(clippy::cast_sign_loss)] Some(digit as u64)
324 });
325 #[allow(clippy::cast_sign_loss)]
326 let iter = iter.chain(iter::repeat_n(0, (exponent + 1 - digits) as usize));
328
329 let mut value = Self::from_base_be(10000, iter)?;
330 if sign == 0x4000 {
331 value = -value;
332 }
333 if error {
334 return Err(Box::new(FromSqlError::ParseError(ty.clone())));
335 }
336
337 value
338 }
339
340 _ => return Err(Box::new(WrongType::new::<Self>(ty.clone()))),
342 })
343 }
344}
345
346#[cfg(test)]
347mod test {
348 use super::*;
349
350 use crate::I256;
351
352 #[test]
353 fn positive_i256_from_sql() {
354 assert_eq!(
355 I256::from_sql(
356 &Type::NUMERIC,
357 &[
358 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, ]
364 )
365 .unwrap(),
366 I256::ONE
367 );
368 }
369
370 #[test]
371 fn positive_i256_to_sql() {
372 let mut bytes = BytesMut::with_capacity(64);
373 I256::ONE.to_sql(&Type::NUMERIC, &mut bytes).unwrap();
374 assert_eq!(
375 *bytes.freeze(),
376 [
377 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, ],
383 );
384 }
385
386 #[test]
387 fn negative_i256_from_sql() {
388 assert_eq!(
389 I256::from_sql(
390 &Type::NUMERIC,
391 &[
392 0x00, 0x01, 0x00, 0x00, 0x40, 0x00, 0x00, 0x00, 0x00, 0x01, ]
398 )
399 .unwrap(),
400 I256::MINUS_ONE
401 );
402 }
403
404 #[test]
405 fn negative_i256_to_sql() {
406 let mut bytes = BytesMut::with_capacity(64);
407 I256::MINUS_ONE.to_sql(&Type::NUMERIC, &mut bytes).unwrap();
408 assert_eq!(
409 *bytes.freeze(),
410 [
411 0x00, 0x01, 0x00, 0x00, 0x40, 0x00, 0x00, 0x00, 0x00, 0x01, ],
417 );
418 }
419}