#![allow(deprecated)]
use elliptic_curve::hash2curve::{ExpandMsgXmd, GroupDigest};
use elliptic_curve::{
bigint::{NonZero, U256},
scalar::FromUintUnchecked,
sec1::{FromEncodedPoint, ToEncodedPoint},
Curve, Field,
};
use p256::{AffinePoint, EncodedPoint, NistP256, ProjectivePoint, Scalar};
use sha2::{Digest, Sha256};
use subtle::ConstantTimeEq;
use super::dleq::{decode_proof, encode_proof, prove, verify};
const COMPRESSED_POINT_LEN: usize = 33;
#[derive(Debug)]
pub enum Error {
Decode,
InvalidPoint,
InvalidProof,
ZeroScalar,
UnsupportedVersion,
}
const TOKEN_VERSION_V1: u8 = 0x01;
const TOKEN_VERSION_LEN: usize = 1;
const TOKEN_POINT_LEN: usize = COMPRESSED_POINT_LEN; const TOKEN_PROOF_LEN: usize = 64;
const TOKEN_LEN: usize = TOKEN_VERSION_LEN + TOKEN_POINT_LEN * 2 + TOKEN_PROOF_LEN;
fn hash_to_curve(input: &[u8], ctx: &[u8]) -> Option<ProjectivePoint> {
const BASE_DST: &[u8] = b"P256_XMD:SHA-256_SSWU_RO_";
let mut dst = Vec::with_capacity(BASE_DST.len() + ctx.len());
dst.extend_from_slice(BASE_DST);
dst.extend_from_slice(ctx);
let point = NistP256::hash_from_bytes::<ExpandMsgXmd<Sha256>>(&[input], &[&dst]).ok()?;
Some(ProjectivePoint::from(point))
}
fn encode_point_compressed(p: &ProjectivePoint) -> [u8; COMPRESSED_POINT_LEN] {
p.to_affine()
.to_encoded_point(true)
.as_bytes()
.try_into()
.unwrap()
}
fn decode_point_compressed(bytes: &[u8]) -> Option<ProjectivePoint> {
let ep = EncodedPoint::from_bytes(bytes).ok()?;
let ap_opt = AffinePoint::from_encoded_point(&ep);
let ap: Option<AffinePoint> = ap_opt.into();
let ap = ap?;
if ap.is_identity().into() {
return None;
}
Some(ProjectivePoint::from(ap))
}
fn encode_point(p: &ProjectivePoint) -> [u8; COMPRESSED_POINT_LEN] {
encode_point_compressed(p)
}
fn decode_point(bytes: &[u8]) -> Result<ProjectivePoint, Error> {
decode_point_compressed(bytes).ok_or(Error::InvalidPoint)
}
fn generator() -> ProjectivePoint {
ProjectivePoint::GENERATOR
}
fn scalar_from_be32(bytes: [u8; 32]) -> Result<Scalar, Error> {
let u = U256::from_be_slice(&bytes);
let n = NonZero::new(NistP256::ORDER).unwrap();
let s = Scalar::from_uint_unchecked(u.rem(&n));
let zero = Scalar::ZERO;
let is_zero = s.to_bytes().ct_eq(&zero.to_bytes());
if bool::from(is_zero) {
return Err(Error::ZeroScalar);
}
Ok(s)
}
fn prf_output_from_b(b: &ProjectivePoint, ctx: &[u8]) -> [u8; 32] {
let mut h = Sha256::new();
h.update(b"VOPRF-P256-SHA256:Finalize");
h.update(ctx);
h.update(encode_point(b));
let out = h.finalize();
let mut out32 = [0u8; 32];
out32.copy_from_slice(&out);
out32
}
pub struct BlindState {
pub r: Scalar,
pub p: ProjectivePoint, }
pub struct Client {
ctx: Vec<u8>,
}
pub struct Server {
k: Scalar,
q: ProjectivePoint, ctx: Vec<u8>,
}
pub struct Verifier {
ctx: Vec<u8>,
}
impl Client {
pub fn new(ctx: &[u8]) -> Self {
Self { ctx: ctx.to_vec() }
}
pub fn blind(&mut self, input: &[u8]) -> Result<(Vec<u8>, BlindState), Error> {
let p = hash_to_curve(input, &self.ctx).ok_or(Error::InvalidPoint)?;
let r = Scalar::random(rand::rngs::OsRng);
let a = p * r;
Ok((encode_point(&a).to_vec(), BlindState { r, p }))
}
pub fn finalize(
self,
_st: BlindState,
token_bytes: &[u8],
issuer_pubkey_sec1_compressed: &[u8],
) -> Result<(Vec<u8>, Vec<u8>), Error> {
if token_bytes.len() != TOKEN_LEN {
return Err(Error::Decode);
}
if token_bytes[0] != TOKEN_VERSION_V1 {
return Err(Error::UnsupportedVersion);
}
let offset = TOKEN_VERSION_LEN;
let a = decode_point(&token_bytes[offset..offset + TOKEN_POINT_LEN])?;
let b = decode_point(&token_bytes[offset + TOKEN_POINT_LEN..offset + TOKEN_POINT_LEN * 2])?;
let proof_bytes: &[u8; 64] = token_bytes[offset + TOKEN_POINT_LEN * 2..]
.try_into()
.map_err(|_| Error::Decode)?;
let proof = decode_proof(proof_bytes);
let q = decode_point(issuer_pubkey_sec1_compressed)?;
let ok = verify(
&generator().to_affine(),
&q.to_affine(),
&a.to_affine(),
&b.to_affine(),
&proof,
Some(&self.ctx),
);
if !ok {
return Err(Error::InvalidProof);
}
let y = prf_output_from_b(&b, &self.ctx);
Ok((token_bytes.to_vec(), y.to_vec()))
}
}
impl Server {
pub fn from_secret_key(sk_bytes: [u8; 32], ctx: &[u8]) -> Result<Self, Error> {
let k = scalar_from_be32(sk_bytes)?;
let q = generator() * k;
Ok(Self {
k,
q,
ctx: ctx.to_vec(),
})
}
pub fn public_key_sec1_compressed(&self) -> [u8; COMPRESSED_POINT_LEN] {
encode_point(&self.q)
}
pub fn evaluate(&self, blinded_bytes: &[u8]) -> Result<Vec<u8>, Error> {
let a = decode_point(blinded_bytes)?;
let b = a * self.k;
let mut rng = rand::rngs::OsRng;
let proof = prove(
&self.k,
&generator().to_affine(),
&self.q.to_affine(),
&a.to_affine(),
&b.to_affine(),
&mut rng,
Some(&self.ctx),
);
let mut token = Vec::with_capacity(TOKEN_LEN);
token.push(TOKEN_VERSION_V1); token.extend_from_slice(&encode_point(&a));
token.extend_from_slice(&encode_point(&b));
token.extend_from_slice(&encode_proof(&proof));
Ok(token)
}
}
impl Verifier {
pub fn new(ctx: &[u8]) -> Self {
Self { ctx: ctx.to_vec() }
}
pub fn verify(
&self,
token_bytes: &[u8],
issuer_pubkey_sec1_compressed: &[u8],
) -> Result<Vec<u8>, Error> {
if token_bytes.len() != TOKEN_LEN {
return Err(Error::Decode);
}
if token_bytes[0] != TOKEN_VERSION_V1 {
return Err(Error::UnsupportedVersion);
}
let offset = TOKEN_VERSION_LEN;
let a = decode_point(&token_bytes[offset..offset + TOKEN_POINT_LEN])?;
let b = decode_point(&token_bytes[offset + TOKEN_POINT_LEN..offset + TOKEN_POINT_LEN * 2])?;
let proof_bytes: &[u8; 64] = token_bytes[offset + TOKEN_POINT_LEN * 2..]
.try_into()
.map_err(|_| Error::Decode)?;
let proof = decode_proof(proof_bytes);
let q = decode_point(issuer_pubkey_sec1_compressed)?;
let ok = verify(
&generator().to_affine(),
&q.to_affine(),
&a.to_affine(),
&b.to_affine(),
&proof,
Some(&self.ctx),
);
if !ok {
return Err(Error::InvalidProof);
}
Ok(prf_output_from_b(&b, &self.ctx).to_vec())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_voprf_rfc_test_vectors() {
let ctx = b"VOPRF-TEST";
let sk_bytes = [
0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e,
0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c,
0x1d, 0x1e, 0x1f, 0x20,
];
let server = Server::from_secret_key(sk_bytes, ctx).unwrap();
let pk = server.public_key_sec1_compressed();
let mut client = Client::new(ctx);
let input = b"test input";
let (blinded, state) = client.blind(input).unwrap();
let token_bytes = server.evaluate(blinded.as_slice()).unwrap();
assert_eq!(token_bytes.len(), TOKEN_LEN);
assert_eq!(token_bytes[0], TOKEN_VERSION_V1);
let (token, output) = client.finalize(state, &token_bytes, &pk).unwrap();
let verifier = Verifier::new(ctx);
let verified_output = verifier.verify(&token, &pk).unwrap();
assert_eq!(output, verified_output);
}
#[test]
fn test_constant_time_scalar_zero_check() {
let zero_bytes = [0u8; 32];
let result = scalar_from_be32(zero_bytes);
assert!(matches!(result, Err(Error::ZeroScalar)));
}
#[test]
fn test_hash_to_curve_consistency() {
let input = b"test input";
let ctx = b"test-ctx";
let p1 = hash_to_curve(input, ctx);
let p2 = hash_to_curve(input, ctx);
assert!(p1.is_some());
assert!(p2.is_some());
assert_eq!(p1.unwrap(), p2.unwrap());
}
#[test]
fn test_point_encoding_roundtrip() {
let g = generator();
let encoded = encode_point(&g);
let decoded = decode_point(&encoded).unwrap();
assert_eq!(g, decoded);
}
#[test]
fn test_token_version_checking() {
let ctx = b"test";
let sk_bytes = [1u8; 32];
let server = Server::from_secret_key(sk_bytes, ctx).unwrap();
let pk = server.public_key_sec1_compressed();
let mut client = Client::new(ctx);
let (blinded, state) = client.blind(b"input").unwrap();
let mut token_bytes = server.evaluate(blinded.as_slice()).unwrap();
token_bytes[0] = 0xFF;
let result = client.finalize(state, &token_bytes, &pk);
assert!(matches!(result, Err(Error::UnsupportedVersion)));
}
#[test]
fn test_dleq_proof_verification() {
let ctx = b"dleq-test";
let sk_bytes = [42u8; 32];
let server = Server::from_secret_key(sk_bytes, ctx).unwrap();
let pk = server.public_key_sec1_compressed();
let mut client = Client::new(ctx);
let (blinded, state) = client.blind(b"test").unwrap();
let token_bytes = server.evaluate(blinded.as_slice()).unwrap();
let result = client.finalize(state, &token_bytes, &pk);
assert!(result.is_ok());
}
}