use crate::mls::{MlsGroup, Result};
use blake3;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct MlsKeySchedule {
epoch: u64,
psk_id_hash: Vec<u8>,
secret: Vec<u8>,
key: Vec<u8>,
base_nonce: Vec<u8>,
}
impl MlsKeySchedule {
pub fn from_group(group: &MlsGroup) -> Result<Self> {
let epoch = group.current_epoch();
let context = group.context();
let mut psk_material = Vec::new();
psk_material.extend_from_slice(context.group_id());
psk_material.extend_from_slice(&epoch.to_le_bytes());
let psk_id_hash = blake3::hash(&psk_material).as_bytes().to_vec();
let mut secret_material = Vec::new();
secret_material.extend_from_slice(context.group_id()); secret_material.extend_from_slice(context.tree_hash());
secret_material.extend_from_slice(context.confirmed_transcript_hash());
secret_material.extend_from_slice(&epoch.to_le_bytes());
let secret = blake3::hash(&secret_material).as_bytes().to_vec();
let mut key_material = Vec::new();
key_material.extend_from_slice(&secret);
key_material.extend_from_slice(b"encryption");
key_material.extend_from_slice(&epoch.to_le_bytes());
let key_hash = blake3::hash(&key_material);
let key = key_hash.as_bytes()[..32].to_vec();
let mut nonce_material = Vec::new();
nonce_material.extend_from_slice(&secret);
nonce_material.extend_from_slice(b"nonce");
nonce_material.extend_from_slice(&epoch.to_le_bytes());
let nonce_hash = blake3::hash(&nonce_material);
let base_nonce = nonce_hash.as_bytes()[..12].to_vec();
Ok(Self {
epoch,
psk_id_hash,
secret,
key,
base_nonce,
})
}
#[must_use]
pub fn encryption_key(&self) -> &[u8] {
&self.key
}
#[must_use]
pub fn base_nonce(&self) -> &[u8] {
&self.base_nonce
}
#[must_use]
pub fn derive_nonce(&self, counter: u64) -> Vec<u8> {
let counter_bytes = counter.to_le_bytes();
let mut nonce = self.base_nonce.clone();
for (i, byte) in counter_bytes.iter().enumerate() {
if i + 4 < nonce.len() {
nonce[i + 4] ^= byte;
}
}
nonce
}
#[must_use]
pub fn epoch(&self) -> u64 {
self.epoch
}
#[must_use]
pub fn psk_id_hash(&self) -> &[u8] {
&self.psk_id_hash
}
#[must_use]
pub fn secret(&self) -> &[u8] {
&self.secret
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::identity::AgentId;
fn test_agent_id(id: u8) -> AgentId {
let mut bytes = [0u8; 32];
bytes[0] = id;
AgentId(bytes)
}
#[tokio::test]
async fn test_key_derivation_from_group() {
let group_id = b"test-group".to_vec();
let initiator = test_agent_id(1);
let group = MlsGroup::new(group_id, initiator).await.unwrap();
let schedule = MlsKeySchedule::from_group(&group);
assert!(schedule.is_ok());
let schedule = schedule.unwrap();
assert_eq!(schedule.encryption_key().len(), 32); assert_eq!(schedule.base_nonce().len(), 12); assert_eq!(schedule.epoch(), 0);
}
#[tokio::test]
async fn test_key_derivation_is_deterministic() {
let group_id = b"test-group".to_vec();
let initiator = test_agent_id(1);
let group = MlsGroup::new(group_id, initiator).await.unwrap();
let schedule1 = MlsKeySchedule::from_group(&group).unwrap();
let schedule2 = MlsKeySchedule::from_group(&group).unwrap();
assert_eq!(schedule1.encryption_key(), schedule2.encryption_key());
assert_eq!(schedule1.base_nonce(), schedule2.base_nonce());
assert_eq!(schedule1.secret(), schedule2.secret());
assert_eq!(schedule1.psk_id_hash(), schedule2.psk_id_hash());
}
#[tokio::test]
async fn test_different_epochs_produce_different_keys() {
let group_id = b"test-group".to_vec();
let initiator = test_agent_id(1);
let mut group = MlsGroup::new(group_id, initiator).await.unwrap();
let schedule_epoch0 = MlsKeySchedule::from_group(&group).unwrap();
let commit = group.commit().unwrap();
group.apply_commit(&commit).unwrap();
assert_eq!(group.current_epoch(), 1);
let schedule_epoch1 = MlsKeySchedule::from_group(&group).unwrap();
assert_ne!(
schedule_epoch0.encryption_key(),
schedule_epoch1.encryption_key()
);
assert_ne!(schedule_epoch0.base_nonce(), schedule_epoch1.base_nonce());
assert_ne!(schedule_epoch0.secret(), schedule_epoch1.secret());
assert_ne!(schedule_epoch0.epoch(), schedule_epoch1.epoch());
}
#[tokio::test]
async fn test_nonce_derivation_is_deterministic() {
let group_id = b"test-group".to_vec();
let initiator = test_agent_id(1);
let group = MlsGroup::new(group_id, initiator).await.unwrap();
let schedule = MlsKeySchedule::from_group(&group).unwrap();
let counter = 42;
let nonce1 = schedule.derive_nonce(counter);
let nonce2 = schedule.derive_nonce(counter);
assert_eq!(nonce1, nonce2);
assert_eq!(nonce1.len(), 12);
}
#[tokio::test]
async fn test_nonce_unique_per_counter() {
let group_id = b"test-group".to_vec();
let initiator = test_agent_id(1);
let group = MlsGroup::new(group_id, initiator).await.unwrap();
let schedule = MlsKeySchedule::from_group(&group).unwrap();
let nonce0 = schedule.derive_nonce(0);
let nonce1 = schedule.derive_nonce(1);
let nonce100 = schedule.derive_nonce(100);
assert_ne!(nonce0, nonce1);
assert_ne!(nonce1, nonce100);
assert_ne!(nonce0, nonce100);
assert_eq!(nonce0.len(), 12);
assert_eq!(nonce1.len(), 12);
assert_eq!(nonce100.len(), 12);
}
#[tokio::test]
async fn test_nonce_xor_with_counter() {
let group_id = b"test-group".to_vec();
let initiator = test_agent_id(1);
let group = MlsGroup::new(group_id, initiator).await.unwrap();
let schedule = MlsKeySchedule::from_group(&group).unwrap();
let base = schedule.base_nonce();
let nonce0 = schedule.derive_nonce(0);
assert_eq!(base, nonce0.as_slice());
let nonce1 = schedule.derive_nonce(1);
assert_ne!(base, nonce1.as_slice());
}
#[tokio::test]
async fn test_different_groups_produce_different_keys() {
let initiator = test_agent_id(1);
let group1 = MlsGroup::new(b"group-1".to_vec(), initiator).await.unwrap();
let group2 = MlsGroup::new(b"group-2".to_vec(), initiator).await.unwrap();
let schedule1 = MlsKeySchedule::from_group(&group1).unwrap();
let schedule2 = MlsKeySchedule::from_group(&group2).unwrap();
assert_ne!(schedule1.encryption_key(), schedule2.encryption_key());
assert_ne!(schedule1.base_nonce(), schedule2.base_nonce());
assert_ne!(schedule1.psk_id_hash(), schedule2.psk_id_hash());
}
#[tokio::test]
async fn test_key_schedule_accessors() {
let group_id = b"test-group".to_vec();
let initiator = test_agent_id(1);
let group = MlsGroup::new(group_id, initiator).await.unwrap();
let schedule = MlsKeySchedule::from_group(&group).unwrap();
assert_eq!(schedule.epoch(), 0);
assert!(!schedule.encryption_key().is_empty());
assert!(!schedule.base_nonce().is_empty());
assert!(!schedule.psk_id_hash().is_empty());
assert!(!schedule.secret().is_empty());
}
#[tokio::test]
async fn test_key_schedule_clone() {
let group_id = b"test-group".to_vec();
let initiator = test_agent_id(1);
let group = MlsGroup::new(group_id, initiator).await.unwrap();
let schedule1 = MlsKeySchedule::from_group(&group).unwrap();
let schedule2 = schedule1.clone();
assert_eq!(schedule1, schedule2);
assert_eq!(schedule1.encryption_key(), schedule2.encryption_key());
assert_eq!(schedule1.base_nonce(), schedule2.base_nonce());
}
}