use elliptic_curve::ecdh::SharedSecret;
use elliptic_curve::group::Curve as GroupCurve;
use elliptic_curve::point::AffineCoordinates;
use elliptic_curve::rand_core::OsRng;
use elliptic_curve::sec1::{FromEncodedPoint, ModulusSize, ToEncodedPoint};
use elliptic_curve::subtle::ConstantTimeEq;
use elliptic_curve::zeroize::Zeroizing;
use elliptic_curve::{
ecdh, AffinePoint, Curve, CurveArithmetic, FieldBytesSize, JwkParameters, ProjectivePoint,
PublicKey, SecretKey,
};
use crate::jose::{EcJwk, Jwk, JwkCurve};
use crate::{Error, Result};
#[derive(Clone, Debug)]
pub struct EncryptionKey<const KEYBYTES: usize>(Zeroizing<[u8; KEYBYTES]>);
impl<const KEYBYTES: usize> EncryptionKey<KEYBYTES> {
#[must_use]
pub fn as_bytes(&self) -> &[u8; KEYBYTES] {
&self.0
}
}
impl<const KEYBYTES: usize> ConstantTimeEq for EncryptionKey<KEYBYTES> {
fn ct_eq(&self, other: &Self) -> elliptic_curve::subtle::Choice {
self.0.ct_eq(other.0.as_ref())
}
}
impl<const KEYBYTES: usize> PartialEq for EncryptionKey<KEYBYTES> {
fn eq(&self, other: &Self) -> bool {
self.ct_eq(other).into()
}
}
pub fn create_enc_key<const N: usize>(s_pub_jwk: &EcJwk) -> Result<(EcJwk, EncryptionKey<N>)> {
match s_pub_jwk.get_curve()? {
JwkCurve::P256 => create_enc_key_inner::<p256::NistP256, N>(s_pub_jwk),
JwkCurve::P384 => create_enc_key_inner::<p384::NistP384, N>(s_pub_jwk),
JwkCurve::P521 => create_enc_key_inner::<p521::NistP521, N>(s_pub_jwk),
}
}
pub fn recover_enc_key<const N: usize>(
c_pub_jwk: &EcJwk,
s_pub_jwk: &EcJwk,
server_key_exchange: impl FnOnce(&Jwk) -> Result<Jwk>,
) -> Result<EncryptionKey<N>> {
let key_exchange = |ec_jwk: &EcJwk| -> Result<EcJwk> {
let mut jwk: Jwk = ec_jwk.clone().into();
jwk.alg = Some("ECMR".into());
server_key_exchange(&jwk).and_then(EcJwk::try_from)
};
match c_pub_jwk.get_curve()? {
JwkCurve::P256 => {
recover_enc_key_inner::<p256::NistP256, N>(c_pub_jwk, s_pub_jwk, key_exchange)
}
JwkCurve::P384 => {
recover_enc_key_inner::<p384::NistP384, N>(c_pub_jwk, s_pub_jwk, key_exchange)
}
JwkCurve::P521 => {
recover_enc_key_inner::<p521::NistP521, N>(c_pub_jwk, s_pub_jwk, key_exchange)
}
}
}
fn create_enc_key_inner<C, const KEYBYTES: usize>(
remote_jwk: &EcJwk,
) -> Result<(EcJwk, EncryptionKey<KEYBYTES>)>
where
C: CurveArithmetic + JwkParameters,
AffinePoint<C>: FromEncodedPoint<C> + ToEncodedPoint<C>,
FieldBytesSize<C>: ModulusSize,
{
let serv_kpub = remote_jwk.to_pub()?;
let cli_kpriv = ecdh::EphemeralSecret::<C>::random(&mut OsRng);
let jwk = EcJwk::from_pub(&cli_kpriv.public_key());
let shared = cli_kpriv.diffie_hellman(&serv_kpub);
Ok((jwk, secret_to_key(shared)))
}
fn recover_enc_key_inner<C, const KEYBYTES: usize>(
c_pub_jwk: &EcJwk,
s_pub_jwk: &EcJwk,
server_key_exchange: impl FnOnce(&EcJwk) -> Result<EcJwk>,
) -> Result<EncryptionKey<KEYBYTES>>
where
C: CurveArithmetic + JwkParameters,
AffinePoint<C>: FromEncodedPoint<C> + ToEncodedPoint<C>,
FieldBytesSize<C>: ModulusSize,
{
let c_pub = c_pub_jwk.to_pub::<C>()?;
let s_pub = s_pub_jwk.to_pub::<C>()?;
let e_priv = SecretKey::<C>::random(&mut OsRng);
let e_pub = e_priv.public_key();
let x_pub = ecmr_add(&c_pub, &e_pub)?;
let x_pub_jwk = EcJwk::from_pub(&x_pub);
let y_pub_jwk = server_key_exchange(&x_pub_jwk)?;
let y_pub = y_pub_jwk.to_pub::<C>()?;
let z_pub = diffie_hellman(&e_priv, &s_pub)?;
let k_pub = ecmr_sub(&y_pub, &z_pub)?;
let k: SharedSecret<C> = k_pub.as_affine().x().into();
Ok(secret_to_key(k))
}
fn diffie_hellman<C>(
secret_key: &SecretKey<C>, public_key: &PublicKey<C>, ) -> Result<PublicKey<C>>
where
C: CurveArithmetic,
{
let public_point: ProjectivePoint<C> = (*public_key.as_affine()).into();
let secret_point: AffinePoint<C> =
(public_point * secret_key.to_nonzero_scalar().as_ref()).to_affine();
PublicKey::from_affine(secret_point).map_err(Into::into)
}
pub fn ecmr_add<C>(local: &PublicKey<C>, remote: &PublicKey<C>) -> Result<PublicKey<C>>
where
C: CurveArithmetic,
{
let local_point: ProjectivePoint<C> = (*local.as_affine()).into();
let remote_point: ProjectivePoint<C> = (*remote.as_affine()).into();
PublicKey::from_affine((local_point + remote_point).to_affine())
.map_err(|_| Error::IdentityPointCreated)
}
pub fn ecmr_sub<C>(local: &PublicKey<C>, remote: &PublicKey<C>) -> Result<PublicKey<C>>
where
C: CurveArithmetic,
{
let local_point: ProjectivePoint<C> = (*local.as_affine()).into();
let remote_point: ProjectivePoint<C> = (*remote.as_affine()).into();
PublicKey::from_affine((local_point - remote_point).to_affine())
.map_err(|_| Error::IdentityPointCreated)
}
#[allow(clippy::needless_pass_by_value)]
fn secret_to_key<C: Curve, const KEYBYTES: usize>(
secret: SharedSecret<C>,
) -> EncryptionKey<KEYBYTES> {
let mut enc_key = EncryptionKey(Zeroizing::new([0u8; KEYBYTES]));
concat_kdf::derive_key_into::<sha2::Sha256>(secret.raw_secret_bytes(), &[], enc_key.0.as_mut())
.unwrap();
enc_key
}
#[cfg(test)]
#[path = "key_exchange_tests.rs"]
mod tests;