use super::{hasher::Hasher, state::State};
use crate::{
decoder::{vint64, Decodable, Event},
error::{self, Error},
field::WireType,
message::Element,
verihash::DigestOutput,
};
use digest::Digest;
pub(crate) struct Decoder<D: Digest> {
wire_type: WireType,
length: usize,
remaining: usize,
state: State,
hasher: Option<Hasher<D>>,
}
impl<D> Decoder<D>
where
D: Digest,
{
pub fn new(wire_type: WireType, length: usize) -> Self {
Self {
wire_type,
length,
remaining: length,
state: State::default(),
hasher: Some(Hasher::new(wire_type)), }
}
pub fn position(&self) -> usize {
self.length.checked_sub(self.remaining).unwrap()
}
pub fn remaining(&self) -> usize {
self.remaining
}
fn transition<'a>(&mut self, event: &Event<'a>) {
self.state = match &event {
Event::LengthDelimiter { wire_type, length }
| Event::SequenceHeader { wire_type, length } => State::Body {
wire_type: *wire_type,
remaining: *length,
},
Event::UInt64(_) | Event::SInt64(_) => State::Value(vint64::Decoder::new()),
Event::ValueChunk {
wire_type,
remaining,
..
} => {
if *remaining > 0 {
State::Body {
wire_type: *wire_type,
remaining: *remaining,
}
} else {
State::default()
}
}
other => unreachable!("unexpected event: {:?}", other),
};
}
pub fn hash_message_digest(&mut self, digest: &DigestOutput<D>) -> Result<(), Error> {
if let Some(hasher) = &mut self.hasher {
hasher.hash_message_digest(digest)?;
}
Ok(())
}
pub fn compute_digest(self) -> Result<Option<DigestOutput<D>>, Error> {
self.hasher.map(|hasher| hasher.finish()).transpose()
}
}
impl<D> Decodable for Decoder<D>
where
D: Digest,
{
fn decode<'a>(&mut self, input: &mut &'a [u8]) -> Result<Option<Event<'a>>, Error> {
let orig_input_len = input.len();
let maybe_event = self.state.decode(self.wire_type, input)?;
let consumed = orig_input_len.checked_sub(input.len()).unwrap();
self.remaining = self.remaining.checked_sub(consumed).unwrap();
if let Some(event) = &maybe_event {
if let Some(hasher) = &mut self.hasher {
hasher.hash_event(event)?;
}
self.transition(&event);
}
Ok(maybe_event)
}
fn decode_dynamically_sized_value<'a>(
&mut self,
expected_type: WireType,
input: &mut &'a [u8],
) -> Result<&'a [u8], Error> {
if expected_type != self.wire_type {
return Err(error::Kind::UnexpectedWireType {
actual: self.wire_type,
wanted: expected_type,
}
.into());
}
debug_assert!(
self.wire_type.is_dynamically_sized(),
"not a dynamically sized wire type: {:?}",
self.wire_type
);
let length = match self.decode(input)? {
Some(Event::LengthDelimiter { length, .. }) => Ok(length),
_ => Err(error::Kind::Decode {
element: Element::LengthDelimiter,
wire_type: self.wire_type,
}
.position(self.length.checked_sub(self.remaining).unwrap())),
}?;
match self.decode(input)? {
Some(Event::ValueChunk {
bytes, remaining, ..
}) => {
if remaining == 0 {
debug_assert_eq!(length, bytes.len());
Ok(bytes)
} else {
Err(error::Kind::Truncated {
remaining,
wire_type: self.wire_type,
}
.into())
}
}
_ => Err(error::Kind::Decode {
element: Element::Value,
wire_type: self.wire_type,
}
.into()),
}
}
}
#[cfg(all(test, features = "sha2"))]
mod tests {
use super::{Decodable, Decoder, WireType};
use sha2::Sha256;
#[test]
fn decode_uint64_sequence() {
let input = [3, 5, 7];
let mut input_ref = &input[..];
let mut decoder: Decoder<Sha256> = Decoder::new(WireType::UInt64, input.len());
assert_eq!(1, decoder.decode_uint64(&mut input_ref).unwrap());
assert_eq!(2, decoder.decode_uint64(&mut input_ref).unwrap());
assert_eq!(3, decoder.decode_uint64(&mut input_ref).unwrap());
assert!(input_ref.is_empty());
}
#[test]
fn decode_sint64_sequence() {
let input = [3, 7, 11];
let mut input_ref = &input[..];
let mut decoder: Decoder<Sha256> = Decoder::new(WireType::SInt64, input.len());
for n in &[-1, -2, -3] {
assert_eq!(*n, decoder.decode_sint64(&mut input_ref).unwrap());
}
assert!(input_ref.is_empty());
}
#[test]
fn decode_bytes_sequence() {
let input = [7, 102, 111, 111, 7, 98, 97, 114, 7, 98, 97, 122];
let mut input_ref = &input[..];
let mut decoder: Decoder<Sha256> = Decoder::new(WireType::Bytes, input.len());
for &b in &[b"foo", b"bar", b"baz"] {
assert_eq!(b, decoder.decode_bytes(&mut input_ref).unwrap());
}
assert!(input_ref.is_empty());
}
#[test]
fn decode_string_sequence() {
let input = [7, 102, 111, 111, 7, 98, 97, 114, 7, 98, 97, 122];
let mut input_ref = &input[..];
let mut decoder: Decoder<Sha256> = Decoder::new(WireType::String, input.len());
for &s in &["foo", "bar", "baz"] {
assert_eq!(s, decoder.decode_string(&mut input_ref).unwrap());
}
assert!(input_ref.is_empty());
}
}