1use std::{
2 borrow::Borrow,
3 clone::Clone,
4 fmt::{Debug, Display, Formatter, LowerHex, Result as FmtResult},
5 ops::Deref,
6 str::FromStr,
7};
8
9use deepsize::{Context, DeepSizeOf};
10#[cfg(feature = "diesel")]
11use diesel::{
12 deserialize::{self, FromSql, FromSqlRow},
13 expression::AsExpression,
14 pg::Pg,
15 serialize::{self, ToSql},
16 sql_types::Binary,
17};
18use rand::Rng;
19use serde::{Deserialize, Serialize};
20use thiserror::Error;
21
22use crate::serde_primitives::hex_bytes;
23
24#[derive(Clone, Default, PartialEq, Eq, Hash, Ord, PartialOrd, Serialize, Deserialize)]
26#[cfg_attr(feature = "diesel", derive(AsExpression, FromSqlRow,))]
27#[cfg_attr(feature = "diesel", diesel(sql_type = Binary))]
28pub struct Bytes(#[serde(with = "hex_bytes")] pub bytes::Bytes);
29
30impl DeepSizeOf for Bytes {
31 fn deep_size_of_children(&self, _ctx: &mut Context) -> usize {
32 self.0.len()
38 }
39}
40
41fn bytes_to_hex(b: &Bytes) -> String {
42 hex::encode(b.0.as_ref())
43}
44
45impl Debug for Bytes {
46 fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
47 write!(f, "Bytes(0x{})", bytes_to_hex(self))
48 }
49}
50
51impl Display for Bytes {
52 fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
53 write!(f, "0x{}", bytes_to_hex(self))
54 }
55}
56
57impl LowerHex for Bytes {
58 fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
59 write!(f, "0x{}", bytes_to_hex(self))
60 }
61}
62
63impl Bytes {
64 pub fn new() -> Self {
65 Self(bytes::Bytes::new())
66 }
67 pub fn to_vec(&self) -> Vec<u8> {
81 self.as_ref().to_vec()
82 }
83
84 pub fn lpad(&self, length: usize, pad_byte: u8) -> Bytes {
109 let mut padded_vec = vec![pad_byte; length.saturating_sub(self.len())];
110 padded_vec.extend_from_slice(self.as_ref());
111
112 Bytes(bytes::Bytes::from(padded_vec))
113 }
114
115 pub fn rpad(&self, length: usize, pad_byte: u8) -> Bytes {
140 let mut padded_vec = self.to_vec();
141 padded_vec.resize(length, pad_byte);
142
143 Bytes(bytes::Bytes::from(padded_vec))
144 }
145
146 pub fn zero(length: usize) -> Bytes {
163 Bytes::from(vec![0u8; length])
164 }
165
166 pub fn random(length: usize) -> Bytes {
183 let mut data = vec![0u8; length];
184 rand::thread_rng().fill(&mut data[..]);
185 Bytes::from(data)
186 }
187
188 pub fn is_zero(&self) -> bool {
201 self.as_ref().iter().all(|b| *b == 0)
202 }
203}
204
205impl Deref for Bytes {
206 type Target = [u8];
207
208 #[inline]
209 fn deref(&self) -> &[u8] {
210 self.as_ref()
211 }
212}
213
214impl AsRef<[u8]> for Bytes {
215 fn as_ref(&self) -> &[u8] {
216 self.0.as_ref()
217 }
218}
219
220impl Borrow<[u8]> for Bytes {
221 fn borrow(&self) -> &[u8] {
222 self.as_ref()
223 }
224}
225
226impl IntoIterator for Bytes {
227 type Item = u8;
228 type IntoIter = bytes::buf::IntoIter<bytes::Bytes>;
229
230 fn into_iter(self) -> Self::IntoIter {
231 self.0.into_iter()
232 }
233}
234
235impl<'a> IntoIterator for &'a Bytes {
236 type Item = &'a u8;
237 type IntoIter = core::slice::Iter<'a, u8>;
238
239 fn into_iter(self) -> Self::IntoIter {
240 self.as_ref().iter()
241 }
242}
243
244impl From<&[u8]> for Bytes {
245 fn from(src: &[u8]) -> Self {
246 Self(bytes::Bytes::copy_from_slice(src))
247 }
248}
249
250impl From<bytes::Bytes> for Bytes {
251 fn from(src: bytes::Bytes) -> Self {
252 Self(src)
253 }
254}
255
256impl From<Bytes> for bytes::Bytes {
257 fn from(src: Bytes) -> Self {
258 src.0
259 }
260}
261
262impl From<Vec<u8>> for Bytes {
263 fn from(src: Vec<u8>) -> Self {
264 Self(src.into())
265 }
266}
267
268impl From<Bytes> for Vec<u8> {
269 fn from(value: Bytes) -> Self {
270 value.to_vec()
271 }
272}
273
274impl<const N: usize> From<[u8; N]> for Bytes {
275 fn from(src: [u8; N]) -> Self {
276 src.to_vec().into()
277 }
278}
279
280impl<'a, const N: usize> From<&'a [u8; N]> for Bytes {
281 fn from(src: &'a [u8; N]) -> Self {
282 src.to_vec().into()
283 }
284}
285
286impl PartialEq<[u8]> for Bytes {
287 fn eq(&self, other: &[u8]) -> bool {
288 self.as_ref() == other
289 }
290}
291
292impl PartialEq<Bytes> for [u8] {
293 fn eq(&self, other: &Bytes) -> bool {
294 *other == *self
295 }
296}
297
298impl PartialEq<Vec<u8>> for Bytes {
299 fn eq(&self, other: &Vec<u8>) -> bool {
300 self.as_ref() == &other[..]
301 }
302}
303
304impl PartialEq<Bytes> for Vec<u8> {
305 fn eq(&self, other: &Bytes) -> bool {
306 *other == *self
307 }
308}
309
310impl PartialEq<bytes::Bytes> for Bytes {
311 fn eq(&self, other: &bytes::Bytes) -> bool {
312 other == self.as_ref()
313 }
314}
315
316#[derive(Debug, Clone, Error)]
317#[error("Failed to parse bytes: {0}")]
318pub struct ParseBytesError(String);
319
320impl FromStr for Bytes {
321 type Err = ParseBytesError;
322
323 fn from_str(value: &str) -> Result<Self, Self::Err> {
324 if let Some(value) = value.strip_prefix("0x") {
325 hex::decode(value)
326 } else {
327 hex::decode(value)
328 }
329 .map(Into::into)
330 .map_err(|e| ParseBytesError(format!("Invalid hex: {e}")))
331 }
332}
333
334impl From<&str> for Bytes {
335 fn from(value: &str) -> Self {
336 value.parse().unwrap()
337 }
338}
339
340#[cfg(feature = "diesel")]
341impl ToSql<Binary, Pg> for Bytes {
342 fn to_sql<'b>(&'b self, out: &mut serialize::Output<'b, '_, Pg>) -> serialize::Result {
343 let bytes_slice: &[u8] = &self.0;
344 <&[u8] as ToSql<Binary, Pg>>::to_sql(&bytes_slice, &mut out.reborrow())
345 }
346}
347
348#[cfg(feature = "diesel")]
349impl FromSql<Binary, Pg> for Bytes {
350 fn from_sql(
351 bytes: <diesel::pg::Pg as diesel::backend::Backend>::RawValue<'_>,
352 ) -> deserialize::Result<Self> {
353 let byte_vec: Vec<u8> = <Vec<u8> as FromSql<Binary, Pg>>::from_sql(bytes)?;
354 Ok(Bytes(bytes::Bytes::from(byte_vec)))
355 }
356}
357
358macro_rules! impl_from_uint_for_bytes {
359 ($($t:ty),*) => {
360 $(
361 impl From<$t> for Bytes {
362 fn from(src: $t) -> Self {
363 let size = std::mem::size_of::<$t>();
364 let mut buf = vec![0u8; size];
365 buf.copy_from_slice(&src.to_be_bytes());
366
367 Self(bytes::Bytes::from(buf))
368 }
369 }
370 )*
371 };
372}
373
374impl_from_uint_for_bytes!(u8, u16, u32, u64, u128);
375
376macro_rules! impl_from_bytes_for_uint {
377 ($($t:ty),*) => {
378 $(
379 impl From<Bytes> for $t {
380 fn from(src: Bytes) -> Self {
381 let bytes_slice = src.as_ref();
382
383 let mut buf = [0u8; std::mem::size_of::<$t>()];
385
386 buf[std::mem::size_of::<$t>() - bytes_slice.len()..].copy_from_slice(bytes_slice);
388
389 <$t>::from_be_bytes(buf)
391 }
392 }
393 )*
394 };
395}
396
397impl_from_bytes_for_uint!(u8, u16, u32, u64, u128);
398
399macro_rules! impl_from_bytes_for_signed_int {
400 ($($t:ty),*) => {
401 $(
402 impl From<Bytes> for $t {
403 fn from(src: Bytes) -> Self {
404 let bytes_slice = src.as_ref();
405
406 let mut buf = if bytes_slice.get(0).map_or(false, |&b| b & 0x80 != 0) {
408 [0xFFu8; std::mem::size_of::<$t>()] } else {
410 [0x00u8; std::mem::size_of::<$t>()] };
412
413 buf[std::mem::size_of::<$t>() - bytes_slice.len()..].copy_from_slice(bytes_slice);
415
416 <$t>::from_be_bytes(buf)
418 }
419 }
420 )*
421 };
422}
423
424impl_from_bytes_for_signed_int!(i8, i16, i32, i64, i128);
425
426#[cfg(test)]
427mod tests {
428 use super::*;
429
430 #[test]
431 fn test_from_bytes() {
432 let b = bytes::Bytes::from("0123456789abcdef");
433 let wrapped_b = Bytes::from(b.clone());
434 let expected = Bytes(b);
435
436 assert_eq!(wrapped_b, expected);
437 }
438
439 #[test]
440 fn test_from_slice() {
441 let arr = [1, 35, 69, 103, 137, 171, 205, 239];
442 let b = Bytes::from(&arr);
443 let expected = Bytes(bytes::Bytes::from(arr.to_vec()));
444
445 assert_eq!(b, expected);
446 }
447
448 #[test]
449 fn hex_formatting() {
450 let b = Bytes::from(vec![1, 35, 69, 103, 137, 171, 205, 239]);
451 let expected = String::from("0x0123456789abcdef");
452 assert_eq!(format!("{b:x}"), expected);
453 assert_eq!(format!("{b}"), expected);
454 }
455
456 #[test]
457 fn test_from_str() {
458 let b = Bytes::from_str("0x1213");
459 assert!(b.is_ok());
460 let b = b.unwrap();
461 assert_eq!(b.as_ref(), hex::decode("1213").unwrap());
462
463 let b = Bytes::from_str("1213");
464 let b = b.unwrap();
465 assert_eq!(b.as_ref(), hex::decode("1213").unwrap());
466 }
467
468 #[test]
469 fn test_debug_formatting() {
470 let b = Bytes::from(vec![1, 35, 69, 103, 137, 171, 205, 239]);
471 assert_eq!(format!("{b:?}"), "Bytes(0x0123456789abcdef)");
472 assert_eq!(format!("{b:#?}"), "Bytes(0x0123456789abcdef)");
473 }
474
475 #[test]
476 fn test_to_vec() {
477 let vec = vec![1, 35, 69, 103, 137, 171, 205, 239];
478 let b = Bytes::from(vec.clone());
479
480 assert_eq!(b.to_vec(), vec);
481 }
482
483 #[test]
484 fn test_vec_partialeq() {
485 let vec = vec![1, 35, 69, 103, 137, 171, 205, 239];
486 let b = Bytes::from(vec.clone());
487 assert_eq!(b, vec);
488 assert_eq!(vec, b);
489
490 let wrong_vec = vec![1, 3, 52, 137];
491 assert_ne!(b, wrong_vec);
492 assert_ne!(wrong_vec, b);
493 }
494
495 #[test]
496 fn test_bytes_partialeq() {
497 let b = bytes::Bytes::from("0123456789abcdef");
498 let wrapped_b = Bytes::from(b.clone());
499 assert_eq!(wrapped_b, b);
500
501 let wrong_b = bytes::Bytes::from("0123absd");
502 assert_ne!(wrong_b, b);
503 }
504
505 #[test]
506 fn test_u128_from_bytes() {
507 let data = Bytes::from(vec![4, 3, 2, 1]);
508 let result: u128 = u128::from(data.clone());
509 assert_eq!(result, u128::from_str("67305985").unwrap());
510 }
511
512 #[test]
513 fn test_i128_from_bytes() {
514 let data = Bytes::from(vec![4, 3, 2, 1]);
515 let result: i128 = i128::from(data.clone());
516 assert_eq!(result, i128::from_str("67305985").unwrap());
517 }
518
519 #[test]
520 fn test_i32_from_bytes() {
521 let data = Bytes::from(vec![4, 3, 2, 1]);
522 let result: i32 = i32::from(data);
523 assert_eq!(result, i32::from_str("67305985").unwrap());
524 }
525}
526
527#[cfg(feature = "diesel")]
528#[cfg(test)]
529mod diesel_tests {
530 use diesel::{insert_into, table, Insertable, Queryable};
531 use diesel_async::{AsyncConnection, AsyncPgConnection, RunQueryDsl, SimpleAsyncConnection};
532
533 use super::*;
534
535 async fn setup_db() -> AsyncPgConnection {
536 let db_url = std::env::var("DATABASE_URL").unwrap();
537 let mut conn = AsyncPgConnection::establish(&db_url)
538 .await
539 .unwrap();
540 conn.begin_test_transaction()
541 .await
542 .unwrap();
543 conn
544 }
545
546 #[tokio::test]
547 async fn test_bytes_db_round_trip() {
548 table! {
549 bytes_table (id) {
550 id -> Int4,
551 data -> Binary,
552 }
553 }
554
555 #[derive(Insertable)]
556 #[diesel(table_name = bytes_table)]
557 struct NewByteEntry {
558 data: Bytes,
559 }
560
561 #[derive(Queryable, PartialEq)]
562 struct ByteEntry {
563 id: i32,
564 data: Bytes,
565 }
566
567 let mut conn = setup_db().await;
568 let example_bytes = Bytes::from_str("0x0123456789abcdef").unwrap();
569
570 conn.batch_execute(
571 r"
572 CREATE TEMPORARY TABLE bytes_table (
573 id SERIAL PRIMARY KEY,
574 data BYTEA NOT NULL
575 );
576 ",
577 )
578 .await
579 .unwrap();
580
581 let new_entry = NewByteEntry { data: example_bytes.clone() };
582
583 let inserted: Vec<ByteEntry> = insert_into(bytes_table::table)
584 .values(&new_entry)
585 .get_results(&mut conn)
586 .await
587 .unwrap();
588
589 assert_eq!(inserted[0].data, example_bytes);
590 }
591}