use crate::identity::{AgentCertificate, AgentId, UserId};
use crate::mls::{MlsError, Result};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
fn agent_id_to_member_id(agent_id: &AgentId) -> saorsa_mls::MemberId {
let bytes: [u8; 16] = agent_id.as_bytes()[..16]
.try_into()
.expect("AgentId is always 32 bytes");
saorsa_mls::MemberId::from_bytes(bytes)
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct MlsGroupContext {
group_id: Vec<u8>,
epoch: u64,
tree_hash: Vec<u8>,
confirmed_transcript_hash: Vec<u8>,
}
impl MlsGroupContext {
#[must_use]
pub fn new(group_id: Vec<u8>) -> Self {
Self {
group_id,
epoch: 0,
tree_hash: Vec::new(),
confirmed_transcript_hash: Vec::new(),
}
}
#[must_use]
pub(crate) fn new_with_material(
group_id: Vec<u8>,
epoch: u64,
tree_hash: Vec<u8>,
confirmed_transcript_hash: Vec<u8>,
) -> Self {
Self {
group_id,
epoch,
tree_hash,
confirmed_transcript_hash,
}
}
#[must_use]
pub fn group_id(&self) -> &[u8] {
&self.group_id
}
#[must_use]
pub fn epoch(&self) -> u64 {
self.epoch
}
#[must_use]
pub fn tree_hash(&self) -> &[u8] {
&self.tree_hash
}
#[must_use]
pub fn confirmed_transcript_hash(&self) -> &[u8] {
&self.confirmed_transcript_hash
}
fn increment_epoch(&mut self) {
self.epoch = self.epoch.saturating_add(1);
}
fn update_crypto_material(&mut self, tree_hash: Vec<u8>, transcript_hash: Vec<u8>) {
self.tree_hash = tree_hash;
self.confirmed_transcript_hash = transcript_hash;
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct MlsMemberInfo {
agent_id: AgentId,
user_id: Option<UserId>,
certificate: Option<AgentCertificate>,
join_epoch: u64,
}
impl MlsMemberInfo {
#[must_use]
pub fn new(agent_id: AgentId, join_epoch: u64) -> Self {
Self {
agent_id,
user_id: None,
certificate: None,
join_epoch,
}
}
#[must_use]
pub fn new_with_user(
agent_id: AgentId,
user_id: UserId,
certificate: AgentCertificate,
join_epoch: u64,
) -> Self {
Self {
agent_id,
user_id: Some(user_id),
certificate: Some(certificate),
join_epoch,
}
}
#[must_use]
pub fn agent_id(&self) -> &AgentId {
&self.agent_id
}
#[must_use]
pub fn user_id(&self) -> Option<&UserId> {
self.user_id.as_ref()
}
#[must_use]
pub fn certificate(&self) -> Option<&AgentCertificate> {
self.certificate.as_ref()
}
#[must_use]
pub fn join_epoch(&self) -> u64 {
self.join_epoch
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum CommitOperation {
AddMember(AgentId),
RemoveMember(AgentId),
UpdateKeys,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct MlsCommit {
group_id: Vec<u8>,
epoch: u64,
operations: Vec<CommitOperation>,
new_tree_hash: Vec<u8>,
new_transcript_hash: Vec<u8>,
}
impl MlsCommit {
#[must_use]
pub fn new(
group_id: Vec<u8>,
epoch: u64,
operations: Vec<CommitOperation>,
new_tree_hash: Vec<u8>,
new_transcript_hash: Vec<u8>,
) -> Self {
Self {
group_id,
epoch,
operations,
new_tree_hash,
new_transcript_hash,
}
}
#[must_use]
pub fn group_id(&self) -> &[u8] {
&self.group_id
}
#[must_use]
pub fn epoch(&self) -> u64 {
self.epoch
}
#[must_use]
pub fn operations(&self) -> &[CommitOperation] {
&self.operations
}
#[must_use]
pub fn new_tree_hash(&self) -> &[u8] {
&self.new_tree_hash
}
#[must_use]
pub fn new_transcript_hash(&self) -> &[u8] {
&self.new_transcript_hash
}
}
#[derive(Debug)]
pub struct MlsGroup {
group_id: Vec<u8>,
inner: saorsa_mls::MlsGroup,
context: MlsGroupContext,
members: HashMap<AgentId, MlsMemberInfo>,
agent_to_member: HashMap<AgentId, saorsa_mls::MemberId>,
member_to_agent: HashMap<saorsa_mls::MemberId, AgentId>,
pending_commits: Vec<MlsCommit>,
epoch: u64,
}
impl MlsGroup {
pub async fn new(group_id: Vec<u8>, initiator: AgentId) -> Result<Self> {
let member_id = agent_id_to_member_id(&initiator);
let identity = saorsa_mls::MemberIdentity::generate(member_id)
.map_err(|e| MlsError::SaorsaMls(format!("identity generation: {e}")))?;
let config = saorsa_mls::GroupConfig::default();
let inner = saorsa_mls::MlsGroup::new(config, identity)
.await
.map_err(|e| MlsError::SaorsaMls(format!("group creation: {e}")))?;
let context = MlsGroupContext::new(group_id.clone());
let mut members = HashMap::new();
members.insert(initiator, MlsMemberInfo::new(initiator, 0));
let mut agent_to_member = HashMap::new();
agent_to_member.insert(initiator, member_id);
let mut member_to_agent = HashMap::new();
member_to_agent.insert(member_id, initiator);
Ok(Self {
group_id,
inner,
context,
members,
agent_to_member,
member_to_agent,
pending_commits: Vec::new(),
epoch: 0,
})
}
#[must_use]
pub fn group_id(&self) -> &[u8] {
&self.group_id
}
#[must_use]
pub fn current_epoch(&self) -> u64 {
self.epoch
}
#[must_use]
pub fn context(&self) -> &MlsGroupContext {
&self.context
}
#[must_use]
pub fn members(&self) -> &HashMap<AgentId, MlsMemberInfo> {
&self.members
}
#[must_use]
pub fn is_member(&self, agent_id: &AgentId) -> bool {
self.members.contains_key(agent_id)
}
pub async fn add_member(&mut self, member: AgentId) -> Result<MlsCommit> {
if self.members.contains_key(&member) {
return Err(MlsError::MlsOperation(format!(
"agent {:?} is already a member",
member.as_bytes()
)));
}
let member_id = agent_id_to_member_id(&member);
let identity = saorsa_mls::MemberIdentity::generate(member_id)
.map_err(|e| MlsError::SaorsaMls(format!("identity generation: {e}")))?;
let _welcome = self
.inner
.add_member(&identity)
.await
.map_err(|e| MlsError::SaorsaMls(format!("add_member: {e}")))?;
self.agent_to_member.insert(member, member_id);
self.member_to_agent.insert(member_id, member);
let operations = vec![CommitOperation::AddMember(member)];
let new_tree_hash =
blake3::hash(&[self.group_id.as_slice(), &self.epoch.to_le_bytes(), b"tree"].concat())
.as_bytes()
.to_vec();
let new_transcript_hash = blake3::hash(
&[
self.group_id.as_slice(),
&self.epoch.to_le_bytes(),
b"transcript",
]
.concat(),
)
.as_bytes()
.to_vec();
let commit = MlsCommit::new(
self.group_id.clone(),
self.epoch,
operations,
new_tree_hash.clone(),
new_transcript_hash.clone(),
);
self.members
.insert(member, MlsMemberInfo::new(member, self.epoch + 1));
self.epoch = self.epoch.saturating_add(1);
self.context.increment_epoch();
self.context
.update_crypto_material(new_tree_hash, new_transcript_hash);
Ok(commit)
}
pub async fn remove_member(&mut self, member: AgentId) -> Result<MlsCommit> {
if !self.members.contains_key(&member) {
return Err(MlsError::MemberNotInGroup(format!(
"{:?}",
member.as_bytes()
)));
}
let member_id = agent_id_to_member_id(&member);
self.inner
.remove_member(&member_id)
.await
.map_err(|e| MlsError::SaorsaMls(format!("remove_member: {e}")))?;
self.agent_to_member.remove(&member);
self.member_to_agent.remove(&member_id);
let operations = vec![CommitOperation::RemoveMember(member)];
let new_tree_hash =
blake3::hash(&[self.group_id.as_slice(), &self.epoch.to_le_bytes(), b"tree"].concat())
.as_bytes()
.to_vec();
let new_transcript_hash = blake3::hash(
&[
self.group_id.as_slice(),
&self.epoch.to_le_bytes(),
b"transcript",
]
.concat(),
)
.as_bytes()
.to_vec();
let commit = MlsCommit::new(
self.group_id.clone(),
self.epoch,
operations,
new_tree_hash.clone(),
new_transcript_hash.clone(),
);
self.members.remove(&member);
self.epoch = self.epoch.saturating_add(1);
self.context.increment_epoch();
self.context
.update_crypto_material(new_tree_hash, new_transcript_hash);
Ok(commit)
}
pub fn commit(&mut self) -> Result<MlsCommit> {
let operations = vec![CommitOperation::UpdateKeys];
let new_tree_hash = blake3::hash(
&[
self.group_id.as_slice(),
&self.epoch.to_le_bytes(),
b"rotate",
]
.concat(),
)
.as_bytes()
.to_vec();
let new_transcript_hash = blake3::hash(
&[
self.group_id.as_slice(),
&self.epoch.to_le_bytes(),
b"transcript-rotate",
]
.concat(),
)
.as_bytes()
.to_vec();
let commit = MlsCommit::new(
self.group_id.clone(),
self.epoch,
operations,
new_tree_hash,
new_transcript_hash,
);
self.pending_commits.push(commit.clone());
Ok(commit)
}
pub fn apply_commit(&mut self, commit: &MlsCommit) -> Result<()> {
if commit.group_id != self.group_id {
return Err(MlsError::MlsOperation(
"commit is for a different group".to_string(),
));
}
if commit.epoch != self.epoch {
return Err(MlsError::EpochMismatch {
current: self.epoch,
received: commit.epoch,
});
}
for operation in &commit.operations {
match operation {
CommitOperation::AddMember(agent_id) => {
if self.members.contains_key(agent_id) {
return Err(MlsError::MlsOperation(format!(
"cannot add existing member {:?}",
agent_id.as_bytes()
)));
}
self.members
.insert(*agent_id, MlsMemberInfo::new(*agent_id, self.epoch + 1));
}
CommitOperation::RemoveMember(agent_id) => {
if self.members.remove(agent_id).is_none() {
return Err(MlsError::MemberNotInGroup(format!(
"{:?}",
agent_id.as_bytes()
)));
}
}
CommitOperation::UpdateKeys => {}
}
}
self.epoch = self.epoch.saturating_add(1);
self.context.increment_epoch();
self.context.update_crypto_material(
commit.new_tree_hash.clone(),
commit.new_transcript_hash.clone(),
);
self.pending_commits
.retain(|c| c.epoch != commit.epoch || c.group_id != commit.group_id);
Ok(())
}
pub fn encrypt_message(&self, plaintext: &[u8]) -> Result<Vec<u8>> {
let msg = self
.inner
.encrypt_message(plaintext)
.map_err(|e| MlsError::EncryptionError(e.to_string()))?;
serde_json::to_vec(&msg)
.map_err(|e| MlsError::EncryptionError(format!("serialization: {e}")))
}
pub fn decrypt_message(&self, ciphertext: &[u8]) -> Result<Vec<u8>> {
let msg: saorsa_mls::ApplicationMessage = serde_json::from_slice(ciphertext)
.map_err(|e| MlsError::DecryptionError(format!("deserialization: {e}")))?;
self.inner
.decrypt_message(&msg)
.map_err(|e| MlsError::DecryptionError(e.to_string()))
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test_agent_id(id: u8) -> AgentId {
let mut bytes = [0u8; 32];
bytes[0] = id;
AgentId(bytes)
}
#[tokio::test]
async fn test_group_creation() {
let group_id = b"test-group".to_vec();
let initiator = test_agent_id(1);
let group = MlsGroup::new(group_id.clone(), initiator).await;
assert!(group.is_ok());
let group = group.unwrap();
assert_eq!(group.group_id(), b"test-group");
assert_eq!(group.current_epoch(), 0);
assert_eq!(group.members().len(), 1);
assert!(group.is_member(&initiator));
}
#[tokio::test]
async fn test_add_member() {
let group_id = b"test-group".to_vec();
let initiator = test_agent_id(1);
let new_member = test_agent_id(2);
let mut group = MlsGroup::new(group_id, initiator).await.unwrap();
let commit = group.add_member(new_member).await;
assert!(commit.is_ok());
let commit = commit.unwrap();
assert_eq!(commit.epoch(), 0);
assert_eq!(commit.operations().len(), 1);
assert_eq!(group.current_epoch(), 1);
assert_eq!(group.members().len(), 2);
assert!(group.is_member(&new_member));
}
#[tokio::test]
async fn test_add_duplicate_member() {
let group_id = b"test-group".to_vec();
let initiator = test_agent_id(1);
let mut group = MlsGroup::new(group_id, initiator).await.unwrap();
let result = group.add_member(initiator).await;
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), MlsError::MlsOperation(_)));
}
#[tokio::test]
async fn test_remove_member() {
let group_id = b"test-group".to_vec();
let initiator = test_agent_id(1);
let member = test_agent_id(2);
let mut group = MlsGroup::new(group_id, initiator).await.unwrap();
let _ = group.add_member(member).await.unwrap();
assert_eq!(group.members().len(), 2);
let commit = group.remove_member(member).await;
assert!(commit.is_ok());
assert_eq!(group.current_epoch(), 2);
assert_eq!(group.members().len(), 1);
assert!(!group.is_member(&member));
}
#[tokio::test]
async fn test_remove_nonexistent_member() {
let group_id = b"test-group".to_vec();
let initiator = test_agent_id(1);
let nonexistent = test_agent_id(99);
let mut group = MlsGroup::new(group_id, initiator).await.unwrap();
let result = group.remove_member(nonexistent).await;
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), MlsError::MemberNotInGroup(_)));
}
#[tokio::test]
async fn test_key_rotation() {
let group_id = b"test-group".to_vec();
let initiator = test_agent_id(1);
let mut group = MlsGroup::new(group_id, initiator).await.unwrap();
let initial_epoch = group.current_epoch();
let commit = group.commit().unwrap();
group.apply_commit(&commit).unwrap();
assert_eq!(group.current_epoch(), initial_epoch + 1);
}
#[tokio::test]
async fn test_epoch_mismatch() {
let group_id = b"test-group".to_vec();
let initiator = test_agent_id(1);
let mut group = MlsGroup::new(group_id.clone(), initiator).await.unwrap();
let wrong_commit = MlsCommit::new(
group_id,
999,
vec![CommitOperation::UpdateKeys],
vec![],
vec![],
);
let result = group.apply_commit(&wrong_commit);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
MlsError::EpochMismatch { .. }
));
}
#[tokio::test]
async fn test_encrypt_decrypt_message() {
let group_id = b"test-encrypt".to_vec();
let initiator = test_agent_id(1);
let group = MlsGroup::new(group_id, initiator).await.unwrap();
let plaintext = b"Hello, MLS with PQC!";
let ciphertext = group.encrypt_message(plaintext).unwrap();
assert_ne!(ciphertext, plaintext);
let decrypted = group.decrypt_message(&ciphertext).unwrap();
assert_eq!(decrypted, plaintext);
}
#[tokio::test]
async fn test_context_updates_on_commit() {
let group_id = b"test-group".to_vec();
let initiator = test_agent_id(1);
let mut group = MlsGroup::new(group_id, initiator).await.unwrap();
let initial_tree_hash = group.context().tree_hash().to_vec();
let commit = group.commit().unwrap();
group.apply_commit(&commit).unwrap();
assert_ne!(group.context().tree_hash(), initial_tree_hash.as_slice());
assert_eq!(group.context().epoch(), 1);
}
}