use ark_ec::CurveGroup;
use ark_ff::{Field, Fp, FpConfig};
use ark_serialize::CanonicalSerialize;
use rand::{CryptoRng, RngCore};
use super::{CommonFieldToUnit, CommonGroupToUnit, FieldToUnitSerialize, GroupToUnitSerialize};
use crate::{
BytesToUnitDeserialize, BytesToUnitSerialize, CommonUnitToBytes, DomainSeparatorMismatch,
DuplexSpongeInterface, ProofResult, ProverState, Unit, UnitTranscript, VerifierState,
};
impl<F: Field, H: DuplexSpongeInterface, R: RngCore + CryptoRng> FieldToUnitSerialize<F>
for ProverState<H, u8, R>
{
fn add_scalars(&mut self, input: &[F]) -> ProofResult<()> {
let serialized = self.public_scalars(input);
self.narg_string.extend(serialized?);
Ok(())
}
}
impl<
C: FpConfig<N>,
H: DuplexSpongeInterface<Fp<C, N>>,
R: RngCore + CryptoRng,
const N: usize,
> FieldToUnitSerialize<Fp<C, N>> for ProverState<H, Fp<C, N>, R>
{
fn add_scalars(&mut self, input: &[Fp<C, N>]) -> ProofResult<()> {
self.public_units(input)?;
for i in input {
i.serialize_compressed(&mut self.narg_string)?;
}
Ok(())
}
}
impl<G, H, R> GroupToUnitSerialize<G> for ProverState<H, u8, R>
where
G: CurveGroup,
H: DuplexSpongeInterface,
R: RngCore + CryptoRng,
Self: CommonGroupToUnit<G, Repr = Vec<u8>>,
{
fn add_points(&mut self, input: &[G]) -> ProofResult<()> {
let serialized = self.public_points(input);
self.narg_string.extend(serialized?);
Ok(())
}
}
impl<G, H, R, C: FpConfig<N>, C2: FpConfig<N>, const N: usize> GroupToUnitSerialize<G>
for ProverState<H, Fp<C, N>, R>
where
G: CurveGroup<BaseField = Fp<C2, N>>,
H: DuplexSpongeInterface<Fp<C, N>>,
R: RngCore + CryptoRng,
Self: CommonGroupToUnit<G> + FieldToUnitSerialize<G::BaseField>,
{
fn add_points(&mut self, input: &[G]) -> ProofResult<()> {
self.public_points(input).map(|_| ())?;
for i in input {
i.serialize_compressed(&mut self.narg_string)?;
}
Ok(())
}
}
impl<H, R, C, const N: usize> BytesToUnitSerialize for ProverState<H, Fp<C, N>, R>
where
H: DuplexSpongeInterface<Fp<C, N>>,
C: FpConfig<N>,
R: RngCore + CryptoRng,
{
fn add_bytes(&mut self, input: &[u8]) -> Result<(), DomainSeparatorMismatch> {
self.public_bytes(input)?;
self.narg_string.extend(input);
Ok(())
}
}
impl<H, C, const N: usize> BytesToUnitDeserialize for VerifierState<'_, H, Fp<C, N>>
where
H: DuplexSpongeInterface<Fp<C, N>>,
C: FpConfig<N>,
{
fn fill_next_bytes(&mut self, input: &mut [u8]) -> Result<(), DomainSeparatorMismatch> {
u8::read(&mut self.narg_string, input)?;
self.public_bytes(input)
}
}
#[cfg(test)]
mod tests {
use ark_bls12_381::Fr;
use ark_curve25519::EdwardsProjective;
use ark_ec::PrimeGroup;
use ark_ff::{Fp64, MontBackend, MontConfig, UniformRand};
use super::*;
use crate::{
codecs::arkworks_algebra::{
FieldDomainSeparator, FieldToUnitSerialize, GroupDomainSeparator,
},
ByteDomainSeparator, DefaultHash, DomainSeparator,
};
type G = EdwardsProjective;
#[derive(MontConfig)]
#[modulus = "2013265921"]
#[generator = "21"]
pub struct BabybearConfig;
pub type BabyBear = Fp64<MontBackend<BabybearConfig, 1>>;
#[test]
fn test_add_scalars() {
let domsep = DomainSeparator::new("test");
let domsep =
<DomainSeparator as FieldDomainSeparator<BabyBear>>::add_scalars(domsep, 3, "com");
let mut prover_state = domsep.to_prover_state();
let mut rng = ark_std::test_rng();
let (f0, f1, f2) = (
BabyBear::rand(&mut rng),
BabyBear::rand(&mut rng),
BabyBear::rand(&mut rng),
);
prover_state.add_scalars(&[f0, f1, f2]).unwrap();
let mut expected_bytes = Vec::new();
f0.serialize_compressed(&mut expected_bytes).unwrap();
f1.serialize_compressed(&mut expected_bytes).unwrap();
f2.serialize_compressed(&mut expected_bytes).unwrap();
assert_eq!(
prover_state.narg_string, expected_bytes,
"Transcript serialization mismatch"
);
let mut prover_state2 = domsep.to_prover_state();
prover_state2.add_scalars(&[f0, f1, f2]).unwrap();
assert_eq!(
prover_state.narg_string, prover_state2.narg_string,
"Transcript encoding should be deterministic for same inputs"
);
}
#[test]
fn test_add_scalars_u8_unit() {
let domsep = DomainSeparator::new("test-add-scalars-u8");
let domsep = <DomainSeparator as FieldDomainSeparator<Fr>>::add_scalars(domsep, 2, "com");
let mut prover = domsep.to_prover_state();
let f0 = Fr::from(5u64);
let f1 = Fr::from(42u64);
prover.add_scalars(&[f0, f1]).unwrap();
let mut expected = Vec::new();
f0.serialize_compressed(&mut expected).unwrap();
f1.serialize_compressed(&mut expected).unwrap();
assert_eq!(prover.narg_string, expected);
}
#[test]
fn test_add_points_u8_unit() {
let domsep = <DomainSeparator as GroupDomainSeparator<G>>::add_points(
DomainSeparator::new("curve25519"),
1,
"pt",
);
let mut prover = domsep.to_prover_state();
let point = G::generator();
prover.add_points(&[point]).unwrap();
assert!(!prover.narg_string.is_empty());
}
#[test]
fn test_add_points_fp_unit() {
let domsep = <DomainSeparator as GroupDomainSeparator<G>>::add_points(
DomainSeparator::new("curve-bb"),
1,
"pt",
);
let mut prover = domsep.to_prover_state();
let point = G::generator();
prover.add_points(&[point]).unwrap();
let mut expected = Vec::new();
point.serialize_compressed(&mut expected).unwrap();
assert_eq!(prover.narg_string, expected);
}
#[test]
fn test_add_bytes_fp_unit() {
let input = b"hello world!";
let domsep: DomainSeparator<DefaultHash, u8> = DomainSeparator::new("test-add-bytes-fp");
let domsep = domsep.add_bytes(12, "com");
let mut prover = domsep.to_prover_state();
prover.add_bytes(input).unwrap();
assert_eq!(prover.narg_string, input);
}
#[test]
fn test_fill_next_bytes_fp_unit() {
let input = b"secret-msg";
let domsep: DomainSeparator<DefaultHash, u8> = DomainSeparator::new("read-bytes");
let domsep = domsep.add_bytes(input.len(), "msg");
let mut prover = domsep.to_prover_state();
prover.add_bytes(input).unwrap();
let mut verifier = domsep.to_verifier_state(&prover.narg_string);
let mut buf = [0u8; 10];
verifier.fill_next_bytes(&mut buf).unwrap();
assert_eq!(buf, *input);
}
}