1#![deny(missing_docs)]
18
19use chacha20poly1305::{
20 ChaCha20Poly1305, Key, Nonce,
21 aead::{Aead, KeyInit},
22};
23use gbp_core::StreamType;
24use openmls::prelude::tls_codec::DeserializeBytes as _;
25use openmls::prelude::tls_codec::Serialize as _;
26use openmls::prelude::*;
27use openmls_basic_credential::SignatureKeyPair;
28use openmls_rust_crypto::{MemoryStorage, OpenMlsRustCrypto};
29use std::collections::HashMap;
30
31pub const CIPHERSUITE: Ciphersuite = Ciphersuite::MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519;
33
34#[derive(Copy, Clone, Debug, PartialEq, Eq)]
36pub enum StreamLabel {
37 Control,
39 Audio,
41 Text,
43 Signal,
45}
46
47impl StreamLabel {
48 pub fn as_str(self) -> &'static str {
50 match self {
51 Self::Control => "gbp/control",
52 Self::Audio => "gbp/audio",
53 Self::Text => "gbp/text",
54 Self::Signal => "gbp/signal",
55 }
56 }
57}
58
59pub fn label_for(st: StreamType) -> StreamLabel {
61 match st {
62 StreamType::Control => StreamLabel::Control,
63 StreamType::Audio => StreamLabel::Audio,
64 StreamType::Text => StreamLabel::Text,
65 StreamType::Signal => StreamLabel::Signal,
66 }
67}
68
69#[derive(Debug, Clone, Copy, PartialEq, Eq)]
72pub enum ProcessedKind {
73 Commit,
75 Application,
78 Proposal,
80 External,
82}
83
84#[derive(Debug, thiserror::Error)]
86pub enum MlsError {
87 #[error("openmls: {0}")]
89 OpenMls(String),
90 #[error("aead: {0}")]
92 Aead(String),
93 #[error("transition in progress: pending staged commit exists")]
96 TransitionInProgress,
97}
98
99pub struct MlsContext {
105 pub provider: OpenMlsRustCrypto,
107 pub signer: SignatureKeyPair,
109 pub group: MlsGroup,
111 pub credential: CredentialWithKey,
113 pub identity: Vec<u8>,
115 pub pending_staged: Option<StagedCommit>,
122}
123
124fn serialize_storage(s: &MemoryStorage) -> Result<Vec<u8>, MlsError> {
130 let map = s
131 .values
132 .read()
133 .map_err(|_| MlsError::OpenMls("storage lock poisoned".into()))?;
134 let mut out = Vec::new();
135 out.extend_from_slice(&(map.len() as u32).to_le_bytes());
136 for (k, v) in map.iter() {
137 out.extend_from_slice(&(k.len() as u32).to_le_bytes());
138 out.extend_from_slice(k);
139 out.extend_from_slice(&(v.len() as u32).to_le_bytes());
140 out.extend_from_slice(v);
141 }
142 Ok(out)
143}
144
145fn deserialize_storage(bytes: &[u8]) -> Result<HashMap<Vec<u8>, Vec<u8>>, MlsError> {
146 let mut cur = bytes;
147 fn rd_u32(cur: &mut &[u8]) -> Result<usize, MlsError> {
148 if cur.len() < 4 {
149 return Err(MlsError::OpenMls("truncated storage blob".into()));
150 }
151 let n = u32::from_le_bytes([cur[0], cur[1], cur[2], cur[3]]) as usize;
152 *cur = &cur[4..];
153 Ok(n)
154 }
155 fn rd_bytes<'a>(cur: &mut &'a [u8], len: usize) -> Result<&'a [u8], MlsError> {
156 if cur.len() < len {
157 return Err(MlsError::OpenMls("truncated storage blob".into()));
158 }
159 let (head, tail) = cur.split_at(len);
160 *cur = tail;
161 Ok(head)
162 }
163 let count = rd_u32(&mut cur)?;
164 let mut map = HashMap::with_capacity(count);
165 for _ in 0..count {
166 let klen = rd_u32(&mut cur)?;
167 let k = rd_bytes(&mut cur, klen)?.to_vec();
168 let vlen = rd_u32(&mut cur)?;
169 let v = rd_bytes(&mut cur, vlen)?.to_vec();
170 map.insert(k, v);
171 }
172 Ok(map)
173}
174
175impl MlsContext {
176 pub fn new_member(identity: &[u8]) -> Result<(Self, KeyPackageBundle), MlsError> {
180 let provider = OpenMlsRustCrypto::default();
181 let signer = SignatureKeyPair::new(CIPHERSUITE.signature_algorithm())
182 .map_err(|e| MlsError::OpenMls(format!("signer: {e:?}")))?;
183 signer
184 .store(provider.storage())
185 .map_err(|e| MlsError::OpenMls(format!("store signer: {e:?}")))?;
186
187 let credential = BasicCredential::new(identity.to_vec());
188 let credential_with_key = CredentialWithKey {
189 credential: credential.into(),
190 signature_key: signer.public().into(),
191 };
192
193 let kp_bundle = KeyPackage::builder()
194 .build(CIPHERSUITE, &provider, &signer, credential_with_key.clone())
195 .map_err(|e| MlsError::OpenMls(format!("kp: {e:?}")))?;
196
197 let cfg = MlsGroupCreateConfig::builder()
198 .ciphersuite(CIPHERSUITE)
199 .use_ratchet_tree_extension(true)
200 .build();
201 let group = MlsGroup::new(&provider, &signer, &cfg, credential_with_key.clone())
202 .map_err(|e| MlsError::OpenMls(format!("group: {e:?}")))?;
203
204 Ok((
205 Self {
206 provider,
207 signer,
208 group,
209 credential: credential_with_key,
210 identity: identity.to_vec(),
211 pending_staged: None,
212 },
213 kp_bundle,
214 ))
215 }
216
217 pub fn invite_full(
231 &mut self,
232 key_packages: &[KeyPackage],
233 ) -> Result<(Vec<u8>, Vec<u8>), MlsError> {
234 let (commit, welcome, _gi) = self
235 .group
236 .add_members(&self.provider, &self.signer, key_packages)
237 .map_err(|e| MlsError::OpenMls(format!("add_members: {e:?}")))?;
238 let commit_bytes = commit
239 .tls_serialize_detached()
240 .map_err(|e| MlsError::OpenMls(format!("commit serialize: {e:?}")))?;
241 let welcome_bytes = welcome
242 .tls_serialize_detached()
243 .map_err(|e| MlsError::OpenMls(format!("welcome serialize: {e:?}")))?;
244 Ok((commit_bytes, welcome_bytes))
245 }
246
247 pub fn invite(&mut self, key_packages: &[KeyPackage]) -> Result<Vec<u8>, MlsError> {
251 let (_commit, welcome) = self.invite_full(key_packages)?;
252 self.finalize_pending_commit()?;
253 Ok(welcome)
254 }
255
256 pub fn remove_members(&mut self, leaf_indices: &[u32]) -> Result<Vec<u8>, MlsError> {
265 let group_size = self.group.members().count() as u32;
268 for &idx in leaf_indices {
269 if idx >= group_size {
270 return Err(MlsError::OpenMls(format!(
271 "leaf_index {idx} out of range (group size {group_size})"
272 )));
273 }
274 }
275 let leaves: Vec<LeafNodeIndex> = leaf_indices
276 .iter()
277 .copied()
278 .map(LeafNodeIndex::new)
279 .collect();
280 let (commit, _welcome_opt, _gi) = self
281 .group
282 .remove_members(&self.provider, &self.signer, &leaves)
283 .map_err(|e| MlsError::OpenMls(format!("remove_members: {e:?}")))?;
284 commit
285 .tls_serialize_detached()
286 .map_err(|e| MlsError::OpenMls(format!("commit serialize: {e:?}")))
287 }
288
289 pub fn finalize_pending_commit(&mut self) -> Result<(), MlsError> {
299 if let Some(staged) = self.pending_staged.take() {
300 self.group
301 .merge_staged_commit(&self.provider, staged)
302 .map_err(|e| MlsError::OpenMls(format!("merge_staged: {e:?}")))?;
303 }
304 let _ = self.group.merge_pending_commit(&self.provider);
309 Ok(())
310 }
311
312 pub fn clear_pending_commit(&mut self) -> Result<(), MlsError> {
315 self.pending_staged = None;
316 self.group
317 .clear_pending_commit(self.provider.storage())
318 .map_err(|e| MlsError::OpenMls(format!("clear: {e:?}")))?;
319 Ok(())
320 }
321
322 pub fn process_message(&mut self, msg_bytes: &[u8]) -> Result<ProcessedKind, MlsError> {
333 let msg_in = MlsMessageIn::tls_deserialize_exact_bytes(msg_bytes)
334 .map_err(|e| MlsError::OpenMls(format!("msg parse: {e:?}")))?;
335 let protocol_msg = match msg_in.extract() {
336 MlsMessageBodyIn::PublicMessage(m) => ProtocolMessage::from(m),
337 MlsMessageBodyIn::PrivateMessage(m) => ProtocolMessage::from(m),
338 other => {
339 return Err(MlsError::OpenMls(format!(
340 "expected protocol message, got {other:?}"
341 )));
342 }
343 };
344 let processed = self
345 .group
346 .process_message(&self.provider, protocol_msg)
347 .map_err(|e| MlsError::OpenMls(format!("process: {e:?}")))?;
348 match processed.into_content() {
349 ProcessedMessageContent::StagedCommitMessage(staged) => {
350 if self.pending_staged.is_some() {
351 return Err(MlsError::TransitionInProgress);
352 }
353 self.pending_staged = Some(*staged);
354 Ok(ProcessedKind::Commit)
355 }
356 ProcessedMessageContent::ApplicationMessage(_) => Ok(ProcessedKind::Application),
357 ProcessedMessageContent::ProposalMessage(_) => Ok(ProcessedKind::Proposal),
358 ProcessedMessageContent::ExternalJoinProposalMessage(_) => Ok(ProcessedKind::External),
359 }
360 }
361
362 pub fn accept_welcome(&mut self, welcome_bytes: &[u8]) -> Result<(), MlsError> {
365 let msg_in = MlsMessageIn::tls_deserialize_exact_bytes(welcome_bytes)
366 .map_err(|e| MlsError::OpenMls(format!("welcome parse: {e:?}")))?;
367 let welcome = match msg_in.extract() {
368 MlsMessageBodyIn::Welcome(w) => w,
369 other => {
370 return Err(MlsError::OpenMls(format!(
371 "expected welcome, got {other:?}"
372 )));
373 }
374 };
375 let join_cfg = MlsGroupJoinConfig::builder()
376 .use_ratchet_tree_extension(true)
377 .build();
378 let staged = StagedWelcome::new_from_welcome(&self.provider, &join_cfg, welcome, None)
379 .map_err(|e| MlsError::OpenMls(format!("staged: {e:?}")))?;
380 self.group = staged
381 .into_group(&self.provider)
382 .map_err(|e| MlsError::OpenMls(format!("into_group: {e:?}")))?;
383 Ok(())
384 }
385
386 pub fn epoch(&self) -> u64 {
388 self.group.epoch().as_u64()
389 }
390
391 pub fn group_id_16(&self) -> [u8; 16] {
394 let raw = self.group.group_id().as_slice();
395 let mut out = [0u8; 16];
396 let n = raw.len().min(16);
397 out[..n].copy_from_slice(&raw[..n]);
398 out
399 }
400
401 pub fn export_state(&self) -> Result<Vec<u8>, MlsError> {
411 let storage_buf = serialize_storage(self.provider.storage())?;
412 let signer_buf = self
413 .signer
414 .tls_serialize_detached()
415 .map_err(|e| MlsError::OpenMls(format!("signer serialize: {e:?}")))?;
416 let gid = self.group.group_id().as_slice().to_vec();
417
418 let mut out = Vec::with_capacity(16 + storage_buf.len() + signer_buf.len() + self.identity.len() + gid.len());
419 for part in [
420 storage_buf.as_slice(),
421 signer_buf.as_slice(),
422 self.identity.as_slice(),
423 gid.as_slice(),
424 ] {
425 out.extend_from_slice(&(part.len() as u32).to_le_bytes());
426 out.extend_from_slice(part);
427 }
428 Ok(out)
429 }
430
431 pub fn restore_state(blob: &[u8]) -> Result<Self, MlsError> {
436 let mut cur = blob;
437 let mut take = || -> Result<&[u8], MlsError> {
438 if cur.len() < 4 {
439 return Err(MlsError::OpenMls("truncated state blob (length)".into()));
440 }
441 let len = u32::from_le_bytes([cur[0], cur[1], cur[2], cur[3]]) as usize;
442 cur = &cur[4..];
443 if cur.len() < len {
444 return Err(MlsError::OpenMls("truncated state blob (body)".into()));
445 }
446 let (head, tail) = cur.split_at(len);
447 cur = tail;
448 Ok(head)
449 };
450 let storage_bytes = take()?.to_vec();
451 let signer_bytes = take()?.to_vec();
452 let identity = take()?.to_vec();
453 let gid_bytes = take()?.to_vec();
454
455 let provider = OpenMlsRustCrypto::default();
457 let map = deserialize_storage(&storage_bytes)?;
458 *provider
459 .storage()
460 .values
461 .write()
462 .map_err(|_| MlsError::OpenMls("storage lock poisoned".into()))? = map;
463
464 let signer = SignatureKeyPair::tls_deserialize_exact_bytes(&signer_bytes)
465 .map_err(|e| MlsError::OpenMls(format!("signer parse: {e:?}")))?;
466 let credential = CredentialWithKey {
467 credential: BasicCredential::new(identity.clone()).into(),
468 signature_key: signer.public().into(),
469 };
470 let group_id = GroupId::from_slice(&gid_bytes);
471 let group = MlsGroup::load(provider.storage(), &group_id)
472 .map_err(|e| MlsError::OpenMls(format!("group load: {e:?}")))?
473 .ok_or_else(|| MlsError::OpenMls("no group in restored state".into()))?;
474
475 Ok(Self {
476 provider,
477 signer,
478 group,
479 credential,
480 identity,
481 pending_staged: None,
482 })
483 }
484
485 pub fn export_stream_key(&self, label: StreamLabel) -> Result<[u8; 32], MlsError> {
487 let secret = self
488 .group
489 .export_secret(self.provider.crypto(), label.as_str(), &[], 32)
490 .map_err(|e| MlsError::OpenMls(format!("export: {e:?}")))?;
491 let mut out = [0u8; 32];
492 out.copy_from_slice(&secret);
493 Ok(out)
494 }
495
496 pub fn export_raw(&self, label: &str, context: &[u8], len: usize) -> Result<Vec<u8>, MlsError> {
501 let secret = self
502 .group
503 .export_secret(self.provider.crypto(), label, context, len)
504 .map_err(|e| MlsError::OpenMls(format!("export_raw: {e:?}")))?;
505 Ok(secret.to_vec())
506 }
507
508 pub fn seal(
511 &self,
512 label: StreamLabel,
513 seq: u32,
514 plaintext: &[u8],
515 ) -> Result<Vec<u8>, MlsError> {
516 let key = self.export_stream_key(label)?;
517 let cipher = ChaCha20Poly1305::new(Key::from_slice(&key));
518 let mut nonce = [0u8; 12];
519 nonce[..4].copy_from_slice(&seq.to_be_bytes());
520 cipher
521 .encrypt(Nonce::from_slice(&nonce), plaintext)
522 .map_err(|e| MlsError::Aead(e.to_string()))
523 }
524
525 pub fn open(
527 &self,
528 label: StreamLabel,
529 seq: u32,
530 ciphertext: &[u8],
531 ) -> Result<Vec<u8>, MlsError> {
532 let key = self.export_stream_key(label)?;
533 let cipher = ChaCha20Poly1305::new(Key::from_slice(&key));
534 let mut nonce = [0u8; 12];
535 nonce[..4].copy_from_slice(&seq.to_be_bytes());
536 cipher
537 .decrypt(Nonce::from_slice(&nonce), ciphertext)
538 .map_err(|e| MlsError::Aead(e.to_string()))
539 }
540}
541
542#[cfg(test)]
543mod tests {
544 use super::*;
545
546 fn alice() -> (MlsContext, openmls::prelude::KeyPackageBundle) {
547 MlsContext::new_member(b"alice").unwrap()
548 }
549
550 fn bob() -> (MlsContext, openmls::prelude::KeyPackageBundle) {
551 MlsContext::new_member(b"bob").unwrap()
552 }
553
554 #[test]
555 fn stream_label_strings_are_correct() {
556 assert_eq!(StreamLabel::Control.as_str(), "gbp/control");
557 assert_eq!(StreamLabel::Audio.as_str(), "gbp/audio");
558 assert_eq!(StreamLabel::Text.as_str(), "gbp/text");
559 assert_eq!(StreamLabel::Signal.as_str(), "gbp/signal");
560 }
561
562 #[test]
563 fn label_for_maps_every_stream_type() {
564 assert_eq!(label_for(StreamType::Control), StreamLabel::Control);
565 assert_eq!(label_for(StreamType::Audio), StreamLabel::Audio);
566 assert_eq!(label_for(StreamType::Text), StreamLabel::Text);
567 assert_eq!(label_for(StreamType::Signal), StreamLabel::Signal);
568 }
569
570 #[test]
571 fn new_member_starts_at_epoch_zero() {
572 let (ctx, _kp) = alice();
573 assert_eq!(ctx.epoch(), 0);
574 }
575
576 #[test]
577 fn group_id_16_is_16_bytes() {
578 let (ctx, _kp) = alice();
579 let id = ctx.group_id_16();
580 assert_eq!(id.len(), 16);
581 }
582
583 #[test]
584 fn export_stream_key_is_32_bytes_and_stable() {
585 let (ctx, _kp) = alice();
586 let k1 = ctx.export_stream_key(StreamLabel::Text).unwrap();
587 let k2 = ctx.export_stream_key(StreamLabel::Text).unwrap();
588 assert_eq!(k1.len(), 32);
589 assert_eq!(k1, k2);
590 }
591
592 #[test]
593 fn different_labels_produce_different_keys() {
594 let (ctx, _kp) = alice();
595 let k_ctrl = ctx.export_stream_key(StreamLabel::Control).unwrap();
596 let k_text = ctx.export_stream_key(StreamLabel::Text).unwrap();
597 assert_ne!(k_ctrl, k_text);
598 }
599
600 #[test]
601 fn seal_open_single_member_round_trip() {
602 let (ctx, _kp) = alice();
603 let plaintext = b"hello world";
604 let ciphertext = ctx.seal(StreamLabel::Text, 1, plaintext).unwrap();
605 assert_ne!(ciphertext, plaintext);
606 let recovered = ctx.open(StreamLabel::Text, 1, &ciphertext).unwrap();
607 assert_eq!(recovered, plaintext);
608 }
609
610 #[test]
611 fn seal_wrong_seq_fails_to_open() {
612 let (ctx, _kp) = alice();
613 let ciphertext = ctx.seal(StreamLabel::Text, 1, b"secret").unwrap();
614 assert!(ctx.open(StreamLabel::Text, 2, &ciphertext).is_err());
615 }
616
617 #[test]
618 fn seal_wrong_label_fails_to_open() {
619 let (ctx, _kp) = alice();
620 let ciphertext = ctx.seal(StreamLabel::Text, 0, b"secret").unwrap();
621 assert!(ctx.open(StreamLabel::Audio, 0, &ciphertext).is_err());
622 }
623
624 #[test]
625 fn two_member_invite_and_welcome() {
626 let (mut alice, _akp) = alice();
627 let (mut bob, bob_kp) = bob();
628
629 let welcome = alice.invite(&[bob_kp.key_package().clone()]).unwrap();
630 assert_eq!(alice.epoch(), 1);
632
633 bob.accept_welcome(&welcome).unwrap();
634 assert_eq!(bob.epoch(), 1);
636 }
637
638 #[test]
639 fn two_member_seal_open_cross_member() {
640 let (mut alice, _akp) = alice();
641 let (mut bob, bob_kp) = bob();
642
643 let welcome = alice.invite(&[bob_kp.key_package().clone()]).unwrap();
644 bob.accept_welcome(&welcome).unwrap();
645
646 let plaintext = b"cross-member secret";
647 let ct = alice.seal(StreamLabel::Control, 0, plaintext).unwrap();
648 let recovered = bob.open(StreamLabel::Control, 0, &ct).unwrap();
649 assert_eq!(recovered, plaintext);
650 }
651
652 #[test]
653 fn export_raw_returns_requested_length() {
654 let (ctx, _kp) = alice();
655 let raw = ctx.export_raw("test/label", b"ctx", 48).unwrap();
656 assert_eq!(raw.len(), 48);
657 }
658
659 #[test]
660 fn clear_pending_commit_is_idempotent() {
661 let (mut ctx, _kp) = alice();
662 ctx.clear_pending_commit().unwrap();
663 ctx.clear_pending_commit().unwrap();
664 }
665
666 #[test]
667 fn finalize_pending_commit_on_fresh_group_is_ok() {
668 let (mut ctx, _kp) = alice();
669 ctx.finalize_pending_commit().unwrap();
670 }
671
672 #[test]
673 fn invite_full_does_not_advance_epoch_until_finalize() {
674 let (mut alice, _akp) = alice();
675 let (_bob, bob_kp) = bob();
676
677 let (_commit, _welcome) = alice.invite_full(&[bob_kp.key_package().clone()]).unwrap();
678 assert_eq!(alice.epoch(), 0);
680
681 alice.finalize_pending_commit().unwrap();
682 assert_eq!(alice.epoch(), 1);
684
685 let (mut alice2, _akp2) = MlsContext::new_member(b"alice2").unwrap();
687 let (mut bob2, bob2_kp) = MlsContext::new_member(b"bob2").unwrap();
688 let (_commit_bytes, welcome_bytes) = alice2
689 .invite_full(&[bob2_kp.key_package().clone()])
690 .unwrap();
691 alice2.finalize_pending_commit().unwrap();
692 bob2.accept_welcome(&welcome_bytes).unwrap();
693 assert_eq!(alice2.epoch(), 1);
694 assert_eq!(bob2.epoch(), 1);
695 }
696
697 #[test]
698 fn export_restore_round_trip_preserves_state() {
699 let (ctx, _kp) = alice();
700 let blob = ctx.export_state().unwrap();
701 let restored = MlsContext::restore_state(&blob).unwrap();
702 assert_eq!(restored.epoch(), ctx.epoch());
703 assert_eq!(restored.group_id_16(), ctx.group_id_16());
704 assert_eq!(
706 restored.export_stream_key(StreamLabel::Text).unwrap(),
707 ctx.export_stream_key(StreamLabel::Text).unwrap()
708 );
709 }
710
711 #[test]
712 fn restored_context_can_seal_and_open() {
713 let (ctx, _kp) = alice();
714 let blob = ctx.export_state().unwrap();
715 let restored = MlsContext::restore_state(&blob).unwrap();
716 let ct = restored.seal(StreamLabel::Text, 7, b"after restore").unwrap();
717 assert_eq!(restored.open(StreamLabel::Text, 7, &ct).unwrap(), b"after restore");
718 }
719
720 #[test]
721 fn export_restore_preserves_multi_member_group() {
722 let (mut alice, _akp) = alice();
723 let (mut bob, bob_kp) = bob();
724 let welcome = alice.invite(&[bob_kp.key_package().clone()]).unwrap();
725 bob.accept_welcome(&welcome).unwrap();
726 assert_eq!(alice.epoch(), 1);
727
728 let blob = alice.export_state().unwrap();
730 let restored_alice = MlsContext::restore_state(&blob).unwrap();
731 assert_eq!(restored_alice.epoch(), 1);
732
733 let ct = restored_alice.seal(StreamLabel::Control, 3, b"still in group").unwrap();
735 assert_eq!(bob.open(StreamLabel::Control, 3, &ct).unwrap(), b"still in group");
736 }
737
738 #[test]
739 fn multi_member_invite_one_welcome_serves_all_joiners() {
740 let (mut alice, _a) = alice();
746 let (mut bob, bob_kp) = bob();
747 let (mut carol, carol_kp) = MlsContext::new_member(b"carol").unwrap();
748
749 let welcome = alice
750 .invite(&[bob_kp.key_package().clone(), carol_kp.key_package().clone()])
751 .unwrap();
752 assert_eq!(alice.epoch(), 1, "one Add commit advances the epoch once");
753
754 bob.accept_welcome(&welcome).unwrap();
756 carol.accept_welcome(&welcome).unwrap();
757 assert_eq!(bob.epoch(), 1);
758 assert_eq!(carol.epoch(), 1);
759
760 let ct = alice.seal(StreamLabel::Text, 1, b"hello group").unwrap();
762 assert_eq!(bob.open(StreamLabel::Text, 1, &ct).unwrap(), b"hello group");
763 assert_eq!(carol.open(StreamLabel::Text, 1, &ct).unwrap(), b"hello group");
764 }
765
766 #[test]
767 fn restored_prekey_accepts_welcome() {
768 let (mut alice, _akp) = alice();
774 let (bob, bob_kp) = bob();
775
776 let bob_blob = bob.export_state().unwrap();
778 let bob_kp_inner = bob_kp.key_package().clone();
779 drop(bob);
780
781 let welcome = alice.invite(&[bob_kp_inner]).unwrap();
783 assert_eq!(alice.epoch(), 1);
784
785 let mut bob_restored = MlsContext::restore_state(&bob_blob).unwrap();
788 bob_restored.accept_welcome(&welcome).unwrap();
789 assert_eq!(bob_restored.epoch(), 1);
790
791 let ct = alice.seal(StreamLabel::Text, 1, b"after reload").unwrap();
793 assert_eq!(bob_restored.open(StreamLabel::Text, 1, &ct).unwrap(), b"after reload");
794 }
795
796 #[test]
797 fn restore_state_rejects_truncated_blob() {
798 let (ctx, _kp) = alice();
799 let blob = ctx.export_state().unwrap();
800 assert!(MlsContext::restore_state(&blob[..blob.len() / 2]).is_err());
801 assert!(MlsContext::restore_state(&[]).is_err());
802 }
803}