use ring::{aead, digest, hkdf, rand::SecureRandom};
use crate::keys::{DERIVED_KEY_LEN, HKDF_INFO, HKDF_SALT_CONTEXT};
const NONCE_LEN: usize = 12;
const TAG_LEN: usize = 16;
const MIN_SEALED_LEN: usize = NONCE_LEN + TAG_LEN;
pub(crate) struct DerivedKey {
sealing: aead::LessSafeKey,
}
impl DerivedKey {
pub(crate) fn derive(ikm: &[u8]) -> Self {
let salt_digest = digest::digest(&digest::SHA256, HKDF_SALT_CONTEXT);
let salt = hkdf::Salt::new(hkdf::HKDF_SHA256, salt_digest.as_ref());
let prk = salt.extract(ikm);
let okm = prk
.expand(&[HKDF_INFO], &aead::CHACHA20_POLY1305)
.expect("HKDF-SHA256 expand to a static KeyType cannot fail");
let mut raw = [0u8; DERIVED_KEY_LEN];
okm.fill(&mut raw)
.expect("OKM length matches DERIVED_KEY_LEN by construction");
let unbound = aead::UnboundKey::new(&aead::CHACHA20_POLY1305, &raw)
.expect("raw is exactly DERIVED_KEY_LEN bytes, matching the AEAD key length");
DerivedKey {
sealing: aead::LessSafeKey::new(unbound),
}
}
}
pub(crate) fn seal(key: &DerivedKey, rng: &dyn SecureRandom, plaintext: &[u8]) -> Vec<u8> {
let mut nonce_bytes = [0u8; NONCE_LEN];
rng.fill(&mut nonce_bytes)
.expect("OS RNG must succeed; failure indicates an unrecoverable system fault");
let mut sealed_payload: Vec<u8> = plaintext.to_vec();
let nonce = aead::Nonce::assume_unique_for_key(nonce_bytes);
key.sealing
.seal_in_place_append_tag(nonce, aead::Aad::empty(), &mut sealed_payload)
.expect("ChaCha20-Poly1305 seal cannot fail for a valid key and 12-byte nonce");
let mut out = Vec::with_capacity(NONCE_LEN + sealed_payload.len());
out.extend_from_slice(&nonce_bytes);
out.extend_from_slice(&sealed_payload);
out
}
pub(crate) fn try_decrypt(keys: &[DerivedKey], payload: &[u8]) -> Option<(Vec<u8>, usize)> {
if payload.len() < MIN_SEALED_LEN {
return None;
}
let (nonce_slice, ct_plus_tag) = payload.split_at(NONCE_LEN);
let nonce_array: [u8; NONCE_LEN] = nonce_slice
.try_into()
.expect("split_at(NONCE_LEN) yields a slice of exactly NONCE_LEN bytes");
for (idx, key) in keys.iter().enumerate() {
let mut buf = ct_plus_tag.to_vec();
let nonce = aead::Nonce::assume_unique_for_key(nonce_array);
if let Ok(plain) = key
.sealing
.open_in_place(nonce, aead::Aad::empty(), &mut buf)
{
let plain_len = plain.len();
buf.truncate(plain_len);
return Some((buf, idx));
}
}
None
}
#[cfg(test)]
mod tests {
use super::*;
use ring::rand::SystemRandom;
use std::collections::HashSet;
const IKM_PRIMARY: &[u8; 32] = b"primary-ikm-fixed-bytes-32-len!!";
const IKM_OTHER_1: &[u8; 32] = b"o1-fallback-ikm-fixed-bytes-32!!";
const IKM_OTHER_2: &[u8; 32] = b"o2-fallback-ikm-fixed-bytes-32!!";
const IKM_UNKNOWN: &[u8; 32] = b"x-unknown-ikm-not-in-keylist-32!";
fn key_from(ikm: &[u8]) -> DerivedKey {
DerivedKey::derive(ikm)
}
#[test]
fn try_decrypt_rejects_every_single_bit_flip_ac1_4() {
let rng = SystemRandom::new();
let key = key_from(IKM_PRIMARY);
let plaintext: Vec<u8> = (0..64u8).collect();
let sealed = seal(&key, &rng, &plaintext);
let keys = [key_from(IKM_PRIMARY)];
for byte_index in 0..sealed.len() {
for bit in 0..8u8 {
let mut tampered = sealed.clone();
tampered[byte_index] ^= 1 << bit;
let result = try_decrypt(&keys, &tampered);
assert!(
result.is_none(),
"bit flip at byte {byte_index} bit {bit} unexpectedly authenticated"
);
}
}
}
#[test]
fn try_decrypt_rejects_tamper_in_issued_at_region_ac2_3() {
let rng = SystemRandom::new();
let key = key_from(IKM_PRIMARY);
let mut plaintext = Vec::new();
plaintext.push(1u8);
plaintext.extend_from_slice(&42i64.to_le_bytes());
plaintext.extend_from_slice(b"trailing-payload-bytes");
let sealed = seal(&key, &rng, &plaintext);
for offset in 1..9 {
let mut tampered = sealed.clone();
tampered[NONCE_LEN + offset] ^= 0x01;
let result = try_decrypt(&[key_from(IKM_PRIMARY)], &tampered);
assert!(
result.is_none(),
"tamper inside issued_at byte {offset} unexpectedly authenticated"
);
}
}
#[test]
fn try_decrypt_with_primary_returns_index_zero_ac4_1() {
let rng = SystemRandom::new();
let key = key_from(IKM_PRIMARY);
let plaintext = b"hello, primary";
let sealed = seal(&key, &rng, plaintext);
let result = try_decrypt(&[key_from(IKM_PRIMARY)], &sealed);
let (decrypted, idx) = result.expect("primary key must authenticate its own seal");
assert_eq!(decrypted, plaintext);
assert_eq!(idx, 0);
}
#[test]
fn try_decrypt_falls_back_to_index_one_ac4_2() {
let rng = SystemRandom::new();
let old_key = key_from(IKM_OTHER_1);
let plaintext = b"hello, fallback";
let sealed = seal(&old_key, &rng, plaintext);
let keys = [key_from(IKM_PRIMARY), key_from(IKM_OTHER_1)];
let result = try_decrypt(&keys, &sealed);
let (decrypted, idx) = result.expect("fallback key must authenticate its own seal");
assert_eq!(decrypted, plaintext);
assert_eq!(idx, 1);
}
#[test]
fn try_decrypt_returns_none_for_unknown_key_ac4_5() {
let rng = SystemRandom::new();
let unknown_key = key_from(IKM_UNKNOWN);
let plaintext = b"sealed under an unknown key";
let sealed = seal(&unknown_key, &rng, plaintext);
let keys = [key_from(IKM_PRIMARY), key_from(IKM_OTHER_1)];
let result = try_decrypt(&keys, &sealed);
assert!(result.is_none());
}
#[test]
fn try_decrypt_matches_third_key_ac4_6() {
let rng = SystemRandom::new();
let o2_key = key_from(IKM_OTHER_2);
let plaintext = b"sealed under second fallback";
let sealed = seal(&o2_key, &rng, plaintext);
let keys = [
key_from(IKM_PRIMARY),
key_from(IKM_OTHER_1),
key_from(IKM_OTHER_2),
];
let result = try_decrypt(&keys, &sealed);
let (decrypted, idx) = result.expect("third key must authenticate its own seal");
assert_eq!(decrypted, plaintext);
assert_eq!(idx, 2);
}
#[test]
fn try_decrypt_rejects_inputs_under_min_len_ac6_2() {
let key = key_from(IKM_PRIMARY);
let keys = [key];
assert!(try_decrypt(&keys, &[]).is_none());
assert!(try_decrypt(&keys, &[0u8; 27]).is_none());
assert!(try_decrypt(&keys, &[0u8; 28]).is_none());
}
#[test]
fn seal_output_length_is_plaintext_plus_28() {
let rng = SystemRandom::new();
let key = key_from(IKM_PRIMARY);
for &size in &[0usize, 1, 100, 3000] {
let plaintext = vec![0xABu8; size];
let sealed = seal(&key, &rng, &plaintext);
assert_eq!(
sealed.len(),
size + NONCE_LEN + TAG_LEN,
"seal output length wrong for plaintext size {size}"
);
}
}
#[test]
fn seal_then_decrypt_round_trip() {
let rng = SystemRandom::new();
let key = key_from(IKM_PRIMARY);
let keys = [key_from(IKM_PRIMARY)];
for size in [0usize, 1, 16, 100, 1024, 3000] {
let plaintext: Vec<u8> = (0..size).map(|i| (i % 251) as u8).collect();
let sealed = seal(&key, &rng, &plaintext);
let result = try_decrypt(&keys, &sealed);
let (decrypted, idx) = result.expect("seal -> decrypt must round-trip");
assert_eq!(decrypted, plaintext, "round-trip mismatch at size {size}");
assert_eq!(idx, 0);
}
}
#[test]
fn seal_produces_distinct_nonces() {
let rng = SystemRandom::new();
let key = key_from(IKM_PRIMARY);
let plaintext = b"same plaintext every time";
let mut nonces: HashSet<[u8; NONCE_LEN]> = HashSet::with_capacity(100);
for _ in 0..100 {
let sealed = seal(&key, &rng, plaintext);
let nonce: [u8; NONCE_LEN] = sealed[..NONCE_LEN]
.try_into()
.expect("seal output starts with NONCE_LEN bytes");
assert!(
nonces.insert(nonce),
"duplicate nonce produced across 100 seals"
);
}
assert_eq!(nonces.len(), 100);
}
#[test]
fn derive_is_deterministic_across_invocations() {
let rng = SystemRandom::new();
let plaintext = b"determinism probe";
let key1 = key_from(IKM_PRIMARY);
let sealed = seal(&key1, &rng, plaintext);
let key2 = key_from(IKM_PRIMARY);
let result = try_decrypt(&[key2], &sealed);
let (decrypted, idx) = result.expect("re-derived key must authenticate the same seal");
assert_eq!(decrypted, plaintext);
assert_eq!(idx, 0);
}
#[test]
fn derive_domain_separates_via_info_string() {
let rng = SystemRandom::new();
let production_key = key_from(IKM_PRIMARY);
let plaintext = b"sealed under production info";
let sealed = seal(&production_key, &rng, plaintext);
let alt_info: &[u8] = b"alt-info-string-not-the-one-used";
let alt_key = derive_with_alt_info(IKM_PRIMARY, alt_info);
let result = try_decrypt(&[alt_key], &sealed);
assert!(
result.is_none(),
"key derived with a different info must not authenticate"
);
let ok = try_decrypt(&[key_from(IKM_PRIMARY)], &sealed);
assert!(ok.is_some());
}
fn derive_with_alt_info(ikm: &[u8], info: &[u8]) -> DerivedKey {
let salt_digest = digest::digest(&digest::SHA256, HKDF_SALT_CONTEXT);
let salt = hkdf::Salt::new(hkdf::HKDF_SHA256, salt_digest.as_ref());
let prk = salt.extract(ikm);
let info_parts = [info];
let okm = prk.expand(&info_parts, &aead::CHACHA20_POLY1305).unwrap();
let mut raw = [0u8; DERIVED_KEY_LEN];
okm.fill(&mut raw).unwrap();
let unbound = aead::UnboundKey::new(&aead::CHACHA20_POLY1305, &raw).unwrap();
DerivedKey {
sealing: aead::LessSafeKey::new(unbound),
}
}
}