#![deny(missing_docs)]
use chacha20poly1305::{
ChaCha20Poly1305, Key, Nonce,
aead::{Aead, KeyInit},
};
use gbp_core::StreamType;
use openmls::prelude::tls_codec::DeserializeBytes as _;
use openmls::prelude::tls_codec::Serialize as _;
use openmls::prelude::*;
use openmls_basic_credential::SignatureKeyPair;
use openmls_rust_crypto::{MemoryStorage, OpenMlsRustCrypto};
use std::collections::HashMap;
pub const CIPHERSUITE: Ciphersuite = Ciphersuite::MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519;
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum StreamLabel {
Control,
Audio,
Text,
Signal,
}
impl StreamLabel {
pub fn as_str(self) -> &'static str {
match self {
Self::Control => "gbp/control",
Self::Audio => "gbp/audio",
Self::Text => "gbp/text",
Self::Signal => "gbp/signal",
}
}
}
pub fn label_for(st: StreamType) -> StreamLabel {
match st {
StreamType::Control => StreamLabel::Control,
StreamType::Audio => StreamLabel::Audio,
StreamType::Text => StreamLabel::Text,
StreamType::Signal => StreamLabel::Signal,
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ProcessedKind {
Commit,
Application,
Proposal,
External,
}
#[derive(Debug, thiserror::Error)]
pub enum MlsError {
#[error("openmls: {0}")]
OpenMls(String),
#[error("aead: {0}")]
Aead(String),
#[error("transition in progress: pending staged commit exists")]
TransitionInProgress,
}
pub struct MlsContext {
pub provider: OpenMlsRustCrypto,
pub signer: SignatureKeyPair,
pub group: MlsGroup,
pub credential: CredentialWithKey,
pub identity: Vec<u8>,
pub pending_staged: Option<StagedCommit>,
}
fn serialize_storage(s: &MemoryStorage) -> Result<Vec<u8>, MlsError> {
let map = s
.values
.read()
.map_err(|_| MlsError::OpenMls("storage lock poisoned".into()))?;
let mut out = Vec::new();
out.extend_from_slice(&(map.len() as u32).to_le_bytes());
for (k, v) in map.iter() {
out.extend_from_slice(&(k.len() as u32).to_le_bytes());
out.extend_from_slice(k);
out.extend_from_slice(&(v.len() as u32).to_le_bytes());
out.extend_from_slice(v);
}
Ok(out)
}
fn deserialize_storage(bytes: &[u8]) -> Result<HashMap<Vec<u8>, Vec<u8>>, MlsError> {
let mut cur = bytes;
fn rd_u32(cur: &mut &[u8]) -> Result<usize, MlsError> {
if cur.len() < 4 {
return Err(MlsError::OpenMls("truncated storage blob".into()));
}
let n = u32::from_le_bytes([cur[0], cur[1], cur[2], cur[3]]) as usize;
*cur = &cur[4..];
Ok(n)
}
fn rd_bytes<'a>(cur: &mut &'a [u8], len: usize) -> Result<&'a [u8], MlsError> {
if cur.len() < len {
return Err(MlsError::OpenMls("truncated storage blob".into()));
}
let (head, tail) = cur.split_at(len);
*cur = tail;
Ok(head)
}
let count = rd_u32(&mut cur)?;
let mut map = HashMap::with_capacity(count);
for _ in 0..count {
let klen = rd_u32(&mut cur)?;
let k = rd_bytes(&mut cur, klen)?.to_vec();
let vlen = rd_u32(&mut cur)?;
let v = rd_bytes(&mut cur, vlen)?.to_vec();
map.insert(k, v);
}
Ok(map)
}
impl MlsContext {
pub fn new_member(identity: &[u8]) -> Result<(Self, KeyPackageBundle), MlsError> {
let provider = OpenMlsRustCrypto::default();
let signer = SignatureKeyPair::new(CIPHERSUITE.signature_algorithm())
.map_err(|e| MlsError::OpenMls(format!("signer: {e:?}")))?;
signer
.store(provider.storage())
.map_err(|e| MlsError::OpenMls(format!("store signer: {e:?}")))?;
let credential = BasicCredential::new(identity.to_vec());
let credential_with_key = CredentialWithKey {
credential: credential.into(),
signature_key: signer.public().into(),
};
let kp_bundle = KeyPackage::builder()
.build(CIPHERSUITE, &provider, &signer, credential_with_key.clone())
.map_err(|e| MlsError::OpenMls(format!("kp: {e:?}")))?;
let cfg = MlsGroupCreateConfig::builder()
.ciphersuite(CIPHERSUITE)
.use_ratchet_tree_extension(true)
.build();
let group = MlsGroup::new(&provider, &signer, &cfg, credential_with_key.clone())
.map_err(|e| MlsError::OpenMls(format!("group: {e:?}")))?;
Ok((
Self {
provider,
signer,
group,
credential: credential_with_key,
identity: identity.to_vec(),
pending_staged: None,
},
kp_bundle,
))
}
pub fn invite_full(
&mut self,
key_packages: &[KeyPackage],
) -> Result<(Vec<u8>, Vec<u8>), MlsError> {
let (commit, welcome, _gi) = self
.group
.add_members(&self.provider, &self.signer, key_packages)
.map_err(|e| MlsError::OpenMls(format!("add_members: {e:?}")))?;
let commit_bytes = commit
.tls_serialize_detached()
.map_err(|e| MlsError::OpenMls(format!("commit serialize: {e:?}")))?;
let welcome_bytes = welcome
.tls_serialize_detached()
.map_err(|e| MlsError::OpenMls(format!("welcome serialize: {e:?}")))?;
Ok((commit_bytes, welcome_bytes))
}
pub fn invite(&mut self, key_packages: &[KeyPackage]) -> Result<Vec<u8>, MlsError> {
let (_commit, welcome) = self.invite_full(key_packages)?;
self.finalize_pending_commit()?;
Ok(welcome)
}
pub fn remove_members(&mut self, leaf_indices: &[u32]) -> Result<Vec<u8>, MlsError> {
let group_size = self.group.members().count() as u32;
for &idx in leaf_indices {
if idx >= group_size {
return Err(MlsError::OpenMls(format!(
"leaf_index {idx} out of range (group size {group_size})"
)));
}
}
let leaves: Vec<LeafNodeIndex> = leaf_indices
.iter()
.copied()
.map(LeafNodeIndex::new)
.collect();
let (commit, _welcome_opt, _gi) = self
.group
.remove_members(&self.provider, &self.signer, &leaves)
.map_err(|e| MlsError::OpenMls(format!("remove_members: {e:?}")))?;
commit
.tls_serialize_detached()
.map_err(|e| MlsError::OpenMls(format!("commit serialize: {e:?}")))
}
pub fn finalize_pending_commit(&mut self) -> Result<(), MlsError> {
if let Some(staged) = self.pending_staged.take() {
self.group
.merge_staged_commit(&self.provider, staged)
.map_err(|e| MlsError::OpenMls(format!("merge_staged: {e:?}")))?;
}
let _ = self.group.merge_pending_commit(&self.provider);
Ok(())
}
pub fn clear_pending_commit(&mut self) -> Result<(), MlsError> {
self.pending_staged = None;
self.group
.clear_pending_commit(self.provider.storage())
.map_err(|e| MlsError::OpenMls(format!("clear: {e:?}")))?;
Ok(())
}
pub fn process_message(&mut self, msg_bytes: &[u8]) -> Result<ProcessedKind, MlsError> {
let msg_in = MlsMessageIn::tls_deserialize_exact_bytes(msg_bytes)
.map_err(|e| MlsError::OpenMls(format!("msg parse: {e:?}")))?;
let protocol_msg = match msg_in.extract() {
MlsMessageBodyIn::PublicMessage(m) => ProtocolMessage::from(m),
MlsMessageBodyIn::PrivateMessage(m) => ProtocolMessage::from(m),
other => {
return Err(MlsError::OpenMls(format!(
"expected protocol message, got {other:?}"
)));
}
};
let processed = self
.group
.process_message(&self.provider, protocol_msg)
.map_err(|e| MlsError::OpenMls(format!("process: {e:?}")))?;
match processed.into_content() {
ProcessedMessageContent::StagedCommitMessage(staged) => {
if self.pending_staged.is_some() {
return Err(MlsError::TransitionInProgress);
}
self.pending_staged = Some(*staged);
Ok(ProcessedKind::Commit)
}
ProcessedMessageContent::ApplicationMessage(_) => Ok(ProcessedKind::Application),
ProcessedMessageContent::ProposalMessage(_) => Ok(ProcessedKind::Proposal),
ProcessedMessageContent::ExternalJoinProposalMessage(_) => Ok(ProcessedKind::External),
}
}
pub fn accept_welcome(&mut self, welcome_bytes: &[u8]) -> Result<(), MlsError> {
let msg_in = MlsMessageIn::tls_deserialize_exact_bytes(welcome_bytes)
.map_err(|e| MlsError::OpenMls(format!("welcome parse: {e:?}")))?;
let welcome = match msg_in.extract() {
MlsMessageBodyIn::Welcome(w) => w,
other => {
return Err(MlsError::OpenMls(format!(
"expected welcome, got {other:?}"
)));
}
};
let join_cfg = MlsGroupJoinConfig::builder()
.use_ratchet_tree_extension(true)
.build();
let staged = StagedWelcome::new_from_welcome(&self.provider, &join_cfg, welcome, None)
.map_err(|e| MlsError::OpenMls(format!("staged: {e:?}")))?;
self.group = staged
.into_group(&self.provider)
.map_err(|e| MlsError::OpenMls(format!("into_group: {e:?}")))?;
Ok(())
}
pub fn epoch(&self) -> u64 {
self.group.epoch().as_u64()
}
pub fn group_id_16(&self) -> [u8; 16] {
let raw = self.group.group_id().as_slice();
let mut out = [0u8; 16];
let n = raw.len().min(16);
out[..n].copy_from_slice(&raw[..n]);
out
}
pub fn export_state(&self) -> Result<Vec<u8>, MlsError> {
let storage_buf = serialize_storage(self.provider.storage())?;
let signer_buf = self
.signer
.tls_serialize_detached()
.map_err(|e| MlsError::OpenMls(format!("signer serialize: {e:?}")))?;
let gid = self.group.group_id().as_slice().to_vec();
let mut out = Vec::with_capacity(16 + storage_buf.len() + signer_buf.len() + self.identity.len() + gid.len());
for part in [
storage_buf.as_slice(),
signer_buf.as_slice(),
self.identity.as_slice(),
gid.as_slice(),
] {
out.extend_from_slice(&(part.len() as u32).to_le_bytes());
out.extend_from_slice(part);
}
Ok(out)
}
pub fn restore_state(blob: &[u8]) -> Result<Self, MlsError> {
let mut cur = blob;
let mut take = || -> Result<&[u8], MlsError> {
if cur.len() < 4 {
return Err(MlsError::OpenMls("truncated state blob (length)".into()));
}
let len = u32::from_le_bytes([cur[0], cur[1], cur[2], cur[3]]) as usize;
cur = &cur[4..];
if cur.len() < len {
return Err(MlsError::OpenMls("truncated state blob (body)".into()));
}
let (head, tail) = cur.split_at(len);
cur = tail;
Ok(head)
};
let storage_bytes = take()?.to_vec();
let signer_bytes = take()?.to_vec();
let identity = take()?.to_vec();
let gid_bytes = take()?.to_vec();
let provider = OpenMlsRustCrypto::default();
let map = deserialize_storage(&storage_bytes)?;
*provider
.storage()
.values
.write()
.map_err(|_| MlsError::OpenMls("storage lock poisoned".into()))? = map;
let signer = SignatureKeyPair::tls_deserialize_exact_bytes(&signer_bytes)
.map_err(|e| MlsError::OpenMls(format!("signer parse: {e:?}")))?;
let credential = CredentialWithKey {
credential: BasicCredential::new(identity.clone()).into(),
signature_key: signer.public().into(),
};
let group_id = GroupId::from_slice(&gid_bytes);
let group = MlsGroup::load(provider.storage(), &group_id)
.map_err(|e| MlsError::OpenMls(format!("group load: {e:?}")))?
.ok_or_else(|| MlsError::OpenMls("no group in restored state".into()))?;
Ok(Self {
provider,
signer,
group,
credential,
identity,
pending_staged: None,
})
}
pub fn export_stream_key(&self, label: StreamLabel) -> Result<[u8; 32], MlsError> {
let secret = self
.group
.export_secret(self.provider.crypto(), label.as_str(), &[], 32)
.map_err(|e| MlsError::OpenMls(format!("export: {e:?}")))?;
let mut out = [0u8; 32];
out.copy_from_slice(&secret);
Ok(out)
}
pub fn export_raw(&self, label: &str, context: &[u8], len: usize) -> Result<Vec<u8>, MlsError> {
let secret = self
.group
.export_secret(self.provider.crypto(), label, context, len)
.map_err(|e| MlsError::OpenMls(format!("export_raw: {e:?}")))?;
Ok(secret.to_vec())
}
pub fn seal(
&self,
label: StreamLabel,
seq: u32,
plaintext: &[u8],
) -> Result<Vec<u8>, MlsError> {
let key = self.export_stream_key(label)?;
let cipher = ChaCha20Poly1305::new(Key::from_slice(&key));
let mut nonce = [0u8; 12];
nonce[..4].copy_from_slice(&seq.to_be_bytes());
cipher
.encrypt(Nonce::from_slice(&nonce), plaintext)
.map_err(|e| MlsError::Aead(e.to_string()))
}
pub fn open(
&self,
label: StreamLabel,
seq: u32,
ciphertext: &[u8],
) -> Result<Vec<u8>, MlsError> {
let key = self.export_stream_key(label)?;
let cipher = ChaCha20Poly1305::new(Key::from_slice(&key));
let mut nonce = [0u8; 12];
nonce[..4].copy_from_slice(&seq.to_be_bytes());
cipher
.decrypt(Nonce::from_slice(&nonce), ciphertext)
.map_err(|e| MlsError::Aead(e.to_string()))
}
}
#[cfg(test)]
mod tests {
use super::*;
fn alice() -> (MlsContext, openmls::prelude::KeyPackageBundle) {
MlsContext::new_member(b"alice").unwrap()
}
fn bob() -> (MlsContext, openmls::prelude::KeyPackageBundle) {
MlsContext::new_member(b"bob").unwrap()
}
#[test]
fn stream_label_strings_are_correct() {
assert_eq!(StreamLabel::Control.as_str(), "gbp/control");
assert_eq!(StreamLabel::Audio.as_str(), "gbp/audio");
assert_eq!(StreamLabel::Text.as_str(), "gbp/text");
assert_eq!(StreamLabel::Signal.as_str(), "gbp/signal");
}
#[test]
fn label_for_maps_every_stream_type() {
assert_eq!(label_for(StreamType::Control), StreamLabel::Control);
assert_eq!(label_for(StreamType::Audio), StreamLabel::Audio);
assert_eq!(label_for(StreamType::Text), StreamLabel::Text);
assert_eq!(label_for(StreamType::Signal), StreamLabel::Signal);
}
#[test]
fn new_member_starts_at_epoch_zero() {
let (ctx, _kp) = alice();
assert_eq!(ctx.epoch(), 0);
}
#[test]
fn group_id_16_is_16_bytes() {
let (ctx, _kp) = alice();
let id = ctx.group_id_16();
assert_eq!(id.len(), 16);
}
#[test]
fn export_stream_key_is_32_bytes_and_stable() {
let (ctx, _kp) = alice();
let k1 = ctx.export_stream_key(StreamLabel::Text).unwrap();
let k2 = ctx.export_stream_key(StreamLabel::Text).unwrap();
assert_eq!(k1.len(), 32);
assert_eq!(k1, k2);
}
#[test]
fn different_labels_produce_different_keys() {
let (ctx, _kp) = alice();
let k_ctrl = ctx.export_stream_key(StreamLabel::Control).unwrap();
let k_text = ctx.export_stream_key(StreamLabel::Text).unwrap();
assert_ne!(k_ctrl, k_text);
}
#[test]
fn seal_open_single_member_round_trip() {
let (ctx, _kp) = alice();
let plaintext = b"hello world";
let ciphertext = ctx.seal(StreamLabel::Text, 1, plaintext).unwrap();
assert_ne!(ciphertext, plaintext);
let recovered = ctx.open(StreamLabel::Text, 1, &ciphertext).unwrap();
assert_eq!(recovered, plaintext);
}
#[test]
fn seal_wrong_seq_fails_to_open() {
let (ctx, _kp) = alice();
let ciphertext = ctx.seal(StreamLabel::Text, 1, b"secret").unwrap();
assert!(ctx.open(StreamLabel::Text, 2, &ciphertext).is_err());
}
#[test]
fn seal_wrong_label_fails_to_open() {
let (ctx, _kp) = alice();
let ciphertext = ctx.seal(StreamLabel::Text, 0, b"secret").unwrap();
assert!(ctx.open(StreamLabel::Audio, 0, &ciphertext).is_err());
}
#[test]
fn two_member_invite_and_welcome() {
let (mut alice, _akp) = alice();
let (mut bob, bob_kp) = bob();
let welcome = alice.invite(&[bob_kp.key_package().clone()]).unwrap();
assert_eq!(alice.epoch(), 1);
bob.accept_welcome(&welcome).unwrap();
assert_eq!(bob.epoch(), 1);
}
#[test]
fn two_member_seal_open_cross_member() {
let (mut alice, _akp) = alice();
let (mut bob, bob_kp) = bob();
let welcome = alice.invite(&[bob_kp.key_package().clone()]).unwrap();
bob.accept_welcome(&welcome).unwrap();
let plaintext = b"cross-member secret";
let ct = alice.seal(StreamLabel::Control, 0, plaintext).unwrap();
let recovered = bob.open(StreamLabel::Control, 0, &ct).unwrap();
assert_eq!(recovered, plaintext);
}
#[test]
fn export_raw_returns_requested_length() {
let (ctx, _kp) = alice();
let raw = ctx.export_raw("test/label", b"ctx", 48).unwrap();
assert_eq!(raw.len(), 48);
}
#[test]
fn clear_pending_commit_is_idempotent() {
let (mut ctx, _kp) = alice();
ctx.clear_pending_commit().unwrap();
ctx.clear_pending_commit().unwrap();
}
#[test]
fn finalize_pending_commit_on_fresh_group_is_ok() {
let (mut ctx, _kp) = alice();
ctx.finalize_pending_commit().unwrap();
}
#[test]
fn invite_full_does_not_advance_epoch_until_finalize() {
let (mut alice, _akp) = alice();
let (_bob, bob_kp) = bob();
let (_commit, _welcome) = alice.invite_full(&[bob_kp.key_package().clone()]).unwrap();
assert_eq!(alice.epoch(), 0);
alice.finalize_pending_commit().unwrap();
assert_eq!(alice.epoch(), 1);
let (mut alice2, _akp2) = MlsContext::new_member(b"alice2").unwrap();
let (mut bob2, bob2_kp) = MlsContext::new_member(b"bob2").unwrap();
let (_commit_bytes, welcome_bytes) = alice2
.invite_full(&[bob2_kp.key_package().clone()])
.unwrap();
alice2.finalize_pending_commit().unwrap();
bob2.accept_welcome(&welcome_bytes).unwrap();
assert_eq!(alice2.epoch(), 1);
assert_eq!(bob2.epoch(), 1);
}
#[test]
fn export_restore_round_trip_preserves_state() {
let (ctx, _kp) = alice();
let blob = ctx.export_state().unwrap();
let restored = MlsContext::restore_state(&blob).unwrap();
assert_eq!(restored.epoch(), ctx.epoch());
assert_eq!(restored.group_id_16(), ctx.group_id_16());
assert_eq!(
restored.export_stream_key(StreamLabel::Text).unwrap(),
ctx.export_stream_key(StreamLabel::Text).unwrap()
);
}
#[test]
fn restored_context_can_seal_and_open() {
let (ctx, _kp) = alice();
let blob = ctx.export_state().unwrap();
let restored = MlsContext::restore_state(&blob).unwrap();
let ct = restored.seal(StreamLabel::Text, 7, b"after restore").unwrap();
assert_eq!(restored.open(StreamLabel::Text, 7, &ct).unwrap(), b"after restore");
}
#[test]
fn export_restore_preserves_multi_member_group() {
let (mut alice, _akp) = alice();
let (mut bob, bob_kp) = bob();
let welcome = alice.invite(&[bob_kp.key_package().clone()]).unwrap();
bob.accept_welcome(&welcome).unwrap();
assert_eq!(alice.epoch(), 1);
let blob = alice.export_state().unwrap();
let restored_alice = MlsContext::restore_state(&blob).unwrap();
assert_eq!(restored_alice.epoch(), 1);
let ct = restored_alice.seal(StreamLabel::Control, 3, b"still in group").unwrap();
assert_eq!(bob.open(StreamLabel::Control, 3, &ct).unwrap(), b"still in group");
}
#[test]
fn multi_member_invite_one_welcome_serves_all_joiners() {
let (mut alice, _a) = alice();
let (mut bob, bob_kp) = bob();
let (mut carol, carol_kp) = MlsContext::new_member(b"carol").unwrap();
let welcome = alice
.invite(&[bob_kp.key_package().clone(), carol_kp.key_package().clone()])
.unwrap();
assert_eq!(alice.epoch(), 1, "one Add commit advances the epoch once");
bob.accept_welcome(&welcome).unwrap();
carol.accept_welcome(&welcome).unwrap();
assert_eq!(bob.epoch(), 1);
assert_eq!(carol.epoch(), 1);
let ct = alice.seal(StreamLabel::Text, 1, b"hello group").unwrap();
assert_eq!(bob.open(StreamLabel::Text, 1, &ct).unwrap(), b"hello group");
assert_eq!(carol.open(StreamLabel::Text, 1, &ct).unwrap(), b"hello group");
}
#[test]
fn restored_prekey_accepts_welcome() {
let (mut alice, _akp) = alice();
let (bob, bob_kp) = bob();
let bob_blob = bob.export_state().unwrap();
let bob_kp_inner = bob_kp.key_package().clone();
drop(bob);
let welcome = alice.invite(&[bob_kp_inner]).unwrap();
assert_eq!(alice.epoch(), 1);
let mut bob_restored = MlsContext::restore_state(&bob_blob).unwrap();
bob_restored.accept_welcome(&welcome).unwrap();
assert_eq!(bob_restored.epoch(), 1);
let ct = alice.seal(StreamLabel::Text, 1, b"after reload").unwrap();
assert_eq!(bob_restored.open(StreamLabel::Text, 1, &ct).unwrap(), b"after reload");
}
#[test]
fn restore_state_rejects_truncated_blob() {
let (ctx, _kp) = alice();
let blob = ctx.export_state().unwrap();
assert!(MlsContext::restore_state(&blob[..blob.len() / 2]).is_err());
assert!(MlsContext::restore_state(&[]).is_err());
}
}