1use alloc::boxed::Box;
6use alloc::vec;
7use alloc::vec::Vec;
8use core::fmt::Debug;
9use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
10use mls_rs_core::{crypto::SignatureSecretKey, error::IntoAnyError};
11
12use crate::{
13 cipher_suite::CipherSuite,
14 client::MlsError,
15 client_config::ClientConfig,
16 extension::RatchetTreeExt,
17 identity::SigningIdentity,
18 protocol_version::ProtocolVersion,
19 signer::Signable,
20 time::MlsTime,
21 tree_kem::{kem::TreeKem, path_secret::PathSecret, TreeKemPrivate, UpdatePath},
22 ExtensionList, MlsRules,
23};
24
25#[cfg(all(not(mls_build_async), feature = "rayon"))]
26use {crate::iter::ParallelIteratorExt, rayon::prelude::*};
27
28use crate::tree_kem::leaf_node::LeafNode;
29
30#[cfg(not(feature = "private_message"))]
31use crate::WireFormat;
32
33#[cfg(feature = "psk")]
34use crate::{
35 group::{JustPreSharedKeyID, PskGroupId, ResumptionPSKUsage, ResumptionPsk},
36 psk::ExternalPskId,
37};
38
39use super::{
40 confirmation_tag::ConfirmationTag,
41 framing::{Content, MlsMessage, MlsMessagePayload, Sender},
42 key_schedule::{KeySchedule, WelcomeSecret},
43 message_hash::MessageHash,
44 message_processor::{path_update_required, MessageProcessor},
45 message_signature::AuthenticatedContent,
46 mls_rules::CommitDirection,
47 proposal::{Proposal, ProposalOrRef},
48 CommitEffect, CommitMessageDescription, EncryptedGroupSecrets, EpochSecrets, ExportedTree,
49 Group, GroupContext, GroupInfo, GroupState, InterimTranscriptHash, NewEpoch,
50 PendingCommitSnapshot, Welcome,
51};
52
53#[cfg(not(feature = "by_ref_proposal"))]
54use super::proposal_cache::prepare_commit;
55
56#[cfg(feature = "custom_proposal")]
57use super::proposal::CustomProposal;
58
59#[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode)]
60#[cfg_attr(feature = "arbitrary", derive(mls_rs_core::arbitrary::Arbitrary))]
61#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
62pub(crate) struct Commit {
63 pub proposals: Vec<ProposalOrRef>,
64 pub path: Option<UpdatePath>,
65}
66
67#[derive(Clone, PartialEq, Debug, MlsEncode, MlsDecode, MlsSize)]
68pub(crate) struct PendingCommit {
69 pub(crate) state: GroupState,
70 pub(crate) epoch_secrets: EpochSecrets,
71 pub(crate) private_tree: TreeKemPrivate,
72 pub(crate) key_schedule: KeySchedule,
73 pub(crate) signer: SignatureSecretKey,
74
75 pub(crate) output: CommitMessageDescription,
76
77 pub(crate) commit_message_hash: MessageHash,
78}
79
80#[derive(Clone)]
81pub struct CommitSecrets(pub(crate) PendingCommitSnapshot);
82
83impl CommitSecrets {
84 pub fn from_bytes(bytes: &[u8]) -> Result<Self, MlsError> {
86 Ok(MlsDecode::mls_decode(&mut &*bytes).map(Self)?)
87 }
88
89 pub fn to_bytes(&self) -> Result<Vec<u8>, MlsError> {
91 Ok(self.0.mls_encode_to_vec()?)
92 }
93}
94
95#[derive(Clone, Debug)]
96#[non_exhaustive]
97pub struct CommitOutput {
101 pub commit_message: MlsMessage,
103 pub welcome_messages: Vec<MlsMessage>,
110 pub ratchet_tree: Option<ExportedTree<'static>>,
114 pub external_commit_group_info: Option<MlsMessage>,
118 #[cfg(feature = "by_ref_proposal")]
120 pub unused_proposals: Vec<crate::mls_rules::ProposalInfo<Proposal>>,
121 pub contains_update_path: bool,
123}
124
125impl CommitOutput {
126 pub fn commit_message(&self) -> &MlsMessage {
128 &self.commit_message
129 }
130
131 pub fn welcome_messages(&self) -> &[MlsMessage] {
133 &self.welcome_messages
134 }
135
136 pub fn ratchet_tree(&self) -> Option<&ExportedTree<'static>> {
140 self.ratchet_tree.as_ref()
141 }
142
143 pub fn external_commit_group_info(&self) -> Option<&MlsMessage> {
147 self.external_commit_group_info.as_ref()
148 }
149
150 #[cfg(feature = "by_ref_proposal")]
152 pub fn unused_proposals(&self) -> &[crate::mls_rules::ProposalInfo<Proposal>] {
153 &self.unused_proposals
154 }
155}
156
157pub struct CommitBuilder<'a, C>
165where
166 C: ClientConfig + Clone,
167{
168 group: &'a mut Group<C>,
169 pub(super) proposals: Vec<Proposal>,
170 authenticated_data: Vec<u8>,
171 group_info_extensions: ExtensionList,
172 new_signer: Option<SignatureSecretKey>,
173 new_signing_identity: Option<SigningIdentity>,
174 new_leaf_node_extensions: Option<ExtensionList>,
175 commit_time: Option<MlsTime>,
176}
177
178impl<'a, C> CommitBuilder<'a, C>
179where
180 C: ClientConfig + Clone,
181{
182 pub fn add_member(mut self, key_package: MlsMessage) -> Result<CommitBuilder<'a, C>, MlsError> {
185 let proposal = self.group.add_proposal(key_package)?;
186 self.proposals.push(proposal);
187 Ok(self)
188 }
189
190 pub fn set_group_info_ext(self, extensions: ExtensionList) -> Self {
201 Self {
202 group_info_extensions: extensions,
203 ..self
204 }
205 }
206
207 pub fn remove_member(mut self, index: u32) -> Result<Self, MlsError> {
210 let proposal = self.group.remove_proposal(index)?;
211 self.proposals.push(proposal);
212 Ok(self)
213 }
214
215 pub fn set_group_context_ext(mut self, extensions: ExtensionList) -> Result<Self, MlsError> {
219 let proposal = self.group.group_context_extensions_proposal(extensions);
220 self.proposals.push(proposal);
221 Ok(self)
222 }
223
224 #[cfg(feature = "psk")]
228 pub fn add_external_psk(mut self, psk_id: ExternalPskId) -> Result<Self, MlsError> {
229 let key_id = JustPreSharedKeyID::External(psk_id);
230 let proposal = self.group.psk_proposal(key_id)?;
231 self.proposals.push(proposal);
232 Ok(self)
233 }
234
235 #[cfg(feature = "psk")]
239 pub fn add_resumption_psk(mut self, psk_epoch: u64) -> Result<Self, MlsError> {
240 let psk_id = ResumptionPsk {
241 psk_epoch,
242 usage: ResumptionPSKUsage::Application,
243 psk_group_id: PskGroupId(self.group.group_id().to_vec()),
244 };
245
246 let key_id = JustPreSharedKeyID::Resumption(psk_id);
247 let proposal = self.group.psk_proposal(key_id)?;
248 self.proposals.push(proposal);
249 Ok(self)
250 }
251
252 pub fn reinit(
255 mut self,
256 group_id: Option<Vec<u8>>,
257 version: ProtocolVersion,
258 cipher_suite: CipherSuite,
259 extensions: ExtensionList,
260 ) -> Result<Self, MlsError> {
261 let proposal = self
262 .group
263 .reinit_proposal(group_id, version, cipher_suite, extensions)?;
264
265 self.proposals.push(proposal);
266 Ok(self)
267 }
268
269 #[cfg(feature = "custom_proposal")]
272 pub fn custom_proposal(mut self, proposal: CustomProposal) -> Self {
273 self.proposals.push(Proposal::Custom(proposal));
274 self
275 }
276
277 pub fn raw_proposal(mut self, proposal: Proposal) -> Self {
281 self.proposals.push(proposal);
282 self
283 }
284
285 pub fn raw_proposals(mut self, mut proposals: Vec<Proposal>) -> Self {
289 self.proposals.append(&mut proposals);
290 self
291 }
292
293 pub fn authenticated_data(self, authenticated_data: Vec<u8>) -> Self {
299 Self {
300 authenticated_data,
301 ..self
302 }
303 }
304
305 pub fn set_new_signing_identity(
313 self,
314 signer: SignatureSecretKey,
315 signing_identity: SigningIdentity,
316 ) -> Self {
317 Self {
318 new_signer: Some(signer),
319 new_signing_identity: Some(signing_identity),
320 ..self
321 }
322 }
323
324 pub fn set_leaf_node_extensions(self, new_leaf_node_extensions: ExtensionList) -> Self {
326 Self {
327 new_leaf_node_extensions: Some(new_leaf_node_extensions),
328 ..self
329 }
330 }
331
332 pub fn commit_time(self, commit_time: MlsTime) -> Self {
334 Self {
335 commit_time: Some(commit_time),
336 ..self
337 }
338 }
339
340 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
349 pub async fn build(self) -> Result<CommitOutput, MlsError> {
350 let (output, pending_commit) = self
351 .group
352 .commit_internal(
353 self.proposals,
354 None,
355 self.authenticated_data,
356 self.group_info_extensions,
357 self.new_signer,
358 self.new_signing_identity,
359 self.new_leaf_node_extensions,
360 self.commit_time,
361 )
362 .await?;
363
364 self.group.pending_commit = pending_commit.try_into()?;
365
366 Ok(output)
367 }
368
369 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
374 pub async fn build_detached(self) -> Result<(CommitOutput, CommitSecrets), MlsError> {
375 let (output, pending_commit) = self
376 .group
377 .commit_internal(
378 self.proposals,
379 None,
380 self.authenticated_data,
381 self.group_info_extensions,
382 self.new_signer,
383 self.new_signing_identity,
384 self.new_leaf_node_extensions,
385 self.commit_time,
386 )
387 .await?;
388
389 Ok((
390 output,
391 CommitSecrets(PendingCommitSnapshot::PendingCommit(
392 pending_commit.mls_encode_to_vec()?,
393 )),
394 ))
395 }
396}
397
398impl<C> Group<C>
399where
400 C: ClientConfig + Clone,
401{
402 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
443 pub async fn commit(&mut self, authenticated_data: Vec<u8>) -> Result<CommitOutput, MlsError> {
444 self.commit_builder()
445 .authenticated_data(authenticated_data)
446 .build()
447 .await
448 }
449
450 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
455 pub async fn commit_detached(
456 &mut self,
457 authenticated_data: Vec<u8>,
458 ) -> Result<(CommitOutput, CommitSecrets), MlsError> {
459 self.commit_builder()
460 .authenticated_data(authenticated_data)
461 .build_detached()
462 .await
463 }
464
465 pub fn commit_builder(&mut self) -> CommitBuilder<'_, C> {
468 CommitBuilder {
469 group: self,
470 proposals: Default::default(),
471 authenticated_data: Default::default(),
472 group_info_extensions: Default::default(),
473 new_signer: Default::default(),
474 new_signing_identity: Default::default(),
475 new_leaf_node_extensions: Default::default(),
476 commit_time: None,
477 }
478 }
479
480 #[allow(clippy::too_many_arguments)]
483 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
484 pub(super) async fn commit_internal(
485 &mut self,
486 proposals: Vec<Proposal>,
487 external_leaf: Option<&LeafNode>,
488 authenticated_data: Vec<u8>,
489 mut welcome_group_info_extensions: ExtensionList,
490 new_signer: Option<SignatureSecretKey>,
491 new_signing_identity: Option<SigningIdentity>,
492 new_leaf_node_extensions: Option<ExtensionList>,
493 commit_time: Option<MlsTime>,
494 ) -> Result<(CommitOutput, PendingCommit), MlsError> {
495 if !self.pending_commit.is_none() {
496 return Err(MlsError::ExistingPendingCommit);
497 }
498
499 if self.state.pending_reinit.is_some() {
500 return Err(MlsError::GroupUsedAfterReInit);
501 }
502
503 let mls_rules = self.config.mls_rules();
504
505 let is_external = external_leaf.is_some();
506
507 let sender = if is_external {
511 Sender::NewMemberCommit
512 } else {
513 Sender::Member(*self.private_tree.self_index)
514 };
515
516 let new_signer = new_signer.unwrap_or_else(|| self.signer.clone());
517 let old_signer = &self.signer;
518
519 #[cfg(feature = "std")]
520 let time = Some(crate::time::MlsTime::now());
521
522 #[cfg(not(feature = "std"))]
523 let time = None;
524
525 let time = if commit_time.is_some() {
526 commit_time
527 } else {
528 time
529 };
530
531 #[cfg(feature = "by_ref_proposal")]
532 let proposals = self.state.proposals.prepare_commit(sender, proposals);
533
534 #[cfg(not(feature = "by_ref_proposal"))]
535 let proposals = prepare_commit(sender, proposals);
536
537 let mut provisional_state = self
538 .state
539 .apply_resolved(
540 sender,
541 proposals,
542 external_leaf,
543 &self.config.identity_provider(),
544 &self.cipher_suite_provider,
545 &self.config.secret_store(),
546 &mls_rules,
547 time,
548 CommitDirection::Send,
549 )
550 .await?;
551
552 let (mut provisional_private_tree, _) =
553 self.provisional_private_tree(&provisional_state)?;
554
555 if is_external {
556 provisional_private_tree.self_index = provisional_state
557 .external_init_index
558 .ok_or(MlsError::ExternalCommitMissingExternalInit)?;
559
560 self.private_tree.self_index = provisional_private_tree.self_index;
561 }
562
563 let commit_options = mls_rules
567 .commit_options(
568 &provisional_state.public_tree.roster(),
569 &provisional_state.group_context,
570 &provisional_state.applied_proposals,
571 )
572 .map_err(|e| MlsError::MlsRulesError(e.into_any_error()))?;
573
574 let perform_path_update = commit_options.path_required
575 || path_update_required(&provisional_state.applied_proposals);
576
577 let (update_path, path_secrets, commit_secret) = if perform_path_update {
578 let new_leaf_node_extensions =
586 new_leaf_node_extensions.or(external_leaf.map(|ln| ln.ungreased_extensions()));
587
588 let new_leaf_node_extensions = match new_leaf_node_extensions {
589 Some(extensions) => extensions,
590 None => self.current_user_leaf_node()?.ungreased_extensions(),
592 };
593
594 #[cfg(feature = "tree_index")]
595 let old_committer_leaf = provisional_state
596 .public_tree
597 .get_leaf_node(provisional_private_tree.self_index)?
598 .clone();
599
600 let encap_gen = TreeKem::new(
601 &mut provisional_state.public_tree,
602 &mut provisional_private_tree,
603 )
604 .encap(
605 &mut provisional_state.group_context,
606 &provisional_state.indexes_of_added_kpkgs,
607 &new_signer,
608 Some(self.config.leaf_properties(new_leaf_node_extensions)),
609 new_signing_identity,
610 &self.cipher_suite_provider,
611 #[cfg(test)]
612 &self.commit_modifiers,
613 )
614 .await?;
615
616 provisional_state
617 .public_tree
618 .update_committer_leaf(
619 &self.config.identity_provider(),
620 &provisional_state.group_context.extensions,
621 provisional_private_tree.self_index,
622 #[cfg(feature = "tree_index")]
623 &old_committer_leaf,
624 #[cfg(test)]
625 !self.commit_modifiers.skip_committer_self_update_validation,
626 )
627 .await?;
628
629 (
630 Some(encap_gen.update_path),
631 Some(encap_gen.path_secrets),
632 encap_gen.commit_secret,
633 )
634 } else {
635 provisional_state
637 .public_tree
638 .update_hashes(
639 &[provisional_private_tree.self_index],
640 &self.cipher_suite_provider,
641 )
642 .await?;
643
644 provisional_state.group_context.tree_hash = provisional_state
645 .public_tree
646 .tree_hash(&self.cipher_suite_provider)
647 .await?;
648
649 (None, None, PathSecret::empty(&self.cipher_suite_provider))
650 };
651
652 #[cfg(feature = "psk")]
653 let (psk_secret, psks) = self
654 .get_psk(&provisional_state.applied_proposals.psks)
655 .await?;
656
657 #[cfg(not(feature = "psk"))]
658 let psk_secret = self.get_psk();
659
660 let added_key_pkgs: Vec<_> = provisional_state
661 .applied_proposals
662 .additions
663 .iter()
664 .map(|info| info.proposal.key_package.clone())
665 .collect();
666
667 let commit = Commit {
668 proposals: provisional_state.applied_proposals.proposals_or_refs(),
669 path: update_path,
670 };
671
672 let mut auth_content = AuthenticatedContent::new_signed(
673 &self.cipher_suite_provider,
674 self.context(),
675 sender,
676 Content::Commit(Box::new(commit)),
677 old_signer,
678 #[cfg(feature = "private_message")]
679 self.encryption_options()?.control_wire_format(sender),
680 #[cfg(not(feature = "private_message"))]
681 WireFormat::PublicMessage,
682 authenticated_data,
683 )
684 .await?;
685
686 let confirmed_transcript_hash = super::transcript_hash::create(
689 self.cipher_suite_provider(),
690 &self.state.interim_transcript_hash,
691 &auth_content,
692 )
693 .await?;
694
695 provisional_state.group_context.confirmed_transcript_hash = confirmed_transcript_hash;
696
697 let key_schedule_result = KeySchedule::from_key_schedule(
698 &self.key_schedule,
699 &commit_secret,
700 &provisional_state.group_context,
701 #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
702 provisional_state.public_tree.total_leaf_count(),
703 &psk_secret,
704 &self.cipher_suite_provider,
705 )
706 .await?;
707
708 let confirmation_tag = ConfirmationTag::create(
709 &key_schedule_result.confirmation_key,
710 &provisional_state.group_context.confirmed_transcript_hash,
711 &self.cipher_suite_provider,
712 )
713 .await?;
714
715 let interim_transcript_hash = InterimTranscriptHash::create(
716 self.cipher_suite_provider(),
717 &provisional_state.group_context.confirmed_transcript_hash,
718 &confirmation_tag,
719 )
720 .await?;
721
722 auth_content.auth.confirmation_tag = Some(confirmation_tag.clone());
723
724 let ratchet_tree_ext = commit_options
725 .ratchet_tree_extension
726 .then(|| RatchetTreeExt {
727 tree_data: ExportedTree::new(provisional_state.public_tree.nodes.clone()),
728 });
729
730 let external_commit_group_info = match commit_options.allow_external_commit {
732 true => {
733 let mut extensions = ExtensionList::new();
734
735 extensions.set_from({
736 key_schedule_result
737 .key_schedule
738 .get_external_key_pair_ext(&self.cipher_suite_provider)
739 .await?
740 })?;
741
742 if let Some(ref ratchet_tree_ext) = ratchet_tree_ext {
743 if !commit_options.always_out_of_band_ratchet_tree {
744 extensions.set_from(ratchet_tree_ext.clone())?;
745 }
746 }
747
748 let info = self
749 .make_group_info(
750 &provisional_state.group_context,
751 extensions,
752 &confirmation_tag,
753 &new_signer,
754 )
755 .await?;
756
757 let msg =
758 MlsMessage::new(self.protocol_version(), MlsMessagePayload::GroupInfo(info));
759
760 Some(msg)
761 }
762 false => None,
763 };
764
765 if let Some(ratchet_tree_ext) = ratchet_tree_ext {
768 welcome_group_info_extensions.set_from(ratchet_tree_ext)?;
769 }
770
771 let welcome_group_info = self
772 .make_group_info(
773 &provisional_state.group_context,
774 welcome_group_info_extensions,
775 &confirmation_tag,
776 &new_signer,
777 )
778 .await?;
779
780 let welcome_secret = WelcomeSecret::from_joiner_secret(
783 &self.cipher_suite_provider,
784 &key_schedule_result.joiner_secret,
785 &psk_secret,
786 )
787 .await?;
788
789 let encrypted_group_info = welcome_secret
790 .encrypt(&welcome_group_info.mls_encode_to_vec()?)
791 .await?;
792
793 let path_secrets = path_secrets.as_ref();
795
796 #[cfg(not(any(mls_build_async, not(feature = "rayon"))))]
797 let encrypted_path_secrets: Vec<_> = added_key_pkgs
798 .into_par_iter()
799 .zip(&provisional_state.indexes_of_added_kpkgs)
800 .map(|(key_package, leaf_index)| {
801 self.encrypt_group_secrets(
802 &key_package,
803 *leaf_index,
804 &key_schedule_result.joiner_secret,
805 path_secrets,
806 #[cfg(feature = "psk")]
807 psks.clone(),
808 &encrypted_group_info,
809 )
810 })
811 .try_collect()?;
812
813 #[cfg(any(mls_build_async, not(feature = "rayon")))]
814 let encrypted_path_secrets = {
815 let mut secrets = Vec::new();
816
817 for (key_package, leaf_index) in added_key_pkgs
818 .into_iter()
819 .zip(&provisional_state.indexes_of_added_kpkgs)
820 {
821 secrets.push(
822 self.encrypt_group_secrets(
823 &key_package,
824 *leaf_index,
825 &key_schedule_result.joiner_secret,
826 path_secrets,
827 #[cfg(feature = "psk")]
828 psks.clone(),
829 &encrypted_group_info,
830 )
831 .await?,
832 );
833 }
834
835 secrets
836 };
837
838 let welcome_messages =
839 if commit_options.single_welcome_message && !encrypted_path_secrets.is_empty() {
840 vec![self.make_welcome_message(encrypted_path_secrets, encrypted_group_info)]
841 } else {
842 encrypted_path_secrets
843 .into_iter()
844 .map(|s| self.make_welcome_message(vec![s], encrypted_group_info.clone()))
845 .collect()
846 };
847
848 let commit_message = self.format_for_wire(auth_content.clone()).await?;
849
850 let ratchet_tree = (!commit_options.ratchet_tree_extension
852 || commit_options.always_out_of_band_ratchet_tree)
853 .then(|| ExportedTree::new(provisional_state.public_tree.nodes.clone()));
854
855 let pending_reinit = provisional_state
856 .applied_proposals
857 .reinitializations
858 .first();
859
860 let pending_commit = PendingCommit {
861 output: CommitMessageDescription {
862 is_external: matches!(auth_content.content.sender, Sender::NewMemberCommit),
863 authenticated_data: auth_content.content.authenticated_data,
864 committer: *provisional_private_tree.self_index,
865 effect: match pending_reinit {
866 Some(r) => CommitEffect::ReInit(r.clone()),
867 None => CommitEffect::NewEpoch(
868 NewEpoch::new(self.state.clone(), &provisional_state).into(),
869 ),
870 },
871 },
872
873 state: GroupState {
874 #[cfg(feature = "by_ref_proposal")]
875 proposals: crate::group::ProposalCache::new(
876 self.protocol_version(),
877 self.group_id().to_vec(),
878 ),
879 context: provisional_state.group_context,
880 public_tree: provisional_state.public_tree,
881 interim_transcript_hash,
882 pending_reinit: pending_reinit.map(|r| r.proposal.clone()),
883 confirmation_tag,
884 },
885
886 commit_message_hash: MessageHash::compute(&self.cipher_suite_provider, &commit_message)
887 .await?,
888 signer: new_signer,
889 epoch_secrets: key_schedule_result.epoch_secrets,
890 key_schedule: key_schedule_result.key_schedule,
891
892 private_tree: provisional_private_tree,
893 };
894
895 let output = CommitOutput {
896 commit_message,
897 welcome_messages,
898 ratchet_tree,
899 external_commit_group_info,
900 contains_update_path: perform_path_update,
901 #[cfg(feature = "by_ref_proposal")]
902 unused_proposals: provisional_state.unused_proposals,
903 };
904
905 Ok((output, pending_commit))
906 }
907
908 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
911 async fn make_group_info(
912 &self,
913 group_context: &GroupContext,
914 extensions: ExtensionList,
915 confirmation_tag: &ConfirmationTag,
916 signer: &SignatureSecretKey,
917 ) -> Result<GroupInfo, MlsError> {
918 let mut group_info = GroupInfo {
919 group_context: group_context.clone(),
920 extensions,
921 confirmation_tag: confirmation_tag.clone(), signer: self.current_member_leaf_index(),
923 signature: vec![],
924 };
925
926 group_info.grease(self.cipher_suite_provider())?;
927
928 group_info
930 .sign(&self.cipher_suite_provider, signer, &())
931 .await?;
932
933 Ok(group_info)
934 }
935
936 fn make_welcome_message(
937 &self,
938 secrets: Vec<EncryptedGroupSecrets>,
939 encrypted_group_info: Vec<u8>,
940 ) -> MlsMessage {
941 MlsMessage::new(
942 self.context().protocol_version,
943 MlsMessagePayload::Welcome(Welcome {
944 cipher_suite: self.context().cipher_suite,
945 secrets,
946 encrypted_group_info,
947 }),
948 )
949 }
950}
951
952#[cfg(test)]
953pub(crate) mod test_utils {
954 use alloc::vec::Vec;
955
956 use crate::{
957 crypto::SignatureSecretKey,
958 tree_kem::{leaf_node::LeafNode, TreeKemPublic, UpdatePathNode},
959 };
960
961 #[derive(Copy, Clone, Debug)]
962 pub struct CommitModifiers {
963 pub modify_leaf: fn(&mut LeafNode, &SignatureSecretKey) -> Option<SignatureSecretKey>,
964 pub modify_tree: fn(&mut TreeKemPublic),
965 pub modify_path: fn(Vec<UpdatePathNode>) -> Vec<UpdatePathNode>,
966 pub skip_committer_self_update_validation: bool,
967 }
968
969 impl Default for CommitModifiers {
970 fn default() -> Self {
971 Self {
972 modify_leaf: |_, _| None,
973 modify_tree: |_| (),
974 modify_path: |a| a,
975 skip_committer_self_update_validation: false,
976 }
977 }
978 }
979}
980
981#[cfg(test)]
982mod tests {
983 use mls_rs_core::{
984 error::IntoAnyError,
985 extension::ExtensionType,
986 identity::{CredentialType, IdentityProvider, MemberValidationContext},
987 time::MlsTime,
988 };
989
990 use crate::extension::RequiredCapabilitiesExt;
991 use crate::{
992 client::test_utils::{test_client_with_key_pkg, TEST_CIPHER_SUITE, TEST_PROTOCOL_VERSION},
993 client_builder::{
994 test_utils::TestClientConfig, BaseConfig, ClientBuilder, WithCryptoProvider,
995 WithIdentityProvider,
996 },
997 client_config::ClientConfig,
998 crypto::test_utils::TestCryptoProvider,
999 extension::test_utils::{TestExtension, TEST_EXTENSION_TYPE},
1000 group::test_utils::{test_group, test_group_custom},
1001 group::{
1002 proposal::ProposalType,
1003 test_utils::{test_group_custom_config, test_n_member_group},
1004 },
1005 identity::test_utils::get_test_signing_identity,
1006 identity::{basic::BasicIdentityProvider, test_utils::get_test_basic_credential},
1007 key_package::test_utils::test_key_package_message,
1008 mls_rules::CommitOptions,
1009 Client,
1010 };
1011
1012 #[cfg(feature = "by_ref_proposal")]
1013 use crate::crypto::test_utils::test_cipher_suite_provider;
1014 #[cfg(feature = "by_ref_proposal")]
1015 use crate::extension::ExternalSendersExt;
1016 #[cfg(feature = "by_ref_proposal")]
1017 use crate::group::mls_rules::DefaultMlsRules;
1018
1019 #[cfg(feature = "psk")]
1020 use crate::{
1021 group::proposal::PreSharedKeyProposal,
1022 psk::{JustPreSharedKeyID, PreSharedKey, PreSharedKeyID},
1023 };
1024
1025 use super::*;
1026
1027 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
1028 async fn test_commit_builder_group() -> Group<TestClientConfig> {
1029 test_group_custom_config(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, |b| {
1030 b.custom_proposal_type(ProposalType::from(42))
1031 .extension_type(TEST_EXTENSION_TYPE.into())
1032 })
1033 .await
1034 .group
1035 }
1036
1037 fn assert_commit_builder_output<C: ClientConfig>(
1038 group: Group<C>,
1039 mut commit_output: CommitOutput,
1040 expected: Vec<Proposal>,
1041 welcome_count: usize,
1042 ) {
1043 let plaintext = commit_output.commit_message.into_plaintext().unwrap();
1044
1045 let commit_data = match plaintext.content.content {
1046 Content::Commit(commit) => commit,
1047 #[cfg(any(feature = "private_message", feature = "by_ref_proposal"))]
1048 _ => panic!("Found non-commit data"),
1049 };
1050
1051 assert_eq!(commit_data.proposals.len(), expected.len());
1052
1053 commit_data.proposals.into_iter().for_each(|proposal| {
1054 let proposal = match proposal {
1055 ProposalOrRef::Proposal(p) => p,
1056 #[cfg(feature = "by_ref_proposal")]
1057 ProposalOrRef::Reference(_) => panic!("found proposal reference"),
1058 };
1059
1060 #[cfg(feature = "psk")]
1061 if let Some(psk_id) = match proposal.as_ref() {
1062 Proposal::Psk(PreSharedKeyProposal { psk: PreSharedKeyID { key_id: JustPreSharedKeyID::External(psk_id), .. },}) => Some(psk_id),
1063 _ => None,
1064 } {
1065 let found = expected.iter().any(|item| matches!(item, Proposal::Psk(PreSharedKeyProposal { psk: PreSharedKeyID { key_id: JustPreSharedKeyID::External(id), .. }}) if id == psk_id));
1066
1067 assert!(found)
1068 } else {
1069 assert!(expected.contains(&proposal));
1070 }
1071
1072 #[cfg(not(feature = "psk"))]
1073 assert!(expected.contains(&proposal));
1074 });
1075
1076 if welcome_count > 0 {
1077 let welcome_msg = commit_output.welcome_messages.pop().unwrap();
1078
1079 assert_eq!(welcome_msg.version, group.state.context.protocol_version);
1080
1081 let welcome_msg = welcome_msg.into_welcome().unwrap();
1082
1083 assert_eq!(welcome_msg.cipher_suite, group.state.context.cipher_suite);
1084 assert_eq!(welcome_msg.secrets.len(), welcome_count);
1085 } else {
1086 assert!(commit_output.welcome_messages.is_empty());
1087 }
1088 }
1089
1090 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1091 async fn test_commit_builder_add() {
1092 let mut group = test_commit_builder_group().await;
1093
1094 let test_key_package =
1095 test_key_package_message(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "alice").await;
1096
1097 let commit_output = group
1098 .commit_builder()
1099 .add_member(test_key_package.clone())
1100 .unwrap()
1101 .build()
1102 .await
1103 .unwrap();
1104
1105 let expected_add = group.add_proposal(test_key_package).unwrap();
1106
1107 assert_commit_builder_output(group, commit_output, vec![expected_add], 1)
1108 }
1109
1110 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1111 async fn test_commit_builder_add_with_ext() {
1112 let mut group = test_commit_builder_group().await;
1113
1114 let (bob_client, bob_key_package) =
1115 test_client_with_key_pkg(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "bob").await;
1116
1117 let ext = TestExtension { foo: 42 };
1118 let mut extension_list = ExtensionList::default();
1119 extension_list.set_from(ext.clone()).unwrap();
1120
1121 let welcome_message = group
1122 .commit_builder()
1123 .add_member(bob_key_package)
1124 .unwrap()
1125 .set_group_info_ext(extension_list)
1126 .build()
1127 .await
1128 .unwrap()
1129 .welcome_messages
1130 .remove(0);
1131
1132 let (_, context) = bob_client
1133 .join_group(None, &welcome_message, None)
1134 .await
1135 .unwrap();
1136
1137 assert_eq!(
1138 context
1139 .group_info_extensions
1140 .get_as::<TestExtension>()
1141 .unwrap()
1142 .unwrap(),
1143 ext
1144 );
1145 }
1146
1147 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1148 async fn test_commit_builder_remove() {
1149 let mut group = test_commit_builder_group().await;
1150 let test_key_package =
1151 test_key_package_message(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "alice").await;
1152
1153 group
1154 .commit_builder()
1155 .add_member(test_key_package)
1156 .unwrap()
1157 .build()
1158 .await
1159 .unwrap();
1160
1161 group.apply_pending_commit().await.unwrap();
1162
1163 let commit_output = group
1164 .commit_builder()
1165 .remove_member(1)
1166 .unwrap()
1167 .build()
1168 .await
1169 .unwrap();
1170
1171 let expected_remove = group.remove_proposal(1).unwrap();
1172
1173 assert_commit_builder_output(group, commit_output, vec![expected_remove], 0);
1174 }
1175
1176 #[cfg(feature = "psk")]
1177 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1178 async fn test_commit_builder_psk() {
1179 let mut group = test_commit_builder_group().await;
1180 let test_psk = ExternalPskId::new(vec![1]);
1181
1182 group
1183 .config
1184 .secret_store()
1185 .insert(test_psk.clone(), PreSharedKey::from(vec![1]));
1186
1187 let commit_output = group
1188 .commit_builder()
1189 .add_external_psk(test_psk.clone())
1190 .unwrap()
1191 .build()
1192 .await
1193 .unwrap();
1194
1195 let key_id = JustPreSharedKeyID::External(test_psk);
1196 let expected_psk = group.psk_proposal(key_id).unwrap();
1197
1198 assert_commit_builder_output(group, commit_output, vec![expected_psk], 0)
1199 }
1200
1201 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1202 async fn test_commit_builder_group_context_ext() {
1203 let mut group = test_commit_builder_group().await;
1204 let mut test_ext = ExtensionList::default();
1205 test_ext
1206 .set_from(RequiredCapabilitiesExt::default())
1207 .unwrap();
1208
1209 let commit_output = group
1210 .commit_builder()
1211 .set_group_context_ext(test_ext.clone())
1212 .unwrap()
1213 .build()
1214 .await
1215 .unwrap();
1216
1217 let expected_ext = group.group_context_extensions_proposal(test_ext);
1218
1219 assert_commit_builder_output(group, commit_output, vec![expected_ext], 0);
1220 }
1221
1222 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1223 async fn test_commit_builder_reinit() {
1224 let mut group = test_commit_builder_group().await;
1225 let test_group_id = "foo".as_bytes().to_vec();
1226 let test_cipher_suite = TEST_CIPHER_SUITE;
1227 let test_protocol_version = TEST_PROTOCOL_VERSION;
1228 let mut test_ext = ExtensionList::default();
1229
1230 test_ext
1231 .set_from(RequiredCapabilitiesExt::default())
1232 .unwrap();
1233
1234 let commit_output = group
1235 .commit_builder()
1236 .reinit(
1237 Some(test_group_id.clone()),
1238 test_protocol_version,
1239 test_cipher_suite,
1240 test_ext.clone(),
1241 )
1242 .unwrap()
1243 .build()
1244 .await
1245 .unwrap();
1246
1247 let expected_reinit = group
1248 .reinit_proposal(
1249 Some(test_group_id),
1250 test_protocol_version,
1251 test_cipher_suite,
1252 test_ext,
1253 )
1254 .unwrap();
1255
1256 assert_commit_builder_output(group, commit_output, vec![expected_reinit], 0);
1257 }
1258
1259 #[cfg(feature = "custom_proposal")]
1260 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1261 async fn test_commit_builder_custom_proposal() {
1262 let mut group = test_commit_builder_group().await;
1263
1264 let proposal = CustomProposal::new(42.into(), vec![0, 1]);
1265
1266 let commit_output = group
1267 .commit_builder()
1268 .custom_proposal(proposal.clone())
1269 .build()
1270 .await
1271 .unwrap();
1272
1273 assert_commit_builder_output(group, commit_output, vec![Proposal::Custom(proposal)], 0);
1274 }
1275
1276 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1277 async fn test_commit_builder_chaining() {
1278 let mut group = test_commit_builder_group().await;
1279 let kp1 = test_key_package_message(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "alice").await;
1280 let kp2 = test_key_package_message(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "bob").await;
1281
1282 let expected_adds = vec![
1283 group.add_proposal(kp1.clone()).unwrap(),
1284 group.add_proposal(kp2.clone()).unwrap(),
1285 ];
1286
1287 let commit_output = group
1288 .commit_builder()
1289 .add_member(kp1)
1290 .unwrap()
1291 .add_member(kp2)
1292 .unwrap()
1293 .build()
1294 .await
1295 .unwrap();
1296
1297 assert_commit_builder_output(group, commit_output, expected_adds, 2);
1298 }
1299
1300 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1301 async fn test_commit_builder_empty_commit() {
1302 let mut group = test_commit_builder_group().await;
1303
1304 let commit_output = group.commit_builder().build().await.unwrap();
1305
1306 assert_commit_builder_output(group, commit_output, vec![], 0);
1307 }
1308
1309 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1310 async fn test_commit_builder_authenticated_data() {
1311 let mut group = test_commit_builder_group().await;
1312 let test_data = "test".as_bytes().to_vec();
1313
1314 let commit_output = group
1315 .commit_builder()
1316 .authenticated_data(test_data.clone())
1317 .build()
1318 .await
1319 .unwrap();
1320
1321 assert_eq!(
1322 commit_output
1323 .commit_message
1324 .into_plaintext()
1325 .unwrap()
1326 .content
1327 .authenticated_data,
1328 test_data
1329 );
1330 }
1331
1332 #[cfg(feature = "by_ref_proposal")]
1333 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1334 async fn test_commit_builder_multiple_welcome_messages() {
1335 let mut group = test_group_custom_config(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, |b| {
1336 let options = CommitOptions::new().with_single_welcome_message(false);
1337 b.mls_rules(DefaultMlsRules::new().with_commit_options(options))
1338 })
1339 .await;
1340
1341 let (alice, alice_kp) =
1342 test_client_with_key_pkg(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "a").await;
1343
1344 let (bob, bob_kp) =
1345 test_client_with_key_pkg(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "b").await;
1346
1347 group.propose_add(alice_kp.clone(), vec![]).await.unwrap();
1348
1349 group.propose_add(bob_kp.clone(), vec![]).await.unwrap();
1350
1351 let output = group.commit(Vec::new()).await.unwrap();
1352 let welcomes = output.welcome_messages;
1353
1354 let cs = test_cipher_suite_provider(TEST_CIPHER_SUITE);
1355
1356 for (client, kp) in [(alice, alice_kp), (bob, bob_kp)] {
1357 let kp_ref = kp.key_package_reference(&cs).await.unwrap().unwrap();
1358
1359 let welcome = welcomes
1360 .iter()
1361 .find(|w| w.welcome_key_package_references().contains(&&kp_ref))
1362 .unwrap();
1363
1364 client.join_group(None, welcome, None).await.unwrap();
1365
1366 assert_eq!(welcome.clone().into_welcome().unwrap().secrets.len(), 1);
1367 }
1368 }
1369
1370 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1371 async fn commit_can_change_credential() {
1372 let cs = TEST_CIPHER_SUITE;
1373 let mut groups = test_n_member_group(TEST_PROTOCOL_VERSION, cs, 3).await;
1374 let (identity, secret_key) = get_test_signing_identity(cs, b"member").await;
1375
1376 let commit_output = groups[0]
1377 .commit_builder()
1378 .set_new_signing_identity(secret_key, identity.clone())
1379 .build()
1380 .await
1381 .unwrap();
1382
1383 groups[0].process_pending_commit().await.unwrap();
1385 let new_member = groups[0].roster().member_with_index(0).unwrap();
1386
1387 assert_eq!(
1388 new_member.signing_identity.credential,
1389 get_test_basic_credential(b"member".to_vec())
1390 );
1391
1392 assert_eq!(
1393 new_member.signing_identity.signature_key,
1394 identity.signature_key
1395 );
1396
1397 groups[1]
1399 .process_message(commit_output.commit_message)
1400 .await
1401 .unwrap();
1402
1403 let new_member = groups[1].roster().member_with_index(0).unwrap();
1404
1405 assert_eq!(
1406 new_member.signing_identity.credential,
1407 get_test_basic_credential(b"member".to_vec())
1408 );
1409
1410 assert_eq!(
1411 new_member.signing_identity.signature_key,
1412 identity.signature_key
1413 );
1414 }
1415
1416 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1417 async fn commit_includes_tree_if_no_ratchet_tree_ext() {
1418 let mut group = test_group_custom(
1419 TEST_PROTOCOL_VERSION,
1420 TEST_CIPHER_SUITE,
1421 Default::default(),
1422 None,
1423 Some(CommitOptions::new().with_ratchet_tree_extension(false)),
1424 )
1425 .await;
1426
1427 let commit = group.commit(vec![]).await.unwrap();
1428
1429 group.apply_pending_commit().await.unwrap();
1430
1431 let new_tree = group.export_tree();
1432
1433 assert_eq!(new_tree, commit.ratchet_tree.unwrap())
1434 }
1435
1436 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1437 async fn commit_does_not_include_tree_if_ratchet_tree_ext() {
1438 let mut group = test_group_custom(
1439 TEST_PROTOCOL_VERSION,
1440 TEST_CIPHER_SUITE,
1441 Default::default(),
1442 None,
1443 Some(CommitOptions::new().with_ratchet_tree_extension(true)),
1444 )
1445 .await;
1446
1447 let commit = group.commit(vec![]).await.unwrap();
1448
1449 assert!(commit.ratchet_tree.is_none());
1450 }
1451
1452 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1453 async fn commit_includes_external_commit_group_info_if_requested() {
1454 let mut group = test_group_custom(
1455 TEST_PROTOCOL_VERSION,
1456 TEST_CIPHER_SUITE,
1457 Default::default(),
1458 None,
1459 Some(
1460 CommitOptions::new()
1461 .with_allow_external_commit(true)
1462 .with_ratchet_tree_extension(false),
1463 ),
1464 )
1465 .await;
1466
1467 let commit = group.commit(vec![]).await.unwrap();
1468
1469 let info = commit
1470 .external_commit_group_info
1471 .unwrap()
1472 .into_group_info()
1473 .unwrap();
1474
1475 assert!(!info.extensions.has_extension(ExtensionType::RATCHET_TREE));
1476 assert!(info.extensions.has_extension(ExtensionType::EXTERNAL_PUB));
1477 }
1478
1479 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1480 async fn commit_includes_external_commit_and_tree_if_requested() {
1481 let mut group = test_group_custom(
1482 TEST_PROTOCOL_VERSION,
1483 TEST_CIPHER_SUITE,
1484 Default::default(),
1485 None,
1486 Some(
1487 CommitOptions::new()
1488 .with_allow_external_commit(true)
1489 .with_ratchet_tree_extension(true),
1490 ),
1491 )
1492 .await;
1493
1494 let commit = group.commit(vec![]).await.unwrap();
1495
1496 let info = commit
1497 .external_commit_group_info
1498 .unwrap()
1499 .into_group_info()
1500 .unwrap();
1501
1502 assert!(info.extensions.has_extension(ExtensionType::RATCHET_TREE));
1503 assert!(info.extensions.has_extension(ExtensionType::EXTERNAL_PUB));
1504 }
1505
1506 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1507 async fn commit_does_not_include_external_commit_group_info_if_not_requested() {
1508 let mut group = test_group_custom(
1509 TEST_PROTOCOL_VERSION,
1510 TEST_CIPHER_SUITE,
1511 Default::default(),
1512 None,
1513 Some(CommitOptions::new().with_allow_external_commit(false)),
1514 )
1515 .await;
1516
1517 let commit = group.commit(vec![]).await.unwrap();
1518
1519 assert!(commit.external_commit_group_info.is_none());
1520 }
1521
1522 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1523 async fn commit_includes_tree_out_of_bounds_and_not_in_external_group_info_if_requested_tree_ext_off(
1524 ) {
1525 let mut group = test_group_custom(
1526 TEST_PROTOCOL_VERSION,
1527 TEST_CIPHER_SUITE,
1528 Default::default(),
1529 None,
1530 Some(
1531 CommitOptions::new()
1532 .with_always_out_of_band_ratchet_tree(true)
1533 .with_ratchet_tree_extension(false)
1534 .with_allow_external_commit(true),
1535 ),
1536 )
1537 .await;
1538
1539 let commit = group.commit(vec![]).await.unwrap();
1540
1541 assert!(commit.ratchet_tree.is_some());
1542
1543 let info = commit
1544 .external_commit_group_info
1545 .unwrap()
1546 .into_group_info()
1547 .unwrap();
1548
1549 assert!(!info.extensions.has_extension(ExtensionType::RATCHET_TREE));
1550 }
1551
1552 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1553 async fn commit_includes_tree_out_of_bounds_and_not_in_external_group_info_if_requested_tree_ext_on(
1554 ) {
1555 let mut group = test_group_custom(
1556 TEST_PROTOCOL_VERSION,
1557 TEST_CIPHER_SUITE,
1558 Default::default(),
1559 None,
1560 Some(
1561 CommitOptions::new()
1562 .with_always_out_of_band_ratchet_tree(true)
1563 .with_ratchet_tree_extension(true)
1564 .with_allow_external_commit(true),
1565 ),
1566 )
1567 .await;
1568
1569 let commit = group.commit(vec![]).await.unwrap();
1570
1571 assert!(commit.ratchet_tree.is_some());
1572
1573 let info = commit
1574 .external_commit_group_info
1575 .unwrap()
1576 .into_group_info()
1577 .unwrap();
1578
1579 assert!(!info.extensions.has_extension(ExtensionType::RATCHET_TREE));
1580 }
1581
1582 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1583 async fn member_identity_is_validated_against_new_extensions() {
1584 let alice = client_with_test_extension(b"alice").await;
1585 let mut alice = alice
1586 .create_group(ExtensionList::new(), Default::default(), None)
1587 .await
1588 .unwrap();
1589
1590 let bob = client_with_test_extension(b"bob").await;
1591 let bob_kp = bob
1592 .generate_key_package_message(Default::default(), Default::default(), None)
1593 .await
1594 .unwrap();
1595
1596 let mut extension_list = ExtensionList::new();
1597 let extension = TestExtension { foo: b'a' };
1598 extension_list.set_from(extension).unwrap();
1599
1600 let res = alice
1601 .commit_builder()
1602 .add_member(bob_kp)
1603 .unwrap()
1604 .set_group_context_ext(extension_list.clone())
1605 .unwrap()
1606 .build()
1607 .await;
1608
1609 assert!(res.is_err());
1610
1611 let alex = client_with_test_extension(b"alex").await;
1612
1613 alice
1614 .commit_builder()
1615 .add_member(
1616 alex.generate_key_package_message(Default::default(), Default::default(), None)
1617 .await
1618 .unwrap(),
1619 )
1620 .unwrap()
1621 .set_group_context_ext(extension_list.clone())
1622 .unwrap()
1623 .build()
1624 .await
1625 .unwrap();
1626 }
1627
1628 #[cfg(feature = "by_ref_proposal")]
1629 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1630 async fn server_identity_is_validated_against_new_extensions() {
1631 let alice = client_with_test_extension(b"alice").await;
1632 let mut alice = alice
1633 .create_group(ExtensionList::new(), Default::default(), None)
1634 .await
1635 .unwrap();
1636
1637 let mut extension_list = ExtensionList::new();
1638 let extension = TestExtension { foo: b'a' };
1639 extension_list.set_from(extension).unwrap();
1640
1641 let (alex_server, _) = get_test_signing_identity(TEST_CIPHER_SUITE, b"alex").await;
1642
1643 let mut alex_extensions = extension_list.clone();
1644
1645 alex_extensions
1646 .set_from(ExternalSendersExt {
1647 allowed_senders: vec![alex_server],
1648 })
1649 .unwrap();
1650
1651 let res = alice
1652 .commit_builder()
1653 .set_group_context_ext(alex_extensions)
1654 .unwrap()
1655 .build()
1656 .await;
1657
1658 assert!(res.is_err());
1659
1660 let (bob_server, _) = get_test_signing_identity(TEST_CIPHER_SUITE, b"bob").await;
1661
1662 let mut bob_extensions = extension_list;
1663
1664 bob_extensions
1665 .set_from(ExternalSendersExt {
1666 allowed_senders: vec![bob_server],
1667 })
1668 .unwrap();
1669
1670 alice
1671 .commit_builder()
1672 .set_group_context_ext(bob_extensions)
1673 .unwrap()
1674 .build()
1675 .await
1676 .unwrap();
1677 }
1678
1679 #[derive(Debug, Clone)]
1680 struct IdentityProviderWithExtension(BasicIdentityProvider);
1681
1682 #[derive(Clone, Debug)]
1683 #[cfg_attr(feature = "std", derive(thiserror::Error))]
1684 #[cfg_attr(feature = "std", error("test error"))]
1685 struct IdentityProviderWithExtensionError {}
1686
1687 impl IntoAnyError for IdentityProviderWithExtensionError {
1688 #[cfg(feature = "std")]
1689 fn into_dyn_error(self) -> Result<Box<dyn std::error::Error + Send + Sync>, Self> {
1690 Ok(self.into())
1691 }
1692 }
1693
1694 impl IdentityProviderWithExtension {
1695 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
1698 async fn starts_with_foo(
1699 &self,
1700 identity: &SigningIdentity,
1701 _timestamp: Option<MlsTime>,
1702 extensions: Option<&ExtensionList>,
1703 ) -> bool {
1704 if let Some(extensions) = extensions {
1705 if let Some(ext) = extensions.get_as::<TestExtension>().unwrap() {
1706 self.identity(identity, extensions).await.unwrap()[0] == ext.foo
1707 } else {
1708 true
1709 }
1710 } else {
1711 true
1712 }
1713 }
1714 }
1715
1716 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
1717 #[cfg_attr(mls_build_async, maybe_async::must_be_async)]
1718 impl IdentityProvider for IdentityProviderWithExtension {
1719 type Error = IdentityProviderWithExtensionError;
1720
1721 async fn validate_member(
1722 &self,
1723 identity: &SigningIdentity,
1724 timestamp: Option<MlsTime>,
1725 context: MemberValidationContext<'_>,
1726 ) -> Result<(), Self::Error> {
1727 self.starts_with_foo(identity, timestamp, context.new_extensions())
1728 .await
1729 .then_some(())
1730 .ok_or(IdentityProviderWithExtensionError {})
1731 }
1732
1733 async fn validate_external_sender(
1734 &self,
1735 identity: &SigningIdentity,
1736 timestamp: Option<MlsTime>,
1737 extensions: Option<&ExtensionList>,
1738 ) -> Result<(), Self::Error> {
1739 (!self.starts_with_foo(identity, timestamp, extensions).await)
1740 .then_some(())
1741 .ok_or(IdentityProviderWithExtensionError {})
1742 }
1743
1744 async fn identity(
1745 &self,
1746 signing_identity: &SigningIdentity,
1747 extensions: &ExtensionList,
1748 ) -> Result<Vec<u8>, Self::Error> {
1749 self.0
1750 .identity(signing_identity, extensions)
1751 .await
1752 .map_err(|_| IdentityProviderWithExtensionError {})
1753 }
1754
1755 async fn valid_successor(
1756 &self,
1757 _predecessor: &SigningIdentity,
1758 _successor: &SigningIdentity,
1759 _extensions: &ExtensionList,
1760 ) -> Result<bool, Self::Error> {
1761 Ok(true)
1762 }
1763
1764 fn supported_types(&self) -> Vec<CredentialType> {
1765 self.0.supported_types()
1766 }
1767 }
1768
1769 type ExtensionClientConfig = WithIdentityProvider<
1770 IdentityProviderWithExtension,
1771 WithCryptoProvider<TestCryptoProvider, BaseConfig>,
1772 >;
1773
1774 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
1775 async fn client_with_test_extension(name: &[u8]) -> Client<ExtensionClientConfig> {
1776 let (identity, secret_key) = get_test_signing_identity(TEST_CIPHER_SUITE, name).await;
1777
1778 ClientBuilder::new()
1779 .crypto_provider(TestCryptoProvider::new())
1780 .extension_types(vec![TEST_EXTENSION_TYPE.into()])
1781 .identity_provider(IdentityProviderWithExtension(BasicIdentityProvider::new()))
1782 .signing_identity(identity, secret_key, TEST_CIPHER_SUITE)
1783 .build()
1784 }
1785
1786 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1787 async fn detached_commit() {
1788 let mut group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
1789
1790 let (_commit, secrets) = group.commit_builder().build_detached().await.unwrap();
1791 assert!(group.pending_commit.is_none());
1792 group.apply_detached_commit(secrets).await.unwrap();
1793 assert_eq!(group.context().epoch, 1);
1794 }
1795
1796 #[cfg(feature = "tree_index")]
1797 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1798 async fn tree_index_consistent_after_committer_self_update() {
1799 use crate::identity::basic::BasicIdentityProvider;
1800 use crate::tree_kem::TreeKemPublic;
1801
1802 let mut group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
1803
1804 group.commit(vec![]).await.unwrap();
1805 group.process_pending_commit().await.unwrap();
1806
1807 let mut rebuilt = TreeKemPublic::import_node_data(
1808 group.state.public_tree.nodes.clone(),
1809 &BasicIdentityProvider,
1810 &Default::default(),
1811 )
1812 .await
1813 .unwrap();
1814
1815 let cs = test_cipher_suite_provider(TEST_CIPHER_SUITE);
1816 rebuilt.tree_hash(&cs).await.unwrap();
1817
1818 assert!(group.state.public_tree.equal_internals(&rebuilt));
1819 }
1820}