#![allow(clippy::integer_arithmetic)]
use bytes::Bytes;
use hex::FromHex;
use more_asserts::{debug_assert_ge, debug_assert_lt};
use once_cell::sync::{Lazy, OnceCell};
#[cfg(any(test, feature = "fuzzing"))]
use proptest_derive::Arbitrary;
use rand::{rngs::OsRng, Rng};
use serde::{de, ser, Deserialize, Serialize};
use std::{
self,
convert::{AsRef, TryFrom},
fmt,
str::FromStr,
};
use tiny_keccak::{Hasher, Sha3};
pub(crate) const HASH_PREFIX: &[u8] = b"APTOS::";
#[derive(Clone, Copy, Eq, Hash, PartialEq, PartialOrd, Ord)]
#[cfg_attr(any(test, feature = "fuzzing"), derive(Arbitrary))]
pub struct HashValue {
hash: [u8; HashValue::LENGTH],
}
impl HashValue {
pub const LENGTH: usize = 32;
pub const LENGTH_IN_BITS: usize = Self::LENGTH * 8;
pub fn new(hash: [u8; HashValue::LENGTH]) -> Self {
HashValue { hash }
}
pub fn from_slice<T: AsRef<[u8]>>(bytes: T) -> Result<Self, HashValueParseError> {
<[u8; Self::LENGTH]>::try_from(bytes.as_ref())
.map_err(|_| HashValueParseError)
.map(Self::new)
}
pub fn to_vec(&self) -> Vec<u8> {
self.hash.to_vec()
}
pub const fn zero() -> Self {
HashValue {
hash: [0; HashValue::LENGTH],
}
}
pub fn random() -> Self {
let mut rng = OsRng;
let hash: [u8; HashValue::LENGTH] = rng.gen();
HashValue { hash }
}
pub fn random_with_rng<R: Rng>(rng: &mut R) -> Self {
let hash: [u8; HashValue::LENGTH] = rng.gen();
HashValue { hash }
}
pub fn sha3_256_of(buffer: &[u8]) -> Self {
let mut sha3 = Sha3::v256();
sha3.update(buffer);
HashValue::from_keccak(sha3)
}
#[cfg(test)]
pub fn from_iter_sha3<'a, I>(buffers: I) -> Self
where
I: IntoIterator<Item = &'a [u8]>,
{
let mut sha3 = Sha3::v256();
for buffer in buffers {
sha3.update(buffer);
}
HashValue::from_keccak(sha3)
}
fn as_ref_mut(&mut self) -> &mut [u8] {
&mut self.hash[..]
}
fn from_keccak(state: Sha3) -> Self {
let mut hash = Self::zero();
state.finalize(hash.as_ref_mut());
hash
}
pub fn bit(&self, index: usize) -> bool {
debug_assert!(index < Self::LENGTH_IN_BITS); let pos = index / 8;
let bit = 7 - index % 8;
(self.hash[pos] >> bit) & 1 != 0
}
pub fn nibble(&self, index: usize) -> u8 {
debug_assert!(index < Self::LENGTH * 2); let pos = index / 2;
let shift = if index % 2 == 0 { 4 } else { 0 };
(self.hash[pos] >> shift) & 0x0f
}
pub fn iter_bits(&self) -> HashValueBitIterator<'_> {
HashValueBitIterator::new(self)
}
pub fn from_bit_iter(
iter: impl ExactSizeIterator<Item = bool>,
) -> Result<Self, HashValueParseError> {
if iter.len() != Self::LENGTH_IN_BITS {
return Err(HashValueParseError);
}
let mut buf = [0; Self::LENGTH];
for (i, bit) in iter.enumerate() {
if bit {
buf[i / 8] |= 1 << (7 - i % 8);
}
}
Ok(Self::new(buf))
}
pub fn common_prefix_bits_len(&self, other: HashValue) -> usize {
self.iter_bits()
.zip(other.iter_bits())
.take_while(|(x, y)| x == y)
.count()
}
pub fn to_hex(&self) -> String {
format!("{:x}", self)
}
pub fn to_hex_literal(&self) -> String {
format!("{:#x}", self)
}
pub fn from_hex<T: AsRef<[u8]>>(hex: T) -> Result<Self, HashValueParseError> {
<[u8; Self::LENGTH]>::from_hex(hex)
.map_err(|_| HashValueParseError)
.map(Self::new)
}
#[cfg(any(test, feature = "fuzzing"))]
pub fn from_u64(v: u64) -> Self {
let mut hash = [0u8; Self::LENGTH];
let bytes = v.to_be_bytes();
hash[Self::LENGTH - bytes.len()..].copy_from_slice(&bytes[..]);
Self::new(hash)
}
}
impl ser::Serialize for HashValue {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: ser::Serializer,
{
if serializer.is_human_readable() {
serializer.serialize_str(&self.to_hex())
} else {
#[derive(Serialize)]
#[serde(rename = "HashValue")]
struct Value<'a> {
hash: &'a [u8; HashValue::LENGTH],
}
Value { hash: &self.hash }.serialize(serializer)
}
}
}
impl<'de> de::Deserialize<'de> for HashValue {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: de::Deserializer<'de>,
{
if deserializer.is_human_readable() {
let encoded_hash = <String>::deserialize(deserializer)?;
HashValue::from_hex(encoded_hash.as_str())
.map_err(<D::Error as ::serde::de::Error>::custom)
} else {
#[derive(Deserialize)]
#[serde(rename = "HashValue")]
struct Value {
hash: [u8; HashValue::LENGTH],
}
let value = Value::deserialize(deserializer)
.map_err(<D::Error as ::serde::de::Error>::custom)?;
Ok(Self::new(value.hash))
}
}
}
impl Default for HashValue {
fn default() -> Self {
HashValue::zero()
}
}
impl AsRef<[u8; HashValue::LENGTH]> for HashValue {
fn as_ref(&self) -> &[u8; HashValue::LENGTH] {
&self.hash
}
}
impl std::ops::Deref for HashValue {
type Target = [u8; Self::LENGTH];
fn deref(&self) -> &Self::Target {
&self.hash
}
}
impl std::ops::Index<usize> for HashValue {
type Output = u8;
fn index(&self, s: usize) -> &u8 {
self.hash.index(s)
}
}
impl fmt::Binary for HashValue {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
for byte in &self.hash {
write!(f, "{:08b}", byte)?;
}
Ok(())
}
}
impl fmt::LowerHex for HashValue {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if f.alternate() {
write!(f, "0x")?;
}
for byte in &self.hash {
write!(f, "{:02x}", byte)?;
}
Ok(())
}
}
impl fmt::Debug for HashValue {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "HashValue(")?;
<Self as fmt::LowerHex>::fmt(self, f)?;
write!(f, ")")?;
Ok(())
}
}
impl fmt::Display for HashValue {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
for byte in self.hash.iter().take(4) {
write!(f, "{:02x}", byte)?;
}
Ok(())
}
}
impl From<HashValue> for Bytes {
fn from(value: HashValue) -> Bytes {
Bytes::copy_from_slice(value.hash.as_ref())
}
}
impl FromStr for HashValue {
type Err = HashValueParseError;
fn from_str(s: &str) -> Result<Self, HashValueParseError> {
HashValue::from_hex(s)
}
}
#[derive(Clone, Copy, Debug)]
pub struct HashValueParseError;
impl fmt::Display for HashValueParseError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "unable to parse HashValue")
}
}
impl std::error::Error for HashValueParseError {}
pub struct HashValueBitIterator<'a> {
hash_bytes: &'a [u8],
pos: std::ops::Range<usize>,
}
impl<'a> HashValueBitIterator<'a> {
fn new(hash_value: &'a HashValue) -> Self {
HashValueBitIterator {
hash_bytes: hash_value.as_ref(),
pos: (0..HashValue::LENGTH_IN_BITS),
}
}
fn get_bit(&self, index: usize) -> bool {
debug_assert_eq!(self.hash_bytes.len(), HashValue::LENGTH); debug_assert_lt!(index, self.hash_bytes.len() * 8); debug_assert_ge!(index, 0); let pos = index / 8;
let bit = 7 - index % 8;
(self.hash_bytes[pos] >> bit) & 1 != 0
}
}
impl<'a> std::iter::Iterator for HashValueBitIterator<'a> {
type Item = bool;
fn next(&mut self) -> Option<Self::Item> {
self.pos.next().map(|x| self.get_bit(x))
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.pos.size_hint()
}
}
impl<'a> std::iter::DoubleEndedIterator for HashValueBitIterator<'a> {
fn next_back(&mut self) -> Option<Self::Item> {
self.pos.next_back().map(|x| self.get_bit(x))
}
}
impl<'a> std::iter::ExactSizeIterator for HashValueBitIterator<'a> {}
pub trait CryptoHash {
type Hasher: CryptoHasher;
fn hash(&self) -> HashValue;
}
pub trait CryptoHasher: Default + std::io::Write {
fn seed() -> &'static [u8; 32];
fn update(&mut self, bytes: &[u8]);
fn finish(self) -> HashValue;
fn hash_all(bytes: &[u8]) -> HashValue {
let mut hasher = Self::default();
hasher.update(bytes);
hasher.finish()
}
}
#[doc(hidden)]
#[derive(Clone)]
pub struct DefaultHasher {
state: Sha3,
}
impl DefaultHasher {
#[doc(hidden)]
pub fn prefixed_hash(buffer: &[u8]) -> [u8; HashValue::LENGTH] {
let salt: Vec<u8> = [HASH_PREFIX, buffer].concat();
HashValue::sha3_256_of(&salt[..]).hash
}
#[doc(hidden)]
pub fn new(typename: &[u8]) -> Self {
let mut state = Sha3::v256();
if !typename.is_empty() {
state.update(&Self::prefixed_hash(typename));
}
DefaultHasher { state }
}
#[doc(hidden)]
pub fn update(&mut self, bytes: &[u8]) {
self.state.update(bytes);
}
#[doc(hidden)]
pub fn finish(self) -> HashValue {
let mut hasher = HashValue::default();
self.state.finalize(hasher.as_ref_mut());
hasher
}
}
impl fmt::Debug for DefaultHasher {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "DefaultHasher: state = Sha3")
}
}
macro_rules! define_hasher {
(
$(#[$attr:meta])*
($hasher_type: ident, $hasher_name: ident, $seed_name: ident, $salt: expr)
) => {
#[derive(Clone, Debug)]
$(#[$attr])*
pub struct $hasher_type(DefaultHasher);
impl $hasher_type {
fn new() -> Self {
$hasher_type(DefaultHasher::new($salt))
}
}
static $hasher_name: Lazy<$hasher_type> = Lazy::new(|| { $hasher_type::new() });
static $seed_name: OnceCell<[u8; 32]> = OnceCell::new();
impl Default for $hasher_type {
fn default() -> Self {
$hasher_name.clone()
}
}
impl CryptoHasher for $hasher_type {
fn seed() -> &'static [u8;32] {
$seed_name.get_or_init(|| {
DefaultHasher::prefixed_hash($salt)
})
}
fn update(&mut self, bytes: &[u8]) {
self.0.update(bytes);
}
fn finish(self) -> HashValue {
self.0.finish()
}
}
impl std::io::Write for $hasher_type {
fn write(&mut self, bytes: &[u8]) -> std::io::Result<usize> {
self.0.update(bytes);
Ok(bytes.len())
}
fn flush(&mut self) -> std::io::Result<()> {
Ok(())
}
}
};
}
define_hasher! {
(
TransactionAccumulatorHasher,
TRANSACTION_ACCUMULATOR_HASHER,
TRANSACTION_ACCUMULATOR_SEED,
b"TransactionAccumulator"
)
}
define_hasher! {
(
EventAccumulatorHasher,
EVENT_ACCUMULATOR_HASHER,
EVENT_ACCUMULATOR_SEED,
b"EventAccumulator"
)
}
define_hasher! {
(
SparseMerkleInternalHasher,
SPARSE_MERKLE_INTERNAL_HASHER,
SPARSE_MERKLE_INTERNAL_SEED,
b"SparseMerkleInternal"
)
}
define_hasher! {
(TestOnlyHasher, TEST_ONLY_HASHER, TEST_ONLY_SEED, b"")
}
fn create_literal_hash(word: &str) -> HashValue {
let mut s = word.as_bytes().to_vec();
assert!(s.len() <= HashValue::LENGTH);
s.resize(HashValue::LENGTH, 0);
HashValue::from_slice(&s).expect("Cannot fail")
}
pub static ACCUMULATOR_PLACEHOLDER_HASH: Lazy<HashValue> =
Lazy::new(|| create_literal_hash("ACCUMULATOR_PLACEHOLDER_HASH"));
pub static SPARSE_MERKLE_PLACEHOLDER_HASH: Lazy<HashValue> =
Lazy::new(|| create_literal_hash("SPARSE_MERKLE_PLACEHOLDER_HASH"));
pub static PRE_GENESIS_BLOCK_ID: Lazy<HashValue> =
Lazy::new(|| create_literal_hash("PRE_GENESIS_BLOCK_ID"));
pub static GENESIS_BLOCK_ID: Lazy<HashValue> = Lazy::new(|| {
HashValue::new([
0x5e, 0x10, 0xba, 0xd4, 0x5b, 0x35, 0xed, 0x92, 0x9c, 0xd6, 0xd2, 0xc7, 0x09, 0x8b, 0x13,
0x5d, 0x02, 0xdd, 0x25, 0x9a, 0xe8, 0x8a, 0x8d, 0x09, 0xf4, 0xeb, 0x5f, 0xba, 0xe9, 0xa6,
0xf6, 0xe4,
])
});
pub trait TestOnlyHash {
fn test_only_hash(&self) -> HashValue;
}
impl<T: ser::Serialize + ?Sized> TestOnlyHash for T {
fn test_only_hash(&self) -> HashValue {
let bytes = bcs::to_bytes(self).expect("serialize failed during hash.");
let mut hasher = TestOnlyHasher::default();
hasher.update(&bytes);
hasher.finish()
}
}