use crate::error::FastCryptoError;
use crate::traits::AllowedRng;
pub trait VRFPublicKey {
type PrivateKey: VRFPrivateKey<PublicKey = Self>;
}
pub trait VRFPrivateKey {
type PublicKey: VRFPublicKey<PrivateKey = Self>;
}
pub trait VRFKeyPair<const OUTPUT_SIZE: usize> {
type Proof: VRFProof<OUTPUT_SIZE, PublicKey = Self::PublicKey>;
type PrivateKey: VRFPrivateKey<PublicKey = Self::PublicKey>;
type PublicKey: VRFPublicKey<PrivateKey = Self::PrivateKey>;
fn generate<R: AllowedRng>(rng: &mut R) -> Self;
fn prove(&self, input: &[u8]) -> Self::Proof;
fn output(&self, input: &[u8]) -> ([u8; OUTPUT_SIZE], Self::Proof) {
let proof = self.prove(input);
let output = proof.to_hash();
(output, proof)
}
}
pub trait VRFProof<const OUTPUT_SIZE: usize> {
type PublicKey: VRFPublicKey;
fn verify(&self, input: &[u8], public_key: &Self::PublicKey) -> Result<(), FastCryptoError>;
fn verify_output(
&self,
input: &[u8],
public_key: &Self::PublicKey,
output: &[u8; OUTPUT_SIZE],
) -> Result<(), FastCryptoError> {
self.verify(input, public_key)?;
if &self.to_hash() != output {
return Err(FastCryptoError::GeneralOpaqueError);
}
Ok(())
}
fn to_hash(&self) -> [u8; OUTPUT_SIZE];
}
pub mod ecvrf {
use crate::error::FastCryptoError;
use crate::groups::ristretto255::{RistrettoPoint, RistrettoScalar};
use crate::groups::{GroupElement, MultiScalarMul, Scalar};
use crate::hash::{HashFunction, ReverseWrapper, Sha512};
use crate::serde_helpers::ToFromByteArray;
use crate::traits::AllowedRng;
use crate::vrf::{VRFKeyPair, VRFPrivateKey, VRFProof, VRFPublicKey};
use elliptic_curve::hash2curve::{ExpandMsg, Expander};
use serde::{Deserialize, Serialize};
use zeroize::ZeroizeOnDrop;
const SUITE_STRING: &[u8; 7] = b"sui_vrf";
const C_LEN: usize = 16;
type H = Sha512;
const DST: &[u8; 49] = b"ECVRF_ristretto255_XMD:SHA-512_R255MAP_RO_sui_vrf";
#[derive(Serialize, Deserialize, PartialEq, Eq, Debug)]
pub struct ECVRFPublicKey(RistrettoPoint);
impl VRFPublicKey for ECVRFPublicKey {
type PrivateKey = ECVRFPrivateKey;
}
impl ECVRFPublicKey {
fn ecvrf_encode_to_curve(&self, alpha_string: &[u8]) -> RistrettoPoint {
let mut expanded_message = elliptic_curve::hash2curve::ExpandMsgXmd::<
<H as ReverseWrapper>::Variant,
>::expand_message(
&[&self.0.compress(), alpha_string],
&[DST],
H::OUTPUT_SIZE,
)
.unwrap();
let mut bytes = [0u8; H::OUTPUT_SIZE];
expanded_message.fill_bytes(&mut bytes);
RistrettoPoint::from_uniform_bytes(&bytes)
}
fn valid(&self) -> bool {
self.0 != RistrettoPoint::zero()
}
}
#[derive(Serialize, Deserialize, PartialEq, Eq, Debug, ZeroizeOnDrop)]
pub struct ECVRFPrivateKey(RistrettoScalar);
impl VRFPrivateKey for ECVRFPrivateKey {
type PublicKey = ECVRFPublicKey;
}
impl ECVRFPrivateKey {
fn ecvrf_nonce_generation(&self, h_string: &[u8]) -> RistrettoScalar {
let hashed_sk_string = H::digest(self.0.to_byte_array());
let mut truncated_hashed_sk_string = [0u8; 32];
truncated_hashed_sk_string.copy_from_slice(&hashed_sk_string.digest[32..64]);
let mut hash_function = H::default();
hash_function.update(truncated_hashed_sk_string);
hash_function.update(h_string);
let k_string = hash_function.finalize();
RistrettoScalar::from_bytes_mod_order_wide(&k_string.digest)
}
}
#[derive(Serialize, Deserialize, PartialEq, Eq, Debug)]
pub struct ECVRFKeyPair {
pub pk: ECVRFPublicKey,
pub sk: ECVRFPrivateKey,
}
fn ecvrf_challenge_generation(points: [&RistrettoPoint; 5]) -> Challenge {
let mut hash = H::default();
hash.update(SUITE_STRING);
hash.update([0x02]); points.into_iter().for_each(|p| hash.update(p.compress()));
hash.update([0x00]); let digest = hash.finalize();
let mut challenge_bytes = [0u8; C_LEN];
challenge_bytes.copy_from_slice(&digest.digest[..C_LEN]);
Challenge(challenge_bytes)
}
#[derive(Serialize, Deserialize, PartialEq, Eq, Debug)]
struct Challenge([u8; C_LEN]);
impl From<&Challenge> for RistrettoScalar {
fn from(c: &Challenge) -> Self {
let mut scalar = [0u8; 32];
scalar[..C_LEN].copy_from_slice(&c.0);
RistrettoScalar::from_bytes_mod_order(&scalar)
}
}
impl VRFKeyPair<64> for ECVRFKeyPair {
type Proof = ECVRFProof;
type PrivateKey = ECVRFPrivateKey;
type PublicKey = ECVRFPublicKey;
fn generate<R: AllowedRng>(rng: &mut R) -> Self {
let s = RistrettoScalar::rand(rng);
ECVRFKeyPair::from(ECVRFPrivateKey(s))
}
fn prove(&self, alpha_string: &[u8]) -> ECVRFProof {
let h = self.pk.ecvrf_encode_to_curve(alpha_string);
let h_string = h.compress();
let gamma = h * self.sk.0;
let k = self.sk.ecvrf_nonce_generation(&h_string);
let c = ecvrf_challenge_generation([
&self.pk.0,
&h,
&gamma,
&(RistrettoPoint::generator() * k),
&(h * k),
]);
let s = k + RistrettoScalar::from(&c) * self.sk.0;
ECVRFProof { gamma, c, s }
}
}
impl From<ECVRFPrivateKey> for ECVRFKeyPair {
fn from(sk: ECVRFPrivateKey) -> Self {
let p = RistrettoPoint::generator() * sk.0;
ECVRFKeyPair {
pk: ECVRFPublicKey(p),
sk,
}
}
}
#[derive(Serialize, Deserialize, PartialEq, Eq, Debug)]
pub struct ECVRFProof {
gamma: RistrettoPoint,
c: Challenge,
s: RistrettoScalar,
}
impl VRFProof<64> for ECVRFProof {
type PublicKey = ECVRFPublicKey;
fn verify(
&self,
alpha_string: &[u8],
public_key: &Self::PublicKey,
) -> Result<(), FastCryptoError> {
if !public_key.valid() {
return Err(FastCryptoError::InvalidInput);
}
let h = public_key.ecvrf_encode_to_curve(alpha_string);
let challenge = RistrettoScalar::from(&self.c);
let u = RistrettoPoint::multi_scalar_mul(
&[self.s, -challenge],
&[RistrettoPoint::generator(), public_key.0],
)?;
let v = RistrettoPoint::multi_scalar_mul(&[self.s, -challenge], &[h, self.gamma])?;
let c_prime = ecvrf_challenge_generation([&public_key.0, &h, &self.gamma, &u, &v]);
if c_prime != self.c {
return Err(FastCryptoError::GeneralOpaqueError);
}
Ok(())
}
fn to_hash(&self) -> [u8; 64] {
let mut hash = H::default();
hash.update(SUITE_STRING);
hash.update([0x03]); hash.update(self.gamma.compress());
hash.update([0x00]); hash.finalize().digest
}
}
}