use core::fmt::Debug;
use digest::OutputSizeUser;
use generic_array::GenericArray;
use serde::{Deserialize, Serialize};
use serde_with::serde_as;
use std::fmt;
use crate::encoding::{Base64, Encoding};
use crate::groups::ristretto255::RistrettoPoint;
use crate::groups::HashToGroupElement;
#[serde_as]
#[derive(Hash, PartialEq, Eq, Clone, Serialize, Deserialize, Ord, PartialOrd, Copy)]
pub struct Digest<const DIGEST_LEN: usize> {
#[serde_as(as = "[_; DIGEST_LEN]")]
pub digest: [u8; DIGEST_LEN],
}
impl<const DIGEST_LEN: usize> Digest<DIGEST_LEN> {
pub fn new(digest: [u8; DIGEST_LEN]) -> Self {
Digest { digest }
}
pub fn to_vec(&self) -> Vec<u8> {
self.digest.to_vec()
}
pub fn size(&self) -> usize {
DIGEST_LEN
}
}
impl<const DIGEST_LEN: usize> fmt::Debug for Digest<DIGEST_LEN> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
write!(f, "{}", Base64::encode(self.digest))
}
}
impl<const DIGEST_LEN: usize> fmt::Display for Digest<DIGEST_LEN> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
write!(f, "{}", Base64::encode(self.digest))
}
}
impl<const DIGEST_LEN: usize> AsRef<[u8]> for Digest<DIGEST_LEN> {
fn as_ref(&self) -> &[u8] {
self.digest.as_ref()
}
}
impl<const DIGEST_LEN: usize> From<Digest<DIGEST_LEN>> for [u8; DIGEST_LEN] {
fn from(digest: Digest<DIGEST_LEN>) -> Self {
digest.digest
}
}
pub trait HashFunction<const DIGEST_LENGTH: usize>: Default {
const OUTPUT_SIZE: usize = DIGEST_LENGTH;
fn new() -> Self {
Self::default()
}
fn update<Data: AsRef<[u8]>>(&mut self, data: Data);
fn finalize(self) -> Digest<DIGEST_LENGTH>;
fn digest<Data: AsRef<[u8]>>(data: Data) -> Digest<DIGEST_LENGTH> {
let mut h = Self::default();
h.update(data);
h.finalize()
}
fn digest_iterator<K: AsRef<[u8]>, I: Iterator<Item = K>>(iter: I) -> Digest<DIGEST_LENGTH> {
let mut h = Self::default();
iter.for_each(|item| h.update(item));
h.finalize()
}
}
pub trait Hash<const DIGEST_LEN: usize> {
type TypedDigest: Into<Digest<DIGEST_LEN>> + Eq + std::hash::Hash + Copy + Debug;
fn digest(&self) -> Self::TypedDigest;
}
#[derive(Default)]
pub struct HashFunctionWrapper<Variant, const DIGEST_LEN: usize>(Variant);
pub trait ReverseWrapper {
type Variant: digest::core_api::CoreProxy + OutputSizeUser;
}
impl<Variant: digest::core_api::CoreProxy + OutputSizeUser, const DIGEST_LEN: usize> ReverseWrapper
for HashFunctionWrapper<Variant, DIGEST_LEN>
{
type Variant = Variant;
}
impl<Variant: digest::Digest + Default, const DIGEST_LEN: usize> HashFunction<DIGEST_LEN>
for HashFunctionWrapper<Variant, DIGEST_LEN>
{
fn update<Data: AsRef<[u8]>>(&mut self, data: Data) {
self.0.update(data);
}
fn finalize(self) -> Digest<DIGEST_LEN> {
let mut digest = [0u8; DIGEST_LEN];
self.0
.finalize_into(GenericArray::from_mut_slice(&mut digest));
Digest { digest }
}
}
impl<Variant: digest::Digest + Default, const DIGEST_LEN: usize> std::io::Write
for HashFunctionWrapper<Variant, DIGEST_LEN>
{
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
self.update(buf);
Ok(buf.len())
}
fn flush(&mut self) -> std::io::Result<()> {
Ok(())
}
}
pub type Sha256 = HashFunctionWrapper<sha2::Sha256, 32>;
pub type Sha3_256 = HashFunctionWrapper<sha3::Sha3_256, 32>;
pub type Sha512 = HashFunctionWrapper<sha2::Sha512, 64>;
pub type Sha3_512 = HashFunctionWrapper<sha3::Sha3_512, 64>;
pub type Keccak256 = HashFunctionWrapper<sha3::Keccak256, 32>;
pub type Blake2b256 = HashFunctionWrapper<blake2::Blake2b<typenum::U32>, 32>;
pub trait MultisetHash<const DIGEST_LENGTH: usize>: Eq {
fn insert<Data: AsRef<[u8]>>(&mut self, item: Data);
fn insert_all<It, Data>(&mut self, items: It)
where
It: IntoIterator<Item = Data>,
Data: AsRef<[u8]>;
fn union(&mut self, other: &Self);
fn remove<Data: AsRef<[u8]>>(&mut self, item: Data);
fn remove_all<It, Data>(&mut self, items: It)
where
It: IntoIterator<Item = Data>,
Data: AsRef<[u8]>;
fn digest(&self) -> Digest<DIGEST_LENGTH>;
}
#[derive(Default, Clone, Serialize, Deserialize)]
pub struct EllipticCurveMultisetHash {
accumulator: RistrettoPoint,
}
impl PartialEq for EllipticCurveMultisetHash {
fn eq(&self, other: &Self) -> bool {
self.accumulator == other.accumulator
}
}
impl Eq for EllipticCurveMultisetHash {}
impl MultisetHash<32> for EllipticCurveMultisetHash {
fn insert<Data: AsRef<[u8]>>(&mut self, item: Data) {
self.accumulator += Self::hash_to_point(item);
}
fn insert_all<It, Data>(&mut self, items: It)
where
It: IntoIterator<Item = Data>,
Data: AsRef<[u8]>,
{
for i in items {
self.insert(i);
}
}
fn union(&mut self, other: &Self) {
self.accumulator += other.accumulator;
}
fn remove<Data: AsRef<[u8]>>(&mut self, item: Data) {
self.accumulator -= Self::hash_to_point(item);
}
fn remove_all<It, Data>(&mut self, items: It)
where
It: IntoIterator<Item = Data>,
Data: AsRef<[u8]>,
{
for i in items {
self.remove(i);
}
}
fn digest(&self) -> Digest<32> {
let serialized = &bincode::serialize(&self.accumulator).unwrap();
Sha256::digest(serialized)
}
}
impl EllipticCurveMultisetHash {
fn hash_to_point<Data: AsRef<[u8]>>(item: Data) -> RistrettoPoint {
RistrettoPoint::hash_to_group_element(item.as_ref())
}
}
impl Debug for EllipticCurveMultisetHash {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Accumulator").finish()
}
}