1use openmls::{
10 framing::{MlsMessageOut, ProcessedMessageContent},
11 group::{MlsGroup, MlsGroupCreateConfig, MlsGroupJoinConfig},
12 prelude::{
13 tls_codec::{Deserialize as TlsDeserialize, Serialize as TlsSerialize},
14 BasicCredential, Ciphersuite, CredentialWithKey, MlsMessageBodyIn, MlsMessageIn,
15 ProcessedMessage, ProtocolMessage, ProtocolVersion,
16 },
17};
18use openmls_basic_credential::SignatureKeyPair;
19use openmls_traits::OpenMlsProvider;
20use ping_mls_store::PersistentMlsProvider;
21use serde::{Deserialize, Serialize};
22use std::collections::BTreeMap;
23use std::sync::Arc;
24use ulid::Ulid;
25use zeroize::Zeroizing;
26
27use crate::{
28 clock::Hlc,
29 codec,
30 device::{DeviceId, GroupSnapshotEntry, GroupStateSnapshot, GROUP_SNAPSHOT_VERSION},
31 error::{Error, Result},
32 identity::UserId,
33 message::{IncomingMessage, MessageEnvelope, MessageKind},
34 storage::Storage,
35 sync::SyncCursor,
36};
37
38const DEFAULT_CIPHERSUITE: Ciphersuite = Ciphersuite::MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519;
39
40#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
42pub struct ConversationId(#[serde(with = "serde_bytes_array16")] pub [u8; 16]);
43
44impl ConversationId {
45 pub fn new() -> Self {
46 Self(Ulid::new().to_bytes())
47 }
48 pub fn as_hex(&self) -> String {
49 hex::encode(self.0)
50 }
51}
52
53impl Default for ConversationId {
54 fn default() -> Self {
55 Self::new()
56 }
57}
58
59mod serde_bytes_array16 {
60 use serde::{Deserializer, Serializer};
61 pub fn serialize<S: Serializer>(b: &[u8; 16], s: S) -> Result<S::Ok, S::Error> {
62 serde_bytes::serialize(b.as_slice(), s)
63 }
64 pub fn deserialize<'de, D: Deserializer<'de>>(d: D) -> Result<[u8; 16], D::Error> {
65 let v: Vec<u8> = serde_bytes::deserialize(d)?;
66 v.try_into()
67 .map_err(|_| serde::de::Error::custom("expected 16 bytes"))
68 }
69}
70
71#[derive(Debug, Clone, Serialize, Deserialize)]
72pub struct ConversationMeta {
73 pub id: ConversationId,
74 pub name: Option<String>,
75 pub epoch: u64,
76 pub member_count: u32,
77 pub is_device_group: bool,
78 pub created_at_ms: u64,
79}
80
81pub struct Conversation {
83 pub(crate) id: ConversationId,
84 pub(crate) meta: ConversationMeta,
85 pub(crate) group: MlsGroup,
86 pub(crate) crypto: Arc<PersistentMlsProvider>,
87 pub(crate) signing: Arc<SignatureKeyPair>,
88 pub(crate) own_device: DeviceId,
89 pub(crate) seq: u64,
90 pub(crate) hlc: Hlc,
91 pub(crate) cursor: SyncCursor,
92 pub(crate) storage: Arc<dyn Storage>,
93 pub(crate) device_leaves: BTreeMap<DeviceId, u32>,
105}
106
107impl std::fmt::Debug for Conversation {
108 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
109 f.debug_struct("Conversation")
110 .field("id", &self.id.as_hex())
111 .field("meta", &self.meta)
112 .finish()
113 }
114}
115
116impl Conversation {
117 pub fn id(&self) -> ConversationId {
118 self.id
119 }
120 pub fn meta(&self) -> &ConversationMeta {
121 &self.meta
122 }
123 pub fn epoch(&self) -> u64 {
124 self.group.epoch().as_u64()
125 }
126 pub fn cursor(&self) -> &SyncCursor {
127 &self.cursor
128 }
129
130 #[allow(clippy::too_many_arguments)]
134 pub(crate) fn create(
135 id: ConversationId,
136 name: Option<String>,
137 own_device: DeviceId,
138 own_user: &UserId,
139 crypto: Arc<PersistentMlsProvider>,
140 signing: Arc<SignatureKeyPair>,
141 storage: Arc<dyn Storage>,
142 now_ms: u64,
143 ) -> Result<Self> {
144 let credential = BasicCredential::new(own_user.0.clone());
145 let credential_with_key = CredentialWithKey {
146 credential: credential.into(),
147 signature_key: signing.public().into(),
148 };
149 let cfg = MlsGroupCreateConfig::builder()
150 .ciphersuite(DEFAULT_CIPHERSUITE)
151 .use_ratchet_tree_extension(true)
152 .build();
153 let group = MlsGroup::new_with_group_id(
154 crypto.as_ref(),
155 signing.as_ref(),
156 &cfg,
157 openmls::group::GroupId::from_slice(&id.0),
158 credential_with_key,
159 )
160 .map_err(Error::mls)?;
161
162 let meta = ConversationMeta {
163 id,
164 name,
165 epoch: 0,
166 member_count: 1,
167 is_device_group: false,
168 created_at_ms: now_ms,
169 };
170 let mut device_leaves = BTreeMap::new();
172 device_leaves.insert(own_device.clone(), group.own_leaf_index().u32());
173 Ok(Self {
174 id,
175 meta,
176 group,
177 crypto,
178 signing,
179 own_device,
180 seq: 0,
181 hlc: Hlc::ZERO.tick(now_ms),
182 cursor: SyncCursor::default(),
183 storage,
184 device_leaves,
185 })
186 }
187
188 pub(crate) fn join(
190 welcome_bytes: &[u8],
191 own_device: DeviceId,
192 crypto: Arc<PersistentMlsProvider>,
193 signing: Arc<SignatureKeyPair>,
194 storage: Arc<dyn Storage>,
195 now_ms: u64,
196 ) -> Result<Self> {
197 let mls_in = MlsMessageIn::tls_deserialize_exact(welcome_bytes).map_err(Error::mls)?;
198 let welcome = match mls_in.extract() {
199 MlsMessageBodyIn::Welcome(w) => w,
200 _ => return Err(Error::Invalid("expected Welcome".into())),
201 };
202 let cfg = MlsGroupJoinConfig::builder()
203 .use_ratchet_tree_extension(true)
204 .build();
205 let staged =
206 openmls::group::StagedWelcome::new_from_welcome(crypto.as_ref(), &cfg, welcome, None)
207 .map_err(Error::mls)?;
208 let group = staged.into_group(crypto.as_ref()).map_err(Error::mls)?;
209
210 let id_bytes: [u8; 16] = group
211 .group_id()
212 .as_slice()
213 .try_into()
214 .map_err(|_| Error::Invalid("group id must be 16 bytes".into()))?;
215 let id = ConversationId(id_bytes);
216 let meta = ConversationMeta {
217 id,
218 name: None,
219 epoch: group.epoch().as_u64(),
220 member_count: group.members().count() as u32,
221 is_device_group: false,
222 created_at_ms: now_ms,
223 };
224
225 let join_epoch = group.epoch().as_u64();
230 let own_leaf = group.own_leaf_index().u32();
234 let mut device_leaves = BTreeMap::new();
235 device_leaves.insert(own_device.clone(), own_leaf);
236 Ok(Self {
237 id,
238 meta,
239 group,
240 crypto,
241 signing,
242 own_device,
243 seq: 0,
244 hlc: Hlc::ZERO.tick(now_ms),
245 cursor: SyncCursor {
246 epoch: join_epoch,
247 ..Default::default()
248 },
249 storage,
250 device_leaves,
251 })
252 }
253
254 #[allow(clippy::too_many_arguments)]
262 pub(crate) fn load(
263 id: ConversationId,
264 meta: ConversationMeta,
265 cursor: SyncCursor,
266 device_leaves: BTreeMap<DeviceId, u32>,
267 own_device: DeviceId,
268 crypto: Arc<PersistentMlsProvider>,
269 signing: Arc<SignatureKeyPair>,
270 storage: Arc<dyn Storage>,
271 now_ms: u64,
272 ) -> Result<Option<Self>> {
273 use openmls::group::GroupId;
274 let group_id = GroupId::from_slice(&id.0);
275 let group = match MlsGroup::load(crypto.storage(), &group_id).map_err(Error::mls)? {
276 Some(g) => g,
277 None => return Ok(None),
278 };
279 let seq = cursor
285 .last_seq_per_device
286 .get(&own_device)
287 .copied()
288 .unwrap_or(0);
289 Ok(Some(Self {
290 id,
291 meta,
292 group,
293 crypto,
294 signing,
295 own_device,
296 seq,
297 hlc: Hlc::ZERO.tick(now_ms),
298 cursor,
299 storage,
300 device_leaves,
301 }))
302 }
303
304 pub fn send_application(&mut self, plaintext: &[u8], now_ms: u64) -> Result<MessageEnvelope> {
310 let out = self
311 .group
312 .create_message(self.crypto.as_ref(), self.signing.as_ref(), plaintext)
313 .map_err(Error::mls)?;
314
315 self.seq += 1;
316 self.hlc = self.hlc.tick(now_ms);
317 let bytes = out.tls_serialize_detached().map_err(Error::mls)?;
318 let env = MessageEnvelope::new_application(
319 self.id,
320 self.epoch(),
321 self.own_device.clone(),
322 self.seq,
323 self.hlc,
324 bytes,
325 plaintext,
326 );
327 self.cursor.advance(
331 env.epoch,
332 self.own_device.clone(),
333 self.seq,
334 self.hlc,
335 now_ms,
336 );
337 Ok(env)
338 }
339
340 pub fn add_members(
351 &mut self,
352 entries: Vec<(DeviceId, Vec<u8>)>,
353 now_ms: u64,
354 ) -> Result<AddOutcome> {
355 let mut kps = Vec::with_capacity(entries.len());
356 let mut sig_to_device: Vec<(Vec<u8>, DeviceId)> = Vec::with_capacity(entries.len());
358 for (device_id, raw) in &entries {
359 let mls_in = MlsMessageIn::tls_deserialize_exact(raw).map_err(Error::mls)?;
360 let kp_in = match mls_in.extract() {
361 MlsMessageBodyIn::KeyPackage(kp) => kp,
362 _ => return Err(Error::Invalid("expected KeyPackage".into())),
363 };
364 let kp = kp_in
367 .validate(self.crypto.crypto(), ProtocolVersion::default())
368 .map_err(Error::mls)?;
369 let sig_key = kp.leaf_node().signature_key().as_slice().to_vec();
370 sig_to_device.push((sig_key, device_id.clone()));
371 kps.push(kp);
372 }
373
374 let pre_commit_epoch = self.epoch();
379
380 let (commit_out, welcome_out, _gi) = self
381 .group
382 .add_members(self.crypto.as_ref(), self.signing.as_ref(), &kps)
383 .map_err(Error::mls)?;
384
385 self.group
386 .merge_pending_commit(self.crypto.as_ref())
387 .map_err(Error::mls)?;
388 self.meta.epoch = self.epoch();
389 self.meta.member_count = self.group.members().count() as u32;
390
391 for member in self.group.members() {
395 if let Some((_, device_id)) = sig_to_device
396 .iter()
397 .find(|(sig, _)| sig.as_slice() == member.signature_key.as_slice())
398 {
399 self.device_leaves
400 .insert(device_id.clone(), member.index.u32());
401 }
402 }
403
404 self.seq += 1;
405 self.hlc = self.hlc.tick(now_ms);
406
407 let commit_bytes = mls_message_out_bytes(commit_out)?;
408 let commit_env = MessageEnvelope::new(
409 self.id,
410 pre_commit_epoch,
411 MessageKind::Commit,
412 self.own_device.clone(),
413 self.seq,
414 self.hlc,
415 commit_bytes,
416 );
417
418 let welcome_bytes = mls_message_out_bytes(welcome_out)?;
419 let welcome_env = MessageEnvelope::new(
420 self.id,
421 self.meta.epoch,
422 MessageKind::Welcome,
423 self.own_device.clone(),
424 self.seq,
425 self.hlc,
426 welcome_bytes,
427 );
428
429 self.cursor.advance(
432 self.meta.epoch,
433 self.own_device.clone(),
434 self.seq,
435 self.hlc,
436 now_ms,
437 );
438
439 Ok(AddOutcome {
440 commit: commit_env,
441 welcome: welcome_env,
442 })
443 }
444
445 pub fn remove_members(
446 &mut self,
447 leaf_indexes: Vec<u32>,
448 now_ms: u64,
449 ) -> Result<MessageEnvelope> {
450 use openmls::prelude::LeafNodeIndex;
451 let leaves: Vec<LeafNodeIndex> = leaf_indexes
452 .iter()
453 .copied()
454 .map(LeafNodeIndex::new)
455 .collect();
456
457 let pre_commit_epoch = self.epoch();
459
460 let (commit_out, _welcome_opt, _gi) = self
461 .group
462 .remove_members(self.crypto.as_ref(), self.signing.as_ref(), &leaves)
463 .map_err(Error::mls)?;
464 self.group
465 .merge_pending_commit(self.crypto.as_ref())
466 .map_err(Error::mls)?;
467 self.meta.epoch = self.epoch();
468 self.meta.member_count = self.group.members().count() as u32;
469
470 let removed: std::collections::HashSet<u32> = leaf_indexes.iter().copied().collect();
474 self.device_leaves.retain(|_, idx| !removed.contains(idx));
475
476 self.seq += 1;
477 self.hlc = self.hlc.tick(now_ms);
478 let bytes = mls_message_out_bytes(commit_out)?;
479 let env = MessageEnvelope::new(
480 self.id,
481 pre_commit_epoch,
482 MessageKind::Commit,
483 self.own_device.clone(),
484 self.seq,
485 self.hlc,
486 bytes,
487 );
488 self.cursor.advance(
491 self.meta.epoch,
492 self.own_device.clone(),
493 self.seq,
494 self.hlc,
495 now_ms,
496 );
497 Ok(env)
498 }
499
500 pub fn process(
502 &mut self,
503 env: &MessageEnvelope,
504 now_ms: u64,
505 ) -> Result<Option<IncomingMessage>> {
506 if !self.cursor.is_new(env.epoch, &env.sender_device, env.seq) {
507 return Ok(None); }
509 let mls_in = MlsMessageIn::tls_deserialize_exact(&env.payload).map_err(Error::mls)?;
510
511 let protocol_msg: ProtocolMessage = match mls_in.extract() {
515 MlsMessageBodyIn::PrivateMessage(m) => m.into(),
516 MlsMessageBodyIn::PublicMessage(m) => m.into(),
517 MlsMessageBodyIn::Welcome(_) => {
518 return Err(Error::Invalid(
519 "Welcome must be handled at client level, not in-group".into(),
520 ));
521 }
522 _ => return Err(Error::Invalid("unsupported MLS message body".into())),
523 };
524
525 let processed: ProcessedMessage = self
526 .group
527 .process_message(self.crypto.as_ref(), protocol_msg)
528 .map_err(Error::mls)?;
529
530 let out = match processed.into_content() {
531 ProcessedMessageContent::ApplicationMessage(app) => {
532 let pt = app.into_bytes();
533 if env.v >= 2 {
539 let computed = crate::message::hash_application_plaintext(&pt);
540 if computed != env.content_hash {
541 return Err(Error::Invalid(
542 "v=2 application content_hash mismatch".into(),
543 ));
544 }
545 }
546 Some(IncomingMessage {
547 conversation_id: self.id,
548 sender_device: env.sender_device.clone(),
549 epoch: env.epoch,
550 hlc: env.hlc,
551 plaintext: pt,
552 content_hash: env.content_hash,
553 })
554 }
555 ProcessedMessageContent::StagedCommitMessage(staged) => {
556 self.group
557 .merge_staged_commit(self.crypto.as_ref(), *staged)
558 .map_err(Error::mls)?;
559 self.meta.epoch = self.epoch();
560 self.meta.member_count = self.group.members().count() as u32;
561 None
562 }
563 ProcessedMessageContent::ProposalMessage(_)
564 | ProcessedMessageContent::ExternalJoinProposalMessage(_) => {
565 None
568 }
569 };
570
571 self.cursor.advance(
572 env.epoch,
573 env.sender_device.clone(),
574 env.seq,
575 env.hlc,
576 now_ms,
577 );
578 Ok(out)
579 }
580
581 pub fn export_secret(
598 &self,
599 label: &str,
600 context: &[u8],
601 length: usize,
602 ) -> Result<Zeroizing<Vec<u8>>> {
603 if length == 0 {
604 return Err(Error::Invalid("export_secret length must be > 0".into()));
605 }
606 if length > 1024 {
609 return Err(Error::Invalid(
610 "export_secret length exceeds 1024-byte cap".into(),
611 ));
612 }
613 let bytes = self
614 .group
615 .export_secret(self.crypto.as_ref(), label, context, length)
616 .map_err(Error::mls)?;
617 Ok(Zeroizing::new(bytes))
618 }
619
620 pub fn export_state_snapshot(&self, now_ms: u64) -> Result<Zeroizing<Vec<u8>>> {
634 let entries = self.crypto.group_scoped_entries(&self.id.0);
635 let snap = GroupStateSnapshot {
636 v: GROUP_SNAPSHOT_VERSION,
637 group_id: self.id,
638 openmls_storage_version: openmls_traits::storage::CURRENT_VERSION,
639 snapshot_created_at_ms: now_ms,
640 entries: entries
641 .into_iter()
642 .map(|(key, value)| GroupSnapshotEntry { key, value })
643 .collect(),
644 };
645 Ok(Zeroizing::new(snap.encode()?))
646 }
647
648 pub fn leaf_index_of(&self, device_id: &DeviceId) -> Option<u32> {
654 self.device_leaves.get(device_id).copied()
655 }
656
657 pub(crate) async fn snapshot_to_storage(&self) -> Result<()> {
658 let blob = self
659 .group
660 .export_secret(self.crypto.as_ref(), "ping-snapshot-marker", &[], 32)
661 .ok();
662 let _ = blob; self.crypto
675 .checkpoint_async()
676 .await
677 .map_err(|e| Error::Storage(format!("checkpoint: {e}")))?;
678
679 let cursor = self.cursor.encode()?;
680 self.storage
681 .put("cursors", &self.id.as_hex(), cursor)
682 .await?;
683 let meta = codec::encode(&self.meta)?;
684 self.storage
685 .put("groups", &format!("{}/meta", self.id.as_hex()), meta)
686 .await?;
687 let leaves_vec: Vec<(DeviceId, u32)> = self
691 .device_leaves
692 .iter()
693 .map(|(d, i)| (d.clone(), *i))
694 .collect();
695 let leaves_bytes = codec::encode(&leaves_vec)?;
696 self.storage
697 .put("device_leaves", &self.id.as_hex(), leaves_bytes)
698 .await?;
699 Ok(())
700 }
701}
702
703#[derive(Debug, Clone)]
707pub struct AddOutcome {
708 pub commit: MessageEnvelope,
709 pub welcome: MessageEnvelope,
710}
711
712fn mls_message_out_bytes(m: MlsMessageOut) -> Result<Vec<u8>> {
713 m.tls_serialize_detached().map_err(Error::mls)
714}