extern crate alloc;
use alloc::vec::Vec;
use ring::hmac;
pub const SESSION_KEY_TAG: &[u8] = b"SessionKey";
pub const SESSION_RECEIVER_KEY_TAG: &[u8] = b"SessionReceiverKey";
pub const AAD_HEADER_LEN: usize = 16;
#[must_use]
pub fn derive_session_key(master_key: &[u8], master_salt: &[u8], session_id: &[u8; 4]) -> [u8; 32] {
derive_with_tag(master_key, master_salt, SESSION_KEY_TAG, session_id)
}
#[must_use]
pub fn derive_session_hmac_key(
master_receiver_specific_key: &[u8],
master_salt: &[u8],
session_id: &[u8; 4],
) -> [u8; 32] {
derive_with_tag(
master_receiver_specific_key,
master_salt,
SESSION_RECEIVER_KEY_TAG,
session_id,
)
}
fn derive_with_tag(key: &[u8], master_salt: &[u8], tag: &[u8], session_id: &[u8; 4]) -> [u8; 32] {
let hmac_key = hmac::Key::new(hmac::HMAC_SHA256, key);
let mut ctx = hmac::Context::with_key(&hmac_key);
ctx.update(master_salt);
ctx.update(tag);
ctx.update(session_id);
let tag = ctx.sign();
let mut out = [0u8; 32];
out.copy_from_slice(tag.as_ref());
out
}
#[must_use]
pub fn compute_aad(
transformation_kind: [u8; 4],
transformation_key_id: [u8; 4],
session_id: [u8; 4],
extension: &[u8],
) -> Vec<u8> {
let mut out = Vec::with_capacity(AAD_HEADER_LEN + extension.len());
out.extend_from_slice(&transformation_kind);
out.extend_from_slice(&transformation_key_id);
out.extend_from_slice(&session_id);
out.extend_from_slice(&[0u8; 4]); out.extend_from_slice(extension);
out
}
#[cfg(test)]
#[allow(clippy::expect_used, clippy::unwrap_used)]
mod tests {
use super::*;
#[test]
fn session_key_is_deterministic() {
let mk = [0xAA; 32];
let salt = [0xBB; 32];
let sid = [0x01, 0x02, 0x03, 0x04];
let k1 = derive_session_key(&mk, &salt, &sid);
let k2 = derive_session_key(&mk, &salt, &sid);
assert_eq!(k1, k2);
}
#[test]
fn session_key_changes_with_session_id() {
let mk = [0xAA; 32];
let salt = [0xBB; 32];
let k1 = derive_session_key(&mk, &salt, &[1, 2, 3, 4]);
let k2 = derive_session_key(&mk, &salt, &[1, 2, 3, 5]);
assert_ne!(k1, k2);
}
#[test]
fn session_key_changes_with_master_salt() {
let mk = [0xAA; 32];
let sid = [0x01, 0x02, 0x03, 0x04];
let k1 = derive_session_key(&mk, &[0xBB; 32], &sid);
let k2 = derive_session_key(&mk, &[0xCC; 32], &sid);
assert_ne!(k1, k2);
}
#[test]
fn session_key_changes_with_master_key() {
let salt = [0xBB; 32];
let sid = [0x01, 0x02, 0x03, 0x04];
let k1 = derive_session_key(&[0xAA; 32], &salt, &sid);
let k2 = derive_session_key(&[0xCC; 32], &salt, &sid);
assert_ne!(k1, k2);
}
#[test]
fn sender_key_and_receiver_key_use_different_tags() {
let mk = [0xAA; 32];
let salt = [0xBB; 32];
let sid = [0x01, 0x02, 0x03, 0x04];
let sender = derive_session_key(&mk, &salt, &sid);
let receiver = derive_session_hmac_key(&mk, &salt, &sid);
assert_ne!(sender, receiver);
}
#[test]
fn session_key_aes128_gcm_truncated_to_16_byte() {
let mk = [0xAA; 16];
let salt = [0xBB; 32];
let sid = [0; 4];
let full = derive_session_key(&mk, &salt, &sid);
let aes128_key = &full[..16];
assert_eq!(aes128_key.len(), 16);
let full2 = derive_session_key(&mk, &salt, &sid);
assert_eq!(&full2[..16], aes128_key);
}
#[test]
fn aad_layout_is_16_byte_with_padding() {
let aad = compute_aad([0, 0, 0, 0x02], [0, 0, 0, 0x07], [0, 0, 0, 0x42], &[]);
assert_eq!(aad.len(), AAD_HEADER_LEN);
assert_eq!(&aad[0..4], &[0, 0, 0, 0x02]);
assert_eq!(&aad[4..8], &[0, 0, 0, 0x07]);
assert_eq!(&aad[8..12], &[0, 0, 0, 0x42]);
assert_eq!(&aad[12..16], &[0, 0, 0, 0]); }
#[test]
fn aad_with_extension_appends_after_header() {
let ext = b"rtps-header-bytes";
let aad = compute_aad([0; 4], [0; 4], [0; 4], ext);
assert_eq!(aad.len(), AAD_HEADER_LEN + ext.len());
assert_eq!(&aad[AAD_HEADER_LEN..], ext);
}
#[test]
fn aad_distinct_for_different_session_ids() {
let a = compute_aad([0; 4], [0; 4], [1, 2, 3, 4], &[]);
let b = compute_aad([0; 4], [0; 4], [1, 2, 3, 5], &[]);
assert_ne!(a, b);
}
#[test]
fn aad_distinct_for_different_kinds() {
let a = compute_aad([0, 0, 0, 0x02], [0; 4], [0; 4], &[]);
let b = compute_aad([0, 0, 0, 0x04], [0; 4], [0; 4], &[]);
assert_ne!(a, b);
}
#[test]
fn rfc_4231_hmac_sha256_known_vector_via_derive() {
let key = [0x0b; 20];
let master_salt = b"Hi T";
let tag: &[u8] = b"";
let sid = [b'h', b'e', b'r', b'e'];
let out = derive_with_tag(&key, master_salt, tag, &sid);
let expected: [u8; 32] = [
0xb0, 0x34, 0x4c, 0x61, 0xd8, 0xdb, 0x38, 0x53, 0x5c, 0xa8, 0xaf, 0xce, 0xaf, 0x0b,
0xf1, 0x2b, 0x88, 0x1d, 0xc2, 0x00, 0xc9, 0x83, 0x3d, 0xa7, 0x26, 0xe9, 0x37, 0x6c,
0x2e, 0x32, 0xcf, 0xf7,
];
assert_eq!(out, expected);
}
}