Skip to main content

gbp_mls/
lib.rs

1//! MLS (RFC 9420) integration for the Group Protocol Stack.
2//!
3//! This crate provides:
4//!
5//! * [`MlsContext`] — a member-side wrapper around an `openmls 0.8` group
6//!   (signing key, credential, provider, current group).
7//! * [`StreamLabel`] — labelled exporter constants used to derive AEAD keys
8//!   from the MLS exporter (`gbp/control`, `gbp/audio`, `gbp/text`,
9//!   `gbp/signal`).
10//! * `seal` / `open` — ChaCha20-Poly1305 AEAD with the labelled-exporter key.
11//!
12//! On every epoch change the old key material is invalidated automatically:
13//! the AEAD key is derived on the fly from `MlsGroup::export_secret`, never
14//! cached, and the previous epoch's secret becomes unreachable as soon as the
15//! group ratchets forward.
16
17#![deny(missing_docs)]
18
19use chacha20poly1305::{
20    ChaCha20Poly1305, Key, Nonce,
21    aead::{Aead, KeyInit},
22};
23use gbp_core::StreamType;
24use openmls::prelude::tls_codec::DeserializeBytes as _;
25use openmls::prelude::tls_codec::Serialize as _;
26use openmls::prelude::*;
27use openmls_basic_credential::SignatureKeyPair;
28use openmls_rust_crypto::{MemoryStorage, OpenMlsRustCrypto};
29use std::collections::HashMap;
30
31/// MLS ciphersuite used by the stack: X25519-AES128GCM-SHA256-Ed25519.
32pub const CIPHERSUITE: Ciphersuite = Ciphersuite::MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519;
33
34/// Exporter label that binds the AEAD key to a stream class.
35#[derive(Copy, Clone, Debug, PartialEq, Eq)]
36pub enum StreamLabel {
37    /// `gbp/control` — control plane key.
38    Control,
39    /// `gbp/audio` — GAP key.
40    Audio,
41    /// `gbp/text` — GTP key.
42    Text,
43    /// `gbp/signal` — GSP key.
44    Signal,
45}
46
47impl StreamLabel {
48    /// Returns the stable string used as the `MlsGroup::export_secret` label.
49    pub fn as_str(self) -> &'static str {
50        match self {
51            Self::Control => "gbp/control",
52            Self::Audio => "gbp/audio",
53            Self::Text => "gbp/text",
54            Self::Signal => "gbp/signal",
55        }
56    }
57}
58
59/// Maps a [`StreamType`] to the corresponding [`StreamLabel`].
60pub fn label_for(st: StreamType) -> StreamLabel {
61    match st {
62        StreamType::Control => StreamLabel::Control,
63        StreamType::Audio => StreamLabel::Audio,
64        StreamType::Text => StreamLabel::Text,
65        StreamType::Signal => StreamLabel::Signal,
66    }
67}
68
69/// Categorises an MLS message processed via
70/// [`MlsContext::process_message`].
71#[derive(Debug, Clone, Copy, PartialEq, Eq)]
72pub enum ProcessedKind {
73    /// A Commit message was applied to the group; epoch advanced.
74    Commit,
75    /// An Application message was decrypted (not used by this stack — GBP
76    /// carries application data outside MLS application messages).
77    Application,
78    /// A Proposal-only message was staged.
79    Proposal,
80    /// An external message that did not advance the group.
81    External,
82}
83
84/// Errors raised by the MLS / AEAD layer.
85#[derive(Debug, thiserror::Error)]
86pub enum MlsError {
87    /// Any error returned by `openmls`, serialised as a string.
88    #[error("openmls: {0}")]
89    OpenMls(String),
90    /// AEAD seal or open failure.
91    #[error("aead: {0}")]
92    Aead(String),
93    /// A pending staged commit already exists — the previous transition must
94    /// be finalised or cleared before processing another commit.
95    #[error("transition in progress: pending staged commit exists")]
96    TransitionInProgress,
97}
98
99/// MLS context for a single group member.
100///
101/// Owns the OpenMLS provider, the signing key, the credential and the
102/// current `MlsGroup`. Ratcheting forward is performed by [`MlsContext::invite`]
103/// and [`MlsContext::accept_welcome`].
104pub struct MlsContext {
105    /// OpenMLS crypto provider.
106    pub provider: OpenMlsRustCrypto,
107    /// Signing key pair for this member.
108    pub signer: SignatureKeyPair,
109    /// Current MLS group.
110    pub group: MlsGroup,
111    /// Credential with the public signing key.
112    pub credential: CredentialWithKey,
113    /// Member identity (opaque application-defined bytes).
114    pub identity: Vec<u8>,
115    /// Staged commit produced by [`MlsContext::process_message`] but not
116    /// yet merged. Held until [`MlsContext::finalize_pending_commit`] (on
117    /// EXECUTE_TRANSITION) so that the local epoch only advances together
118    /// with the rest of the group, never earlier — otherwise this side's
119    /// READY frame would be sealed under an epoch the coordinator can't
120    /// open.
121    pub pending_staged: Option<StagedCommit>,
122}
123
124// ── Storage (de)serialisation for export_state / restore_state ────────────────
125// MemoryStorage exposes its key-value map as a public field; its built-in
126// serialize/deserialize are behind the `test-utils` feature, so we (de)serialise
127// the map ourselves with a simple length-prefixed (u32-LE) record format.
128
129fn serialize_storage(s: &MemoryStorage) -> Result<Vec<u8>, MlsError> {
130    let map = s
131        .values
132        .read()
133        .map_err(|_| MlsError::OpenMls("storage lock poisoned".into()))?;
134    let mut out = Vec::new();
135    out.extend_from_slice(&(map.len() as u32).to_le_bytes());
136    for (k, v) in map.iter() {
137        out.extend_from_slice(&(k.len() as u32).to_le_bytes());
138        out.extend_from_slice(k);
139        out.extend_from_slice(&(v.len() as u32).to_le_bytes());
140        out.extend_from_slice(v);
141    }
142    Ok(out)
143}
144
145fn deserialize_storage(bytes: &[u8]) -> Result<HashMap<Vec<u8>, Vec<u8>>, MlsError> {
146    let mut cur = bytes;
147    fn rd_u32(cur: &mut &[u8]) -> Result<usize, MlsError> {
148        if cur.len() < 4 {
149            return Err(MlsError::OpenMls("truncated storage blob".into()));
150        }
151        let n = u32::from_le_bytes([cur[0], cur[1], cur[2], cur[3]]) as usize;
152        *cur = &cur[4..];
153        Ok(n)
154    }
155    fn rd_bytes<'a>(cur: &mut &'a [u8], len: usize) -> Result<&'a [u8], MlsError> {
156        if cur.len() < len {
157            return Err(MlsError::OpenMls("truncated storage blob".into()));
158        }
159        let (head, tail) = cur.split_at(len);
160        *cur = tail;
161        Ok(head)
162    }
163    let count = rd_u32(&mut cur)?;
164    let mut map = HashMap::with_capacity(count);
165    for _ in 0..count {
166        let klen = rd_u32(&mut cur)?;
167        let k = rd_bytes(&mut cur, klen)?.to_vec();
168        let vlen = rd_u32(&mut cur)?;
169        let v = rd_bytes(&mut cur, vlen)?.to_vec();
170        map.insert(k, v);
171    }
172    Ok(map)
173}
174
175impl MlsContext {
176    /// Creates a new context with a single-member group, returning the
177    /// context together with a [`KeyPackageBundle`] that other members can
178    /// use to invite this one.
179    pub fn new_member(identity: &[u8]) -> Result<(Self, KeyPackageBundle), MlsError> {
180        let provider = OpenMlsRustCrypto::default();
181        let signer = SignatureKeyPair::new(CIPHERSUITE.signature_algorithm())
182            .map_err(|e| MlsError::OpenMls(format!("signer: {e:?}")))?;
183        signer
184            .store(provider.storage())
185            .map_err(|e| MlsError::OpenMls(format!("store signer: {e:?}")))?;
186
187        let credential = BasicCredential::new(identity.to_vec());
188        let credential_with_key = CredentialWithKey {
189            credential: credential.into(),
190            signature_key: signer.public().into(),
191        };
192
193        let kp_bundle = KeyPackage::builder()
194            .build(CIPHERSUITE, &provider, &signer, credential_with_key.clone())
195            .map_err(|e| MlsError::OpenMls(format!("kp: {e:?}")))?;
196
197        let cfg = MlsGroupCreateConfig::builder()
198            .ciphersuite(CIPHERSUITE)
199            .use_ratchet_tree_extension(true)
200            .build();
201        let group = MlsGroup::new(&provider, &signer, &cfg, credential_with_key.clone())
202            .map_err(|e| MlsError::OpenMls(format!("group: {e:?}")))?;
203
204        Ok((
205            Self {
206                provider,
207                signer,
208                group,
209                credential: credential_with_key,
210                identity: identity.to_vec(),
211                pending_staged: None,
212            },
213            kp_bundle,
214        ))
215    }
216
217    /// Result of [`MlsContext::invite_full`]: the Commit message that
218    /// existing members must apply via [`MlsContext::process_message`],
219    /// plus the Welcome that the new joiner must apply via
220    /// [`MlsContext::accept_welcome`].
221    ///
222    /// RFC 9420 §11/§12.4 — Welcome is for the joiner only; existing members
223    /// MUST receive the Commit to advance their epoch.
224    ///
225    /// IMPORTANT: this call **does not** merge the pending commit. The
226    /// caller MUST call [`MlsContext::finalize_pending_commit`] only after
227    /// they are confident the Commit/Welcome have been distributed (e.g.
228    /// the GBP coordinator has observed READY quorum). If the distribution
229    /// fails, call [`MlsContext::clear_pending_commit`] to roll back.
230    pub fn invite_full(
231        &mut self,
232        key_packages: &[KeyPackage],
233    ) -> Result<(Vec<u8>, Vec<u8>), MlsError> {
234        let (commit, welcome, _gi) = self
235            .group
236            .add_members(&self.provider, &self.signer, key_packages)
237            .map_err(|e| MlsError::OpenMls(format!("add_members: {e:?}")))?;
238        let commit_bytes = commit
239            .tls_serialize_detached()
240            .map_err(|e| MlsError::OpenMls(format!("commit serialize: {e:?}")))?;
241        let welcome_bytes = welcome
242            .tls_serialize_detached()
243            .map_err(|e| MlsError::OpenMls(format!("welcome serialize: {e:?}")))?;
244        Ok((commit_bytes, welcome_bytes))
245    }
246
247    /// Backwards-compatible wrapper. Builds the Commit, eagerly merges, and
248    /// returns only the Welcome bytes. Kept for callers that distribute the
249    /// Commit out-of-band and don't need atomic abort semantics.
250    pub fn invite(&mut self, key_packages: &[KeyPackage]) -> Result<Vec<u8>, MlsError> {
251        let (_commit, welcome) = self.invite_full(key_packages)?;
252        self.finalize_pending_commit()?;
253        Ok(welcome)
254    }
255
256    /// Removes members identified by their MLS LeafIndex via a Remove commit
257    /// and returns the TLS-serialised Commit message that remaining members
258    /// must apply via [`MlsContext::process_message`].
259    ///
260    /// Like [`MlsContext::invite_full`], the caller is responsible for
261    /// calling [`MlsContext::finalize_pending_commit`] after successful
262    /// distribution, or [`MlsContext::clear_pending_commit`] on failure.
263    /// RFC 9420 §12.3.
264    pub fn remove_members(&mut self, leaf_indices: &[u32]) -> Result<Vec<u8>, MlsError> {
265        // Validate indices against the current group size up front so the
266        // caller gets a clear error rather than an opaque openmls failure.
267        let group_size = self.group.members().count() as u32;
268        for &idx in leaf_indices {
269            if idx >= group_size {
270                return Err(MlsError::OpenMls(format!(
271                    "leaf_index {idx} out of range (group size {group_size})"
272                )));
273            }
274        }
275        let leaves: Vec<LeafNodeIndex> = leaf_indices
276            .iter()
277            .copied()
278            .map(LeafNodeIndex::new)
279            .collect();
280        let (commit, _welcome_opt, _gi) = self
281            .group
282            .remove_members(&self.provider, &self.signer, &leaves)
283            .map_err(|e| MlsError::OpenMls(format!("remove_members: {e:?}")))?;
284        commit
285            .tls_serialize_detached()
286            .map_err(|e| MlsError::OpenMls(format!("commit serialize: {e:?}")))
287    }
288
289    /// Merges any pending commit. Handles both:
290    /// * a self-issued commit produced by [`MlsContext::invite_full`] /
291    ///   [`MlsContext::remove_members`] (merged via `merge_pending_commit`);
292    /// * a staged commit deposited by [`MlsContext::process_message`]
293    ///   (merged via `merge_staged_commit`, consumed from
294    ///   [`MlsContext::pending_staged`]).
295    ///
296    /// Idempotent: if there is nothing to merge, returns Ok. Called from
297    /// the GBP control plane in response to `EXECUTE_TRANSITION`.
298    pub fn finalize_pending_commit(&mut self) -> Result<(), MlsError> {
299        if let Some(staged) = self.pending_staged.take() {
300            self.group
301                .merge_staged_commit(&self.provider, staged)
302                .map_err(|e| MlsError::OpenMls(format!("merge_staged: {e:?}")))?;
303        }
304        // merge_pending_commit errors if there's nothing to merge — for
305        // members that only received a commit (no self-issued one) that's
306        // expected, so swallow the error. Self-issued commits are merged
307        // via this path on the coordinator side.
308        let _ = self.group.merge_pending_commit(&self.provider);
309        Ok(())
310    }
311
312    /// Discards any pending commit (self-issued and/or staged) without
313    /// applying it. Used on `ABORT_TRANSITION`.
314    pub fn clear_pending_commit(&mut self) -> Result<(), MlsError> {
315        self.pending_staged = None;
316        self.group
317            .clear_pending_commit(self.provider.storage())
318            .map_err(|e| MlsError::OpenMls(format!("clear: {e:?}")))?;
319        Ok(())
320    }
321
322    /// Applies a Commit (or staged Proposal) message to the group. Existing
323    /// members invoke this after receiving the Commit broadcast embedded in
324    /// `PREPARE_TRANSITION` args.
325    ///
326    /// IMPORTANT: a Commit is staged but **not** merged here. It must be
327    /// merged via [`MlsContext::finalize_pending_commit`] in response to the
328    /// matching `EXECUTE_TRANSITION`, so that this side's MLS epoch
329    /// advances together with the rest of the group — never earlier.
330    /// Calling this twice without an intervening finalize/clear discards
331    /// the previously staged commit (the second call wins).
332    pub fn process_message(&mut self, msg_bytes: &[u8]) -> Result<ProcessedKind, MlsError> {
333        let msg_in = MlsMessageIn::tls_deserialize_exact_bytes(msg_bytes)
334            .map_err(|e| MlsError::OpenMls(format!("msg parse: {e:?}")))?;
335        let protocol_msg = match msg_in.extract() {
336            MlsMessageBodyIn::PublicMessage(m) => ProtocolMessage::from(m),
337            MlsMessageBodyIn::PrivateMessage(m) => ProtocolMessage::from(m),
338            other => {
339                return Err(MlsError::OpenMls(format!(
340                    "expected protocol message, got {other:?}"
341                )));
342            }
343        };
344        let processed = self
345            .group
346            .process_message(&self.provider, protocol_msg)
347            .map_err(|e| MlsError::OpenMls(format!("process: {e:?}")))?;
348        match processed.into_content() {
349            ProcessedMessageContent::StagedCommitMessage(staged) => {
350                if self.pending_staged.is_some() {
351                    return Err(MlsError::TransitionInProgress);
352                }
353                self.pending_staged = Some(*staged);
354                Ok(ProcessedKind::Commit)
355            }
356            ProcessedMessageContent::ApplicationMessage(_) => Ok(ProcessedKind::Application),
357            ProcessedMessageContent::ProposalMessage(_) => Ok(ProcessedKind::Proposal),
358            ProcessedMessageContent::ExternalJoinProposalMessage(_) => Ok(ProcessedKind::External),
359        }
360    }
361
362    /// Replaces the local group with the one described by the given
363    /// `Welcome` message.
364    pub fn accept_welcome(&mut self, welcome_bytes: &[u8]) -> Result<(), MlsError> {
365        let msg_in = MlsMessageIn::tls_deserialize_exact_bytes(welcome_bytes)
366            .map_err(|e| MlsError::OpenMls(format!("welcome parse: {e:?}")))?;
367        let welcome = match msg_in.extract() {
368            MlsMessageBodyIn::Welcome(w) => w,
369            other => {
370                return Err(MlsError::OpenMls(format!(
371                    "expected welcome, got {other:?}"
372                )));
373            }
374        };
375        let join_cfg = MlsGroupJoinConfig::builder()
376            .use_ratchet_tree_extension(true)
377            .build();
378        let staged = StagedWelcome::new_from_welcome(&self.provider, &join_cfg, welcome, None)
379            .map_err(|e| MlsError::OpenMls(format!("staged: {e:?}")))?;
380        self.group = staged
381            .into_group(&self.provider)
382            .map_err(|e| MlsError::OpenMls(format!("into_group: {e:?}")))?;
383        Ok(())
384    }
385
386    /// Returns the current group epoch.
387    pub fn epoch(&self) -> u64 {
388        self.group.epoch().as_u64()
389    }
390
391    /// Returns the 16-byte group identifier (truncated or zero-padded if the
392    /// underlying MLS group_id has a different length).
393    pub fn group_id_16(&self) -> [u8; 16] {
394        let raw = self.group.group_id().as_slice();
395        let mut out = [0u8; 16];
396        let n = raw.len().min(16);
397        out[..n].copy_from_slice(&raw[..n]);
398        out
399    }
400
401    /// Serialises the full local MLS state into an opaque blob that
402    /// [`MlsContext::restore_state`] can reconstruct verbatim. Lets a client
403    /// persist the context (disk / IndexedDB) so a chat survives a restart
404    /// without re-establishing the group — the basis for deterministic,
405    /// reload-surviving secret chats.
406    ///
407    /// The blob bundles four length-prefixed (u32-LE) sections:
408    /// `[provider storage | signer | identity | group_id]`. It contains
409    /// **private key material** — callers MUST store it encrypted at rest.
410    pub fn export_state(&self) -> Result<Vec<u8>, MlsError> {
411        let storage_buf = serialize_storage(self.provider.storage())?;
412        let signer_buf = self
413            .signer
414            .tls_serialize_detached()
415            .map_err(|e| MlsError::OpenMls(format!("signer serialize: {e:?}")))?;
416        let gid = self.group.group_id().as_slice().to_vec();
417
418        let mut out = Vec::with_capacity(16 + storage_buf.len() + signer_buf.len() + self.identity.len() + gid.len());
419        for part in [
420            storage_buf.as_slice(),
421            signer_buf.as_slice(),
422            self.identity.as_slice(),
423            gid.as_slice(),
424        ] {
425            out.extend_from_slice(&(part.len() as u32).to_le_bytes());
426            out.extend_from_slice(part);
427        }
428        Ok(out)
429    }
430
431    /// Reconstructs a context from a blob produced by
432    /// [`MlsContext::export_state`]. The restored context is at the same epoch
433    /// with the same group state, signer and identity, and can immediately
434    /// send / receive again.
435    pub fn restore_state(blob: &[u8]) -> Result<Self, MlsError> {
436        let mut cur = blob;
437        let mut take = || -> Result<&[u8], MlsError> {
438            if cur.len() < 4 {
439                return Err(MlsError::OpenMls("truncated state blob (length)".into()));
440            }
441            let len = u32::from_le_bytes([cur[0], cur[1], cur[2], cur[3]]) as usize;
442            cur = &cur[4..];
443            if cur.len() < len {
444                return Err(MlsError::OpenMls("truncated state blob (body)".into()));
445            }
446            let (head, tail) = cur.split_at(len);
447            cur = tail;
448            Ok(head)
449        };
450        let storage_bytes = take()?.to_vec();
451        let signer_bytes = take()?.to_vec();
452        let identity = take()?.to_vec();
453        let gid_bytes = take()?.to_vec();
454
455        // Rehydrate a fresh provider's (public) key-value map from the blob.
456        let provider = OpenMlsRustCrypto::default();
457        let map = deserialize_storage(&storage_bytes)?;
458        *provider
459            .storage()
460            .values
461            .write()
462            .map_err(|_| MlsError::OpenMls("storage lock poisoned".into()))? = map;
463
464        let signer = SignatureKeyPair::tls_deserialize_exact_bytes(&signer_bytes)
465            .map_err(|e| MlsError::OpenMls(format!("signer parse: {e:?}")))?;
466        let credential = CredentialWithKey {
467            credential: BasicCredential::new(identity.clone()).into(),
468            signature_key: signer.public().into(),
469        };
470        let group_id = GroupId::from_slice(&gid_bytes);
471        let group = MlsGroup::load(provider.storage(), &group_id)
472            .map_err(|e| MlsError::OpenMls(format!("group load: {e:?}")))?
473            .ok_or_else(|| MlsError::OpenMls("no group in restored state".into()))?;
474
475        Ok(Self {
476            provider,
477            signer,
478            group,
479            credential,
480            identity,
481            pending_staged: None,
482        })
483    }
484
485    /// Exports a 32-byte secret under the given stream label.
486    pub fn export_stream_key(&self, label: StreamLabel) -> Result<[u8; 32], MlsError> {
487        let secret = self
488            .group
489            .export_secret(self.provider.crypto(), label.as_str(), &[], 32)
490            .map_err(|e| MlsError::OpenMls(format!("export: {e:?}")))?;
491        let mut out = [0u8; 32];
492        out.copy_from_slice(&secret);
493        Ok(out)
494    }
495
496    /// Exports `len` bytes under an arbitrary `label` and `context`.
497    ///
498    /// Used by external crates (e.g. `hush-sframe`) that need custom KDF
499    /// labels without depending on OpenMLS directly.
500    pub fn export_raw(&self, label: &str, context: &[u8], len: usize) -> Result<Vec<u8>, MlsError> {
501        let secret = self
502            .group
503            .export_secret(self.provider.crypto(), label, context, len)
504            .map_err(|e| MlsError::OpenMls(format!("export_raw: {e:?}")))?;
505        Ok(secret.to_vec())
506    }
507
508    /// Encrypts `plaintext` with ChaCha20-Poly1305 using the stream-labelled
509    /// AEAD key and a nonce derived from the per-stream `seq`.
510    pub fn seal(
511        &self,
512        label: StreamLabel,
513        seq: u32,
514        plaintext: &[u8],
515    ) -> Result<Vec<u8>, MlsError> {
516        let key = self.export_stream_key(label)?;
517        let cipher = ChaCha20Poly1305::new(Key::from_slice(&key));
518        let mut nonce = [0u8; 12];
519        nonce[..4].copy_from_slice(&seq.to_be_bytes());
520        cipher
521            .encrypt(Nonce::from_slice(&nonce), plaintext)
522            .map_err(|e| MlsError::Aead(e.to_string()))
523    }
524
525    /// Decrypts `ciphertext` with the same parameters as [`MlsContext::seal`].
526    pub fn open(
527        &self,
528        label: StreamLabel,
529        seq: u32,
530        ciphertext: &[u8],
531    ) -> Result<Vec<u8>, MlsError> {
532        let key = self.export_stream_key(label)?;
533        let cipher = ChaCha20Poly1305::new(Key::from_slice(&key));
534        let mut nonce = [0u8; 12];
535        nonce[..4].copy_from_slice(&seq.to_be_bytes());
536        cipher
537            .decrypt(Nonce::from_slice(&nonce), ciphertext)
538            .map_err(|e| MlsError::Aead(e.to_string()))
539    }
540}
541
542#[cfg(test)]
543mod tests {
544    use super::*;
545
546    fn alice() -> (MlsContext, openmls::prelude::KeyPackageBundle) {
547        MlsContext::new_member(b"alice").unwrap()
548    }
549
550    fn bob() -> (MlsContext, openmls::prelude::KeyPackageBundle) {
551        MlsContext::new_member(b"bob").unwrap()
552    }
553
554    #[test]
555    fn stream_label_strings_are_correct() {
556        assert_eq!(StreamLabel::Control.as_str(), "gbp/control");
557        assert_eq!(StreamLabel::Audio.as_str(), "gbp/audio");
558        assert_eq!(StreamLabel::Text.as_str(), "gbp/text");
559        assert_eq!(StreamLabel::Signal.as_str(), "gbp/signal");
560    }
561
562    #[test]
563    fn label_for_maps_every_stream_type() {
564        assert_eq!(label_for(StreamType::Control), StreamLabel::Control);
565        assert_eq!(label_for(StreamType::Audio), StreamLabel::Audio);
566        assert_eq!(label_for(StreamType::Text), StreamLabel::Text);
567        assert_eq!(label_for(StreamType::Signal), StreamLabel::Signal);
568    }
569
570    #[test]
571    fn new_member_starts_at_epoch_zero() {
572        let (ctx, _kp) = alice();
573        assert_eq!(ctx.epoch(), 0);
574    }
575
576    #[test]
577    fn group_id_16_is_16_bytes() {
578        let (ctx, _kp) = alice();
579        let id = ctx.group_id_16();
580        assert_eq!(id.len(), 16);
581    }
582
583    #[test]
584    fn export_stream_key_is_32_bytes_and_stable() {
585        let (ctx, _kp) = alice();
586        let k1 = ctx.export_stream_key(StreamLabel::Text).unwrap();
587        let k2 = ctx.export_stream_key(StreamLabel::Text).unwrap();
588        assert_eq!(k1.len(), 32);
589        assert_eq!(k1, k2);
590    }
591
592    #[test]
593    fn different_labels_produce_different_keys() {
594        let (ctx, _kp) = alice();
595        let k_ctrl = ctx.export_stream_key(StreamLabel::Control).unwrap();
596        let k_text = ctx.export_stream_key(StreamLabel::Text).unwrap();
597        assert_ne!(k_ctrl, k_text);
598    }
599
600    #[test]
601    fn seal_open_single_member_round_trip() {
602        let (ctx, _kp) = alice();
603        let plaintext = b"hello world";
604        let ciphertext = ctx.seal(StreamLabel::Text, 1, plaintext).unwrap();
605        assert_ne!(ciphertext, plaintext);
606        let recovered = ctx.open(StreamLabel::Text, 1, &ciphertext).unwrap();
607        assert_eq!(recovered, plaintext);
608    }
609
610    #[test]
611    fn seal_wrong_seq_fails_to_open() {
612        let (ctx, _kp) = alice();
613        let ciphertext = ctx.seal(StreamLabel::Text, 1, b"secret").unwrap();
614        assert!(ctx.open(StreamLabel::Text, 2, &ciphertext).is_err());
615    }
616
617    #[test]
618    fn seal_wrong_label_fails_to_open() {
619        let (ctx, _kp) = alice();
620        let ciphertext = ctx.seal(StreamLabel::Text, 0, b"secret").unwrap();
621        assert!(ctx.open(StreamLabel::Audio, 0, &ciphertext).is_err());
622    }
623
624    #[test]
625    fn two_member_invite_and_welcome() {
626        let (mut alice, _akp) = alice();
627        let (mut bob, bob_kp) = bob();
628
629        let welcome = alice.invite(&[bob_kp.key_package().clone()]).unwrap();
630        // Alice's epoch advances after invite.
631        assert_eq!(alice.epoch(), 1);
632
633        bob.accept_welcome(&welcome).unwrap();
634        // Bob joins at epoch 1.
635        assert_eq!(bob.epoch(), 1);
636    }
637
638    #[test]
639    fn two_member_seal_open_cross_member() {
640        let (mut alice, _akp) = alice();
641        let (mut bob, bob_kp) = bob();
642
643        let welcome = alice.invite(&[bob_kp.key_package().clone()]).unwrap();
644        bob.accept_welcome(&welcome).unwrap();
645
646        let plaintext = b"cross-member secret";
647        let ct = alice.seal(StreamLabel::Control, 0, plaintext).unwrap();
648        let recovered = bob.open(StreamLabel::Control, 0, &ct).unwrap();
649        assert_eq!(recovered, plaintext);
650    }
651
652    #[test]
653    fn export_raw_returns_requested_length() {
654        let (ctx, _kp) = alice();
655        let raw = ctx.export_raw("test/label", b"ctx", 48).unwrap();
656        assert_eq!(raw.len(), 48);
657    }
658
659    #[test]
660    fn clear_pending_commit_is_idempotent() {
661        let (mut ctx, _kp) = alice();
662        ctx.clear_pending_commit().unwrap();
663        ctx.clear_pending_commit().unwrap();
664    }
665
666    #[test]
667    fn finalize_pending_commit_on_fresh_group_is_ok() {
668        let (mut ctx, _kp) = alice();
669        ctx.finalize_pending_commit().unwrap();
670    }
671
672    #[test]
673    fn invite_full_does_not_advance_epoch_until_finalize() {
674        let (mut alice, _akp) = alice();
675        let (_bob, bob_kp) = bob();
676
677        let (_commit, _welcome) = alice.invite_full(&[bob_kp.key_package().clone()]).unwrap();
678        // invite_full does NOT merge → epoch still 0
679        assert_eq!(alice.epoch(), 0);
680
681        alice.finalize_pending_commit().unwrap();
682        // after finalize → epoch 1
683        assert_eq!(alice.epoch(), 1);
684
685        // New members join via welcome, not via commit.
686        let (mut alice2, _akp2) = MlsContext::new_member(b"alice2").unwrap();
687        let (mut bob2, bob2_kp) = MlsContext::new_member(b"bob2").unwrap();
688        let (_commit_bytes, welcome_bytes) = alice2
689            .invite_full(&[bob2_kp.key_package().clone()])
690            .unwrap();
691        alice2.finalize_pending_commit().unwrap();
692        bob2.accept_welcome(&welcome_bytes).unwrap();
693        assert_eq!(alice2.epoch(), 1);
694        assert_eq!(bob2.epoch(), 1);
695    }
696
697    #[test]
698    fn export_restore_round_trip_preserves_state() {
699        let (ctx, _kp) = alice();
700        let blob = ctx.export_state().unwrap();
701        let restored = MlsContext::restore_state(&blob).unwrap();
702        assert_eq!(restored.epoch(), ctx.epoch());
703        assert_eq!(restored.group_id_16(), ctx.group_id_16());
704        // Identical exporter secret ⇒ the full group state was restored.
705        assert_eq!(
706            restored.export_stream_key(StreamLabel::Text).unwrap(),
707            ctx.export_stream_key(StreamLabel::Text).unwrap()
708        );
709    }
710
711    #[test]
712    fn restored_context_can_seal_and_open() {
713        let (ctx, _kp) = alice();
714        let blob = ctx.export_state().unwrap();
715        let restored = MlsContext::restore_state(&blob).unwrap();
716        let ct = restored.seal(StreamLabel::Text, 7, b"after restore").unwrap();
717        assert_eq!(restored.open(StreamLabel::Text, 7, &ct).unwrap(), b"after restore");
718    }
719
720    #[test]
721    fn export_restore_preserves_multi_member_group() {
722        let (mut alice, _akp) = alice();
723        let (mut bob, bob_kp) = bob();
724        let welcome = alice.invite(&[bob_kp.key_package().clone()]).unwrap();
725        bob.accept_welcome(&welcome).unwrap();
726        assert_eq!(alice.epoch(), 1);
727
728        // Persist Alice at epoch 1, then restore from the blob.
729        let blob = alice.export_state().unwrap();
730        let restored_alice = MlsContext::restore_state(&blob).unwrap();
731        assert_eq!(restored_alice.epoch(), 1);
732
733        // Restored Alice still shares the group key with Bob.
734        let ct = restored_alice.seal(StreamLabel::Control, 3, b"still in group").unwrap();
735        assert_eq!(bob.open(StreamLabel::Control, 3, &ct).unwrap(), b"still in group");
736    }
737
738    #[test]
739    fn multi_member_invite_one_welcome_serves_all_joiners() {
740        // A single Add commit for several KeyPackages yields ONE Welcome that
741        // every new member accepts with their own KeyPackage (RFC 9420 §12.4).
742        // This is what Hush secret groups rely on: claim N KeyPackages, one
743        // invite, broadcast one Welcome. (Existing tests only ever added one
744        // joiner at a time — this covers the multi-element slice.)
745        let (mut alice, _a) = alice();
746        let (mut bob, bob_kp) = bob();
747        let (mut carol, carol_kp) = MlsContext::new_member(b"carol").unwrap();
748
749        let welcome = alice
750            .invite(&[bob_kp.key_package().clone(), carol_kp.key_package().clone()])
751            .unwrap();
752        assert_eq!(alice.epoch(), 1, "one Add commit advances the epoch once");
753
754        // Both joiners accept the SAME Welcome and land at the same epoch.
755        bob.accept_welcome(&welcome).unwrap();
756        carol.accept_welcome(&welcome).unwrap();
757        assert_eq!(bob.epoch(), 1);
758        assert_eq!(carol.epoch(), 1);
759
760        // All three share the group key → mutual decryption.
761        let ct = alice.seal(StreamLabel::Text, 1, b"hello group").unwrap();
762        assert_eq!(bob.open(StreamLabel::Text, 1, &ct).unwrap(), b"hello group");
763        assert_eq!(carol.open(StreamLabel::Text, 1, &ct).unwrap(), b"hello group");
764    }
765
766    #[test]
767    fn restored_prekey_accepts_welcome() {
768        // A published KeyPackage's owner persists its context (export_state),
769        // then a fresh process restores it (restore_state) — the restored
770        // context MUST still accept a Welcome targeting that KeyPackage. This is
771        // the secret-DM reload path (Hush ADR-0023): the joiner's pre-key
772        // survives a reload (e.g. browser IndexedDB) and can still join.
773        let (mut alice, _akp) = alice();
774        let (bob, bob_kp) = bob();
775
776        // Persist bob's pre-key context, then drop the live one (simulate reload).
777        let bob_blob = bob.export_state().unwrap();
778        let bob_kp_inner = bob_kp.key_package().clone();
779        drop(bob);
780
781        // Alice invites bob's published KeyPackage.
782        let welcome = alice.invite(&[bob_kp_inner]).unwrap();
783        assert_eq!(alice.epoch(), 1);
784
785        // Bob restored from the blob accepts the Welcome — i.e. the private
786        // KeyPackage keys (init/encryption) survived export/restore.
787        let mut bob_restored = MlsContext::restore_state(&bob_blob).unwrap();
788        bob_restored.accept_welcome(&welcome).unwrap();
789        assert_eq!(bob_restored.epoch(), 1);
790
791        // Mutual decryption confirms the shared group.
792        let ct = alice.seal(StreamLabel::Text, 1, b"after reload").unwrap();
793        assert_eq!(bob_restored.open(StreamLabel::Text, 1, &ct).unwrap(), b"after reload");
794    }
795
796    #[test]
797    fn restore_state_rejects_truncated_blob() {
798        let (ctx, _kp) = alice();
799        let blob = ctx.export_state().unwrap();
800        assert!(MlsContext::restore_state(&blob[..blob.len() / 2]).is_err());
801        assert!(MlsContext::restore_state(&[]).is_err());
802    }
803}