use crate::{
decoder::Event,
error::{self, Error},
field::WireType,
verihash::{self, DigestOutput},
};
use core::fmt::{self, Debug};
use digest::Digest;
pub(super) struct Hasher<D: Digest> {
verihash: verihash::Hasher<D>,
state: Option<State>,
}
impl<D> Hasher<D>
where
D: Digest,
{
pub fn new(wire_type: WireType) -> Self {
let mut verihash = verihash::Hasher::new();
verihash.input(&[wire_type.to_u8()]);
Self {
verihash,
state: Some(State::default()),
}
}
pub fn hash_event(&mut self, event: &Event<'_>) -> Result<(), Error> {
if let Some(state) = self.state.take() {
let new_state = state.transition(event, &mut self.verihash)?;
self.state = Some(new_state);
Ok(())
} else {
Err(error::Kind::Failed.into())
}
}
pub fn hash_message_digest(&mut self, digest: &DigestOutput<D>) -> Result<(), Error> {
match self.state {
Some(State::Message { remaining }) if remaining == 0 => {
self.verihash.input(digest);
self.state = Some(State::Initial);
Ok(())
}
_ => Err(error::Kind::Hashing.into()),
}
}
pub fn finish(self) -> Result<DigestOutput<D>, Error> {
if self.state == Some(State::Initial) {
Ok(self.verihash.finish())
} else {
Err(error::Kind::Hashing.into())
}
}
}
impl<D> Debug for Hasher<D>
where
D: Digest,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("sequence::Hasher").finish()
}
}
#[derive(Debug, Eq, PartialEq)]
enum State {
Initial,
Bytes { remaining: usize },
String { remaining: usize },
Message { remaining: usize },
}
impl Default for State {
fn default() -> Self {
State::Initial
}
}
impl State {
pub fn transition<D: Digest>(
self,
event: &Event<'_>,
verihash: &mut verihash::Hasher<D>,
) -> Result<Self, Error> {
match event {
Event::LengthDelimiter { wire_type, length } => {
self.handle_length_delimiter(*wire_type, *length, verihash)
}
Event::UInt64(_) | Event::SInt64(_) => self.handle_fixed_sized_value(event, verihash),
Event::ValueChunk {
wire_type,
bytes,
remaining,
} => self.handle_value_chunk(*wire_type, bytes, *remaining, verihash),
_ => Err(error::Kind::Hashing.into()),
}
}
fn handle_length_delimiter<D: Digest>(
self,
wire_type: WireType,
length: usize,
verihash: &mut verihash::Hasher<D>,
) -> Result<Self, Error> {
if self != State::Initial {
return Err(error::Kind::Hashing.into());
}
let new_state = match wire_type {
WireType::Bytes => State::Bytes { remaining: length },
WireType::String => State::String { remaining: length },
WireType::Message => State::Message { remaining: length },
_ => unreachable!(),
};
verihash.dynamically_sized_value(wire_type, length);
Ok(new_state)
}
fn handle_fixed_sized_value<D: Digest>(
self,
value: &Event<'_>,
verihash: &mut verihash::Hasher<D>,
) -> Result<Self, Error> {
if self != State::Initial {
return Err(error::Kind::Hashing.into());
}
match value {
Event::UInt64(value) => {
verihash.fixed_size_value(WireType::UInt64, &value.to_le_bytes())
}
Event::SInt64(value) => {
verihash.fixed_size_value(WireType::SInt64, &value.to_le_bytes())
}
_ => unreachable!(),
}
Ok(State::Initial)
}
fn handle_value_chunk<D: Digest>(
self,
wire_type: WireType,
bytes: &[u8],
new_remaining: usize,
verihash: &mut verihash::Hasher<D>,
) -> Result<Self, Error> {
let new_state = match self {
State::Bytes { remaining } => {
if wire_type != WireType::Bytes || remaining - bytes.len() != new_remaining {
return Err(error::Kind::Hashing.into());
}
if new_remaining == 0 {
State::Initial
} else {
State::Bytes {
remaining: new_remaining,
}
}
}
State::String { remaining } => {
if wire_type != WireType::String || remaining - bytes.len() != new_remaining {
return Err(error::Kind::Hashing.into());
}
if new_remaining == 0 {
State::Initial
} else {
State::String {
remaining: new_remaining,
}
}
}
State::Message { remaining } => {
if wire_type != WireType::Message || remaining - bytes.len() != new_remaining {
return Err(error::Kind::Hashing.into());
}
return Ok(State::Message {
remaining: new_remaining,
});
}
_ => return Err(error::Kind::Hashing.into()),
};
verihash.input(bytes);
Ok(new_state)
}
}