use crate::error::BuildError;
pub(crate) const MIN_KEY_BYTES: usize = 16;
pub(crate) const DERIVED_KEY_LEN: usize = 32;
pub(crate) const HKDF_SALT_CONTEXT: &[u8] = b"seshcookie-rs-v1-salt";
pub(crate) const HKDF_INFO: &[u8] = b"seshcookie-rs v1 ChaCha20-Poly1305 session";
#[derive(Clone)]
pub struct SessionKeys {
primary: Vec<u8>,
fallbacks: Vec<Vec<u8>>,
}
impl SessionKeys {
pub fn new(primary: &[u8]) -> Result<Self, BuildError> {
validate_ikm(primary)?;
Ok(Self {
primary: primary.to_vec(),
fallbacks: Vec::new(),
})
}
pub fn with_fallback(mut self, fallback: &[u8]) -> Result<Self, BuildError> {
validate_ikm(fallback)?;
self.fallbacks.push(fallback.to_vec());
Ok(self)
}
pub fn with_fallbacks<I, B>(mut self, iter: I) -> Result<Self, BuildError>
where
I: IntoIterator<Item = B>,
B: AsRef<[u8]>,
{
let collected: Vec<Vec<u8>> = iter.into_iter().map(|b| b.as_ref().to_vec()).collect();
for ikm in &collected {
validate_ikm(ikm)?;
}
self.fallbacks.extend(collected);
Ok(self)
}
#[allow(dead_code)] pub(crate) fn total_keys(&self) -> usize {
1 + self.fallbacks.len()
}
pub(crate) fn ikm_in_order(&self) -> impl Iterator<Item = &[u8]> {
std::iter::once(self.primary.as_slice()).chain(self.fallbacks.iter().map(Vec::as_slice))
}
}
fn validate_ikm(bytes: &[u8]) -> Result<(), BuildError> {
if bytes.is_empty() {
return Err(BuildError::EmptyKey);
}
if bytes.len() < MIN_KEY_BYTES {
return Err(BuildError::ShortKey { len: bytes.len() });
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn new_rejects_empty_primary_with_empty_key_ac7_3() {
let result = SessionKeys::new(&[]);
assert_eq!(result.err(), Some(BuildError::EmptyKey));
}
#[test]
fn with_fallback_rejects_empty_with_empty_key_ac7_3() {
let keys = SessionKeys::new(&[0u8; 16]).expect("16-byte primary is valid");
let result = keys.with_fallback(&[]);
assert_eq!(result.err(), Some(BuildError::EmptyKey));
}
#[test]
fn with_fallbacks_rejects_empty_entry_with_empty_key_ac7_3() {
let keys = SessionKeys::new(&[0u8; 16]).expect("16-byte primary is valid");
let result = keys.with_fallbacks([&[0u8; 16][..], &[][..]]);
assert_eq!(result.err(), Some(BuildError::EmptyKey));
}
#[test]
fn new_rejects_15_byte_primary_with_short_key_ac7_4() {
let result = SessionKeys::new(&[0u8; 15]);
assert_eq!(result.err(), Some(BuildError::ShortKey { len: 15 }));
}
#[test]
fn with_fallback_rejects_7_byte_with_short_key_ac7_4() {
let keys = SessionKeys::new(&[0u8; 16]).expect("16-byte primary is valid");
let result = keys.with_fallback(&[0u8; 7]);
assert_eq!(result.err(), Some(BuildError::ShortKey { len: 7 }));
}
#[test]
fn with_fallbacks_reports_short_entry_ac7_4() {
let keys = SessionKeys::new(&[0u8; 16]).expect("16-byte primary is valid");
let result = keys.with_fallbacks([&[0u8; 16][..], &[0u8; 5][..]]);
assert_eq!(result.err(), Some(BuildError::ShortKey { len: 5 }));
}
#[test]
fn new_accepts_16_byte_primary_ac7_5() {
let keys = SessionKeys::new(&[0u8; 16]).expect("16-byte primary is valid");
assert_eq!(keys.total_keys(), 1);
}
#[test]
fn with_fallback_accepts_16_byte_and_orders_correctly_ac7_5() {
let primary = [0xAAu8; 16];
let fallback = [0xBBu8; 16];
let keys = SessionKeys::new(&primary)
.expect("primary is valid")
.with_fallback(&fallback)
.expect("fallback is valid");
assert_eq!(keys.total_keys(), 2);
let collected: Vec<&[u8]> = keys.ikm_in_order().collect();
assert_eq!(collected.len(), 2);
assert_eq!(collected[0], &primary[..]);
assert_eq!(collected[1], &fallback[..]);
}
#[test]
fn with_fallbacks_accepts_slice_refs() {
let primary = [0u8; 16];
let f1 = [1u8; 16];
let f2 = [2u8; 16];
let keys = SessionKeys::new(&primary)
.expect("primary is valid")
.with_fallbacks([&f1[..], &f2[..]])
.expect("both fallbacks are valid");
assert_eq!(keys.total_keys(), 3);
}
#[test]
fn with_fallbacks_accepts_owned_vecs() {
let primary = [0u8; 16];
let f1: Vec<u8> = vec![1u8; 16];
let f2: Vec<u8> = vec![2u8; 16];
let keys = SessionKeys::new(&primary)
.expect("primary is valid")
.with_fallbacks(vec![f1, f2])
.expect("both fallbacks are valid");
assert_eq!(keys.total_keys(), 3);
}
#[test]
fn with_fallbacks_accepts_byte_arrays() {
let primary = [0u8; 16];
let arrays: [[u8; 16]; 2] = [[3u8; 16], [4u8; 16]];
let keys = SessionKeys::new(&primary)
.expect("primary is valid")
.with_fallbacks(arrays)
.expect("both fallbacks are valid");
assert_eq!(keys.total_keys(), 3);
}
#[test]
fn total_keys_tracks_fallback_count() {
let primary = [0u8; 16];
let mut keys = SessionKeys::new(&primary).expect("valid");
assert_eq!(keys.total_keys(), 1);
for n in 1..=5 {
keys = keys
.with_fallback(&[0u8; 16])
.expect("16-byte fallback is valid");
assert_eq!(keys.total_keys(), 1 + n);
}
}
#[test]
fn ikm_in_order_preserves_insertion_order() {
let primary = [0xAAu8; 16];
let f1 = [0xBBu8; 16];
let f2 = [0xCCu8; 16];
let f3 = [0xDDu8; 16];
let keys = SessionKeys::new(&primary)
.expect("primary is valid")
.with_fallback(&f1)
.expect("f1 is valid")
.with_fallbacks([&f2[..], &f3[..]])
.expect("f2, f3 are valid");
let collected: Vec<&[u8]> = keys.ikm_in_order().collect();
assert_eq!(collected, vec![&primary[..], &f1[..], &f2[..], &f3[..]]);
}
#[test]
fn new_accepts_17_byte_primary() {
assert!(SessionKeys::new(&[0u8; 17]).is_ok());
}
#[test]
fn new_rejects_1_byte_with_short_key_not_empty() {
let result = SessionKeys::new(&[0u8; 1]);
assert_eq!(result.err(), Some(BuildError::ShortKey { len: 1 }));
}
#[test]
fn build_error_display_messages() {
assert_eq!(
BuildError::EmptyKey.to_string(),
"secret key must not be empty"
);
assert_eq!(
BuildError::ShortKey { len: 7 }.to_string(),
"secret must be at least 16 bytes, got 7"
);
}
}