1use alloc::boxed::Box;
6use alloc::vec;
7use alloc::vec::Vec;
8use core::fmt::Debug;
9use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
10use mls_rs_core::{crypto::SignatureSecretKey, error::IntoAnyError};
11
12use crate::{
13 cipher_suite::CipherSuite,
14 client::MlsError,
15 client_config::ClientConfig,
16 extension::RatchetTreeExt,
17 identity::SigningIdentity,
18 protocol_version::ProtocolVersion,
19 signer::Signable,
20 time::MlsTime,
21 tree_kem::{kem::TreeKem, path_secret::PathSecret, TreeKemPrivate, UpdatePath},
22 ExtensionList, MlsRules,
23};
24
25#[cfg(all(not(mls_build_async), feature = "rayon"))]
26use {crate::iter::ParallelIteratorExt, rayon::prelude::*};
27
28use crate::tree_kem::leaf_node::LeafNode;
29
30#[cfg(not(feature = "private_message"))]
31use crate::WireFormat;
32
33#[cfg(feature = "psk")]
34use crate::{
35 group::{JustPreSharedKeyID, PskGroupId, ResumptionPSKUsage, ResumptionPsk},
36 psk::ExternalPskId,
37};
38
39use super::{
40 confirmation_tag::ConfirmationTag,
41 framing::{Content, MlsMessage, MlsMessagePayload, Sender},
42 key_schedule::{KeySchedule, WelcomeSecret},
43 message_hash::MessageHash,
44 message_processor::{path_update_required, MessageProcessor},
45 message_signature::AuthenticatedContent,
46 mls_rules::CommitDirection,
47 proposal::{Proposal, ProposalOrRef},
48 CommitEffect, CommitMessageDescription, EncryptedGroupSecrets, EpochSecrets, ExportedTree,
49 Group, GroupContext, GroupInfo, GroupState, InterimTranscriptHash, NewEpoch,
50 PendingCommitSnapshot, Welcome,
51};
52
53#[cfg(not(feature = "by_ref_proposal"))]
54use super::proposal_cache::prepare_commit;
55
56#[cfg(feature = "custom_proposal")]
57use super::proposal::CustomProposal;
58
59#[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode)]
60#[cfg_attr(feature = "arbitrary", derive(mls_rs_core::arbitrary::Arbitrary))]
61#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
62pub(crate) struct Commit {
63 pub proposals: Vec<ProposalOrRef>,
64 pub path: Option<UpdatePath>,
65}
66
67#[derive(Clone, PartialEq, Debug, MlsEncode, MlsDecode, MlsSize)]
68pub(crate) struct PendingCommit {
69 pub(crate) state: GroupState,
70 pub(crate) epoch_secrets: EpochSecrets,
71 pub(crate) private_tree: TreeKemPrivate,
72 pub(crate) key_schedule: KeySchedule,
73 pub(crate) signer: SignatureSecretKey,
74
75 pub(crate) output: CommitMessageDescription,
76
77 pub(crate) commit_message_hash: MessageHash,
78}
79
80#[cfg_attr(
81 all(feature = "ffi", not(test)),
82 safer_ffi_gen::ffi_type(clone, opaque)
83)]
84#[derive(Clone)]
85pub struct CommitSecrets(pub(crate) PendingCommitSnapshot);
86
87impl CommitSecrets {
88 pub fn from_bytes(bytes: &[u8]) -> Result<Self, MlsError> {
90 Ok(MlsDecode::mls_decode(&mut &*bytes).map(Self)?)
91 }
92
93 pub fn to_bytes(&self) -> Result<Vec<u8>, MlsError> {
95 Ok(self.0.mls_encode_to_vec()?)
96 }
97}
98
99#[cfg_attr(
100 all(feature = "ffi", not(test)),
101 safer_ffi_gen::ffi_type(clone, opaque)
102)]
103#[derive(Clone, Debug)]
104#[non_exhaustive]
105pub struct CommitOutput {
109 pub commit_message: MlsMessage,
111 pub welcome_messages: Vec<MlsMessage>,
118 pub ratchet_tree: Option<ExportedTree<'static>>,
122 pub external_commit_group_info: Option<MlsMessage>,
126 #[cfg(feature = "by_ref_proposal")]
128 pub unused_proposals: Vec<crate::mls_rules::ProposalInfo<Proposal>>,
129 pub contains_update_path: bool,
131}
132
133#[cfg_attr(all(feature = "ffi", not(test)), ::safer_ffi_gen::safer_ffi_gen)]
134impl CommitOutput {
135 #[cfg(feature = "ffi")]
137 pub fn commit_message(&self) -> &MlsMessage {
138 &self.commit_message
139 }
140
141 #[cfg(feature = "ffi")]
143 pub fn welcome_messages(&self) -> &[MlsMessage] {
144 &self.welcome_messages
145 }
146
147 #[cfg(feature = "ffi")]
151 pub fn ratchet_tree(&self) -> Option<&ExportedTree<'static>> {
152 self.ratchet_tree.as_ref()
153 }
154
155 #[cfg(feature = "ffi")]
159 pub fn external_commit_group_info(&self) -> Option<&MlsMessage> {
160 self.external_commit_group_info.as_ref()
161 }
162
163 #[cfg(all(feature = "ffi", feature = "by_ref_proposal"))]
165 pub fn unused_proposals(&self) -> &[crate::mls_rules::ProposalInfo<Proposal>] {
166 &self.unused_proposals
167 }
168}
169
170pub struct CommitBuilder<'a, C>
178where
179 C: ClientConfig + Clone,
180{
181 group: &'a mut Group<C>,
182 pub(super) proposals: Vec<Proposal>,
183 authenticated_data: Vec<u8>,
184 group_info_extensions: ExtensionList,
185 new_signer: Option<SignatureSecretKey>,
186 new_signing_identity: Option<SigningIdentity>,
187 new_leaf_node_extensions: Option<ExtensionList>,
188 commit_time: Option<MlsTime>,
189}
190
191impl<'a, C> CommitBuilder<'a, C>
192where
193 C: ClientConfig + Clone,
194{
195 pub fn add_member(mut self, key_package: MlsMessage) -> Result<CommitBuilder<'a, C>, MlsError> {
198 let proposal = self.group.add_proposal(key_package)?;
199 self.proposals.push(proposal);
200 Ok(self)
201 }
202
203 pub fn set_group_info_ext(self, extensions: ExtensionList) -> Self {
214 Self {
215 group_info_extensions: extensions,
216 ..self
217 }
218 }
219
220 pub fn remove_member(mut self, index: u32) -> Result<Self, MlsError> {
223 let proposal = self.group.remove_proposal(index)?;
224 self.proposals.push(proposal);
225 Ok(self)
226 }
227
228 pub fn set_group_context_ext(mut self, extensions: ExtensionList) -> Result<Self, MlsError> {
232 let proposal = self.group.group_context_extensions_proposal(extensions);
233 self.proposals.push(proposal);
234 Ok(self)
235 }
236
237 #[cfg(feature = "psk")]
241 pub fn add_external_psk(mut self, psk_id: ExternalPskId) -> Result<Self, MlsError> {
242 let key_id = JustPreSharedKeyID::External(psk_id);
243 let proposal = self.group.psk_proposal(key_id)?;
244 self.proposals.push(proposal);
245 Ok(self)
246 }
247
248 #[cfg(feature = "psk")]
252 pub fn add_resumption_psk(mut self, psk_epoch: u64) -> Result<Self, MlsError> {
253 let psk_id = ResumptionPsk {
254 psk_epoch,
255 usage: ResumptionPSKUsage::Application,
256 psk_group_id: PskGroupId(self.group.group_id().to_vec()),
257 };
258
259 let key_id = JustPreSharedKeyID::Resumption(psk_id);
260 let proposal = self.group.psk_proposal(key_id)?;
261 self.proposals.push(proposal);
262 Ok(self)
263 }
264
265 pub fn reinit(
268 mut self,
269 group_id: Option<Vec<u8>>,
270 version: ProtocolVersion,
271 cipher_suite: CipherSuite,
272 extensions: ExtensionList,
273 ) -> Result<Self, MlsError> {
274 let proposal = self
275 .group
276 .reinit_proposal(group_id, version, cipher_suite, extensions)?;
277
278 self.proposals.push(proposal);
279 Ok(self)
280 }
281
282 #[cfg(feature = "custom_proposal")]
285 pub fn custom_proposal(mut self, proposal: CustomProposal) -> Self {
286 self.proposals.push(Proposal::Custom(proposal));
287 self
288 }
289
290 pub fn raw_proposal(mut self, proposal: Proposal) -> Self {
294 self.proposals.push(proposal);
295 self
296 }
297
298 pub fn raw_proposals(mut self, mut proposals: Vec<Proposal>) -> Self {
302 self.proposals.append(&mut proposals);
303 self
304 }
305
306 pub fn authenticated_data(self, authenticated_data: Vec<u8>) -> Self {
312 Self {
313 authenticated_data,
314 ..self
315 }
316 }
317
318 pub fn set_new_signing_identity(
326 self,
327 signer: SignatureSecretKey,
328 signing_identity: SigningIdentity,
329 ) -> Self {
330 Self {
331 new_signer: Some(signer),
332 new_signing_identity: Some(signing_identity),
333 ..self
334 }
335 }
336
337 pub fn set_leaf_node_extensions(self, new_leaf_node_extensions: ExtensionList) -> Self {
339 Self {
340 new_leaf_node_extensions: Some(new_leaf_node_extensions),
341 ..self
342 }
343 }
344
345 pub fn commit_time(self, commit_time: MlsTime) -> Self {
347 Self {
348 commit_time: Some(commit_time),
349 ..self
350 }
351 }
352
353 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
362 pub async fn build(self) -> Result<CommitOutput, MlsError> {
363 let (output, pending_commit) = self
364 .group
365 .commit_internal(
366 self.proposals,
367 None,
368 self.authenticated_data,
369 self.group_info_extensions,
370 self.new_signer,
371 self.new_signing_identity,
372 self.new_leaf_node_extensions,
373 self.commit_time,
374 )
375 .await?;
376
377 self.group.pending_commit = pending_commit.try_into()?;
378
379 Ok(output)
380 }
381
382 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
387 pub async fn build_detached(self) -> Result<(CommitOutput, CommitSecrets), MlsError> {
388 let (output, pending_commit) = self
389 .group
390 .commit_internal(
391 self.proposals,
392 None,
393 self.authenticated_data,
394 self.group_info_extensions,
395 self.new_signer,
396 self.new_signing_identity,
397 self.new_leaf_node_extensions,
398 self.commit_time,
399 )
400 .await?;
401
402 Ok((
403 output,
404 CommitSecrets(PendingCommitSnapshot::PendingCommit(
405 pending_commit.mls_encode_to_vec()?,
406 )),
407 ))
408 }
409}
410
411impl<C> Group<C>
412where
413 C: ClientConfig + Clone,
414{
415 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
456 pub async fn commit(&mut self, authenticated_data: Vec<u8>) -> Result<CommitOutput, MlsError> {
457 self.commit_builder()
458 .authenticated_data(authenticated_data)
459 .build()
460 .await
461 }
462
463 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
468 pub async fn commit_detached(
469 &mut self,
470 authenticated_data: Vec<u8>,
471 ) -> Result<(CommitOutput, CommitSecrets), MlsError> {
472 self.commit_builder()
473 .authenticated_data(authenticated_data)
474 .build_detached()
475 .await
476 }
477
478 pub fn commit_builder(&mut self) -> CommitBuilder<C> {
481 CommitBuilder {
482 group: self,
483 proposals: Default::default(),
484 authenticated_data: Default::default(),
485 group_info_extensions: Default::default(),
486 new_signer: Default::default(),
487 new_signing_identity: Default::default(),
488 new_leaf_node_extensions: Default::default(),
489 commit_time: None,
490 }
491 }
492
493 #[allow(clippy::too_many_arguments)]
496 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
497 pub(super) async fn commit_internal(
498 &mut self,
499 proposals: Vec<Proposal>,
500 external_leaf: Option<&LeafNode>,
501 authenticated_data: Vec<u8>,
502 mut welcome_group_info_extensions: ExtensionList,
503 new_signer: Option<SignatureSecretKey>,
504 new_signing_identity: Option<SigningIdentity>,
505 new_leaf_node_extensions: Option<ExtensionList>,
506 commit_time: Option<MlsTime>,
507 ) -> Result<(CommitOutput, PendingCommit), MlsError> {
508 if !self.pending_commit.is_none() {
509 return Err(MlsError::ExistingPendingCommit);
510 }
511
512 if self.state.pending_reinit.is_some() {
513 return Err(MlsError::GroupUsedAfterReInit);
514 }
515
516 let mls_rules = self.config.mls_rules();
517
518 let is_external = external_leaf.is_some();
519
520 let sender = if is_external {
524 Sender::NewMemberCommit
525 } else {
526 Sender::Member(*self.private_tree.self_index)
527 };
528
529 let new_signer = new_signer.unwrap_or_else(|| self.signer.clone());
530 let old_signer = &self.signer;
531
532 #[cfg(feature = "std")]
533 let time = Some(crate::time::MlsTime::now());
534
535 #[cfg(not(feature = "std"))]
536 let time = None;
537
538 let time = if commit_time.is_some() {
539 commit_time
540 } else {
541 time
542 };
543
544 #[cfg(feature = "by_ref_proposal")]
545 let proposals = self.state.proposals.prepare_commit(sender, proposals);
546
547 #[cfg(not(feature = "by_ref_proposal"))]
548 let proposals = prepare_commit(sender, proposals);
549
550 let mut provisional_state = self
551 .state
552 .apply_resolved(
553 sender,
554 proposals,
555 external_leaf,
556 &self.config.identity_provider(),
557 &self.cipher_suite_provider,
558 &self.config.secret_store(),
559 &mls_rules,
560 time,
561 CommitDirection::Send,
562 )
563 .await?;
564
565 let (mut provisional_private_tree, _) =
566 self.provisional_private_tree(&provisional_state)?;
567
568 if is_external {
569 provisional_private_tree.self_index = provisional_state
570 .external_init_index
571 .ok_or(MlsError::ExternalCommitMissingExternalInit)?;
572
573 self.private_tree.self_index = provisional_private_tree.self_index;
574 }
575
576 let commit_options = mls_rules
580 .commit_options(
581 &provisional_state.public_tree.roster(),
582 &provisional_state.group_context,
583 &provisional_state.applied_proposals,
584 )
585 .map_err(|e| MlsError::MlsRulesError(e.into_any_error()))?;
586
587 let perform_path_update = commit_options.path_required
588 || path_update_required(&provisional_state.applied_proposals);
589
590 let (update_path, path_secrets, commit_secret) = if perform_path_update {
591 let new_leaf_node_extensions =
599 new_leaf_node_extensions.or(external_leaf.map(|ln| ln.ungreased_extensions()));
600
601 let new_leaf_node_extensions = match new_leaf_node_extensions {
602 Some(extensions) => extensions,
603 None => self.current_user_leaf_node()?.ungreased_extensions(),
605 };
606
607 let encap_gen = TreeKem::new(
608 &mut provisional_state.public_tree,
609 &mut provisional_private_tree,
610 )
611 .encap(
612 &mut provisional_state.group_context,
613 &provisional_state.indexes_of_added_kpkgs,
614 &new_signer,
615 Some(self.config.leaf_properties(new_leaf_node_extensions)),
616 new_signing_identity,
617 &self.cipher_suite_provider,
618 #[cfg(test)]
619 &self.commit_modifiers,
620 )
621 .await?;
622
623 (
624 Some(encap_gen.update_path),
625 Some(encap_gen.path_secrets),
626 encap_gen.commit_secret,
627 )
628 } else {
629 provisional_state
631 .public_tree
632 .update_hashes(
633 &[provisional_private_tree.self_index],
634 &self.cipher_suite_provider,
635 )
636 .await?;
637
638 provisional_state.group_context.tree_hash = provisional_state
639 .public_tree
640 .tree_hash(&self.cipher_suite_provider)
641 .await?;
642
643 (None, None, PathSecret::empty(&self.cipher_suite_provider))
644 };
645
646 #[cfg(feature = "psk")]
647 let (psk_secret, psks) = self
648 .get_psk(&provisional_state.applied_proposals.psks)
649 .await?;
650
651 #[cfg(not(feature = "psk"))]
652 let psk_secret = self.get_psk();
653
654 let added_key_pkgs: Vec<_> = provisional_state
655 .applied_proposals
656 .additions
657 .iter()
658 .map(|info| info.proposal.key_package.clone())
659 .collect();
660
661 let commit = Commit {
662 proposals: provisional_state.applied_proposals.proposals_or_refs(),
663 path: update_path,
664 };
665
666 let mut auth_content = AuthenticatedContent::new_signed(
667 &self.cipher_suite_provider,
668 self.context(),
669 sender,
670 Content::Commit(Box::new(commit)),
671 old_signer,
672 #[cfg(feature = "private_message")]
673 self.encryption_options()?.control_wire_format(sender),
674 #[cfg(not(feature = "private_message"))]
675 WireFormat::PublicMessage,
676 authenticated_data,
677 )
678 .await?;
679
680 let confirmed_transcript_hash = super::transcript_hash::create(
683 self.cipher_suite_provider(),
684 &self.state.interim_transcript_hash,
685 &auth_content,
686 )
687 .await?;
688
689 provisional_state.group_context.confirmed_transcript_hash = confirmed_transcript_hash;
690
691 let key_schedule_result = KeySchedule::from_key_schedule(
692 &self.key_schedule,
693 &commit_secret,
694 &provisional_state.group_context,
695 #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
696 provisional_state.public_tree.total_leaf_count(),
697 &psk_secret,
698 &self.cipher_suite_provider,
699 )
700 .await?;
701
702 let confirmation_tag = ConfirmationTag::create(
703 &key_schedule_result.confirmation_key,
704 &provisional_state.group_context.confirmed_transcript_hash,
705 &self.cipher_suite_provider,
706 )
707 .await?;
708
709 let interim_transcript_hash = InterimTranscriptHash::create(
710 self.cipher_suite_provider(),
711 &provisional_state.group_context.confirmed_transcript_hash,
712 &confirmation_tag,
713 )
714 .await?;
715
716 auth_content.auth.confirmation_tag = Some(confirmation_tag.clone());
717
718 let ratchet_tree_ext = commit_options
719 .ratchet_tree_extension
720 .then(|| RatchetTreeExt {
721 tree_data: ExportedTree::new(provisional_state.public_tree.nodes.clone()),
722 });
723
724 let external_commit_group_info = match commit_options.allow_external_commit {
726 true => {
727 let mut extensions = ExtensionList::new();
728
729 extensions.set_from({
730 key_schedule_result
731 .key_schedule
732 .get_external_key_pair_ext(&self.cipher_suite_provider)
733 .await?
734 })?;
735
736 if let Some(ref ratchet_tree_ext) = ratchet_tree_ext {
737 if !commit_options.always_out_of_band_ratchet_tree {
738 extensions.set_from(ratchet_tree_ext.clone())?;
739 }
740 }
741
742 let info = self
743 .make_group_info(
744 &provisional_state.group_context,
745 extensions,
746 &confirmation_tag,
747 &new_signer,
748 )
749 .await?;
750
751 let msg =
752 MlsMessage::new(self.protocol_version(), MlsMessagePayload::GroupInfo(info));
753
754 Some(msg)
755 }
756 false => None,
757 };
758
759 if let Some(ratchet_tree_ext) = ratchet_tree_ext {
762 welcome_group_info_extensions.set_from(ratchet_tree_ext)?;
763 }
764
765 let welcome_group_info = self
766 .make_group_info(
767 &provisional_state.group_context,
768 welcome_group_info_extensions,
769 &confirmation_tag,
770 &new_signer,
771 )
772 .await?;
773
774 let welcome_secret = WelcomeSecret::from_joiner_secret(
777 &self.cipher_suite_provider,
778 &key_schedule_result.joiner_secret,
779 &psk_secret,
780 )
781 .await?;
782
783 let encrypted_group_info = welcome_secret
784 .encrypt(&welcome_group_info.mls_encode_to_vec()?)
785 .await?;
786
787 let path_secrets = path_secrets.as_ref();
789
790 #[cfg(not(any(mls_build_async, not(feature = "rayon"))))]
791 let encrypted_path_secrets: Vec<_> = added_key_pkgs
792 .into_par_iter()
793 .zip(&provisional_state.indexes_of_added_kpkgs)
794 .map(|(key_package, leaf_index)| {
795 self.encrypt_group_secrets(
796 &key_package,
797 *leaf_index,
798 &key_schedule_result.joiner_secret,
799 path_secrets,
800 #[cfg(feature = "psk")]
801 psks.clone(),
802 &encrypted_group_info,
803 )
804 })
805 .try_collect()?;
806
807 #[cfg(any(mls_build_async, not(feature = "rayon")))]
808 let encrypted_path_secrets = {
809 let mut secrets = Vec::new();
810
811 for (key_package, leaf_index) in added_key_pkgs
812 .into_iter()
813 .zip(&provisional_state.indexes_of_added_kpkgs)
814 {
815 secrets.push(
816 self.encrypt_group_secrets(
817 &key_package,
818 *leaf_index,
819 &key_schedule_result.joiner_secret,
820 path_secrets,
821 #[cfg(feature = "psk")]
822 psks.clone(),
823 &encrypted_group_info,
824 )
825 .await?,
826 );
827 }
828
829 secrets
830 };
831
832 let welcome_messages =
833 if commit_options.single_welcome_message && !encrypted_path_secrets.is_empty() {
834 vec![self.make_welcome_message(encrypted_path_secrets, encrypted_group_info)]
835 } else {
836 encrypted_path_secrets
837 .into_iter()
838 .map(|s| self.make_welcome_message(vec![s], encrypted_group_info.clone()))
839 .collect()
840 };
841
842 let commit_message = self.format_for_wire(auth_content.clone()).await?;
843
844 let ratchet_tree = (!commit_options.ratchet_tree_extension
846 || commit_options.always_out_of_band_ratchet_tree)
847 .then(|| ExportedTree::new(provisional_state.public_tree.nodes.clone()));
848
849 let pending_reinit = provisional_state
850 .applied_proposals
851 .reinitializations
852 .first();
853
854 let pending_commit = PendingCommit {
855 output: CommitMessageDescription {
856 is_external: matches!(auth_content.content.sender, Sender::NewMemberCommit),
857 authenticated_data: auth_content.content.authenticated_data,
858 committer: *provisional_private_tree.self_index,
859 effect: match pending_reinit {
860 Some(r) => CommitEffect::ReInit(r.clone()),
861 None => CommitEffect::NewEpoch(
862 NewEpoch::new(self.state.clone(), &provisional_state).into(),
863 ),
864 },
865 },
866
867 state: GroupState {
868 #[cfg(feature = "by_ref_proposal")]
869 proposals: crate::group::ProposalCache::new(
870 self.protocol_version(),
871 self.group_id().to_vec(),
872 ),
873 context: provisional_state.group_context,
874 public_tree: provisional_state.public_tree,
875 interim_transcript_hash,
876 pending_reinit: pending_reinit.map(|r| r.proposal.clone()),
877 confirmation_tag,
878 },
879
880 commit_message_hash: MessageHash::compute(&self.cipher_suite_provider, &commit_message)
881 .await?,
882 signer: new_signer,
883 epoch_secrets: key_schedule_result.epoch_secrets,
884 key_schedule: key_schedule_result.key_schedule,
885
886 private_tree: provisional_private_tree,
887 };
888
889 let output = CommitOutput {
890 commit_message,
891 welcome_messages,
892 ratchet_tree,
893 external_commit_group_info,
894 contains_update_path: perform_path_update,
895 #[cfg(feature = "by_ref_proposal")]
896 unused_proposals: provisional_state.unused_proposals,
897 };
898
899 Ok((output, pending_commit))
900 }
901
902 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
905 async fn make_group_info(
906 &self,
907 group_context: &GroupContext,
908 extensions: ExtensionList,
909 confirmation_tag: &ConfirmationTag,
910 signer: &SignatureSecretKey,
911 ) -> Result<GroupInfo, MlsError> {
912 let mut group_info = GroupInfo {
913 group_context: group_context.clone(),
914 extensions,
915 confirmation_tag: confirmation_tag.clone(), signer: self.current_member_leaf_index(),
917 signature: vec![],
918 };
919
920 group_info.grease(self.cipher_suite_provider())?;
921
922 group_info
924 .sign(&self.cipher_suite_provider, signer, &())
925 .await?;
926
927 Ok(group_info)
928 }
929
930 fn make_welcome_message(
931 &self,
932 secrets: Vec<EncryptedGroupSecrets>,
933 encrypted_group_info: Vec<u8>,
934 ) -> MlsMessage {
935 MlsMessage::new(
936 self.context().protocol_version,
937 MlsMessagePayload::Welcome(Welcome {
938 cipher_suite: self.context().cipher_suite,
939 secrets,
940 encrypted_group_info,
941 }),
942 )
943 }
944}
945
946#[cfg(test)]
947pub(crate) mod test_utils {
948 use alloc::vec::Vec;
949
950 use crate::{
951 crypto::SignatureSecretKey,
952 tree_kem::{leaf_node::LeafNode, TreeKemPublic, UpdatePathNode},
953 };
954
955 #[derive(Copy, Clone, Debug)]
956 pub struct CommitModifiers {
957 pub modify_leaf: fn(&mut LeafNode, &SignatureSecretKey) -> Option<SignatureSecretKey>,
958 pub modify_tree: fn(&mut TreeKemPublic),
959 pub modify_path: fn(Vec<UpdatePathNode>) -> Vec<UpdatePathNode>,
960 }
961
962 impl Default for CommitModifiers {
963 fn default() -> Self {
964 Self {
965 modify_leaf: |_, _| None,
966 modify_tree: |_| (),
967 modify_path: |a| a,
968 }
969 }
970 }
971}
972
973#[cfg(test)]
974mod tests {
975 use mls_rs_core::{
976 error::IntoAnyError,
977 extension::ExtensionType,
978 identity::{CredentialType, IdentityProvider, MemberValidationContext},
979 time::MlsTime,
980 };
981
982 use crate::extension::RequiredCapabilitiesExt;
983 use crate::{
984 client::test_utils::{test_client_with_key_pkg, TEST_CIPHER_SUITE, TEST_PROTOCOL_VERSION},
985 client_builder::{
986 test_utils::TestClientConfig, BaseConfig, ClientBuilder, WithCryptoProvider,
987 WithIdentityProvider,
988 },
989 client_config::ClientConfig,
990 crypto::test_utils::TestCryptoProvider,
991 extension::test_utils::{TestExtension, TEST_EXTENSION_TYPE},
992 group::test_utils::{test_group, test_group_custom},
993 group::{
994 proposal::ProposalType,
995 test_utils::{test_group_custom_config, test_n_member_group},
996 },
997 identity::test_utils::get_test_signing_identity,
998 identity::{basic::BasicIdentityProvider, test_utils::get_test_basic_credential},
999 key_package::test_utils::test_key_package_message,
1000 mls_rules::CommitOptions,
1001 Client,
1002 };
1003
1004 #[cfg(feature = "by_ref_proposal")]
1005 use crate::crypto::test_utils::test_cipher_suite_provider;
1006 #[cfg(feature = "by_ref_proposal")]
1007 use crate::extension::ExternalSendersExt;
1008 #[cfg(feature = "by_ref_proposal")]
1009 use crate::group::mls_rules::DefaultMlsRules;
1010
1011 #[cfg(feature = "psk")]
1012 use crate::{
1013 group::proposal::PreSharedKeyProposal,
1014 psk::{JustPreSharedKeyID, PreSharedKey, PreSharedKeyID},
1015 };
1016
1017 use super::*;
1018
1019 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
1020 async fn test_commit_builder_group() -> Group<TestClientConfig> {
1021 test_group_custom_config(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, |b| {
1022 b.custom_proposal_type(ProposalType::from(42))
1023 .extension_type(TEST_EXTENSION_TYPE.into())
1024 })
1025 .await
1026 .group
1027 }
1028
1029 fn assert_commit_builder_output<C: ClientConfig>(
1030 group: Group<C>,
1031 mut commit_output: CommitOutput,
1032 expected: Vec<Proposal>,
1033 welcome_count: usize,
1034 ) {
1035 let plaintext = commit_output.commit_message.into_plaintext().unwrap();
1036
1037 let commit_data = match plaintext.content.content {
1038 Content::Commit(commit) => commit,
1039 #[cfg(any(feature = "private_message", feature = "by_ref_proposal"))]
1040 _ => panic!("Found non-commit data"),
1041 };
1042
1043 assert_eq!(commit_data.proposals.len(), expected.len());
1044
1045 commit_data.proposals.into_iter().for_each(|proposal| {
1046 let proposal = match proposal {
1047 ProposalOrRef::Proposal(p) => p,
1048 #[cfg(feature = "by_ref_proposal")]
1049 ProposalOrRef::Reference(_) => panic!("found proposal reference"),
1050 };
1051
1052 #[cfg(feature = "psk")]
1053 if let Some(psk_id) = match proposal.as_ref() {
1054 Proposal::Psk(PreSharedKeyProposal { psk: PreSharedKeyID { key_id: JustPreSharedKeyID::External(psk_id), .. },}) => Some(psk_id),
1055 _ => None,
1056 } {
1057 let found = expected.iter().any(|item| matches!(item, Proposal::Psk(PreSharedKeyProposal { psk: PreSharedKeyID { key_id: JustPreSharedKeyID::External(id), .. }}) if id == psk_id));
1058
1059 assert!(found)
1060 } else {
1061 assert!(expected.contains(&proposal));
1062 }
1063
1064 #[cfg(not(feature = "psk"))]
1065 assert!(expected.contains(&proposal));
1066 });
1067
1068 if welcome_count > 0 {
1069 let welcome_msg = commit_output.welcome_messages.pop().unwrap();
1070
1071 assert_eq!(welcome_msg.version, group.state.context.protocol_version);
1072
1073 let welcome_msg = welcome_msg.into_welcome().unwrap();
1074
1075 assert_eq!(welcome_msg.cipher_suite, group.state.context.cipher_suite);
1076 assert_eq!(welcome_msg.secrets.len(), welcome_count);
1077 } else {
1078 assert!(commit_output.welcome_messages.is_empty());
1079 }
1080 }
1081
1082 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1083 async fn test_commit_builder_add() {
1084 let mut group = test_commit_builder_group().await;
1085
1086 let test_key_package =
1087 test_key_package_message(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "alice").await;
1088
1089 let commit_output = group
1090 .commit_builder()
1091 .add_member(test_key_package.clone())
1092 .unwrap()
1093 .build()
1094 .await
1095 .unwrap();
1096
1097 let expected_add = group.add_proposal(test_key_package).unwrap();
1098
1099 assert_commit_builder_output(group, commit_output, vec![expected_add], 1)
1100 }
1101
1102 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1103 async fn test_commit_builder_add_with_ext() {
1104 let mut group = test_commit_builder_group().await;
1105
1106 let (bob_client, bob_key_package) =
1107 test_client_with_key_pkg(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "bob").await;
1108
1109 let ext = TestExtension { foo: 42 };
1110 let mut extension_list = ExtensionList::default();
1111 extension_list.set_from(ext.clone()).unwrap();
1112
1113 let welcome_message = group
1114 .commit_builder()
1115 .add_member(bob_key_package)
1116 .unwrap()
1117 .set_group_info_ext(extension_list)
1118 .build()
1119 .await
1120 .unwrap()
1121 .welcome_messages
1122 .remove(0);
1123
1124 let (_, context) = bob_client
1125 .join_group(None, &welcome_message, None)
1126 .await
1127 .unwrap();
1128
1129 assert_eq!(
1130 context
1131 .group_info_extensions
1132 .get_as::<TestExtension>()
1133 .unwrap()
1134 .unwrap(),
1135 ext
1136 );
1137 }
1138
1139 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1140 async fn test_commit_builder_remove() {
1141 let mut group = test_commit_builder_group().await;
1142 let test_key_package =
1143 test_key_package_message(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "alice").await;
1144
1145 group
1146 .commit_builder()
1147 .add_member(test_key_package)
1148 .unwrap()
1149 .build()
1150 .await
1151 .unwrap();
1152
1153 group.apply_pending_commit().await.unwrap();
1154
1155 let commit_output = group
1156 .commit_builder()
1157 .remove_member(1)
1158 .unwrap()
1159 .build()
1160 .await
1161 .unwrap();
1162
1163 let expected_remove = group.remove_proposal(1).unwrap();
1164
1165 assert_commit_builder_output(group, commit_output, vec![expected_remove], 0);
1166 }
1167
1168 #[cfg(feature = "psk")]
1169 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1170 async fn test_commit_builder_psk() {
1171 let mut group = test_commit_builder_group().await;
1172 let test_psk = ExternalPskId::new(vec![1]);
1173
1174 group
1175 .config
1176 .secret_store()
1177 .insert(test_psk.clone(), PreSharedKey::from(vec![1]));
1178
1179 let commit_output = group
1180 .commit_builder()
1181 .add_external_psk(test_psk.clone())
1182 .unwrap()
1183 .build()
1184 .await
1185 .unwrap();
1186
1187 let key_id = JustPreSharedKeyID::External(test_psk);
1188 let expected_psk = group.psk_proposal(key_id).unwrap();
1189
1190 assert_commit_builder_output(group, commit_output, vec![expected_psk], 0)
1191 }
1192
1193 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1194 async fn test_commit_builder_group_context_ext() {
1195 let mut group = test_commit_builder_group().await;
1196 let mut test_ext = ExtensionList::default();
1197 test_ext
1198 .set_from(RequiredCapabilitiesExt::default())
1199 .unwrap();
1200
1201 let commit_output = group
1202 .commit_builder()
1203 .set_group_context_ext(test_ext.clone())
1204 .unwrap()
1205 .build()
1206 .await
1207 .unwrap();
1208
1209 let expected_ext = group.group_context_extensions_proposal(test_ext);
1210
1211 assert_commit_builder_output(group, commit_output, vec![expected_ext], 0);
1212 }
1213
1214 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1215 async fn test_commit_builder_reinit() {
1216 let mut group = test_commit_builder_group().await;
1217 let test_group_id = "foo".as_bytes().to_vec();
1218 let test_cipher_suite = TEST_CIPHER_SUITE;
1219 let test_protocol_version = TEST_PROTOCOL_VERSION;
1220 let mut test_ext = ExtensionList::default();
1221
1222 test_ext
1223 .set_from(RequiredCapabilitiesExt::default())
1224 .unwrap();
1225
1226 let commit_output = group
1227 .commit_builder()
1228 .reinit(
1229 Some(test_group_id.clone()),
1230 test_protocol_version,
1231 test_cipher_suite,
1232 test_ext.clone(),
1233 )
1234 .unwrap()
1235 .build()
1236 .await
1237 .unwrap();
1238
1239 let expected_reinit = group
1240 .reinit_proposal(
1241 Some(test_group_id),
1242 test_protocol_version,
1243 test_cipher_suite,
1244 test_ext,
1245 )
1246 .unwrap();
1247
1248 assert_commit_builder_output(group, commit_output, vec![expected_reinit], 0);
1249 }
1250
1251 #[cfg(feature = "custom_proposal")]
1252 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1253 async fn test_commit_builder_custom_proposal() {
1254 let mut group = test_commit_builder_group().await;
1255
1256 let proposal = CustomProposal::new(42.into(), vec![0, 1]);
1257
1258 let commit_output = group
1259 .commit_builder()
1260 .custom_proposal(proposal.clone())
1261 .build()
1262 .await
1263 .unwrap();
1264
1265 assert_commit_builder_output(group, commit_output, vec![Proposal::Custom(proposal)], 0);
1266 }
1267
1268 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1269 async fn test_commit_builder_chaining() {
1270 let mut group = test_commit_builder_group().await;
1271 let kp1 = test_key_package_message(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "alice").await;
1272 let kp2 = test_key_package_message(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "bob").await;
1273
1274 let expected_adds = vec![
1275 group.add_proposal(kp1.clone()).unwrap(),
1276 group.add_proposal(kp2.clone()).unwrap(),
1277 ];
1278
1279 let commit_output = group
1280 .commit_builder()
1281 .add_member(kp1)
1282 .unwrap()
1283 .add_member(kp2)
1284 .unwrap()
1285 .build()
1286 .await
1287 .unwrap();
1288
1289 assert_commit_builder_output(group, commit_output, expected_adds, 2);
1290 }
1291
1292 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1293 async fn test_commit_builder_empty_commit() {
1294 let mut group = test_commit_builder_group().await;
1295
1296 let commit_output = group.commit_builder().build().await.unwrap();
1297
1298 assert_commit_builder_output(group, commit_output, vec![], 0);
1299 }
1300
1301 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1302 async fn test_commit_builder_authenticated_data() {
1303 let mut group = test_commit_builder_group().await;
1304 let test_data = "test".as_bytes().to_vec();
1305
1306 let commit_output = group
1307 .commit_builder()
1308 .authenticated_data(test_data.clone())
1309 .build()
1310 .await
1311 .unwrap();
1312
1313 assert_eq!(
1314 commit_output
1315 .commit_message
1316 .into_plaintext()
1317 .unwrap()
1318 .content
1319 .authenticated_data,
1320 test_data
1321 );
1322 }
1323
1324 #[cfg(feature = "by_ref_proposal")]
1325 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1326 async fn test_commit_builder_multiple_welcome_messages() {
1327 let mut group = test_group_custom_config(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, |b| {
1328 let options = CommitOptions::new().with_single_welcome_message(false);
1329 b.mls_rules(DefaultMlsRules::new().with_commit_options(options))
1330 })
1331 .await;
1332
1333 let (alice, alice_kp) =
1334 test_client_with_key_pkg(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "a").await;
1335
1336 let (bob, bob_kp) =
1337 test_client_with_key_pkg(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "b").await;
1338
1339 group.propose_add(alice_kp.clone(), vec![]).await.unwrap();
1340
1341 group.propose_add(bob_kp.clone(), vec![]).await.unwrap();
1342
1343 let output = group.commit(Vec::new()).await.unwrap();
1344 let welcomes = output.welcome_messages;
1345
1346 let cs = test_cipher_suite_provider(TEST_CIPHER_SUITE);
1347
1348 for (client, kp) in [(alice, alice_kp), (bob, bob_kp)] {
1349 let kp_ref = kp.key_package_reference(&cs).await.unwrap().unwrap();
1350
1351 let welcome = welcomes
1352 .iter()
1353 .find(|w| w.welcome_key_package_references().contains(&&kp_ref))
1354 .unwrap();
1355
1356 client.join_group(None, welcome, None).await.unwrap();
1357
1358 assert_eq!(welcome.clone().into_welcome().unwrap().secrets.len(), 1);
1359 }
1360 }
1361
1362 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1363 async fn commit_can_change_credential() {
1364 let cs = TEST_CIPHER_SUITE;
1365 let mut groups = test_n_member_group(TEST_PROTOCOL_VERSION, cs, 3).await;
1366 let (identity, secret_key) = get_test_signing_identity(cs, b"member").await;
1367
1368 let commit_output = groups[0]
1369 .commit_builder()
1370 .set_new_signing_identity(secret_key, identity.clone())
1371 .build()
1372 .await
1373 .unwrap();
1374
1375 groups[0].process_pending_commit().await.unwrap();
1377 let new_member = groups[0].roster().member_with_index(0).unwrap();
1378
1379 assert_eq!(
1380 new_member.signing_identity.credential,
1381 get_test_basic_credential(b"member".to_vec())
1382 );
1383
1384 assert_eq!(
1385 new_member.signing_identity.signature_key,
1386 identity.signature_key
1387 );
1388
1389 groups[1]
1391 .process_message(commit_output.commit_message)
1392 .await
1393 .unwrap();
1394
1395 let new_member = groups[1].roster().member_with_index(0).unwrap();
1396
1397 assert_eq!(
1398 new_member.signing_identity.credential,
1399 get_test_basic_credential(b"member".to_vec())
1400 );
1401
1402 assert_eq!(
1403 new_member.signing_identity.signature_key,
1404 identity.signature_key
1405 );
1406 }
1407
1408 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1409 async fn commit_includes_tree_if_no_ratchet_tree_ext() {
1410 let mut group = test_group_custom(
1411 TEST_PROTOCOL_VERSION,
1412 TEST_CIPHER_SUITE,
1413 Default::default(),
1414 None,
1415 Some(CommitOptions::new().with_ratchet_tree_extension(false)),
1416 )
1417 .await;
1418
1419 let commit = group.commit(vec![]).await.unwrap();
1420
1421 group.apply_pending_commit().await.unwrap();
1422
1423 let new_tree = group.export_tree();
1424
1425 assert_eq!(new_tree, commit.ratchet_tree.unwrap())
1426 }
1427
1428 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1429 async fn commit_does_not_include_tree_if_ratchet_tree_ext() {
1430 let mut group = test_group_custom(
1431 TEST_PROTOCOL_VERSION,
1432 TEST_CIPHER_SUITE,
1433 Default::default(),
1434 None,
1435 Some(CommitOptions::new().with_ratchet_tree_extension(true)),
1436 )
1437 .await;
1438
1439 let commit = group.commit(vec![]).await.unwrap();
1440
1441 assert!(commit.ratchet_tree.is_none());
1442 }
1443
1444 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1445 async fn commit_includes_external_commit_group_info_if_requested() {
1446 let mut group = test_group_custom(
1447 TEST_PROTOCOL_VERSION,
1448 TEST_CIPHER_SUITE,
1449 Default::default(),
1450 None,
1451 Some(
1452 CommitOptions::new()
1453 .with_allow_external_commit(true)
1454 .with_ratchet_tree_extension(false),
1455 ),
1456 )
1457 .await;
1458
1459 let commit = group.commit(vec![]).await.unwrap();
1460
1461 let info = commit
1462 .external_commit_group_info
1463 .unwrap()
1464 .into_group_info()
1465 .unwrap();
1466
1467 assert!(!info.extensions.has_extension(ExtensionType::RATCHET_TREE));
1468 assert!(info.extensions.has_extension(ExtensionType::EXTERNAL_PUB));
1469 }
1470
1471 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1472 async fn commit_includes_external_commit_and_tree_if_requested() {
1473 let mut group = test_group_custom(
1474 TEST_PROTOCOL_VERSION,
1475 TEST_CIPHER_SUITE,
1476 Default::default(),
1477 None,
1478 Some(
1479 CommitOptions::new()
1480 .with_allow_external_commit(true)
1481 .with_ratchet_tree_extension(true),
1482 ),
1483 )
1484 .await;
1485
1486 let commit = group.commit(vec![]).await.unwrap();
1487
1488 let info = commit
1489 .external_commit_group_info
1490 .unwrap()
1491 .into_group_info()
1492 .unwrap();
1493
1494 assert!(info.extensions.has_extension(ExtensionType::RATCHET_TREE));
1495 assert!(info.extensions.has_extension(ExtensionType::EXTERNAL_PUB));
1496 }
1497
1498 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1499 async fn commit_does_not_include_external_commit_group_info_if_not_requested() {
1500 let mut group = test_group_custom(
1501 TEST_PROTOCOL_VERSION,
1502 TEST_CIPHER_SUITE,
1503 Default::default(),
1504 None,
1505 Some(CommitOptions::new().with_allow_external_commit(false)),
1506 )
1507 .await;
1508
1509 let commit = group.commit(vec![]).await.unwrap();
1510
1511 assert!(commit.external_commit_group_info.is_none());
1512 }
1513
1514 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1515 async fn commit_includes_tree_out_of_bounds_and_not_in_external_group_info_if_requested_tree_ext_off(
1516 ) {
1517 let mut group = test_group_custom(
1518 TEST_PROTOCOL_VERSION,
1519 TEST_CIPHER_SUITE,
1520 Default::default(),
1521 None,
1522 Some(
1523 CommitOptions::new()
1524 .with_always_out_of_band_ratchet_tree(true)
1525 .with_ratchet_tree_extension(false)
1526 .with_allow_external_commit(true),
1527 ),
1528 )
1529 .await;
1530
1531 let commit = group.commit(vec![]).await.unwrap();
1532
1533 assert!(commit.ratchet_tree.is_some());
1534
1535 let info = commit
1536 .external_commit_group_info
1537 .unwrap()
1538 .into_group_info()
1539 .unwrap();
1540
1541 assert!(!info.extensions.has_extension(ExtensionType::RATCHET_TREE));
1542 }
1543
1544 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1545 async fn commit_includes_tree_out_of_bounds_and_not_in_external_group_info_if_requested_tree_ext_on(
1546 ) {
1547 let mut group = test_group_custom(
1548 TEST_PROTOCOL_VERSION,
1549 TEST_CIPHER_SUITE,
1550 Default::default(),
1551 None,
1552 Some(
1553 CommitOptions::new()
1554 .with_always_out_of_band_ratchet_tree(true)
1555 .with_ratchet_tree_extension(true)
1556 .with_allow_external_commit(true),
1557 ),
1558 )
1559 .await;
1560
1561 let commit = group.commit(vec![]).await.unwrap();
1562
1563 assert!(commit.ratchet_tree.is_some());
1564
1565 let info = commit
1566 .external_commit_group_info
1567 .unwrap()
1568 .into_group_info()
1569 .unwrap();
1570
1571 assert!(!info.extensions.has_extension(ExtensionType::RATCHET_TREE));
1572 }
1573
1574 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1575 async fn member_identity_is_validated_against_new_extensions() {
1576 let alice = client_with_test_extension(b"alice").await;
1577 let mut alice = alice
1578 .create_group(ExtensionList::new(), Default::default(), None)
1579 .await
1580 .unwrap();
1581
1582 let bob = client_with_test_extension(b"bob").await;
1583 let bob_kp = bob
1584 .generate_key_package_message(Default::default(), Default::default(), None)
1585 .await
1586 .unwrap();
1587
1588 let mut extension_list = ExtensionList::new();
1589 let extension = TestExtension { foo: b'a' };
1590 extension_list.set_from(extension).unwrap();
1591
1592 let res = alice
1593 .commit_builder()
1594 .add_member(bob_kp)
1595 .unwrap()
1596 .set_group_context_ext(extension_list.clone())
1597 .unwrap()
1598 .build()
1599 .await;
1600
1601 assert!(res.is_err());
1602
1603 let alex = client_with_test_extension(b"alex").await;
1604
1605 alice
1606 .commit_builder()
1607 .add_member(
1608 alex.generate_key_package_message(Default::default(), Default::default(), None)
1609 .await
1610 .unwrap(),
1611 )
1612 .unwrap()
1613 .set_group_context_ext(extension_list.clone())
1614 .unwrap()
1615 .build()
1616 .await
1617 .unwrap();
1618 }
1619
1620 #[cfg(feature = "by_ref_proposal")]
1621 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1622 async fn server_identity_is_validated_against_new_extensions() {
1623 let alice = client_with_test_extension(b"alice").await;
1624 let mut alice = alice
1625 .create_group(ExtensionList::new(), Default::default(), None)
1626 .await
1627 .unwrap();
1628
1629 let mut extension_list = ExtensionList::new();
1630 let extension = TestExtension { foo: b'a' };
1631 extension_list.set_from(extension).unwrap();
1632
1633 let (alex_server, _) = get_test_signing_identity(TEST_CIPHER_SUITE, b"alex").await;
1634
1635 let mut alex_extensions = extension_list.clone();
1636
1637 alex_extensions
1638 .set_from(ExternalSendersExt {
1639 allowed_senders: vec![alex_server],
1640 })
1641 .unwrap();
1642
1643 let res = alice
1644 .commit_builder()
1645 .set_group_context_ext(alex_extensions)
1646 .unwrap()
1647 .build()
1648 .await;
1649
1650 assert!(res.is_err());
1651
1652 let (bob_server, _) = get_test_signing_identity(TEST_CIPHER_SUITE, b"bob").await;
1653
1654 let mut bob_extensions = extension_list;
1655
1656 bob_extensions
1657 .set_from(ExternalSendersExt {
1658 allowed_senders: vec![bob_server],
1659 })
1660 .unwrap();
1661
1662 alice
1663 .commit_builder()
1664 .set_group_context_ext(bob_extensions)
1665 .unwrap()
1666 .build()
1667 .await
1668 .unwrap();
1669 }
1670
1671 #[derive(Debug, Clone)]
1672 struct IdentityProviderWithExtension(BasicIdentityProvider);
1673
1674 #[derive(Clone, Debug)]
1675 #[cfg_attr(feature = "std", derive(thiserror::Error))]
1676 #[cfg_attr(feature = "std", error("test error"))]
1677 struct IdentityProviderWithExtensionError {}
1678
1679 impl IntoAnyError for IdentityProviderWithExtensionError {
1680 #[cfg(feature = "std")]
1681 fn into_dyn_error(self) -> Result<Box<dyn std::error::Error + Send + Sync>, Self> {
1682 Ok(self.into())
1683 }
1684 }
1685
1686 impl IdentityProviderWithExtension {
1687 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
1690 async fn starts_with_foo(
1691 &self,
1692 identity: &SigningIdentity,
1693 _timestamp: Option<MlsTime>,
1694 extensions: Option<&ExtensionList>,
1695 ) -> bool {
1696 if let Some(extensions) = extensions {
1697 if let Some(ext) = extensions.get_as::<TestExtension>().unwrap() {
1698 self.identity(identity, extensions).await.unwrap()[0] == ext.foo
1699 } else {
1700 true
1701 }
1702 } else {
1703 true
1704 }
1705 }
1706 }
1707
1708 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
1709 #[cfg_attr(mls_build_async, maybe_async::must_be_async)]
1710 impl IdentityProvider for IdentityProviderWithExtension {
1711 type Error = IdentityProviderWithExtensionError;
1712
1713 async fn validate_member(
1714 &self,
1715 identity: &SigningIdentity,
1716 timestamp: Option<MlsTime>,
1717 context: MemberValidationContext<'_>,
1718 ) -> Result<(), Self::Error> {
1719 self.starts_with_foo(identity, timestamp, context.new_extensions())
1720 .await
1721 .then_some(())
1722 .ok_or(IdentityProviderWithExtensionError {})
1723 }
1724
1725 async fn validate_external_sender(
1726 &self,
1727 identity: &SigningIdentity,
1728 timestamp: Option<MlsTime>,
1729 extensions: Option<&ExtensionList>,
1730 ) -> Result<(), Self::Error> {
1731 (!self.starts_with_foo(identity, timestamp, extensions).await)
1732 .then_some(())
1733 .ok_or(IdentityProviderWithExtensionError {})
1734 }
1735
1736 async fn identity(
1737 &self,
1738 signing_identity: &SigningIdentity,
1739 extensions: &ExtensionList,
1740 ) -> Result<Vec<u8>, Self::Error> {
1741 self.0
1742 .identity(signing_identity, extensions)
1743 .await
1744 .map_err(|_| IdentityProviderWithExtensionError {})
1745 }
1746
1747 async fn valid_successor(
1748 &self,
1749 _predecessor: &SigningIdentity,
1750 _successor: &SigningIdentity,
1751 _extensions: &ExtensionList,
1752 ) -> Result<bool, Self::Error> {
1753 Ok(true)
1754 }
1755
1756 fn supported_types(&self) -> Vec<CredentialType> {
1757 self.0.supported_types()
1758 }
1759 }
1760
1761 type ExtensionClientConfig = WithIdentityProvider<
1762 IdentityProviderWithExtension,
1763 WithCryptoProvider<TestCryptoProvider, BaseConfig>,
1764 >;
1765
1766 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
1767 async fn client_with_test_extension(name: &[u8]) -> Client<ExtensionClientConfig> {
1768 let (identity, secret_key) = get_test_signing_identity(TEST_CIPHER_SUITE, name).await;
1769
1770 ClientBuilder::new()
1771 .crypto_provider(TestCryptoProvider::new())
1772 .extension_types(vec![TEST_EXTENSION_TYPE.into()])
1773 .identity_provider(IdentityProviderWithExtension(BasicIdentityProvider::new()))
1774 .signing_identity(identity, secret_key, TEST_CIPHER_SUITE)
1775 .build()
1776 }
1777
1778 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1779 async fn detached_commit() {
1780 let mut group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
1781
1782 let (_commit, secrets) = group.commit_builder().build_detached().await.unwrap();
1783 assert!(group.pending_commit.is_none());
1784 group.apply_detached_commit(secrets).await.unwrap();
1785 assert_eq!(group.context().epoch, 1);
1786 }
1787}