use alloc::vec::Vec;
use hkdf::{Hkdf, hmac::SimpleHmac};
use k256::sha2::Sha256;
use rand::{CryptoRng, RngCore};
use subtle::ConstantTimeEq;
use crate::{
dsa::eddsa_25519_sha512::{KeyExchangeKey, PublicKey},
ecdh::KeyAgreementScheme,
utils::{
ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable,
zeroize::{Zeroize, ZeroizeOnDrop, Zeroizing},
},
};
pub struct SharedSecret {
pub(crate) inner: x25519_dalek::SharedSecret,
}
impl SharedSecret {
pub(crate) fn new(inner: x25519_dalek::SharedSecret) -> SharedSecret {
Self { inner }
}
pub fn extract(&self, salt: Option<&[u8]>) -> Hkdf<Sha256, SimpleHmac<Sha256>> {
Hkdf::new(salt, self.inner.as_bytes())
}
}
impl Zeroize for SharedSecret {
fn zeroize(&mut self) {
let bytes = self.inner.as_bytes();
for byte in
unsafe { core::slice::from_raw_parts_mut(bytes.as_ptr() as *mut u8, bytes.len()) }
{
unsafe {
core::ptr::write_volatile(byte, 0u8);
}
}
core::sync::atomic::compiler_fence(core::sync::atomic::Ordering::SeqCst);
}
}
impl ZeroizeOnDrop for SharedSecret {}
impl AsRef<[u8]> for SharedSecret {
fn as_ref(&self) -> &[u8] {
self.inner.as_bytes()
}
}
pub struct EphemeralSecretKey {
inner: x25519_dalek::EphemeralSecret,
}
impl ZeroizeOnDrop for EphemeralSecretKey {}
impl EphemeralSecretKey {
#[cfg(feature = "std")]
#[allow(clippy::new_without_default)]
pub fn new() -> Self {
let mut rng = rand::rng();
Self::with_rng(&mut rng)
}
pub fn with_rng<R: CryptoRng + RngCore>(rng: &mut R) -> Self {
use k256::elliptic_curve::rand_core::SeedableRng;
let mut seed = Zeroizing::new([0_u8; 32]);
RngCore::fill_bytes(rng, &mut *seed);
let rng = rand_hc::Hc128Rng::from_seed(*seed);
let sk = x25519_dalek::EphemeralSecret::random_from_rng(rng);
Self { inner: sk }
}
pub fn public_key(&self) -> EphemeralPublicKey {
EphemeralPublicKey {
inner: x25519_dalek::PublicKey::from(&self.inner),
}
}
pub fn diffie_hellman(self, pk_other: &PublicKey) -> SharedSecret {
let shared = self.inner.diffie_hellman(&pk_other.to_x25519());
SharedSecret::new(shared)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct EphemeralPublicKey {
pub(crate) inner: x25519_dalek::PublicKey,
}
impl Serializable for EphemeralPublicKey {
fn write_into<W: ByteWriter>(&self, target: &mut W) {
target.write_bytes(self.inner.as_bytes());
}
}
impl Deserializable for EphemeralPublicKey {
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
let bytes: [u8; 32] = source.read_array()?;
let mont = curve25519_dalek::montgomery::MontgomeryPoint(bytes);
let edwards = mont.to_edwards(0).ok_or_else(|| {
DeserializationError::InvalidValue("Invalid X25519 public key".into())
})?;
if edwards.is_small_order() {
return Err(DeserializationError::InvalidValue("Invalid X25519 public key".into()));
}
Ok(Self {
inner: x25519_dalek::PublicKey::from(bytes),
})
}
}
pub struct X25519;
impl KeyAgreementScheme for X25519 {
type EphemeralSecretKey = EphemeralSecretKey;
type EphemeralPublicKey = EphemeralPublicKey;
type SecretKey = KeyExchangeKey;
type PublicKey = PublicKey;
type SharedSecret = SharedSecret;
fn generate_ephemeral_keypair<R: CryptoRng + RngCore>(
rng: &mut R,
) -> (Self::EphemeralSecretKey, Self::EphemeralPublicKey) {
let sk = EphemeralSecretKey::with_rng(rng);
let pk = sk.public_key();
(sk, pk)
}
fn exchange_ephemeral_static(
ephemeral_sk: Self::EphemeralSecretKey,
static_pk: &Self::PublicKey,
) -> Result<Self::SharedSecret, super::KeyAgreementError> {
let shared = ephemeral_sk.diffie_hellman(static_pk);
if is_all_zero(shared.as_ref()) {
return Err(super::KeyAgreementError::InvalidSharedSecret);
}
Ok(shared)
}
fn exchange_static_ephemeral(
static_sk: &Self::SecretKey,
ephemeral_pk: &Self::EphemeralPublicKey,
) -> Result<Self::SharedSecret, super::KeyAgreementError> {
let shared = static_sk.get_shared_secret(ephemeral_pk.clone());
if is_all_zero(shared.as_ref()) {
return Err(super::KeyAgreementError::InvalidSharedSecret);
}
Ok(shared)
}
fn extract_key_material(
shared_secret: &Self::SharedSecret,
length: usize,
info: &[u8],
) -> Result<Vec<u8>, super::KeyAgreementError> {
let hkdf = shared_secret.extract(None);
let mut buf = vec![0_u8; length];
hkdf.expand(info, &mut buf)
.map_err(|_| super::KeyAgreementError::HkdfExpansionFailed)?;
Ok(buf)
}
}
fn is_all_zero(bytes: &[u8]) -> bool {
if bytes.is_empty() {
return false;
}
let acc = bytes.iter().fold(0u8, |acc, &byte| acc | byte);
acc.ct_eq(&0u8).into()
}
#[cfg(test)]
mod tests {
use curve25519_dalek::{constants::EIGHT_TORSION, montgomery::MontgomeryPoint};
use super::*;
use crate::{
dsa::eddsa_25519_sha512::KeyExchangeKey, ecdh::KeyAgreementError,
rand::test_utils::seeded_rng, utils::Deserializable,
};
#[test]
fn key_agreement() {
let mut rng = seeded_rng([0u8; 32]);
let sk = KeyExchangeKey::with_rng(&mut rng);
let pk = sk.public_key();
let sk_e = EphemeralSecretKey::with_rng(&mut rng);
let pk_e = sk_e.public_key();
let shared_secret_key_1 = sk_e.diffie_hellman(&pk);
let shared_secret_key_2 = sk.get_shared_secret(pk_e);
assert_eq!(shared_secret_key_1.inner.to_bytes(), shared_secret_key_2.inner.to_bytes());
}
#[test]
fn ephemeral_public_key_rejects_small_order() {
let bytes = EIGHT_TORSION[1].to_montgomery().to_bytes();
let result = EphemeralPublicKey::read_from_bytes(&bytes);
assert!(result.is_err());
}
#[test]
fn ephemeral_public_key_rejects_twist_point() {
let bytes = find_twist_point_bytes();
let result = EphemeralPublicKey::read_from_bytes(&bytes);
assert!(result.is_err());
}
#[test]
fn exchange_static_ephemeral_rejects_zero_shared_secret() {
let mut rng = seeded_rng([0u8; 32]);
let static_sk = KeyExchangeKey::with_rng(&mut rng);
let low_order_bytes = EIGHT_TORSION[0].to_montgomery().to_bytes();
let low_order_pk = EphemeralPublicKey {
inner: x25519_dalek::PublicKey::from(low_order_bytes),
};
let result = X25519::exchange_static_ephemeral(&static_sk, &low_order_pk);
assert!(matches!(result, Err(KeyAgreementError::InvalidSharedSecret)));
}
#[test]
fn is_all_zero_accepts_arbitrary_lengths() {
assert!(!is_all_zero(&[]));
assert!(is_all_zero(&[0u8; 16]));
assert!(!is_all_zero(&[0u8, 1u8, 0u8, 0u8]));
}
fn find_twist_point_bytes() -> [u8; 32] {
let mut bytes = [0u8; 32];
for i in 0u16..=u16::MAX {
bytes[0] = (i & 0xff) as u8;
bytes[1] = (i >> 8) as u8;
if MontgomeryPoint(bytes).to_edwards(0).is_none() {
return bytes;
}
}
panic!("no twist point found in 16-bit search space");
}
}