use std::{
borrow::Borrow,
clone::Clone,
fmt::{Debug, Display, Formatter, LowerHex, Result as FmtResult},
ops::Deref,
str::FromStr,
};
use deepsize::{Context, DeepSizeOf};
#[cfg(feature = "diesel")]
use diesel::{
deserialize::{self, FromSql, FromSqlRow},
expression::AsExpression,
pg::Pg,
serialize::{self, ToSql},
sql_types::Binary,
};
use rand::Rng;
use serde::{Deserialize, Serialize};
use thiserror::Error;
use crate::serde_primitives::hex_bytes;
#[derive(Clone, Default, PartialEq, Eq, Hash, Ord, PartialOrd, Serialize, Deserialize)]
#[cfg_attr(feature = "diesel", derive(AsExpression, FromSqlRow,))]
#[cfg_attr(feature = "diesel", diesel(sql_type = Binary))]
pub struct Bytes(#[serde(with = "hex_bytes")] pub bytes::Bytes);
impl DeepSizeOf for Bytes {
fn deep_size_of_children(&self, _ctx: &mut Context) -> usize {
self.0.len()
}
}
fn bytes_to_hex(b: &Bytes) -> String {
hex::encode(b.0.as_ref())
}
impl Debug for Bytes {
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
write!(f, "Bytes(0x{})", bytes_to_hex(self))
}
}
impl Display for Bytes {
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
write!(f, "0x{}", bytes_to_hex(self))
}
}
impl LowerHex for Bytes {
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
write!(f, "0x{}", bytes_to_hex(self))
}
}
impl Bytes {
pub fn new() -> Self {
Self(bytes::Bytes::new())
}
pub fn to_vec(&self) -> Vec<u8> {
self.as_ref().to_vec()
}
pub fn lpad(&self, length: usize, pad_byte: u8) -> Bytes {
let mut padded_vec = vec![pad_byte; length.saturating_sub(self.len())];
padded_vec.extend_from_slice(self.as_ref());
Bytes(bytes::Bytes::from(padded_vec))
}
pub fn rpad(&self, length: usize, pad_byte: u8) -> Bytes {
let mut padded_vec = self.to_vec();
padded_vec.resize(length, pad_byte);
Bytes(bytes::Bytes::from(padded_vec))
}
pub fn zero(length: usize) -> Bytes {
Bytes::from(vec![0u8; length])
}
pub fn random(length: usize) -> Bytes {
let mut data = vec![0u8; length];
rand::thread_rng().fill(&mut data[..]);
Bytes::from(data)
}
pub fn is_zero(&self) -> bool {
self.as_ref().iter().all(|b| *b == 0)
}
}
impl Deref for Bytes {
type Target = [u8];
#[inline]
fn deref(&self) -> &[u8] {
self.as_ref()
}
}
impl AsRef<[u8]> for Bytes {
fn as_ref(&self) -> &[u8] {
self.0.as_ref()
}
}
impl Borrow<[u8]> for Bytes {
fn borrow(&self) -> &[u8] {
self.as_ref()
}
}
impl IntoIterator for Bytes {
type Item = u8;
type IntoIter = bytes::buf::IntoIter<bytes::Bytes>;
fn into_iter(self) -> Self::IntoIter {
self.0.into_iter()
}
}
impl<'a> IntoIterator for &'a Bytes {
type Item = &'a u8;
type IntoIter = core::slice::Iter<'a, u8>;
fn into_iter(self) -> Self::IntoIter {
self.as_ref().iter()
}
}
impl From<&[u8]> for Bytes {
fn from(src: &[u8]) -> Self {
Self(bytes::Bytes::copy_from_slice(src))
}
}
impl From<bytes::Bytes> for Bytes {
fn from(src: bytes::Bytes) -> Self {
Self(src)
}
}
impl From<Bytes> for bytes::Bytes {
fn from(src: Bytes) -> Self {
src.0
}
}
impl From<Vec<u8>> for Bytes {
fn from(src: Vec<u8>) -> Self {
Self(src.into())
}
}
impl From<Bytes> for Vec<u8> {
fn from(value: Bytes) -> Self {
value.to_vec()
}
}
impl<const N: usize> From<[u8; N]> for Bytes {
fn from(src: [u8; N]) -> Self {
src.to_vec().into()
}
}
impl<'a, const N: usize> From<&'a [u8; N]> for Bytes {
fn from(src: &'a [u8; N]) -> Self {
src.to_vec().into()
}
}
impl PartialEq<[u8]> for Bytes {
fn eq(&self, other: &[u8]) -> bool {
self.as_ref() == other
}
}
impl PartialEq<Bytes> for [u8] {
fn eq(&self, other: &Bytes) -> bool {
*other == *self
}
}
impl PartialEq<Vec<u8>> for Bytes {
fn eq(&self, other: &Vec<u8>) -> bool {
self.as_ref() == &other[..]
}
}
impl PartialEq<Bytes> for Vec<u8> {
fn eq(&self, other: &Bytes) -> bool {
*other == *self
}
}
impl PartialEq<bytes::Bytes> for Bytes {
fn eq(&self, other: &bytes::Bytes) -> bool {
other == self.as_ref()
}
}
#[derive(Debug, Clone, Error)]
#[error("Failed to parse bytes: {0}")]
pub struct ParseBytesError(String);
impl FromStr for Bytes {
type Err = ParseBytesError;
fn from_str(value: &str) -> Result<Self, Self::Err> {
if let Some(value) = value.strip_prefix("0x") {
hex::decode(value)
} else {
hex::decode(value)
}
.map(Into::into)
.map_err(|e| ParseBytesError(format!("Invalid hex: {e}")))
}
}
impl From<&str> for Bytes {
fn from(value: &str) -> Self {
value.parse().unwrap()
}
}
#[cfg(feature = "diesel")]
impl ToSql<Binary, Pg> for Bytes {
fn to_sql<'b>(&'b self, out: &mut serialize::Output<'b, '_, Pg>) -> serialize::Result {
let bytes_slice: &[u8] = &self.0;
<&[u8] as ToSql<Binary, Pg>>::to_sql(&bytes_slice, &mut out.reborrow())
}
}
#[cfg(feature = "diesel")]
impl FromSql<Binary, Pg> for Bytes {
fn from_sql(
bytes: <diesel::pg::Pg as diesel::backend::Backend>::RawValue<'_>,
) -> deserialize::Result<Self> {
let byte_vec: Vec<u8> = <Vec<u8> as FromSql<Binary, Pg>>::from_sql(bytes)?;
Ok(Bytes(bytes::Bytes::from(byte_vec)))
}
}
macro_rules! impl_from_uint_for_bytes {
($($t:ty),*) => {
$(
impl From<$t> for Bytes {
fn from(src: $t) -> Self {
let size = std::mem::size_of::<$t>();
let mut buf = vec![0u8; size];
buf.copy_from_slice(&src.to_be_bytes());
Self(bytes::Bytes::from(buf))
}
}
)*
};
}
impl_from_uint_for_bytes!(u8, u16, u32, u64, u128);
macro_rules! impl_from_bytes_for_uint {
($($t:ty),*) => {
$(
impl From<Bytes> for $t {
fn from(src: Bytes) -> Self {
let bytes_slice = src.as_ref();
let mut buf = [0u8; std::mem::size_of::<$t>()];
buf[std::mem::size_of::<$t>() - bytes_slice.len()..].copy_from_slice(bytes_slice);
<$t>::from_be_bytes(buf)
}
}
)*
};
}
impl_from_bytes_for_uint!(u8, u16, u32, u64, u128);
macro_rules! impl_from_bytes_for_signed_int {
($($t:ty),*) => {
$(
impl From<Bytes> for $t {
fn from(src: Bytes) -> Self {
let bytes_slice = src.as_ref();
let mut buf = if bytes_slice.get(0).map_or(false, |&b| b & 0x80 != 0) {
[0xFFu8; std::mem::size_of::<$t>()] } else {
[0x00u8; std::mem::size_of::<$t>()] };
buf[std::mem::size_of::<$t>() - bytes_slice.len()..].copy_from_slice(bytes_slice);
<$t>::from_be_bytes(buf)
}
}
)*
};
}
impl_from_bytes_for_signed_int!(i8, i16, i32, i64, i128);
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_from_bytes() {
let b = bytes::Bytes::from("0123456789abcdef");
let wrapped_b = Bytes::from(b.clone());
let expected = Bytes(b);
assert_eq!(wrapped_b, expected);
}
#[test]
fn test_from_slice() {
let arr = [1, 35, 69, 103, 137, 171, 205, 239];
let b = Bytes::from(&arr);
let expected = Bytes(bytes::Bytes::from(arr.to_vec()));
assert_eq!(b, expected);
}
#[test]
fn hex_formatting() {
let b = Bytes::from(vec![1, 35, 69, 103, 137, 171, 205, 239]);
let expected = String::from("0x0123456789abcdef");
assert_eq!(format!("{b:x}"), expected);
assert_eq!(format!("{b}"), expected);
}
#[test]
fn test_from_str() {
let b = Bytes::from_str("0x1213");
assert!(b.is_ok());
let b = b.unwrap();
assert_eq!(b.as_ref(), hex::decode("1213").unwrap());
let b = Bytes::from_str("1213");
let b = b.unwrap();
assert_eq!(b.as_ref(), hex::decode("1213").unwrap());
}
#[test]
fn test_debug_formatting() {
let b = Bytes::from(vec![1, 35, 69, 103, 137, 171, 205, 239]);
assert_eq!(format!("{b:?}"), "Bytes(0x0123456789abcdef)");
assert_eq!(format!("{b:#?}"), "Bytes(0x0123456789abcdef)");
}
#[test]
fn test_to_vec() {
let vec = vec![1, 35, 69, 103, 137, 171, 205, 239];
let b = Bytes::from(vec.clone());
assert_eq!(b.to_vec(), vec);
}
#[test]
fn test_vec_partialeq() {
let vec = vec![1, 35, 69, 103, 137, 171, 205, 239];
let b = Bytes::from(vec.clone());
assert_eq!(b, vec);
assert_eq!(vec, b);
let wrong_vec = vec![1, 3, 52, 137];
assert_ne!(b, wrong_vec);
assert_ne!(wrong_vec, b);
}
#[test]
fn test_bytes_partialeq() {
let b = bytes::Bytes::from("0123456789abcdef");
let wrapped_b = Bytes::from(b.clone());
assert_eq!(wrapped_b, b);
let wrong_b = bytes::Bytes::from("0123absd");
assert_ne!(wrong_b, b);
}
#[test]
fn test_u128_from_bytes() {
let data = Bytes::from(vec![4, 3, 2, 1]);
let result: u128 = u128::from(data.clone());
assert_eq!(result, u128::from_str("67305985").unwrap());
}
#[test]
fn test_i128_from_bytes() {
let data = Bytes::from(vec![4, 3, 2, 1]);
let result: i128 = i128::from(data.clone());
assert_eq!(result, i128::from_str("67305985").unwrap());
}
#[test]
fn test_i32_from_bytes() {
let data = Bytes::from(vec![4, 3, 2, 1]);
let result: i32 = i32::from(data);
assert_eq!(result, i32::from_str("67305985").unwrap());
}
}
#[cfg(feature = "diesel")]
#[cfg(test)]
mod diesel_tests {
use diesel::{insert_into, table, Insertable, Queryable};
use diesel_async::{AsyncConnection, AsyncPgConnection, RunQueryDsl, SimpleAsyncConnection};
use super::*;
async fn setup_db() -> AsyncPgConnection {
let db_url = std::env::var("DATABASE_URL").unwrap();
let mut conn = AsyncPgConnection::establish(&db_url)
.await
.unwrap();
conn.begin_test_transaction()
.await
.unwrap();
conn
}
#[tokio::test]
async fn test_bytes_db_round_trip() {
table! {
bytes_table (id) {
id -> Int4,
data -> Binary,
}
}
#[derive(Insertable)]
#[diesel(table_name = bytes_table)]
struct NewByteEntry {
data: Bytes,
}
#[derive(Queryable, PartialEq)]
struct ByteEntry {
id: i32,
data: Bytes,
}
let mut conn = setup_db().await;
let example_bytes = Bytes::from_str("0x0123456789abcdef").unwrap();
conn.batch_execute(
r"
CREATE TEMPORARY TABLE bytes_table (
id SERIAL PRIMARY KEY,
data BYTEA NOT NULL
);
",
)
.await
.unwrap();
let new_entry = NewByteEntry { data: example_bytes.clone() };
let inserted: Vec<ByteEntry> = insert_into(bytes_table::table)
.values(&new_entry)
.get_results(&mut conn)
.await
.unwrap();
assert_eq!(inserted[0].data, example_bytes);
}
}