mls_rs/group/
commit.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// Copyright by contributors to this project.
3// SPDX-License-Identifier: (Apache-2.0 OR MIT)
4
5use alloc::boxed::Box;
6use alloc::vec;
7use alloc::vec::Vec;
8use core::fmt::Debug;
9use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
10use mls_rs_core::{crypto::SignatureSecretKey, error::IntoAnyError};
11
12use crate::{
13    cipher_suite::CipherSuite,
14    client::MlsError,
15    client_config::ClientConfig,
16    extension::RatchetTreeExt,
17    identity::SigningIdentity,
18    protocol_version::ProtocolVersion,
19    signer::Signable,
20    time::MlsTime,
21    tree_kem::{kem::TreeKem, path_secret::PathSecret, TreeKemPrivate, UpdatePath},
22    ExtensionList, MlsRules,
23};
24
25#[cfg(all(not(mls_build_async), feature = "rayon"))]
26use {crate::iter::ParallelIteratorExt, rayon::prelude::*};
27
28use crate::tree_kem::leaf_node::LeafNode;
29
30#[cfg(not(feature = "private_message"))]
31use crate::WireFormat;
32
33#[cfg(feature = "psk")]
34use crate::{
35    group::{JustPreSharedKeyID, PskGroupId, ResumptionPSKUsage, ResumptionPsk},
36    psk::ExternalPskId,
37};
38
39use super::{
40    confirmation_tag::ConfirmationTag,
41    framing::{Content, MlsMessage, MlsMessagePayload, Sender},
42    key_schedule::{KeySchedule, WelcomeSecret},
43    message_hash::MessageHash,
44    message_processor::{path_update_required, MessageProcessor},
45    message_signature::AuthenticatedContent,
46    mls_rules::CommitDirection,
47    proposal::{Proposal, ProposalOrRef},
48    CommitEffect, CommitMessageDescription, EncryptedGroupSecrets, EpochSecrets, ExportedTree,
49    Group, GroupContext, GroupInfo, GroupState, InterimTranscriptHash, NewEpoch,
50    PendingCommitSnapshot, Welcome,
51};
52
53#[cfg(not(feature = "by_ref_proposal"))]
54use super::proposal_cache::prepare_commit;
55
56#[cfg(feature = "custom_proposal")]
57use super::proposal::CustomProposal;
58
59#[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode)]
60#[cfg_attr(feature = "arbitrary", derive(mls_rs_core::arbitrary::Arbitrary))]
61#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
62pub(crate) struct Commit {
63    pub proposals: Vec<ProposalOrRef>,
64    pub path: Option<UpdatePath>,
65}
66
67#[derive(Clone, PartialEq, Debug, MlsEncode, MlsDecode, MlsSize)]
68pub(crate) struct PendingCommit {
69    pub(crate) state: GroupState,
70    pub(crate) epoch_secrets: EpochSecrets,
71    pub(crate) private_tree: TreeKemPrivate,
72    pub(crate) key_schedule: KeySchedule,
73    pub(crate) signer: SignatureSecretKey,
74
75    pub(crate) output: CommitMessageDescription,
76
77    pub(crate) commit_message_hash: MessageHash,
78}
79
80#[cfg_attr(
81    all(feature = "ffi", not(test)),
82    safer_ffi_gen::ffi_type(clone, opaque)
83)]
84#[derive(Clone)]
85pub struct CommitSecrets(pub(crate) PendingCommitSnapshot);
86
87impl CommitSecrets {
88    /// Deserialize the commit secrets from bytes
89    pub fn from_bytes(bytes: &[u8]) -> Result<Self, MlsError> {
90        Ok(MlsDecode::mls_decode(&mut &*bytes).map(Self)?)
91    }
92
93    /// Serialize the commit secrets to bytes
94    pub fn to_bytes(&self) -> Result<Vec<u8>, MlsError> {
95        Ok(self.0.mls_encode_to_vec()?)
96    }
97}
98
99#[cfg_attr(
100    all(feature = "ffi", not(test)),
101    safer_ffi_gen::ffi_type(clone, opaque)
102)]
103#[derive(Clone, Debug)]
104#[non_exhaustive]
105/// Result of MLS commit operation using
106/// [`Group::commit`](crate::group::Group::commit) or
107/// [`CommitBuilder::build`](CommitBuilder::build).
108pub struct CommitOutput {
109    /// Commit message to send to other group members.
110    pub commit_message: MlsMessage,
111    /// Welcome messages to send to new group members. If the commit does not add members,
112    /// this list is empty. Otherwise, if [`MlsRules::commit_options`] returns `single_welcome_message`
113    /// set to true, then this list contains a single message sent to all members. Else, the list
114    /// contains one message for each added member. Recipients of each message can be identified using
115    /// [`MlsMessage::key_package_reference`] of their key packages and
116    /// [`MlsMessage::welcome_key_package_references`].
117    pub welcome_messages: Vec<MlsMessage>,
118    /// Ratchet tree that can be sent out of band if
119    /// `ratchet_tree_extension` is not used according to
120    /// [`MlsRules::commit_options`].
121    pub ratchet_tree: Option<ExportedTree<'static>>,
122    /// A group info that can be provided to new members in order to enable external commit
123    /// functionality. This value is set if [`MlsRules::commit_options`] returns
124    /// `allow_external_commit` set to true.
125    pub external_commit_group_info: Option<MlsMessage>,
126    /// Proposals that were received in the prior epoch but not included in the following commit.
127    #[cfg(feature = "by_ref_proposal")]
128    pub unused_proposals: Vec<crate::mls_rules::ProposalInfo<Proposal>>,
129    /// Indicator that the commit contains a path update
130    pub contains_update_path: bool,
131}
132
133#[cfg_attr(all(feature = "ffi", not(test)), ::safer_ffi_gen::safer_ffi_gen)]
134impl CommitOutput {
135    /// Commit message to send to other group members.
136    #[cfg(feature = "ffi")]
137    pub fn commit_message(&self) -> &MlsMessage {
138        &self.commit_message
139    }
140
141    /// Welcome message to send to new group members.
142    #[cfg(feature = "ffi")]
143    pub fn welcome_messages(&self) -> &[MlsMessage] {
144        &self.welcome_messages
145    }
146
147    /// Ratchet tree that can be sent out of band if
148    /// `ratchet_tree_extension` is not used according to
149    /// [`MlsRules::commit_options`].
150    #[cfg(feature = "ffi")]
151    pub fn ratchet_tree(&self) -> Option<&ExportedTree<'static>> {
152        self.ratchet_tree.as_ref()
153    }
154
155    /// A group info that can be provided to new members in order to enable external commit
156    /// functionality. This value is set if [`MlsRules::commit_options`] returns
157    /// `allow_external_commit` set to true.
158    #[cfg(feature = "ffi")]
159    pub fn external_commit_group_info(&self) -> Option<&MlsMessage> {
160        self.external_commit_group_info.as_ref()
161    }
162
163    /// Proposals that were received in the prior epoch but not included in the following commit.
164    #[cfg(all(feature = "ffi", feature = "by_ref_proposal"))]
165    pub fn unused_proposals(&self) -> &[crate::mls_rules::ProposalInfo<Proposal>] {
166        &self.unused_proposals
167    }
168}
169
170/// Build a commit with multiple proposals by-value.
171///
172/// Proposals within a commit can be by-value or by-reference.
173/// Proposals received during the current epoch will be added to the resulting
174/// commit by-reference automatically so long as they pass the rules defined
175/// in the current
176/// [proposal rules](crate::client_builder::ClientBuilder::mls_rules).
177pub struct CommitBuilder<'a, C>
178where
179    C: ClientConfig + Clone,
180{
181    group: &'a mut Group<C>,
182    pub(super) proposals: Vec<Proposal>,
183    authenticated_data: Vec<u8>,
184    group_info_extensions: ExtensionList,
185    new_signer: Option<SignatureSecretKey>,
186    new_signing_identity: Option<SigningIdentity>,
187    new_leaf_node_extensions: Option<ExtensionList>,
188    commit_time: Option<MlsTime>,
189}
190
191impl<'a, C> CommitBuilder<'a, C>
192where
193    C: ClientConfig + Clone,
194{
195    /// Insert an [`AddProposal`](crate::group::proposal::AddProposal) into
196    /// the current commit that is being built.
197    pub fn add_member(mut self, key_package: MlsMessage) -> Result<CommitBuilder<'a, C>, MlsError> {
198        let proposal = self.group.add_proposal(key_package)?;
199        self.proposals.push(proposal);
200        Ok(self)
201    }
202
203    /// Set group info extensions that will be inserted into the resulting
204    /// [welcome messages](CommitOutput::welcome_messages) for new members.
205    ///
206    /// Group info extensions that are transmitted as part of a welcome message
207    /// are encrypted along with other private values.
208    ///
209    /// These extensions can be retrieved as part of
210    /// [`NewMemberInfo`](crate::group::NewMemberInfo) that is returned
211    /// by joining the group via
212    /// [`Client::join_group`](crate::Client::join_group).
213    pub fn set_group_info_ext(self, extensions: ExtensionList) -> Self {
214        Self {
215            group_info_extensions: extensions,
216            ..self
217        }
218    }
219
220    /// Insert a [`RemoveProposal`](crate::group::proposal::RemoveProposal) into
221    /// the current commit that is being built.
222    pub fn remove_member(mut self, index: u32) -> Result<Self, MlsError> {
223        let proposal = self.group.remove_proposal(index)?;
224        self.proposals.push(proposal);
225        Ok(self)
226    }
227
228    /// Insert a
229    /// [`GroupContextExtensions`](crate::group::proposal::Proposal::GroupContextExtensions)
230    /// into the current commit that is being built.
231    pub fn set_group_context_ext(mut self, extensions: ExtensionList) -> Result<Self, MlsError> {
232        let proposal = self.group.group_context_extensions_proposal(extensions);
233        self.proposals.push(proposal);
234        Ok(self)
235    }
236
237    /// Insert a
238    /// [`PreSharedKeyProposal`](crate::group::proposal::PreSharedKeyProposal) with
239    /// an external PSK into the current commit that is being built.
240    #[cfg(feature = "psk")]
241    pub fn add_external_psk(mut self, psk_id: ExternalPskId) -> Result<Self, MlsError> {
242        let key_id = JustPreSharedKeyID::External(psk_id);
243        let proposal = self.group.psk_proposal(key_id)?;
244        self.proposals.push(proposal);
245        Ok(self)
246    }
247
248    /// Insert a
249    /// [`PreSharedKeyProposal`](crate::group::proposal::PreSharedKeyProposal) with
250    /// a resumption PSK into the current commit that is being built.
251    #[cfg(feature = "psk")]
252    pub fn add_resumption_psk(mut self, psk_epoch: u64) -> Result<Self, MlsError> {
253        let psk_id = ResumptionPsk {
254            psk_epoch,
255            usage: ResumptionPSKUsage::Application,
256            psk_group_id: PskGroupId(self.group.group_id().to_vec()),
257        };
258
259        let key_id = JustPreSharedKeyID::Resumption(psk_id);
260        let proposal = self.group.psk_proposal(key_id)?;
261        self.proposals.push(proposal);
262        Ok(self)
263    }
264
265    /// Insert a [`ReInitProposal`](crate::group::proposal::ReInitProposal) into
266    /// the current commit that is being built.
267    pub fn reinit(
268        mut self,
269        group_id: Option<Vec<u8>>,
270        version: ProtocolVersion,
271        cipher_suite: CipherSuite,
272        extensions: ExtensionList,
273    ) -> Result<Self, MlsError> {
274        let proposal = self
275            .group
276            .reinit_proposal(group_id, version, cipher_suite, extensions)?;
277
278        self.proposals.push(proposal);
279        Ok(self)
280    }
281
282    /// Insert a [`CustomProposal`](crate::group::proposal::CustomProposal) into
283    /// the current commit that is being built.
284    #[cfg(feature = "custom_proposal")]
285    pub fn custom_proposal(mut self, proposal: CustomProposal) -> Self {
286        self.proposals.push(Proposal::Custom(proposal));
287        self
288    }
289
290    /// Insert a proposal that was previously constructed such as when a
291    /// proposal is returned from
292    /// [`NewEpoch::unused_proposals`](super::NewEpoch::unused_proposals).
293    pub fn raw_proposal(mut self, proposal: Proposal) -> Self {
294        self.proposals.push(proposal);
295        self
296    }
297
298    /// Insert proposals that were previously constructed such as when a
299    /// proposal is returned from
300    /// [`NewEpoch::unused_proposals`](super::NewEpoch::unused_proposals).
301    pub fn raw_proposals(mut self, mut proposals: Vec<Proposal>) -> Self {
302        self.proposals.append(&mut proposals);
303        self
304    }
305
306    /// Add additional authenticated data to the commit.
307    ///
308    /// # Warning
309    ///
310    /// The data provided here is always sent unencrypted.
311    pub fn authenticated_data(self, authenticated_data: Vec<u8>) -> Self {
312        Self {
313            authenticated_data,
314            ..self
315        }
316    }
317
318    /// Change the committer's signing identity as part of making this commit.
319    /// This will only succeed if the [`IdentityProvider`](crate::IdentityProvider)
320    /// in use by the group considers the credential inside this signing_identity
321    /// [valid](crate::IdentityProvider::validate_member)
322    /// and results in the same
323    /// [identity](crate::IdentityProvider::identity)
324    /// being used.
325    pub fn set_new_signing_identity(
326        self,
327        signer: SignatureSecretKey,
328        signing_identity: SigningIdentity,
329    ) -> Self {
330        Self {
331            new_signer: Some(signer),
332            new_signing_identity: Some(signing_identity),
333            ..self
334        }
335    }
336
337    /// Change the committer's leaf node extensions as part of making this commit.
338    pub fn set_leaf_node_extensions(self, new_leaf_node_extensions: ExtensionList) -> Self {
339        Self {
340            new_leaf_node_extensions: Some(new_leaf_node_extensions),
341            ..self
342        }
343    }
344
345    /// Add a time to associate with the commit creation.
346    pub fn commit_time(self, commit_time: MlsTime) -> Self {
347        Self {
348            commit_time: Some(commit_time),
349            ..self
350        }
351    }
352
353    /// Finalize the commit to send.
354    ///
355    /// # Errors
356    ///
357    /// This function will return an error if any of the proposals provided
358    /// are not contextually valid according to the rules defined by the
359    /// MLS RFC, or if they do not pass the custom rules defined by the current
360    /// [proposal rules](crate::client_builder::ClientBuilder::mls_rules).
361    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
362    pub async fn build(self) -> Result<CommitOutput, MlsError> {
363        let (output, pending_commit) = self
364            .group
365            .commit_internal(
366                self.proposals,
367                None,
368                self.authenticated_data,
369                self.group_info_extensions,
370                self.new_signer,
371                self.new_signing_identity,
372                self.new_leaf_node_extensions,
373                self.commit_time,
374            )
375            .await?;
376
377        self.group.pending_commit = pending_commit.try_into()?;
378
379        Ok(output)
380    }
381
382    /// The same function as `GroupBuilder::build` except the secrets generated
383    /// for the commit are outputted instead of being cached internally.
384    ///
385    /// A detached commit can be applied using `Group::apply_detached_commit`.
386    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
387    pub async fn build_detached(self) -> Result<(CommitOutput, CommitSecrets), MlsError> {
388        let (output, pending_commit) = self
389            .group
390            .commit_internal(
391                self.proposals,
392                None,
393                self.authenticated_data,
394                self.group_info_extensions,
395                self.new_signer,
396                self.new_signing_identity,
397                self.new_leaf_node_extensions,
398                self.commit_time,
399            )
400            .await?;
401
402        Ok((
403            output,
404            CommitSecrets(PendingCommitSnapshot::PendingCommit(
405                pending_commit.mls_encode_to_vec()?,
406            )),
407        ))
408    }
409}
410
411impl<C> Group<C>
412where
413    C: ClientConfig + Clone,
414{
415    /// Perform a commit of received proposals.
416    ///
417    /// This function is the equivalent of [`Group::commit_builder`] immediately
418    /// followed by [`CommitBuilder::build`]. Any received proposals since the
419    /// last commit will be included in the resulting message by-reference.
420    ///
421    /// Data provided in the `authenticated_data` field will be placed into
422    /// the resulting commit message unencrypted.
423    ///
424    /// # Pending Commits
425    ///
426    /// When a commit is created, it is not applied immediately in order to
427    /// allow for the resolution of conflicts when multiple members of a group
428    /// attempt to make commits at the same time. For example, a central relay
429    /// can be used to decide which commit should be accepted by the group by
430    /// determining a consistent view of commit packet order for all clients.
431    ///
432    /// Pending commits are stored internally as part of the group's state
433    /// so they do not need to be tracked outside of this library. Any commit
434    /// message that is processed before calling [Group::apply_pending_commit]
435    /// will clear the currently pending commit.
436    ///
437    /// # Empty Commits
438    ///
439    /// Sending a commit that contains no proposals is a valid operation
440    /// within the MLS protocol. It is useful for providing stronger forward
441    /// secrecy and post-compromise security, especially for long running
442    /// groups when group membership does not change often.
443    ///
444    /// # Path Updates
445    ///
446    /// Path updates provide forward secrecy and post-compromise security
447    /// within the MLS protocol.
448    /// The `path_required` option returned by [`MlsRules::commit_options`](`crate::MlsRules::commit_options`)
449    /// controls the ability of a group to send a commit without a path update.
450    /// An update path will automatically be sent if there are no proposals
451    /// in the commit, or if any proposal other than
452    /// [`Add`](crate::group::proposal::Proposal::Add),
453    /// [`Psk`](crate::group::proposal::Proposal::Psk),
454    /// or [`ReInit`](crate::group::proposal::Proposal::ReInit) are part of the commit.
455    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
456    pub async fn commit(&mut self, authenticated_data: Vec<u8>) -> Result<CommitOutput, MlsError> {
457        self.commit_builder()
458            .authenticated_data(authenticated_data)
459            .build()
460            .await
461    }
462
463    /// The same function as `Group::commit` except the secrets generated
464    /// for the commit are outputted instead of being cached internally.
465    ///
466    /// A detached commit can be applied using `Group::apply_detached_commit`.
467    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
468    pub async fn commit_detached(
469        &mut self,
470        authenticated_data: Vec<u8>,
471    ) -> Result<(CommitOutput, CommitSecrets), MlsError> {
472        self.commit_builder()
473            .authenticated_data(authenticated_data)
474            .build_detached()
475            .await
476    }
477
478    /// Create a new commit builder that can include proposals
479    /// by-value.
480    pub fn commit_builder(&mut self) -> CommitBuilder<C> {
481        CommitBuilder {
482            group: self,
483            proposals: Default::default(),
484            authenticated_data: Default::default(),
485            group_info_extensions: Default::default(),
486            new_signer: Default::default(),
487            new_signing_identity: Default::default(),
488            new_leaf_node_extensions: Default::default(),
489            commit_time: None,
490        }
491    }
492
493    /// Returns commit and optional [`MlsMessage`] containing a welcome message
494    /// for newly added members.
495    #[allow(clippy::too_many_arguments)]
496    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
497    pub(super) async fn commit_internal(
498        &mut self,
499        proposals: Vec<Proposal>,
500        external_leaf: Option<&LeafNode>,
501        authenticated_data: Vec<u8>,
502        mut welcome_group_info_extensions: ExtensionList,
503        new_signer: Option<SignatureSecretKey>,
504        new_signing_identity: Option<SigningIdentity>,
505        new_leaf_node_extensions: Option<ExtensionList>,
506        commit_time: Option<MlsTime>,
507    ) -> Result<(CommitOutput, PendingCommit), MlsError> {
508        if !self.pending_commit.is_none() {
509            return Err(MlsError::ExistingPendingCommit);
510        }
511
512        if self.state.pending_reinit.is_some() {
513            return Err(MlsError::GroupUsedAfterReInit);
514        }
515
516        let mls_rules = self.config.mls_rules();
517
518        let is_external = external_leaf.is_some();
519
520        // Construct an initial Commit object with the proposals field populated from Proposals
521        // received during the current epoch, and an empty path field. Add passed in proposals
522        // by value
523        let sender = if is_external {
524            Sender::NewMemberCommit
525        } else {
526            Sender::Member(*self.private_tree.self_index)
527        };
528
529        let new_signer = new_signer.unwrap_or_else(|| self.signer.clone());
530        let old_signer = &self.signer;
531
532        #[cfg(feature = "std")]
533        let time = Some(crate::time::MlsTime::now());
534
535        #[cfg(not(feature = "std"))]
536        let time = None;
537
538        let time = if commit_time.is_some() {
539            commit_time
540        } else {
541            time
542        };
543
544        #[cfg(feature = "by_ref_proposal")]
545        let proposals = self.state.proposals.prepare_commit(sender, proposals);
546
547        #[cfg(not(feature = "by_ref_proposal"))]
548        let proposals = prepare_commit(sender, proposals);
549
550        let mut provisional_state = self
551            .state
552            .apply_resolved(
553                sender,
554                proposals,
555                external_leaf,
556                &self.config.identity_provider(),
557                &self.cipher_suite_provider,
558                &self.config.secret_store(),
559                &mls_rules,
560                time,
561                CommitDirection::Send,
562            )
563            .await?;
564
565        let (mut provisional_private_tree, _) =
566            self.provisional_private_tree(&provisional_state)?;
567
568        if is_external {
569            provisional_private_tree.self_index = provisional_state
570                .external_init_index
571                .ok_or(MlsError::ExternalCommitMissingExternalInit)?;
572
573            self.private_tree.self_index = provisional_private_tree.self_index;
574        }
575
576        // Decide whether to populate the path field: If the path field is required based on the
577        // proposals that are in the commit (see above), then it MUST be populated. Otherwise, the
578        // sender MAY omit the path field at its discretion.
579        let commit_options = mls_rules
580            .commit_options(
581                &provisional_state.public_tree.roster(),
582                &provisional_state.group_context,
583                &provisional_state.applied_proposals,
584            )
585            .map_err(|e| MlsError::MlsRulesError(e.into_any_error()))?;
586
587        let perform_path_update = commit_options.path_required
588            || path_update_required(&provisional_state.applied_proposals);
589
590        let (update_path, path_secrets, commit_secret) = if perform_path_update {
591            // If populating the path field: Create an UpdatePath using the new tree. Any new
592            // member (from an add proposal) MUST be excluded from the resolution during the
593            // computation of the UpdatePath. The GroupContext for this operation uses the
594            // group_id, epoch, tree_hash, and confirmed_transcript_hash values in the initial
595            // GroupContext object. The leaf_key_package for this UpdatePath must have a
596            // parent_hash extension.
597
598            let new_leaf_node_extensions =
599                new_leaf_node_extensions.or(external_leaf.map(|ln| ln.ungreased_extensions()));
600
601            let new_leaf_node_extensions = match new_leaf_node_extensions {
602                Some(extensions) => extensions,
603                // If we are not setting new extensions and this is not an external leaf then the current node MUST exist.
604                None => self.current_user_leaf_node()?.ungreased_extensions(),
605            };
606
607            let encap_gen = TreeKem::new(
608                &mut provisional_state.public_tree,
609                &mut provisional_private_tree,
610            )
611            .encap(
612                &mut provisional_state.group_context,
613                &provisional_state.indexes_of_added_kpkgs,
614                &new_signer,
615                Some(self.config.leaf_properties(new_leaf_node_extensions)),
616                new_signing_identity,
617                &self.cipher_suite_provider,
618                #[cfg(test)]
619                &self.commit_modifiers,
620            )
621            .await?;
622
623            (
624                Some(encap_gen.update_path),
625                Some(encap_gen.path_secrets),
626                encap_gen.commit_secret,
627            )
628        } else {
629            // Update the tree hash, since it was not updated by encap.
630            provisional_state
631                .public_tree
632                .update_hashes(
633                    &[provisional_private_tree.self_index],
634                    &self.cipher_suite_provider,
635                )
636                .await?;
637
638            provisional_state.group_context.tree_hash = provisional_state
639                .public_tree
640                .tree_hash(&self.cipher_suite_provider)
641                .await?;
642
643            (None, None, PathSecret::empty(&self.cipher_suite_provider))
644        };
645
646        #[cfg(feature = "psk")]
647        let (psk_secret, psks) = self
648            .get_psk(&provisional_state.applied_proposals.psks)
649            .await?;
650
651        #[cfg(not(feature = "psk"))]
652        let psk_secret = self.get_psk();
653
654        let added_key_pkgs: Vec<_> = provisional_state
655            .applied_proposals
656            .additions
657            .iter()
658            .map(|info| info.proposal.key_package.clone())
659            .collect();
660
661        let commit = Commit {
662            proposals: provisional_state.applied_proposals.proposals_or_refs(),
663            path: update_path,
664        };
665
666        let mut auth_content = AuthenticatedContent::new_signed(
667            &self.cipher_suite_provider,
668            self.context(),
669            sender,
670            Content::Commit(Box::new(commit)),
671            old_signer,
672            #[cfg(feature = "private_message")]
673            self.encryption_options()?.control_wire_format(sender),
674            #[cfg(not(feature = "private_message"))]
675            WireFormat::PublicMessage,
676            authenticated_data,
677        )
678        .await?;
679
680        // Use the signature, the commit_secret and the psk_secret to advance the key schedule and
681        // compute the confirmation_tag value in the MlsPlaintext.
682        let confirmed_transcript_hash = super::transcript_hash::create(
683            self.cipher_suite_provider(),
684            &self.state.interim_transcript_hash,
685            &auth_content,
686        )
687        .await?;
688
689        provisional_state.group_context.confirmed_transcript_hash = confirmed_transcript_hash;
690
691        let key_schedule_result = KeySchedule::from_key_schedule(
692            &self.key_schedule,
693            &commit_secret,
694            &provisional_state.group_context,
695            #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
696            provisional_state.public_tree.total_leaf_count(),
697            &psk_secret,
698            &self.cipher_suite_provider,
699        )
700        .await?;
701
702        let confirmation_tag = ConfirmationTag::create(
703            &key_schedule_result.confirmation_key,
704            &provisional_state.group_context.confirmed_transcript_hash,
705            &self.cipher_suite_provider,
706        )
707        .await?;
708
709        let interim_transcript_hash = InterimTranscriptHash::create(
710            self.cipher_suite_provider(),
711            &provisional_state.group_context.confirmed_transcript_hash,
712            &confirmation_tag,
713        )
714        .await?;
715
716        auth_content.auth.confirmation_tag = Some(confirmation_tag.clone());
717
718        let ratchet_tree_ext = commit_options
719            .ratchet_tree_extension
720            .then(|| RatchetTreeExt {
721                tree_data: ExportedTree::new(provisional_state.public_tree.nodes.clone()),
722            });
723
724        // Generate external commit group info if required by commit_options
725        let external_commit_group_info = match commit_options.allow_external_commit {
726            true => {
727                let mut extensions = ExtensionList::new();
728
729                extensions.set_from({
730                    key_schedule_result
731                        .key_schedule
732                        .get_external_key_pair_ext(&self.cipher_suite_provider)
733                        .await?
734                })?;
735
736                if let Some(ref ratchet_tree_ext) = ratchet_tree_ext {
737                    if !commit_options.always_out_of_band_ratchet_tree {
738                        extensions.set_from(ratchet_tree_ext.clone())?;
739                    }
740                }
741
742                let info = self
743                    .make_group_info(
744                        &provisional_state.group_context,
745                        extensions,
746                        &confirmation_tag,
747                        &new_signer,
748                    )
749                    .await?;
750
751                let msg =
752                    MlsMessage::new(self.protocol_version(), MlsMessagePayload::GroupInfo(info));
753
754                Some(msg)
755            }
756            false => None,
757        };
758
759        // Build the group info that will be placed into the welcome messages.
760        // Add the ratchet tree extension if necessary
761        if let Some(ratchet_tree_ext) = ratchet_tree_ext {
762            welcome_group_info_extensions.set_from(ratchet_tree_ext)?;
763        }
764
765        let welcome_group_info = self
766            .make_group_info(
767                &provisional_state.group_context,
768                welcome_group_info_extensions,
769                &confirmation_tag,
770                &new_signer,
771            )
772            .await?;
773
774        // Encrypt the GroupInfo using the key and nonce derived from the joiner_secret for
775        // the new epoch
776        let welcome_secret = WelcomeSecret::from_joiner_secret(
777            &self.cipher_suite_provider,
778            &key_schedule_result.joiner_secret,
779            &psk_secret,
780        )
781        .await?;
782
783        let encrypted_group_info = welcome_secret
784            .encrypt(&welcome_group_info.mls_encode_to_vec()?)
785            .await?;
786
787        // Encrypt path secrets and joiner secret to new members
788        let path_secrets = path_secrets.as_ref();
789
790        #[cfg(not(any(mls_build_async, not(feature = "rayon"))))]
791        let encrypted_path_secrets: Vec<_> = added_key_pkgs
792            .into_par_iter()
793            .zip(&provisional_state.indexes_of_added_kpkgs)
794            .map(|(key_package, leaf_index)| {
795                self.encrypt_group_secrets(
796                    &key_package,
797                    *leaf_index,
798                    &key_schedule_result.joiner_secret,
799                    path_secrets,
800                    #[cfg(feature = "psk")]
801                    psks.clone(),
802                    &encrypted_group_info,
803                )
804            })
805            .try_collect()?;
806
807        #[cfg(any(mls_build_async, not(feature = "rayon")))]
808        let encrypted_path_secrets = {
809            let mut secrets = Vec::new();
810
811            for (key_package, leaf_index) in added_key_pkgs
812                .into_iter()
813                .zip(&provisional_state.indexes_of_added_kpkgs)
814            {
815                secrets.push(
816                    self.encrypt_group_secrets(
817                        &key_package,
818                        *leaf_index,
819                        &key_schedule_result.joiner_secret,
820                        path_secrets,
821                        #[cfg(feature = "psk")]
822                        psks.clone(),
823                        &encrypted_group_info,
824                    )
825                    .await?,
826                );
827            }
828
829            secrets
830        };
831
832        let welcome_messages =
833            if commit_options.single_welcome_message && !encrypted_path_secrets.is_empty() {
834                vec![self.make_welcome_message(encrypted_path_secrets, encrypted_group_info)]
835            } else {
836                encrypted_path_secrets
837                    .into_iter()
838                    .map(|s| self.make_welcome_message(vec![s], encrypted_group_info.clone()))
839                    .collect()
840            };
841
842        let commit_message = self.format_for_wire(auth_content.clone()).await?;
843
844        // TODO is it necessary to clone the tree here? or can we just output serialized bytes?
845        let ratchet_tree = (!commit_options.ratchet_tree_extension
846            || commit_options.always_out_of_band_ratchet_tree)
847            .then(|| ExportedTree::new(provisional_state.public_tree.nodes.clone()));
848
849        let pending_reinit = provisional_state
850            .applied_proposals
851            .reinitializations
852            .first();
853
854        let pending_commit = PendingCommit {
855            output: CommitMessageDescription {
856                is_external: matches!(auth_content.content.sender, Sender::NewMemberCommit),
857                authenticated_data: auth_content.content.authenticated_data,
858                committer: *provisional_private_tree.self_index,
859                effect: match pending_reinit {
860                    Some(r) => CommitEffect::ReInit(r.clone()),
861                    None => CommitEffect::NewEpoch(
862                        NewEpoch::new(self.state.clone(), &provisional_state).into(),
863                    ),
864                },
865            },
866
867            state: GroupState {
868                #[cfg(feature = "by_ref_proposal")]
869                proposals: crate::group::ProposalCache::new(
870                    self.protocol_version(),
871                    self.group_id().to_vec(),
872                ),
873                context: provisional_state.group_context,
874                public_tree: provisional_state.public_tree,
875                interim_transcript_hash,
876                pending_reinit: pending_reinit.map(|r| r.proposal.clone()),
877                confirmation_tag,
878            },
879
880            commit_message_hash: MessageHash::compute(&self.cipher_suite_provider, &commit_message)
881                .await?,
882            signer: new_signer,
883            epoch_secrets: key_schedule_result.epoch_secrets,
884            key_schedule: key_schedule_result.key_schedule,
885
886            private_tree: provisional_private_tree,
887        };
888
889        let output = CommitOutput {
890            commit_message,
891            welcome_messages,
892            ratchet_tree,
893            external_commit_group_info,
894            contains_update_path: perform_path_update,
895            #[cfg(feature = "by_ref_proposal")]
896            unused_proposals: provisional_state.unused_proposals,
897        };
898
899        Ok((output, pending_commit))
900    }
901
902    // Construct a GroupInfo reflecting the new state
903    // Group ID, epoch, tree, and confirmed transcript hash from the new state
904    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
905    async fn make_group_info(
906        &self,
907        group_context: &GroupContext,
908        extensions: ExtensionList,
909        confirmation_tag: &ConfirmationTag,
910        signer: &SignatureSecretKey,
911    ) -> Result<GroupInfo, MlsError> {
912        let mut group_info = GroupInfo {
913            group_context: group_context.clone(),
914            extensions,
915            confirmation_tag: confirmation_tag.clone(), // The confirmation_tag from the MlsPlaintext object
916            signer: self.current_member_leaf_index(),
917            signature: vec![],
918        };
919
920        group_info.grease(self.cipher_suite_provider())?;
921
922        // Sign the GroupInfo using the member's private signing key
923        group_info
924            .sign(&self.cipher_suite_provider, signer, &())
925            .await?;
926
927        Ok(group_info)
928    }
929
930    fn make_welcome_message(
931        &self,
932        secrets: Vec<EncryptedGroupSecrets>,
933        encrypted_group_info: Vec<u8>,
934    ) -> MlsMessage {
935        MlsMessage::new(
936            self.context().protocol_version,
937            MlsMessagePayload::Welcome(Welcome {
938                cipher_suite: self.context().cipher_suite,
939                secrets,
940                encrypted_group_info,
941            }),
942        )
943    }
944}
945
946#[cfg(test)]
947pub(crate) mod test_utils {
948    use alloc::vec::Vec;
949
950    use crate::{
951        crypto::SignatureSecretKey,
952        tree_kem::{leaf_node::LeafNode, TreeKemPublic, UpdatePathNode},
953    };
954
955    #[derive(Copy, Clone, Debug)]
956    pub struct CommitModifiers {
957        pub modify_leaf: fn(&mut LeafNode, &SignatureSecretKey) -> Option<SignatureSecretKey>,
958        pub modify_tree: fn(&mut TreeKemPublic),
959        pub modify_path: fn(Vec<UpdatePathNode>) -> Vec<UpdatePathNode>,
960    }
961
962    impl Default for CommitModifiers {
963        fn default() -> Self {
964            Self {
965                modify_leaf: |_, _| None,
966                modify_tree: |_| (),
967                modify_path: |a| a,
968            }
969        }
970    }
971}
972
973#[cfg(test)]
974mod tests {
975    use mls_rs_core::{
976        error::IntoAnyError,
977        extension::ExtensionType,
978        identity::{CredentialType, IdentityProvider, MemberValidationContext},
979        time::MlsTime,
980    };
981
982    use crate::extension::RequiredCapabilitiesExt;
983    use crate::{
984        client::test_utils::{test_client_with_key_pkg, TEST_CIPHER_SUITE, TEST_PROTOCOL_VERSION},
985        client_builder::{
986            test_utils::TestClientConfig, BaseConfig, ClientBuilder, WithCryptoProvider,
987            WithIdentityProvider,
988        },
989        client_config::ClientConfig,
990        crypto::test_utils::TestCryptoProvider,
991        extension::test_utils::{TestExtension, TEST_EXTENSION_TYPE},
992        group::test_utils::{test_group, test_group_custom},
993        group::{
994            proposal::ProposalType,
995            test_utils::{test_group_custom_config, test_n_member_group},
996        },
997        identity::test_utils::get_test_signing_identity,
998        identity::{basic::BasicIdentityProvider, test_utils::get_test_basic_credential},
999        key_package::test_utils::test_key_package_message,
1000        mls_rules::CommitOptions,
1001        Client,
1002    };
1003
1004    #[cfg(feature = "by_ref_proposal")]
1005    use crate::crypto::test_utils::test_cipher_suite_provider;
1006    #[cfg(feature = "by_ref_proposal")]
1007    use crate::extension::ExternalSendersExt;
1008    #[cfg(feature = "by_ref_proposal")]
1009    use crate::group::mls_rules::DefaultMlsRules;
1010
1011    #[cfg(feature = "psk")]
1012    use crate::{
1013        group::proposal::PreSharedKeyProposal,
1014        psk::{JustPreSharedKeyID, PreSharedKey, PreSharedKeyID},
1015    };
1016
1017    use super::*;
1018
1019    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
1020    async fn test_commit_builder_group() -> Group<TestClientConfig> {
1021        test_group_custom_config(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, |b| {
1022            b.custom_proposal_type(ProposalType::from(42))
1023                .extension_type(TEST_EXTENSION_TYPE.into())
1024        })
1025        .await
1026        .group
1027    }
1028
1029    fn assert_commit_builder_output<C: ClientConfig>(
1030        group: Group<C>,
1031        mut commit_output: CommitOutput,
1032        expected: Vec<Proposal>,
1033        welcome_count: usize,
1034    ) {
1035        let plaintext = commit_output.commit_message.into_plaintext().unwrap();
1036
1037        let commit_data = match plaintext.content.content {
1038            Content::Commit(commit) => commit,
1039            #[cfg(any(feature = "private_message", feature = "by_ref_proposal"))]
1040            _ => panic!("Found non-commit data"),
1041        };
1042
1043        assert_eq!(commit_data.proposals.len(), expected.len());
1044
1045        commit_data.proposals.into_iter().for_each(|proposal| {
1046            let proposal = match proposal {
1047                ProposalOrRef::Proposal(p) => p,
1048                #[cfg(feature = "by_ref_proposal")]
1049                ProposalOrRef::Reference(_) => panic!("found proposal reference"),
1050            };
1051
1052            #[cfg(feature = "psk")]
1053            if let Some(psk_id) = match proposal.as_ref() {
1054                Proposal::Psk(PreSharedKeyProposal { psk: PreSharedKeyID { key_id: JustPreSharedKeyID::External(psk_id), .. },}) => Some(psk_id),
1055                _ => None,
1056            } {
1057                let found = expected.iter().any(|item| matches!(item, Proposal::Psk(PreSharedKeyProposal { psk: PreSharedKeyID { key_id: JustPreSharedKeyID::External(id), .. }}) if id == psk_id));
1058
1059                assert!(found)
1060            } else {
1061                assert!(expected.contains(&proposal));
1062            }
1063
1064            #[cfg(not(feature = "psk"))]
1065            assert!(expected.contains(&proposal));
1066        });
1067
1068        if welcome_count > 0 {
1069            let welcome_msg = commit_output.welcome_messages.pop().unwrap();
1070
1071            assert_eq!(welcome_msg.version, group.state.context.protocol_version);
1072
1073            let welcome_msg = welcome_msg.into_welcome().unwrap();
1074
1075            assert_eq!(welcome_msg.cipher_suite, group.state.context.cipher_suite);
1076            assert_eq!(welcome_msg.secrets.len(), welcome_count);
1077        } else {
1078            assert!(commit_output.welcome_messages.is_empty());
1079        }
1080    }
1081
1082    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1083    async fn test_commit_builder_add() {
1084        let mut group = test_commit_builder_group().await;
1085
1086        let test_key_package =
1087            test_key_package_message(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "alice").await;
1088
1089        let commit_output = group
1090            .commit_builder()
1091            .add_member(test_key_package.clone())
1092            .unwrap()
1093            .build()
1094            .await
1095            .unwrap();
1096
1097        let expected_add = group.add_proposal(test_key_package).unwrap();
1098
1099        assert_commit_builder_output(group, commit_output, vec![expected_add], 1)
1100    }
1101
1102    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1103    async fn test_commit_builder_add_with_ext() {
1104        let mut group = test_commit_builder_group().await;
1105
1106        let (bob_client, bob_key_package) =
1107            test_client_with_key_pkg(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "bob").await;
1108
1109        let ext = TestExtension { foo: 42 };
1110        let mut extension_list = ExtensionList::default();
1111        extension_list.set_from(ext.clone()).unwrap();
1112
1113        let welcome_message = group
1114            .commit_builder()
1115            .add_member(bob_key_package)
1116            .unwrap()
1117            .set_group_info_ext(extension_list)
1118            .build()
1119            .await
1120            .unwrap()
1121            .welcome_messages
1122            .remove(0);
1123
1124        let (_, context) = bob_client
1125            .join_group(None, &welcome_message, None)
1126            .await
1127            .unwrap();
1128
1129        assert_eq!(
1130            context
1131                .group_info_extensions
1132                .get_as::<TestExtension>()
1133                .unwrap()
1134                .unwrap(),
1135            ext
1136        );
1137    }
1138
1139    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1140    async fn test_commit_builder_remove() {
1141        let mut group = test_commit_builder_group().await;
1142        let test_key_package =
1143            test_key_package_message(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "alice").await;
1144
1145        group
1146            .commit_builder()
1147            .add_member(test_key_package)
1148            .unwrap()
1149            .build()
1150            .await
1151            .unwrap();
1152
1153        group.apply_pending_commit().await.unwrap();
1154
1155        let commit_output = group
1156            .commit_builder()
1157            .remove_member(1)
1158            .unwrap()
1159            .build()
1160            .await
1161            .unwrap();
1162
1163        let expected_remove = group.remove_proposal(1).unwrap();
1164
1165        assert_commit_builder_output(group, commit_output, vec![expected_remove], 0);
1166    }
1167
1168    #[cfg(feature = "psk")]
1169    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1170    async fn test_commit_builder_psk() {
1171        let mut group = test_commit_builder_group().await;
1172        let test_psk = ExternalPskId::new(vec![1]);
1173
1174        group
1175            .config
1176            .secret_store()
1177            .insert(test_psk.clone(), PreSharedKey::from(vec![1]));
1178
1179        let commit_output = group
1180            .commit_builder()
1181            .add_external_psk(test_psk.clone())
1182            .unwrap()
1183            .build()
1184            .await
1185            .unwrap();
1186
1187        let key_id = JustPreSharedKeyID::External(test_psk);
1188        let expected_psk = group.psk_proposal(key_id).unwrap();
1189
1190        assert_commit_builder_output(group, commit_output, vec![expected_psk], 0)
1191    }
1192
1193    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1194    async fn test_commit_builder_group_context_ext() {
1195        let mut group = test_commit_builder_group().await;
1196        let mut test_ext = ExtensionList::default();
1197        test_ext
1198            .set_from(RequiredCapabilitiesExt::default())
1199            .unwrap();
1200
1201        let commit_output = group
1202            .commit_builder()
1203            .set_group_context_ext(test_ext.clone())
1204            .unwrap()
1205            .build()
1206            .await
1207            .unwrap();
1208
1209        let expected_ext = group.group_context_extensions_proposal(test_ext);
1210
1211        assert_commit_builder_output(group, commit_output, vec![expected_ext], 0);
1212    }
1213
1214    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1215    async fn test_commit_builder_reinit() {
1216        let mut group = test_commit_builder_group().await;
1217        let test_group_id = "foo".as_bytes().to_vec();
1218        let test_cipher_suite = TEST_CIPHER_SUITE;
1219        let test_protocol_version = TEST_PROTOCOL_VERSION;
1220        let mut test_ext = ExtensionList::default();
1221
1222        test_ext
1223            .set_from(RequiredCapabilitiesExt::default())
1224            .unwrap();
1225
1226        let commit_output = group
1227            .commit_builder()
1228            .reinit(
1229                Some(test_group_id.clone()),
1230                test_protocol_version,
1231                test_cipher_suite,
1232                test_ext.clone(),
1233            )
1234            .unwrap()
1235            .build()
1236            .await
1237            .unwrap();
1238
1239        let expected_reinit = group
1240            .reinit_proposal(
1241                Some(test_group_id),
1242                test_protocol_version,
1243                test_cipher_suite,
1244                test_ext,
1245            )
1246            .unwrap();
1247
1248        assert_commit_builder_output(group, commit_output, vec![expected_reinit], 0);
1249    }
1250
1251    #[cfg(feature = "custom_proposal")]
1252    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1253    async fn test_commit_builder_custom_proposal() {
1254        let mut group = test_commit_builder_group().await;
1255
1256        let proposal = CustomProposal::new(42.into(), vec![0, 1]);
1257
1258        let commit_output = group
1259            .commit_builder()
1260            .custom_proposal(proposal.clone())
1261            .build()
1262            .await
1263            .unwrap();
1264
1265        assert_commit_builder_output(group, commit_output, vec![Proposal::Custom(proposal)], 0);
1266    }
1267
1268    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1269    async fn test_commit_builder_chaining() {
1270        let mut group = test_commit_builder_group().await;
1271        let kp1 = test_key_package_message(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "alice").await;
1272        let kp2 = test_key_package_message(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "bob").await;
1273
1274        let expected_adds = vec![
1275            group.add_proposal(kp1.clone()).unwrap(),
1276            group.add_proposal(kp2.clone()).unwrap(),
1277        ];
1278
1279        let commit_output = group
1280            .commit_builder()
1281            .add_member(kp1)
1282            .unwrap()
1283            .add_member(kp2)
1284            .unwrap()
1285            .build()
1286            .await
1287            .unwrap();
1288
1289        assert_commit_builder_output(group, commit_output, expected_adds, 2);
1290    }
1291
1292    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1293    async fn test_commit_builder_empty_commit() {
1294        let mut group = test_commit_builder_group().await;
1295
1296        let commit_output = group.commit_builder().build().await.unwrap();
1297
1298        assert_commit_builder_output(group, commit_output, vec![], 0);
1299    }
1300
1301    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1302    async fn test_commit_builder_authenticated_data() {
1303        let mut group = test_commit_builder_group().await;
1304        let test_data = "test".as_bytes().to_vec();
1305
1306        let commit_output = group
1307            .commit_builder()
1308            .authenticated_data(test_data.clone())
1309            .build()
1310            .await
1311            .unwrap();
1312
1313        assert_eq!(
1314            commit_output
1315                .commit_message
1316                .into_plaintext()
1317                .unwrap()
1318                .content
1319                .authenticated_data,
1320            test_data
1321        );
1322    }
1323
1324    #[cfg(feature = "by_ref_proposal")]
1325    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1326    async fn test_commit_builder_multiple_welcome_messages() {
1327        let mut group = test_group_custom_config(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, |b| {
1328            let options = CommitOptions::new().with_single_welcome_message(false);
1329            b.mls_rules(DefaultMlsRules::new().with_commit_options(options))
1330        })
1331        .await;
1332
1333        let (alice, alice_kp) =
1334            test_client_with_key_pkg(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "a").await;
1335
1336        let (bob, bob_kp) =
1337            test_client_with_key_pkg(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "b").await;
1338
1339        group.propose_add(alice_kp.clone(), vec![]).await.unwrap();
1340
1341        group.propose_add(bob_kp.clone(), vec![]).await.unwrap();
1342
1343        let output = group.commit(Vec::new()).await.unwrap();
1344        let welcomes = output.welcome_messages;
1345
1346        let cs = test_cipher_suite_provider(TEST_CIPHER_SUITE);
1347
1348        for (client, kp) in [(alice, alice_kp), (bob, bob_kp)] {
1349            let kp_ref = kp.key_package_reference(&cs).await.unwrap().unwrap();
1350
1351            let welcome = welcomes
1352                .iter()
1353                .find(|w| w.welcome_key_package_references().contains(&&kp_ref))
1354                .unwrap();
1355
1356            client.join_group(None, welcome, None).await.unwrap();
1357
1358            assert_eq!(welcome.clone().into_welcome().unwrap().secrets.len(), 1);
1359        }
1360    }
1361
1362    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1363    async fn commit_can_change_credential() {
1364        let cs = TEST_CIPHER_SUITE;
1365        let mut groups = test_n_member_group(TEST_PROTOCOL_VERSION, cs, 3).await;
1366        let (identity, secret_key) = get_test_signing_identity(cs, b"member").await;
1367
1368        let commit_output = groups[0]
1369            .commit_builder()
1370            .set_new_signing_identity(secret_key, identity.clone())
1371            .build()
1372            .await
1373            .unwrap();
1374
1375        // Check that the credential was updated by in the committer's state.
1376        groups[0].process_pending_commit().await.unwrap();
1377        let new_member = groups[0].roster().member_with_index(0).unwrap();
1378
1379        assert_eq!(
1380            new_member.signing_identity.credential,
1381            get_test_basic_credential(b"member".to_vec())
1382        );
1383
1384        assert_eq!(
1385            new_member.signing_identity.signature_key,
1386            identity.signature_key
1387        );
1388
1389        // Check that the credential was updated in another member's state.
1390        groups[1]
1391            .process_message(commit_output.commit_message)
1392            .await
1393            .unwrap();
1394
1395        let new_member = groups[1].roster().member_with_index(0).unwrap();
1396
1397        assert_eq!(
1398            new_member.signing_identity.credential,
1399            get_test_basic_credential(b"member".to_vec())
1400        );
1401
1402        assert_eq!(
1403            new_member.signing_identity.signature_key,
1404            identity.signature_key
1405        );
1406    }
1407
1408    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1409    async fn commit_includes_tree_if_no_ratchet_tree_ext() {
1410        let mut group = test_group_custom(
1411            TEST_PROTOCOL_VERSION,
1412            TEST_CIPHER_SUITE,
1413            Default::default(),
1414            None,
1415            Some(CommitOptions::new().with_ratchet_tree_extension(false)),
1416        )
1417        .await;
1418
1419        let commit = group.commit(vec![]).await.unwrap();
1420
1421        group.apply_pending_commit().await.unwrap();
1422
1423        let new_tree = group.export_tree();
1424
1425        assert_eq!(new_tree, commit.ratchet_tree.unwrap())
1426    }
1427
1428    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1429    async fn commit_does_not_include_tree_if_ratchet_tree_ext() {
1430        let mut group = test_group_custom(
1431            TEST_PROTOCOL_VERSION,
1432            TEST_CIPHER_SUITE,
1433            Default::default(),
1434            None,
1435            Some(CommitOptions::new().with_ratchet_tree_extension(true)),
1436        )
1437        .await;
1438
1439        let commit = group.commit(vec![]).await.unwrap();
1440
1441        assert!(commit.ratchet_tree.is_none());
1442    }
1443
1444    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1445    async fn commit_includes_external_commit_group_info_if_requested() {
1446        let mut group = test_group_custom(
1447            TEST_PROTOCOL_VERSION,
1448            TEST_CIPHER_SUITE,
1449            Default::default(),
1450            None,
1451            Some(
1452                CommitOptions::new()
1453                    .with_allow_external_commit(true)
1454                    .with_ratchet_tree_extension(false),
1455            ),
1456        )
1457        .await;
1458
1459        let commit = group.commit(vec![]).await.unwrap();
1460
1461        let info = commit
1462            .external_commit_group_info
1463            .unwrap()
1464            .into_group_info()
1465            .unwrap();
1466
1467        assert!(!info.extensions.has_extension(ExtensionType::RATCHET_TREE));
1468        assert!(info.extensions.has_extension(ExtensionType::EXTERNAL_PUB));
1469    }
1470
1471    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1472    async fn commit_includes_external_commit_and_tree_if_requested() {
1473        let mut group = test_group_custom(
1474            TEST_PROTOCOL_VERSION,
1475            TEST_CIPHER_SUITE,
1476            Default::default(),
1477            None,
1478            Some(
1479                CommitOptions::new()
1480                    .with_allow_external_commit(true)
1481                    .with_ratchet_tree_extension(true),
1482            ),
1483        )
1484        .await;
1485
1486        let commit = group.commit(vec![]).await.unwrap();
1487
1488        let info = commit
1489            .external_commit_group_info
1490            .unwrap()
1491            .into_group_info()
1492            .unwrap();
1493
1494        assert!(info.extensions.has_extension(ExtensionType::RATCHET_TREE));
1495        assert!(info.extensions.has_extension(ExtensionType::EXTERNAL_PUB));
1496    }
1497
1498    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1499    async fn commit_does_not_include_external_commit_group_info_if_not_requested() {
1500        let mut group = test_group_custom(
1501            TEST_PROTOCOL_VERSION,
1502            TEST_CIPHER_SUITE,
1503            Default::default(),
1504            None,
1505            Some(CommitOptions::new().with_allow_external_commit(false)),
1506        )
1507        .await;
1508
1509        let commit = group.commit(vec![]).await.unwrap();
1510
1511        assert!(commit.external_commit_group_info.is_none());
1512    }
1513
1514    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1515    async fn commit_includes_tree_out_of_bounds_and_not_in_external_group_info_if_requested_tree_ext_off(
1516    ) {
1517        let mut group = test_group_custom(
1518            TEST_PROTOCOL_VERSION,
1519            TEST_CIPHER_SUITE,
1520            Default::default(),
1521            None,
1522            Some(
1523                CommitOptions::new()
1524                    .with_always_out_of_band_ratchet_tree(true)
1525                    .with_ratchet_tree_extension(false)
1526                    .with_allow_external_commit(true),
1527            ),
1528        )
1529        .await;
1530
1531        let commit = group.commit(vec![]).await.unwrap();
1532
1533        assert!(commit.ratchet_tree.is_some());
1534
1535        let info = commit
1536            .external_commit_group_info
1537            .unwrap()
1538            .into_group_info()
1539            .unwrap();
1540
1541        assert!(!info.extensions.has_extension(ExtensionType::RATCHET_TREE));
1542    }
1543
1544    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1545    async fn commit_includes_tree_out_of_bounds_and_not_in_external_group_info_if_requested_tree_ext_on(
1546    ) {
1547        let mut group = test_group_custom(
1548            TEST_PROTOCOL_VERSION,
1549            TEST_CIPHER_SUITE,
1550            Default::default(),
1551            None,
1552            Some(
1553                CommitOptions::new()
1554                    .with_always_out_of_band_ratchet_tree(true)
1555                    .with_ratchet_tree_extension(true)
1556                    .with_allow_external_commit(true),
1557            ),
1558        )
1559        .await;
1560
1561        let commit = group.commit(vec![]).await.unwrap();
1562
1563        assert!(commit.ratchet_tree.is_some());
1564
1565        let info = commit
1566            .external_commit_group_info
1567            .unwrap()
1568            .into_group_info()
1569            .unwrap();
1570
1571        assert!(!info.extensions.has_extension(ExtensionType::RATCHET_TREE));
1572    }
1573
1574    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1575    async fn member_identity_is_validated_against_new_extensions() {
1576        let alice = client_with_test_extension(b"alice").await;
1577        let mut alice = alice
1578            .create_group(ExtensionList::new(), Default::default(), None)
1579            .await
1580            .unwrap();
1581
1582        let bob = client_with_test_extension(b"bob").await;
1583        let bob_kp = bob
1584            .generate_key_package_message(Default::default(), Default::default(), None)
1585            .await
1586            .unwrap();
1587
1588        let mut extension_list = ExtensionList::new();
1589        let extension = TestExtension { foo: b'a' };
1590        extension_list.set_from(extension).unwrap();
1591
1592        let res = alice
1593            .commit_builder()
1594            .add_member(bob_kp)
1595            .unwrap()
1596            .set_group_context_ext(extension_list.clone())
1597            .unwrap()
1598            .build()
1599            .await;
1600
1601        assert!(res.is_err());
1602
1603        let alex = client_with_test_extension(b"alex").await;
1604
1605        alice
1606            .commit_builder()
1607            .add_member(
1608                alex.generate_key_package_message(Default::default(), Default::default(), None)
1609                    .await
1610                    .unwrap(),
1611            )
1612            .unwrap()
1613            .set_group_context_ext(extension_list.clone())
1614            .unwrap()
1615            .build()
1616            .await
1617            .unwrap();
1618    }
1619
1620    #[cfg(feature = "by_ref_proposal")]
1621    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1622    async fn server_identity_is_validated_against_new_extensions() {
1623        let alice = client_with_test_extension(b"alice").await;
1624        let mut alice = alice
1625            .create_group(ExtensionList::new(), Default::default(), None)
1626            .await
1627            .unwrap();
1628
1629        let mut extension_list = ExtensionList::new();
1630        let extension = TestExtension { foo: b'a' };
1631        extension_list.set_from(extension).unwrap();
1632
1633        let (alex_server, _) = get_test_signing_identity(TEST_CIPHER_SUITE, b"alex").await;
1634
1635        let mut alex_extensions = extension_list.clone();
1636
1637        alex_extensions
1638            .set_from(ExternalSendersExt {
1639                allowed_senders: vec![alex_server],
1640            })
1641            .unwrap();
1642
1643        let res = alice
1644            .commit_builder()
1645            .set_group_context_ext(alex_extensions)
1646            .unwrap()
1647            .build()
1648            .await;
1649
1650        assert!(res.is_err());
1651
1652        let (bob_server, _) = get_test_signing_identity(TEST_CIPHER_SUITE, b"bob").await;
1653
1654        let mut bob_extensions = extension_list;
1655
1656        bob_extensions
1657            .set_from(ExternalSendersExt {
1658                allowed_senders: vec![bob_server],
1659            })
1660            .unwrap();
1661
1662        alice
1663            .commit_builder()
1664            .set_group_context_ext(bob_extensions)
1665            .unwrap()
1666            .build()
1667            .await
1668            .unwrap();
1669    }
1670
1671    #[derive(Debug, Clone)]
1672    struct IdentityProviderWithExtension(BasicIdentityProvider);
1673
1674    #[derive(Clone, Debug)]
1675    #[cfg_attr(feature = "std", derive(thiserror::Error))]
1676    #[cfg_attr(feature = "std", error("test error"))]
1677    struct IdentityProviderWithExtensionError {}
1678
1679    impl IntoAnyError for IdentityProviderWithExtensionError {
1680        #[cfg(feature = "std")]
1681        fn into_dyn_error(self) -> Result<Box<dyn std::error::Error + Send + Sync>, Self> {
1682            Ok(self.into())
1683        }
1684    }
1685
1686    impl IdentityProviderWithExtension {
1687        // True if the identity starts with the character `foo` from `TestExtension` or if `TestExtension`
1688        // is not set.
1689        #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
1690        async fn starts_with_foo(
1691            &self,
1692            identity: &SigningIdentity,
1693            _timestamp: Option<MlsTime>,
1694            extensions: Option<&ExtensionList>,
1695        ) -> bool {
1696            if let Some(extensions) = extensions {
1697                if let Some(ext) = extensions.get_as::<TestExtension>().unwrap() {
1698                    self.identity(identity, extensions).await.unwrap()[0] == ext.foo
1699                } else {
1700                    true
1701                }
1702            } else {
1703                true
1704            }
1705        }
1706    }
1707
1708    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
1709    #[cfg_attr(mls_build_async, maybe_async::must_be_async)]
1710    impl IdentityProvider for IdentityProviderWithExtension {
1711        type Error = IdentityProviderWithExtensionError;
1712
1713        async fn validate_member(
1714            &self,
1715            identity: &SigningIdentity,
1716            timestamp: Option<MlsTime>,
1717            context: MemberValidationContext<'_>,
1718        ) -> Result<(), Self::Error> {
1719            self.starts_with_foo(identity, timestamp, context.new_extensions())
1720                .await
1721                .then_some(())
1722                .ok_or(IdentityProviderWithExtensionError {})
1723        }
1724
1725        async fn validate_external_sender(
1726            &self,
1727            identity: &SigningIdentity,
1728            timestamp: Option<MlsTime>,
1729            extensions: Option<&ExtensionList>,
1730        ) -> Result<(), Self::Error> {
1731            (!self.starts_with_foo(identity, timestamp, extensions).await)
1732                .then_some(())
1733                .ok_or(IdentityProviderWithExtensionError {})
1734        }
1735
1736        async fn identity(
1737            &self,
1738            signing_identity: &SigningIdentity,
1739            extensions: &ExtensionList,
1740        ) -> Result<Vec<u8>, Self::Error> {
1741            self.0
1742                .identity(signing_identity, extensions)
1743                .await
1744                .map_err(|_| IdentityProviderWithExtensionError {})
1745        }
1746
1747        async fn valid_successor(
1748            &self,
1749            _predecessor: &SigningIdentity,
1750            _successor: &SigningIdentity,
1751            _extensions: &ExtensionList,
1752        ) -> Result<bool, Self::Error> {
1753            Ok(true)
1754        }
1755
1756        fn supported_types(&self) -> Vec<CredentialType> {
1757            self.0.supported_types()
1758        }
1759    }
1760
1761    type ExtensionClientConfig = WithIdentityProvider<
1762        IdentityProviderWithExtension,
1763        WithCryptoProvider<TestCryptoProvider, BaseConfig>,
1764    >;
1765
1766    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
1767    async fn client_with_test_extension(name: &[u8]) -> Client<ExtensionClientConfig> {
1768        let (identity, secret_key) = get_test_signing_identity(TEST_CIPHER_SUITE, name).await;
1769
1770        ClientBuilder::new()
1771            .crypto_provider(TestCryptoProvider::new())
1772            .extension_types(vec![TEST_EXTENSION_TYPE.into()])
1773            .identity_provider(IdentityProviderWithExtension(BasicIdentityProvider::new()))
1774            .signing_identity(identity, secret_key, TEST_CIPHER_SUITE)
1775            .build()
1776    }
1777
1778    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1779    async fn detached_commit() {
1780        let mut group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
1781
1782        let (_commit, secrets) = group.commit_builder().build_detached().await.unwrap();
1783        assert!(group.pending_commit.is_none());
1784        group.apply_detached_commit(secrets).await.unwrap();
1785        assert_eq!(group.context().epoch, 1);
1786    }
1787}