1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155
use crate::{decode::static_left_pad, Error, Result, EMPTY_LIST_CODE, EMPTY_STRING_CODE};
use bytes::{Buf, BufMut};
use core::hint::unreachable_unchecked;
/// The header of an RLP item.
#[derive(Clone, Debug, Default, PartialEq, Eq)]
pub struct Header {
/// True if list, false otherwise.
pub list: bool,
/// Length of the payload in bytes.
pub payload_length: usize,
}
impl Header {
/// Decodes an RLP header from the given buffer.
///
/// # Errors
///
/// Returns an error if the buffer is too short or the header is invalid.
#[inline]
pub fn decode(buf: &mut &[u8]) -> Result<Self> {
let payload_length;
let mut list = false;
match get_next_byte(buf)? {
0..=0x7F => payload_length = 1,
b @ EMPTY_STRING_CODE..=0xB7 => {
buf.advance(1);
payload_length = (b - EMPTY_STRING_CODE) as usize;
if payload_length == 1 && get_next_byte(buf)? < EMPTY_STRING_CODE {
return Err(Error::NonCanonicalSingleByte)
}
}
b @ (0xB8..=0xBF | 0xF8..=0xFF) => {
buf.advance(1);
list = b >= 0xF8; // second range
let code = if list { 0xF7 } else { 0xB7 };
// SAFETY: `b - code` is always in the range `1..=8` in the current match arm.
// The compiler/LLVM apparently cannot prove this because of the `|` pattern +
// the above `if`, since it can do it in the other arms with only 1 range.
let len_of_len = unsafe { b.checked_sub(code).unwrap_unchecked() } as usize;
if len_of_len == 0 || len_of_len > 8 {
unsafe { unreachable_unchecked() }
}
if buf.len() < len_of_len {
return Err(Error::InputTooShort)
}
// SAFETY: length checked above
let len = unsafe { buf.get_unchecked(..len_of_len) };
buf.advance(len_of_len);
let len = u64::from_be_bytes(static_left_pad(len)?);
payload_length =
usize::try_from(len).map_err(|_| Error::Custom("Input too big"))?;
if payload_length < 56 {
return Err(Error::NonCanonicalSize)
}
}
b @ EMPTY_LIST_CODE..=0xF7 => {
buf.advance(1);
list = true;
payload_length = (b - EMPTY_LIST_CODE) as usize;
}
}
if buf.remaining() < payload_length {
return Err(Error::InputTooShort)
}
Ok(Self {
list,
payload_length,
})
}
/// Decodes the next payload from the given buffer, advancing it.
///
/// # Errors
///
/// Returns an error if the buffer is too short or the header is invalid.
#[inline]
pub fn decode_bytes<'a>(buf: &mut &'a [u8], is_list: bool) -> Result<&'a [u8]> {
let Self {
list,
payload_length,
} = Self::decode(buf)?;
if list != is_list {
return Err(if is_list {
Error::UnexpectedString
} else {
Error::UnexpectedList
})
}
// SAFETY: this is already checked in `decode`
if buf.remaining() < payload_length {
unsafe { unreachable_unchecked() }
}
let bytes = unsafe { buf.get_unchecked(..payload_length) };
buf.advance(payload_length);
Ok(bytes)
}
/// Decodes a string slice from the given buffer, advancing it.
///
/// # Errors
///
/// Returns an error if the buffer is too short or the header is invalid.
#[inline]
pub fn decode_str<'a>(buf: &mut &'a [u8]) -> Result<&'a str> {
let bytes = Self::decode_bytes(buf, false)?;
core::str::from_utf8(bytes).map_err(|_| Error::Custom("invalid string"))
}
/// Encodes the header into the `out` buffer.
#[inline]
pub fn encode(&self, out: &mut dyn BufMut) {
if self.payload_length < 56 {
let code = if self.list {
EMPTY_LIST_CODE
} else {
EMPTY_STRING_CODE
};
out.put_u8(code + self.payload_length as u8);
} else {
let len_be;
let len_be = crate::encode::to_be_bytes_trimmed!(len_be, self.payload_length);
let code = if self.list { 0xF7 } else { 0xB7 };
out.put_u8(code + len_be.len() as u8);
out.put_slice(len_be);
}
}
/// Returns the length of the encoded header.
#[inline]
pub const fn length(&self) -> usize {
crate::length_of_length(self.payload_length)
}
}
/// Same as `buf.first().ok_or(Error::InputTooShort)`.
#[inline(always)]
fn get_next_byte(buf: &[u8]) -> Result<u8> {
if buf.is_empty() {
return Err(Error::InputTooShort)
}
// SAFETY: length checked above
Ok(*unsafe { buf.get_unchecked(0) })
}