use crate::{common::*, curve_arithmetic::Curve};
use sha3::{Digest, Sha3_256};
use std::convert::Infallible;
use std::fmt::Arguments;
use std::io::{IoSlice, Write};
#[repr(transparent)]
#[derive(Debug)]
pub struct RandomOracle(Sha3_256);
#[repr(transparent)]
#[derive(Debug)]
pub struct TranscriptProtocolV1(Sha3_256);
#[derive(Debug, Serialize, PartialEq, Eq, Clone, Copy)]
pub struct Challenge {
challenge: [u8; 32],
}
impl AsRef<[u8]> for Challenge {
fn as_ref(&self) -> &[u8] {
&self.challenge
}
}
impl Write for RandomOracle {
#[inline(always)]
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
self.0.update(buf);
Ok(buf.len())
}
#[inline(always)]
fn write_all(&mut self, buf: &[u8]) -> std::io::Result<()> {
self.0.update(buf);
Ok(())
}
#[inline(always)]
fn flush(&mut self) -> std::io::Result<()> {
Ok(())
}
}
impl Buffer for RandomOracle {
type Result = sha3::digest::Output<Sha3_256>;
#[inline(always)]
fn start() -> Self {
#[allow(deprecated)]
RandomOracle::empty()
}
fn result(self) -> Self::Result {
self.0.finalize()
}
}
impl Eq for RandomOracle {}
impl PartialEq for RandomOracle {
fn eq(&self, other: &Self) -> bool {
self.0.clone().finalize() == other.0.clone().finalize()
}
}
pub trait TranscriptProtocol {
fn append_label(&mut self, label: impl AsRef<[u8]>);
fn append_message(&mut self, label: impl AsRef<[u8]>, message: &impl Serial);
fn append_messages<'a, T: Serial + 'a, B: IntoIterator<Item = &'a T>>(
&mut self,
label: impl AsRef<[u8]>,
messages: B,
) where
B::IntoIter: ExactSizeIterator;
fn append_final_prover_message(&mut self, label: impl AsRef<[u8]>, message: &impl Serial);
fn append_each_message<T, B: IntoIterator<Item = T>>(
&mut self,
label: impl AsRef<[u8]>,
messages: B,
append_item: impl FnMut(&mut Self, T),
) where
B::IntoIter: ExactSizeIterator;
fn extract_challenge_scalar<C: Curve>(&mut self, label: impl AsRef<[u8]>) -> C::Scalar;
fn extract_raw_challenge(&self) -> Challenge;
}
impl TranscriptProtocol for RandomOracle {
fn append_label(&mut self, label: impl AsRef<[u8]>) {
self.0.update(label)
}
fn append_message(&mut self, label: impl AsRef<[u8]>, message: &impl Serial) {
self.append_label(label);
self.put(message)
}
fn append_messages<'a, T: Serial + 'a, B: IntoIterator<Item = &'a T>>(
&mut self,
label: impl AsRef<[u8]>,
messages: B,
) where
B::IntoIter: ExactSizeIterator,
{
self.append_label(label);
for message in messages {
self.put(message);
}
}
fn append_final_prover_message(&mut self, _label: impl AsRef<[u8]>, _message: &impl Serial) {
}
fn append_each_message<T, B: IntoIterator<Item = T>>(
&mut self,
label: impl AsRef<[u8]>,
messages: B,
mut append_item: impl FnMut(&mut Self, T),
) where
B::IntoIter: ExactSizeIterator,
{
self.append_label(label);
for message in messages {
append_item(self, message);
}
}
fn extract_challenge_scalar<C: Curve>(&mut self, label: impl AsRef<[u8]>) -> C::Scalar {
self.challenge_scalar::<C, _>(label)
}
fn extract_raw_challenge(&self) -> Challenge {
self.split().get_challenge()
}
}
impl RandomOracle {
#[cfg_attr(
not(test),
deprecated(
note = "Use TranscriptProtocolV1 instead, see documentation on RandomOracle. Do not change existing protocols without changing their proof version since it will break compatability with existing proofs."
)
)]
pub fn empty() -> Self {
RandomOracle(Sha3_256::new())
}
#[cfg_attr(
not(test),
deprecated(
note = "Use TranscriptProtocolV1 instead, see documentation on RandomOracle. Do not change existing protocols without changing their proof version since it will break compatability with existing proofs."
)
)]
pub fn domain<B: AsRef<[u8]>>(data: B) -> Self {
RandomOracle(Sha3_256::new().chain_update(data))
}
pub fn split(&self) -> Self {
RandomOracle(self.0.clone())
}
pub fn add_bytes<B: AsRef<[u8]>>(&mut self, data: B) {
self.0.update(data)
}
pub fn extend_from<'a, I, S, B: AsRef<[u8]>>(&mut self, label: B, iter: I)
where
S: Serial + 'a,
I: IntoIterator<Item = &'a S>,
{
self.add_bytes(label);
for i in iter.into_iter() {
self.put(i)
}
}
pub fn result_to_scalar<C: Curve>(self) -> C::Scalar {
C::scalar_from_bytes(self.result())
}
pub fn get_challenge(self) -> Challenge {
Challenge {
challenge: self.result().into(),
}
}
pub fn challenge_scalar<C: Curve, B: AsRef<[u8]>>(&mut self, label: B) -> C::Scalar {
self.add_bytes(label);
self.split().result_to_scalar::<C>()
}
}
struct BufferAdapter<T>(T);
impl<T: Write> Write for BufferAdapter<T> {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
self.0.write(buf)
}
fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> std::io::Result<usize> {
self.0.write_vectored(bufs)
}
fn flush(&mut self) -> std::io::Result<()> {
self.0.flush()
}
fn write_all(&mut self, buf: &[u8]) -> std::io::Result<()> {
self.0.write_all(buf)
}
fn write_fmt(&mut self, args: Arguments<'_>) -> std::io::Result<()> {
self.0.write_fmt(args)
}
}
impl<T: Write> Buffer for BufferAdapter<T> {
type Result = Infallible;
fn start() -> Self {
unimplemented!()
}
fn result(self) -> Self::Result {
unimplemented!()
}
}
impl TranscriptProtocol for TranscriptProtocolV1 {
fn append_label(&mut self, label: impl AsRef<[u8]>) {
let label = label.as_ref();
BufferAdapter(&mut self.0).put(&(label.len() as u64));
self.0.update(label)
}
fn append_message(&mut self, label: impl AsRef<[u8]>, message: &impl Serial) {
self.append_label(label);
BufferAdapter(&mut self.0).put(message)
}
fn append_messages<'a, T: Serial + 'a, B: IntoIterator<Item = &'a T>>(
&mut self,
label: impl AsRef<[u8]>,
messages: B,
) where
B::IntoIter: ExactSizeIterator,
{
let messages = messages.into_iter();
self.append_label(label);
BufferAdapter(&mut self.0).put(&(messages.len() as u64));
for message in messages {
BufferAdapter(&mut self.0).put(message);
}
}
fn append_final_prover_message(&mut self, label: impl AsRef<[u8]>, message: &impl Serial) {
self.append_message(label, message);
}
fn append_each_message<T, B: IntoIterator<Item = T>>(
&mut self,
label: impl AsRef<[u8]>,
messages: B,
mut append_item: impl FnMut(&mut Self, T),
) where
B::IntoIter: ExactSizeIterator,
{
let messages = messages.into_iter();
self.append_label(label);
BufferAdapter(&mut self.0).put(&(messages.len() as u64));
for message in messages {
append_item(self, message);
}
}
fn extract_challenge_scalar<C: Curve>(&mut self, label: impl AsRef<[u8]>) -> C::Scalar {
self.append_label(label);
C::scalar_from_bytes(self.extract_raw_challenge().challenge)
}
fn extract_raw_challenge(&self) -> Challenge {
Challenge {
challenge: self.0.clone().finalize().into(),
}
}
}
impl TranscriptProtocolV1 {
pub fn with_domain(domain: impl AsRef<[u8]>) -> Self {
let mut transcript = TranscriptProtocolV1(Sha3_256::new());
transcript.append_label(domain);
transcript
}
pub fn split(&self) -> Self {
TranscriptProtocolV1(self.0.clone())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::common;
use crate::id::constants::ArCurve;
use rand::*;
#[test]
pub fn test_extend_from() {
let mut v1 = vec![0u8; 50];
let mut csprng = thread_rng();
for _ in 0..1000 {
for v in v1.iter_mut() {
*v = csprng.gen::<u8>();
}
let mut s1 = RandomOracle::empty();
for x in v1.iter() {
s1.put(x);
}
let mut s2 = RandomOracle::empty();
#[allow(deprecated)]
s2.extend_from(b"", v1.iter());
let res1 = s1.result();
let ref_res1: &[u8] = res1.as_ref();
let res2 = s2.result();
let ref_res2: &[u8] = res2.as_ref();
assert_eq!(ref_res1, ref_res2);
}
}
#[test]
pub fn test_split() {
let mut v1 = vec![0u8; 50];
let mut csprng = thread_rng();
for _ in 0..1000 {
let mut s1 = RandomOracle::empty();
s1.put(&v1);
let mut s2 = s1.split();
for v in v1.iter_mut() {
*v = csprng.gen::<u8>();
s1.put(v);
}
let res1 = s1.result();
let ref_res1: &[u8] = res1.as_ref();
s2.add_bytes(&v1);
let res2 = s2.result();
let ref_res2: &[u8] = res2.as_ref();
assert_eq!(ref_res1, ref_res2);
}
}
#[test]
pub fn test_v0_domain_stable() {
let ro = RandomOracle::domain("Domain1");
let challenge_hex = hex::encode(ro.get_challenge());
assert_eq!(
challenge_hex,
"b6dbfe8bfbc515d92bcc322b1e98291a45536f81f6eca2411d8dae54766666f1"
);
}
#[test]
pub fn test_v0_add_bytes_stable() {
let mut ro = RandomOracle::empty();
ro.add_bytes([1u8, 2, 3]);
let challenge_hex = hex::encode(ro.get_challenge());
assert_eq!(
challenge_hex,
"fd1780a6fc9ee0dab26ceb4b3941ab03e66ccd970d1db91612c66df4515b0a0a"
);
}
#[test]
pub fn test_v0_append_label_stable() {
let mut ro = RandomOracle::empty();
ro.append_label([1u8, 2, 3]);
let challenge_hex = hex::encode(ro.get_challenge());
assert_eq!(
challenge_hex,
"fd1780a6fc9ee0dab26ceb4b3941ab03e66ccd970d1db91612c66df4515b0a0a"
);
}
#[test]
pub fn test_v0_append_message_stable() {
let mut ro = RandomOracle::empty();
ro.append_message("Label1", &vec![1u8, 2, 3]);
let challenge_hex = hex::encode(ro.get_challenge());
assert_eq!(
challenge_hex,
"3756eec6f9241f9a1cd8b401f54679cf9be2e057365728336221b1871ff666fb"
);
}
#[test]
pub fn test_v0_append_messages_stable() {
let mut ro = RandomOracle::empty();
ro.append_messages("Label1", &vec![1u8, 2, 3]);
let challenge_hex = hex::encode(ro.get_challenge());
assert_eq!(
challenge_hex,
"6b1addb1c08e887242f5e78127c31c17851f29349c45aa415adce255f95fd292"
);
let mut ro = RandomOracle::empty();
ro.extend_from("Label1", &vec![1u8, 2, 3]);
let challenge_hex = hex::encode(ro.get_challenge());
assert_eq!(
challenge_hex,
"6b1addb1c08e887242f5e78127c31c17851f29349c45aa415adce255f95fd292"
);
}
#[test]
pub fn test_v0_append_final_prover_message_stable() {
let mut ro = RandomOracle::empty();
ro.append_final_prover_message("Label1", &vec![1u8, 2, 3]);
let challenge_hex = hex::encode(ro.get_challenge());
assert_eq!(
challenge_hex,
"a7ffc6f8bf1ed76651c14756a061d662f580ff4de43b49fa82d80a4b80f8434a"
);
}
#[test]
pub fn test_v0_extract_challenge_scalar_stable() {
let ro = RandomOracle::empty();
let scalar_hex = hex::encode(common::to_bytes(
&ro.split().extract_challenge_scalar::<ArCurve>("Scalar1"),
));
assert_eq!(
scalar_hex,
"08646777f9c47efc863115861aa18d95653212c3bdf36899c7db46fbdae095cd"
);
let scalar_hex = hex::encode(common::to_bytes(
&ro.split().challenge_scalar::<ArCurve, _>("Scalar1"),
));
assert_eq!(
scalar_hex,
"08646777f9c47efc863115861aa18d95653212c3bdf36899c7db46fbdae095cd"
);
}
#[test]
pub fn test_v0_append_each_message_stable() {
let mut ro = RandomOracle::empty();
ro.append_each_message("Label1", &vec![1u8, 2, 3], |ro, item| {
ro.append_message("Item", item)
});
let challenge_hex = hex::encode(ro.get_challenge());
assert_eq!(
challenge_hex,
"90da7b2dc7bc9091be9201598ef0d8b43f8b00c53454822a2f8ce41c6a3f3d85"
);
}
#[test]
pub fn test_v1_with_domain_stable() {
let ro = TranscriptProtocolV1::with_domain("Domain1");
let challenge_hex = hex::encode(ro.extract_raw_challenge());
assert_eq!(
challenge_hex,
"5691f0658460c461ffe14baa70071545df78725892d0decfe6f6642233a0d8e2"
);
}
#[test]
pub fn test_v1_append_label_stable() {
let mut ro = TranscriptProtocolV1::with_domain("Domain1");
ro.append_label([1u8, 2, 3]);
let challenge_hex = hex::encode(ro.extract_raw_challenge());
assert_eq!(
challenge_hex,
"683a300a44b3f9165f78dd9fd90efc9a632c11131ef5e805ff3505b5bf0cc7d2"
);
}
#[test]
pub fn test_v1_append_message_stable() {
let mut ro = TranscriptProtocolV1::with_domain("Domain1");
ro.append_message("Label1", &vec![1u8, 2, 3]);
let challenge_hex = hex::encode(ro.extract_raw_challenge());
assert_eq!(
challenge_hex,
"5fb23e3d1cfb33d1b2e2da1c070c7a79056b00d13d642ee47fba542d4863a911"
);
}
#[test]
pub fn test_v1_append_messages_stable() {
let mut ro = TranscriptProtocolV1::with_domain("Domain1");
ro.append_messages("Label1", &vec![1u8, 2, 3]);
let challenge_hex = hex::encode(ro.extract_raw_challenge());
assert_eq!(
challenge_hex,
"5fb23e3d1cfb33d1b2e2da1c070c7a79056b00d13d642ee47fba542d4863a911"
);
}
#[test]
pub fn test_v1_append_final_prover_message_stable() {
let mut ro = TranscriptProtocolV1::with_domain("Domain1");
ro.append_final_prover_message("Label1", &vec![1u8, 2, 3]);
let challenge_hex = hex::encode(ro.extract_raw_challenge());
assert_eq!(
challenge_hex,
"5fb23e3d1cfb33d1b2e2da1c070c7a79056b00d13d642ee47fba542d4863a911"
);
}
#[test]
pub fn test_v1_extract_challenge_scalar_stable() {
let ro = TranscriptProtocolV1::with_domain("Domain1");
let scalar_hex = hex::encode(common::to_bytes(
&ro.split().extract_challenge_scalar::<ArCurve>("Scalar1"),
));
assert_eq!(
scalar_hex,
"3efcc0fdddcc90a71a022212338ae1c6c7b102fdb9af6befd460d68561856ad9"
);
}
#[test]
pub fn test_v1_append_each_message_stable() {
let mut ro = TranscriptProtocolV1::with_domain("Domain1");
ro.append_each_message("Label1", &vec![1u8, 2, 3], |ro, item| {
ro.append_message("Item", item)
});
let challenge_hex = hex::encode(ro.extract_raw_challenge());
assert_eq!(
challenge_hex,
"ffd0694d68003afd3751f33bbadd38ae26db78aa4e62ce4d53814b9676d6c7dd"
);
}
}