use std::marker::PhantomData;
use crate::{
CipherSuite,
crypto::{
aead::{AeadDecrypt, AeadEncrypt},
buffer::{DecryptionBufferView, EncryptionBufferView},
key_derivation::KeyDerivation,
},
error::Result,
header::{Counter, KeyId},
};
pub struct EncryptionKey<A, D>
where
A: AeadEncrypt,
D: KeyDerivation,
{
aead: A,
secret: D::Secret,
cipher_suite: CipherSuite,
key_id: KeyId,
_derivation: PhantomData<D>,
}
impl<A, D> EncryptionKey<A, D>
where
A: AeadEncrypt<Secret = D::Secret>,
D: KeyDerivation,
{
pub fn derive_from<K, M>(cipher_suite: CipherSuite, key_id: K, key_material: M) -> Result<Self>
where
K: Into<KeyId>,
M: AsRef<[u8]>,
{
let key_id = key_id.into();
let aead = A::try_from(cipher_suite)?;
let secret = D::expand_from(cipher_suite, key_material, key_id)?;
Ok(Self {
aead,
secret,
cipher_suite,
key_id,
_derivation: PhantomData,
})
}
pub fn encrypt<'a, B>(&self, buffer: B, counter: Counter) -> Result<()>
where
B: Into<EncryptionBufferView<'a>>,
{
self.aead.encrypt(&self.secret, buffer, counter)
}
pub fn key_id(&self) -> KeyId {
self.key_id
}
pub fn cipher_suite(&self) -> CipherSuite {
self.cipher_suite
}
#[cfg(all(test, crypto_backend))]
pub(crate) fn from_test_vector(
cipher_suite: CipherSuite,
test_vec: &crate::test_vectors::SframeTest,
) -> Self
where
D: KeyDerivation<Secret = crate::crypto::secret::Secret>,
{
if cipher_suite.is_ctr_mode() {
Self::derive_from(cipher_suite, test_vec.key_id, &test_vec.key_material).unwrap()
} else {
let secret = crate::crypto::secret::Secret::from_test_vector(test_vec);
let aead = A::try_from(cipher_suite).unwrap();
Self {
aead,
secret,
cipher_suite,
key_id: test_vec.key_id,
_derivation: PhantomData,
}
}
}
}
pub struct DecryptionKey<A, D>
where
A: AeadDecrypt,
D: KeyDerivation,
{
aead: A,
secret: D::Secret,
cipher_suite: CipherSuite,
key_id: KeyId,
_derivation: PhantomData<D>,
}
impl<A, D> DecryptionKey<A, D>
where
A: AeadDecrypt<Secret = D::Secret>,
D: KeyDerivation,
{
pub fn derive_from<K, M>(cipher_suite: CipherSuite, key_id: K, key_material: M) -> Result<Self>
where
K: Into<KeyId>,
M: AsRef<[u8]>,
{
let key_id = key_id.into();
let aead = A::try_from(cipher_suite)?;
let secret = D::expand_from(cipher_suite, key_material, key_id)?;
Ok(Self {
aead,
secret,
cipher_suite,
key_id,
_derivation: PhantomData,
})
}
pub fn decrypt<'a, B>(&self, buffer: B, counter: Counter) -> Result<()>
where
B: Into<DecryptionBufferView<'a>>,
{
self.aead.decrypt(&self.secret, buffer, counter)
}
pub fn key_id(&self) -> KeyId {
self.key_id
}
pub fn cipher_suite(&self) -> CipherSuite {
self.cipher_suite
}
#[cfg(all(test, crypto_backend))]
pub(crate) fn secret(&self) -> &D::Secret {
&self.secret
}
#[cfg(all(test, crypto_backend))]
pub(crate) fn from_test_vector(
cipher_suite: CipherSuite,
test_vec: &crate::test_vectors::SframeTest,
) -> Self
where
D: KeyDerivation<Secret = crate::crypto::secret::Secret>,
{
if cipher_suite.is_ctr_mode() {
Self::derive_from(cipher_suite, test_vec.key_id, &test_vec.key_material).unwrap()
} else {
let secret = crate::crypto::secret::Secret::from_test_vector(test_vec);
let aead = A::try_from(cipher_suite).unwrap();
Self {
aead,
secret,
cipher_suite,
key_id: test_vec.key_id,
_derivation: PhantomData,
}
}
}
}
macro_rules! impl_key_traits {
($name:ident, $aead:ident) => {
impl<A, D> Clone for $name<A, D>
where
A: $aead + Clone,
D: KeyDerivation,
D::Secret: Clone,
{
fn clone(&self) -> Self {
Self {
aead: self.aead.clone(),
secret: self.secret.clone(),
cipher_suite: self.cipher_suite,
key_id: self.key_id,
_derivation: PhantomData,
}
}
}
impl<A, D> std::fmt::Debug for $name<A, D>
where
A: $aead + std::fmt::Debug,
D: KeyDerivation,
D::Secret: std::fmt::Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct(stringify!($name))
.field("aead", &self.aead)
.field("secret", &self.secret)
.field("cipher_suite", &self.cipher_suite)
.field("key_id", &self.key_id)
.finish()
}
}
impl<A, D> PartialEq for $name<A, D>
where
A: $aead + PartialEq,
D: KeyDerivation,
D::Secret: PartialEq,
{
fn eq(&self, other: &Self) -> bool {
self.aead == other.aead
&& self.secret == other.secret
&& self.cipher_suite == other.cipher_suite
&& self.key_id == other.key_id
}
}
impl<A, D> Eq for $name<A, D>
where
A: $aead + Eq,
D: KeyDerivation,
D::Secret: Eq,
{
}
};
}
impl_key_traits!(EncryptionKey, AeadEncrypt);
impl_key_traits!(DecryptionKey, AeadDecrypt);