use openmls::{
framing::{MlsMessageOut, ProcessedMessageContent},
group::{MlsGroup, MlsGroupCreateConfig, MlsGroupJoinConfig},
prelude::{
tls_codec::{Deserialize as TlsDeserialize, Serialize as TlsSerialize},
BasicCredential, Ciphersuite, CredentialWithKey, Extension, Extensions, MlsMessageBodyIn,
MlsMessageIn, ProcessedMessage, ProtocolMessage, ProtocolVersion, UnknownExtension,
},
};
use openmls_basic_credential::SignatureKeyPair;
use openmls_traits::OpenMlsProvider;
use ping_mls_store::PersistentMlsProvider;
use serde::{Deserialize, Serialize};
use std::collections::BTreeMap;
use std::sync::Arc;
use ulid::Ulid;
use zeroize::Zeroizing;
use crate::{
clock::Hlc,
codec,
device::{DeviceId, GroupSnapshotEntry, GroupStateSnapshot, GROUP_SNAPSHOT_VERSION},
error::{Error, Result},
identity::UserId,
message::{IncomingMessage, MessageEnvelope, MessageKind},
storage::Storage,
sync::SyncCursor,
};
const DEFAULT_CIPHERSUITE: Ciphersuite = Ciphersuite::MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519;
const GROUP_NAME_EXTENSION_TYPE: u16 = 0xFF00;
fn group_name_from_extensions(extensions: &Extensions) -> Option<String> {
extensions.iter().find_map(|ext| match ext {
Extension::Unknown(ext_type, data) if *ext_type == GROUP_NAME_EXTENSION_TYPE => {
String::from_utf8(data.0.clone())
.ok()
.filter(|s| !s.is_empty())
}
_ => None,
})
}
fn group_context_extensions_for_name(name: Option<&str>) -> Extensions {
match name {
Some(n) if !n.is_empty() => Extensions::single(Extension::Unknown(
GROUP_NAME_EXTENSION_TYPE,
UnknownExtension(n.as_bytes().to_vec()),
)),
_ => Extensions::empty(),
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
pub struct ConversationId(#[serde(with = "serde_bytes_array16")] pub [u8; 16]);
impl ConversationId {
pub fn new() -> Self {
Self(Ulid::new().to_bytes())
}
pub fn as_hex(&self) -> String {
hex::encode(self.0)
}
}
impl Default for ConversationId {
fn default() -> Self {
Self::new()
}
}
mod serde_bytes_array16 {
use serde::{Deserializer, Serializer};
pub fn serialize<S: Serializer>(b: &[u8; 16], s: S) -> Result<S::Ok, S::Error> {
serde_bytes::serialize(b.as_slice(), s)
}
pub fn deserialize<'de, D: Deserializer<'de>>(d: D) -> Result<[u8; 16], D::Error> {
let v: Vec<u8> = serde_bytes::deserialize(d)?;
v.try_into()
.map_err(|_| serde::de::Error::custom("expected 16 bytes"))
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConversationMeta {
pub id: ConversationId,
pub name: Option<String>,
pub epoch: u64,
pub member_count: u32,
pub is_device_group: bool,
pub created_at_ms: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MemberInfo {
pub user_id: UserId,
pub leaf_index: u32,
}
pub struct Conversation {
pub(crate) id: ConversationId,
pub(crate) meta: ConversationMeta,
pub(crate) group: MlsGroup,
pub(crate) crypto: Arc<PersistentMlsProvider>,
pub(crate) signing: Arc<SignatureKeyPair>,
pub(crate) own_device: DeviceId,
pub(crate) seq: u64,
pub(crate) hlc: Hlc,
pub(crate) cursor: SyncCursor,
pub(crate) storage: Arc<dyn Storage>,
pub(crate) device_leaves: BTreeMap<DeviceId, u32>,
}
impl std::fmt::Debug for Conversation {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Conversation")
.field("id", &self.id.as_hex())
.field("meta", &self.meta)
.finish()
}
}
impl Conversation {
pub fn id(&self) -> ConversationId {
self.id
}
pub fn meta(&self) -> &ConversationMeta {
&self.meta
}
pub fn members(&self) -> Vec<MemberInfo> {
self.group
.members()
.filter_map(|m| {
let basic = BasicCredential::try_from(m.credential).ok()?;
Some(MemberInfo {
user_id: UserId(basic.identity().to_vec()),
leaf_index: m.index.u32(),
})
})
.collect()
}
pub fn epoch(&self) -> u64 {
self.group.epoch().as_u64()
}
pub fn cursor(&self) -> &SyncCursor {
&self.cursor
}
#[allow(clippy::too_many_arguments)]
pub(crate) fn create(
id: ConversationId,
name: Option<String>,
own_device: DeviceId,
own_user: &UserId,
crypto: Arc<PersistentMlsProvider>,
signing: Arc<SignatureKeyPair>,
storage: Arc<dyn Storage>,
now_ms: u64,
) -> Result<Self> {
let credential = BasicCredential::new(own_user.0.clone());
let credential_with_key = CredentialWithKey {
credential: credential.into(),
signature_key: signing.public().into(),
};
let cfg = MlsGroupCreateConfig::builder()
.ciphersuite(DEFAULT_CIPHERSUITE)
.use_ratchet_tree_extension(true)
.with_group_context_extensions(group_context_extensions_for_name(name.as_deref()))
.map_err(Error::mls)?
.build();
let group = MlsGroup::new_with_group_id(
crypto.as_ref(),
signing.as_ref(),
&cfg,
openmls::group::GroupId::from_slice(&id.0),
credential_with_key,
)
.map_err(Error::mls)?;
let meta = ConversationMeta {
id,
name,
epoch: 0,
member_count: 1,
is_device_group: false,
created_at_ms: now_ms,
};
let mut device_leaves = BTreeMap::new();
device_leaves.insert(own_device.clone(), group.own_leaf_index().u32());
Ok(Self {
id,
meta,
group,
crypto,
signing,
own_device,
seq: 0,
hlc: Hlc::ZERO.tick(now_ms),
cursor: SyncCursor::default(),
storage,
device_leaves,
})
}
pub(crate) fn join(
welcome_bytes: &[u8],
own_device: DeviceId,
crypto: Arc<PersistentMlsProvider>,
signing: Arc<SignatureKeyPair>,
storage: Arc<dyn Storage>,
now_ms: u64,
) -> Result<Self> {
let mls_in = MlsMessageIn::tls_deserialize_exact(welcome_bytes).map_err(Error::mls)?;
let welcome = match mls_in.extract() {
MlsMessageBodyIn::Welcome(w) => w,
_ => return Err(Error::Invalid("expected Welcome".into())),
};
let cfg = MlsGroupJoinConfig::builder()
.use_ratchet_tree_extension(true)
.build();
let staged =
openmls::group::StagedWelcome::new_from_welcome(crypto.as_ref(), &cfg, welcome, None)
.map_err(Error::mls)?;
let group = staged.into_group(crypto.as_ref()).map_err(Error::mls)?;
let id_bytes: [u8; 16] = group
.group_id()
.as_slice()
.try_into()
.map_err(|_| Error::Invalid("group id must be 16 bytes".into()))?;
let id = ConversationId(id_bytes);
let name = group_name_from_extensions(group.extensions());
let meta = ConversationMeta {
id,
name,
epoch: group.epoch().as_u64(),
member_count: group.members().count() as u32,
is_device_group: false,
created_at_ms: now_ms,
};
let join_epoch = group.epoch().as_u64();
let own_leaf = group.own_leaf_index().u32();
let mut device_leaves = BTreeMap::new();
device_leaves.insert(own_device.clone(), own_leaf);
Ok(Self {
id,
meta,
group,
crypto,
signing,
own_device,
seq: 0,
hlc: Hlc::ZERO.tick(now_ms),
cursor: SyncCursor {
epoch: join_epoch,
..Default::default()
},
storage,
device_leaves,
})
}
pub(crate) fn name_from_group_state(&self) -> Option<String> {
group_name_from_extensions(self.group.extensions())
}
#[allow(clippy::too_many_arguments)]
pub(crate) fn load(
id: ConversationId,
meta: ConversationMeta,
cursor: SyncCursor,
device_leaves: BTreeMap<DeviceId, u32>,
own_device: DeviceId,
crypto: Arc<PersistentMlsProvider>,
signing: Arc<SignatureKeyPair>,
storage: Arc<dyn Storage>,
now_ms: u64,
) -> Result<Option<Self>> {
use openmls::group::GroupId;
let group_id = GroupId::from_slice(&id.0);
let group = match MlsGroup::load(crypto.storage(), &group_id).map_err(Error::mls)? {
Some(g) => g,
None => return Ok(None),
};
let seq = cursor
.last_seq_per_device
.get(&own_device)
.copied()
.unwrap_or(0);
Ok(Some(Self {
id,
meta,
group,
crypto,
signing,
own_device,
seq,
hlc: Hlc::ZERO.tick(now_ms),
cursor,
storage,
device_leaves,
}))
}
pub fn send_application(&mut self, plaintext: &[u8], now_ms: u64) -> Result<MessageEnvelope> {
let out = self
.group
.create_message(self.crypto.as_ref(), self.signing.as_ref(), plaintext)
.map_err(Error::mls)?;
self.seq += 1;
self.hlc = self.hlc.tick(now_ms);
let bytes = out.tls_serialize_detached().map_err(Error::mls)?;
let env = MessageEnvelope::new_application(
self.id,
self.epoch(),
self.own_device.clone(),
self.seq,
self.hlc,
bytes,
plaintext,
);
self.cursor.advance(
env.epoch,
self.own_device.clone(),
self.seq,
self.hlc,
now_ms,
);
Ok(env)
}
pub fn add_members(
&mut self,
entries: Vec<(DeviceId, Vec<u8>)>,
now_ms: u64,
) -> Result<AddOutcome> {
let staged = self.stage_add_members(entries, now_ms)?;
self.confirm_staged(&staged, now_ms)?;
let StagedCommit {
commit, welcome, ..
} = staged;
let welcome =
welcome.ok_or_else(|| Error::Invalid("add_members produced no Welcome".into()))?;
Ok(AddOutcome { commit, welcome })
}
pub(crate) fn stage_add_members(
&mut self,
entries: Vec<(DeviceId, Vec<u8>)>,
now_ms: u64,
) -> Result<StagedCommit> {
let mut kps = Vec::with_capacity(entries.len());
let mut sig_to_device: Vec<(Vec<u8>, DeviceId)> = Vec::with_capacity(entries.len());
for (device_id, raw) in &entries {
let mls_in = MlsMessageIn::tls_deserialize_exact(raw).map_err(Error::mls)?;
let kp_in = match mls_in.extract() {
MlsMessageBodyIn::KeyPackage(kp) => kp,
_ => return Err(Error::Invalid("expected KeyPackage".into())),
};
let kp = kp_in
.validate(self.crypto.crypto(), ProtocolVersion::default())
.map_err(Error::mls)?;
let sig_key = kp.leaf_node().signature_key().as_slice().to_vec();
sig_to_device.push((sig_key, device_id.clone()));
kps.push(kp);
}
let pre_commit_epoch = self.epoch();
let post_commit_epoch = pre_commit_epoch + 1;
let (commit_out, welcome_out, _gi) = self
.group
.add_members(self.crypto.as_ref(), self.signing.as_ref(), &kps)
.map_err(Error::mls)?;
let next_seq = self.seq + 1;
let next_hlc = self.hlc.tick(now_ms);
let commit_bytes = mls_message_out_bytes(commit_out)?;
let commit_env = MessageEnvelope::new(
self.id,
pre_commit_epoch,
MessageKind::Commit,
self.own_device.clone(),
next_seq,
next_hlc,
commit_bytes,
);
let welcome_bytes = mls_message_out_bytes(welcome_out)?;
let welcome_env = MessageEnvelope::new(
self.id,
post_commit_epoch,
MessageKind::Welcome,
self.own_device.clone(),
next_seq,
next_hlc,
welcome_bytes,
);
Ok(StagedCommit {
commit: commit_env,
welcome: Some(welcome_env),
next_seq,
next_hlc,
leaf_update: StagedLeafUpdate::Add(sig_to_device),
})
}
pub fn remove_members(
&mut self,
leaf_indexes: Vec<u32>,
now_ms: u64,
) -> Result<MessageEnvelope> {
let staged = self.stage_remove_members(leaf_indexes, now_ms)?;
self.confirm_staged(&staged, now_ms)?;
let StagedCommit { commit, .. } = staged;
Ok(commit)
}
pub(crate) fn stage_remove_members(
&mut self,
leaf_indexes: Vec<u32>,
now_ms: u64,
) -> Result<StagedCommit> {
use openmls::prelude::LeafNodeIndex;
let leaves: Vec<LeafNodeIndex> = leaf_indexes
.iter()
.copied()
.map(LeafNodeIndex::new)
.collect();
let pre_commit_epoch = self.epoch();
let (commit_out, _welcome_opt, _gi) = self
.group
.remove_members(self.crypto.as_ref(), self.signing.as_ref(), &leaves)
.map_err(Error::mls)?;
let next_seq = self.seq + 1;
let next_hlc = self.hlc.tick(now_ms);
let bytes = mls_message_out_bytes(commit_out)?;
let commit_env = MessageEnvelope::new(
self.id,
pre_commit_epoch,
MessageKind::Commit,
self.own_device.clone(),
next_seq,
next_hlc,
bytes,
);
let removed: std::collections::HashSet<u32> = leaf_indexes.iter().copied().collect();
Ok(StagedCommit {
commit: commit_env,
welcome: None,
next_seq,
next_hlc,
leaf_update: StagedLeafUpdate::Remove(removed),
})
}
pub(crate) fn confirm_staged(&mut self, staged: &StagedCommit, now_ms: u64) -> Result<()> {
self.group
.merge_pending_commit(self.crypto.as_ref())
.map_err(Error::mls)?;
self.meta.epoch = self.epoch();
self.meta.member_count = self.group.members().count() as u32;
match &staged.leaf_update {
StagedLeafUpdate::Add(sig_to_device) => {
for member in self.group.members() {
if let Some((_, device_id)) = sig_to_device
.iter()
.find(|(sig, _)| sig.as_slice() == member.signature_key.as_slice())
{
self.device_leaves
.insert(device_id.clone(), member.index.u32());
}
}
}
StagedLeafUpdate::Remove(removed) => {
self.device_leaves.retain(|_, idx| !removed.contains(idx));
}
}
self.seq = staged.next_seq;
self.hlc = staged.next_hlc;
self.cursor.advance(
self.meta.epoch,
self.own_device.clone(),
self.seq,
self.hlc,
now_ms,
);
Ok(())
}
pub(crate) fn abort_staged(&mut self) -> Result<()> {
self.group
.clear_pending_commit(self.crypto.storage())
.map_err(Error::mls)?;
Ok(())
}
pub fn process(
&mut self,
env: &MessageEnvelope,
now_ms: u64,
) -> Result<Option<IncomingMessage>> {
if !self.cursor.is_new(env.epoch, &env.sender_device, env.seq) {
return Ok(None); }
let mls_in = MlsMessageIn::tls_deserialize_exact(&env.payload).map_err(Error::mls)?;
let protocol_msg: ProtocolMessage = match mls_in.extract() {
MlsMessageBodyIn::PrivateMessage(m) => m.into(),
MlsMessageBodyIn::PublicMessage(m) => m.into(),
MlsMessageBodyIn::Welcome(_) => {
return Err(Error::Invalid(
"Welcome must be handled at client level, not in-group".into(),
));
}
_ => return Err(Error::Invalid("unsupported MLS message body".into())),
};
let processed: ProcessedMessage = self
.group
.process_message(self.crypto.as_ref(), protocol_msg)
.map_err(Error::mls)?;
let sender_user_id = BasicCredential::try_from(processed.credential().clone())
.map(|c| UserId(c.identity().to_vec()))
.unwrap_or_else(|_| UserId(Vec::new()));
let out = match processed.into_content() {
ProcessedMessageContent::ApplicationMessage(app) => {
let pt = app.into_bytes();
if env.v >= 2 {
let computed = crate::message::hash_application_plaintext(&pt);
if computed != env.content_hash {
return Err(Error::Invalid(
"v=2 application content_hash mismatch".into(),
));
}
}
Some(IncomingMessage {
conversation_id: self.id,
sender_device: env.sender_device.clone(),
sender_user_id,
epoch: env.epoch,
hlc: env.hlc,
plaintext: pt,
content_hash: env.content_hash,
})
}
ProcessedMessageContent::StagedCommitMessage(staged) => {
self.group
.merge_staged_commit(self.crypto.as_ref(), *staged)
.map_err(Error::mls)?;
self.meta.epoch = self.epoch();
self.meta.member_count = self.group.members().count() as u32;
None
}
ProcessedMessageContent::ProposalMessage(_)
| ProcessedMessageContent::ExternalJoinProposalMessage(_) => {
None
}
};
self.cursor.advance(
env.epoch,
env.sender_device.clone(),
env.seq,
env.hlc,
now_ms,
);
Ok(out)
}
pub fn export_secret(
&self,
label: &str,
context: &[u8],
length: usize,
) -> Result<Zeroizing<Vec<u8>>> {
if length == 0 {
return Err(Error::Invalid("export_secret length must be > 0".into()));
}
if length > 1024 {
return Err(Error::Invalid(
"export_secret length exceeds 1024-byte cap".into(),
));
}
let bytes = self
.group
.export_secret(self.crypto.as_ref(), label, context, length)
.map_err(Error::mls)?;
Ok(Zeroizing::new(bytes))
}
pub fn export_state_snapshot(&self, now_ms: u64) -> Result<Zeroizing<Vec<u8>>> {
let entries = self.crypto.group_scoped_entries(&self.id.0);
let snap = GroupStateSnapshot {
v: GROUP_SNAPSHOT_VERSION,
group_id: self.id,
openmls_storage_version: openmls_traits::storage::CURRENT_VERSION,
snapshot_created_at_ms: now_ms,
entries: entries
.into_iter()
.map(|(key, value)| GroupSnapshotEntry { key, value })
.collect(),
};
Ok(Zeroizing::new(snap.encode()?))
}
pub fn leaf_index_of(&self, device_id: &DeviceId) -> Option<u32> {
self.device_leaves.get(device_id).copied()
}
pub(crate) fn snapshot_inputs(&self) -> Result<ConversationSnapshot> {
let leaves_vec: Vec<(DeviceId, u32)> = self
.device_leaves
.iter()
.map(|(d, i)| (d.clone(), *i))
.collect();
Ok(ConversationSnapshot {
id: self.id,
crypto: self.crypto.clone(),
storage: self.storage.clone(),
cursor: self.cursor.encode()?,
meta: codec::encode(&self.meta)?,
device_leaves: codec::encode(&leaves_vec)?,
})
}
pub(crate) async fn snapshot_to_storage(&self) -> Result<()> {
self.snapshot_inputs()?.flush().await
}
}
pub(crate) struct ConversationSnapshot {
id: ConversationId,
crypto: Arc<PersistentMlsProvider>,
storage: Arc<dyn Storage>,
cursor: Vec<u8>,
meta: Vec<u8>,
device_leaves: Vec<u8>,
}
impl ConversationSnapshot {
pub(crate) async fn flush(self) -> Result<()> {
self.crypto
.checkpoint_async()
.await
.map_err(|e| Error::Storage(format!("checkpoint: {e}")))?;
let hex = self.id.as_hex();
self.storage.put("cursors", &hex, self.cursor).await?;
self.storage
.put("groups", &format!("{hex}/meta"), self.meta)
.await?;
self.storage
.put("device_leaves", &hex, self.device_leaves)
.await?;
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct AddOutcome {
pub commit: MessageEnvelope,
pub welcome: MessageEnvelope,
}
pub(crate) enum StagedLeafUpdate {
Add(Vec<(Vec<u8>, DeviceId)>),
Remove(std::collections::HashSet<u32>),
}
pub(crate) struct StagedCommit {
pub commit: MessageEnvelope,
pub welcome: Option<MessageEnvelope>,
next_seq: u64,
next_hlc: Hlc,
leaf_update: StagedLeafUpdate,
}
fn mls_message_out_bytes(m: MlsMessageOut) -> Result<Vec<u8>> {
m.tls_serialize_detached().map_err(Error::mls)
}