pub use curve25519_dalek;
use curve25519_dalek::constants::RISTRETTO_BASEPOINT_POINT;
use curve25519_dalek::ristretto::{CompressedRistretto, RistrettoPoint};
use curve25519_dalek::scalar::Scalar;
use curve25519_dalek::traits::IsIdentity;
use digest::core_api::BlockSizeUser;
use digest::{FixedOutput, HashMarker};
use generic_array::GenericArray;
use generic_array::typenum::{IsLess, IsLessOrEqual, U32, U256};
use rand::{CryptoRng, RngCore};
use voprf::Mode;
use zeroize::ZeroizeOnDrop;
use super::{Group, STR_OPAQUE_DERIVE_AUTH_KEY_PAIR};
use crate::errors::{InternalError, ProtocolError};
use crate::key_exchange::shared::DiffieHellman;
use crate::serialization::SliceExt;
pub struct Ristretto255;
impl Group for Ristretto255 {
type Pk = NonIdentity;
type PkLen = U32;
type Sk = NonZeroScalar;
type SkLen = U32;
fn serialize_pk(pk: &Self::Pk) -> GenericArray<u8, Self::PkLen> {
pk.0.compress().to_bytes().into()
}
fn deserialize_take_pk(bytes: &mut &[u8]) -> Result<Self::Pk, ProtocolError> {
CompressedRistretto(bytes.take_array("public key")?.into())
.decompress()
.ok_or(ProtocolError::SerializationError)
.and_then(NonIdentity::from_point)
}
fn random_sk<R: RngCore + CryptoRng>(rng: &mut R) -> Self::Sk {
loop {
let scalar = Scalar::random(rng);
if scalar != Scalar::ZERO {
break NonZeroScalar(scalar);
}
}
}
fn derive_scalar(seed: GenericArray<u8, Self::SkLen>) -> Result<Self::Sk, InternalError> {
voprf::derive_key::<Self>(&seed, &STR_OPAQUE_DERIVE_AUTH_KEY_PAIR, Mode::Oprf)
.map(NonZeroScalar)
.map_err(InternalError::from)
}
fn public_key(sk: &Self::Sk) -> Self::Pk {
NonIdentity(RISTRETTO_BASEPOINT_POINT * sk.0)
}
fn serialize_sk(sk: &Self::Sk) -> GenericArray<u8, Self::SkLen> {
sk.0.to_bytes().into()
}
fn deserialize_take_sk(bytes: &mut &[u8]) -> Result<Self::Sk, ProtocolError> {
Scalar::from_canonical_bytes(bytes.take_array("secret key")?.into())
.into_option()
.ok_or(ProtocolError::SerializationError)
.and_then(NonZeroScalar::from_scalar)
}
}
impl DiffieHellman<Ristretto255> for NonZeroScalar {
fn diffie_hellman(&self, pk: &NonIdentity) -> GenericArray<u8, U32> {
Ristretto255::serialize_pk(&NonIdentity(pk.0 * self.0))
}
}
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub struct NonIdentity(
#[cfg_attr(feature = "serde", serde(deserialize_with = "serde_deserialize_pk"))] RistrettoPoint,
);
impl NonIdentity {
fn from_point(point: RistrettoPoint) -> Result<Self, ProtocolError> {
if point.is_identity() {
Err(ProtocolError::SerializationError)
} else {
Ok(NonIdentity(point))
}
}
}
#[cfg(feature = "serde")]
fn serde_deserialize_pk<'de, D>(deserializer: D) -> Result<RistrettoPoint, D::Error>
where
D: serde::Deserializer<'de>,
{
use serde::de::{Deserialize, Error};
let point = RistrettoPoint::deserialize(deserializer)?;
NonIdentity::from_point(point)
.map(|point| point.0)
.map_err(Error::custom)
}
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
#[derive(Clone, Debug, Eq, Hash, PartialEq, ZeroizeOnDrop)]
pub struct NonZeroScalar(
#[cfg_attr(feature = "serde", serde(deserialize_with = "serde_deserialize_sk"))] Scalar,
);
impl NonZeroScalar {
fn from_scalar(scalar: Scalar) -> Result<Self, ProtocolError> {
if scalar == Scalar::ZERO {
Err(ProtocolError::SerializationError)
} else {
Ok(Self(scalar))
}
}
}
#[cfg(feature = "serde")]
fn serde_deserialize_sk<'de, D>(deserializer: D) -> Result<Scalar, D::Error>
where
D: serde::Deserializer<'de>,
{
use serde::de::{Deserialize, Error};
let scalar = Scalar::deserialize(deserializer)?;
NonZeroScalar::from_scalar(scalar)
.map(|scalar| scalar.0)
.map_err(Error::custom)
}
impl voprf::CipherSuite for Ristretto255 {
const ID: &'static str = voprf::Ristretto255::ID;
type Group = <voprf::Ristretto255 as voprf::CipherSuite>::Group;
type Hash = <voprf::Ristretto255 as voprf::CipherSuite>::Hash;
}
impl voprf::Group for Ristretto255 {
type Elem = <voprf::Ristretto255 as voprf::Group>::Elem;
type ElemLen = <voprf::Ristretto255 as voprf::Group>::ElemLen;
type Scalar = <voprf::Ristretto255 as voprf::Group>::Scalar;
type ScalarLen = <voprf::Ristretto255 as voprf::Group>::ScalarLen;
fn hash_to_curve<H>(
input: &[&[u8]],
dst: &[&[u8]],
) -> voprf::Result<Self::Elem, voprf::InternalError>
where
H: BlockSizeUser + Default + FixedOutput + HashMarker,
H::OutputSize: IsLess<U256> + IsLessOrEqual<H::BlockSize>,
{
<voprf::Ristretto255 as voprf::Group>::hash_to_curve::<H>(input, dst)
}
fn hash_to_scalar<H>(
input: &[&[u8]],
dst: &[&[u8]],
) -> voprf::Result<Self::Scalar, voprf::InternalError>
where
H: BlockSizeUser + Default + FixedOutput + HashMarker,
H::OutputSize: IsLess<U256> + IsLessOrEqual<H::BlockSize>,
{
<voprf::Ristretto255 as voprf::Group>::hash_to_scalar::<H>(input, dst)
}
fn base_elem() -> Self::Elem {
<voprf::Ristretto255 as voprf::Group>::base_elem()
}
fn identity_elem() -> Self::Elem {
<voprf::Ristretto255 as voprf::Group>::identity_elem()
}
fn serialize_elem(elem: Self::Elem) -> GenericArray<u8, Self::ElemLen> {
<voprf::Ristretto255 as voprf::Group>::serialize_elem(elem)
}
fn deserialize_elem(element_bits: &[u8]) -> voprf::Result<Self::Elem> {
<voprf::Ristretto255 as voprf::Group>::deserialize_elem(element_bits)
}
fn random_scalar<R: RngCore + CryptoRng>(rng: &mut R) -> Self::Scalar {
<voprf::Ristretto255 as voprf::Group>::random_scalar(rng)
}
fn invert_scalar(scalar: Self::Scalar) -> Self::Scalar {
<voprf::Ristretto255 as voprf::Group>::invert_scalar(scalar)
}
fn is_zero_scalar(scalar: Self::Scalar) -> subtle::Choice {
<voprf::Ristretto255 as voprf::Group>::is_zero_scalar(scalar)
}
fn serialize_scalar(scalar: Self::Scalar) -> GenericArray<u8, Self::ScalarLen> {
<voprf::Ristretto255 as voprf::Group>::serialize_scalar(scalar)
}
fn deserialize_scalar(scalar_bits: &[u8]) -> voprf::Result<Self::Scalar> {
<voprf::Ristretto255 as voprf::Group>::deserialize_scalar(scalar_bits)
}
}