1use alloc::boxed::Box;
6use alloc::vec;
7use alloc::vec::Vec;
8use core::fmt::Debug;
9use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
10use mls_rs_core::{crypto::SignatureSecretKey, error::IntoAnyError};
11
12use crate::{
13 cipher_suite::CipherSuite,
14 client::MlsError,
15 client_config::ClientConfig,
16 extension::RatchetTreeExt,
17 identity::SigningIdentity,
18 protocol_version::ProtocolVersion,
19 signer::Signable,
20 tree_kem::{
21 kem::TreeKem, node::LeafIndex, path_secret::PathSecret, TreeKemPrivate, UpdatePath,
22 },
23 ExtensionList, MlsRules,
24};
25
26#[cfg(all(not(mls_build_async), feature = "rayon"))]
27use {crate::iter::ParallelIteratorExt, rayon::prelude::*};
28
29use crate::tree_kem::leaf_node::LeafNode;
30
31#[cfg(not(feature = "private_message"))]
32use crate::WireFormat;
33
34#[cfg(feature = "psk")]
35use crate::{
36 group::{JustPreSharedKeyID, PskGroupId, ResumptionPSKUsage, ResumptionPsk},
37 psk::ExternalPskId,
38};
39
40use super::{
41 confirmation_tag::ConfirmationTag,
42 framing::{Content, MlsMessage, MlsMessagePayload, Sender},
43 key_schedule::{KeySchedule, WelcomeSecret},
44 message_hash::MessageHash,
45 message_processor::{path_update_required, MessageProcessor},
46 message_signature::AuthenticatedContent,
47 mls_rules::CommitDirection,
48 proposal::{Proposal, ProposalOrRef},
49 CommitEffect, CommitMessageDescription, EncryptedGroupSecrets, EpochSecrets, ExportedTree,
50 Group, GroupContext, GroupInfo, GroupState, InterimTranscriptHash, NewEpoch,
51 PendingCommitSnapshot, Welcome,
52};
53
54#[cfg(not(feature = "by_ref_proposal"))]
55use super::proposal_cache::prepare_commit;
56
57#[cfg(feature = "custom_proposal")]
58use super::proposal::CustomProposal;
59
60#[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode)]
61#[cfg_attr(feature = "arbitrary", derive(mls_rs_core::arbitrary::Arbitrary))]
62#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
63pub(crate) struct Commit {
64 pub proposals: Vec<ProposalOrRef>,
65 pub path: Option<UpdatePath>,
66}
67
68#[derive(Clone, PartialEq, Debug, MlsEncode, MlsDecode, MlsSize)]
69pub(crate) struct PendingCommit {
70 pub(crate) state: GroupState,
71 pub(crate) epoch_secrets: EpochSecrets,
72 pub(crate) private_tree: TreeKemPrivate,
73 pub(crate) key_schedule: KeySchedule,
74 pub(crate) signer: SignatureSecretKey,
75
76 pub(crate) output: CommitMessageDescription,
77
78 pub(crate) commit_message_hash: MessageHash,
79}
80
81#[cfg_attr(
82 all(feature = "ffi", not(test)),
83 safer_ffi_gen::ffi_type(clone, opaque)
84)]
85#[derive(Clone)]
86pub struct CommitSecrets(pub(crate) PendingCommitSnapshot);
87
88impl CommitSecrets {
89 pub fn from_bytes(bytes: &[u8]) -> Result<Self, MlsError> {
91 Ok(MlsDecode::mls_decode(&mut &*bytes).map(Self)?)
92 }
93
94 pub fn to_bytes(&self) -> Result<Vec<u8>, MlsError> {
96 Ok(self.0.mls_encode_to_vec()?)
97 }
98}
99
100#[cfg_attr(
101 all(feature = "ffi", not(test)),
102 safer_ffi_gen::ffi_type(clone, opaque)
103)]
104#[derive(Clone, Debug)]
105#[non_exhaustive]
106pub struct CommitOutput {
110 pub commit_message: MlsMessage,
112 pub welcome_messages: Vec<MlsMessage>,
119 pub ratchet_tree: Option<ExportedTree<'static>>,
123 pub external_commit_group_info: Option<MlsMessage>,
127 #[cfg(feature = "by_ref_proposal")]
129 pub unused_proposals: Vec<crate::mls_rules::ProposalInfo<Proposal>>,
130 pub contains_update_path: bool,
132}
133
134#[cfg_attr(all(feature = "ffi", not(test)), ::safer_ffi_gen::safer_ffi_gen)]
135impl CommitOutput {
136 #[cfg(feature = "ffi")]
138 pub fn commit_message(&self) -> &MlsMessage {
139 &self.commit_message
140 }
141
142 #[cfg(feature = "ffi")]
144 pub fn welcome_messages(&self) -> &[MlsMessage] {
145 &self.welcome_messages
146 }
147
148 #[cfg(feature = "ffi")]
152 pub fn ratchet_tree(&self) -> Option<&ExportedTree<'static>> {
153 self.ratchet_tree.as_ref()
154 }
155
156 #[cfg(feature = "ffi")]
160 pub fn external_commit_group_info(&self) -> Option<&MlsMessage> {
161 self.external_commit_group_info.as_ref()
162 }
163
164 #[cfg(all(feature = "ffi", feature = "by_ref_proposal"))]
166 pub fn unused_proposals(&self) -> &[crate::mls_rules::ProposalInfo<Proposal>] {
167 &self.unused_proposals
168 }
169}
170
171pub struct CommitBuilder<'a, C>
179where
180 C: ClientConfig + Clone,
181{
182 group: &'a mut Group<C>,
183 pub(super) proposals: Vec<Proposal>,
184 authenticated_data: Vec<u8>,
185 group_info_extensions: ExtensionList,
186 new_signer: Option<SignatureSecretKey>,
187 new_signing_identity: Option<SigningIdentity>,
188 new_leaf_node_extensions: Option<ExtensionList>,
189}
190
191impl<'a, C> CommitBuilder<'a, C>
192where
193 C: ClientConfig + Clone,
194{
195 pub fn add_member(mut self, key_package: MlsMessage) -> Result<CommitBuilder<'a, C>, MlsError> {
198 let proposal = self.group.add_proposal(key_package)?;
199 self.proposals.push(proposal);
200 Ok(self)
201 }
202
203 pub fn set_group_info_ext(self, extensions: ExtensionList) -> Self {
214 Self {
215 group_info_extensions: extensions,
216 ..self
217 }
218 }
219
220 pub fn remove_member(mut self, index: u32) -> Result<Self, MlsError> {
223 let proposal = self.group.remove_proposal(index)?;
224 self.proposals.push(proposal);
225 Ok(self)
226 }
227
228 pub fn set_group_context_ext(mut self, extensions: ExtensionList) -> Result<Self, MlsError> {
232 let proposal = self.group.group_context_extensions_proposal(extensions);
233 self.proposals.push(proposal);
234 Ok(self)
235 }
236
237 #[cfg(feature = "psk")]
241 pub fn add_external_psk(mut self, psk_id: ExternalPskId) -> Result<Self, MlsError> {
242 let key_id = JustPreSharedKeyID::External(psk_id);
243 let proposal = self.group.psk_proposal(key_id)?;
244 self.proposals.push(proposal);
245 Ok(self)
246 }
247
248 #[cfg(feature = "psk")]
252 pub fn add_resumption_psk(mut self, psk_epoch: u64) -> Result<Self, MlsError> {
253 let psk_id = ResumptionPsk {
254 psk_epoch,
255 usage: ResumptionPSKUsage::Application,
256 psk_group_id: PskGroupId(self.group.group_id().to_vec()),
257 };
258
259 let key_id = JustPreSharedKeyID::Resumption(psk_id);
260 let proposal = self.group.psk_proposal(key_id)?;
261 self.proposals.push(proposal);
262 Ok(self)
263 }
264
265 pub fn reinit(
268 mut self,
269 group_id: Option<Vec<u8>>,
270 version: ProtocolVersion,
271 cipher_suite: CipherSuite,
272 extensions: ExtensionList,
273 ) -> Result<Self, MlsError> {
274 let proposal = self
275 .group
276 .reinit_proposal(group_id, version, cipher_suite, extensions)?;
277
278 self.proposals.push(proposal);
279 Ok(self)
280 }
281
282 #[cfg(feature = "custom_proposal")]
285 pub fn custom_proposal(mut self, proposal: CustomProposal) -> Self {
286 self.proposals.push(Proposal::Custom(proposal));
287 self
288 }
289
290 pub fn raw_proposal(mut self, proposal: Proposal) -> Self {
294 self.proposals.push(proposal);
295 self
296 }
297
298 pub fn raw_proposals(mut self, mut proposals: Vec<Proposal>) -> Self {
302 self.proposals.append(&mut proposals);
303 self
304 }
305
306 pub fn authenticated_data(self, authenticated_data: Vec<u8>) -> Self {
312 Self {
313 authenticated_data,
314 ..self
315 }
316 }
317
318 pub fn set_new_signing_identity(
326 self,
327 signer: SignatureSecretKey,
328 signing_identity: SigningIdentity,
329 ) -> Self {
330 Self {
331 new_signer: Some(signer),
332 new_signing_identity: Some(signing_identity),
333 ..self
334 }
335 }
336
337 pub fn set_leaf_node_extensions(self, new_leaf_node_extensions: ExtensionList) -> Self {
339 Self {
340 new_leaf_node_extensions: Some(new_leaf_node_extensions),
341 ..self
342 }
343 }
344
345 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
354 pub async fn build(self) -> Result<CommitOutput, MlsError> {
355 let (output, pending_commit) = self
356 .group
357 .commit_internal(
358 self.proposals,
359 None,
360 self.authenticated_data,
361 self.group_info_extensions,
362 self.new_signer,
363 self.new_signing_identity,
364 self.new_leaf_node_extensions,
365 )
366 .await?;
367
368 self.group.pending_commit = pending_commit.try_into()?;
369
370 Ok(output)
371 }
372
373 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
378 pub async fn build_detached(self) -> Result<(CommitOutput, CommitSecrets), MlsError> {
379 let (output, pending_commit) = self
380 .group
381 .commit_internal(
382 self.proposals,
383 None,
384 self.authenticated_data,
385 self.group_info_extensions,
386 self.new_signer,
387 self.new_signing_identity,
388 self.new_leaf_node_extensions,
389 )
390 .await?;
391
392 Ok((
393 output,
394 CommitSecrets(PendingCommitSnapshot::PendingCommit(
395 pending_commit.mls_encode_to_vec()?,
396 )),
397 ))
398 }
399}
400
401impl<C> Group<C>
402where
403 C: ClientConfig + Clone,
404{
405 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
446 pub async fn commit(&mut self, authenticated_data: Vec<u8>) -> Result<CommitOutput, MlsError> {
447 self.commit_builder()
448 .authenticated_data(authenticated_data)
449 .build()
450 .await
451 }
452
453 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
458 pub async fn commit_detached(
459 &mut self,
460 authenticated_data: Vec<u8>,
461 ) -> Result<(CommitOutput, CommitSecrets), MlsError> {
462 self.commit_builder()
463 .authenticated_data(authenticated_data)
464 .build_detached()
465 .await
466 }
467
468 pub fn commit_builder(&mut self) -> CommitBuilder<C> {
471 CommitBuilder {
472 group: self,
473 proposals: Default::default(),
474 authenticated_data: Default::default(),
475 group_info_extensions: Default::default(),
476 new_signer: Default::default(),
477 new_signing_identity: Default::default(),
478 new_leaf_node_extensions: Default::default(),
479 }
480 }
481
482 #[allow(clippy::too_many_arguments)]
485 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
486 pub(super) async fn commit_internal(
487 &mut self,
488 proposals: Vec<Proposal>,
489 external_leaf: Option<&LeafNode>,
490 authenticated_data: Vec<u8>,
491 mut welcome_group_info_extensions: ExtensionList,
492 new_signer: Option<SignatureSecretKey>,
493 new_signing_identity: Option<SigningIdentity>,
494 new_leaf_node_extensions: Option<ExtensionList>,
495 ) -> Result<(CommitOutput, PendingCommit), MlsError> {
496 if !self.pending_commit.is_none() {
497 return Err(MlsError::ExistingPendingCommit);
498 }
499
500 if self.state.pending_reinit.is_some() {
501 return Err(MlsError::GroupUsedAfterReInit);
502 }
503
504 let mls_rules = self.config.mls_rules();
505
506 let is_external = external_leaf.is_some();
507
508 let sender = if is_external {
512 Sender::NewMemberCommit
513 } else {
514 Sender::Member(*self.private_tree.self_index)
515 };
516
517 let new_signer = new_signer.unwrap_or_else(|| self.signer.clone());
518 let old_signer = &self.signer;
519
520 #[cfg(feature = "std")]
521 let time = Some(crate::time::MlsTime::now());
522
523 #[cfg(not(feature = "std"))]
524 let time = None;
525
526 #[cfg(feature = "by_ref_proposal")]
527 let proposals = self.state.proposals.prepare_commit(sender, proposals);
528
529 #[cfg(not(feature = "by_ref_proposal"))]
530 let proposals = prepare_commit(sender, proposals);
531
532 let mut provisional_state = self
533 .state
534 .apply_resolved(
535 sender,
536 proposals,
537 external_leaf,
538 &self.config.identity_provider(),
539 &self.cipher_suite_provider,
540 &self.config.secret_store(),
541 &mls_rules,
542 time,
543 CommitDirection::Send,
544 )
545 .await?;
546
547 let (mut provisional_private_tree, _) =
548 self.provisional_private_tree(&provisional_state)?;
549
550 if is_external {
551 provisional_private_tree.self_index = provisional_state
552 .external_init_index
553 .ok_or(MlsError::ExternalCommitMissingExternalInit)?;
554
555 self.private_tree.self_index = provisional_private_tree.self_index;
556 }
557
558 let commit_options = mls_rules
562 .commit_options(
563 &provisional_state.public_tree.roster(),
564 &provisional_state.group_context,
565 &provisional_state.applied_proposals,
566 )
567 .map_err(|e| MlsError::MlsRulesError(e.into_any_error()))?;
568
569 let perform_path_update = commit_options.path_required
570 || path_update_required(&provisional_state.applied_proposals);
571
572 let (update_path, path_secrets, commit_secret) = if perform_path_update {
573 let new_leaf_node_extensions =
581 new_leaf_node_extensions.or(external_leaf.map(|ln| ln.ungreased_extensions()));
582
583 let new_leaf_node_extensions = match new_leaf_node_extensions {
584 Some(extensions) => extensions,
585 None => self.current_user_leaf_node()?.ungreased_extensions(),
587 };
588
589 let encap_gen = TreeKem::new(
590 &mut provisional_state.public_tree,
591 &mut provisional_private_tree,
592 )
593 .encap(
594 &mut provisional_state.group_context,
595 &provisional_state.indexes_of_added_kpkgs,
596 &new_signer,
597 Some(self.config.leaf_properties(new_leaf_node_extensions)),
598 new_signing_identity,
599 &self.cipher_suite_provider,
600 #[cfg(test)]
601 &self.commit_modifiers,
602 )
603 .await?;
604
605 (
606 Some(encap_gen.update_path),
607 Some(encap_gen.path_secrets),
608 encap_gen.commit_secret,
609 )
610 } else {
611 provisional_state
613 .public_tree
614 .update_hashes(
615 &[provisional_private_tree.self_index],
616 &self.cipher_suite_provider,
617 )
618 .await?;
619
620 provisional_state.group_context.tree_hash = provisional_state
621 .public_tree
622 .tree_hash(&self.cipher_suite_provider)
623 .await?;
624
625 (None, None, PathSecret::empty(&self.cipher_suite_provider))
626 };
627
628 #[cfg(feature = "psk")]
629 let (psk_secret, psks) = self
630 .get_psk(&provisional_state.applied_proposals.psks)
631 .await?;
632
633 #[cfg(not(feature = "psk"))]
634 let psk_secret = self.get_psk();
635
636 let added_key_pkgs: Vec<_> = provisional_state
637 .applied_proposals
638 .additions
639 .iter()
640 .map(|info| info.proposal.key_package.clone())
641 .collect();
642
643 let commit = Commit {
644 proposals: provisional_state.applied_proposals.proposals_or_refs(),
645 path: update_path,
646 };
647
648 let mut auth_content = AuthenticatedContent::new_signed(
649 &self.cipher_suite_provider,
650 self.context(),
651 sender,
652 Content::Commit(Box::new(commit)),
653 old_signer,
654 #[cfg(feature = "private_message")]
655 self.encryption_options()?.control_wire_format(sender),
656 #[cfg(not(feature = "private_message"))]
657 WireFormat::PublicMessage,
658 authenticated_data,
659 )
660 .await?;
661
662 let confirmed_transcript_hash = super::transcript_hash::create(
665 self.cipher_suite_provider(),
666 &self.state.interim_transcript_hash,
667 &auth_content,
668 )
669 .await?;
670
671 provisional_state.group_context.confirmed_transcript_hash = confirmed_transcript_hash;
672
673 let key_schedule_result = KeySchedule::from_key_schedule(
674 &self.key_schedule,
675 &commit_secret,
676 &provisional_state.group_context,
677 #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
678 provisional_state.public_tree.total_leaf_count(),
679 &psk_secret,
680 &self.cipher_suite_provider,
681 )
682 .await?;
683
684 let confirmation_tag = ConfirmationTag::create(
685 &key_schedule_result.confirmation_key,
686 &provisional_state.group_context.confirmed_transcript_hash,
687 &self.cipher_suite_provider,
688 )
689 .await?;
690
691 let interim_transcript_hash = InterimTranscriptHash::create(
692 self.cipher_suite_provider(),
693 &provisional_state.group_context.confirmed_transcript_hash,
694 &confirmation_tag,
695 )
696 .await?;
697
698 auth_content.auth.confirmation_tag = Some(confirmation_tag.clone());
699
700 let ratchet_tree_ext = commit_options
701 .ratchet_tree_extension
702 .then(|| RatchetTreeExt {
703 tree_data: ExportedTree::new(provisional_state.public_tree.nodes.clone()),
704 });
705
706 let external_commit_group_info = match commit_options.allow_external_commit {
708 true => {
709 let mut extensions = ExtensionList::new();
710
711 extensions.set_from({
712 key_schedule_result
713 .key_schedule
714 .get_external_key_pair_ext(&self.cipher_suite_provider)
715 .await?
716 })?;
717
718 if let Some(ref ratchet_tree_ext) = ratchet_tree_ext {
719 if !commit_options.always_out_of_band_ratchet_tree {
720 extensions.set_from(ratchet_tree_ext.clone())?;
721 }
722 }
723
724 let info = self
725 .make_group_info(
726 &provisional_state.group_context,
727 extensions,
728 &confirmation_tag,
729 &new_signer,
730 )
731 .await?;
732
733 let msg =
734 MlsMessage::new(self.protocol_version(), MlsMessagePayload::GroupInfo(info));
735
736 Some(msg)
737 }
738 false => None,
739 };
740
741 if let Some(ratchet_tree_ext) = ratchet_tree_ext {
744 welcome_group_info_extensions.set_from(ratchet_tree_ext)?;
745 }
746
747 let welcome_group_info = self
748 .make_group_info(
749 &provisional_state.group_context,
750 welcome_group_info_extensions,
751 &confirmation_tag,
752 &new_signer,
753 )
754 .await?;
755
756 let welcome_secret = WelcomeSecret::from_joiner_secret(
759 &self.cipher_suite_provider,
760 &key_schedule_result.joiner_secret,
761 &psk_secret,
762 )
763 .await?;
764
765 let encrypted_group_info = welcome_secret
766 .encrypt(&welcome_group_info.mls_encode_to_vec()?)
767 .await?;
768
769 let path_secrets = path_secrets.as_ref();
771
772 #[cfg(not(any(mls_build_async, not(feature = "rayon"))))]
773 let encrypted_path_secrets: Vec<_> = added_key_pkgs
774 .into_par_iter()
775 .zip(&provisional_state.indexes_of_added_kpkgs)
776 .map(|(key_package, leaf_index)| {
777 self.encrypt_group_secrets(
778 &key_package,
779 *leaf_index,
780 &key_schedule_result.joiner_secret,
781 path_secrets,
782 #[cfg(feature = "psk")]
783 psks.clone(),
784 &encrypted_group_info,
785 )
786 })
787 .try_collect()?;
788
789 #[cfg(any(mls_build_async, not(feature = "rayon")))]
790 let encrypted_path_secrets = {
791 let mut secrets = Vec::new();
792
793 for (key_package, leaf_index) in added_key_pkgs
794 .into_iter()
795 .zip(&provisional_state.indexes_of_added_kpkgs)
796 {
797 secrets.push(
798 self.encrypt_group_secrets(
799 &key_package,
800 *leaf_index,
801 &key_schedule_result.joiner_secret,
802 path_secrets,
803 #[cfg(feature = "psk")]
804 psks.clone(),
805 &encrypted_group_info,
806 )
807 .await?,
808 );
809 }
810
811 secrets
812 };
813
814 let welcome_messages =
815 if commit_options.single_welcome_message && !encrypted_path_secrets.is_empty() {
816 vec![self.make_welcome_message(encrypted_path_secrets, encrypted_group_info)]
817 } else {
818 encrypted_path_secrets
819 .into_iter()
820 .map(|s| self.make_welcome_message(vec![s], encrypted_group_info.clone()))
821 .collect()
822 };
823
824 let commit_message = self.format_for_wire(auth_content.clone()).await?;
825
826 let ratchet_tree = (!commit_options.ratchet_tree_extension
828 || commit_options.always_out_of_band_ratchet_tree)
829 .then(|| ExportedTree::new(provisional_state.public_tree.nodes.clone()));
830
831 let pending_reinit = provisional_state
832 .applied_proposals
833 .reinitializations
834 .first();
835
836 let pending_commit = PendingCommit {
837 output: CommitMessageDescription {
838 is_external: matches!(auth_content.content.sender, Sender::NewMemberCommit),
839 authenticated_data: auth_content.content.authenticated_data,
840 committer: *provisional_private_tree.self_index,
841 effect: match pending_reinit {
842 Some(r) => CommitEffect::ReInit(r.clone()),
843 None => CommitEffect::NewEpoch(
844 NewEpoch::new(self.state.clone(), &provisional_state).into(),
845 ),
846 },
847 },
848
849 state: GroupState {
850 #[cfg(feature = "by_ref_proposal")]
851 proposals: crate::group::ProposalCache::new(
852 self.protocol_version(),
853 self.group_id().to_vec(),
854 ),
855 context: provisional_state.group_context,
856 public_tree: provisional_state.public_tree,
857 interim_transcript_hash,
858 pending_reinit: pending_reinit.map(|r| r.proposal.clone()),
859 confirmation_tag,
860 },
861
862 commit_message_hash: MessageHash::compute(&self.cipher_suite_provider, &commit_message)
863 .await?,
864 signer: new_signer,
865 epoch_secrets: key_schedule_result.epoch_secrets,
866 key_schedule: key_schedule_result.key_schedule,
867
868 private_tree: provisional_private_tree,
869 };
870
871 let output = CommitOutput {
872 commit_message,
873 welcome_messages,
874 ratchet_tree,
875 external_commit_group_info,
876 contains_update_path: perform_path_update,
877 #[cfg(feature = "by_ref_proposal")]
878 unused_proposals: provisional_state.unused_proposals,
879 };
880
881 Ok((output, pending_commit))
882 }
883
884 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
887 async fn make_group_info(
888 &self,
889 group_context: &GroupContext,
890 extensions: ExtensionList,
891 confirmation_tag: &ConfirmationTag,
892 signer: &SignatureSecretKey,
893 ) -> Result<GroupInfo, MlsError> {
894 let mut group_info = GroupInfo {
895 group_context: group_context.clone(),
896 extensions,
897 confirmation_tag: confirmation_tag.clone(), signer: LeafIndex(self.current_member_index()),
899 signature: vec![],
900 };
901
902 group_info.grease(self.cipher_suite_provider())?;
903
904 group_info
906 .sign(&self.cipher_suite_provider, signer, &())
907 .await?;
908
909 Ok(group_info)
910 }
911
912 fn make_welcome_message(
913 &self,
914 secrets: Vec<EncryptedGroupSecrets>,
915 encrypted_group_info: Vec<u8>,
916 ) -> MlsMessage {
917 MlsMessage::new(
918 self.context().protocol_version,
919 MlsMessagePayload::Welcome(Welcome {
920 cipher_suite: self.context().cipher_suite,
921 secrets,
922 encrypted_group_info,
923 }),
924 )
925 }
926}
927
928#[cfg(test)]
929pub(crate) mod test_utils {
930 use alloc::vec::Vec;
931
932 use crate::{
933 crypto::SignatureSecretKey,
934 tree_kem::{leaf_node::LeafNode, TreeKemPublic, UpdatePathNode},
935 };
936
937 #[derive(Copy, Clone, Debug)]
938 pub struct CommitModifiers {
939 pub modify_leaf: fn(&mut LeafNode, &SignatureSecretKey) -> Option<SignatureSecretKey>,
940 pub modify_tree: fn(&mut TreeKemPublic),
941 pub modify_path: fn(Vec<UpdatePathNode>) -> Vec<UpdatePathNode>,
942 }
943
944 impl Default for CommitModifiers {
945 fn default() -> Self {
946 Self {
947 modify_leaf: |_, _| None,
948 modify_tree: |_| (),
949 modify_path: |a| a,
950 }
951 }
952 }
953}
954
955#[cfg(test)]
956mod tests {
957 use mls_rs_core::{
958 error::IntoAnyError,
959 extension::ExtensionType,
960 identity::{CredentialType, IdentityProvider, MemberValidationContext},
961 time::MlsTime,
962 };
963
964 use crate::extension::RequiredCapabilitiesExt;
965 use crate::{
966 client::test_utils::{test_client_with_key_pkg, TEST_CIPHER_SUITE, TEST_PROTOCOL_VERSION},
967 client_builder::{
968 test_utils::TestClientConfig, BaseConfig, ClientBuilder, WithCryptoProvider,
969 WithIdentityProvider,
970 },
971 client_config::ClientConfig,
972 crypto::test_utils::TestCryptoProvider,
973 extension::test_utils::{TestExtension, TEST_EXTENSION_TYPE},
974 group::test_utils::{test_group, test_group_custom},
975 group::{
976 proposal::ProposalType,
977 test_utils::{test_group_custom_config, test_n_member_group},
978 },
979 identity::test_utils::get_test_signing_identity,
980 identity::{basic::BasicIdentityProvider, test_utils::get_test_basic_credential},
981 key_package::test_utils::test_key_package_message,
982 mls_rules::CommitOptions,
983 Client,
984 };
985
986 #[cfg(feature = "by_ref_proposal")]
987 use crate::crypto::test_utils::test_cipher_suite_provider;
988 #[cfg(feature = "by_ref_proposal")]
989 use crate::extension::ExternalSendersExt;
990 #[cfg(feature = "by_ref_proposal")]
991 use crate::group::mls_rules::DefaultMlsRules;
992
993 #[cfg(feature = "psk")]
994 use crate::{
995 group::proposal::PreSharedKeyProposal,
996 psk::{JustPreSharedKeyID, PreSharedKey, PreSharedKeyID},
997 };
998
999 use super::*;
1000
1001 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
1002 async fn test_commit_builder_group() -> Group<TestClientConfig> {
1003 test_group_custom_config(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, |b| {
1004 b.custom_proposal_type(ProposalType::from(42))
1005 .extension_type(TEST_EXTENSION_TYPE.into())
1006 })
1007 .await
1008 .group
1009 }
1010
1011 fn assert_commit_builder_output<C: ClientConfig>(
1012 group: Group<C>,
1013 mut commit_output: CommitOutput,
1014 expected: Vec<Proposal>,
1015 welcome_count: usize,
1016 ) {
1017 let plaintext = commit_output.commit_message.into_plaintext().unwrap();
1018
1019 let commit_data = match plaintext.content.content {
1020 Content::Commit(commit) => commit,
1021 #[cfg(any(feature = "private_message", feature = "by_ref_proposal"))]
1022 _ => panic!("Found non-commit data"),
1023 };
1024
1025 assert_eq!(commit_data.proposals.len(), expected.len());
1026
1027 commit_data.proposals.into_iter().for_each(|proposal| {
1028 let proposal = match proposal {
1029 ProposalOrRef::Proposal(p) => p,
1030 #[cfg(feature = "by_ref_proposal")]
1031 ProposalOrRef::Reference(_) => panic!("found proposal reference"),
1032 };
1033
1034 #[cfg(feature = "psk")]
1035 if let Some(psk_id) = match proposal.as_ref() {
1036 Proposal::Psk(PreSharedKeyProposal { psk: PreSharedKeyID { key_id: JustPreSharedKeyID::External(psk_id), .. },}) => Some(psk_id),
1037 _ => None,
1038 } {
1039 let found = expected.iter().any(|item| matches!(item, Proposal::Psk(PreSharedKeyProposal { psk: PreSharedKeyID { key_id: JustPreSharedKeyID::External(id), .. }}) if id == psk_id));
1040
1041 assert!(found)
1042 } else {
1043 assert!(expected.contains(&proposal));
1044 }
1045
1046 #[cfg(not(feature = "psk"))]
1047 assert!(expected.contains(&proposal));
1048 });
1049
1050 if welcome_count > 0 {
1051 let welcome_msg = commit_output.welcome_messages.pop().unwrap();
1052
1053 assert_eq!(welcome_msg.version, group.state.context.protocol_version);
1054
1055 let welcome_msg = welcome_msg.into_welcome().unwrap();
1056
1057 assert_eq!(welcome_msg.cipher_suite, group.state.context.cipher_suite);
1058 assert_eq!(welcome_msg.secrets.len(), welcome_count);
1059 } else {
1060 assert!(commit_output.welcome_messages.is_empty());
1061 }
1062 }
1063
1064 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1065 async fn test_commit_builder_add() {
1066 let mut group = test_commit_builder_group().await;
1067
1068 let test_key_package =
1069 test_key_package_message(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "alice").await;
1070
1071 let commit_output = group
1072 .commit_builder()
1073 .add_member(test_key_package.clone())
1074 .unwrap()
1075 .build()
1076 .await
1077 .unwrap();
1078
1079 let expected_add = group.add_proposal(test_key_package).unwrap();
1080
1081 assert_commit_builder_output(group, commit_output, vec![expected_add], 1)
1082 }
1083
1084 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1085 async fn test_commit_builder_add_with_ext() {
1086 let mut group = test_commit_builder_group().await;
1087
1088 let (bob_client, bob_key_package) =
1089 test_client_with_key_pkg(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "bob").await;
1090
1091 let ext = TestExtension { foo: 42 };
1092 let mut extension_list = ExtensionList::default();
1093 extension_list.set_from(ext.clone()).unwrap();
1094
1095 let welcome_message = group
1096 .commit_builder()
1097 .add_member(bob_key_package)
1098 .unwrap()
1099 .set_group_info_ext(extension_list)
1100 .build()
1101 .await
1102 .unwrap()
1103 .welcome_messages
1104 .remove(0);
1105
1106 let (_, context) = bob_client.join_group(None, &welcome_message).await.unwrap();
1107
1108 assert_eq!(
1109 context
1110 .group_info_extensions
1111 .get_as::<TestExtension>()
1112 .unwrap()
1113 .unwrap(),
1114 ext
1115 );
1116 }
1117
1118 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1119 async fn test_commit_builder_remove() {
1120 let mut group = test_commit_builder_group().await;
1121 let test_key_package =
1122 test_key_package_message(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "alice").await;
1123
1124 group
1125 .commit_builder()
1126 .add_member(test_key_package)
1127 .unwrap()
1128 .build()
1129 .await
1130 .unwrap();
1131
1132 group.apply_pending_commit().await.unwrap();
1133
1134 let commit_output = group
1135 .commit_builder()
1136 .remove_member(1)
1137 .unwrap()
1138 .build()
1139 .await
1140 .unwrap();
1141
1142 let expected_remove = group.remove_proposal(1).unwrap();
1143
1144 assert_commit_builder_output(group, commit_output, vec![expected_remove], 0);
1145 }
1146
1147 #[cfg(feature = "psk")]
1148 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1149 async fn test_commit_builder_psk() {
1150 let mut group = test_commit_builder_group().await;
1151 let test_psk = ExternalPskId::new(vec![1]);
1152
1153 group
1154 .config
1155 .secret_store()
1156 .insert(test_psk.clone(), PreSharedKey::from(vec![1]));
1157
1158 let commit_output = group
1159 .commit_builder()
1160 .add_external_psk(test_psk.clone())
1161 .unwrap()
1162 .build()
1163 .await
1164 .unwrap();
1165
1166 let key_id = JustPreSharedKeyID::External(test_psk);
1167 let expected_psk = group.psk_proposal(key_id).unwrap();
1168
1169 assert_commit_builder_output(group, commit_output, vec![expected_psk], 0)
1170 }
1171
1172 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1173 async fn test_commit_builder_group_context_ext() {
1174 let mut group = test_commit_builder_group().await;
1175 let mut test_ext = ExtensionList::default();
1176 test_ext
1177 .set_from(RequiredCapabilitiesExt::default())
1178 .unwrap();
1179
1180 let commit_output = group
1181 .commit_builder()
1182 .set_group_context_ext(test_ext.clone())
1183 .unwrap()
1184 .build()
1185 .await
1186 .unwrap();
1187
1188 let expected_ext = group.group_context_extensions_proposal(test_ext);
1189
1190 assert_commit_builder_output(group, commit_output, vec![expected_ext], 0);
1191 }
1192
1193 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1194 async fn test_commit_builder_reinit() {
1195 let mut group = test_commit_builder_group().await;
1196 let test_group_id = "foo".as_bytes().to_vec();
1197 let test_cipher_suite = TEST_CIPHER_SUITE;
1198 let test_protocol_version = TEST_PROTOCOL_VERSION;
1199 let mut test_ext = ExtensionList::default();
1200
1201 test_ext
1202 .set_from(RequiredCapabilitiesExt::default())
1203 .unwrap();
1204
1205 let commit_output = group
1206 .commit_builder()
1207 .reinit(
1208 Some(test_group_id.clone()),
1209 test_protocol_version,
1210 test_cipher_suite,
1211 test_ext.clone(),
1212 )
1213 .unwrap()
1214 .build()
1215 .await
1216 .unwrap();
1217
1218 let expected_reinit = group
1219 .reinit_proposal(
1220 Some(test_group_id),
1221 test_protocol_version,
1222 test_cipher_suite,
1223 test_ext,
1224 )
1225 .unwrap();
1226
1227 assert_commit_builder_output(group, commit_output, vec![expected_reinit], 0);
1228 }
1229
1230 #[cfg(feature = "custom_proposal")]
1231 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1232 async fn test_commit_builder_custom_proposal() {
1233 let mut group = test_commit_builder_group().await;
1234
1235 let proposal = CustomProposal::new(42.into(), vec![0, 1]);
1236
1237 let commit_output = group
1238 .commit_builder()
1239 .custom_proposal(proposal.clone())
1240 .build()
1241 .await
1242 .unwrap();
1243
1244 assert_commit_builder_output(group, commit_output, vec![Proposal::Custom(proposal)], 0);
1245 }
1246
1247 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1248 async fn test_commit_builder_chaining() {
1249 let mut group = test_commit_builder_group().await;
1250 let kp1 = test_key_package_message(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "alice").await;
1251 let kp2 = test_key_package_message(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "bob").await;
1252
1253 let expected_adds = vec![
1254 group.add_proposal(kp1.clone()).unwrap(),
1255 group.add_proposal(kp2.clone()).unwrap(),
1256 ];
1257
1258 let commit_output = group
1259 .commit_builder()
1260 .add_member(kp1)
1261 .unwrap()
1262 .add_member(kp2)
1263 .unwrap()
1264 .build()
1265 .await
1266 .unwrap();
1267
1268 assert_commit_builder_output(group, commit_output, expected_adds, 2);
1269 }
1270
1271 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1272 async fn test_commit_builder_empty_commit() {
1273 let mut group = test_commit_builder_group().await;
1274
1275 let commit_output = group.commit_builder().build().await.unwrap();
1276
1277 assert_commit_builder_output(group, commit_output, vec![], 0);
1278 }
1279
1280 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1281 async fn test_commit_builder_authenticated_data() {
1282 let mut group = test_commit_builder_group().await;
1283 let test_data = "test".as_bytes().to_vec();
1284
1285 let commit_output = group
1286 .commit_builder()
1287 .authenticated_data(test_data.clone())
1288 .build()
1289 .await
1290 .unwrap();
1291
1292 assert_eq!(
1293 commit_output
1294 .commit_message
1295 .into_plaintext()
1296 .unwrap()
1297 .content
1298 .authenticated_data,
1299 test_data
1300 );
1301 }
1302
1303 #[cfg(feature = "by_ref_proposal")]
1304 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1305 async fn test_commit_builder_multiple_welcome_messages() {
1306 let mut group = test_group_custom_config(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, |b| {
1307 let options = CommitOptions::new().with_single_welcome_message(false);
1308 b.mls_rules(DefaultMlsRules::new().with_commit_options(options))
1309 })
1310 .await;
1311
1312 let (alice, alice_kp) =
1313 test_client_with_key_pkg(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "a").await;
1314
1315 let (bob, bob_kp) =
1316 test_client_with_key_pkg(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "b").await;
1317
1318 group.propose_add(alice_kp.clone(), vec![]).await.unwrap();
1319
1320 group.propose_add(bob_kp.clone(), vec![]).await.unwrap();
1321
1322 let output = group.commit(Vec::new()).await.unwrap();
1323 let welcomes = output.welcome_messages;
1324
1325 let cs = test_cipher_suite_provider(TEST_CIPHER_SUITE);
1326
1327 for (client, kp) in [(alice, alice_kp), (bob, bob_kp)] {
1328 let kp_ref = kp.key_package_reference(&cs).await.unwrap().unwrap();
1329
1330 let welcome = welcomes
1331 .iter()
1332 .find(|w| w.welcome_key_package_references().contains(&&kp_ref))
1333 .unwrap();
1334
1335 client.join_group(None, welcome).await.unwrap();
1336
1337 assert_eq!(welcome.clone().into_welcome().unwrap().secrets.len(), 1);
1338 }
1339 }
1340
1341 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1342 async fn commit_can_change_credential() {
1343 let cs = TEST_CIPHER_SUITE;
1344 let mut groups = test_n_member_group(TEST_PROTOCOL_VERSION, cs, 3).await;
1345 let (identity, secret_key) = get_test_signing_identity(cs, b"member").await;
1346
1347 let commit_output = groups[0]
1348 .commit_builder()
1349 .set_new_signing_identity(secret_key, identity.clone())
1350 .build()
1351 .await
1352 .unwrap();
1353
1354 groups[0].process_pending_commit().await.unwrap();
1356 let new_member = groups[0].roster().member_with_index(0).unwrap();
1357
1358 assert_eq!(
1359 new_member.signing_identity.credential,
1360 get_test_basic_credential(b"member".to_vec())
1361 );
1362
1363 assert_eq!(
1364 new_member.signing_identity.signature_key,
1365 identity.signature_key
1366 );
1367
1368 groups[1]
1370 .process_message(commit_output.commit_message)
1371 .await
1372 .unwrap();
1373
1374 let new_member = groups[1].roster().member_with_index(0).unwrap();
1375
1376 assert_eq!(
1377 new_member.signing_identity.credential,
1378 get_test_basic_credential(b"member".to_vec())
1379 );
1380
1381 assert_eq!(
1382 new_member.signing_identity.signature_key,
1383 identity.signature_key
1384 );
1385 }
1386
1387 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1388 async fn commit_includes_tree_if_no_ratchet_tree_ext() {
1389 let mut group = test_group_custom(
1390 TEST_PROTOCOL_VERSION,
1391 TEST_CIPHER_SUITE,
1392 Default::default(),
1393 None,
1394 Some(CommitOptions::new().with_ratchet_tree_extension(false)),
1395 )
1396 .await;
1397
1398 let commit = group.commit(vec![]).await.unwrap();
1399
1400 group.apply_pending_commit().await.unwrap();
1401
1402 let new_tree = group.export_tree();
1403
1404 assert_eq!(new_tree, commit.ratchet_tree.unwrap())
1405 }
1406
1407 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1408 async fn commit_does_not_include_tree_if_ratchet_tree_ext() {
1409 let mut group = test_group_custom(
1410 TEST_PROTOCOL_VERSION,
1411 TEST_CIPHER_SUITE,
1412 Default::default(),
1413 None,
1414 Some(CommitOptions::new().with_ratchet_tree_extension(true)),
1415 )
1416 .await;
1417
1418 let commit = group.commit(vec![]).await.unwrap();
1419
1420 assert!(commit.ratchet_tree.is_none());
1421 }
1422
1423 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1424 async fn commit_includes_external_commit_group_info_if_requested() {
1425 let mut group = test_group_custom(
1426 TEST_PROTOCOL_VERSION,
1427 TEST_CIPHER_SUITE,
1428 Default::default(),
1429 None,
1430 Some(
1431 CommitOptions::new()
1432 .with_allow_external_commit(true)
1433 .with_ratchet_tree_extension(false),
1434 ),
1435 )
1436 .await;
1437
1438 let commit = group.commit(vec![]).await.unwrap();
1439
1440 let info = commit
1441 .external_commit_group_info
1442 .unwrap()
1443 .into_group_info()
1444 .unwrap();
1445
1446 assert!(!info.extensions.has_extension(ExtensionType::RATCHET_TREE));
1447 assert!(info.extensions.has_extension(ExtensionType::EXTERNAL_PUB));
1448 }
1449
1450 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1451 async fn commit_includes_external_commit_and_tree_if_requested() {
1452 let mut group = test_group_custom(
1453 TEST_PROTOCOL_VERSION,
1454 TEST_CIPHER_SUITE,
1455 Default::default(),
1456 None,
1457 Some(
1458 CommitOptions::new()
1459 .with_allow_external_commit(true)
1460 .with_ratchet_tree_extension(true),
1461 ),
1462 )
1463 .await;
1464
1465 let commit = group.commit(vec![]).await.unwrap();
1466
1467 let info = commit
1468 .external_commit_group_info
1469 .unwrap()
1470 .into_group_info()
1471 .unwrap();
1472
1473 assert!(info.extensions.has_extension(ExtensionType::RATCHET_TREE));
1474 assert!(info.extensions.has_extension(ExtensionType::EXTERNAL_PUB));
1475 }
1476
1477 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1478 async fn commit_does_not_include_external_commit_group_info_if_not_requested() {
1479 let mut group = test_group_custom(
1480 TEST_PROTOCOL_VERSION,
1481 TEST_CIPHER_SUITE,
1482 Default::default(),
1483 None,
1484 Some(CommitOptions::new().with_allow_external_commit(false)),
1485 )
1486 .await;
1487
1488 let commit = group.commit(vec![]).await.unwrap();
1489
1490 assert!(commit.external_commit_group_info.is_none());
1491 }
1492
1493 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1494 async fn commit_includes_tree_out_of_bounds_and_not_in_external_group_info_if_requested_tree_ext_off(
1495 ) {
1496 let mut group = test_group_custom(
1497 TEST_PROTOCOL_VERSION,
1498 TEST_CIPHER_SUITE,
1499 Default::default(),
1500 None,
1501 Some(
1502 CommitOptions::new()
1503 .with_always_out_of_band_ratchet_tree(true)
1504 .with_ratchet_tree_extension(false)
1505 .with_allow_external_commit(true),
1506 ),
1507 )
1508 .await;
1509
1510 let commit = group.commit(vec![]).await.unwrap();
1511
1512 assert!(commit.ratchet_tree.is_some());
1513
1514 let info = commit
1515 .external_commit_group_info
1516 .unwrap()
1517 .into_group_info()
1518 .unwrap();
1519
1520 assert!(!info.extensions.has_extension(ExtensionType::RATCHET_TREE));
1521 }
1522
1523 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1524 async fn commit_includes_tree_out_of_bounds_and_not_in_external_group_info_if_requested_tree_ext_on(
1525 ) {
1526 let mut group = test_group_custom(
1527 TEST_PROTOCOL_VERSION,
1528 TEST_CIPHER_SUITE,
1529 Default::default(),
1530 None,
1531 Some(
1532 CommitOptions::new()
1533 .with_always_out_of_band_ratchet_tree(true)
1534 .with_ratchet_tree_extension(true)
1535 .with_allow_external_commit(true),
1536 ),
1537 )
1538 .await;
1539
1540 let commit = group.commit(vec![]).await.unwrap();
1541
1542 assert!(commit.ratchet_tree.is_some());
1543
1544 let info = commit
1545 .external_commit_group_info
1546 .unwrap()
1547 .into_group_info()
1548 .unwrap();
1549
1550 assert!(!info.extensions.has_extension(ExtensionType::RATCHET_TREE));
1551 }
1552
1553 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1554 async fn member_identity_is_validated_against_new_extensions() {
1555 let alice = client_with_test_extension(b"alice").await;
1556 let mut alice = alice
1557 .create_group(ExtensionList::new(), Default::default())
1558 .await
1559 .unwrap();
1560
1561 let bob = client_with_test_extension(b"bob").await;
1562 let bob_kp = bob
1563 .generate_key_package_message(Default::default(), Default::default())
1564 .await
1565 .unwrap();
1566
1567 let mut extension_list = ExtensionList::new();
1568 let extension = TestExtension { foo: b'a' };
1569 extension_list.set_from(extension).unwrap();
1570
1571 let res = alice
1572 .commit_builder()
1573 .add_member(bob_kp)
1574 .unwrap()
1575 .set_group_context_ext(extension_list.clone())
1576 .unwrap()
1577 .build()
1578 .await;
1579
1580 assert!(res.is_err());
1581
1582 let alex = client_with_test_extension(b"alex").await;
1583
1584 alice
1585 .commit_builder()
1586 .add_member(
1587 alex.generate_key_package_message(Default::default(), Default::default())
1588 .await
1589 .unwrap(),
1590 )
1591 .unwrap()
1592 .set_group_context_ext(extension_list.clone())
1593 .unwrap()
1594 .build()
1595 .await
1596 .unwrap();
1597 }
1598
1599 #[cfg(feature = "by_ref_proposal")]
1600 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1601 async fn server_identity_is_validated_against_new_extensions() {
1602 let alice = client_with_test_extension(b"alice").await;
1603 let mut alice = alice
1604 .create_group(ExtensionList::new(), Default::default())
1605 .await
1606 .unwrap();
1607
1608 let mut extension_list = ExtensionList::new();
1609 let extension = TestExtension { foo: b'a' };
1610 extension_list.set_from(extension).unwrap();
1611
1612 let (alex_server, _) = get_test_signing_identity(TEST_CIPHER_SUITE, b"alex").await;
1613
1614 let mut alex_extensions = extension_list.clone();
1615
1616 alex_extensions
1617 .set_from(ExternalSendersExt {
1618 allowed_senders: vec![alex_server],
1619 })
1620 .unwrap();
1621
1622 let res = alice
1623 .commit_builder()
1624 .set_group_context_ext(alex_extensions)
1625 .unwrap()
1626 .build()
1627 .await;
1628
1629 assert!(res.is_err());
1630
1631 let (bob_server, _) = get_test_signing_identity(TEST_CIPHER_SUITE, b"bob").await;
1632
1633 let mut bob_extensions = extension_list;
1634
1635 bob_extensions
1636 .set_from(ExternalSendersExt {
1637 allowed_senders: vec![bob_server],
1638 })
1639 .unwrap();
1640
1641 alice
1642 .commit_builder()
1643 .set_group_context_ext(bob_extensions)
1644 .unwrap()
1645 .build()
1646 .await
1647 .unwrap();
1648 }
1649
1650 #[derive(Debug, Clone)]
1651 struct IdentityProviderWithExtension(BasicIdentityProvider);
1652
1653 #[derive(Clone, Debug)]
1654 #[cfg_attr(feature = "std", derive(thiserror::Error))]
1655 #[cfg_attr(feature = "std", error("test error"))]
1656 struct IdentityProviderWithExtensionError {}
1657
1658 impl IntoAnyError for IdentityProviderWithExtensionError {
1659 #[cfg(feature = "std")]
1660 fn into_dyn_error(self) -> Result<Box<dyn std::error::Error + Send + Sync>, Self> {
1661 Ok(self.into())
1662 }
1663 }
1664
1665 impl IdentityProviderWithExtension {
1666 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
1669 async fn starts_with_foo(
1670 &self,
1671 identity: &SigningIdentity,
1672 _timestamp: Option<MlsTime>,
1673 extensions: Option<&ExtensionList>,
1674 ) -> bool {
1675 if let Some(extensions) = extensions {
1676 if let Some(ext) = extensions.get_as::<TestExtension>().unwrap() {
1677 self.identity(identity, extensions).await.unwrap()[0] == ext.foo
1678 } else {
1679 true
1680 }
1681 } else {
1682 true
1683 }
1684 }
1685 }
1686
1687 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
1688 #[cfg_attr(mls_build_async, maybe_async::must_be_async)]
1689 impl IdentityProvider for IdentityProviderWithExtension {
1690 type Error = IdentityProviderWithExtensionError;
1691
1692 async fn validate_member(
1693 &self,
1694 identity: &SigningIdentity,
1695 timestamp: Option<MlsTime>,
1696 context: MemberValidationContext<'_>,
1697 ) -> Result<(), Self::Error> {
1698 self.starts_with_foo(identity, timestamp, context.new_extensions())
1699 .await
1700 .then_some(())
1701 .ok_or(IdentityProviderWithExtensionError {})
1702 }
1703
1704 async fn validate_external_sender(
1705 &self,
1706 identity: &SigningIdentity,
1707 timestamp: Option<MlsTime>,
1708 extensions: Option<&ExtensionList>,
1709 ) -> Result<(), Self::Error> {
1710 (!self.starts_with_foo(identity, timestamp, extensions).await)
1711 .then_some(())
1712 .ok_or(IdentityProviderWithExtensionError {})
1713 }
1714
1715 async fn identity(
1716 &self,
1717 signing_identity: &SigningIdentity,
1718 extensions: &ExtensionList,
1719 ) -> Result<Vec<u8>, Self::Error> {
1720 self.0
1721 .identity(signing_identity, extensions)
1722 .await
1723 .map_err(|_| IdentityProviderWithExtensionError {})
1724 }
1725
1726 async fn valid_successor(
1727 &self,
1728 _predecessor: &SigningIdentity,
1729 _successor: &SigningIdentity,
1730 _extensions: &ExtensionList,
1731 ) -> Result<bool, Self::Error> {
1732 Ok(true)
1733 }
1734
1735 fn supported_types(&self) -> Vec<CredentialType> {
1736 self.0.supported_types()
1737 }
1738 }
1739
1740 type ExtensionClientConfig = WithIdentityProvider<
1741 IdentityProviderWithExtension,
1742 WithCryptoProvider<TestCryptoProvider, BaseConfig>,
1743 >;
1744
1745 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
1746 async fn client_with_test_extension(name: &[u8]) -> Client<ExtensionClientConfig> {
1747 let (identity, secret_key) = get_test_signing_identity(TEST_CIPHER_SUITE, name).await;
1748
1749 ClientBuilder::new()
1750 .crypto_provider(TestCryptoProvider::new())
1751 .extension_types(vec![TEST_EXTENSION_TYPE.into()])
1752 .identity_provider(IdentityProviderWithExtension(BasicIdentityProvider::new()))
1753 .signing_identity(identity, secret_key, TEST_CIPHER_SUITE)
1754 .build()
1755 }
1756
1757 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1758 async fn detached_commit() {
1759 let mut group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
1760
1761 let (_commit, secrets) = group.commit_builder().build_detached().await.unwrap();
1762 assert!(group.pending_commit.is_none());
1763 group.apply_detached_commit(secrets).await.unwrap();
1764 assert_eq!(group.context().epoch, 1);
1765 }
1766}