commonware_utils/array/
fixed_bytes.rs

1use crate::{hex, Array, SizedSerialize};
2use std::{
3    cmp::{Ord, PartialOrd},
4    fmt::{Debug, Display},
5    hash::Hash,
6    ops::Deref,
7};
8use thiserror::Error;
9
10/// Errors returned by `Bytes` functions.
11#[derive(Error, Debug, PartialEq)]
12pub enum Error {
13    #[error("invalid length")]
14    InvalidLength,
15}
16
17/// An `Array` implementation for fixed-length byte arrays.
18#[derive(Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug)]
19#[repr(transparent)]
20pub struct FixedBytes<const N: usize>([u8; N]);
21
22impl<const N: usize> FixedBytes<N> {
23    /// Creates a new `FixedBytes` instance from an array of length `N`.
24    pub fn new(value: [u8; N]) -> Self {
25        Self(value)
26    }
27}
28
29impl<const N: usize> Array for FixedBytes<N> {
30    type Error = Error;
31}
32
33impl<const N: usize> SizedSerialize for FixedBytes<N> {
34    const SERIALIZED_LEN: usize = N;
35}
36
37impl<const N: usize> TryFrom<&[u8]> for FixedBytes<N> {
38    type Error = Error;
39
40    fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
41        let array: [u8; N] = value.try_into().map_err(|_| Error::InvalidLength)?;
42        Ok(Self(array))
43    }
44}
45
46impl<const N: usize> TryFrom<&Vec<u8>> for FixedBytes<N> {
47    type Error = Error;
48
49    fn try_from(value: &Vec<u8>) -> Result<Self, Self::Error> {
50        Self::try_from(value.as_slice())
51    }
52}
53
54impl<const N: usize> TryFrom<Vec<u8>> for FixedBytes<N> {
55    type Error = Error;
56
57    fn try_from(value: Vec<u8>) -> Result<Self, Self::Error> {
58        if value.len() != N {
59            return Err(Error::InvalidLength);
60        }
61        let boxed_slice = value.into_boxed_slice();
62        let boxed_array: Box<[u8; N]> = boxed_slice.try_into().map_err(|_| Error::InvalidLength)?;
63        Ok(Self(*boxed_array))
64    }
65}
66
67impl<const N: usize> AsRef<[u8]> for FixedBytes<N> {
68    fn as_ref(&self) -> &[u8] {
69        &self.0
70    }
71}
72
73impl<const N: usize> Deref for FixedBytes<N> {
74    type Target = [u8];
75    fn deref(&self) -> &[u8] {
76        &self.0
77    }
78}
79
80impl<const N: usize> Display for FixedBytes<N> {
81    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
82        write!(f, "{}", hex(&self.0))
83    }
84}
85
86#[cfg(test)]
87mod tests {
88    use super::*;
89    use crate::array::Error as ArrayError;
90    use bytes::{Buf, BytesMut};
91
92    #[test]
93    fn test_bytes_creation_and_conversion() {
94        let value = [1, 2, 3, 4];
95        let bytes = FixedBytes::new(value);
96        assert_eq!(bytes.as_ref(), &value);
97
98        let slice = [1, 2, 3, 4];
99        let bytes_from_slice = FixedBytes::try_from(slice.as_ref()).unwrap();
100        assert_eq!(bytes_from_slice, bytes);
101
102        let vec = vec![1, 2, 3, 4];
103        let bytes_from_vec_ref = FixedBytes::try_from(&vec).unwrap();
104        assert_eq!(bytes_from_vec_ref, bytes);
105
106        let bytes_from_vec = FixedBytes::try_from(vec).unwrap();
107        assert_eq!(bytes_from_vec, bytes);
108
109        // Test with incorrect length
110        let slice_too_short = [1, 2, 3];
111        assert_eq!(
112            FixedBytes::<4>::try_from(slice_too_short.as_ref()),
113            Err(Error::InvalidLength)
114        );
115
116        let vec_too_long = vec![1, 2, 3, 4, 5];
117        assert_eq!(
118            FixedBytes::<4>::try_from(&vec_too_long),
119            Err(Error::InvalidLength)
120        );
121        assert_eq!(
122            FixedBytes::<4>::try_from(vec_too_long),
123            Err(Error::InvalidLength)
124        );
125    }
126
127    #[test]
128    fn test_read_from() {
129        let mut buf = BytesMut::from(&[1, 2, 3, 4][..]);
130        let bytes = FixedBytes::<4>::read_from(&mut buf).unwrap();
131        assert_eq!(bytes.as_ref(), &[1, 2, 3, 4]);
132        assert_eq!(buf.remaining(), 0);
133
134        let mut buf = BytesMut::from(&[1, 2, 3][..]);
135        let result = FixedBytes::<4>::read_from(&mut buf);
136        assert_eq!(result, Err(ArrayError::InsufficientBytes));
137
138        let mut buf = BytesMut::from(&[1, 2, 3, 4, 5][..]);
139        let bytes = FixedBytes::<4>::read_from(&mut buf).unwrap();
140        assert_eq!(bytes.as_ref(), &[1, 2, 3, 4]);
141        assert_eq!(buf.remaining(), 1);
142        assert_eq!(buf[0], 5);
143    }
144
145    #[test]
146    fn test_display() {
147        let bytes = FixedBytes::new([0x01, 0x02, 0x03, 0x04]);
148        assert_eq!(format!("{}", bytes), "01020304");
149    }
150
151    #[test]
152    fn test_ord_and_eq() {
153        let a = FixedBytes::new([1, 2, 3, 4]);
154        let b = FixedBytes::new([1, 2, 3, 5]);
155        assert!(a < b);
156        assert_ne!(a, b);
157
158        let c = FixedBytes::new([1, 2, 3, 4]);
159        assert_eq!(a, c);
160    }
161}