Skip to main content

gbp_mls/
lib.rs

1//! MLS (RFC 9420) integration for the Group Protocol Stack.
2//!
3//! This crate provides:
4//!
5//! * [`MlsContext`] — a member-side wrapper around an `openmls 0.8` group
6//!   (signing key, credential, provider, current group).
7//! * [`StreamLabel`] — labelled exporter constants used to derive AEAD keys
8//!   from the MLS exporter (`gbp/control`, `gbp/audio`, `gbp/text`,
9//!   `gbp/signal`).
10//! * `seal` / `open` — ChaCha20-Poly1305 AEAD with the labelled-exporter key.
11//!
12//! On every epoch change the old key material is invalidated automatically:
13//! the AEAD key is derived on the fly from `MlsGroup::export_secret`, never
14//! cached, and the previous epoch's secret becomes unreachable as soon as the
15//! group ratchets forward.
16
17#![deny(missing_docs)]
18
19use chacha20poly1305::{
20    ChaCha20Poly1305, Key, Nonce,
21    aead::{Aead, KeyInit},
22};
23use gbp_core::StreamType;
24use openmls::prelude::tls_codec::Serialize as _;
25use openmls::prelude::*;
26use openmls_basic_credential::SignatureKeyPair;
27use openmls_rust_crypto::OpenMlsRustCrypto;
28
29/// MLS ciphersuite used by the stack: X25519-AES128GCM-SHA256-Ed25519.
30pub const CIPHERSUITE: Ciphersuite = Ciphersuite::MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519;
31
32/// Exporter label that binds the AEAD key to a stream class.
33#[derive(Copy, Clone, Debug, PartialEq, Eq)]
34pub enum StreamLabel {
35    /// `gbp/control` — control plane key.
36    Control,
37    /// `gbp/audio` — GAP key.
38    Audio,
39    /// `gbp/text` — GTP key.
40    Text,
41    /// `gbp/signal` — GSP key.
42    Signal,
43}
44
45impl StreamLabel {
46    /// Returns the stable string used as the `MlsGroup::export_secret` label.
47    pub fn as_str(self) -> &'static str {
48        match self {
49            Self::Control => "gbp/control",
50            Self::Audio => "gbp/audio",
51            Self::Text => "gbp/text",
52            Self::Signal => "gbp/signal",
53        }
54    }
55}
56
57/// Maps a [`StreamType`] to the corresponding [`StreamLabel`].
58pub fn label_for(st: StreamType) -> StreamLabel {
59    match st {
60        StreamType::Control => StreamLabel::Control,
61        StreamType::Audio => StreamLabel::Audio,
62        StreamType::Text => StreamLabel::Text,
63        StreamType::Signal => StreamLabel::Signal,
64    }
65}
66
67/// Categorises an MLS message processed via
68/// [`MlsContext::process_message`].
69#[derive(Debug, Clone, Copy, PartialEq, Eq)]
70pub enum ProcessedKind {
71    /// A Commit message was applied to the group; epoch advanced.
72    Commit,
73    /// An Application message was decrypted (not used by this stack — GBP
74    /// carries application data outside MLS application messages).
75    Application,
76    /// A Proposal-only message was staged.
77    Proposal,
78    /// An external message that did not advance the group.
79    External,
80}
81
82/// Errors raised by the MLS / AEAD layer.
83#[derive(Debug, thiserror::Error)]
84pub enum MlsError {
85    /// Any error returned by `openmls`, serialised as a string.
86    #[error("openmls: {0}")]
87    OpenMls(String),
88    /// AEAD seal or open failure.
89    #[error("aead: {0}")]
90    Aead(String),
91    /// A pending staged commit already exists — the previous transition must
92    /// be finalised or cleared before processing another commit.
93    #[error("transition in progress: pending staged commit exists")]
94    TransitionInProgress,
95}
96
97/// MLS context for a single group member.
98///
99/// Owns the OpenMLS provider, the signing key, the credential and the
100/// current `MlsGroup`. Ratcheting forward is performed by [`MlsContext::invite`]
101/// and [`MlsContext::accept_welcome`].
102pub struct MlsContext {
103    /// OpenMLS crypto provider.
104    pub provider: OpenMlsRustCrypto,
105    /// Signing key pair for this member.
106    pub signer: SignatureKeyPair,
107    /// Current MLS group.
108    pub group: MlsGroup,
109    /// Credential with the public signing key.
110    pub credential: CredentialWithKey,
111    /// Member identity (opaque application-defined bytes).
112    pub identity: Vec<u8>,
113    /// Staged commit produced by [`MlsContext::process_message`] but not
114    /// yet merged. Held until [`MlsContext::finalize_pending_commit`] (on
115    /// EXECUTE_TRANSITION) so that the local epoch only advances together
116    /// with the rest of the group, never earlier — otherwise this side's
117    /// READY frame would be sealed under an epoch the coordinator can't
118    /// open.
119    pub pending_staged: Option<StagedCommit>,
120}
121
122impl MlsContext {
123    /// Creates a new context with a single-member group, returning the
124    /// context together with a [`KeyPackageBundle`] that other members can
125    /// use to invite this one.
126    pub fn new_member(identity: &[u8]) -> Result<(Self, KeyPackageBundle), MlsError> {
127        let provider = OpenMlsRustCrypto::default();
128        let signer = SignatureKeyPair::new(CIPHERSUITE.signature_algorithm())
129            .map_err(|e| MlsError::OpenMls(format!("signer: {e:?}")))?;
130        signer
131            .store(provider.storage())
132            .map_err(|e| MlsError::OpenMls(format!("store signer: {e:?}")))?;
133
134        let credential = BasicCredential::new(identity.to_vec());
135        let credential_with_key = CredentialWithKey {
136            credential: credential.into(),
137            signature_key: signer.public().into(),
138        };
139
140        let kp_bundle = KeyPackage::builder()
141            .build(CIPHERSUITE, &provider, &signer, credential_with_key.clone())
142            .map_err(|e| MlsError::OpenMls(format!("kp: {e:?}")))?;
143
144        let cfg = MlsGroupCreateConfig::builder()
145            .ciphersuite(CIPHERSUITE)
146            .use_ratchet_tree_extension(true)
147            .build();
148        let group = MlsGroup::new(&provider, &signer, &cfg, credential_with_key.clone())
149            .map_err(|e| MlsError::OpenMls(format!("group: {e:?}")))?;
150
151        Ok((
152            Self {
153                provider,
154                signer,
155                group,
156                credential: credential_with_key,
157                identity: identity.to_vec(),
158                pending_staged: None,
159            },
160            kp_bundle,
161        ))
162    }
163
164    /// Result of [`MlsContext::invite_full`]: the Commit message that
165    /// existing members must apply via [`MlsContext::process_message`],
166    /// plus the Welcome that the new joiner must apply via
167    /// [`MlsContext::accept_welcome`].
168    ///
169    /// RFC 9420 §11/§12.4 — Welcome is for the joiner only; existing members
170    /// MUST receive the Commit to advance their epoch.
171    ///
172    /// IMPORTANT: this call **does not** merge the pending commit. The
173    /// caller MUST call [`MlsContext::finalize_pending_commit`] only after
174    /// they are confident the Commit/Welcome have been distributed (e.g.
175    /// the GBP coordinator has observed READY quorum). If the distribution
176    /// fails, call [`MlsContext::clear_pending_commit`] to roll back.
177    pub fn invite_full(
178        &mut self,
179        key_packages: &[KeyPackage],
180    ) -> Result<(Vec<u8>, Vec<u8>), MlsError> {
181        let (commit, welcome, _gi) = self
182            .group
183            .add_members(&self.provider, &self.signer, key_packages)
184            .map_err(|e| MlsError::OpenMls(format!("add_members: {e:?}")))?;
185        let commit_bytes = commit
186            .tls_serialize_detached()
187            .map_err(|e| MlsError::OpenMls(format!("commit serialize: {e:?}")))?;
188        let welcome_bytes = welcome
189            .tls_serialize_detached()
190            .map_err(|e| MlsError::OpenMls(format!("welcome serialize: {e:?}")))?;
191        Ok((commit_bytes, welcome_bytes))
192    }
193
194    /// Backwards-compatible wrapper. Builds the Commit, eagerly merges, and
195    /// returns only the Welcome bytes. Kept for callers that distribute the
196    /// Commit out-of-band and don't need atomic abort semantics.
197    pub fn invite(&mut self, key_packages: &[KeyPackage]) -> Result<Vec<u8>, MlsError> {
198        let (_commit, welcome) = self.invite_full(key_packages)?;
199        self.finalize_pending_commit()?;
200        Ok(welcome)
201    }
202
203    /// Removes members identified by their MLS LeafIndex via a Remove commit
204    /// and returns the TLS-serialised Commit message that remaining members
205    /// must apply via [`MlsContext::process_message`].
206    ///
207    /// Like [`MlsContext::invite_full`], the caller is responsible for
208    /// calling [`MlsContext::finalize_pending_commit`] after successful
209    /// distribution, or [`MlsContext::clear_pending_commit`] on failure.
210    /// RFC 9420 §12.3.
211    pub fn remove_members(&mut self, leaf_indices: &[u32]) -> Result<Vec<u8>, MlsError> {
212        // Validate indices against the current group size up front so the
213        // caller gets a clear error rather than an opaque openmls failure.
214        let group_size = self.group.members().count() as u32;
215        for &idx in leaf_indices {
216            if idx >= group_size {
217                return Err(MlsError::OpenMls(format!(
218                    "leaf_index {idx} out of range (group size {group_size})"
219                )));
220            }
221        }
222        let leaves: Vec<LeafNodeIndex> = leaf_indices
223            .iter()
224            .copied()
225            .map(LeafNodeIndex::new)
226            .collect();
227        let (commit, _welcome_opt, _gi) = self
228            .group
229            .remove_members(&self.provider, &self.signer, &leaves)
230            .map_err(|e| MlsError::OpenMls(format!("remove_members: {e:?}")))?;
231        commit
232            .tls_serialize_detached()
233            .map_err(|e| MlsError::OpenMls(format!("commit serialize: {e:?}")))
234    }
235
236    /// Merges any pending commit. Handles both:
237    /// * a self-issued commit produced by [`MlsContext::invite_full`] /
238    ///   [`MlsContext::remove_members`] (merged via `merge_pending_commit`);
239    /// * a staged commit deposited by [`MlsContext::process_message`]
240    ///   (merged via `merge_staged_commit`, consumed from
241    ///   [`MlsContext::pending_staged`]).
242    ///
243    /// Idempotent: if there is nothing to merge, returns Ok. Called from
244    /// the GBP control plane in response to `EXECUTE_TRANSITION`.
245    pub fn finalize_pending_commit(&mut self) -> Result<(), MlsError> {
246        if let Some(staged) = self.pending_staged.take() {
247            self.group
248                .merge_staged_commit(&self.provider, staged)
249                .map_err(|e| MlsError::OpenMls(format!("merge_staged: {e:?}")))?;
250        }
251        // merge_pending_commit errors if there's nothing to merge — for
252        // members that only received a commit (no self-issued one) that's
253        // expected, so swallow the error. Self-issued commits are merged
254        // via this path on the coordinator side.
255        let _ = self.group.merge_pending_commit(&self.provider);
256        Ok(())
257    }
258
259    /// Discards any pending commit (self-issued and/or staged) without
260    /// applying it. Used on `ABORT_TRANSITION`.
261    pub fn clear_pending_commit(&mut self) -> Result<(), MlsError> {
262        self.pending_staged = None;
263        self.group
264            .clear_pending_commit(self.provider.storage())
265            .map_err(|e| MlsError::OpenMls(format!("clear: {e:?}")))?;
266        Ok(())
267    }
268
269    /// Applies a Commit (or staged Proposal) message to the group. Existing
270    /// members invoke this after receiving the Commit broadcast embedded in
271    /// `PREPARE_TRANSITION` args.
272    ///
273    /// IMPORTANT: a Commit is staged but **not** merged here. It must be
274    /// merged via [`MlsContext::finalize_pending_commit`] in response to the
275    /// matching `EXECUTE_TRANSITION`, so that this side's MLS epoch
276    /// advances together with the rest of the group — never earlier.
277    /// Calling this twice without an intervening finalize/clear discards
278    /// the previously staged commit (the second call wins).
279    pub fn process_message(&mut self, msg_bytes: &[u8]) -> Result<ProcessedKind, MlsError> {
280        let msg_in = MlsMessageIn::tls_deserialize_exact_bytes(msg_bytes)
281            .map_err(|e| MlsError::OpenMls(format!("msg parse: {e:?}")))?;
282        let protocol_msg = match msg_in.extract() {
283            MlsMessageBodyIn::PublicMessage(m) => ProtocolMessage::from(m),
284            MlsMessageBodyIn::PrivateMessage(m) => ProtocolMessage::from(m),
285            other => {
286                return Err(MlsError::OpenMls(format!(
287                    "expected protocol message, got {other:?}"
288                )));
289            }
290        };
291        let processed = self
292            .group
293            .process_message(&self.provider, protocol_msg)
294            .map_err(|e| MlsError::OpenMls(format!("process: {e:?}")))?;
295        match processed.into_content() {
296            ProcessedMessageContent::StagedCommitMessage(staged) => {
297                if self.pending_staged.is_some() {
298                    return Err(MlsError::TransitionInProgress);
299                }
300                self.pending_staged = Some(*staged);
301                Ok(ProcessedKind::Commit)
302            }
303            ProcessedMessageContent::ApplicationMessage(_) => Ok(ProcessedKind::Application),
304            ProcessedMessageContent::ProposalMessage(_) => Ok(ProcessedKind::Proposal),
305            ProcessedMessageContent::ExternalJoinProposalMessage(_) => Ok(ProcessedKind::External),
306        }
307    }
308
309    /// Replaces the local group with the one described by the given
310    /// `Welcome` message.
311    pub fn accept_welcome(&mut self, welcome_bytes: &[u8]) -> Result<(), MlsError> {
312        let msg_in = MlsMessageIn::tls_deserialize_exact_bytes(welcome_bytes)
313            .map_err(|e| MlsError::OpenMls(format!("welcome parse: {e:?}")))?;
314        let welcome = match msg_in.extract() {
315            MlsMessageBodyIn::Welcome(w) => w,
316            other => {
317                return Err(MlsError::OpenMls(format!(
318                    "expected welcome, got {other:?}"
319                )));
320            }
321        };
322        let join_cfg = MlsGroupJoinConfig::builder()
323            .use_ratchet_tree_extension(true)
324            .build();
325        let staged = StagedWelcome::new_from_welcome(&self.provider, &join_cfg, welcome, None)
326            .map_err(|e| MlsError::OpenMls(format!("staged: {e:?}")))?;
327        self.group = staged
328            .into_group(&self.provider)
329            .map_err(|e| MlsError::OpenMls(format!("into_group: {e:?}")))?;
330        Ok(())
331    }
332
333    /// Returns the current group epoch.
334    pub fn epoch(&self) -> u64 {
335        self.group.epoch().as_u64()
336    }
337
338    /// Returns the 16-byte group identifier (truncated or zero-padded if the
339    /// underlying MLS group_id has a different length).
340    pub fn group_id_16(&self) -> [u8; 16] {
341        let raw = self.group.group_id().as_slice();
342        let mut out = [0u8; 16];
343        let n = raw.len().min(16);
344        out[..n].copy_from_slice(&raw[..n]);
345        out
346    }
347
348    /// Exports a 32-byte secret under the given stream label.
349    pub fn export_stream_key(&self, label: StreamLabel) -> Result<[u8; 32], MlsError> {
350        let secret = self
351            .group
352            .export_secret(self.provider.crypto(), label.as_str(), &[], 32)
353            .map_err(|e| MlsError::OpenMls(format!("export: {e:?}")))?;
354        let mut out = [0u8; 32];
355        out.copy_from_slice(&secret);
356        Ok(out)
357    }
358
359    /// Exports `len` bytes under an arbitrary `label` and `context`.
360    ///
361    /// Used by external crates (e.g. `hush-sframe`) that need custom KDF
362    /// labels without depending on OpenMLS directly.
363    pub fn export_raw(&self, label: &str, context: &[u8], len: usize) -> Result<Vec<u8>, MlsError> {
364        let secret = self
365            .group
366            .export_secret(self.provider.crypto(), label, context, len)
367            .map_err(|e| MlsError::OpenMls(format!("export_raw: {e:?}")))?;
368        Ok(secret.to_vec())
369    }
370
371    /// Encrypts `plaintext` with ChaCha20-Poly1305 using the stream-labelled
372    /// AEAD key and a nonce derived from the per-stream `seq`.
373    pub fn seal(
374        &self,
375        label: StreamLabel,
376        seq: u32,
377        plaintext: &[u8],
378    ) -> Result<Vec<u8>, MlsError> {
379        let key = self.export_stream_key(label)?;
380        let cipher = ChaCha20Poly1305::new(Key::from_slice(&key));
381        let mut nonce = [0u8; 12];
382        nonce[..4].copy_from_slice(&seq.to_be_bytes());
383        cipher
384            .encrypt(Nonce::from_slice(&nonce), plaintext)
385            .map_err(|e| MlsError::Aead(e.to_string()))
386    }
387
388    /// Decrypts `ciphertext` with the same parameters as [`MlsContext::seal`].
389    pub fn open(
390        &self,
391        label: StreamLabel,
392        seq: u32,
393        ciphertext: &[u8],
394    ) -> Result<Vec<u8>, MlsError> {
395        let key = self.export_stream_key(label)?;
396        let cipher = ChaCha20Poly1305::new(Key::from_slice(&key));
397        let mut nonce = [0u8; 12];
398        nonce[..4].copy_from_slice(&seq.to_be_bytes());
399        cipher
400            .decrypt(Nonce::from_slice(&nonce), ciphertext)
401            .map_err(|e| MlsError::Aead(e.to_string()))
402    }
403}
404
405#[cfg(test)]
406mod tests {
407    use super::*;
408
409    fn alice() -> (MlsContext, openmls::prelude::KeyPackageBundle) {
410        MlsContext::new_member(b"alice").unwrap()
411    }
412
413    fn bob() -> (MlsContext, openmls::prelude::KeyPackageBundle) {
414        MlsContext::new_member(b"bob").unwrap()
415    }
416
417    #[test]
418    fn stream_label_strings_are_correct() {
419        assert_eq!(StreamLabel::Control.as_str(), "gbp/control");
420        assert_eq!(StreamLabel::Audio.as_str(), "gbp/audio");
421        assert_eq!(StreamLabel::Text.as_str(), "gbp/text");
422        assert_eq!(StreamLabel::Signal.as_str(), "gbp/signal");
423    }
424
425    #[test]
426    fn label_for_maps_every_stream_type() {
427        assert_eq!(label_for(StreamType::Control), StreamLabel::Control);
428        assert_eq!(label_for(StreamType::Audio), StreamLabel::Audio);
429        assert_eq!(label_for(StreamType::Text), StreamLabel::Text);
430        assert_eq!(label_for(StreamType::Signal), StreamLabel::Signal);
431    }
432
433    #[test]
434    fn new_member_starts_at_epoch_zero() {
435        let (ctx, _kp) = alice();
436        assert_eq!(ctx.epoch(), 0);
437    }
438
439    #[test]
440    fn group_id_16_is_16_bytes() {
441        let (ctx, _kp) = alice();
442        let id = ctx.group_id_16();
443        assert_eq!(id.len(), 16);
444    }
445
446    #[test]
447    fn export_stream_key_is_32_bytes_and_stable() {
448        let (ctx, _kp) = alice();
449        let k1 = ctx.export_stream_key(StreamLabel::Text).unwrap();
450        let k2 = ctx.export_stream_key(StreamLabel::Text).unwrap();
451        assert_eq!(k1.len(), 32);
452        assert_eq!(k1, k2);
453    }
454
455    #[test]
456    fn different_labels_produce_different_keys() {
457        let (ctx, _kp) = alice();
458        let k_ctrl = ctx.export_stream_key(StreamLabel::Control).unwrap();
459        let k_text = ctx.export_stream_key(StreamLabel::Text).unwrap();
460        assert_ne!(k_ctrl, k_text);
461    }
462
463    #[test]
464    fn seal_open_single_member_round_trip() {
465        let (ctx, _kp) = alice();
466        let plaintext = b"hello world";
467        let ciphertext = ctx.seal(StreamLabel::Text, 1, plaintext).unwrap();
468        assert_ne!(ciphertext, plaintext);
469        let recovered = ctx.open(StreamLabel::Text, 1, &ciphertext).unwrap();
470        assert_eq!(recovered, plaintext);
471    }
472
473    #[test]
474    fn seal_wrong_seq_fails_to_open() {
475        let (ctx, _kp) = alice();
476        let ciphertext = ctx.seal(StreamLabel::Text, 1, b"secret").unwrap();
477        assert!(ctx.open(StreamLabel::Text, 2, &ciphertext).is_err());
478    }
479
480    #[test]
481    fn seal_wrong_label_fails_to_open() {
482        let (ctx, _kp) = alice();
483        let ciphertext = ctx.seal(StreamLabel::Text, 0, b"secret").unwrap();
484        assert!(ctx.open(StreamLabel::Audio, 0, &ciphertext).is_err());
485    }
486
487    #[test]
488    fn two_member_invite_and_welcome() {
489        let (mut alice, _akp) = alice();
490        let (mut bob, bob_kp) = bob();
491
492        let welcome = alice.invite(&[bob_kp.key_package().clone()]).unwrap();
493        // Alice's epoch advances after invite.
494        assert_eq!(alice.epoch(), 1);
495
496        bob.accept_welcome(&welcome).unwrap();
497        // Bob joins at epoch 1.
498        assert_eq!(bob.epoch(), 1);
499    }
500
501    #[test]
502    fn two_member_seal_open_cross_member() {
503        let (mut alice, _akp) = alice();
504        let (mut bob, bob_kp) = bob();
505
506        let welcome = alice.invite(&[bob_kp.key_package().clone()]).unwrap();
507        bob.accept_welcome(&welcome).unwrap();
508
509        let plaintext = b"cross-member secret";
510        let ct = alice.seal(StreamLabel::Control, 0, plaintext).unwrap();
511        let recovered = bob.open(StreamLabel::Control, 0, &ct).unwrap();
512        assert_eq!(recovered, plaintext);
513    }
514
515    #[test]
516    fn export_raw_returns_requested_length() {
517        let (ctx, _kp) = alice();
518        let raw = ctx.export_raw("test/label", b"ctx", 48).unwrap();
519        assert_eq!(raw.len(), 48);
520    }
521
522    #[test]
523    fn clear_pending_commit_is_idempotent() {
524        let (mut ctx, _kp) = alice();
525        ctx.clear_pending_commit().unwrap();
526        ctx.clear_pending_commit().unwrap();
527    }
528
529    #[test]
530    fn finalize_pending_commit_on_fresh_group_is_ok() {
531        let (mut ctx, _kp) = alice();
532        ctx.finalize_pending_commit().unwrap();
533    }
534
535    #[test]
536    fn invite_full_does_not_advance_epoch_until_finalize() {
537        let (mut alice, _akp) = alice();
538        let (_bob, bob_kp) = bob();
539
540        let (_commit, _welcome) = alice.invite_full(&[bob_kp.key_package().clone()]).unwrap();
541        // invite_full does NOT merge → epoch still 0
542        assert_eq!(alice.epoch(), 0);
543
544        alice.finalize_pending_commit().unwrap();
545        // after finalize → epoch 1
546        assert_eq!(alice.epoch(), 1);
547
548        // New members join via welcome, not via commit.
549        let (mut alice2, _akp2) = MlsContext::new_member(b"alice2").unwrap();
550        let (mut bob2, bob2_kp) = MlsContext::new_member(b"bob2").unwrap();
551        let (_commit_bytes, welcome_bytes) = alice2
552            .invite_full(&[bob2_kp.key_package().clone()])
553            .unwrap();
554        alice2.finalize_pending_commit().unwrap();
555        bob2.accept_welcome(&welcome_bytes).unwrap();
556        assert_eq!(alice2.epoch(), 1);
557        assert_eq!(bob2.epoch(), 1);
558    }
559}