use core::convert::TryFrom;
use core::ops::Add;
use derive_where::derive_where;
use digest::{Digest, Output, OutputSizeUser};
use generic_array::sequence::Concat;
use generic_array::typenum::{IsLess, Unsigned, U2, U256, U9};
use generic_array::{ArrayLength, GenericArray};
use rand_core::{TryCryptoRng, TryRngCore};
use subtle::ConstantTimeEq;
#[cfg(feature = "serde")]
use crate::serialization::serde::{Element, Scalar};
use crate::{CipherSuite, Error, Group, InternalError, Result};
pub(crate) const STR_FINALIZE: [u8; 8] = *b"Finalize";
pub(crate) const STR_SEED: [u8; 5] = *b"Seed-";
pub(crate) const STR_DERIVE_KEYPAIR: [u8; 13] = *b"DeriveKeyPair";
pub(crate) const STR_COMPOSITE: [u8; 9] = *b"Composite";
pub(crate) const STR_CHALLENGE: [u8; 9] = *b"Challenge";
pub(crate) const STR_INFO: [u8; 4] = *b"Info";
pub(crate) const STR_OPRF: [u8; 7] = *b"OPRFV1-";
pub(crate) const STR_HASH_TO_SCALAR: [u8; 13] = *b"HashToScalar-";
pub(crate) const STR_HASH_TO_GROUP: [u8; 12] = *b"HashToGroup-";
#[derive(Clone, Copy, Debug)]
pub enum Mode {
Oprf,
Voprf,
Poprf,
}
impl Mode {
pub fn to_u8(self) -> u8 {
match self {
Mode::Oprf => 0,
Mode::Voprf => 1,
Mode::Poprf => 2,
}
}
}
#[derive_where(Clone, ZeroizeOnDrop)]
#[derive_where(Debug, Eq, Hash, Ord, PartialEq, PartialOrd; <CS::Group as Group>::Elem)]
#[cfg_attr(
feature = "serde",
derive(serde::Deserialize, serde::Serialize),
serde(bound = "")
)]
pub struct BlindedElement<CS: CipherSuite>(
#[cfg_attr(feature = "serde", serde(with = "Element::<CS::Group>"))]
pub(crate) <CS::Group as Group>::Elem,
);
#[derive_where(Clone, ZeroizeOnDrop)]
#[derive_where(Debug, Eq, Hash, Ord, PartialEq, PartialOrd; <CS::Group as Group>::Elem)]
#[cfg_attr(
feature = "serde",
derive(serde::Deserialize, serde::Serialize),
serde(bound = "")
)]
pub struct EvaluationElement<CS: CipherSuite>(
#[cfg_attr(feature = "serde", serde(with = "Element::<CS::Group>"))]
pub(crate) <CS::Group as Group>::Elem,
);
#[derive_where(Clone, ZeroizeOnDrop)]
#[derive_where(Debug, Eq, Hash, Ord, PartialEq, PartialOrd; <CS::Group as Group>::Elem)]
#[cfg_attr(
feature = "serde",
derive(serde::Deserialize, serde::Serialize),
serde(bound = "")
)]
pub struct PreparedEvaluationElement<CS: CipherSuite>(pub(crate) EvaluationElement<CS>);
#[derive_where(Clone, ZeroizeOnDrop)]
#[derive_where(Debug, Eq, Hash, Ord, PartialEq, PartialOrd; <CS::Group as Group>::Scalar)]
#[cfg_attr(
feature = "serde",
derive(serde::Deserialize, serde::Serialize),
serde(bound = "")
)]
pub struct Proof<CS: CipherSuite> {
#[cfg_attr(feature = "serde", serde(with = "Scalar::<CS::Group>"))]
pub(crate) c_scalar: <CS::Group as Group>::Scalar,
#[cfg_attr(feature = "serde", serde(with = "Scalar::<CS::Group>"))]
pub(crate) s_scalar: <CS::Group as Group>::Scalar,
}
#[allow(clippy::many_single_char_names)]
pub(crate) fn generate_proof<CS: CipherSuite, R: TryRngCore + TryCryptoRng>(
rng: &mut R,
k: <CS::Group as Group>::Scalar,
a: <CS::Group as Group>::Elem,
b: <CS::Group as Group>::Elem,
cs: impl ExactSizeIterator<Item = <CS::Group as Group>::Elem>,
ds: impl ExactSizeIterator<Item = <CS::Group as Group>::Elem>,
mode: Mode,
) -> Result<Proof<CS>> {
let (m, z) = compute_composites::<CS, _, _>(Some(k), b, cs, ds, mode)?;
let r = CS::Group::random_scalar(rng)?;
let t2 = a * &r;
let t3 = m * &r;
let bm = CS::Group::serialize_elem(b);
let a0 = CS::Group::serialize_elem(m);
let a1 = CS::Group::serialize_elem(z);
let a2 = CS::Group::serialize_elem(t2);
let a3 = CS::Group::serialize_elem(t3);
let elem_len = <CS::Group as Group>::ElemLen::U16.to_be_bytes();
let h2_input = [
&elem_len,
bm.as_slice(),
&elem_len,
&a0,
&elem_len,
&a1,
&elem_len,
&a2,
&elem_len,
&a3,
&STR_CHALLENGE,
];
let dst = Dst::new::<CS, _, _>(STR_HASH_TO_SCALAR, mode);
let c_scalar = CS::Group::hash_to_scalar::<CS::Hash>(&h2_input, &dst.as_dst()).unwrap();
let s_scalar = r - &(c_scalar * &k);
Ok(Proof { c_scalar, s_scalar })
}
#[allow(clippy::many_single_char_names)]
pub(crate) fn verify_proof<CS: CipherSuite>(
a: <CS::Group as Group>::Elem,
b: <CS::Group as Group>::Elem,
cs: impl ExactSizeIterator<Item = <CS::Group as Group>::Elem>,
ds: impl ExactSizeIterator<Item = <CS::Group as Group>::Elem>,
proof: &Proof<CS>,
mode: Mode,
) -> Result<()> {
let (m, z) = compute_composites::<CS, _, _>(None, b, cs, ds, mode)?;
let t2 = (a * &proof.s_scalar) + &(b * &proof.c_scalar);
let t3 = (m * &proof.s_scalar) + &(z * &proof.c_scalar);
let bm = CS::Group::serialize_elem(b);
let a0 = CS::Group::serialize_elem(m);
let a1 = CS::Group::serialize_elem(z);
let a2 = CS::Group::serialize_elem(t2);
let a3 = CS::Group::serialize_elem(t3);
let elem_len = <CS::Group as Group>::ElemLen::U16.to_be_bytes();
let h2_input = [
&elem_len,
bm.as_slice(),
&elem_len,
&a0,
&elem_len,
&a1,
&elem_len,
&a2,
&elem_len,
&a3,
&STR_CHALLENGE,
];
let dst = Dst::new::<CS, _, _>(STR_HASH_TO_SCALAR, mode);
let c = CS::Group::hash_to_scalar::<CS::Hash>(&h2_input, &dst.as_dst()).unwrap();
match c.ct_eq(&proof.c_scalar).into() {
true => Ok(()),
false => Err(Error::ProofVerification),
}
}
type ComputeCompositesResult<CS> = (
<<CS as CipherSuite>::Group as Group>::Elem,
<<CS as CipherSuite>::Group as Group>::Elem,
);
fn compute_composites<
CS: CipherSuite,
IC: Iterator<Item = <CS::Group as Group>::Elem> + ExactSizeIterator,
ID: Iterator<Item = <CS::Group as Group>::Elem> + ExactSizeIterator,
>(
k_option: Option<<CS::Group as Group>::Scalar>,
b: <CS::Group as Group>::Elem,
c_slice: IC,
d_slice: ID,
mode: Mode,
) -> Result<ComputeCompositesResult<CS>> {
let elem_len = <CS::Group as Group>::ElemLen::U16.to_be_bytes();
if c_slice.len() != d_slice.len() {
return Err(Error::Batch);
}
let len = u16::try_from(c_slice.len()).map_err(|_| Error::Batch)?;
let seed_dst = Dst::new::<CS, _, _>(STR_SEED, mode);
let seed = CS::Hash::new()
.chain_update(elem_len)
.chain_update(CS::Group::serialize_elem(b))
.chain_update(seed_dst.i2osp_2())
.chain_update_multi(&seed_dst.as_dst())
.finalize();
let seed_len = i2osp_2_array::<<CS::Hash as OutputSizeUser>::OutputSize>();
let mut m = CS::Group::identity_elem();
let mut z = CS::Group::identity_elem();
for (i, (c, d)) in (0..len).zip(c_slice.zip(d_slice)) {
let ci = CS::Group::serialize_elem(c);
let di = CS::Group::serialize_elem(d);
let h2_input = [
seed_len.as_slice(),
&seed,
&i.to_be_bytes(),
&elem_len,
&ci,
&elem_len,
&di,
&STR_COMPOSITE,
];
let dst = Dst::new::<CS, _, _>(STR_HASH_TO_SCALAR, mode);
let di = CS::Group::hash_to_scalar::<CS::Hash>(&h2_input, &dst.as_dst()).unwrap();
m = c * &di + &m;
z = match k_option {
Some(_) => z,
None => d * &di + &z,
};
}
z = match k_option {
Some(k) => m * &k,
None => z,
};
Ok((m, z))
}
pub(crate) fn derive_key_internal<CS: CipherSuite>(
seed: &[u8],
info: &[u8],
mode: Mode,
) -> Result<<CS::Group as Group>::Scalar, Error> {
let dst = Dst::new::<CS, _, _>(STR_DERIVE_KEYPAIR, mode);
let info_len = i2osp_2(info.len()).map_err(|_| Error::DeriveKeyPair)?;
for counter in 0_u8..=u8::MAX {
let sk_s = CS::Group::hash_to_scalar::<CS::Hash>(
&[seed, &info_len, info, &counter.to_be_bytes()],
&dst.as_dst(),
)
.map_err(|_| Error::DeriveKeyPair)?;
if !bool::from(CS::Group::is_zero_scalar(sk_s)) {
return Ok(sk_s);
}
}
Err(Error::Protocol)
}
#[cfg(feature = "danger")]
pub fn derive_key<CS: CipherSuite>(
seed: &[u8],
info: &[u8],
mode: Mode,
) -> Result<<CS::Group as Group>::Scalar, Error> {
derive_key_internal::<CS>(seed, info, mode)
}
type DeriveKeypairResult<CS> = (
<<CS as CipherSuite>::Group as Group>::Scalar,
<<CS as CipherSuite>::Group as Group>::Elem,
);
pub(crate) fn derive_keypair<CS: CipherSuite>(
seed: &[u8],
info: &[u8],
mode: Mode,
) -> Result<DeriveKeypairResult<CS>, Error> {
let sk_s = derive_key_internal::<CS>(seed, info, mode)?;
let pk_s = CS::Group::base_elem() * &sk_s;
Ok((sk_s, pk_s))
}
pub(crate) fn deterministic_blind_unchecked<CS: CipherSuite>(
input: &[u8],
blind: &<CS::Group as Group>::Scalar,
mode: Mode,
) -> Result<<CS::Group as Group>::Elem> {
let hashed_point = hash_to_group::<CS>(input, mode)?;
Ok(hashed_point * blind)
}
pub(crate) fn hash_to_group<CS: CipherSuite>(
input: &[u8],
mode: Mode,
) -> Result<<CS::Group as Group>::Elem> {
let dst = Dst::new::<CS, _, _>(STR_HASH_TO_GROUP, mode);
CS::Group::hash_to_curve::<CS::Hash>(&[input], &dst.as_dst()).map_err(|_| Error::Input)
}
pub(crate) fn server_evaluate_hash_input<CS: CipherSuite>(
input: &[u8],
info: Option<&[u8]>,
issued_element: GenericArray<u8, <<CS as CipherSuite>::Group as Group>::ElemLen>,
) -> Result<Output<CS::Hash>> {
let mut hash = CS::Hash::new()
.chain_update(i2osp_2(input.as_ref().len()).map_err(|_| Error::Input)?)
.chain_update(input.as_ref());
if let Some(info) = info {
hash = hash
.chain_update(i2osp_2(info.as_ref().len()).map_err(|_| Error::Input)?)
.chain_update(info.as_ref());
}
Ok(hash
.chain_update(i2osp_2(issued_element.as_slice().len()).map_err(|_| Error::Input)?)
.chain_update(issued_element)
.chain_update(STR_FINALIZE)
.finalize())
}
pub(crate) struct Dst<L: ArrayLength> {
dst_1: GenericArray<u8, L>,
dst_2: &'static str,
}
impl<L: ArrayLength> Dst<L> {
pub(crate) fn new<CS, T, TL>(par_1: T, mode: Mode) -> Self
where
CS: CipherSuite,
T: Into<GenericArray<u8, TL>>,
TL: ArrayLength + Add<U9, Output = L>,
{
let par_1 = par_1.into();
let par_2 = GenericArray::from(STR_OPRF)
.concat([mode.to_u8()].into())
.concat([b'-'].into());
let dst_1 = par_1.concat(par_2);
let dst_2 = CS::ID;
assert!(
L::USIZE + dst_2.len() <= u16::MAX.into(),
"constructed DST longer then {}",
u16::MAX
);
Self { dst_1, dst_2 }
}
pub(crate) fn as_dst(&self) -> [&[u8]; 2] {
[&self.dst_1, self.dst_2.as_bytes()]
}
pub(crate) fn i2osp_2(&self) -> [u8; 2] {
u16::try_from(L::USIZE + self.dst_2.len())
.unwrap()
.to_be_bytes()
}
}
trait DigestExt {
fn chain_update_multi(self, data: &[&[u8]]) -> Self;
}
impl<T> DigestExt for T
where
T: Digest,
{
fn chain_update_multi(mut self, datas: &[&[u8]]) -> Self {
for data in datas {
self.update(data)
}
self
}
}
pub(crate) fn i2osp_2(input: usize) -> Result<[u8; 2], InternalError> {
u16::try_from(input)
.map(|input| input.to_be_bytes())
.map_err(|_| InternalError::I2osp)
}
pub(crate) fn i2osp_2_array<L: ArrayLength + IsLess<U256>>() -> GenericArray<u8, U2> {
L::U16.to_be_bytes().into()
}