tycho_common/
hex_bytes.rs

1use std::{
2    borrow::Borrow,
3    clone::Clone,
4    fmt::{Debug, Display, Formatter, LowerHex, Result as FmtResult},
5    ops::Deref,
6    str::FromStr,
7};
8
9#[cfg(feature = "diesel")]
10use diesel::{
11    deserialize::{self, FromSql, FromSqlRow},
12    expression::AsExpression,
13    pg::Pg,
14    serialize::{self, ToSql},
15    sql_types::Binary,
16};
17use rand::Rng;
18use serde::{Deserialize, Serialize};
19use thiserror::Error;
20
21use crate::serde_primitives::hex_bytes;
22
23/// Wrapper type around Bytes to deserialize/serialize from/to hex
24#[derive(Clone, Default, PartialEq, Eq, Hash, Ord, PartialOrd, Serialize, Deserialize)]
25#[cfg_attr(feature = "diesel", derive(AsExpression, FromSqlRow,))]
26#[cfg_attr(feature = "diesel", diesel(sql_type = Binary))]
27pub struct Bytes(#[serde(with = "hex_bytes")] pub bytes::Bytes);
28
29fn bytes_to_hex(b: &Bytes) -> String {
30    hex::encode(b.0.as_ref())
31}
32
33impl Debug for Bytes {
34    fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
35        write!(f, "Bytes(0x{})", bytes_to_hex(self))
36    }
37}
38
39impl Display for Bytes {
40    fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
41        write!(f, "0x{}", bytes_to_hex(self))
42    }
43}
44
45impl LowerHex for Bytes {
46    fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
47        write!(f, "0x{}", bytes_to_hex(self))
48    }
49}
50
51impl Bytes {
52    pub fn new() -> Self {
53        Self(bytes::Bytes::new())
54    }
55    /// This function converts the internal byte array into a `Vec<u8>`
56    ///
57    /// # Returns
58    ///
59    /// A `Vec<u8>` containing the bytes from the `Bytes` struct.
60    ///
61    /// # Example
62    ///
63    /// ```
64    /// let bytes = Bytes::from(vec![0x01, 0x02, 0x03]);
65    /// let vec = bytes.to_vec();
66    /// assert_eq!(vec, vec![0x01, 0x02, 0x03]);
67    /// ```
68    pub fn to_vec(&self) -> Vec<u8> {
69        self.as_ref().to_vec()
70    }
71
72    /// Left-pads the byte array to the specified length with the given padding byte.
73    ///
74    /// This function creates a new `Bytes` instance by prepending the specified padding byte
75    /// to the current byte array until its total length matches the desired length.
76    ///
77    /// If the current length of the byte array is greater than or equal to the specified length,
78    /// the original byte array is returned unchanged.
79    ///
80    /// # Arguments
81    ///
82    /// * `length` - The desired total length of the resulting byte array.
83    /// * `pad_byte` - The byte value to use for padding. Commonly `0x00`.
84    ///
85    /// # Returns
86    ///
87    /// A new `Bytes` instance with the byte array left-padded to the desired length.
88    ///
89    /// # Example
90    ///
91    /// ```
92    /// let bytes = Bytes::from(vec![0x01, 0x02, 0x03]);
93    /// let padded = bytes.lpad(6, 0x00);
94    /// assert_eq!(padded.to_vec(), vec![0x00, 0x00, 0x00, 0x01, 0x02, 0x03]);
95    /// ```
96    pub fn lpad(&self, length: usize, pad_byte: u8) -> Bytes {
97        let mut padded_vec = vec![pad_byte; length.saturating_sub(self.len())];
98        padded_vec.extend_from_slice(self.as_ref());
99
100        Bytes(bytes::Bytes::from(padded_vec))
101    }
102
103    /// Right-pads the byte array to the specified length with the given padding byte.
104    ///
105    /// This function creates a new `Bytes` instance by appending the specified padding byte
106    /// to the current byte array until its total length matches the desired length.
107    ///
108    /// If the current length of the byte array is greater than or equal to the specified length,
109    /// the original byte array is returned unchanged.
110    ///
111    /// # Arguments
112    ///
113    /// * `length` - The desired total length of the resulting byte array.
114    /// * `pad_byte` - The byte value to use for padding. Commonly `0x00`.
115    ///
116    /// # Returns
117    ///
118    /// A new `Bytes` instance with the byte array right-padded to the desired length.
119    ///
120    /// # Example
121    ///
122    /// ```
123    /// let bytes = Bytes::from(vec![0x01, 0x02, 0x03]);
124    /// let padded = bytes.rpad(6, 0x00);
125    /// assert_eq!(padded.to_vec(), vec![0x01, 0x02, 0x03, 0x00, 0x00, 0x00]);
126    /// ```
127    pub fn rpad(&self, length: usize, pad_byte: u8) -> Bytes {
128        let mut padded_vec = self.to_vec();
129        padded_vec.resize(length, pad_byte);
130
131        Bytes(bytes::Bytes::from(padded_vec))
132    }
133
134    /// Creates a `Bytes` object of the specified length, filled with zeros.
135    ///
136    /// # Arguments
137    ///
138    /// * `length` - The length of the `Bytes` object to be created.
139    ///
140    /// # Returns
141    ///
142    /// A `Bytes` object of the specified length, where each byte is set to zero.
143    ///
144    /// # Example
145    ///
146    /// ```
147    /// let b = Bytes::zero(5);
148    /// assert_eq!(b, Bytes::from(vec![0, 0, 0, 0, 0]));
149    /// ```
150    pub fn zero(length: usize) -> Bytes {
151        Bytes::from(vec![0u8; length])
152    }
153
154    /// Creates a `Bytes` object of the specified length, filled with random bytes.
155    ///
156    /// # Arguments
157    ///
158    /// * `length` - The length of the `Bytes` object to be created.
159    ///
160    /// # Returns
161    ///
162    /// A `Bytes` object of the specified length, filled with random bytes.
163    ///
164    /// # Example
165    ///
166    /// ```
167    /// let random_bytes = Bytes::random(5);
168    /// assert_eq!(random_bytes.len(), 5);
169    /// ```
170    pub fn random(length: usize) -> Bytes {
171        let mut data = vec![0u8; length];
172        rand::thread_rng().fill(&mut data[..]);
173        Bytes::from(data)
174    }
175
176    /// Checks if the byte array is full of zeros.
177    ///
178    /// # Returns
179    ///
180    /// A boolean value indicating whether all bytes in the byte array are zero.
181    ///
182    /// # Example
183    ///
184    /// ```
185    /// let b = Bytes::zero(5);
186    /// assert!(b.is_zero());
187    /// ```
188    pub fn is_zero(&self) -> bool {
189        self.as_ref().iter().all(|b| *b == 0)
190    }
191}
192
193impl Deref for Bytes {
194    type Target = [u8];
195
196    #[inline]
197    fn deref(&self) -> &[u8] {
198        self.as_ref()
199    }
200}
201
202impl AsRef<[u8]> for Bytes {
203    fn as_ref(&self) -> &[u8] {
204        self.0.as_ref()
205    }
206}
207
208impl Borrow<[u8]> for Bytes {
209    fn borrow(&self) -> &[u8] {
210        self.as_ref()
211    }
212}
213
214impl IntoIterator for Bytes {
215    type Item = u8;
216    type IntoIter = bytes::buf::IntoIter<bytes::Bytes>;
217
218    fn into_iter(self) -> Self::IntoIter {
219        self.0.into_iter()
220    }
221}
222
223impl<'a> IntoIterator for &'a Bytes {
224    type Item = &'a u8;
225    type IntoIter = core::slice::Iter<'a, u8>;
226
227    fn into_iter(self) -> Self::IntoIter {
228        self.as_ref().iter()
229    }
230}
231
232impl From<&[u8]> for Bytes {
233    fn from(src: &[u8]) -> Self {
234        Self(bytes::Bytes::copy_from_slice(src))
235    }
236}
237
238impl From<bytes::Bytes> for Bytes {
239    fn from(src: bytes::Bytes) -> Self {
240        Self(src)
241    }
242}
243
244impl From<Bytes> for bytes::Bytes {
245    fn from(src: Bytes) -> Self {
246        src.0
247    }
248}
249
250impl From<Vec<u8>> for Bytes {
251    fn from(src: Vec<u8>) -> Self {
252        Self(src.into())
253    }
254}
255
256impl From<Bytes> for Vec<u8> {
257    fn from(value: Bytes) -> Self {
258        value.to_vec()
259    }
260}
261
262impl<const N: usize> From<[u8; N]> for Bytes {
263    fn from(src: [u8; N]) -> Self {
264        src.to_vec().into()
265    }
266}
267
268impl<'a, const N: usize> From<&'a [u8; N]> for Bytes {
269    fn from(src: &'a [u8; N]) -> Self {
270        src.to_vec().into()
271    }
272}
273
274impl PartialEq<[u8]> for Bytes {
275    fn eq(&self, other: &[u8]) -> bool {
276        self.as_ref() == other
277    }
278}
279
280impl PartialEq<Bytes> for [u8] {
281    fn eq(&self, other: &Bytes) -> bool {
282        *other == *self
283    }
284}
285
286impl PartialEq<Vec<u8>> for Bytes {
287    fn eq(&self, other: &Vec<u8>) -> bool {
288        self.as_ref() == &other[..]
289    }
290}
291
292impl PartialEq<Bytes> for Vec<u8> {
293    fn eq(&self, other: &Bytes) -> bool {
294        *other == *self
295    }
296}
297
298impl PartialEq<bytes::Bytes> for Bytes {
299    fn eq(&self, other: &bytes::Bytes) -> bool {
300        other == self.as_ref()
301    }
302}
303
304#[derive(Debug, Clone, Error)]
305#[error("Failed to parse bytes: {0}")]
306pub struct ParseBytesError(String);
307
308impl FromStr for Bytes {
309    type Err = ParseBytesError;
310
311    fn from_str(value: &str) -> Result<Self, Self::Err> {
312        if let Some(value) = value.strip_prefix("0x") {
313            hex::decode(value)
314        } else {
315            hex::decode(value)
316        }
317        .map(Into::into)
318        .map_err(|e| ParseBytesError(format!("Invalid hex: {e}")))
319    }
320}
321
322impl From<&str> for Bytes {
323    fn from(value: &str) -> Self {
324        value.parse().unwrap()
325    }
326}
327
328#[cfg(feature = "diesel")]
329impl ToSql<Binary, Pg> for Bytes {
330    fn to_sql<'b>(&'b self, out: &mut serialize::Output<'b, '_, Pg>) -> serialize::Result {
331        let bytes_slice: &[u8] = &self.0;
332        <&[u8] as ToSql<Binary, Pg>>::to_sql(&bytes_slice, &mut out.reborrow())
333    }
334}
335
336#[cfg(feature = "diesel")]
337impl FromSql<Binary, Pg> for Bytes {
338    fn from_sql(
339        bytes: <diesel::pg::Pg as diesel::backend::Backend>::RawValue<'_>,
340    ) -> deserialize::Result<Self> {
341        let byte_vec: Vec<u8> = <Vec<u8> as FromSql<Binary, Pg>>::from_sql(bytes)?;
342        Ok(Bytes(bytes::Bytes::from(byte_vec)))
343    }
344}
345
346macro_rules! impl_from_uint_for_bytes {
347    ($($t:ty),*) => {
348        $(
349            impl From<$t> for Bytes {
350                fn from(src: $t) -> Self {
351                    let size = std::mem::size_of::<$t>();
352                    let mut buf = vec![0u8; size];
353                    buf.copy_from_slice(&src.to_be_bytes());
354
355                    Self(bytes::Bytes::from(buf))
356                }
357            }
358        )*
359    };
360}
361
362impl_from_uint_for_bytes!(u8, u16, u32, u64, u128);
363
364macro_rules! impl_from_bytes_for_uint {
365    ($($t:ty),*) => {
366        $(
367            impl From<Bytes> for $t {
368                fn from(src: Bytes) -> Self {
369                    let bytes_slice = src.as_ref();
370
371                    // Create an array with zeros.
372                    let mut buf = [0u8; std::mem::size_of::<$t>()];
373
374                    // Copy bytes from `bytes_slice` to the end of `buf` to maintain big-endian order.
375                    buf[std::mem::size_of::<$t>() - bytes_slice.len()..].copy_from_slice(bytes_slice);
376
377                    // Convert to the integer type using big-endian.
378                    <$t>::from_be_bytes(buf)
379                }
380            }
381        )*
382    };
383}
384
385impl_from_bytes_for_uint!(u8, u16, u32, u64, u128);
386
387macro_rules! impl_from_bytes_for_signed_int {
388    ($($t:ty),*) => {
389        $(
390            impl From<Bytes> for $t {
391                fn from(src: Bytes) -> Self {
392                    let bytes_slice = src.as_ref();
393
394                    // Create an array with zeros or ones for negative numbers.
395                    let mut buf = if bytes_slice.get(0).map_or(false, |&b| b & 0x80 != 0) {
396                        [0xFFu8; std::mem::size_of::<$t>()] // Sign-extend with 0xFF for negative numbers.
397                    } else {
398                        [0x00u8; std::mem::size_of::<$t>()] // Fill with 0x00 for positive numbers.
399                    };
400
401                    // Copy bytes from `bytes_slice` to the end of `buf` to maintain big-endian order.
402                    buf[std::mem::size_of::<$t>() - bytes_slice.len()..].copy_from_slice(bytes_slice);
403
404                    // Convert to the signed integer type using big-endian.
405                    <$t>::from_be_bytes(buf)
406                }
407            }
408        )*
409    };
410}
411
412impl_from_bytes_for_signed_int!(i8, i16, i32, i64, i128);
413
414#[cfg(test)]
415mod tests {
416    use super::*;
417
418    #[test]
419    fn test_from_bytes() {
420        let b = bytes::Bytes::from("0123456789abcdef");
421        let wrapped_b = Bytes::from(b.clone());
422        let expected = Bytes(b);
423
424        assert_eq!(wrapped_b, expected);
425    }
426
427    #[test]
428    fn test_from_slice() {
429        let arr = [1, 35, 69, 103, 137, 171, 205, 239];
430        let b = Bytes::from(&arr);
431        let expected = Bytes(bytes::Bytes::from(arr.to_vec()));
432
433        assert_eq!(b, expected);
434    }
435
436    #[test]
437    fn hex_formatting() {
438        let b = Bytes::from(vec![1, 35, 69, 103, 137, 171, 205, 239]);
439        let expected = String::from("0x0123456789abcdef");
440        assert_eq!(format!("{b:x}"), expected);
441        assert_eq!(format!("{b}"), expected);
442    }
443
444    #[test]
445    fn test_from_str() {
446        let b = Bytes::from_str("0x1213");
447        assert!(b.is_ok());
448        let b = b.unwrap();
449        assert_eq!(b.as_ref(), hex::decode("1213").unwrap());
450
451        let b = Bytes::from_str("1213");
452        let b = b.unwrap();
453        assert_eq!(b.as_ref(), hex::decode("1213").unwrap());
454    }
455
456    #[test]
457    fn test_debug_formatting() {
458        let b = Bytes::from(vec![1, 35, 69, 103, 137, 171, 205, 239]);
459        assert_eq!(format!("{b:?}"), "Bytes(0x0123456789abcdef)");
460        assert_eq!(format!("{b:#?}"), "Bytes(0x0123456789abcdef)");
461    }
462
463    #[test]
464    fn test_to_vec() {
465        let vec = vec![1, 35, 69, 103, 137, 171, 205, 239];
466        let b = Bytes::from(vec.clone());
467
468        assert_eq!(b.to_vec(), vec);
469    }
470
471    #[test]
472    fn test_vec_partialeq() {
473        let vec = vec![1, 35, 69, 103, 137, 171, 205, 239];
474        let b = Bytes::from(vec.clone());
475        assert_eq!(b, vec);
476        assert_eq!(vec, b);
477
478        let wrong_vec = vec![1, 3, 52, 137];
479        assert_ne!(b, wrong_vec);
480        assert_ne!(wrong_vec, b);
481    }
482
483    #[test]
484    fn test_bytes_partialeq() {
485        let b = bytes::Bytes::from("0123456789abcdef");
486        let wrapped_b = Bytes::from(b.clone());
487        assert_eq!(wrapped_b, b);
488
489        let wrong_b = bytes::Bytes::from("0123absd");
490        assert_ne!(wrong_b, b);
491    }
492
493    #[test]
494    fn test_u128_from_bytes() {
495        let data = Bytes::from(vec![4, 3, 2, 1]);
496        let result: u128 = u128::from(data.clone());
497        assert_eq!(result, u128::from_str("67305985").unwrap());
498    }
499
500    #[test]
501    fn test_i128_from_bytes() {
502        let data = Bytes::from(vec![4, 3, 2, 1]);
503        let result: i128 = i128::from(data.clone());
504        assert_eq!(result, i128::from_str("67305985").unwrap());
505    }
506
507    #[test]
508    fn test_i32_from_bytes() {
509        let data = Bytes::from(vec![4, 3, 2, 1]);
510        let result: i32 = i32::from(data);
511        assert_eq!(result, i32::from_str("67305985").unwrap());
512    }
513}
514
515#[cfg(feature = "diesel")]
516#[cfg(test)]
517mod diesel_tests {
518    use diesel::{insert_into, table, Insertable, Queryable};
519    use diesel_async::{AsyncConnection, AsyncPgConnection, RunQueryDsl, SimpleAsyncConnection};
520
521    use super::*;
522
523    async fn setup_db() -> AsyncPgConnection {
524        let db_url = std::env::var("DATABASE_URL").unwrap();
525        let mut conn = AsyncPgConnection::establish(&db_url)
526            .await
527            .unwrap();
528        conn.begin_test_transaction()
529            .await
530            .unwrap();
531        conn
532    }
533
534    #[tokio::test]
535    async fn test_bytes_db_round_trip() {
536        table! {
537            bytes_table (id) {
538                id -> Int4,
539                data -> Binary,
540            }
541        }
542
543        #[derive(Insertable)]
544        #[diesel(table_name = bytes_table)]
545        struct NewByteEntry {
546            data: Bytes,
547        }
548
549        #[derive(Queryable, PartialEq)]
550        struct ByteEntry {
551            id: i32,
552            data: Bytes,
553        }
554
555        let mut conn = setup_db().await;
556        let example_bytes = Bytes::from_str("0x0123456789abcdef").unwrap();
557
558        conn.batch_execute(
559            r"
560            CREATE TEMPORARY TABLE bytes_table (
561                id SERIAL PRIMARY KEY,
562                data BYTEA NOT NULL
563            );
564        ",
565        )
566        .await
567        .unwrap();
568
569        let new_entry = NewByteEntry { data: example_bytes.clone() };
570
571        let inserted: Vec<ByteEntry> = insert_into(bytes_table::table)
572            .values(&new_entry)
573            .get_results(&mut conn)
574            .await
575            .unwrap();
576
577        assert_eq!(inserted[0].data, example_bytes);
578    }
579}