#![allow(unsafe_code)]
use derive_where::derive_where;
use generic_array::{ArrayLength, GenericArray};
use rand::{CryptoRng, RngCore};
use crate::errors::{InternalError, ProtocolError};
use crate::key_exchange::group::KeGroup;
#[cfg_attr(
feature = "serde",
derive(serde::Deserialize, serde::Serialize),
serde(
bound(
deserialize = "S: serde::Deserialize<'de>",
serialize = "S: serde::Serialize"
),
crate = "serde"
)
)]
#[derive_where(Clone)]
#[derive_where(Debug, Eq, Hash, Ord, PartialEq, PartialOrd; KG::Pk, S)]
pub struct KeyPair<KG: KeGroup, S: SecretKey<KG> = PrivateKey<KG>> {
pk: PublicKey<KG>,
sk: S,
}
impl<KG: KeGroup, S: SecretKey<KG>> KeyPair<KG, S> {
pub fn public(&self) -> &PublicKey<KG> {
&self.pk
}
pub fn private(&self) -> &S {
&self.sk
}
pub fn from_private_key_slice(input: &[u8]) -> Result<Self, ProtocolError<S::Error>> {
Self::from_private_key(S::deserialize(input)?)
}
pub fn from_private_key(sk: S) -> Result<Self, ProtocolError<S::Error>> {
let pk = sk.public_key()?;
Ok(Self { pk, sk })
}
}
impl<KG: KeGroup> KeyPair<KG> {
pub(crate) fn generate_random<R: RngCore + CryptoRng>(rng: &mut R) -> Self {
let sk = KG::random_sk(rng);
let pk = KG::public_key(sk);
Self {
pk: PublicKey(pk),
sk: PrivateKey(sk),
}
}
}
#[cfg(test)]
impl<KG: KeGroup> KeyPair<KG>
where
KG::Pk: std::fmt::Debug,
KG::Sk: std::fmt::Debug,
{
fn uniform_keypair_strategy() -> proptest::prelude::BoxedStrategy<Self> {
use proptest::prelude::*;
use rand::rngs::StdRng;
use rand::SeedableRng;
any::<[u8; 32]>()
.prop_filter_map("valid random keypair", |seed| {
let mut rng = StdRng::from_seed(seed);
Some(Self::generate_random(&mut rng))
})
.no_shrink()
.boxed()
}
}
#[derive_where(Clone, ZeroizeOnDrop)]
#[derive_where(Debug, Eq, Hash, Ord, PartialEq, PartialOrd; KG::Sk)]
pub struct PrivateKey<KG: KeGroup>(KG::Sk);
pub trait SecretKey<KG: KeGroup>: Clone + Sized {
type Error;
type Len: ArrayLength<u8>;
fn diffie_hellman(
&self,
pk: PublicKey<KG>,
) -> Result<GenericArray<u8, KG::PkLen>, InternalError<Self::Error>>;
fn public_key(&self) -> Result<PublicKey<KG>, InternalError<Self::Error>>;
fn serialize(&self) -> GenericArray<u8, Self::Len>;
fn deserialize(input: &[u8]) -> Result<Self, InternalError<Self::Error>>;
}
impl<KG: KeGroup> SecretKey<KG> for PrivateKey<KG> {
type Error = core::convert::Infallible;
type Len = KG::SkLen;
fn diffie_hellman(
&self,
pk: PublicKey<KG>,
) -> Result<GenericArray<u8, KG::PkLen>, InternalError> {
Ok(KG::diffie_hellman(pk.0, self.0))
}
fn public_key(&self) -> Result<PublicKey<KG>, InternalError> {
Ok(PublicKey(KG::public_key(self.0)))
}
fn serialize(&self) -> GenericArray<u8, Self::Len> {
KG::serialize_sk(self.0)
}
fn deserialize(input: &[u8]) -> Result<Self, InternalError> {
KG::deserialize_sk(input).map(Self)
}
}
#[cfg(feature = "serde")]
impl<'de, KG: KeGroup> serde::Deserialize<'de> for PrivateKey<KG> {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
use serde::de::Error;
KG::deserialize_sk(&GenericArray::<_, KG::SkLen>::deserialize(deserializer)?)
.map(Self)
.map_err(D::Error::custom)
}
}
#[cfg(feature = "serde")]
impl<KG: KeGroup> serde::Serialize for PrivateKey<KG> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
KG::serialize_sk(self.0).serialize(serializer)
}
}
#[derive_where(Clone, ZeroizeOnDrop)]
#[derive_where(Debug, Eq, Hash, Ord, PartialEq, PartialOrd; KG::Pk)]
pub struct PublicKey<KG: KeGroup>(KG::Pk);
impl<KG: KeGroup> PublicKey<KG> {
pub fn deserialize(key_bytes: &[u8]) -> Result<Self, InternalError> {
KG::deserialize_pk(key_bytes).map(Self)
}
pub fn serialize(&self) -> GenericArray<u8, KG::PkLen> {
KG::serialize_pk(self.0)
}
}
#[cfg(feature = "serde")]
impl<'de, KG: KeGroup> serde::Deserialize<'de> for PublicKey<KG> {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
use serde::de::Error;
KG::deserialize_pk(&GenericArray::<_, KG::PkLen>::deserialize(deserializer)?)
.map(Self)
.map_err(D::Error::custom)
}
}
#[cfg(feature = "serde")]
impl<KG: KeGroup> serde::Serialize for PublicKey<KG> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
KG::serialize_pk(self.0).serialize(serializer)
}
}
#[cfg(test)]
mod tests {
use rand::rngs::OsRng;
use super::*;
use crate::errors::*;
use crate::util;
#[test]
fn test_zeroize_key() {
fn inner<G: KeGroup>() {
let mut rng = OsRng;
let mut key = PrivateKey::<G>(G::random_sk(&mut rng));
util::test_zeroize_on_drop(&mut key);
}
#[cfg(feature = "ristretto255")]
inner::<crate::Ristretto255>();
inner::<::p256::NistP256>();
}
macro_rules! test {
($mod:ident, $point:ty) => {
mod $mod {
use std::format;
use proptest::prelude::*;
use super::*;
proptest! {
#[test]
fn pub_from_priv(kp in KeyPair::<$point>::uniform_keypair_strategy()) {
let pk = kp.public();
let sk = kp.private();
prop_assert_eq!(&sk.public_key()?, pk);
}
#[test]
fn dh(kp1 in KeyPair::<$point>::uniform_keypair_strategy(),
kp2 in KeyPair::<$point>::uniform_keypair_strategy()) {
let dh1 = kp2.private().diffie_hellman(kp1.public().clone())?;
let dh2 = kp1.private().diffie_hellman(kp2.public().clone())?;
prop_assert_eq!(dh1, dh2);
}
#[test]
fn private_key_slice(kp in KeyPair::<$point>::uniform_keypair_strategy()) {
let sk_bytes = kp.private().serialize().to_vec();
let kp2 = KeyPair::<$point>::from_private_key_slice(&sk_bytes)?;
let kp2_private_bytes = kp2.private().serialize().to_vec();
prop_assert_eq!(sk_bytes, kp2_private_bytes);
}
}
}
};
}
#[cfg(feature = "ristretto255")]
test!(ristretto, crate::Ristretto255);
test!(p256, ::p256::NistP256);
#[test]
fn remote_key() {
use rand::rngs::OsRng;
use crate::{
CipherSuite, ClientLogin, ClientLoginFinishParameters, ClientLoginFinishResult,
ClientLoginStartResult, ClientRegistration, ClientRegistrationFinishParameters,
ClientRegistrationFinishResult, ClientRegistrationStartResult, ServerLogin,
ServerLoginStartParameters, ServerLoginStartResult, ServerRegistration,
ServerRegistrationStartResult, ServerSetup,
};
struct Default;
impl CipherSuite for Default {
#[cfg(feature = "ristretto255")]
type OprfCs = crate::Ristretto255;
#[cfg(not(feature = "ristretto255"))]
type OprfCs = ::p256::NistP256;
#[cfg(feature = "ristretto255")]
type KeGroup = crate::Ristretto255;
#[cfg(not(feature = "ristretto255"))]
type KeGroup = ::p256::NistP256;
type KeyExchange = crate::key_exchange::tripledh::TripleDh;
type Ksf = crate::ksf::Identity;
}
type KeCurve = <Default as CipherSuite>::KeGroup;
#[derive(Clone)]
struct RemoteKey(PrivateKey<KeCurve>);
impl SecretKey<KeCurve> for RemoteKey {
type Error = core::convert::Infallible;
type Len = <KeCurve as KeGroup>::SkLen;
fn diffie_hellman(
&self,
pk: PublicKey<KeCurve>,
) -> Result<GenericArray<u8, <KeCurve as KeGroup>::PkLen>, InternalError<Self::Error>>
{
self.0.diffie_hellman(pk)
}
fn public_key(&self) -> Result<PublicKey<KeCurve>, InternalError<Self::Error>> {
self.0.public_key()
}
fn serialize(&self) -> GenericArray<u8, Self::Len> {
self.0.serialize()
}
fn deserialize(input: &[u8]) -> Result<Self, InternalError<Self::Error>> {
PrivateKey::deserialize(input).map(Self)
}
}
const PASSWORD: &str = "password";
let sk = KeCurve::random_sk(&mut OsRng);
let sk = RemoteKey(PrivateKey(sk));
let keypair = KeyPair::from_private_key(sk).unwrap();
let server_setup = ServerSetup::<Default, RemoteKey>::new_with_key(&mut OsRng, keypair);
let ClientRegistrationStartResult {
message,
state: client,
} = ClientRegistration::<Default>::start(&mut OsRng, PASSWORD.as_bytes()).unwrap();
let ServerRegistrationStartResult { message, .. } =
ServerRegistration::start(&server_setup, message, &[]).unwrap();
let ClientRegistrationFinishResult { message, .. } = client
.finish(
&mut OsRng,
PASSWORD.as_bytes(),
message,
ClientRegistrationFinishParameters::default(),
)
.unwrap();
let file = ServerRegistration::finish(message);
let ClientLoginStartResult {
message,
state: client,
} = ClientLogin::<Default>::start(&mut OsRng, PASSWORD.as_bytes()).unwrap();
let ServerLoginStartResult {
message,
state: server,
..
} = ServerLogin::start(
&mut OsRng,
&server_setup,
Some(file),
message,
&[],
ServerLoginStartParameters::default(),
)
.unwrap();
let ClientLoginFinishResult { message, .. } = client
.finish(
PASSWORD.as_bytes(),
message,
ClientLoginFinishParameters::default(),
)
.unwrap();
server.finish(message).unwrap();
}
}