use rand::Rng;
use rand_chacha::{rand_core::SeedableRng, ChaCha8Rng};
use serde::Serialize;
use vrf::openssl::{CipherSuite, ECVRF};
use vrf::VRF;
use crate::error::{Error, Result};
use crate::types::Role;
#[derive(Debug, Serialize)]
pub struct VrfParams {
pub weight: u64,
pub round: u64,
pub seed: u64,
pub role: Role,
}
pub struct VrfProof {
pub proof: Vec<u8>,
pub hash: Vec<u8>,
}
pub struct VrfClient {
inner: ECVRF,
}
impl VrfClient {
pub fn new() -> Result<Self> {
let inner = ECVRF::from_suite(CipherSuite::SECP256K1_SHA256_TAI)
.map_err(|e| Error::Vrf(format!("{:?}", e)))?;
Ok(Self { inner })
}
pub fn generate_keys(&mut self, seed: u64) -> Result<(Vec<u8>, Vec<u8>)> {
let mut rng = ChaCha8Rng::seed_from_u64(seed);
let secret_key: Vec<u8> = (0..32).map(|_| rng.gen()).collect();
let public_key = self
.inner
.derive_public_key(&secret_key)
.map_err(|e| Error::Vrf(format!("{:?}", e)))?;
Ok((secret_key, public_key))
}
pub fn prove(&mut self, secret_key: &[u8], params: &VrfParams) -> Result<VrfProof> {
let data = bincode::serialize(params).map_err(|e| Error::Vrf(e.to_string()))?;
let proof = self
.inner
.prove(secret_key, &data)
.map_err(|e| Error::Vrf(format!("{:?}", e)))?;
let hash = self
.inner
.proof_to_hash(&proof)
.map_err(|e| Error::Vrf(format!("{:?}", e)))?;
Ok(VrfProof { proof, hash })
}
pub fn verify(
&mut self,
public_key: &[u8],
vrf_proof: &VrfProof,
params: &VrfParams,
) -> Result<bool> {
let data = bincode::serialize(params).map_err(|e| Error::Vrf(e.to_string()))?;
let beta = self
.inner
.verify(public_key, &vrf_proof.proof, &data)
.map_err(|e| Error::Vrf(format!("{:?}", e)))?;
Ok(beta == vrf_proof.hash)
}
}
fn log_binom_pmf(k: u64, n: u64, p: f64) -> f64 {
if p <= 0.0 {
return if k == 0 { 0.0 } else { f64::NEG_INFINITY };
}
if p >= 1.0 {
return if k == n { 0.0 } else { f64::NEG_INFINITY };
}
if k == 0 {
return (n as f64) * (1.0 - p).ln();
}
if k == n {
return (n as f64) * p.ln();
}
let (lgn1, _) = libm::lgamma_r((n + 1) as f64);
let (lgk1, _) = libm::lgamma_r((k + 1) as f64);
let (lgnk1, _) = libm::lgamma_r((n - k + 1) as f64);
lgn1 - lgk1 - lgnk1 + (k as f64) * p.ln() + ((n - k) as f64) * (1.0 - p).ln()
}
pub fn sortition(vrf_hash: &[u8], weight: u64, expected: u64, total: u64) -> u64 {
if weight == 0 || total == 0 || expected == 0 {
return 0;
}
let p = expected as f64 / total as f64;
let mut bytes = [0u8; 8];
let len = vrf_hash.len().min(8);
bytes[..len].copy_from_slice(&vrf_hash[..len]);
let val = u64::from_le_bytes(bytes) as f64 / u64::MAX as f64;
let mut cumulative = 0.0;
for k in 0..=weight {
cumulative += log_binom_pmf(k, weight, p).exp();
if val < cumulative {
return k;
}
}
weight
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn sortition_zero_weight_returns_zero() {
assert_eq!(sortition(&[0; 32], 0, 10, 100), 0);
}
#[test]
fn sortition_zero_expected_returns_zero() {
assert_eq!(sortition(&[0; 32], 100, 0, 1000), 0);
}
}