use alloc::vec::Vec;
use core::fmt;
use rand::{CryptoRng, Rng, RngCore, SeedableRng};
use crate::{Decoding, DuplexSpongeInterface, Encoding, NargSerialize, StdHash};
type StdRng = rand::rngs::StdRng;
pub struct ProverState<H = StdHash, R = StdRng>
where
H: DuplexSpongeInterface,
R: RngCore + CryptoRng,
{
pub(crate) private_rng: ReseedableRng<R>,
#[cfg(feature = "yolocrypto")]
pub duplex_sponge_state: H,
#[cfg(not(feature = "yolocrypto"))]
pub(crate) duplex_sponge_state: H,
pub(crate) narg_string: Vec<u8>,
}
#[derive(Default)]
pub struct ReseedableRng<R: RngCore + CryptoRng> {
pub(crate) duplex_sponge: StdHash,
pub(crate) csrng: R,
}
impl<R: RngCore + CryptoRng> From<R> for ReseedableRng<R> {
fn from(mut csrng: R) -> Self {
let mut duplex_sponge = StdHash::default();
let seed: [u8; 32] = csrng.gen::<[u8; 32]>();
duplex_sponge.absorb(&seed);
Self {
duplex_sponge,
csrng,
}
}
}
impl ReseedableRng<StdRng> {
pub fn new() -> Self {
let csrng = StdRng::from_entropy();
csrng.into()
}
}
impl<R: RngCore + CryptoRng> RngCore for ReseedableRng<R> {
fn next_u32(&mut self) -> u32 {
let mut buf = [0u8; 4];
self.fill_bytes(buf.as_mut());
u32::from_le_bytes(buf)
}
fn next_u64(&mut self) -> u64 {
let mut buf = [0u8; 8];
self.fill_bytes(buf.as_mut());
u64::from_le_bytes(buf)
}
fn fill_bytes(&mut self, dest: &mut [u8]) {
self.duplex_sponge.squeeze(dest);
}
fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand::Error> {
self.duplex_sponge.squeeze(dest);
Ok(())
}
}
impl<R: RngCore + CryptoRng> ReseedableRng<R> {
pub fn reseed_with(&mut self, value: &[u8]) {
self.duplex_sponge.ratchet();
self.duplex_sponge.absorb(value);
self.duplex_sponge.ratchet();
}
pub fn reseed(&mut self) {
let seed = self.csrng.gen::<[u8; 32]>();
self.reseed_with(&seed);
}
}
impl<R: RngCore + CryptoRng> CryptoRng for ReseedableRng<R> {}
impl<H, R> fmt::Debug for ProverState<H, R>
where
H: DuplexSpongeInterface,
R: RngCore + CryptoRng,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "ProverState<{}>", core::any::type_name::<H>())
}
}
impl<H, R> ProverState<H, R>
where
H: DuplexSpongeInterface,
R: RngCore + CryptoRng,
{
pub const fn rng(&mut self) -> &mut ReseedableRng<R> {
&mut self.private_rng
}
#[inline]
pub const fn narg_string(&self) -> &[u8] {
self.narg_string.as_slice()
}
pub fn public_message<T: Encoding<[H::U]> + ?Sized>(&mut self, message: &T) {
self.duplex_sponge_state.absorb(message.encode().as_ref());
}
pub fn prover_message<T: Encoding<[H::U]> + NargSerialize + ?Sized>(&mut self, message: &T) {
self.duplex_sponge_state.absorb(message.encode().as_ref());
message.serialize_into_narg(&mut self.narg_string);
}
pub fn verifier_message<T: Decoding<[H::U]>>(&mut self) -> T {
let mut buf = T::Repr::default();
self.duplex_sponge_state.squeeze(buf.as_mut());
T::decode(buf)
}
#[deprecated(note = "Please use ProverState::narg_string instead.")]
#[inline]
pub const fn transcript(&self) -> &[u8] {
self.narg_string()
}
#[deprecated(note = "Please use ProverState::verifier_message instead.")]
pub fn challenge<T: Decoding<[H::U]>>(&mut self) -> T {
self.verifier_message()
}
pub fn public_messages<T: Encoding<[H::U]>>(&mut self, messages: &[T]) {
for message in messages {
self.public_message(message);
}
}
pub fn public_messages_iter<J>(&mut self, messages: J)
where
J: IntoIterator,
J::Item: Encoding<[H::U]>,
{
messages
.into_iter()
.for_each(|message| self.public_message(&message));
}
pub fn prover_messages<T: Encoding<[H::U]> + NargSerialize>(&mut self, messages: &[T]) {
for message in messages {
self.prover_message(message);
}
}
pub fn prover_messages_iter<J>(&mut self, messages: J)
where
J: IntoIterator,
J::Item: Encoding<[H::U]> + NargSerialize,
{
messages
.into_iter()
.for_each(|message| self.prover_message(&message));
}
pub fn verifier_messages<T: Decoding<[H::U]>, const N: usize>(&mut self) -> [T; N] {
core::array::from_fn(|_| self.verifier_message())
}
pub fn verifier_messages_vec<T: Decoding<[H::U]>>(&mut self, len: usize) -> Vec<T> {
(0..len).map(|_| self.verifier_message()).collect()
}
}
#[cfg(feature = "yolocrypto")]
impl<H: DuplexSpongeInterface + Default, R: RngCore + CryptoRng + SeedableRng> Default
for ProverState<H, R>
{
fn default() -> Self {
Self {
duplex_sponge_state: H::default(),
private_rng: R::from_entropy().into(),
narg_string: Vec::new(),
}
}
}
impl<H: DuplexSpongeInterface, R: RngCore + CryptoRng + SeedableRng> From<H> for ProverState<H, R> {
fn from(value: H) -> Self {
Self {
duplex_sponge_state: value,
private_rng: R::from_entropy().into(),
narg_string: Vec::new(),
}
}
}