use std::{any::type_name, iter};
use anyhow::bail;
use either::Either;
use ethereum_types::U256;
use nunny::NonEmpty;
use serde::de::DeserializeOwned;
use u4::{U4x2, U4};
use winnow::{
combinator::{empty, eof, fail, preceded, repeat_till, trace},
error::{ErrorKind, FromExternalError, StrContext},
stream::Stream,
token::{any, one_of, take},
Parser as _,
};
pub fn parse(input: &[u8]) -> anyhow::Result<NonEmpty<Vec<Instruction>>> {
match preceded(
one_of((0u8, 1u8)), repeat_till(1.., instruction, eof).map(|(it, _)| {
NonEmpty::<Vec<_>>::new(it).expect("repeat_till should ensure non-empty collection")
}),
)
.parse(input)
{
Ok(it) => Ok(it),
Err(e) => bail!("parse error at offset {}: {}", e.offset(), e.inner()),
}
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub enum Instruction {
Leaf {
key: NonEmpty<Vec<U4>>,
value: NonEmpty<Vec<u8>>,
},
Extension {
key: NonEmpty<Vec<U4>>,
},
Branch {
mask: u64,
},
Hash {
raw_hash: [u8; 32],
},
Code {
raw_code: NonEmpty<Vec<u8>>,
},
AccountLeaf {
key: NonEmpty<Vec<U4>>,
nonce: Option<u64>,
balance: Option<U256>,
has_code: bool,
has_storage: bool,
},
SmtLeaf(SmtLeaf),
EmptyRoot,
NewTrie,
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct SmtLeaf {
pub node_type: SmtLeafType,
pub address: NonEmpty<Vec<u8>>,
pub value: NonEmpty<Vec<u8>>,
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub enum SmtLeafType {
Balance,
Nonce,
Code,
Storage(NonEmpty<Vec<u8>>),
CodeLength,
}
type PResult<T> = winnow::PResult<T, winnow::error::ContextError>;
fn instruction(input: &mut &[u8]) -> PResult<Instruction> {
let start = input.checkpoint();
let opcode = any(input)?;
match opcode {
0x00 => trace(
"leaf",
(key, cbor).map(|(key, value)| Instruction::Leaf { key, value }),
)
.parse_next(input),
0x01 => trace("extension", key.map(|key| Instruction::Extension { key })).parse_next(input),
0x02 => trace("branch", cbor.map(|mask| Instruction::Branch { mask })).parse_next(input),
0x03 => {
trace("hash", array.map(|raw_hash| Instruction::Hash { raw_hash })).parse_next(input)
}
0x04 => {
trace("code", cbor.map(|raw_code| Instruction::Code { raw_code })).parse_next(input)
}
0x05 => trace("account_leaf", account_leaf).parse_next(input),
0x06 => trace("empty_root", empty.value(Instruction::EmptyRoot)).parse_next(input),
0x07 => trace("smt_leaf", smt_leaf).parse_next(input),
0xBB => trace("new_trie", empty.value(Instruction::NewTrie)).parse_next(input),
_ => {
input.reset(&start);
fail.context(StrContext::Label("unrecognised opcode"))
.parse_next(input)
}
}
}
#[derive(thiserror::Error, Debug)]
#[error("{}", .0)]
struct Error(&'static str);
fn account_leaf(input: &mut &[u8]) -> PResult<Instruction> {
bitflags::bitflags! {
struct AccountLeafFlags: u8 {
const HAS_CODE = 0b0000_0001;
const HAS_STORAGE = 0b0000_0010;
const ENCODES_NONCE = 0b0000_0100;
const ENCODES_BALANCE = 0b0000_1000;
}
}
let key = key(input)?;
let flags = any
.try_map(|byte| {
AccountLeafFlags::from_bits(byte)
.ok_or(Error("unrecognised bits in flags for account leaf"))
})
.parse_next(input)?;
Ok(Instruction::AccountLeaf {
key,
nonce: match flags.contains(AccountLeafFlags::ENCODES_NONCE) {
true => Some(trace("nonce", cbor).parse_next(input)?),
false => None,
},
balance: match flags.contains(AccountLeafFlags::ENCODES_BALANCE) {
true => Some(
trace(
"balance",
cbor::<Vec<u8>>.map(|bytes| U256::from_big_endian(&bytes)),
)
.parse_next(input)?,
),
false => None,
},
has_storage: flags.contains(AccountLeafFlags::HAS_STORAGE),
has_code: {
let has_code = flags.contains(AccountLeafFlags::HAS_CODE);
if has_code {
trace("code_length", cbor::<u64>).parse_next(input)?;
}
has_code
},
})
}
fn smt_leaf(input: &mut &[u8]) -> PResult<Instruction> {
let start = input.checkpoint();
let node_type = any(input)?;
Ok(Instruction::SmtLeaf(SmtLeaf {
address: cbor(input)?,
node_type: match node_type {
0 => SmtLeafType::Balance,
1 => SmtLeafType::Nonce,
2 => SmtLeafType::Code,
3 => SmtLeafType::Storage(cbor(input)?),
4 => SmtLeafType::CodeLength,
_ => {
input.reset(&start);
fail.context(StrContext::Label("unrecognised leaf node type"))
.parse_next(input)?
}
},
value: cbor(input)?,
}))
}
fn key(input: &mut &[u8]) -> PResult<NonEmpty<Vec<U4>>> {
trace("key", cbor.try_map(|it: NonEmpty<Vec<u8>>| decode_key(&it))).parse_next(input)
}
fn cbor<T: DeserializeOwned + std::fmt::Debug>(input: &mut &[u8]) -> PResult<T> {
trace(
format!("cbor{{{}}}", type_name::<T>()),
|input: &mut &[u8]| {
let start = input.checkpoint();
match ciborium::from_reader::<T, _>(&mut *input) {
Ok(it) => Ok(it),
Err(e) => {
input.reset(&start);
Err(FromExternalError::from_external_error(
input,
ErrorKind::Verify,
e,
))
}
}
},
)
.parse_next(input)
}
fn decode_key(bytes: &NonEmpty<[u8]>) -> Result<NonEmpty<Vec<U4>>, Error> {
bitflags::bitflags! {
struct EncodeKeyFlags: u8 {
const ODD = 0b0000_0001;
const TERMINATED = 0b0000_0010;
}
}
let v = match bytes.split_first() {
(only, &[]) => nunny::vec![U4::new(*only).ok_or(Error("excess bits in single nibble"))?],
(flags, rest) => {
let flags = EncodeKeyFlags::from_bits(*flags)
.ok_or(Error("unrecognised bits in flags for key encoding"))?;
let (last, rest) = match rest.split_last() {
Some((last, rest)) => (Some(*last), rest),
None => (None, &[][..]),
};
let tail = match (last, flags.contains(EncodeKeyFlags::ODD)) {
(None, true) => return Err(Error("span was zero but odd flag was set")),
(None, false) => Either::Left(iter::empty::<U4>()),
(Some(left), true) => Either::Right(Either::Left(iter::once(
U4::new(left.rotate_right(4)).ok_or(Error("unrecognised bits in nibble"))?,
))),
(Some(packed), false) => {
let both = U4x2 { packed };
Either::Right(Either::Right([both.left(), both.right()].into_iter()))
}
};
nunny::Vec::new(
rest.iter()
.copied()
.flat_map(|packed| {
let both = U4x2 { packed };
[both.left(), both.right()]
})
.chain(tail)
.collect(),
)
.expect("an empty `rest` must be caught by the Erigon special case")
}
};
Ok(v)
}
fn array<const N: usize>(input: &mut &[u8]) -> PResult<[u8; N]> {
take(N)
.map(|it: &[u8]| it.try_into().expect("take has already selected N bytes"))
.parse_next(input)
}
#[cfg(test)]
#[track_caller]
fn do_test<'a, T: PartialEq + core::fmt::Debug>(
src: &'a [u8],
expected: T,
mut parser: impl winnow::Parser<&'a [u8], T, winnow::error::ContextError>,
) {
let actual = parser.parse(src.as_ref()).unwrap();
assert_eq!(expected, actual)
}
#[test]
fn cbor_test_cases() {
do_test(b"\x00", 0, cbor);
do_test(b"\x01", 1, cbor);
do_test(b"\x0a", 10, cbor);
do_test(b"\x17", 23, cbor);
}