use ark_ec::{CurveGroup, PrimeGroup};
use ark_std::UniformRand;
use rand::rngs::OsRng;
use spongefish::codecs::arkworks_algebra::{
CommonGroupToUnit, DomainSeparator, DuplexSpongeInterface, FieldDomainSeparator,
FieldToUnitDeserialize, FieldToUnitSerialize, GroupDomainSeparator, GroupToUnitDeserialize,
GroupToUnitSerialize, ProofError, ProofResult, ProverState, UnitToField, VerifierState,
};
trait SchnorrDomainSeparator<G: CurveGroup> {
fn new_schnorr_proof(domsep: &str) -> Self;
fn add_schnorr_statement(self) -> Self;
fn add_schnorr_domsep(self) -> Self;
}
impl<G, H> SchnorrDomainSeparator<G> for DomainSeparator<H>
where
G: CurveGroup,
H: DuplexSpongeInterface,
Self: GroupDomainSeparator<G> + FieldDomainSeparator<G::ScalarField>,
{
fn new_schnorr_proof(domsep: &str) -> Self {
Self::new(domsep)
.add_schnorr_statement()
.add_schnorr_domsep()
}
fn add_schnorr_statement(self) -> Self {
self.add_points(1, "generator (P)")
.add_points(1, "public key (X)")
.ratchet()
}
fn add_schnorr_domsep(self) -> Self {
self.add_points(1, "commitment (K)")
.challenge_scalars(1, "challenge (c)")
.add_scalars(1, "response (r)")
}
}
fn keygen<G: CurveGroup>() -> (G::ScalarField, G) {
let sk = G::ScalarField::rand(&mut OsRng);
let pk = G::generator() * sk;
(sk, pk)
}
#[allow(non_snake_case)]
fn prove<H, G>(
prover_state: &mut ProverState<H>,
P: G,
x: G::ScalarField,
) -> ProofResult<&[u8]>
where
H: DuplexSpongeInterface,
G: CurveGroup,
ProverState<H>: GroupToUnitSerialize<G> + UnitToField<G::ScalarField>,
{
let k = G::ScalarField::rand(prover_state.rng());
let K = P * k;
prover_state.add_points(&[K])?;
let [c] = prover_state.challenge_scalars()?;
let r = k + c * x;
prover_state.add_scalars(&[r])?;
Ok(prover_state.narg_string())
}
#[allow(non_snake_case)]
fn verify<G, H>(
verifier_state: &mut VerifierState<H>,
P: G,
X: G,
) -> ProofResult<()>
where
G: CurveGroup,
H: DuplexSpongeInterface,
for<'a> VerifierState<'a, H>: GroupToUnitDeserialize<G>
+ FieldToUnitDeserialize<G::ScalarField>
+ UnitToField<G::ScalarField>,
{
let [K] = verifier_state.next_points().unwrap();
let [c] = verifier_state.challenge_scalars().unwrap();
let [r]: [G::ScalarField; 1] = verifier_state.next_scalars().unwrap();
if P * r == K + X * c {
Ok(())
} else {
Err(ProofError::InvalidProof)
}
}
#[allow(non_snake_case)]
fn main() {
type G = ark_curve25519::EdwardsProjective;
type H = spongefish::duplex_sponge::legacy::DigestBridge<blake2::Blake2s256>;
let io: DomainSeparator<H> =
SchnorrDomainSeparator::<G>::new_schnorr_proof("spongefish::example");
let P = G::generator();
let (x, X) = keygen();
let mut prover_state = io.to_prover_state();
prover_state.public_points(&[P, P * x]).unwrap();
prover_state.ratchet().unwrap();
let proof = prove(&mut prover_state, P, x).expect("Invalid proof");
println!("Here's a Schnorr signature:\n{}", hex::encode(proof));
let mut verifier_state = io.to_verifier_state(proof);
verifier_state.public_points(&[P, X]).unwrap();
verifier_state.ratchet().unwrap();
verify(&mut verifier_state, P, X).expect("Invalid proof");
}