use crate::identity::AgentId;
use crate::mls::{MlsCipher, MlsError, MlsGroup, MlsGroupContext, Result};
use blake3;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MlsWelcome {
group_id: Vec<u8>,
epoch: u64,
encrypted_group_secrets: HashMap<AgentId, Vec<u8>>,
tree: Vec<u8>,
confirmation_tag: Vec<u8>,
}
impl MlsWelcome {
pub fn create(group: &MlsGroup, invitee: &AgentId) -> Result<Self> {
let context = group.context();
let epoch = context.epoch();
let group_id = context.group_id().to_vec();
let invitee_key = Self::derive_invitee_key(invitee, &group_id, epoch);
let cipher = MlsCipher::new(invitee_key, vec![0u8; 12]);
let group_secrets = Self::serialize_group_secrets(context);
let aad = Self::build_aad(&group_id, epoch, invitee);
let encrypted_secrets = cipher.encrypt(&group_secrets, &aad, 0)?;
let mut encrypted_group_secrets = HashMap::new();
encrypted_group_secrets.insert(*invitee, encrypted_secrets);
let tree = Self::serialize_tree(context);
let confirmation_tag = Self::generate_confirmation_tag(context, invitee);
Ok(Self {
group_id,
epoch,
encrypted_group_secrets,
tree,
confirmation_tag,
})
}
pub fn verify(&self) -> Result<()> {
if self.confirmation_tag.len() != 32 {
return Err(MlsError::MlsOperation(
"invalid confirmation tag length".to_string(),
));
}
if self.group_id.is_empty() {
return Err(MlsError::MlsOperation("empty group_id".to_string()));
}
if self.tree.is_empty() {
return Err(MlsError::MlsOperation("empty tree".to_string()));
}
if self.encrypted_group_secrets.is_empty() {
return Err(MlsError::MlsOperation("no encrypted secrets".to_string()));
}
Ok(())
}
pub fn accept(&self, agent_id: &AgentId) -> Result<MlsGroupContext> {
self.verify()?;
let encrypted_secrets = self
.encrypted_group_secrets
.get(agent_id)
.ok_or_else(|| MlsError::MemberNotInGroup(format!("{:?}", agent_id)))?;
let invitee_key = Self::derive_invitee_key(agent_id, &self.group_id, self.epoch);
let cipher = MlsCipher::new(invitee_key, vec![0u8; 12]);
let aad = Self::build_aad(&self.group_id, self.epoch, agent_id);
let group_secrets = cipher.decrypt(encrypted_secrets, &aad, 0)?;
Self::deserialize_group_context(&group_secrets, &self.group_id, self.epoch)
}
fn derive_invitee_key(invitee: &AgentId, group_id: &[u8], epoch: u64) -> Vec<u8> {
let mut key_material = Vec::new();
key_material.extend_from_slice(invitee.as_bytes());
key_material.extend_from_slice(group_id);
key_material.extend_from_slice(&epoch.to_le_bytes());
key_material.extend_from_slice(b"welcome-key");
let hash = blake3::hash(&key_material);
hash.as_bytes()[..32].to_vec()
}
fn build_aad(group_id: &[u8], epoch: u64, invitee: &AgentId) -> Vec<u8> {
let mut aad = Vec::new();
aad.extend_from_slice(b"MLS-Welcome");
aad.extend_from_slice(group_id);
aad.extend_from_slice(&epoch.to_le_bytes());
aad.extend_from_slice(invitee.as_bytes());
aad
}
fn serialize_group_secrets(context: &MlsGroupContext) -> Vec<u8> {
let mut secrets = Vec::new();
secrets.extend_from_slice(context.group_id());
secrets.extend_from_slice(&context.epoch().to_le_bytes());
secrets.extend_from_slice(context.tree_hash());
secrets.extend_from_slice(context.confirmed_transcript_hash());
secrets
}
fn serialize_tree(context: &MlsGroupContext) -> Vec<u8> {
let mut tree = Vec::new();
tree.extend_from_slice(b"TREE");
tree.extend_from_slice(&(context.group_id().len() as u32).to_le_bytes());
tree.extend_from_slice(context.group_id());
tree.extend_from_slice(context.tree_hash());
tree
}
fn generate_confirmation_tag(context: &MlsGroupContext, invitee: &AgentId) -> Vec<u8> {
let mut tag_material = Vec::new();
tag_material.extend_from_slice(b"MLS-Welcome-Tag");
tag_material.extend_from_slice(context.group_id());
tag_material.extend_from_slice(&context.epoch().to_le_bytes());
tag_material.extend_from_slice(invitee.as_bytes());
tag_material.extend_from_slice(context.tree_hash());
tag_material.extend_from_slice(context.confirmed_transcript_hash());
blake3::hash(&tag_material).as_bytes().to_vec()
}
fn deserialize_group_context(
secrets: &[u8],
expected_group_id: &[u8],
expected_epoch: u64,
) -> Result<MlsGroupContext> {
if secrets.len() < expected_group_id.len() + 8 {
return Err(MlsError::MlsOperation(
"invalid group secrets length".to_string(),
));
}
let mut offset = 0;
let group_id_end = offset + expected_group_id.len();
let group_id = secrets[offset..group_id_end].to_vec();
if group_id != expected_group_id {
return Err(MlsError::MlsOperation("group ID mismatch".to_string()));
}
offset = group_id_end;
let epoch_bytes: [u8; 8] = secrets[offset..offset + 8]
.try_into()
.map_err(|_| MlsError::MlsOperation("invalid epoch bytes".to_string()))?;
let epoch = u64::from_le_bytes(epoch_bytes);
if epoch != expected_epoch {
return Err(MlsError::EpochMismatch {
current: expected_epoch,
received: epoch,
});
}
offset += 8;
let remaining = secrets.len() - offset;
let tree_hash_len = remaining / 2;
let tree_hash = secrets[offset..offset + tree_hash_len].to_vec();
offset += tree_hash_len;
let confirmed_transcript_hash = secrets[offset..].to_vec();
Ok(MlsGroupContext::new_with_material(
group_id,
epoch,
tree_hash,
confirmed_transcript_hash,
))
}
#[must_use]
pub fn group_id(&self) -> &[u8] {
&self.group_id
}
#[must_use]
pub fn epoch(&self) -> u64 {
self.epoch
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::identity::Identity;
async fn create_test_group() -> (MlsGroup, AgentId) {
let identity = Identity::generate().expect("identity generation failed");
let agent_id = identity.agent_id();
let group_id = b"test-group".to_vec();
let group = MlsGroup::new(group_id, agent_id)
.await
.expect("group creation failed");
(group, agent_id)
}
fn create_test_invitee() -> AgentId {
let identity = Identity::generate().expect("identity generation failed");
identity.agent_id()
}
#[tokio::test]
async fn test_welcome_creation() {
let (group, _creator) = create_test_group().await;
let invitee = create_test_invitee();
let welcome = MlsWelcome::create(&group, &invitee).expect("welcome creation failed");
assert_eq!(welcome.group_id(), group.context().group_id());
assert_eq!(welcome.epoch(), group.current_epoch());
assert!(welcome.encrypted_group_secrets.contains_key(&invitee));
assert!(!welcome.tree.is_empty());
assert_eq!(welcome.confirmation_tag.len(), 32);
}
#[tokio::test]
async fn test_welcome_verification() {
let (group, _creator) = create_test_group().await;
let invitee = create_test_invitee();
let welcome = MlsWelcome::create(&group, &invitee).expect("welcome creation failed");
assert!(welcome.verify().is_ok());
}
#[tokio::test]
async fn test_welcome_verification_rejects_empty_group_id() {
let (group, _creator) = create_test_group().await;
let invitee = create_test_invitee();
let mut welcome = MlsWelcome::create(&group, &invitee).expect("welcome creation failed");
welcome.group_id = Vec::new();
assert!(welcome.verify().is_err());
}
#[tokio::test]
async fn test_welcome_verification_rejects_empty_tree() {
let (group, _creator) = create_test_group().await;
let invitee = create_test_invitee();
let mut welcome = MlsWelcome::create(&group, &invitee).expect("welcome creation failed");
welcome.tree = Vec::new();
assert!(welcome.verify().is_err());
}
#[tokio::test]
async fn test_welcome_verification_rejects_invalid_tag() {
let (group, _creator) = create_test_group().await;
let invitee = create_test_invitee();
let mut welcome = MlsWelcome::create(&group, &invitee).expect("welcome creation failed");
welcome.confirmation_tag = vec![0u8; 16];
assert!(welcome.verify().is_err());
}
#[tokio::test]
async fn test_welcome_accept_by_invitee() {
let (group, _creator) = create_test_group().await;
let invitee = create_test_invitee();
let welcome = MlsWelcome::create(&group, &invitee).expect("welcome creation failed");
let context = welcome.accept(&invitee).expect("accept failed");
assert_eq!(context.group_id(), group.context().group_id());
assert_eq!(context.epoch(), group.current_epoch());
}
#[tokio::test]
async fn test_welcome_accept_rejects_wrong_agent() {
let (group, _creator) = create_test_group().await;
let invitee = create_test_invitee();
let wrong_agent = create_test_invitee();
let welcome = MlsWelcome::create(&group, &invitee).expect("welcome creation failed");
let result = welcome.accept(&wrong_agent);
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), MlsError::MemberNotInGroup(_)));
}
#[test]
fn test_invitee_key_derivation_is_deterministic() {
let invitee = create_test_invitee();
let group_id = b"test-group";
let epoch = 5;
let key1 = MlsWelcome::derive_invitee_key(&invitee, group_id, epoch);
let key2 = MlsWelcome::derive_invitee_key(&invitee, group_id, epoch);
assert_eq!(key1, key2);
assert_eq!(key1.len(), 32);
}
#[tokio::test]
async fn test_invitee_key_varies_with_epoch() {
let invitee = create_test_invitee();
let group_id = b"test-group";
let key1 = MlsWelcome::derive_invitee_key(&invitee, group_id, 1);
let key2 = MlsWelcome::derive_invitee_key(&invitee, group_id, 2);
assert_ne!(key1, key2);
}
#[tokio::test]
async fn test_invitee_key_varies_with_agent() {
let invitee1 = create_test_invitee();
let invitee2 = create_test_invitee();
let group_id = b"test-group";
let epoch = 1;
let key1 = MlsWelcome::derive_invitee_key(&invitee1, group_id, epoch);
let key2 = MlsWelcome::derive_invitee_key(&invitee2, group_id, epoch);
assert_ne!(key1, key2);
}
#[tokio::test]
async fn test_welcome_serialization() {
let (group, _creator) = create_test_group().await;
let invitee = create_test_invitee();
let welcome = MlsWelcome::create(&group, &invitee).expect("welcome creation failed");
let serialized = bincode::serialize(&welcome).expect("serialization failed");
let deserialized: MlsWelcome =
bincode::deserialize(&serialized).expect("deserialization failed");
assert_eq!(deserialized.group_id(), welcome.group_id());
assert_eq!(deserialized.epoch(), welcome.epoch());
}
}