use crate::{common::*, curve_arithmetic::Curve};
use sha3::{Digest, Sha3_256};
use std::io::Write;
#[repr(transparent)]
#[derive(Debug)]
pub struct RandomOracle(Sha3_256);
#[derive(Debug, Serialize, PartialEq, Eq, Clone, Copy)]
pub struct Challenge {
challenge: [u8; 32],
}
impl AsRef<[u8]> for Challenge {
fn as_ref(&self) -> &[u8] {
&self.challenge
}
}
impl Write for RandomOracle {
#[inline(always)]
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
self.0.update(buf);
Ok(buf.len())
}
#[inline(always)]
fn write_all(&mut self, buf: &[u8]) -> std::io::Result<()> {
self.0.update(buf);
Ok(())
}
#[inline(always)]
fn flush(&mut self) -> std::io::Result<()> {
Ok(())
}
}
impl Buffer for RandomOracle {
type Result = sha3::digest::Output<Sha3_256>;
#[inline(always)]
fn start() -> Self {
RandomOracle::empty()
}
fn result(self) -> Self::Result {
self.0.finalize()
}
}
impl Eq for RandomOracle {}
impl PartialEq for RandomOracle {
fn eq(&self, other: &Self) -> bool {
self.0.clone().finalize() == other.0.clone().finalize()
}
}
impl RandomOracle {
pub fn empty() -> Self {
RandomOracle(Sha3_256::new())
}
pub fn domain<B: AsRef<[u8]>>(data: B) -> Self {
RandomOracle(Sha3_256::new().chain_update(data))
}
pub fn split(&self) -> Self {
RandomOracle(self.0.clone())
}
pub fn add<B: Serial>(&mut self, data: &B) {
self.put(data)
}
pub fn add_bytes<B: AsRef<[u8]>>(&mut self, data: B) {
self.0.update(data)
}
pub fn append_message<S: Serial, B: AsRef<[u8]>>(&mut self, label: B, message: &S) {
self.add_bytes(label);
self.add(message)
}
pub fn extend_from<'a, I, S, B: AsRef<[u8]>>(&mut self, label: B, iter: I)
where
S: Serial + 'a,
I: IntoIterator<Item = &'a S>,
{
self.add_bytes(label);
for i in iter.into_iter() {
self.add(i)
}
}
pub fn result_to_scalar<C: Curve>(self) -> C::Scalar {
C::scalar_from_bytes(self.result())
}
pub fn get_challenge(self) -> Challenge {
Challenge {
challenge: self.result().into(),
}
}
pub fn challenge_scalar<C: Curve, B: AsRef<[u8]>>(&mut self, label: B) -> C::Scalar {
self.add_bytes(label);
self.split().result_to_scalar::<C>()
}
}
#[cfg(test)]
mod tests {
use super::*;
use rand::*;
#[test]
pub fn test_extend_from() {
let mut v1 = vec![0u8; 50];
let mut csprng = thread_rng();
for _ in 0..1000 {
for v in v1.iter_mut() {
*v = csprng.gen::<u8>();
}
let mut s1 = RandomOracle::empty();
for x in v1.iter() {
s1.add(x);
}
let mut s2 = RandomOracle::empty();
s2.extend_from(b"", v1.iter());
let res1 = s1.result();
let ref_res1: &[u8] = res1.as_ref();
let res2 = s2.result();
let ref_res2: &[u8] = res2.as_ref();
assert_eq!(ref_res1, ref_res2);
}
}
#[test]
pub fn test_split() {
let mut v1 = vec![0u8; 50];
let mut csprng = thread_rng();
for _ in 0..1000 {
let mut s1 = RandomOracle::empty();
s1.add(&v1);
let mut s2 = s1.split();
for v in v1.iter_mut() {
*v = csprng.gen::<u8>();
s1.add(v);
}
let res1 = s1.result();
let ref_res1: &[u8] = res1.as_ref();
s2.add_bytes(&v1);
let res2 = s2.result();
let ref_res2: &[u8] = res2.as_ref();
assert_eq!(ref_res1, ref_res2);
}
}
}