use core::convert::TryFrom;
use derive_where::derive_where;
use digest::core_api::BlockSizeUser;
use digest::{Digest, Output, OutputSizeUser};
use generic_array::sequence::Concat;
use generic_array::typenum::{IsLess, IsLessOrEqual, Unsigned, U11, U2, U256};
use generic_array::{ArrayLength, GenericArray};
use rand_core::{CryptoRng, RngCore};
use subtle::ConstantTimeEq;
#[cfg(feature = "serde")]
use crate::serialization::serde::{Element, Scalar};
use crate::{CipherSuite, Error, Group, InternalError, Result};
pub(crate) const STR_FINALIZE: [u8; 8] = *b"Finalize";
pub(crate) const STR_SEED: [u8; 5] = *b"Seed-";
pub(crate) const STR_DERIVE_KEYPAIR: [u8; 13] = *b"DeriveKeyPair";
pub(crate) const STR_COMPOSITE: [u8; 9] = *b"Composite";
pub(crate) const STR_CHALLENGE: [u8; 9] = *b"Challenge";
pub(crate) const STR_INFO: [u8; 4] = *b"Info";
pub(crate) const STR_VOPRF: [u8; 8] = *b"VOPRF10-";
pub(crate) const STR_HASH_TO_SCALAR: [u8; 13] = *b"HashToScalar-";
pub(crate) const STR_HASH_TO_GROUP: [u8; 12] = *b"HashToGroup-";
#[derive(Clone, Copy, Debug)]
pub enum Mode {
Oprf,
Voprf,
Poprf,
}
impl Mode {
pub fn to_u8(self) -> u8 {
match self {
Mode::Oprf => 0,
Mode::Voprf => 1,
Mode::Poprf => 2,
}
}
}
#[derive_where(Clone, ZeroizeOnDrop)]
#[derive_where(Debug, Eq, Hash, Ord, PartialEq, PartialOrd; <CS::Group as Group>::Elem)]
#[cfg_attr(
feature = "serde",
derive(serde::Deserialize, serde::Serialize),
serde(crate = "serde", bound = "")
)]
pub struct BlindedElement<CS: CipherSuite>(
#[cfg_attr(feature = "serde", serde(with = "Element::<CS::Group>"))]
pub(crate) <CS::Group as Group>::Elem,
)
where
<CS::Hash as OutputSizeUser>::OutputSize:
IsLess<U256> + IsLessOrEqual<<CS::Hash as BlockSizeUser>::BlockSize>;
#[derive_where(Clone, ZeroizeOnDrop)]
#[derive_where(Debug, Eq, Hash, Ord, PartialEq, PartialOrd; <CS::Group as Group>::Elem)]
#[cfg_attr(
feature = "serde",
derive(serde::Deserialize, serde::Serialize),
serde(crate = "serde", bound = "")
)]
pub struct EvaluationElement<CS: CipherSuite>(
#[cfg_attr(feature = "serde", serde(with = "Element::<CS::Group>"))]
pub(crate) <CS::Group as Group>::Elem,
)
where
<CS::Hash as OutputSizeUser>::OutputSize:
IsLess<U256> + IsLessOrEqual<<CS::Hash as BlockSizeUser>::BlockSize>;
#[derive_where(Clone, ZeroizeOnDrop)]
#[derive_where(Debug, Eq, Hash, Ord, PartialEq, PartialOrd; <CS::Group as Group>::Elem)]
#[cfg_attr(
feature = "serde",
derive(serde::Deserialize, serde::Serialize),
serde(crate = "serde", bound = "")
)]
pub struct PreparedEvaluationElement<CS: CipherSuite>(pub(crate) EvaluationElement<CS>)
where
<CS::Hash as OutputSizeUser>::OutputSize:
IsLess<U256> + IsLessOrEqual<<CS::Hash as BlockSizeUser>::BlockSize>;
#[derive_where(Clone, ZeroizeOnDrop)]
#[derive_where(Debug, Eq, Hash, Ord, PartialEq, PartialOrd; <CS::Group as Group>::Scalar)]
#[cfg_attr(
feature = "serde",
derive(serde::Deserialize, serde::Serialize),
serde(crate = "serde", bound = "")
)]
pub struct Proof<CS: CipherSuite>
where
<CS::Hash as OutputSizeUser>::OutputSize:
IsLess<U256> + IsLessOrEqual<<CS::Hash as BlockSizeUser>::BlockSize>,
{
#[cfg_attr(feature = "serde", serde(with = "Scalar::<CS::Group>"))]
pub(crate) c_scalar: <CS::Group as Group>::Scalar,
#[cfg_attr(feature = "serde", serde(with = "Scalar::<CS::Group>"))]
pub(crate) s_scalar: <CS::Group as Group>::Scalar,
}
#[allow(clippy::many_single_char_names)]
pub(crate) fn generate_proof<CS: CipherSuite, R: RngCore + CryptoRng>(
rng: &mut R,
k: <CS::Group as Group>::Scalar,
a: <CS::Group as Group>::Elem,
b: <CS::Group as Group>::Elem,
cs: impl Iterator<Item = <CS::Group as Group>::Elem> + ExactSizeIterator,
ds: impl Iterator<Item = <CS::Group as Group>::Elem> + ExactSizeIterator,
mode: Mode,
) -> Result<Proof<CS>>
where
<CS::Hash as OutputSizeUser>::OutputSize:
IsLess<U256> + IsLessOrEqual<<CS::Hash as BlockSizeUser>::BlockSize>,
{
let (m, z) = compute_composites::<CS, _, _>(Some(k), b, cs, ds, mode)?;
let r = CS::Group::random_scalar(rng);
let t2 = a * &r;
let t3 = m * &r;
let bm = CS::Group::serialize_elem(b);
let a0 = CS::Group::serialize_elem(m);
let a1 = CS::Group::serialize_elem(z);
let a2 = CS::Group::serialize_elem(t2);
let a3 = CS::Group::serialize_elem(t3);
let elem_len = <CS::Group as Group>::ElemLen::U16.to_be_bytes();
let h2_input = [
&elem_len,
bm.as_slice(),
&elem_len,
&a0,
&elem_len,
&a1,
&elem_len,
&a2,
&elem_len,
&a3,
&STR_CHALLENGE,
];
let dst = GenericArray::from(STR_HASH_TO_SCALAR).concat(create_context_string::<CS>(mode));
let c_scalar = CS::Group::hash_to_scalar::<CS::Hash>(&h2_input, &dst).unwrap();
let s_scalar = r - &(c_scalar * &k);
Ok(Proof { c_scalar, s_scalar })
}
#[allow(clippy::many_single_char_names)]
pub(crate) fn verify_proof<CS: CipherSuite>(
a: <CS::Group as Group>::Elem,
b: <CS::Group as Group>::Elem,
cs: impl Iterator<Item = <CS::Group as Group>::Elem> + ExactSizeIterator,
ds: impl Iterator<Item = <CS::Group as Group>::Elem> + ExactSizeIterator,
proof: &Proof<CS>,
mode: Mode,
) -> Result<()>
where
<CS::Hash as OutputSizeUser>::OutputSize:
IsLess<U256> + IsLessOrEqual<<CS::Hash as BlockSizeUser>::BlockSize>,
{
let (m, z) = compute_composites::<CS, _, _>(None, b, cs, ds, mode)?;
let t2 = (a * &proof.s_scalar) + &(b * &proof.c_scalar);
let t3 = (m * &proof.s_scalar) + &(z * &proof.c_scalar);
let bm = CS::Group::serialize_elem(b);
let a0 = CS::Group::serialize_elem(m);
let a1 = CS::Group::serialize_elem(z);
let a2 = CS::Group::serialize_elem(t2);
let a3 = CS::Group::serialize_elem(t3);
let elem_len = <CS::Group as Group>::ElemLen::U16.to_be_bytes();
let h2_input = [
&elem_len,
bm.as_slice(),
&elem_len,
&a0,
&elem_len,
&a1,
&elem_len,
&a2,
&elem_len,
&a3,
&STR_CHALLENGE,
];
let dst = GenericArray::from(STR_HASH_TO_SCALAR).concat(create_context_string::<CS>(mode));
let c = CS::Group::hash_to_scalar::<CS::Hash>(&h2_input, &dst).unwrap();
match c.ct_eq(&proof.c_scalar).into() {
true => Ok(()),
false => Err(Error::ProofVerification),
}
}
type ComputeCompositesResult<CS> = (
<<CS as CipherSuite>::Group as Group>::Elem,
<<CS as CipherSuite>::Group as Group>::Elem,
);
fn compute_composites<
CS: CipherSuite,
IC: Iterator<Item = <CS::Group as Group>::Elem> + ExactSizeIterator,
ID: Iterator<Item = <CS::Group as Group>::Elem> + ExactSizeIterator,
>(
k_option: Option<<CS::Group as Group>::Scalar>,
b: <CS::Group as Group>::Elem,
c_slice: IC,
d_slice: ID,
mode: Mode,
) -> Result<ComputeCompositesResult<CS>>
where
<CS::Hash as OutputSizeUser>::OutputSize:
IsLess<U256> + IsLessOrEqual<<CS::Hash as BlockSizeUser>::BlockSize>,
{
let elem_len = <CS::Group as Group>::ElemLen::U16.to_be_bytes();
if c_slice.len() != d_slice.len() {
return Err(Error::Batch);
}
let len = u16::try_from(c_slice.len()).map_err(|_| Error::Batch)?;
let seed_dst = GenericArray::from(STR_SEED).concat(create_context_string::<CS>(mode));
let seed = CS::Hash::new()
.chain_update(elem_len)
.chain_update(CS::Group::serialize_elem(b))
.chain_update(i2osp_2_array(&seed_dst))
.chain_update(seed_dst)
.finalize();
let seed_len = i2osp_2_array(&seed);
let mut m = CS::Group::identity_elem();
let mut z = CS::Group::identity_elem();
for (i, (c, d)) in (0..len).zip(c_slice.zip(d_slice)) {
let ci = CS::Group::serialize_elem(c);
let di = CS::Group::serialize_elem(d);
let h2_input = [
seed_len.as_slice(),
&seed,
&i.to_be_bytes(),
&elem_len,
&ci,
&elem_len,
&di,
&STR_COMPOSITE,
];
let dst = GenericArray::from(STR_HASH_TO_SCALAR).concat(create_context_string::<CS>(mode));
let di = CS::Group::hash_to_scalar::<CS::Hash>(&h2_input, &dst).unwrap();
m = c * &di + &m;
z = match k_option {
Some(_) => z,
None => d * &di + &z,
};
}
z = match k_option {
Some(k) => m * &k,
None => z,
};
Ok((m, z))
}
pub(crate) fn derive_key_internal<CS: CipherSuite>(
seed: &[u8],
info: &[u8],
mode: Mode,
) -> Result<<CS::Group as Group>::Scalar, Error>
where
<CS::Hash as OutputSizeUser>::OutputSize:
IsLess<U256> + IsLessOrEqual<<CS::Hash as BlockSizeUser>::BlockSize>,
{
let context_string = create_context_string::<CS>(mode);
let dst = GenericArray::from(STR_DERIVE_KEYPAIR).concat(context_string);
let info_len = i2osp_2(info.len()).map_err(|_| Error::DeriveKeyPair)?;
for counter in 0_u8..=u8::MAX {
let sk_s = CS::Group::hash_to_scalar::<CS::Hash>(
&[seed, &info_len, info, &counter.to_be_bytes()],
&dst,
)
.map_err(|_| Error::DeriveKeyPair)?;
if !bool::from(CS::Group::is_zero_scalar(sk_s)) {
return Ok(sk_s);
}
}
Err(Error::Protocol)
}
#[cfg(feature = "danger")]
pub fn derive_key<CS: CipherSuite>(
seed: &[u8],
info: &[u8],
mode: Mode,
) -> Result<<CS::Group as Group>::Scalar, Error>
where
<CS::Hash as OutputSizeUser>::OutputSize:
IsLess<U256> + IsLessOrEqual<<CS::Hash as BlockSizeUser>::BlockSize>,
{
derive_key_internal::<CS>(seed, info, mode)
}
type DeriveKeypairResult<CS> = (
<<CS as CipherSuite>::Group as Group>::Scalar,
<<CS as CipherSuite>::Group as Group>::Elem,
);
pub(crate) fn derive_keypair<CS: CipherSuite>(
seed: &[u8],
info: &[u8],
mode: Mode,
) -> Result<DeriveKeypairResult<CS>, Error>
where
<CS::Hash as OutputSizeUser>::OutputSize:
IsLess<U256> + IsLessOrEqual<<CS::Hash as BlockSizeUser>::BlockSize>,
{
let sk_s = derive_key_internal::<CS>(seed, info, mode)?;
let pk_s = CS::Group::base_elem() * &sk_s;
Ok((sk_s, pk_s))
}
pub(crate) fn deterministic_blind_unchecked<CS: CipherSuite>(
input: &[u8],
blind: &<CS::Group as Group>::Scalar,
mode: Mode,
) -> Result<<CS::Group as Group>::Elem>
where
<CS::Hash as OutputSizeUser>::OutputSize:
IsLess<U256> + IsLessOrEqual<<CS::Hash as BlockSizeUser>::BlockSize>,
{
let hashed_point = hash_to_group::<CS>(input, mode)?;
Ok(hashed_point * blind)
}
pub(crate) fn hash_to_group<CS: CipherSuite>(
input: &[u8],
mode: Mode,
) -> Result<<CS::Group as Group>::Elem>
where
<CS::Hash as OutputSizeUser>::OutputSize:
IsLess<U256> + IsLessOrEqual<<CS::Hash as BlockSizeUser>::BlockSize>,
{
let dst = GenericArray::from(STR_HASH_TO_GROUP).concat(create_context_string::<CS>(mode));
CS::Group::hash_to_curve::<CS::Hash>(&[input], &dst).map_err(|_| Error::Input)
}
pub(crate) fn server_evaluate_hash_input<CS: CipherSuite>(
input: &[u8],
info: Option<&[u8]>,
issued_element: GenericArray<u8, <<CS as CipherSuite>::Group as Group>::ElemLen>,
) -> Result<Output<CS::Hash>>
where
<CS::Hash as OutputSizeUser>::OutputSize:
IsLess<U256> + IsLessOrEqual<<CS::Hash as BlockSizeUser>::BlockSize>,
{
let mut hash = CS::Hash::new()
.chain_update(i2osp_2(input.as_ref().len()).map_err(|_| Error::Input)?)
.chain_update(input.as_ref());
if let Some(info) = info {
hash = hash
.chain_update(i2osp_2(info.as_ref().len()).map_err(|_| Error::Input)?)
.chain_update(info.as_ref());
}
Ok(hash
.chain_update(i2osp_2(issued_element.as_ref().len()).map_err(|_| Error::Input)?)
.chain_update(issued_element)
.chain_update(STR_FINALIZE)
.finalize())
}
pub(crate) fn create_context_string<CS: CipherSuite>(mode: Mode) -> GenericArray<u8, U11>
where
<CS::Hash as OutputSizeUser>::OutputSize:
IsLess<U256> + IsLessOrEqual<<CS::Hash as BlockSizeUser>::BlockSize>,
{
GenericArray::from(STR_VOPRF)
.concat([mode.to_u8()].into())
.concat(CS::ID.to_be_bytes().into())
}
pub(crate) fn i2osp_2(input: usize) -> Result<[u8; 2], InternalError> {
u16::try_from(input)
.map(|input| input.to_be_bytes())
.map_err(|_| InternalError::I2osp)
}
pub(crate) fn i2osp_2_array<L: ArrayLength<u8> + IsLess<U256>>(
_: &GenericArray<u8, L>,
) -> GenericArray<u8, U2> {
L::U16.to_be_bytes().into()
}