commonware_utils/array/
fixed_bytes.rs

1use crate::{hex, Array};
2use bytes::{Buf, BufMut};
3use commonware_codec::{Codec, Error as CodecError, SizedCodec};
4use std::{
5    cmp::{Ord, PartialOrd},
6    fmt::{Debug, Display},
7    hash::Hash,
8    ops::Deref,
9};
10use thiserror::Error;
11
12/// Errors returned by `Bytes` functions.
13#[derive(Error, Debug, PartialEq)]
14pub enum Error {
15    #[error("invalid length")]
16    InvalidLength,
17}
18
19/// An `Array` implementation for fixed-length byte arrays.
20#[derive(Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug)]
21#[repr(transparent)]
22pub struct FixedBytes<const N: usize>([u8; N]);
23
24impl<const N: usize> FixedBytes<N> {
25    /// Creates a new `FixedBytes` instance from an array of length `N`.
26    pub fn new(value: [u8; N]) -> Self {
27        Self(value)
28    }
29}
30
31impl<const N: usize> Codec for FixedBytes<N> {
32    fn write(&self, buf: &mut impl BufMut) {
33        self.0.write(buf);
34    }
35
36    fn read(buf: &mut impl Buf) -> Result<Self, CodecError> {
37        Ok(Self(<[u8; N]>::read(buf)?))
38    }
39
40    fn len_encoded(&self) -> usize {
41        N
42    }
43}
44
45impl<const N: usize> SizedCodec for FixedBytes<N> {
46    const LEN_ENCODED: usize = N;
47}
48
49impl<const N: usize> Array for FixedBytes<N> {
50    type Error = Error;
51}
52
53impl<const N: usize> TryFrom<&[u8]> for FixedBytes<N> {
54    type Error = Error;
55
56    fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
57        let array: [u8; N] = value.try_into().map_err(|_| Error::InvalidLength)?;
58        Ok(Self(array))
59    }
60}
61
62impl<const N: usize> TryFrom<&Vec<u8>> for FixedBytes<N> {
63    type Error = Error;
64
65    fn try_from(value: &Vec<u8>) -> Result<Self, Self::Error> {
66        Self::try_from(value.as_slice())
67    }
68}
69
70impl<const N: usize> TryFrom<Vec<u8>> for FixedBytes<N> {
71    type Error = Error;
72
73    fn try_from(value: Vec<u8>) -> Result<Self, Self::Error> {
74        if value.len() != N {
75            return Err(Error::InvalidLength);
76        }
77        let boxed_slice = value.into_boxed_slice();
78        let boxed_array: Box<[u8; N]> = boxed_slice.try_into().map_err(|_| Error::InvalidLength)?;
79        Ok(Self(*boxed_array))
80    }
81}
82
83impl<const N: usize> AsRef<[u8]> for FixedBytes<N> {
84    fn as_ref(&self) -> &[u8] {
85        &self.0
86    }
87}
88
89impl<const N: usize> Deref for FixedBytes<N> {
90    type Target = [u8];
91    fn deref(&self) -> &[u8] {
92        &self.0
93    }
94}
95
96impl<const N: usize> Display for FixedBytes<N> {
97    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
98        write!(f, "{}", hex(&self.0))
99    }
100}
101
102#[cfg(test)]
103mod tests {
104    use super::*;
105    use crate::array::Error as ArrayError;
106    use bytes::{Buf, BytesMut};
107
108    #[test]
109    fn test_codec() {
110        let original = FixedBytes::new([1, 2, 3, 4]);
111        let encoded = original.encode();
112        assert_eq!(encoded.len(), original.len());
113        let decoded = FixedBytes::decode(encoded).unwrap();
114        assert_eq!(original, decoded);
115    }
116
117    #[test]
118    fn test_bytes_creation_and_conversion() {
119        let value = [1, 2, 3, 4];
120        let bytes = FixedBytes::new(value);
121        assert_eq!(bytes.as_ref(), &value);
122
123        let slice = [1, 2, 3, 4];
124        let bytes_from_slice = FixedBytes::try_from(slice.as_ref()).unwrap();
125        assert_eq!(bytes_from_slice, bytes);
126
127        let vec = vec![1, 2, 3, 4];
128        let bytes_from_vec_ref = FixedBytes::try_from(&vec).unwrap();
129        assert_eq!(bytes_from_vec_ref, bytes);
130
131        let bytes_from_vec = FixedBytes::try_from(vec).unwrap();
132        assert_eq!(bytes_from_vec, bytes);
133
134        // Test with incorrect length
135        let slice_too_short = [1, 2, 3];
136        assert_eq!(
137            FixedBytes::<4>::try_from(slice_too_short.as_ref()),
138            Err(Error::InvalidLength)
139        );
140
141        let vec_too_long = vec![1, 2, 3, 4, 5];
142        assert_eq!(
143            FixedBytes::<4>::try_from(&vec_too_long),
144            Err(Error::InvalidLength)
145        );
146        assert_eq!(
147            FixedBytes::<4>::try_from(vec_too_long),
148            Err(Error::InvalidLength)
149        );
150    }
151
152    #[test]
153    fn test_read_from() {
154        let mut buf = BytesMut::from(&[1, 2, 3, 4][..]);
155        let bytes = FixedBytes::<4>::read_from(&mut buf).unwrap();
156        assert_eq!(bytes.as_ref(), &[1, 2, 3, 4]);
157        assert_eq!(buf.remaining(), 0);
158
159        let mut buf = BytesMut::from(&[1, 2, 3][..]);
160        let result = FixedBytes::<4>::read_from(&mut buf);
161        assert_eq!(result, Err(ArrayError::InsufficientBytes));
162
163        let mut buf = BytesMut::from(&[1, 2, 3, 4, 5][..]);
164        let bytes = FixedBytes::<4>::read_from(&mut buf).unwrap();
165        assert_eq!(bytes.as_ref(), &[1, 2, 3, 4]);
166        assert_eq!(buf.remaining(), 1);
167        assert_eq!(buf[0], 5);
168    }
169
170    #[test]
171    fn test_display() {
172        let bytes = FixedBytes::new([0x01, 0x02, 0x03, 0x04]);
173        assert_eq!(format!("{}", bytes), "01020304");
174    }
175
176    #[test]
177    fn test_ord_and_eq() {
178        let a = FixedBytes::new([1, 2, 3, 4]);
179        let b = FixedBytes::new([1, 2, 3, 5]);
180        assert!(a < b);
181        assert_ne!(a, b);
182
183        let c = FixedBytes::new([1, 2, 3, 4]);
184        assert_eq!(a, c);
185    }
186}