1#![deny(missing_docs)]
18
19use chacha20poly1305::{
20 ChaCha20Poly1305, Key, Nonce,
21 aead::{Aead, KeyInit},
22};
23use gbp_core::StreamType;
24use openmls::prelude::tls_codec::Serialize as _;
25use openmls::prelude::*;
26use openmls_basic_credential::SignatureKeyPair;
27use openmls_rust_crypto::OpenMlsRustCrypto;
28
29pub const CIPHERSUITE: Ciphersuite = Ciphersuite::MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519;
31
32#[derive(Copy, Clone, Debug, PartialEq, Eq)]
34pub enum StreamLabel {
35 Control,
37 Audio,
39 Text,
41 Signal,
43}
44
45impl StreamLabel {
46 pub fn as_str(self) -> &'static str {
48 match self {
49 Self::Control => "gbp/control",
50 Self::Audio => "gbp/audio",
51 Self::Text => "gbp/text",
52 Self::Signal => "gbp/signal",
53 }
54 }
55}
56
57pub fn label_for(st: StreamType) -> StreamLabel {
59 match st {
60 StreamType::Control => StreamLabel::Control,
61 StreamType::Audio => StreamLabel::Audio,
62 StreamType::Text => StreamLabel::Text,
63 StreamType::Signal => StreamLabel::Signal,
64 }
65}
66
67#[derive(Debug, Clone, Copy, PartialEq, Eq)]
70pub enum ProcessedKind {
71 Commit,
73 Application,
76 Proposal,
78 External,
80}
81
82#[derive(Debug, thiserror::Error)]
84pub enum MlsError {
85 #[error("openmls: {0}")]
87 OpenMls(String),
88 #[error("aead: {0}")]
90 Aead(String),
91 #[error("transition in progress: pending staged commit exists")]
94 TransitionInProgress,
95}
96
97pub struct MlsContext {
103 pub provider: OpenMlsRustCrypto,
105 pub signer: SignatureKeyPair,
107 pub group: MlsGroup,
109 pub credential: CredentialWithKey,
111 pub identity: Vec<u8>,
113 pub pending_staged: Option<StagedCommit>,
120}
121
122impl MlsContext {
123 pub fn new_member(identity: &[u8]) -> Result<(Self, KeyPackageBundle), MlsError> {
127 let provider = OpenMlsRustCrypto::default();
128 let signer = SignatureKeyPair::new(CIPHERSUITE.signature_algorithm())
129 .map_err(|e| MlsError::OpenMls(format!("signer: {e:?}")))?;
130 signer
131 .store(provider.storage())
132 .map_err(|e| MlsError::OpenMls(format!("store signer: {e:?}")))?;
133
134 let credential = BasicCredential::new(identity.to_vec());
135 let credential_with_key = CredentialWithKey {
136 credential: credential.into(),
137 signature_key: signer.public().into(),
138 };
139
140 let kp_bundle = KeyPackage::builder()
141 .build(CIPHERSUITE, &provider, &signer, credential_with_key.clone())
142 .map_err(|e| MlsError::OpenMls(format!("kp: {e:?}")))?;
143
144 let cfg = MlsGroupCreateConfig::builder()
145 .ciphersuite(CIPHERSUITE)
146 .use_ratchet_tree_extension(true)
147 .build();
148 let group = MlsGroup::new(&provider, &signer, &cfg, credential_with_key.clone())
149 .map_err(|e| MlsError::OpenMls(format!("group: {e:?}")))?;
150
151 Ok((
152 Self {
153 provider,
154 signer,
155 group,
156 credential: credential_with_key,
157 identity: identity.to_vec(),
158 pending_staged: None,
159 },
160 kp_bundle,
161 ))
162 }
163
164 pub fn invite_full(
178 &mut self,
179 key_packages: &[KeyPackage],
180 ) -> Result<(Vec<u8>, Vec<u8>), MlsError> {
181 let (commit, welcome, _gi) = self
182 .group
183 .add_members(&self.provider, &self.signer, key_packages)
184 .map_err(|e| MlsError::OpenMls(format!("add_members: {e:?}")))?;
185 let commit_bytes = commit
186 .tls_serialize_detached()
187 .map_err(|e| MlsError::OpenMls(format!("commit serialize: {e:?}")))?;
188 let welcome_bytes = welcome
189 .tls_serialize_detached()
190 .map_err(|e| MlsError::OpenMls(format!("welcome serialize: {e:?}")))?;
191 Ok((commit_bytes, welcome_bytes))
192 }
193
194 pub fn invite(&mut self, key_packages: &[KeyPackage]) -> Result<Vec<u8>, MlsError> {
198 let (_commit, welcome) = self.invite_full(key_packages)?;
199 self.finalize_pending_commit()?;
200 Ok(welcome)
201 }
202
203 pub fn remove_members(&mut self, leaf_indices: &[u32]) -> Result<Vec<u8>, MlsError> {
212 let group_size = self.group.members().count() as u32;
215 for &idx in leaf_indices {
216 if idx >= group_size {
217 return Err(MlsError::OpenMls(format!(
218 "leaf_index {idx} out of range (group size {group_size})"
219 )));
220 }
221 }
222 let leaves: Vec<LeafNodeIndex> = leaf_indices
223 .iter()
224 .copied()
225 .map(LeafNodeIndex::new)
226 .collect();
227 let (commit, _welcome_opt, _gi) = self
228 .group
229 .remove_members(&self.provider, &self.signer, &leaves)
230 .map_err(|e| MlsError::OpenMls(format!("remove_members: {e:?}")))?;
231 commit
232 .tls_serialize_detached()
233 .map_err(|e| MlsError::OpenMls(format!("commit serialize: {e:?}")))
234 }
235
236 pub fn finalize_pending_commit(&mut self) -> Result<(), MlsError> {
246 if let Some(staged) = self.pending_staged.take() {
247 self.group
248 .merge_staged_commit(&self.provider, staged)
249 .map_err(|e| MlsError::OpenMls(format!("merge_staged: {e:?}")))?;
250 }
251 let _ = self.group.merge_pending_commit(&self.provider);
256 Ok(())
257 }
258
259 pub fn clear_pending_commit(&mut self) -> Result<(), MlsError> {
262 self.pending_staged = None;
263 self.group
264 .clear_pending_commit(self.provider.storage())
265 .map_err(|e| MlsError::OpenMls(format!("clear: {e:?}")))?;
266 Ok(())
267 }
268
269 pub fn process_message(&mut self, msg_bytes: &[u8]) -> Result<ProcessedKind, MlsError> {
280 let msg_in = MlsMessageIn::tls_deserialize_exact_bytes(msg_bytes)
281 .map_err(|e| MlsError::OpenMls(format!("msg parse: {e:?}")))?;
282 let protocol_msg = match msg_in.extract() {
283 MlsMessageBodyIn::PublicMessage(m) => ProtocolMessage::from(m),
284 MlsMessageBodyIn::PrivateMessage(m) => ProtocolMessage::from(m),
285 other => {
286 return Err(MlsError::OpenMls(format!(
287 "expected protocol message, got {other:?}"
288 )));
289 }
290 };
291 let processed = self
292 .group
293 .process_message(&self.provider, protocol_msg)
294 .map_err(|e| MlsError::OpenMls(format!("process: {e:?}")))?;
295 match processed.into_content() {
296 ProcessedMessageContent::StagedCommitMessage(staged) => {
297 if self.pending_staged.is_some() {
298 return Err(MlsError::TransitionInProgress);
299 }
300 self.pending_staged = Some(*staged);
301 Ok(ProcessedKind::Commit)
302 }
303 ProcessedMessageContent::ApplicationMessage(_) => Ok(ProcessedKind::Application),
304 ProcessedMessageContent::ProposalMessage(_) => Ok(ProcessedKind::Proposal),
305 ProcessedMessageContent::ExternalJoinProposalMessage(_) => Ok(ProcessedKind::External),
306 }
307 }
308
309 pub fn accept_welcome(&mut self, welcome_bytes: &[u8]) -> Result<(), MlsError> {
312 let msg_in = MlsMessageIn::tls_deserialize_exact_bytes(welcome_bytes)
313 .map_err(|e| MlsError::OpenMls(format!("welcome parse: {e:?}")))?;
314 let welcome = match msg_in.extract() {
315 MlsMessageBodyIn::Welcome(w) => w,
316 other => {
317 return Err(MlsError::OpenMls(format!(
318 "expected welcome, got {other:?}"
319 )));
320 }
321 };
322 let join_cfg = MlsGroupJoinConfig::builder()
323 .use_ratchet_tree_extension(true)
324 .build();
325 let staged = StagedWelcome::new_from_welcome(&self.provider, &join_cfg, welcome, None)
326 .map_err(|e| MlsError::OpenMls(format!("staged: {e:?}")))?;
327 self.group = staged
328 .into_group(&self.provider)
329 .map_err(|e| MlsError::OpenMls(format!("into_group: {e:?}")))?;
330 Ok(())
331 }
332
333 pub fn epoch(&self) -> u64 {
335 self.group.epoch().as_u64()
336 }
337
338 pub fn group_id_16(&self) -> [u8; 16] {
341 let raw = self.group.group_id().as_slice();
342 let mut out = [0u8; 16];
343 let n = raw.len().min(16);
344 out[..n].copy_from_slice(&raw[..n]);
345 out
346 }
347
348 pub fn export_stream_key(&self, label: StreamLabel) -> Result<[u8; 32], MlsError> {
350 let secret = self
351 .group
352 .export_secret(self.provider.crypto(), label.as_str(), &[], 32)
353 .map_err(|e| MlsError::OpenMls(format!("export: {e:?}")))?;
354 let mut out = [0u8; 32];
355 out.copy_from_slice(&secret);
356 Ok(out)
357 }
358
359 pub fn export_raw(&self, label: &str, context: &[u8], len: usize) -> Result<Vec<u8>, MlsError> {
364 let secret = self
365 .group
366 .export_secret(self.provider.crypto(), label, context, len)
367 .map_err(|e| MlsError::OpenMls(format!("export_raw: {e:?}")))?;
368 Ok(secret.to_vec())
369 }
370
371 pub fn seal(
374 &self,
375 label: StreamLabel,
376 seq: u32,
377 plaintext: &[u8],
378 ) -> Result<Vec<u8>, MlsError> {
379 let key = self.export_stream_key(label)?;
380 let cipher = ChaCha20Poly1305::new(Key::from_slice(&key));
381 let mut nonce = [0u8; 12];
382 nonce[..4].copy_from_slice(&seq.to_be_bytes());
383 cipher
384 .encrypt(Nonce::from_slice(&nonce), plaintext)
385 .map_err(|e| MlsError::Aead(e.to_string()))
386 }
387
388 pub fn open(
390 &self,
391 label: StreamLabel,
392 seq: u32,
393 ciphertext: &[u8],
394 ) -> Result<Vec<u8>, MlsError> {
395 let key = self.export_stream_key(label)?;
396 let cipher = ChaCha20Poly1305::new(Key::from_slice(&key));
397 let mut nonce = [0u8; 12];
398 nonce[..4].copy_from_slice(&seq.to_be_bytes());
399 cipher
400 .decrypt(Nonce::from_slice(&nonce), ciphertext)
401 .map_err(|e| MlsError::Aead(e.to_string()))
402 }
403}
404
405#[cfg(test)]
406mod tests {
407 use super::*;
408
409 fn alice() -> (MlsContext, openmls::prelude::KeyPackageBundle) {
410 MlsContext::new_member(b"alice").unwrap()
411 }
412
413 fn bob() -> (MlsContext, openmls::prelude::KeyPackageBundle) {
414 MlsContext::new_member(b"bob").unwrap()
415 }
416
417 #[test]
418 fn stream_label_strings_are_correct() {
419 assert_eq!(StreamLabel::Control.as_str(), "gbp/control");
420 assert_eq!(StreamLabel::Audio.as_str(), "gbp/audio");
421 assert_eq!(StreamLabel::Text.as_str(), "gbp/text");
422 assert_eq!(StreamLabel::Signal.as_str(), "gbp/signal");
423 }
424
425 #[test]
426 fn label_for_maps_every_stream_type() {
427 assert_eq!(label_for(StreamType::Control), StreamLabel::Control);
428 assert_eq!(label_for(StreamType::Audio), StreamLabel::Audio);
429 assert_eq!(label_for(StreamType::Text), StreamLabel::Text);
430 assert_eq!(label_for(StreamType::Signal), StreamLabel::Signal);
431 }
432
433 #[test]
434 fn new_member_starts_at_epoch_zero() {
435 let (ctx, _kp) = alice();
436 assert_eq!(ctx.epoch(), 0);
437 }
438
439 #[test]
440 fn group_id_16_is_16_bytes() {
441 let (ctx, _kp) = alice();
442 let id = ctx.group_id_16();
443 assert_eq!(id.len(), 16);
444 }
445
446 #[test]
447 fn export_stream_key_is_32_bytes_and_stable() {
448 let (ctx, _kp) = alice();
449 let k1 = ctx.export_stream_key(StreamLabel::Text).unwrap();
450 let k2 = ctx.export_stream_key(StreamLabel::Text).unwrap();
451 assert_eq!(k1.len(), 32);
452 assert_eq!(k1, k2);
453 }
454
455 #[test]
456 fn different_labels_produce_different_keys() {
457 let (ctx, _kp) = alice();
458 let k_ctrl = ctx.export_stream_key(StreamLabel::Control).unwrap();
459 let k_text = ctx.export_stream_key(StreamLabel::Text).unwrap();
460 assert_ne!(k_ctrl, k_text);
461 }
462
463 #[test]
464 fn seal_open_single_member_round_trip() {
465 let (ctx, _kp) = alice();
466 let plaintext = b"hello world";
467 let ciphertext = ctx.seal(StreamLabel::Text, 1, plaintext).unwrap();
468 assert_ne!(ciphertext, plaintext);
469 let recovered = ctx.open(StreamLabel::Text, 1, &ciphertext).unwrap();
470 assert_eq!(recovered, plaintext);
471 }
472
473 #[test]
474 fn seal_wrong_seq_fails_to_open() {
475 let (ctx, _kp) = alice();
476 let ciphertext = ctx.seal(StreamLabel::Text, 1, b"secret").unwrap();
477 assert!(ctx.open(StreamLabel::Text, 2, &ciphertext).is_err());
478 }
479
480 #[test]
481 fn seal_wrong_label_fails_to_open() {
482 let (ctx, _kp) = alice();
483 let ciphertext = ctx.seal(StreamLabel::Text, 0, b"secret").unwrap();
484 assert!(ctx.open(StreamLabel::Audio, 0, &ciphertext).is_err());
485 }
486
487 #[test]
488 fn two_member_invite_and_welcome() {
489 let (mut alice, _akp) = alice();
490 let (mut bob, bob_kp) = bob();
491
492 let welcome = alice.invite(&[bob_kp.key_package().clone()]).unwrap();
493 assert_eq!(alice.epoch(), 1);
495
496 bob.accept_welcome(&welcome).unwrap();
497 assert_eq!(bob.epoch(), 1);
499 }
500
501 #[test]
502 fn two_member_seal_open_cross_member() {
503 let (mut alice, _akp) = alice();
504 let (mut bob, bob_kp) = bob();
505
506 let welcome = alice.invite(&[bob_kp.key_package().clone()]).unwrap();
507 bob.accept_welcome(&welcome).unwrap();
508
509 let plaintext = b"cross-member secret";
510 let ct = alice.seal(StreamLabel::Control, 0, plaintext).unwrap();
511 let recovered = bob.open(StreamLabel::Control, 0, &ct).unwrap();
512 assert_eq!(recovered, plaintext);
513 }
514
515 #[test]
516 fn export_raw_returns_requested_length() {
517 let (ctx, _kp) = alice();
518 let raw = ctx.export_raw("test/label", b"ctx", 48).unwrap();
519 assert_eq!(raw.len(), 48);
520 }
521
522 #[test]
523 fn clear_pending_commit_is_idempotent() {
524 let (mut ctx, _kp) = alice();
525 ctx.clear_pending_commit().unwrap();
526 ctx.clear_pending_commit().unwrap();
527 }
528
529 #[test]
530 fn finalize_pending_commit_on_fresh_group_is_ok() {
531 let (mut ctx, _kp) = alice();
532 ctx.finalize_pending_commit().unwrap();
533 }
534
535 #[test]
536 fn invite_full_does_not_advance_epoch_until_finalize() {
537 let (mut alice, _akp) = alice();
538 let (_bob, bob_kp) = bob();
539
540 let (_commit, _welcome) = alice.invite_full(&[bob_kp.key_package().clone()]).unwrap();
541 assert_eq!(alice.epoch(), 0);
543
544 alice.finalize_pending_commit().unwrap();
545 assert_eq!(alice.epoch(), 1);
547
548 let (mut alice2, _akp2) = MlsContext::new_member(b"alice2").unwrap();
550 let (mut bob2, bob2_kp) = MlsContext::new_member(b"bob2").unwrap();
551 let (_commit_bytes, welcome_bytes) = alice2
552 .invite_full(&[bob2_kp.key_package().clone()])
553 .unwrap();
554 alice2.finalize_pending_commit().unwrap();
555 bob2.accept_welcome(&welcome_bytes).unwrap();
556 assert_eq!(alice2.epoch(), 1);
557 assert_eq!(bob2.epoch(), 1);
558 }
559}