use openmls::{
framing::{MlsMessageOut, ProcessedMessageContent},
group::{MlsGroup, MlsGroupCreateConfig, MlsGroupJoinConfig},
prelude::{
tls_codec::{Deserialize as TlsDeserialize, Serialize as TlsSerialize},
BasicCredential, Ciphersuite, CredentialWithKey, MlsMessageBodyIn, MlsMessageIn,
ProcessedMessage, ProtocolMessage, ProtocolVersion,
},
};
use openmls_basic_credential::SignatureKeyPair;
use openmls_rust_crypto::OpenMlsRustCrypto;
use openmls_traits::OpenMlsProvider;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use ulid::Ulid;
use crate::{
clock::Hlc,
codec,
device::DeviceId,
error::{Error, Result},
identity::UserId,
message::{IncomingMessage, MessageEnvelope, MessageKind},
storage::Storage,
sync::SyncCursor,
};
const DEFAULT_CIPHERSUITE: Ciphersuite = Ciphersuite::MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
pub struct ConversationId(#[serde(with = "serde_bytes_array16")] pub [u8; 16]);
impl ConversationId {
pub fn new() -> Self {
Self(Ulid::new().to_bytes())
}
pub fn as_hex(&self) -> String {
hex::encode(self.0)
}
}
impl Default for ConversationId {
fn default() -> Self {
Self::new()
}
}
mod serde_bytes_array16 {
use serde::{Deserializer, Serializer};
pub fn serialize<S: Serializer>(b: &[u8; 16], s: S) -> Result<S::Ok, S::Error> {
serde_bytes::serialize(b.as_slice(), s)
}
pub fn deserialize<'de, D: Deserializer<'de>>(d: D) -> Result<[u8; 16], D::Error> {
let v: Vec<u8> = serde_bytes::deserialize(d)?;
v.try_into()
.map_err(|_| serde::de::Error::custom("expected 16 bytes"))
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConversationMeta {
pub id: ConversationId,
pub name: Option<String>,
pub epoch: u64,
pub member_count: u32,
pub is_device_group: bool,
pub created_at_ms: u64,
}
pub struct Conversation {
pub(crate) id: ConversationId,
pub(crate) meta: ConversationMeta,
pub(crate) group: MlsGroup,
pub(crate) crypto: Arc<OpenMlsRustCrypto>,
pub(crate) signing: Arc<SignatureKeyPair>,
pub(crate) own_device: DeviceId,
pub(crate) seq: u64,
pub(crate) hlc: Hlc,
pub(crate) cursor: SyncCursor,
pub(crate) storage: Arc<dyn Storage>,
}
impl std::fmt::Debug for Conversation {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Conversation")
.field("id", &self.id.as_hex())
.field("meta", &self.meta)
.finish()
}
}
impl Conversation {
pub fn id(&self) -> ConversationId {
self.id
}
pub fn meta(&self) -> &ConversationMeta {
&self.meta
}
pub fn epoch(&self) -> u64 {
self.group.epoch().as_u64()
}
pub fn cursor(&self) -> &SyncCursor {
&self.cursor
}
#[allow(clippy::too_many_arguments)]
pub(crate) fn create(
id: ConversationId,
name: Option<String>,
own_device: DeviceId,
own_user: &UserId,
crypto: Arc<OpenMlsRustCrypto>,
signing: Arc<SignatureKeyPair>,
storage: Arc<dyn Storage>,
now_ms: u64,
) -> Result<Self> {
let credential = BasicCredential::new(own_user.0.clone());
let credential_with_key = CredentialWithKey {
credential: credential.into(),
signature_key: signing.public().into(),
};
let cfg = MlsGroupCreateConfig::builder()
.ciphersuite(DEFAULT_CIPHERSUITE)
.use_ratchet_tree_extension(true)
.build();
let group = MlsGroup::new_with_group_id(
crypto.as_ref(),
signing.as_ref(),
&cfg,
openmls::group::GroupId::from_slice(&id.0),
credential_with_key,
)
.map_err(Error::mls)?;
let meta = ConversationMeta {
id,
name,
epoch: 0,
member_count: 1,
is_device_group: false,
created_at_ms: now_ms,
};
Ok(Self {
id,
meta,
group,
crypto,
signing,
own_device,
seq: 0,
hlc: Hlc::ZERO.tick(now_ms),
cursor: SyncCursor::default(),
storage,
})
}
pub(crate) fn join(
welcome_bytes: &[u8],
own_device: DeviceId,
crypto: Arc<OpenMlsRustCrypto>,
signing: Arc<SignatureKeyPair>,
storage: Arc<dyn Storage>,
now_ms: u64,
) -> Result<Self> {
let mls_in = MlsMessageIn::tls_deserialize_exact(welcome_bytes).map_err(Error::mls)?;
let welcome = match mls_in.extract() {
MlsMessageBodyIn::Welcome(w) => w,
_ => return Err(Error::Invalid("expected Welcome".into())),
};
let cfg = MlsGroupJoinConfig::builder()
.use_ratchet_tree_extension(true)
.build();
let staged =
openmls::group::StagedWelcome::new_from_welcome(crypto.as_ref(), &cfg, welcome, None)
.map_err(Error::mls)?;
let group = staged.into_group(crypto.as_ref()).map_err(Error::mls)?;
let id_bytes: [u8; 16] = group
.group_id()
.as_slice()
.try_into()
.map_err(|_| Error::Invalid("group id must be 16 bytes".into()))?;
let id = ConversationId(id_bytes);
let meta = ConversationMeta {
id,
name: None,
epoch: group.epoch().as_u64(),
member_count: group.members().count() as u32,
is_device_group: false,
created_at_ms: now_ms,
};
Ok(Self {
id,
meta,
group,
crypto,
signing,
own_device,
seq: 0,
hlc: Hlc::ZERO.tick(now_ms),
cursor: SyncCursor::default(),
storage,
})
}
pub fn send_application(&mut self, plaintext: &[u8], now_ms: u64) -> Result<MessageEnvelope> {
let out = self
.group
.create_message(self.crypto.as_ref(), self.signing.as_ref(), plaintext)
.map_err(Error::mls)?;
self.seq += 1;
self.hlc = self.hlc.tick(now_ms);
let bytes = out.tls_serialize_detached().map_err(Error::mls)?;
let env = MessageEnvelope::new(
self.id,
self.epoch(),
MessageKind::Application,
self.own_device.clone(),
self.seq,
self.hlc,
bytes,
);
Ok(env)
}
pub fn add_members(&mut self, key_packages: Vec<Vec<u8>>, now_ms: u64) -> Result<AddOutcome> {
let mut kps = Vec::with_capacity(key_packages.len());
for raw in &key_packages {
let mls_in = MlsMessageIn::tls_deserialize_exact(raw).map_err(Error::mls)?;
let kp_in = match mls_in.extract() {
MlsMessageBodyIn::KeyPackage(kp) => kp,
_ => return Err(Error::Invalid("expected KeyPackage".into())),
};
let kp = kp_in
.validate(self.crypto.crypto(), ProtocolVersion::default())
.map_err(Error::mls)?;
kps.push(kp);
}
let (commit_out, welcome_out, _gi) = self
.group
.add_members(self.crypto.as_ref(), self.signing.as_ref(), &kps)
.map_err(Error::mls)?;
self.group
.merge_pending_commit(self.crypto.as_ref())
.map_err(Error::mls)?;
self.meta.epoch = self.epoch();
self.meta.member_count = self.group.members().count() as u32;
self.seq += 1;
self.hlc = self.hlc.tick(now_ms);
let commit_bytes = mls_message_out_bytes(commit_out)?;
let commit_env = MessageEnvelope::new(
self.id,
self.meta.epoch,
MessageKind::Commit,
self.own_device.clone(),
self.seq,
self.hlc,
commit_bytes,
);
let welcome_bytes = mls_message_out_bytes(welcome_out)?;
let welcome_env = MessageEnvelope::new(
self.id,
self.meta.epoch,
MessageKind::Welcome,
self.own_device.clone(),
self.seq,
self.hlc,
welcome_bytes,
);
Ok(AddOutcome {
commit: commit_env,
welcome: welcome_env,
})
}
pub fn remove_members(
&mut self,
leaf_indexes: Vec<u32>,
now_ms: u64,
) -> Result<MessageEnvelope> {
use openmls::prelude::LeafNodeIndex;
let leaves: Vec<LeafNodeIndex> = leaf_indexes.into_iter().map(LeafNodeIndex::new).collect();
let (commit_out, _welcome_opt, _gi) = self
.group
.remove_members(self.crypto.as_ref(), self.signing.as_ref(), &leaves)
.map_err(Error::mls)?;
self.group
.merge_pending_commit(self.crypto.as_ref())
.map_err(Error::mls)?;
self.meta.epoch = self.epoch();
self.meta.member_count = self.group.members().count() as u32;
self.seq += 1;
self.hlc = self.hlc.tick(now_ms);
let bytes = mls_message_out_bytes(commit_out)?;
Ok(MessageEnvelope::new(
self.id,
self.meta.epoch,
MessageKind::Commit,
self.own_device.clone(),
self.seq,
self.hlc,
bytes,
))
}
pub fn process(
&mut self,
env: &MessageEnvelope,
now_ms: u64,
) -> Result<Option<IncomingMessage>> {
if !self.cursor.is_new(env.epoch, &env.sender_device, env.seq) {
return Ok(None); }
let mls_in = MlsMessageIn::tls_deserialize_exact(&env.payload).map_err(Error::mls)?;
let protocol_msg: ProtocolMessage = match mls_in.extract() {
MlsMessageBodyIn::PrivateMessage(m) => m.into(),
MlsMessageBodyIn::PublicMessage(m) => m.into(),
MlsMessageBodyIn::Welcome(_) => {
return Err(Error::Invalid(
"Welcome must be handled at client level, not in-group".into(),
));
}
_ => return Err(Error::Invalid("unsupported MLS message body".into())),
};
let processed: ProcessedMessage = self
.group
.process_message(self.crypto.as_ref(), protocol_msg)
.map_err(Error::mls)?;
let out = match processed.into_content() {
ProcessedMessageContent::ApplicationMessage(app) => {
let pt = app.into_bytes();
Some(IncomingMessage {
conversation_id: self.id,
sender_device: env.sender_device.clone(),
epoch: env.epoch,
hlc: env.hlc,
plaintext: pt,
content_hash: env.content_hash,
})
}
ProcessedMessageContent::StagedCommitMessage(staged) => {
self.group
.merge_staged_commit(self.crypto.as_ref(), *staged)
.map_err(Error::mls)?;
self.meta.epoch = self.epoch();
self.meta.member_count = self.group.members().count() as u32;
None
}
ProcessedMessageContent::ProposalMessage(_)
| ProcessedMessageContent::ExternalJoinProposalMessage(_) => {
None
}
};
self.cursor.advance(
env.epoch,
env.sender_device.clone(),
env.seq,
env.hlc,
now_ms,
);
Ok(out)
}
pub(crate) async fn snapshot_to_storage(&self) -> Result<()> {
let blob = self
.group
.export_secret(self.crypto.as_ref(), "ping-snapshot-marker", &[], 32)
.ok();
let _ = blob; let cursor = self.cursor.encode()?;
self.storage
.put("cursors", &self.id.as_hex(), cursor)
.await?;
let meta = codec::encode(&self.meta)?;
self.storage
.put("groups", &format!("{}/meta", self.id.as_hex()), meta)
.await?;
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct AddOutcome {
pub commit: MessageEnvelope,
pub welcome: MessageEnvelope,
}
fn mls_message_out_bytes(m: MlsMessageOut) -> Result<Vec<u8>> {
m.tls_serialize_detached().map_err(Error::mls)
}