Skip to main content

alloy_rlp/
decode.rs

1use crate::{Error, Header, Result};
2use bytes::{Bytes, BytesMut};
3use core::{
4    marker::{PhantomData, PhantomPinned},
5    num::{NonZeroU128, NonZeroU16, NonZeroU32, NonZeroU64, NonZeroU8, NonZeroUsize},
6};
7
8const NON_ZERO_INTEGER_ERROR: &str = "non-zero integer cannot be zero";
9
10/// A type that can be decoded from an RLP blob.
11pub trait Decodable: Sized {
12    /// Decodes the blob into the appropriate type. `buf` must be advanced past
13    /// the decoded object.
14    fn decode(buf: &mut &[u8]) -> Result<Self>;
15}
16
17/// An active RLP decoder, with a specific slice of a payload.
18#[derive(Debug)]
19pub struct Rlp<'a> {
20    payload_view: &'a [u8],
21}
22
23impl<'a> Rlp<'a> {
24    /// Instantiate an RLP decoder with a payload slice.
25    pub fn new(mut payload: &'a [u8]) -> Result<Self> {
26        let payload_view = Header::decode_bytes(&mut payload, true)?;
27        Ok(Self { payload_view })
28    }
29
30    /// Decode the next item from the buffer.
31    #[inline]
32    pub fn get_next<T: Decodable>(&mut self) -> Result<Option<T>> {
33        if self.payload_view.is_empty() {
34            Ok(None)
35        } else {
36            T::decode(&mut self.payload_view).map(Some)
37        }
38    }
39}
40
41impl<T: ?Sized> Decodable for PhantomData<T> {
42    fn decode(_buf: &mut &[u8]) -> Result<Self> {
43        Ok(Self)
44    }
45}
46
47impl Decodable for PhantomPinned {
48    fn decode(_buf: &mut &[u8]) -> Result<Self> {
49        Ok(Self)
50    }
51}
52
53impl Decodable for bool {
54    #[inline]
55    fn decode(buf: &mut &[u8]) -> Result<Self> {
56        Ok(match u8::decode(buf)? {
57            0 => false,
58            1 => true,
59            _ => return Err(Error::Custom("invalid bool value, must be 0 or 1")),
60        })
61    }
62}
63
64impl<const N: usize> Decodable for [u8; N] {
65    #[inline]
66    fn decode(from: &mut &[u8]) -> Result<Self> {
67        let bytes = Header::decode_bytes(from, false)?;
68        Self::try_from(bytes).map_err(|_| Error::UnexpectedLength)
69    }
70}
71
72macro_rules! decode_integer {
73    ($($t:ty),+ $(,)?) => {$(
74        impl Decodable for $t {
75            #[inline]
76            fn decode(buf: &mut &[u8]) -> Result<Self> {
77                let bytes = Header::decode_bytes(buf, false)?;
78                static_left_pad(bytes).map(<$t>::from_be_bytes)
79            }
80        }
81    )+};
82}
83
84decode_integer!(u8, u16, u32, u64, usize, u128);
85
86macro_rules! decode_nonzero_integer {
87    ($($t:ty => $inner:ty),+ $(,)?) => {$(
88        impl Decodable for $t {
89            #[inline]
90            fn decode(buf: &mut &[u8]) -> Result<Self> {
91                <$inner>::decode(buf).and_then(|value| {
92                    <$t>::new(value).ok_or(Error::Custom(NON_ZERO_INTEGER_ERROR))
93                })
94            }
95        }
96    )+};
97}
98
99decode_nonzero_integer! {
100    NonZeroU8 => u8,
101    NonZeroU16 => u16,
102    NonZeroU32 => u32,
103    NonZeroU64 => u64,
104    NonZeroUsize => usize,
105    NonZeroU128 => u128,
106}
107
108impl Decodable for Bytes {
109    #[inline]
110    fn decode(buf: &mut &[u8]) -> Result<Self> {
111        Header::decode_bytes(buf, false).map(|x| Self::from(x.to_vec()))
112    }
113}
114
115impl Decodable for BytesMut {
116    #[inline]
117    fn decode(buf: &mut &[u8]) -> Result<Self> {
118        Header::decode_bytes(buf, false).map(Self::from)
119    }
120}
121
122impl Decodable for alloc::string::String {
123    #[inline]
124    fn decode(buf: &mut &[u8]) -> Result<Self> {
125        Header::decode_str(buf).map(Into::into)
126    }
127}
128
129impl<T: Decodable> Decodable for alloc::vec::Vec<T> {
130    #[inline]
131    fn decode(buf: &mut &[u8]) -> Result<Self> {
132        let mut bytes = Header::decode_bytes(buf, true)?;
133        let mut vec = Self::new();
134        let payload_view = &mut bytes;
135        while !payload_view.is_empty() {
136            vec.push(T::decode(payload_view)?);
137        }
138        Ok(vec)
139    }
140}
141
142macro_rules! wrap_impl {
143    ($($(#[$attr:meta])* [$($gen:tt)*] <$t:ty>::$new:ident($t2:ty)),+ $(,)?) => {$(
144        $(#[$attr])*
145        impl<$($gen)*> Decodable for $t {
146            #[inline]
147            fn decode(buf: &mut &[u8]) -> Result<Self> {
148                <$t2 as Decodable>::decode(buf).map(<$t>::$new)
149            }
150        }
151    )+};
152}
153
154wrap_impl! {
155    #[cfg(feature = "arrayvec")]
156    [const N: usize] <arrayvec::ArrayVec<u8, N>>::from([u8; N]),
157    [T: Decodable] <alloc::boxed::Box<T>>::new(T),
158    [T: Decodable] <alloc::rc::Rc<T>>::new(T),
159    #[cfg(target_has_atomic = "ptr")]
160    [T: Decodable] <alloc::sync::Arc<T>>::new(T),
161}
162
163impl<T: ?Sized + alloc::borrow::ToOwned> Decodable for alloc::borrow::Cow<'_, T>
164where
165    T::Owned: Decodable,
166{
167    #[inline]
168    fn decode(buf: &mut &[u8]) -> Result<Self> {
169        T::Owned::decode(buf).map(Self::Owned)
170    }
171}
172
173#[cfg(any(feature = "std", feature = "core-net"))]
174mod std_impl {
175    use super::*;
176    #[cfg(all(feature = "core-net", not(feature = "std")))]
177    use core::net::{IpAddr, Ipv4Addr, Ipv6Addr};
178    #[cfg(feature = "std")]
179    use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
180
181    impl Decodable for IpAddr {
182        fn decode(buf: &mut &[u8]) -> Result<Self> {
183            let bytes = Header::decode_bytes(buf, false)?;
184            match bytes.len() {
185                4 => Ok(Self::V4(Ipv4Addr::from(slice_to_array::<4>(bytes).expect("infallible")))),
186                16 => {
187                    Ok(Self::V6(Ipv6Addr::from(slice_to_array::<16>(bytes).expect("infallible"))))
188                }
189                _ => Err(Error::UnexpectedLength),
190            }
191        }
192    }
193
194    impl Decodable for Ipv4Addr {
195        #[inline]
196        fn decode(buf: &mut &[u8]) -> Result<Self> {
197            let bytes = Header::decode_bytes(buf, false)?;
198            slice_to_array::<4>(bytes).map(Self::from)
199        }
200    }
201
202    impl Decodable for Ipv6Addr {
203        #[inline]
204        fn decode(buf: &mut &[u8]) -> Result<Self> {
205            let bytes = Header::decode_bytes(buf, false)?;
206            slice_to_array::<16>(bytes).map(Self::from)
207        }
208    }
209
210    #[inline]
211    fn slice_to_array<const N: usize>(slice: &[u8]) -> Result<[u8; N]> {
212        slice.try_into().map_err(|_| Error::UnexpectedLength)
213    }
214}
215
216/// Decodes the entire input, ensuring no trailing bytes remain.
217///
218/// # Errors
219///
220/// Returns an error if the encoding is invalid or if data remains after decoding the RLP item.
221#[inline]
222pub fn decode_exact<T: Decodable>(bytes: impl AsRef<[u8]>) -> Result<T> {
223    let mut buf = bytes.as_ref();
224    let out = T::decode(&mut buf)?;
225
226    // check if there are any remaining bytes after decoding
227    if !buf.is_empty() {
228        // TODO: introduce a new variant TrailingBytes to better distinguish this error
229        return Err(Error::UnexpectedLength);
230    }
231
232    Ok(out)
233}
234
235/// Left-pads a slice to a statically known size array.
236///
237/// # Errors
238///
239/// Returns an error if the slice is too long or if the first byte is 0.
240#[inline]
241pub(crate) fn static_left_pad<const N: usize>(data: &[u8]) -> Result<[u8; N]> {
242    if data.len() > N {
243        return Err(Error::Overflow);
244    }
245
246    let mut v = [0; N];
247
248    // yes, data may empty, e.g. we decode a bool false value
249    if data.is_empty() {
250        return Ok(v);
251    }
252
253    if data[0] == 0 {
254        return Err(Error::LeadingZero);
255    }
256
257    // SAFETY: length checked above
258    unsafe { v.get_unchecked_mut(N - data.len()..) }.copy_from_slice(data);
259    Ok(v)
260}
261
262#[cfg(test)]
263mod tests {
264    use super::*;
265    use crate::{encode, Encodable};
266    use core::{
267        fmt::Debug,
268        num::{NonZeroU128, NonZeroU16, NonZeroU32, NonZeroU64, NonZeroU8, NonZeroUsize},
269    };
270    use hex_literal::hex;
271
272    #[allow(unused_imports)]
273    use alloc::{string::String, vec::Vec};
274
275    fn check_decode<'a, T, IT>(fixtures: IT)
276    where
277        T: Encodable + Decodable + PartialEq + Debug,
278        IT: IntoIterator<Item = (Result<T>, &'a [u8])>,
279    {
280        for (expected, mut input) in fixtures {
281            if let Ok(expected) = &expected {
282                assert_eq!(crate::encode(expected), input, "{expected:?}");
283            }
284
285            let orig = input;
286            assert_eq!(
287                T::decode(&mut input),
288                expected,
289                "input: {}{}",
290                hex::encode(orig),
291                expected.as_ref().map_or_else(
292                    |_| String::new(),
293                    |expected| format!("; expected: {}", hex::encode(crate::encode(expected)))
294                )
295            );
296
297            if expected.is_ok() {
298                assert_eq!(input, &[]);
299            }
300        }
301    }
302
303    #[test]
304    fn rlp_bool() {
305        let out = [0x80];
306        let val = bool::decode(&mut &out[..]);
307        assert_eq!(Ok(false), val);
308
309        let out = [0x01];
310        let val = bool::decode(&mut &out[..]);
311        assert_eq!(Ok(true), val);
312    }
313
314    #[test]
315    fn rlp_strings() {
316        check_decode::<Bytes, _>([
317            (Ok(hex!("00")[..].to_vec().into()), &hex!("00")[..]),
318            (
319                Ok(hex!("6f62636465666768696a6b6c6d")[..].to_vec().into()),
320                &hex!("8D6F62636465666768696A6B6C6D")[..],
321            ),
322            (Err(Error::UnexpectedList), &hex!("C0")[..]),
323        ])
324    }
325
326    #[test]
327    fn rlp_fixed_length() {
328        check_decode([
329            (Ok(hex!("6f62636465666768696a6b6c6d")), &hex!("8D6F62636465666768696A6B6C6D")[..]),
330            (Err(Error::UnexpectedLength), &hex!("8C6F62636465666768696A6B6C")[..]),
331            (Err(Error::UnexpectedLength), &hex!("8E6F62636465666768696A6B6C6D6E")[..]),
332        ])
333    }
334
335    #[test]
336    fn rlp_u64() {
337        check_decode([
338            (Ok(9_u64), &hex!("09")[..]),
339            (Ok(0_u64), &hex!("80")[..]),
340            (Ok(0x0505_u64), &hex!("820505")[..]),
341            (Ok(0xCE05050505_u64), &hex!("85CE05050505")[..]),
342            (Err(Error::Overflow), &hex!("8AFFFFFFFFFFFFFFFFFF7C")[..]),
343            (Err(Error::InputTooShort), &hex!("8BFFFFFFFFFFFFFFFFFF7C")[..]),
344            (Err(Error::UnexpectedList), &hex!("C0")[..]),
345            (Err(Error::LeadingZero), &hex!("00")[..]),
346            (Err(Error::NonCanonicalSingleByte), &hex!("8105")[..]),
347            (Err(Error::LeadingZero), &hex!("8200F4")[..]),
348            (Err(Error::NonCanonicalSize), &hex!("B8020004")[..]),
349            (
350                Err(Error::Overflow),
351                &hex!("A101000000000000000000000000000000000000008B000000000000000000000000")[..],
352            ),
353        ])
354    }
355
356    #[test]
357    fn rlp_nonzero_uints() {
358        check_decode([(Ok(NonZeroU8::new(9).unwrap()), &hex!("09")[..])]);
359        check_decode([(Ok(NonZeroU16::new(0x0505).unwrap()), &hex!("820505")[..])]);
360        check_decode([(Ok(NonZeroU32::new(0xCE0505).unwrap()), &hex!("83CE0505")[..])]);
361        check_decode([(Ok(NonZeroU64::new(0xCE05050505).unwrap()), &hex!("85CE05050505")[..])]);
362        check_decode([(Ok(NonZeroUsize::new(0x80).unwrap()), &hex!("8180")[..])]);
363        check_decode([(
364            Ok(NonZeroU128::new(0x10203E405060708090A0B0C0D0E0F2).unwrap()),
365            &hex!("8f10203e405060708090a0b0c0d0e0f2")[..],
366        )]);
367        check_decode::<NonZeroU8, _>([(
368            Err(Error::Custom(NON_ZERO_INTEGER_ERROR)),
369            &hex!("80")[..],
370        )]);
371        check_decode::<NonZeroU8, _>([(Err(Error::LeadingZero), &hex!("00")[..])]);
372    }
373
374    #[test]
375    fn rlp_vectors() {
376        check_decode::<Vec<u64>, _>([
377            (Ok(vec![]), &hex!("C0")[..]),
378            (Ok(vec![0xBBCCB5_u64, 0xFFC0B5_u64]), &hex!("C883BBCCB583FFC0B5")[..]),
379        ])
380    }
381
382    #[cfg(feature = "std")]
383    #[test]
384    fn rlp_ip() {
385        use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
386
387        let localhost4 = Ipv4Addr::new(127, 0, 0, 1);
388        let localhost6 = localhost4.to_ipv6_mapped();
389        let expected4 = &hex!("847F000001")[..];
390        let expected6 = &hex!("9000000000000000000000ffff7f000001")[..];
391        check_decode::<Ipv4Addr, _>([(Ok(localhost4), expected4)]);
392        check_decode::<Ipv6Addr, _>([(Ok(localhost6), expected6)]);
393        check_decode::<IpAddr, _>([
394            (Ok(IpAddr::V4(localhost4)), expected4),
395            (Ok(IpAddr::V6(localhost6)), expected6),
396        ]);
397    }
398
399    #[test]
400    fn malformed_rlp() {
401        check_decode::<Bytes, _>([
402            (Err(Error::InputTooShort), &hex!("C1")[..]),
403            (Err(Error::InputTooShort), &hex!("D7")[..]),
404        ]);
405        check_decode::<[u8; 5], _>([
406            (Err(Error::InputTooShort), &hex!("C1")[..]),
407            (Err(Error::InputTooShort), &hex!("D7")[..]),
408        ]);
409        #[cfg(feature = "std")]
410        check_decode::<std::net::IpAddr, _>([
411            (Err(Error::InputTooShort), &hex!("C1")[..]),
412            (Err(Error::InputTooShort), &hex!("D7")[..]),
413        ]);
414        check_decode::<Vec<u8>, _>([
415            (Err(Error::InputTooShort), &hex!("C1")[..]),
416            (Err(Error::InputTooShort), &hex!("D7")[..]),
417        ]);
418        check_decode::<String, _>([
419            (Err(Error::InputTooShort), &hex!("C1")[..]),
420            (Err(Error::InputTooShort), &hex!("D7")[..]),
421        ]);
422        check_decode::<String, _>([
423            (Err(Error::InputTooShort), &hex!("C1")[..]),
424            (Err(Error::InputTooShort), &hex!("D7")[..]),
425        ]);
426        check_decode::<u8, _>([(Err(Error::InputTooShort), &hex!("82")[..])]);
427        check_decode::<u64, _>([(Err(Error::InputTooShort), &hex!("82")[..])]);
428    }
429
430    #[test]
431    fn rlp_full() {
432        fn check_decode_exact<T: Decodable + Encodable + PartialEq + Debug>(input: T) {
433            let encoded = encode(&input);
434            assert_eq!(decode_exact::<T>(&encoded), Ok(input));
435            assert_eq!(
436                decode_exact::<T>([encoded, vec![0x00]].concat()),
437                Err(Error::UnexpectedLength)
438            );
439        }
440
441        check_decode_exact::<String>("".into());
442        check_decode_exact::<String>("test1234".into());
443        check_decode_exact::<Vec<u64>>(vec![]);
444        check_decode_exact::<Vec<u64>>(vec![0; 4]);
445    }
446}