1use openmls::{
10 framing::{MlsMessageOut, ProcessedMessageContent},
11 group::{MlsGroup, MlsGroupCreateConfig, MlsGroupJoinConfig},
12 prelude::{
13 tls_codec::{Deserialize as TlsDeserialize, Serialize as TlsSerialize},
14 BasicCredential, Ciphersuite, CredentialWithKey, MlsMessageBodyIn, MlsMessageIn,
15 ProcessedMessage, ProtocolMessage, ProtocolVersion,
16 },
17};
18use openmls_basic_credential::SignatureKeyPair;
19use openmls_traits::OpenMlsProvider;
20use ping_mls_store::PersistentMlsProvider;
21use serde::{Deserialize, Serialize};
22use std::collections::BTreeMap;
23use std::sync::Arc;
24use ulid::Ulid;
25use zeroize::Zeroizing;
26
27use crate::{
28 clock::Hlc,
29 codec,
30 device::{DeviceId, GroupSnapshotEntry, GroupStateSnapshot, GROUP_SNAPSHOT_VERSION},
31 error::{Error, Result},
32 identity::UserId,
33 message::{IncomingMessage, MessageEnvelope, MessageKind},
34 storage::Storage,
35 sync::SyncCursor,
36};
37
38const DEFAULT_CIPHERSUITE: Ciphersuite = Ciphersuite::MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519;
39
40#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
42pub struct ConversationId(#[serde(with = "serde_bytes_array16")] pub [u8; 16]);
43
44impl ConversationId {
45 pub fn new() -> Self {
46 Self(Ulid::new().to_bytes())
47 }
48 pub fn as_hex(&self) -> String {
49 hex::encode(self.0)
50 }
51}
52
53impl Default for ConversationId {
54 fn default() -> Self {
55 Self::new()
56 }
57}
58
59mod serde_bytes_array16 {
60 use serde::{Deserializer, Serializer};
61 pub fn serialize<S: Serializer>(b: &[u8; 16], s: S) -> Result<S::Ok, S::Error> {
62 serde_bytes::serialize(b.as_slice(), s)
63 }
64 pub fn deserialize<'de, D: Deserializer<'de>>(d: D) -> Result<[u8; 16], D::Error> {
65 let v: Vec<u8> = serde_bytes::deserialize(d)?;
66 v.try_into()
67 .map_err(|_| serde::de::Error::custom("expected 16 bytes"))
68 }
69}
70
71#[derive(Debug, Clone, Serialize, Deserialize)]
72pub struct ConversationMeta {
73 pub id: ConversationId,
74 pub name: Option<String>,
75 pub epoch: u64,
76 pub member_count: u32,
77 pub is_device_group: bool,
78 pub created_at_ms: u64,
79}
80
81pub struct Conversation {
83 pub(crate) id: ConversationId,
84 pub(crate) meta: ConversationMeta,
85 pub(crate) group: MlsGroup,
86 pub(crate) crypto: Arc<PersistentMlsProvider>,
87 pub(crate) signing: Arc<SignatureKeyPair>,
88 pub(crate) own_device: DeviceId,
89 pub(crate) seq: u64,
90 pub(crate) hlc: Hlc,
91 pub(crate) cursor: SyncCursor,
92 pub(crate) storage: Arc<dyn Storage>,
93 pub(crate) device_leaves: BTreeMap<DeviceId, u32>,
105}
106
107impl std::fmt::Debug for Conversation {
108 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
109 f.debug_struct("Conversation")
110 .field("id", &self.id.as_hex())
111 .field("meta", &self.meta)
112 .finish()
113 }
114}
115
116impl Conversation {
117 pub fn id(&self) -> ConversationId {
118 self.id
119 }
120 pub fn meta(&self) -> &ConversationMeta {
121 &self.meta
122 }
123 pub fn epoch(&self) -> u64 {
124 self.group.epoch().as_u64()
125 }
126 pub fn cursor(&self) -> &SyncCursor {
127 &self.cursor
128 }
129
130 #[allow(clippy::too_many_arguments)]
134 pub(crate) fn create(
135 id: ConversationId,
136 name: Option<String>,
137 own_device: DeviceId,
138 own_user: &UserId,
139 crypto: Arc<PersistentMlsProvider>,
140 signing: Arc<SignatureKeyPair>,
141 storage: Arc<dyn Storage>,
142 now_ms: u64,
143 ) -> Result<Self> {
144 let credential = BasicCredential::new(own_user.0.clone());
145 let credential_with_key = CredentialWithKey {
146 credential: credential.into(),
147 signature_key: signing.public().into(),
148 };
149 let cfg = MlsGroupCreateConfig::builder()
150 .ciphersuite(DEFAULT_CIPHERSUITE)
151 .use_ratchet_tree_extension(true)
152 .build();
153 let group = MlsGroup::new_with_group_id(
154 crypto.as_ref(),
155 signing.as_ref(),
156 &cfg,
157 openmls::group::GroupId::from_slice(&id.0),
158 credential_with_key,
159 )
160 .map_err(Error::mls)?;
161
162 let meta = ConversationMeta {
163 id,
164 name,
165 epoch: 0,
166 member_count: 1,
167 is_device_group: false,
168 created_at_ms: now_ms,
169 };
170 let mut device_leaves = BTreeMap::new();
172 device_leaves.insert(own_device.clone(), group.own_leaf_index().u32());
173 Ok(Self {
174 id,
175 meta,
176 group,
177 crypto,
178 signing,
179 own_device,
180 seq: 0,
181 hlc: Hlc::ZERO.tick(now_ms),
182 cursor: SyncCursor::default(),
183 storage,
184 device_leaves,
185 })
186 }
187
188 pub(crate) fn join(
190 welcome_bytes: &[u8],
191 own_device: DeviceId,
192 crypto: Arc<PersistentMlsProvider>,
193 signing: Arc<SignatureKeyPair>,
194 storage: Arc<dyn Storage>,
195 now_ms: u64,
196 ) -> Result<Self> {
197 let mls_in = MlsMessageIn::tls_deserialize_exact(welcome_bytes).map_err(Error::mls)?;
198 let welcome = match mls_in.extract() {
199 MlsMessageBodyIn::Welcome(w) => w,
200 _ => return Err(Error::Invalid("expected Welcome".into())),
201 };
202 let cfg = MlsGroupJoinConfig::builder()
203 .use_ratchet_tree_extension(true)
204 .build();
205 let staged =
206 openmls::group::StagedWelcome::new_from_welcome(crypto.as_ref(), &cfg, welcome, None)
207 .map_err(Error::mls)?;
208 let group = staged.into_group(crypto.as_ref()).map_err(Error::mls)?;
209
210 let id_bytes: [u8; 16] = group
211 .group_id()
212 .as_slice()
213 .try_into()
214 .map_err(|_| Error::Invalid("group id must be 16 bytes".into()))?;
215 let id = ConversationId(id_bytes);
216 let meta = ConversationMeta {
217 id,
218 name: None,
219 epoch: group.epoch().as_u64(),
220 member_count: group.members().count() as u32,
221 is_device_group: false,
222 created_at_ms: now_ms,
223 };
224
225 let join_epoch = group.epoch().as_u64();
230 let own_leaf = group.own_leaf_index().u32();
234 let mut device_leaves = BTreeMap::new();
235 device_leaves.insert(own_device.clone(), own_leaf);
236 Ok(Self {
237 id,
238 meta,
239 group,
240 crypto,
241 signing,
242 own_device,
243 seq: 0,
244 hlc: Hlc::ZERO.tick(now_ms),
245 cursor: SyncCursor {
246 epoch: join_epoch,
247 ..Default::default()
248 },
249 storage,
250 device_leaves,
251 })
252 }
253
254 #[allow(clippy::too_many_arguments)]
262 pub(crate) fn load(
263 id: ConversationId,
264 meta: ConversationMeta,
265 cursor: SyncCursor,
266 device_leaves: BTreeMap<DeviceId, u32>,
267 own_device: DeviceId,
268 crypto: Arc<PersistentMlsProvider>,
269 signing: Arc<SignatureKeyPair>,
270 storage: Arc<dyn Storage>,
271 now_ms: u64,
272 ) -> Result<Option<Self>> {
273 use openmls::group::GroupId;
274 let group_id = GroupId::from_slice(&id.0);
275 let group = match MlsGroup::load(crypto.storage(), &group_id).map_err(Error::mls)? {
276 Some(g) => g,
277 None => return Ok(None),
278 };
279 Ok(Some(Self {
280 id,
281 meta,
282 group,
283 crypto,
284 signing,
285 own_device,
286 seq: 0,
287 hlc: Hlc::ZERO.tick(now_ms),
288 cursor,
289 storage,
290 device_leaves,
291 }))
292 }
293
294 pub fn send_application(&mut self, plaintext: &[u8], now_ms: u64) -> Result<MessageEnvelope> {
300 let out = self
301 .group
302 .create_message(self.crypto.as_ref(), self.signing.as_ref(), plaintext)
303 .map_err(Error::mls)?;
304
305 self.seq += 1;
306 self.hlc = self.hlc.tick(now_ms);
307 let bytes = out.tls_serialize_detached().map_err(Error::mls)?;
308 let env = MessageEnvelope::new_application(
309 self.id,
310 self.epoch(),
311 self.own_device.clone(),
312 self.seq,
313 self.hlc,
314 bytes,
315 plaintext,
316 );
317 self.cursor.advance(
321 env.epoch,
322 self.own_device.clone(),
323 self.seq,
324 self.hlc,
325 now_ms,
326 );
327 Ok(env)
328 }
329
330 pub fn add_members(
341 &mut self,
342 entries: Vec<(DeviceId, Vec<u8>)>,
343 now_ms: u64,
344 ) -> Result<AddOutcome> {
345 let mut kps = Vec::with_capacity(entries.len());
346 let mut sig_to_device: Vec<(Vec<u8>, DeviceId)> = Vec::with_capacity(entries.len());
348 for (device_id, raw) in &entries {
349 let mls_in = MlsMessageIn::tls_deserialize_exact(raw).map_err(Error::mls)?;
350 let kp_in = match mls_in.extract() {
351 MlsMessageBodyIn::KeyPackage(kp) => kp,
352 _ => return Err(Error::Invalid("expected KeyPackage".into())),
353 };
354 let kp = kp_in
357 .validate(self.crypto.crypto(), ProtocolVersion::default())
358 .map_err(Error::mls)?;
359 let sig_key = kp.leaf_node().signature_key().as_slice().to_vec();
360 sig_to_device.push((sig_key, device_id.clone()));
361 kps.push(kp);
362 }
363
364 let pre_commit_epoch = self.epoch();
369
370 let (commit_out, welcome_out, _gi) = self
371 .group
372 .add_members(self.crypto.as_ref(), self.signing.as_ref(), &kps)
373 .map_err(Error::mls)?;
374
375 self.group
376 .merge_pending_commit(self.crypto.as_ref())
377 .map_err(Error::mls)?;
378 self.meta.epoch = self.epoch();
379 self.meta.member_count = self.group.members().count() as u32;
380
381 for member in self.group.members() {
385 if let Some((_, device_id)) = sig_to_device
386 .iter()
387 .find(|(sig, _)| sig.as_slice() == member.signature_key.as_slice())
388 {
389 self.device_leaves
390 .insert(device_id.clone(), member.index.u32());
391 }
392 }
393
394 self.seq += 1;
395 self.hlc = self.hlc.tick(now_ms);
396
397 let commit_bytes = mls_message_out_bytes(commit_out)?;
398 let commit_env = MessageEnvelope::new(
399 self.id,
400 pre_commit_epoch,
401 MessageKind::Commit,
402 self.own_device.clone(),
403 self.seq,
404 self.hlc,
405 commit_bytes,
406 );
407
408 let welcome_bytes = mls_message_out_bytes(welcome_out)?;
409 let welcome_env = MessageEnvelope::new(
410 self.id,
411 self.meta.epoch,
412 MessageKind::Welcome,
413 self.own_device.clone(),
414 self.seq,
415 self.hlc,
416 welcome_bytes,
417 );
418
419 self.cursor.advance(
422 self.meta.epoch,
423 self.own_device.clone(),
424 self.seq,
425 self.hlc,
426 now_ms,
427 );
428
429 Ok(AddOutcome {
430 commit: commit_env,
431 welcome: welcome_env,
432 })
433 }
434
435 pub fn remove_members(
436 &mut self,
437 leaf_indexes: Vec<u32>,
438 now_ms: u64,
439 ) -> Result<MessageEnvelope> {
440 use openmls::prelude::LeafNodeIndex;
441 let leaves: Vec<LeafNodeIndex> = leaf_indexes
442 .iter()
443 .copied()
444 .map(LeafNodeIndex::new)
445 .collect();
446
447 let pre_commit_epoch = self.epoch();
449
450 let (commit_out, _welcome_opt, _gi) = self
451 .group
452 .remove_members(self.crypto.as_ref(), self.signing.as_ref(), &leaves)
453 .map_err(Error::mls)?;
454 self.group
455 .merge_pending_commit(self.crypto.as_ref())
456 .map_err(Error::mls)?;
457 self.meta.epoch = self.epoch();
458 self.meta.member_count = self.group.members().count() as u32;
459
460 let removed: std::collections::HashSet<u32> = leaf_indexes.iter().copied().collect();
464 self.device_leaves.retain(|_, idx| !removed.contains(idx));
465
466 self.seq += 1;
467 self.hlc = self.hlc.tick(now_ms);
468 let bytes = mls_message_out_bytes(commit_out)?;
469 let env = MessageEnvelope::new(
470 self.id,
471 pre_commit_epoch,
472 MessageKind::Commit,
473 self.own_device.clone(),
474 self.seq,
475 self.hlc,
476 bytes,
477 );
478 self.cursor.advance(
481 self.meta.epoch,
482 self.own_device.clone(),
483 self.seq,
484 self.hlc,
485 now_ms,
486 );
487 Ok(env)
488 }
489
490 pub fn process(
492 &mut self,
493 env: &MessageEnvelope,
494 now_ms: u64,
495 ) -> Result<Option<IncomingMessage>> {
496 if !self.cursor.is_new(env.epoch, &env.sender_device, env.seq) {
497 return Ok(None); }
499 let mls_in = MlsMessageIn::tls_deserialize_exact(&env.payload).map_err(Error::mls)?;
500
501 let protocol_msg: ProtocolMessage = match mls_in.extract() {
505 MlsMessageBodyIn::PrivateMessage(m) => m.into(),
506 MlsMessageBodyIn::PublicMessage(m) => m.into(),
507 MlsMessageBodyIn::Welcome(_) => {
508 return Err(Error::Invalid(
509 "Welcome must be handled at client level, not in-group".into(),
510 ));
511 }
512 _ => return Err(Error::Invalid("unsupported MLS message body".into())),
513 };
514
515 let processed: ProcessedMessage = self
516 .group
517 .process_message(self.crypto.as_ref(), protocol_msg)
518 .map_err(Error::mls)?;
519
520 let out = match processed.into_content() {
521 ProcessedMessageContent::ApplicationMessage(app) => {
522 let pt = app.into_bytes();
523 if env.v >= 2 {
529 let computed = crate::message::hash_application_plaintext(&pt);
530 if computed != env.content_hash {
531 return Err(Error::Invalid(
532 "v=2 application content_hash mismatch".into(),
533 ));
534 }
535 }
536 Some(IncomingMessage {
537 conversation_id: self.id,
538 sender_device: env.sender_device.clone(),
539 epoch: env.epoch,
540 hlc: env.hlc,
541 plaintext: pt,
542 content_hash: env.content_hash,
543 })
544 }
545 ProcessedMessageContent::StagedCommitMessage(staged) => {
546 self.group
547 .merge_staged_commit(self.crypto.as_ref(), *staged)
548 .map_err(Error::mls)?;
549 self.meta.epoch = self.epoch();
550 self.meta.member_count = self.group.members().count() as u32;
551 None
552 }
553 ProcessedMessageContent::ProposalMessage(_)
554 | ProcessedMessageContent::ExternalJoinProposalMessage(_) => {
555 None
558 }
559 };
560
561 self.cursor.advance(
562 env.epoch,
563 env.sender_device.clone(),
564 env.seq,
565 env.hlc,
566 now_ms,
567 );
568 Ok(out)
569 }
570
571 pub fn export_secret(
588 &self,
589 label: &str,
590 context: &[u8],
591 length: usize,
592 ) -> Result<Zeroizing<Vec<u8>>> {
593 if length == 0 {
594 return Err(Error::Invalid("export_secret length must be > 0".into()));
595 }
596 if length > 1024 {
599 return Err(Error::Invalid(
600 "export_secret length exceeds 1024-byte cap".into(),
601 ));
602 }
603 let bytes = self
604 .group
605 .export_secret(self.crypto.as_ref(), label, context, length)
606 .map_err(Error::mls)?;
607 Ok(Zeroizing::new(bytes))
608 }
609
610 pub fn export_state_snapshot(&self, now_ms: u64) -> Result<Zeroizing<Vec<u8>>> {
624 let entries = self.crypto.group_scoped_entries(&self.id.0);
625 let snap = GroupStateSnapshot {
626 v: GROUP_SNAPSHOT_VERSION,
627 group_id: self.id,
628 openmls_storage_version: openmls_traits::storage::CURRENT_VERSION,
629 snapshot_created_at_ms: now_ms,
630 entries: entries
631 .into_iter()
632 .map(|(key, value)| GroupSnapshotEntry { key, value })
633 .collect(),
634 };
635 Ok(Zeroizing::new(snap.encode()?))
636 }
637
638 pub fn leaf_index_of(&self, device_id: &DeviceId) -> Option<u32> {
644 self.device_leaves.get(device_id).copied()
645 }
646
647 pub(crate) async fn snapshot_to_storage(&self) -> Result<()> {
648 let blob = self
649 .group
650 .export_secret(self.crypto.as_ref(), "ping-snapshot-marker", &[], 32)
651 .ok();
652 let _ = blob; self.crypto
661 .checkpoint()
662 .map_err(|e| Error::Storage(format!("checkpoint: {e}")))?;
663
664 let cursor = self.cursor.encode()?;
665 self.storage
666 .put("cursors", &self.id.as_hex(), cursor)
667 .await?;
668 let meta = codec::encode(&self.meta)?;
669 self.storage
670 .put("groups", &format!("{}/meta", self.id.as_hex()), meta)
671 .await?;
672 let leaves_vec: Vec<(DeviceId, u32)> = self
676 .device_leaves
677 .iter()
678 .map(|(d, i)| (d.clone(), *i))
679 .collect();
680 let leaves_bytes = codec::encode(&leaves_vec)?;
681 self.storage
682 .put("device_leaves", &self.id.as_hex(), leaves_bytes)
683 .await?;
684 Ok(())
685 }
686}
687
688#[derive(Debug, Clone)]
692pub struct AddOutcome {
693 pub commit: MessageEnvelope,
694 pub welcome: MessageEnvelope,
695}
696
697fn mls_message_out_bytes(m: MlsMessageOut) -> Result<Vec<u8>> {
698 m.tls_serialize_detached().map_err(Error::mls)
699}