use super::Kdf;
use crate::{
crypto::{
cipher_suite::CipherSuite,
common::key_derivation::expand_subsecret,
key_derivation::{
KeyDerivation, Ratcheting, get_hkdf_key_expand_label, get_hkdf_ratchet_expand_label,
get_hkdf_salt_expand_label,
},
secret::Secret,
},
error::{Result, SframeError},
header::KeyId,
};
impl KeyDerivation for Kdf {
type Secret = Secret;
fn expand_from<M, K>(cipher_suite: CipherSuite, key_material: M, key_id: K) -> Result<Secret>
where
M: AsRef<[u8]>,
K: Into<KeyId>,
{
let key_id = key_id.into();
let try_expand = || {
let (base_key, salt) = expand_secret(cipher_suite, key_material.as_ref(), key_id)?;
let secret = if cipher_suite.is_ctr_mode() {
let (key, auth) = expand_subsecret(cipher_suite, &base_key);
Secret::aes_ctr(key, salt, auth)
} else {
Secret::aead(base_key, salt)
};
Ok(secret)
};
try_expand().map_err(|err: openssl::error::ErrorStack| {
log::debug!("Key derivation failed, OpenSSL error stack: {err}");
SframeError::KeyDerivationFailure
})
}
}
impl Ratcheting for Kdf {
fn ratchet<M>(cipher_suite: CipherSuite, base_key: M) -> Result<Vec<u8>>
where
M: AsRef<[u8]>,
{
let prk = extract_pseudo_random_key(cipher_suite, base_key.as_ref(), b"")?;
expand_key(
cipher_suite,
&prk,
get_hkdf_ratchet_expand_label(),
cipher_suite.nonce_len(),
)
.map_err(|_: openssl::error::ErrorStack| SframeError::RatchetingFailure)
}
}
fn expand_secret(
cipher_suite: CipherSuite,
key_material: &[u8],
key_id: u64,
) -> std::result::Result<(Vec<u8>, Vec<u8>), openssl::error::ErrorStack> {
let prk = extract_pseudo_random_key(cipher_suite, key_material, b"")?;
let key = expand_key(
cipher_suite,
&prk,
&get_hkdf_key_expand_label(key_id, cipher_suite),
cipher_suite.key_len(),
)?;
let salt = expand_key(
cipher_suite,
&prk,
&get_hkdf_salt_expand_label(key_id, cipher_suite),
cipher_suite.nonce_len(),
)?;
Ok((key, salt))
}
fn extract_pseudo_random_key(
cipher_suite: CipherSuite,
key_material: &[u8],
salt: &[u8],
) -> std::result::Result<Vec<u8>, openssl::error::ErrorStack> {
let mut ctx = init_openssl_ctx(cipher_suite)?;
ctx.set_hkdf_mode(openssl::pkey_ctx::HkdfMode::EXTRACT_ONLY)?;
ctx.set_hkdf_salt(salt)?;
ctx.set_hkdf_key(key_material)?;
let mut prk = vec![];
ctx.derive_to_vec(&mut prk)?;
Ok(prk)
}
fn expand_key(
cipher_suite: CipherSuite,
prk: &[u8],
info: &[u8],
key_len: usize,
) -> std::result::Result<Vec<u8>, openssl::error::ErrorStack> {
let mut ctx = init_openssl_ctx(cipher_suite)?;
ctx.set_hkdf_mode(openssl::pkey_ctx::HkdfMode::EXPAND_ONLY)?;
ctx.set_hkdf_key(prk)?;
ctx.add_hkdf_info(info)?;
let mut key = vec![0; key_len];
ctx.derive(Some(&mut key))?;
Ok(key)
}
fn init_openssl_ctx(
cipher_suite: CipherSuite,
) -> std::result::Result<openssl::pkey_ctx::PkeyCtx<()>, openssl::error::ErrorStack> {
let mut ctx = openssl::pkey_ctx::PkeyCtx::new_id(openssl::pkey::Id::HKDF)?;
ctx.derive_init()?;
let digest = cipher_suite.into();
ctx.set_hkdf_md(digest)?;
Ok(ctx)
}
impl From<CipherSuite> for &'static openssl::md::MdRef {
fn from(cipher_suite: CipherSuite) -> Self {
match cipher_suite {
CipherSuite::AesGcm128Sha256
| CipherSuite::AesCtr128HmacSha256_80
| CipherSuite::AesCtr128HmacSha256_64
| CipherSuite::AesCtr128HmacSha256_32 => openssl::md::Md::sha256(),
CipherSuite::AesGcm256Sha512 => openssl::md::Md::sha512(),
}
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::{test_vectors::get_aes_ctr_test_vector, util::test::assert_bytes_eq};
use test_case::test_case;
#[test_case(CipherSuite::AesCtr128HmacSha256_80; "AesCtr128HmacSha256_80")]
#[test_case(CipherSuite::AesCtr128HmacSha256_64; "AesCtr128HmacSha256_64")]
#[test_case(CipherSuite::AesCtr128HmacSha256_32; "AesCtr128HmacSha256_32")]
fn derive_correct_sub_keys(cipher_suite: CipherSuite) {
let test_vec = get_aes_ctr_test_vector(&cipher_suite.to_string());
let (key, auth) = expand_subsecret(cipher_suite, &test_vec.base_key);
assert_bytes_eq(&key, &test_vec.enc_key);
assert_bytes_eq(&auth, &test_vec.auth_key);
}
}