1use alloc::boxed::Box;
6use alloc::vec;
7use alloc::vec::Vec;
8use core::fmt::Debug;
9use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
10use mls_rs_core::{crypto::SignatureSecretKey, error::IntoAnyError};
11
12use crate::{
13 cipher_suite::CipherSuite,
14 client::MlsError,
15 client_config::ClientConfig,
16 extension::RatchetTreeExt,
17 identity::SigningIdentity,
18 protocol_version::ProtocolVersion,
19 signer::Signable,
20 tree_kem::{
21 kem::TreeKem, node::LeafIndex, path_secret::PathSecret, TreeKemPrivate, UpdatePath,
22 },
23 ExtensionList, MlsRules,
24};
25
26#[cfg(all(not(mls_build_async), feature = "rayon"))]
27use {crate::iter::ParallelIteratorExt, rayon::prelude::*};
28
29use crate::tree_kem::leaf_node::LeafNode;
30
31#[cfg(not(feature = "private_message"))]
32use crate::WireFormat;
33
34#[cfg(feature = "psk")]
35use crate::{
36 group::{JustPreSharedKeyID, PskGroupId, ResumptionPSKUsage, ResumptionPsk},
37 psk::ExternalPskId,
38};
39
40use super::{
41 confirmation_tag::ConfirmationTag,
42 framing::{Content, MlsMessage, MlsMessagePayload, Sender},
43 key_schedule::{KeySchedule, WelcomeSecret},
44 message_hash::MessageHash,
45 message_processor::{path_update_required, MessageProcessor},
46 message_signature::AuthenticatedContent,
47 mls_rules::CommitDirection,
48 proposal::{Proposal, ProposalOrRef},
49 CommitEffect, CommitMessageDescription, EncryptedGroupSecrets, EpochSecrets, ExportedTree,
50 Group, GroupContext, GroupInfo, GroupState, InterimTranscriptHash, NewEpoch,
51 PendingCommitSnapshot, Welcome,
52};
53
54#[cfg(not(feature = "by_ref_proposal"))]
55use super::proposal_cache::prepare_commit;
56
57#[cfg(feature = "custom_proposal")]
58use super::proposal::CustomProposal;
59
60#[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode)]
61#[cfg_attr(feature = "arbitrary", derive(mls_rs_core::arbitrary::Arbitrary))]
62#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
63pub(crate) struct Commit {
64 pub proposals: Vec<ProposalOrRef>,
65 pub path: Option<UpdatePath>,
66}
67
68#[derive(Clone, PartialEq, Debug, MlsEncode, MlsDecode, MlsSize)]
69pub(crate) struct PendingCommit {
70 pub(crate) state: GroupState,
71 pub(crate) epoch_secrets: EpochSecrets,
72 pub(crate) private_tree: TreeKemPrivate,
73 pub(crate) key_schedule: KeySchedule,
74 pub(crate) signer: SignatureSecretKey,
75
76 pub(crate) output: CommitMessageDescription,
77
78 pub(crate) commit_message_hash: MessageHash,
79}
80
81#[cfg_attr(
82 all(feature = "ffi", not(test)),
83 safer_ffi_gen::ffi_type(clone, opaque)
84)]
85#[derive(Clone)]
86pub struct CommitSecrets(pub(crate) PendingCommitSnapshot);
87
88impl CommitSecrets {
89 pub fn from_bytes(bytes: &[u8]) -> Result<Self, MlsError> {
91 Ok(MlsDecode::mls_decode(&mut &*bytes).map(Self)?)
92 }
93
94 pub fn to_bytes(&self) -> Result<Vec<u8>, MlsError> {
96 Ok(self.0.mls_encode_to_vec()?)
97 }
98}
99
100#[cfg_attr(
101 all(feature = "ffi", not(test)),
102 safer_ffi_gen::ffi_type(clone, opaque)
103)]
104#[derive(Clone, Debug)]
105#[non_exhaustive]
106pub struct CommitOutput {
110 pub commit_message: MlsMessage,
112 pub welcome_messages: Vec<MlsMessage>,
119 pub ratchet_tree: Option<ExportedTree<'static>>,
123 pub external_commit_group_info: Option<MlsMessage>,
127 #[cfg(feature = "by_ref_proposal")]
129 pub unused_proposals: Vec<crate::mls_rules::ProposalInfo<Proposal>>,
130 pub contains_update_path: bool,
132}
133
134#[cfg_attr(all(feature = "ffi", not(test)), ::safer_ffi_gen::safer_ffi_gen)]
135impl CommitOutput {
136 #[cfg(feature = "ffi")]
138 pub fn commit_message(&self) -> &MlsMessage {
139 &self.commit_message
140 }
141
142 #[cfg(feature = "ffi")]
144 pub fn welcome_messages(&self) -> &[MlsMessage] {
145 &self.welcome_messages
146 }
147
148 #[cfg(feature = "ffi")]
152 pub fn ratchet_tree(&self) -> Option<&ExportedTree<'static>> {
153 self.ratchet_tree.as_ref()
154 }
155
156 #[cfg(feature = "ffi")]
160 pub fn external_commit_group_info(&self) -> Option<&MlsMessage> {
161 self.external_commit_group_info.as_ref()
162 }
163
164 #[cfg(all(feature = "ffi", feature = "by_ref_proposal"))]
166 pub fn unused_proposals(&self) -> &[crate::mls_rules::ProposalInfo<Proposal>] {
167 &self.unused_proposals
168 }
169}
170
171pub struct CommitBuilder<'a, C>
179where
180 C: ClientConfig + Clone,
181{
182 group: &'a mut Group<C>,
183 pub(super) proposals: Vec<Proposal>,
184 authenticated_data: Vec<u8>,
185 group_info_extensions: ExtensionList,
186 new_signer: Option<SignatureSecretKey>,
187 new_signing_identity: Option<SigningIdentity>,
188 new_leaf_node_extensions: Option<ExtensionList>,
189}
190
191impl<'a, C> CommitBuilder<'a, C>
192where
193 C: ClientConfig + Clone,
194{
195 pub fn add_member(mut self, key_package: MlsMessage) -> Result<CommitBuilder<'a, C>, MlsError> {
198 let proposal = self.group.add_proposal(key_package)?;
199 self.proposals.push(proposal);
200 Ok(self)
201 }
202
203 pub fn set_group_info_ext(self, extensions: ExtensionList) -> Self {
214 Self {
215 group_info_extensions: extensions,
216 ..self
217 }
218 }
219
220 pub fn remove_member(mut self, index: u32) -> Result<Self, MlsError> {
223 let proposal = self.group.remove_proposal(index)?;
224 self.proposals.push(proposal);
225 Ok(self)
226 }
227
228 pub fn set_group_context_ext(mut self, extensions: ExtensionList) -> Result<Self, MlsError> {
232 let proposal = self.group.group_context_extensions_proposal(extensions);
233 self.proposals.push(proposal);
234 Ok(self)
235 }
236
237 #[cfg(feature = "psk")]
241 pub fn add_external_psk(mut self, psk_id: ExternalPskId) -> Result<Self, MlsError> {
242 let key_id = JustPreSharedKeyID::External(psk_id);
243 let proposal = self.group.psk_proposal(key_id)?;
244 self.proposals.push(proposal);
245 Ok(self)
246 }
247
248 #[cfg(feature = "psk")]
252 pub fn add_resumption_psk(mut self, psk_epoch: u64) -> Result<Self, MlsError> {
253 let psk_id = ResumptionPsk {
254 psk_epoch,
255 usage: ResumptionPSKUsage::Application,
256 psk_group_id: PskGroupId(self.group.group_id().to_vec()),
257 };
258
259 let key_id = JustPreSharedKeyID::Resumption(psk_id);
260 let proposal = self.group.psk_proposal(key_id)?;
261 self.proposals.push(proposal);
262 Ok(self)
263 }
264
265 pub fn reinit(
268 mut self,
269 group_id: Option<Vec<u8>>,
270 version: ProtocolVersion,
271 cipher_suite: CipherSuite,
272 extensions: ExtensionList,
273 ) -> Result<Self, MlsError> {
274 let proposal = self
275 .group
276 .reinit_proposal(group_id, version, cipher_suite, extensions)?;
277
278 self.proposals.push(proposal);
279 Ok(self)
280 }
281
282 #[cfg(feature = "custom_proposal")]
285 pub fn custom_proposal(mut self, proposal: CustomProposal) -> Self {
286 self.proposals.push(Proposal::Custom(proposal));
287 self
288 }
289
290 pub fn raw_proposal(mut self, proposal: Proposal) -> Self {
294 self.proposals.push(proposal);
295 self
296 }
297
298 pub fn raw_proposals(mut self, mut proposals: Vec<Proposal>) -> Self {
302 self.proposals.append(&mut proposals);
303 self
304 }
305
306 pub fn authenticated_data(self, authenticated_data: Vec<u8>) -> Self {
312 Self {
313 authenticated_data,
314 ..self
315 }
316 }
317
318 pub fn set_new_signing_identity(
326 self,
327 signer: SignatureSecretKey,
328 signing_identity: SigningIdentity,
329 ) -> Self {
330 Self {
331 new_signer: Some(signer),
332 new_signing_identity: Some(signing_identity),
333 ..self
334 }
335 }
336
337 pub fn set_leaf_node_extensions(self, new_leaf_node_extensions: ExtensionList) -> Self {
339 Self {
340 new_leaf_node_extensions: Some(new_leaf_node_extensions),
341 ..self
342 }
343 }
344
345 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
354 pub async fn build(self) -> Result<CommitOutput, MlsError> {
355 let (output, pending_commit) = self
356 .group
357 .commit_internal(
358 self.proposals,
359 None,
360 self.authenticated_data,
361 self.group_info_extensions,
362 self.new_signer,
363 self.new_signing_identity,
364 self.new_leaf_node_extensions,
365 )
366 .await?;
367
368 self.group.pending_commit = pending_commit.try_into()?;
369
370 Ok(output)
371 }
372
373 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
378 pub async fn build_detached(self) -> Result<(CommitOutput, CommitSecrets), MlsError> {
379 let (output, pending_commit) = self
380 .group
381 .commit_internal(
382 self.proposals,
383 None,
384 self.authenticated_data,
385 self.group_info_extensions,
386 self.new_signer,
387 self.new_signing_identity,
388 self.new_leaf_node_extensions,
389 )
390 .await?;
391
392 Ok((
393 output,
394 CommitSecrets(PendingCommitSnapshot::PendingCommit(
395 pending_commit.mls_encode_to_vec()?,
396 )),
397 ))
398 }
399}
400
401impl<C> Group<C>
402where
403 C: ClientConfig + Clone,
404{
405 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
446 pub async fn commit(&mut self, authenticated_data: Vec<u8>) -> Result<CommitOutput, MlsError> {
447 self.commit_builder()
448 .authenticated_data(authenticated_data)
449 .build()
450 .await
451 }
452
453 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
458 pub async fn commit_detached(
459 &mut self,
460 authenticated_data: Vec<u8>,
461 ) -> Result<(CommitOutput, CommitSecrets), MlsError> {
462 self.commit_builder()
463 .authenticated_data(authenticated_data)
464 .build_detached()
465 .await
466 }
467
468 pub fn commit_builder(&mut self) -> CommitBuilder<C> {
471 CommitBuilder {
472 group: self,
473 proposals: Default::default(),
474 authenticated_data: Default::default(),
475 group_info_extensions: Default::default(),
476 new_signer: Default::default(),
477 new_signing_identity: Default::default(),
478 new_leaf_node_extensions: Default::default(),
479 }
480 }
481
482 #[allow(clippy::too_many_arguments)]
485 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
486 pub(super) async fn commit_internal(
487 &mut self,
488 proposals: Vec<Proposal>,
489 external_leaf: Option<&LeafNode>,
490 authenticated_data: Vec<u8>,
491 mut welcome_group_info_extensions: ExtensionList,
492 new_signer: Option<SignatureSecretKey>,
493 new_signing_identity: Option<SigningIdentity>,
494 new_leaf_node_extensions: Option<ExtensionList>,
495 ) -> Result<(CommitOutput, PendingCommit), MlsError> {
496 if !self.pending_commit.is_none() {
497 return Err(MlsError::ExistingPendingCommit);
498 }
499
500 if self.state.pending_reinit.is_some() {
501 return Err(MlsError::GroupUsedAfterReInit);
502 }
503
504 let mls_rules = self.config.mls_rules();
505
506 let is_external = external_leaf.is_some();
507
508 let sender = if is_external {
512 Sender::NewMemberCommit
513 } else {
514 Sender::Member(*self.private_tree.self_index)
515 };
516
517 let new_signer = new_signer.unwrap_or_else(|| self.signer.clone());
518 let old_signer = &self.signer;
519
520 #[cfg(feature = "std")]
521 let time = Some(crate::time::MlsTime::now());
522
523 #[cfg(not(feature = "std"))]
524 let time = None;
525
526 #[cfg(feature = "by_ref_proposal")]
527 let proposals = self.state.proposals.prepare_commit(sender, proposals);
528
529 #[cfg(not(feature = "by_ref_proposal"))]
530 let proposals = prepare_commit(sender, proposals);
531
532 let mut provisional_state = self
533 .state
534 .apply_resolved(
535 sender,
536 proposals,
537 external_leaf,
538 &self.config.identity_provider(),
539 &self.cipher_suite_provider,
540 &self.config.secret_store(),
541 &mls_rules,
542 time,
543 CommitDirection::Send,
544 )
545 .await?;
546
547 let (mut provisional_private_tree, _) =
548 self.provisional_private_tree(&provisional_state)?;
549
550 if is_external {
551 provisional_private_tree.self_index = provisional_state
552 .external_init_index
553 .ok_or(MlsError::ExternalCommitMissingExternalInit)?;
554
555 self.private_tree.self_index = provisional_private_tree.self_index;
556 }
557
558 let commit_options = mls_rules
562 .commit_options(
563 &provisional_state.public_tree.roster(),
564 &provisional_state.group_context,
565 &provisional_state.applied_proposals,
566 )
567 .map_err(|e| MlsError::MlsRulesError(e.into_any_error()))?;
568
569 let perform_path_update = commit_options.path_required
570 || path_update_required(&provisional_state.applied_proposals);
571
572 let (update_path, path_secrets, commit_secret) = if perform_path_update {
573 let new_leaf_node_extensions =
581 new_leaf_node_extensions.or(external_leaf.map(|ln| ln.ungreased_extensions()));
582
583 let new_leaf_node_extensions = match new_leaf_node_extensions {
584 Some(extensions) => extensions,
585 None => self.current_user_leaf_node()?.ungreased_extensions(),
587 };
588
589 let encap_gen = TreeKem::new(
590 &mut provisional_state.public_tree,
591 &mut provisional_private_tree,
592 )
593 .encap(
594 &mut provisional_state.group_context,
595 &provisional_state.indexes_of_added_kpkgs,
596 &new_signer,
597 Some(self.config.leaf_properties(new_leaf_node_extensions)),
598 new_signing_identity,
599 &self.cipher_suite_provider,
600 #[cfg(test)]
601 &self.commit_modifiers,
602 )
603 .await?;
604
605 (
606 Some(encap_gen.update_path),
607 Some(encap_gen.path_secrets),
608 encap_gen.commit_secret,
609 )
610 } else {
611 provisional_state
613 .public_tree
614 .update_hashes(
615 &[provisional_private_tree.self_index],
616 &self.cipher_suite_provider,
617 )
618 .await?;
619
620 provisional_state.group_context.tree_hash = provisional_state
621 .public_tree
622 .tree_hash(&self.cipher_suite_provider)
623 .await?;
624
625 (None, None, PathSecret::empty(&self.cipher_suite_provider))
626 };
627
628 #[cfg(feature = "psk")]
629 let (psk_secret, psks) = self
630 .get_psk(&provisional_state.applied_proposals.psks)
631 .await?;
632
633 #[cfg(not(feature = "psk"))]
634 let psk_secret = self.get_psk();
635
636 let added_key_pkgs: Vec<_> = provisional_state
637 .applied_proposals
638 .additions
639 .iter()
640 .map(|info| info.proposal.key_package.clone())
641 .collect();
642
643 let commit = Commit {
644 proposals: provisional_state.applied_proposals.proposals_or_refs(),
645 path: update_path,
646 };
647
648 let mut auth_content = AuthenticatedContent::new_signed(
649 &self.cipher_suite_provider,
650 self.context(),
651 sender,
652 Content::Commit(Box::new(commit)),
653 old_signer,
654 #[cfg(feature = "private_message")]
655 self.encryption_options()?.control_wire_format(sender),
656 #[cfg(not(feature = "private_message"))]
657 WireFormat::PublicMessage,
658 authenticated_data,
659 )
660 .await?;
661
662 let confirmed_transcript_hash = super::transcript_hash::create(
665 self.cipher_suite_provider(),
666 &self.state.interim_transcript_hash,
667 &auth_content,
668 )
669 .await?;
670
671 provisional_state.group_context.confirmed_transcript_hash = confirmed_transcript_hash;
672
673 let key_schedule_result = KeySchedule::from_key_schedule(
674 &self.key_schedule,
675 &commit_secret,
676 &provisional_state.group_context,
677 #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
678 provisional_state.public_tree.total_leaf_count(),
679 &psk_secret,
680 &self.cipher_suite_provider,
681 )
682 .await?;
683
684 let confirmation_tag = ConfirmationTag::create(
685 &key_schedule_result.confirmation_key,
686 &provisional_state.group_context.confirmed_transcript_hash,
687 &self.cipher_suite_provider,
688 )
689 .await?;
690
691 let interim_transcript_hash = InterimTranscriptHash::create(
692 self.cipher_suite_provider(),
693 &provisional_state.group_context.confirmed_transcript_hash,
694 &confirmation_tag,
695 )
696 .await?;
697
698 auth_content.auth.confirmation_tag = Some(confirmation_tag.clone());
699
700 let ratchet_tree_ext = commit_options
701 .ratchet_tree_extension
702 .then(|| RatchetTreeExt {
703 tree_data: ExportedTree::new(provisional_state.public_tree.nodes.clone()),
704 });
705
706 let external_commit_group_info = match commit_options.allow_external_commit {
708 true => {
709 let mut extensions = ExtensionList::new();
710
711 extensions.set_from({
712 key_schedule_result
713 .key_schedule
714 .get_external_key_pair_ext(&self.cipher_suite_provider)
715 .await?
716 })?;
717
718 if let Some(ref ratchet_tree_ext) = ratchet_tree_ext {
719 extensions.set_from(ratchet_tree_ext.clone())?;
720 }
721
722 let info = self
723 .make_group_info(
724 &provisional_state.group_context,
725 extensions,
726 &confirmation_tag,
727 &new_signer,
728 )
729 .await?;
730
731 let msg =
732 MlsMessage::new(self.protocol_version(), MlsMessagePayload::GroupInfo(info));
733
734 Some(msg)
735 }
736 false => None,
737 };
738
739 if let Some(ratchet_tree_ext) = ratchet_tree_ext {
742 welcome_group_info_extensions.set_from(ratchet_tree_ext)?;
743 }
744
745 let welcome_group_info = self
746 .make_group_info(
747 &provisional_state.group_context,
748 welcome_group_info_extensions,
749 &confirmation_tag,
750 &new_signer,
751 )
752 .await?;
753
754 let welcome_secret = WelcomeSecret::from_joiner_secret(
757 &self.cipher_suite_provider,
758 &key_schedule_result.joiner_secret,
759 &psk_secret,
760 )
761 .await?;
762
763 let encrypted_group_info = welcome_secret
764 .encrypt(&welcome_group_info.mls_encode_to_vec()?)
765 .await?;
766
767 let path_secrets = path_secrets.as_ref();
769
770 #[cfg(not(any(mls_build_async, not(feature = "rayon"))))]
771 let encrypted_path_secrets: Vec<_> = added_key_pkgs
772 .into_par_iter()
773 .zip(&provisional_state.indexes_of_added_kpkgs)
774 .map(|(key_package, leaf_index)| {
775 self.encrypt_group_secrets(
776 &key_package,
777 *leaf_index,
778 &key_schedule_result.joiner_secret,
779 path_secrets,
780 #[cfg(feature = "psk")]
781 psks.clone(),
782 &encrypted_group_info,
783 )
784 })
785 .try_collect()?;
786
787 #[cfg(any(mls_build_async, not(feature = "rayon")))]
788 let encrypted_path_secrets = {
789 let mut secrets = Vec::new();
790
791 for (key_package, leaf_index) in added_key_pkgs
792 .into_iter()
793 .zip(&provisional_state.indexes_of_added_kpkgs)
794 {
795 secrets.push(
796 self.encrypt_group_secrets(
797 &key_package,
798 *leaf_index,
799 &key_schedule_result.joiner_secret,
800 path_secrets,
801 #[cfg(feature = "psk")]
802 psks.clone(),
803 &encrypted_group_info,
804 )
805 .await?,
806 );
807 }
808
809 secrets
810 };
811
812 let welcome_messages =
813 if commit_options.single_welcome_message && !encrypted_path_secrets.is_empty() {
814 vec![self.make_welcome_message(encrypted_path_secrets, encrypted_group_info)]
815 } else {
816 encrypted_path_secrets
817 .into_iter()
818 .map(|s| self.make_welcome_message(vec![s], encrypted_group_info.clone()))
819 .collect()
820 };
821
822 let commit_message = self.format_for_wire(auth_content.clone()).await?;
823
824 let ratchet_tree = (!commit_options.ratchet_tree_extension)
826 .then(|| ExportedTree::new(provisional_state.public_tree.nodes.clone()));
827
828 let pending_reinit = provisional_state
829 .applied_proposals
830 .reinitializations
831 .first();
832
833 let pending_commit = PendingCommit {
834 output: CommitMessageDescription {
835 is_external: matches!(auth_content.content.sender, Sender::NewMemberCommit),
836 authenticated_data: auth_content.content.authenticated_data,
837 committer: *provisional_private_tree.self_index,
838 effect: match pending_reinit {
839 Some(r) => CommitEffect::ReInit(r.clone()),
840 None => CommitEffect::NewEpoch(
841 NewEpoch::new(self.state.clone(), &provisional_state).into(),
842 ),
843 },
844 },
845
846 state: GroupState {
847 #[cfg(feature = "by_ref_proposal")]
848 proposals: crate::group::ProposalCache::new(
849 self.protocol_version(),
850 self.group_id().to_vec(),
851 ),
852 context: provisional_state.group_context,
853 public_tree: provisional_state.public_tree,
854 interim_transcript_hash,
855 pending_reinit: pending_reinit.map(|r| r.proposal.clone()),
856 confirmation_tag,
857 },
858
859 commit_message_hash: MessageHash::compute(&self.cipher_suite_provider, &commit_message)
860 .await?,
861 signer: new_signer,
862 epoch_secrets: key_schedule_result.epoch_secrets,
863 key_schedule: key_schedule_result.key_schedule,
864
865 private_tree: provisional_private_tree,
866 };
867
868 let output = CommitOutput {
869 commit_message,
870 welcome_messages,
871 ratchet_tree,
872 external_commit_group_info,
873 contains_update_path: perform_path_update,
874 #[cfg(feature = "by_ref_proposal")]
875 unused_proposals: provisional_state.unused_proposals,
876 };
877
878 Ok((output, pending_commit))
879 }
880
881 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
884 async fn make_group_info(
885 &self,
886 group_context: &GroupContext,
887 extensions: ExtensionList,
888 confirmation_tag: &ConfirmationTag,
889 signer: &SignatureSecretKey,
890 ) -> Result<GroupInfo, MlsError> {
891 let mut group_info = GroupInfo {
892 group_context: group_context.clone(),
893 extensions,
894 confirmation_tag: confirmation_tag.clone(), signer: LeafIndex(self.current_member_index()),
896 signature: vec![],
897 };
898
899 group_info.grease(self.cipher_suite_provider())?;
900
901 group_info
903 .sign(&self.cipher_suite_provider, signer, &())
904 .await?;
905
906 Ok(group_info)
907 }
908
909 fn make_welcome_message(
910 &self,
911 secrets: Vec<EncryptedGroupSecrets>,
912 encrypted_group_info: Vec<u8>,
913 ) -> MlsMessage {
914 MlsMessage::new(
915 self.context().protocol_version,
916 MlsMessagePayload::Welcome(Welcome {
917 cipher_suite: self.context().cipher_suite,
918 secrets,
919 encrypted_group_info,
920 }),
921 )
922 }
923}
924
925#[cfg(test)]
926pub(crate) mod test_utils {
927 use alloc::vec::Vec;
928
929 use crate::{
930 crypto::SignatureSecretKey,
931 tree_kem::{leaf_node::LeafNode, TreeKemPublic, UpdatePathNode},
932 };
933
934 #[derive(Copy, Clone, Debug)]
935 pub struct CommitModifiers {
936 pub modify_leaf: fn(&mut LeafNode, &SignatureSecretKey) -> Option<SignatureSecretKey>,
937 pub modify_tree: fn(&mut TreeKemPublic),
938 pub modify_path: fn(Vec<UpdatePathNode>) -> Vec<UpdatePathNode>,
939 }
940
941 impl Default for CommitModifiers {
942 fn default() -> Self {
943 Self {
944 modify_leaf: |_, _| None,
945 modify_tree: |_| (),
946 modify_path: |a| a,
947 }
948 }
949 }
950}
951
952#[cfg(test)]
953mod tests {
954 use mls_rs_core::{
955 error::IntoAnyError,
956 extension::ExtensionType,
957 identity::{CredentialType, IdentityProvider, MemberValidationContext},
958 time::MlsTime,
959 };
960
961 use crate::extension::RequiredCapabilitiesExt;
962 use crate::{
963 client::test_utils::{test_client_with_key_pkg, TEST_CIPHER_SUITE, TEST_PROTOCOL_VERSION},
964 client_builder::{
965 test_utils::TestClientConfig, BaseConfig, ClientBuilder, WithCryptoProvider,
966 WithIdentityProvider,
967 },
968 client_config::ClientConfig,
969 crypto::test_utils::TestCryptoProvider,
970 extension::test_utils::{TestExtension, TEST_EXTENSION_TYPE},
971 group::test_utils::{test_group, test_group_custom},
972 group::{
973 proposal::ProposalType,
974 test_utils::{test_group_custom_config, test_n_member_group},
975 },
976 identity::test_utils::get_test_signing_identity,
977 identity::{basic::BasicIdentityProvider, test_utils::get_test_basic_credential},
978 key_package::test_utils::test_key_package_message,
979 mls_rules::CommitOptions,
980 Client,
981 };
982
983 #[cfg(feature = "by_ref_proposal")]
984 use crate::crypto::test_utils::test_cipher_suite_provider;
985 #[cfg(feature = "by_ref_proposal")]
986 use crate::extension::ExternalSendersExt;
987 #[cfg(feature = "by_ref_proposal")]
988 use crate::group::mls_rules::DefaultMlsRules;
989
990 #[cfg(feature = "psk")]
991 use crate::{
992 group::proposal::PreSharedKeyProposal,
993 psk::{JustPreSharedKeyID, PreSharedKey, PreSharedKeyID},
994 };
995
996 use super::*;
997
998 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
999 async fn test_commit_builder_group() -> Group<TestClientConfig> {
1000 test_group_custom_config(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, |b| {
1001 b.custom_proposal_type(ProposalType::from(42))
1002 .extension_type(TEST_EXTENSION_TYPE.into())
1003 })
1004 .await
1005 .group
1006 }
1007
1008 fn assert_commit_builder_output<C: ClientConfig>(
1009 group: Group<C>,
1010 mut commit_output: CommitOutput,
1011 expected: Vec<Proposal>,
1012 welcome_count: usize,
1013 ) {
1014 let plaintext = commit_output.commit_message.into_plaintext().unwrap();
1015
1016 let commit_data = match plaintext.content.content {
1017 Content::Commit(commit) => commit,
1018 #[cfg(any(feature = "private_message", feature = "by_ref_proposal"))]
1019 _ => panic!("Found non-commit data"),
1020 };
1021
1022 assert_eq!(commit_data.proposals.len(), expected.len());
1023
1024 commit_data.proposals.into_iter().for_each(|proposal| {
1025 let proposal = match proposal {
1026 ProposalOrRef::Proposal(p) => p,
1027 #[cfg(feature = "by_ref_proposal")]
1028 ProposalOrRef::Reference(_) => panic!("found proposal reference"),
1029 };
1030
1031 #[cfg(feature = "psk")]
1032 if let Some(psk_id) = match proposal.as_ref() {
1033 Proposal::Psk(PreSharedKeyProposal { psk: PreSharedKeyID { key_id: JustPreSharedKeyID::External(psk_id), .. },}) => Some(psk_id),
1034 _ => None,
1035 } {
1036 let found = expected.iter().any(|item| matches!(item, Proposal::Psk(PreSharedKeyProposal { psk: PreSharedKeyID { key_id: JustPreSharedKeyID::External(id), .. }}) if id == psk_id));
1037
1038 assert!(found)
1039 } else {
1040 assert!(expected.contains(&proposal));
1041 }
1042
1043 #[cfg(not(feature = "psk"))]
1044 assert!(expected.contains(&proposal));
1045 });
1046
1047 if welcome_count > 0 {
1048 let welcome_msg = commit_output.welcome_messages.pop().unwrap();
1049
1050 assert_eq!(welcome_msg.version, group.state.context.protocol_version);
1051
1052 let welcome_msg = welcome_msg.into_welcome().unwrap();
1053
1054 assert_eq!(welcome_msg.cipher_suite, group.state.context.cipher_suite);
1055 assert_eq!(welcome_msg.secrets.len(), welcome_count);
1056 } else {
1057 assert!(commit_output.welcome_messages.is_empty());
1058 }
1059 }
1060
1061 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1062 async fn test_commit_builder_add() {
1063 let mut group = test_commit_builder_group().await;
1064
1065 let test_key_package =
1066 test_key_package_message(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "alice").await;
1067
1068 let commit_output = group
1069 .commit_builder()
1070 .add_member(test_key_package.clone())
1071 .unwrap()
1072 .build()
1073 .await
1074 .unwrap();
1075
1076 let expected_add = group.add_proposal(test_key_package).unwrap();
1077
1078 assert_commit_builder_output(group, commit_output, vec![expected_add], 1)
1079 }
1080
1081 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1082 async fn test_commit_builder_add_with_ext() {
1083 let mut group = test_commit_builder_group().await;
1084
1085 let (bob_client, bob_key_package) =
1086 test_client_with_key_pkg(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "bob").await;
1087
1088 let ext = TestExtension { foo: 42 };
1089 let mut extension_list = ExtensionList::default();
1090 extension_list.set_from(ext.clone()).unwrap();
1091
1092 let welcome_message = group
1093 .commit_builder()
1094 .add_member(bob_key_package)
1095 .unwrap()
1096 .set_group_info_ext(extension_list)
1097 .build()
1098 .await
1099 .unwrap()
1100 .welcome_messages
1101 .remove(0);
1102
1103 let (_, context) = bob_client.join_group(None, &welcome_message).await.unwrap();
1104
1105 assert_eq!(
1106 context
1107 .group_info_extensions
1108 .get_as::<TestExtension>()
1109 .unwrap()
1110 .unwrap(),
1111 ext
1112 );
1113 }
1114
1115 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1116 async fn test_commit_builder_remove() {
1117 let mut group = test_commit_builder_group().await;
1118 let test_key_package =
1119 test_key_package_message(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "alice").await;
1120
1121 group
1122 .commit_builder()
1123 .add_member(test_key_package)
1124 .unwrap()
1125 .build()
1126 .await
1127 .unwrap();
1128
1129 group.apply_pending_commit().await.unwrap();
1130
1131 let commit_output = group
1132 .commit_builder()
1133 .remove_member(1)
1134 .unwrap()
1135 .build()
1136 .await
1137 .unwrap();
1138
1139 let expected_remove = group.remove_proposal(1).unwrap();
1140
1141 assert_commit_builder_output(group, commit_output, vec![expected_remove], 0);
1142 }
1143
1144 #[cfg(feature = "psk")]
1145 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1146 async fn test_commit_builder_psk() {
1147 let mut group = test_commit_builder_group().await;
1148 let test_psk = ExternalPskId::new(vec![1]);
1149
1150 group
1151 .config
1152 .secret_store()
1153 .insert(test_psk.clone(), PreSharedKey::from(vec![1]));
1154
1155 let commit_output = group
1156 .commit_builder()
1157 .add_external_psk(test_psk.clone())
1158 .unwrap()
1159 .build()
1160 .await
1161 .unwrap();
1162
1163 let key_id = JustPreSharedKeyID::External(test_psk);
1164 let expected_psk = group.psk_proposal(key_id).unwrap();
1165
1166 assert_commit_builder_output(group, commit_output, vec![expected_psk], 0)
1167 }
1168
1169 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1170 async fn test_commit_builder_group_context_ext() {
1171 let mut group = test_commit_builder_group().await;
1172 let mut test_ext = ExtensionList::default();
1173 test_ext
1174 .set_from(RequiredCapabilitiesExt::default())
1175 .unwrap();
1176
1177 let commit_output = group
1178 .commit_builder()
1179 .set_group_context_ext(test_ext.clone())
1180 .unwrap()
1181 .build()
1182 .await
1183 .unwrap();
1184
1185 let expected_ext = group.group_context_extensions_proposal(test_ext);
1186
1187 assert_commit_builder_output(group, commit_output, vec![expected_ext], 0);
1188 }
1189
1190 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1191 async fn test_commit_builder_reinit() {
1192 let mut group = test_commit_builder_group().await;
1193 let test_group_id = "foo".as_bytes().to_vec();
1194 let test_cipher_suite = TEST_CIPHER_SUITE;
1195 let test_protocol_version = TEST_PROTOCOL_VERSION;
1196 let mut test_ext = ExtensionList::default();
1197
1198 test_ext
1199 .set_from(RequiredCapabilitiesExt::default())
1200 .unwrap();
1201
1202 let commit_output = group
1203 .commit_builder()
1204 .reinit(
1205 Some(test_group_id.clone()),
1206 test_protocol_version,
1207 test_cipher_suite,
1208 test_ext.clone(),
1209 )
1210 .unwrap()
1211 .build()
1212 .await
1213 .unwrap();
1214
1215 let expected_reinit = group
1216 .reinit_proposal(
1217 Some(test_group_id),
1218 test_protocol_version,
1219 test_cipher_suite,
1220 test_ext,
1221 )
1222 .unwrap();
1223
1224 assert_commit_builder_output(group, commit_output, vec![expected_reinit], 0);
1225 }
1226
1227 #[cfg(feature = "custom_proposal")]
1228 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1229 async fn test_commit_builder_custom_proposal() {
1230 let mut group = test_commit_builder_group().await;
1231
1232 let proposal = CustomProposal::new(42.into(), vec![0, 1]);
1233
1234 let commit_output = group
1235 .commit_builder()
1236 .custom_proposal(proposal.clone())
1237 .build()
1238 .await
1239 .unwrap();
1240
1241 assert_commit_builder_output(group, commit_output, vec![Proposal::Custom(proposal)], 0);
1242 }
1243
1244 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1245 async fn test_commit_builder_chaining() {
1246 let mut group = test_commit_builder_group().await;
1247 let kp1 = test_key_package_message(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "alice").await;
1248 let kp2 = test_key_package_message(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "bob").await;
1249
1250 let expected_adds = vec![
1251 group.add_proposal(kp1.clone()).unwrap(),
1252 group.add_proposal(kp2.clone()).unwrap(),
1253 ];
1254
1255 let commit_output = group
1256 .commit_builder()
1257 .add_member(kp1)
1258 .unwrap()
1259 .add_member(kp2)
1260 .unwrap()
1261 .build()
1262 .await
1263 .unwrap();
1264
1265 assert_commit_builder_output(group, commit_output, expected_adds, 2);
1266 }
1267
1268 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1269 async fn test_commit_builder_empty_commit() {
1270 let mut group = test_commit_builder_group().await;
1271
1272 let commit_output = group.commit_builder().build().await.unwrap();
1273
1274 assert_commit_builder_output(group, commit_output, vec![], 0);
1275 }
1276
1277 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1278 async fn test_commit_builder_authenticated_data() {
1279 let mut group = test_commit_builder_group().await;
1280 let test_data = "test".as_bytes().to_vec();
1281
1282 let commit_output = group
1283 .commit_builder()
1284 .authenticated_data(test_data.clone())
1285 .build()
1286 .await
1287 .unwrap();
1288
1289 assert_eq!(
1290 commit_output
1291 .commit_message
1292 .into_plaintext()
1293 .unwrap()
1294 .content
1295 .authenticated_data,
1296 test_data
1297 );
1298 }
1299
1300 #[cfg(feature = "by_ref_proposal")]
1301 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1302 async fn test_commit_builder_multiple_welcome_messages() {
1303 let mut group = test_group_custom_config(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, |b| {
1304 let options = CommitOptions::new().with_single_welcome_message(false);
1305 b.mls_rules(DefaultMlsRules::new().with_commit_options(options))
1306 })
1307 .await;
1308
1309 let (alice, alice_kp) =
1310 test_client_with_key_pkg(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "a").await;
1311
1312 let (bob, bob_kp) =
1313 test_client_with_key_pkg(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "b").await;
1314
1315 group.propose_add(alice_kp.clone(), vec![]).await.unwrap();
1316
1317 group.propose_add(bob_kp.clone(), vec![]).await.unwrap();
1318
1319 let output = group.commit(Vec::new()).await.unwrap();
1320 let welcomes = output.welcome_messages;
1321
1322 let cs = test_cipher_suite_provider(TEST_CIPHER_SUITE);
1323
1324 for (client, kp) in [(alice, alice_kp), (bob, bob_kp)] {
1325 let kp_ref = kp.key_package_reference(&cs).await.unwrap().unwrap();
1326
1327 let welcome = welcomes
1328 .iter()
1329 .find(|w| w.welcome_key_package_references().contains(&&kp_ref))
1330 .unwrap();
1331
1332 client.join_group(None, welcome).await.unwrap();
1333
1334 assert_eq!(welcome.clone().into_welcome().unwrap().secrets.len(), 1);
1335 }
1336 }
1337
1338 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1339 async fn commit_can_change_credential() {
1340 let cs = TEST_CIPHER_SUITE;
1341 let mut groups = test_n_member_group(TEST_PROTOCOL_VERSION, cs, 3).await;
1342 let (identity, secret_key) = get_test_signing_identity(cs, b"member").await;
1343
1344 let commit_output = groups[0]
1345 .commit_builder()
1346 .set_new_signing_identity(secret_key, identity.clone())
1347 .build()
1348 .await
1349 .unwrap();
1350
1351 groups[0].process_pending_commit().await.unwrap();
1353 let new_member = groups[0].roster().member_with_index(0).unwrap();
1354
1355 assert_eq!(
1356 new_member.signing_identity.credential,
1357 get_test_basic_credential(b"member".to_vec())
1358 );
1359
1360 assert_eq!(
1361 new_member.signing_identity.signature_key,
1362 identity.signature_key
1363 );
1364
1365 groups[1]
1367 .process_message(commit_output.commit_message)
1368 .await
1369 .unwrap();
1370
1371 let new_member = groups[1].roster().member_with_index(0).unwrap();
1372
1373 assert_eq!(
1374 new_member.signing_identity.credential,
1375 get_test_basic_credential(b"member".to_vec())
1376 );
1377
1378 assert_eq!(
1379 new_member.signing_identity.signature_key,
1380 identity.signature_key
1381 );
1382 }
1383
1384 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1385 async fn commit_includes_tree_if_no_ratchet_tree_ext() {
1386 let mut group = test_group_custom(
1387 TEST_PROTOCOL_VERSION,
1388 TEST_CIPHER_SUITE,
1389 Default::default(),
1390 None,
1391 Some(CommitOptions::new().with_ratchet_tree_extension(false)),
1392 )
1393 .await;
1394
1395 let commit = group.commit(vec![]).await.unwrap();
1396
1397 group.apply_pending_commit().await.unwrap();
1398
1399 let new_tree = group.export_tree();
1400
1401 assert_eq!(new_tree, commit.ratchet_tree.unwrap())
1402 }
1403
1404 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1405 async fn commit_does_not_include_tree_if_ratchet_tree_ext() {
1406 let mut group = test_group_custom(
1407 TEST_PROTOCOL_VERSION,
1408 TEST_CIPHER_SUITE,
1409 Default::default(),
1410 None,
1411 Some(CommitOptions::new().with_ratchet_tree_extension(true)),
1412 )
1413 .await;
1414
1415 let commit = group.commit(vec![]).await.unwrap();
1416
1417 assert!(commit.ratchet_tree.is_none());
1418 }
1419
1420 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1421 async fn commit_includes_external_commit_group_info_if_requested() {
1422 let mut group = test_group_custom(
1423 TEST_PROTOCOL_VERSION,
1424 TEST_CIPHER_SUITE,
1425 Default::default(),
1426 None,
1427 Some(
1428 CommitOptions::new()
1429 .with_allow_external_commit(true)
1430 .with_ratchet_tree_extension(false),
1431 ),
1432 )
1433 .await;
1434
1435 let commit = group.commit(vec![]).await.unwrap();
1436
1437 let info = commit
1438 .external_commit_group_info
1439 .unwrap()
1440 .into_group_info()
1441 .unwrap();
1442
1443 assert!(!info.extensions.has_extension(ExtensionType::RATCHET_TREE));
1444 assert!(info.extensions.has_extension(ExtensionType::EXTERNAL_PUB));
1445 }
1446
1447 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1448 async fn commit_includes_external_commit_and_tree_if_requested() {
1449 let mut group = test_group_custom(
1450 TEST_PROTOCOL_VERSION,
1451 TEST_CIPHER_SUITE,
1452 Default::default(),
1453 None,
1454 Some(
1455 CommitOptions::new()
1456 .with_allow_external_commit(true)
1457 .with_ratchet_tree_extension(true),
1458 ),
1459 )
1460 .await;
1461
1462 let commit = group.commit(vec![]).await.unwrap();
1463
1464 let info = commit
1465 .external_commit_group_info
1466 .unwrap()
1467 .into_group_info()
1468 .unwrap();
1469
1470 assert!(info.extensions.has_extension(ExtensionType::RATCHET_TREE));
1471 assert!(info.extensions.has_extension(ExtensionType::EXTERNAL_PUB));
1472 }
1473
1474 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1475 async fn commit_does_not_include_external_commit_group_info_if_not_requested() {
1476 let mut group = test_group_custom(
1477 TEST_PROTOCOL_VERSION,
1478 TEST_CIPHER_SUITE,
1479 Default::default(),
1480 None,
1481 Some(CommitOptions::new().with_allow_external_commit(false)),
1482 )
1483 .await;
1484
1485 let commit = group.commit(vec![]).await.unwrap();
1486
1487 assert!(commit.external_commit_group_info.is_none());
1488 }
1489
1490 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1491 async fn member_identity_is_validated_against_new_extensions() {
1492 let alice = client_with_test_extension(b"alice").await;
1493 let mut alice = alice
1494 .create_group(ExtensionList::new(), Default::default())
1495 .await
1496 .unwrap();
1497
1498 let bob = client_with_test_extension(b"bob").await;
1499 let bob_kp = bob
1500 .generate_key_package_message(Default::default(), Default::default())
1501 .await
1502 .unwrap();
1503
1504 let mut extension_list = ExtensionList::new();
1505 let extension = TestExtension { foo: b'a' };
1506 extension_list.set_from(extension).unwrap();
1507
1508 let res = alice
1509 .commit_builder()
1510 .add_member(bob_kp)
1511 .unwrap()
1512 .set_group_context_ext(extension_list.clone())
1513 .unwrap()
1514 .build()
1515 .await;
1516
1517 assert!(res.is_err());
1518
1519 let alex = client_with_test_extension(b"alex").await;
1520
1521 alice
1522 .commit_builder()
1523 .add_member(
1524 alex.generate_key_package_message(Default::default(), Default::default())
1525 .await
1526 .unwrap(),
1527 )
1528 .unwrap()
1529 .set_group_context_ext(extension_list.clone())
1530 .unwrap()
1531 .build()
1532 .await
1533 .unwrap();
1534 }
1535
1536 #[cfg(feature = "by_ref_proposal")]
1537 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1538 async fn server_identity_is_validated_against_new_extensions() {
1539 let alice = client_with_test_extension(b"alice").await;
1540 let mut alice = alice
1541 .create_group(ExtensionList::new(), Default::default())
1542 .await
1543 .unwrap();
1544
1545 let mut extension_list = ExtensionList::new();
1546 let extension = TestExtension { foo: b'a' };
1547 extension_list.set_from(extension).unwrap();
1548
1549 let (alex_server, _) = get_test_signing_identity(TEST_CIPHER_SUITE, b"alex").await;
1550
1551 let mut alex_extensions = extension_list.clone();
1552
1553 alex_extensions
1554 .set_from(ExternalSendersExt {
1555 allowed_senders: vec![alex_server],
1556 })
1557 .unwrap();
1558
1559 let res = alice
1560 .commit_builder()
1561 .set_group_context_ext(alex_extensions)
1562 .unwrap()
1563 .build()
1564 .await;
1565
1566 assert!(res.is_err());
1567
1568 let (bob_server, _) = get_test_signing_identity(TEST_CIPHER_SUITE, b"bob").await;
1569
1570 let mut bob_extensions = extension_list;
1571
1572 bob_extensions
1573 .set_from(ExternalSendersExt {
1574 allowed_senders: vec![bob_server],
1575 })
1576 .unwrap();
1577
1578 alice
1579 .commit_builder()
1580 .set_group_context_ext(bob_extensions)
1581 .unwrap()
1582 .build()
1583 .await
1584 .unwrap();
1585 }
1586
1587 #[derive(Debug, Clone)]
1588 struct IdentityProviderWithExtension(BasicIdentityProvider);
1589
1590 #[derive(Clone, Debug)]
1591 #[cfg_attr(feature = "std", derive(thiserror::Error))]
1592 #[cfg_attr(feature = "std", error("test error"))]
1593 struct IdentityProviderWithExtensionError {}
1594
1595 impl IntoAnyError for IdentityProviderWithExtensionError {
1596 #[cfg(feature = "std")]
1597 fn into_dyn_error(self) -> Result<Box<dyn std::error::Error + Send + Sync>, Self> {
1598 Ok(self.into())
1599 }
1600 }
1601
1602 impl IdentityProviderWithExtension {
1603 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
1606 async fn starts_with_foo(
1607 &self,
1608 identity: &SigningIdentity,
1609 _timestamp: Option<MlsTime>,
1610 extensions: Option<&ExtensionList>,
1611 ) -> bool {
1612 if let Some(extensions) = extensions {
1613 if let Some(ext) = extensions.get_as::<TestExtension>().unwrap() {
1614 self.identity(identity, extensions).await.unwrap()[0] == ext.foo
1615 } else {
1616 true
1617 }
1618 } else {
1619 true
1620 }
1621 }
1622 }
1623
1624 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
1625 #[cfg_attr(mls_build_async, maybe_async::must_be_async)]
1626 impl IdentityProvider for IdentityProviderWithExtension {
1627 type Error = IdentityProviderWithExtensionError;
1628
1629 async fn validate_member(
1630 &self,
1631 identity: &SigningIdentity,
1632 timestamp: Option<MlsTime>,
1633 context: MemberValidationContext<'_>,
1634 ) -> Result<(), Self::Error> {
1635 self.starts_with_foo(identity, timestamp, context.new_extensions())
1636 .await
1637 .then_some(())
1638 .ok_or(IdentityProviderWithExtensionError {})
1639 }
1640
1641 async fn validate_external_sender(
1642 &self,
1643 identity: &SigningIdentity,
1644 timestamp: Option<MlsTime>,
1645 extensions: Option<&ExtensionList>,
1646 ) -> Result<(), Self::Error> {
1647 (!self.starts_with_foo(identity, timestamp, extensions).await)
1648 .then_some(())
1649 .ok_or(IdentityProviderWithExtensionError {})
1650 }
1651
1652 async fn identity(
1653 &self,
1654 signing_identity: &SigningIdentity,
1655 extensions: &ExtensionList,
1656 ) -> Result<Vec<u8>, Self::Error> {
1657 self.0
1658 .identity(signing_identity, extensions)
1659 .await
1660 .map_err(|_| IdentityProviderWithExtensionError {})
1661 }
1662
1663 async fn valid_successor(
1664 &self,
1665 _predecessor: &SigningIdentity,
1666 _successor: &SigningIdentity,
1667 _extensions: &ExtensionList,
1668 ) -> Result<bool, Self::Error> {
1669 Ok(true)
1670 }
1671
1672 fn supported_types(&self) -> Vec<CredentialType> {
1673 self.0.supported_types()
1674 }
1675 }
1676
1677 type ExtensionClientConfig = WithIdentityProvider<
1678 IdentityProviderWithExtension,
1679 WithCryptoProvider<TestCryptoProvider, BaseConfig>,
1680 >;
1681
1682 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
1683 async fn client_with_test_extension(name: &[u8]) -> Client<ExtensionClientConfig> {
1684 let (identity, secret_key) = get_test_signing_identity(TEST_CIPHER_SUITE, name).await;
1685
1686 ClientBuilder::new()
1687 .crypto_provider(TestCryptoProvider::new())
1688 .extension_types(vec![TEST_EXTENSION_TYPE.into()])
1689 .identity_provider(IdentityProviderWithExtension(BasicIdentityProvider::new()))
1690 .signing_identity(identity, secret_key, TEST_CIPHER_SUITE)
1691 .build()
1692 }
1693
1694 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1695 async fn detached_commit() {
1696 let mut group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
1697
1698 let (_commit, secrets) = group.commit_builder().build_detached().await.unwrap();
1699 assert!(group.pending_commit.is_none());
1700 group.apply_detached_commit(secrets).await.unwrap();
1701 assert_eq!(group.context().epoch, 1);
1702 }
1703}