use crate::{DecapsulationKey, EncapsulationKey};
use core::marker::PhantomData;
use elliptic_curve::{
AffinePoint, CurveArithmetic, Error, FieldBytes, FieldBytesSize, PublicKey, SecretKey,
ecdh::EphemeralSecret,
sec1::{FromSec1Point, ModulusSize, ToSec1Point, UncompressedPoint, UncompressedPointSize},
};
use kem::{
Ciphertext, Decapsulator, Encapsulate, Generate, InvalidKey, Kem, KeyExport, KeySizeUser,
SharedKey, TryDecapsulate, TryKeyInit,
};
use rand_core::{CryptoRng, TryCryptoRng};
#[cfg(doc)]
use crate::Expander;
#[derive(Clone, Copy, Debug, Default, Eq, PartialEq, PartialOrd, Ord)]
pub struct EcdhKem<C: CurveArithmetic>(PhantomData<C>);
impl<C> Kem for EcdhKem<C>
where
C: CurveArithmetic,
FieldBytesSize<C>: ModulusSize,
EcdhDecapsulationKey<C>: TryDecapsulate<Kem = Self> + Generate,
EcdhEncapsulationKey<C>: Encapsulate<Kem = Self> + Clone,
{
type DecapsulationKey = EcdhDecapsulationKey<C>;
type EncapsulationKey = EcdhEncapsulationKey<C>;
type CiphertextSize = UncompressedPointSize<C>;
type SharedKeySize = FieldBytesSize<C>;
}
pub type EcdhDecapsulationKey<C> = DecapsulationKey<SecretKey<C>, PublicKey<C>>;
impl<C> Decapsulator for EcdhDecapsulationKey<C>
where
C: CurveArithmetic,
FieldBytesSize<C>: ModulusSize,
AffinePoint<C>: FromSec1Point<C> + ToSec1Point<C>,
{
type Kem = EcdhKem<C>;
fn encapsulation_key(&self) -> &EcdhEncapsulationKey<C> {
&self.ek
}
}
impl<C> KeySizeUser for EcdhDecapsulationKey<C>
where
C: CurveArithmetic,
{
type KeySize = FieldBytesSize<C>;
}
impl<C> TryKeyInit for EcdhDecapsulationKey<C>
where
C: CurveArithmetic,
{
fn new(key: &FieldBytes<C>) -> Result<Self, InvalidKey> {
SecretKey::from_bytes(key)
.map(Into::into)
.map_err(|_| InvalidKey)
}
}
impl<C> KeyExport for EcdhDecapsulationKey<C>
where
C: CurveArithmetic,
{
fn to_bytes(&self) -> FieldBytes<C> {
self.dk.to_bytes()
}
}
impl<C> Generate for EcdhDecapsulationKey<C>
where
C: CurveArithmetic,
FieldBytesSize<C>: ModulusSize,
{
fn try_generate_from_rng<R: TryCryptoRng + ?Sized>(rng: &mut R) -> Result<Self, R::Error> {
Ok(SecretKey::try_generate_from_rng(rng)?.into())
}
}
impl<C> TryDecapsulate for EcdhDecapsulationKey<C>
where
C: CurveArithmetic,
FieldBytesSize<C>: ModulusSize,
AffinePoint<C>: FromSec1Point<C> + ToSec1Point<C>,
{
type Error = Error;
fn try_decapsulate(
&self,
encapsulated_key: &Ciphertext<EcdhKem<C>>,
) -> Result<SharedKey<EcdhKem<C>>, Error> {
let encapsulated_key = PublicKey::<C>::from_sec1_bytes(encapsulated_key)?;
let shared_secret = self.dk.diffie_hellman(&encapsulated_key);
Ok(shared_secret.raw_secret_bytes().clone())
}
}
pub type EcdhEncapsulationKey<C> = EncapsulationKey<PublicKey<C>>;
impl<C> KeySizeUser for EcdhEncapsulationKey<C>
where
C: CurveArithmetic,
FieldBytesSize<C>: ModulusSize,
{
type KeySize = UncompressedPointSize<C>;
}
impl<C> TryKeyInit for EcdhEncapsulationKey<C>
where
C: CurveArithmetic,
FieldBytesSize<C>: ModulusSize,
AffinePoint<C>: FromSec1Point<C> + ToSec1Point<C>,
{
fn new(encapsulation_key: &UncompressedPoint<C>) -> Result<Self, InvalidKey> {
PublicKey::<C>::from_sec1_bytes(encapsulation_key)
.map(Into::into)
.map_err(|_| InvalidKey)
}
}
impl<C> KeyExport for EcdhEncapsulationKey<C>
where
C: CurveArithmetic,
FieldBytesSize<C>: ModulusSize,
AffinePoint<C>: FromSec1Point<C> + ToSec1Point<C>,
{
fn to_bytes(&self) -> UncompressedPoint<C> {
self.0.to_uncompressed_point()
}
}
impl<C> Encapsulate for EcdhEncapsulationKey<C>
where
C: CurveArithmetic,
FieldBytesSize<C>: ModulusSize,
AffinePoint<C>: FromSec1Point<C> + ToSec1Point<C>,
{
type Kem = EcdhKem<C>;
fn encapsulate_with_rng<R>(
&self,
rng: &mut R,
) -> (Ciphertext<EcdhKem<C>>, SharedKey<EcdhKem<C>>)
where
R: CryptoRng + ?Sized,
{
let sk = EphemeralSecret::generate_from_rng(rng);
let ss = sk.diffie_hellman(&self.0);
let pk = sk.public_key().to_uncompressed_point();
(pk, ss.raw_secret_bytes().clone())
}
}