Skip to main content

mls_rs/group/
commit.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// Copyright by contributors to this project.
3// SPDX-License-Identifier: (Apache-2.0 OR MIT)
4
5use alloc::boxed::Box;
6use alloc::vec;
7use alloc::vec::Vec;
8use core::fmt::Debug;
9use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
10use mls_rs_core::{crypto::SignatureSecretKey, error::IntoAnyError};
11
12use crate::{
13    cipher_suite::CipherSuite,
14    client::MlsError,
15    client_config::ClientConfig,
16    extension::RatchetTreeExt,
17    identity::SigningIdentity,
18    protocol_version::ProtocolVersion,
19    signer::Signable,
20    time::MlsTime,
21    tree_kem::{kem::TreeKem, path_secret::PathSecret, TreeKemPrivate, UpdatePath},
22    ExtensionList, MlsRules,
23};
24
25#[cfg(all(not(mls_build_async), feature = "rayon"))]
26use {crate::iter::ParallelIteratorExt, rayon::prelude::*};
27
28use crate::tree_kem::leaf_node::LeafNode;
29
30#[cfg(not(feature = "private_message"))]
31use crate::WireFormat;
32
33#[cfg(feature = "psk")]
34use crate::{
35    group::{JustPreSharedKeyID, PskGroupId, ResumptionPSKUsage, ResumptionPsk},
36    psk::ExternalPskId,
37};
38
39use super::{
40    confirmation_tag::ConfirmationTag,
41    framing::{Content, MlsMessage, MlsMessagePayload, Sender},
42    key_schedule::{KeySchedule, WelcomeSecret},
43    message_hash::MessageHash,
44    message_processor::{path_update_required, MessageProcessor},
45    message_signature::AuthenticatedContent,
46    mls_rules::CommitDirection,
47    proposal::{Proposal, ProposalOrRef},
48    CommitEffect, CommitMessageDescription, EncryptedGroupSecrets, EpochSecrets, ExportedTree,
49    Group, GroupContext, GroupInfo, GroupState, InterimTranscriptHash, NewEpoch,
50    PendingCommitSnapshot, Welcome,
51};
52
53#[cfg(not(feature = "by_ref_proposal"))]
54use super::proposal_cache::prepare_commit;
55
56#[cfg(feature = "custom_proposal")]
57use super::proposal::CustomProposal;
58
59#[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode)]
60#[cfg_attr(feature = "arbitrary", derive(mls_rs_core::arbitrary::Arbitrary))]
61#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
62pub(crate) struct Commit {
63    pub proposals: Vec<ProposalOrRef>,
64    pub path: Option<UpdatePath>,
65}
66
67#[derive(Clone, PartialEq, Debug, MlsEncode, MlsDecode, MlsSize)]
68pub(crate) struct PendingCommit {
69    pub(crate) state: GroupState,
70    pub(crate) epoch_secrets: EpochSecrets,
71    pub(crate) private_tree: TreeKemPrivate,
72    pub(crate) key_schedule: KeySchedule,
73    pub(crate) signer: SignatureSecretKey,
74
75    pub(crate) output: CommitMessageDescription,
76
77    pub(crate) commit_message_hash: MessageHash,
78}
79
80#[derive(Clone)]
81pub struct CommitSecrets(pub(crate) PendingCommitSnapshot);
82
83impl CommitSecrets {
84    /// Deserialize the commit secrets from bytes
85    pub fn from_bytes(bytes: &[u8]) -> Result<Self, MlsError> {
86        Ok(MlsDecode::mls_decode(&mut &*bytes).map(Self)?)
87    }
88
89    /// Serialize the commit secrets to bytes
90    pub fn to_bytes(&self) -> Result<Vec<u8>, MlsError> {
91        Ok(self.0.mls_encode_to_vec()?)
92    }
93}
94
95#[derive(Clone, Debug)]
96#[non_exhaustive]
97/// Result of MLS commit operation using
98/// [`Group::commit`](crate::group::Group::commit) or
99/// [`CommitBuilder::build`](CommitBuilder::build).
100pub struct CommitOutput {
101    /// Commit message to send to other group members.
102    pub commit_message: MlsMessage,
103    /// Welcome messages to send to new group members. If the commit does not add members,
104    /// this list is empty. Otherwise, if [`MlsRules::commit_options`] returns `single_welcome_message`
105    /// set to true, then this list contains a single message sent to all members. Else, the list
106    /// contains one message for each added member. Recipients of each message can be identified using
107    /// [`MlsMessage::key_package_reference`] of their key packages and
108    /// [`MlsMessage::welcome_key_package_references`].
109    pub welcome_messages: Vec<MlsMessage>,
110    /// Ratchet tree that can be sent out of band if
111    /// `ratchet_tree_extension` is not used according to
112    /// [`MlsRules::commit_options`].
113    pub ratchet_tree: Option<ExportedTree<'static>>,
114    /// A group info that can be provided to new members in order to enable external commit
115    /// functionality. This value is set if [`MlsRules::commit_options`] returns
116    /// `allow_external_commit` set to true.
117    pub external_commit_group_info: Option<MlsMessage>,
118    /// Proposals that were received in the prior epoch but not included in the following commit.
119    #[cfg(feature = "by_ref_proposal")]
120    pub unused_proposals: Vec<crate::mls_rules::ProposalInfo<Proposal>>,
121    /// Indicator that the commit contains a path update
122    pub contains_update_path: bool,
123}
124
125impl CommitOutput {
126    /// Commit message to send to other group members.
127    pub fn commit_message(&self) -> &MlsMessage {
128        &self.commit_message
129    }
130
131    /// Welcome message to send to new group members.
132    pub fn welcome_messages(&self) -> &[MlsMessage] {
133        &self.welcome_messages
134    }
135
136    /// Ratchet tree that can be sent out of band if
137    /// `ratchet_tree_extension` is not used according to
138    /// [`MlsRules::commit_options`].
139    pub fn ratchet_tree(&self) -> Option<&ExportedTree<'static>> {
140        self.ratchet_tree.as_ref()
141    }
142
143    /// A group info that can be provided to new members in order to enable external commit
144    /// functionality. This value is set if [`MlsRules::commit_options`] returns
145    /// `allow_external_commit` set to true.
146    pub fn external_commit_group_info(&self) -> Option<&MlsMessage> {
147        self.external_commit_group_info.as_ref()
148    }
149
150    /// Proposals that were received in the prior epoch but not included in the following commit.
151    #[cfg(feature = "by_ref_proposal")]
152    pub fn unused_proposals(&self) -> &[crate::mls_rules::ProposalInfo<Proposal>] {
153        &self.unused_proposals
154    }
155}
156
157/// Build a commit with multiple proposals by-value.
158///
159/// Proposals within a commit can be by-value or by-reference.
160/// Proposals received during the current epoch will be added to the resulting
161/// commit by-reference automatically so long as they pass the rules defined
162/// in the current
163/// [proposal rules](crate::client_builder::ClientBuilder::mls_rules).
164pub struct CommitBuilder<'a, C>
165where
166    C: ClientConfig + Clone,
167{
168    group: &'a mut Group<C>,
169    pub(super) proposals: Vec<Proposal>,
170    authenticated_data: Vec<u8>,
171    group_info_extensions: ExtensionList,
172    new_signer: Option<SignatureSecretKey>,
173    new_signing_identity: Option<SigningIdentity>,
174    new_leaf_node_extensions: Option<ExtensionList>,
175    commit_time: Option<MlsTime>,
176}
177
178impl<'a, C> CommitBuilder<'a, C>
179where
180    C: ClientConfig + Clone,
181{
182    /// Insert an [`AddProposal`](crate::group::proposal::AddProposal) into
183    /// the current commit that is being built.
184    pub fn add_member(mut self, key_package: MlsMessage) -> Result<CommitBuilder<'a, C>, MlsError> {
185        let proposal = self.group.add_proposal(key_package)?;
186        self.proposals.push(proposal);
187        Ok(self)
188    }
189
190    /// Set group info extensions that will be inserted into the resulting
191    /// [welcome messages](CommitOutput::welcome_messages) for new members.
192    ///
193    /// Group info extensions that are transmitted as part of a welcome message
194    /// are encrypted along with other private values.
195    ///
196    /// These extensions can be retrieved as part of
197    /// [`NewMemberInfo`](crate::group::NewMemberInfo) that is returned
198    /// by joining the group via
199    /// [`Client::join_group`](crate::Client::join_group).
200    pub fn set_group_info_ext(self, extensions: ExtensionList) -> Self {
201        Self {
202            group_info_extensions: extensions,
203            ..self
204        }
205    }
206
207    /// Insert a [`RemoveProposal`](crate::group::proposal::RemoveProposal) into
208    /// the current commit that is being built.
209    pub fn remove_member(mut self, index: u32) -> Result<Self, MlsError> {
210        let proposal = self.group.remove_proposal(index)?;
211        self.proposals.push(proposal);
212        Ok(self)
213    }
214
215    /// Insert a
216    /// [`GroupContextExtensions`](crate::group::proposal::Proposal::GroupContextExtensions)
217    /// into the current commit that is being built.
218    pub fn set_group_context_ext(mut self, extensions: ExtensionList) -> Result<Self, MlsError> {
219        let proposal = self.group.group_context_extensions_proposal(extensions);
220        self.proposals.push(proposal);
221        Ok(self)
222    }
223
224    /// Insert a
225    /// [`PreSharedKeyProposal`](crate::group::proposal::PreSharedKeyProposal) with
226    /// an external PSK into the current commit that is being built.
227    #[cfg(feature = "psk")]
228    pub fn add_external_psk(mut self, psk_id: ExternalPskId) -> Result<Self, MlsError> {
229        let key_id = JustPreSharedKeyID::External(psk_id);
230        let proposal = self.group.psk_proposal(key_id)?;
231        self.proposals.push(proposal);
232        Ok(self)
233    }
234
235    /// Insert a
236    /// [`PreSharedKeyProposal`](crate::group::proposal::PreSharedKeyProposal) with
237    /// a resumption PSK into the current commit that is being built.
238    #[cfg(feature = "psk")]
239    pub fn add_resumption_psk(mut self, psk_epoch: u64) -> Result<Self, MlsError> {
240        let psk_id = ResumptionPsk {
241            psk_epoch,
242            usage: ResumptionPSKUsage::Application,
243            psk_group_id: PskGroupId(self.group.group_id().to_vec()),
244        };
245
246        let key_id = JustPreSharedKeyID::Resumption(psk_id);
247        let proposal = self.group.psk_proposal(key_id)?;
248        self.proposals.push(proposal);
249        Ok(self)
250    }
251
252    /// Insert a [`ReInitProposal`](crate::group::proposal::ReInitProposal) into
253    /// the current commit that is being built.
254    pub fn reinit(
255        mut self,
256        group_id: Option<Vec<u8>>,
257        version: ProtocolVersion,
258        cipher_suite: CipherSuite,
259        extensions: ExtensionList,
260    ) -> Result<Self, MlsError> {
261        let proposal = self
262            .group
263            .reinit_proposal(group_id, version, cipher_suite, extensions)?;
264
265        self.proposals.push(proposal);
266        Ok(self)
267    }
268
269    /// Insert a [`CustomProposal`](crate::group::proposal::CustomProposal) into
270    /// the current commit that is being built.
271    #[cfg(feature = "custom_proposal")]
272    pub fn custom_proposal(mut self, proposal: CustomProposal) -> Self {
273        self.proposals.push(Proposal::Custom(proposal));
274        self
275    }
276
277    /// Insert a proposal that was previously constructed such as when a
278    /// proposal is returned from
279    /// [`NewEpoch::unused_proposals`](super::NewEpoch::unused_proposals).
280    pub fn raw_proposal(mut self, proposal: Proposal) -> Self {
281        self.proposals.push(proposal);
282        self
283    }
284
285    /// Insert proposals that were previously constructed such as when a
286    /// proposal is returned from
287    /// [`NewEpoch::unused_proposals`](super::NewEpoch::unused_proposals).
288    pub fn raw_proposals(mut self, mut proposals: Vec<Proposal>) -> Self {
289        self.proposals.append(&mut proposals);
290        self
291    }
292
293    /// Add additional authenticated data to the commit.
294    ///
295    /// # Warning
296    ///
297    /// The data provided here is always sent unencrypted.
298    pub fn authenticated_data(self, authenticated_data: Vec<u8>) -> Self {
299        Self {
300            authenticated_data,
301            ..self
302        }
303    }
304
305    /// Change the committer's signing identity as part of making this commit.
306    /// This will only succeed if the [`IdentityProvider`](crate::IdentityProvider)
307    /// in use by the group considers the credential inside this signing_identity
308    /// [valid](crate::IdentityProvider::validate_member)
309    /// and results in the same
310    /// [identity](crate::IdentityProvider::identity)
311    /// being used.
312    pub fn set_new_signing_identity(
313        self,
314        signer: SignatureSecretKey,
315        signing_identity: SigningIdentity,
316    ) -> Self {
317        Self {
318            new_signer: Some(signer),
319            new_signing_identity: Some(signing_identity),
320            ..self
321        }
322    }
323
324    /// Change the committer's leaf node extensions as part of making this commit.
325    pub fn set_leaf_node_extensions(self, new_leaf_node_extensions: ExtensionList) -> Self {
326        Self {
327            new_leaf_node_extensions: Some(new_leaf_node_extensions),
328            ..self
329        }
330    }
331
332    /// Add a time to associate with the commit creation.
333    pub fn commit_time(self, commit_time: MlsTime) -> Self {
334        Self {
335            commit_time: Some(commit_time),
336            ..self
337        }
338    }
339
340    /// Finalize the commit to send.
341    ///
342    /// # Errors
343    ///
344    /// This function will return an error if any of the proposals provided
345    /// are not contextually valid according to the rules defined by the
346    /// MLS RFC, or if they do not pass the custom rules defined by the current
347    /// [proposal rules](crate::client_builder::ClientBuilder::mls_rules).
348    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
349    pub async fn build(self) -> Result<CommitOutput, MlsError> {
350        let (output, pending_commit) = self
351            .group
352            .commit_internal(
353                self.proposals,
354                None,
355                self.authenticated_data,
356                self.group_info_extensions,
357                self.new_signer,
358                self.new_signing_identity,
359                self.new_leaf_node_extensions,
360                self.commit_time,
361            )
362            .await?;
363
364        self.group.pending_commit = pending_commit.try_into()?;
365
366        Ok(output)
367    }
368
369    /// The same function as `GroupBuilder::build` except the secrets generated
370    /// for the commit are outputted instead of being cached internally.
371    ///
372    /// A detached commit can be applied using `Group::apply_detached_commit`.
373    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
374    pub async fn build_detached(self) -> Result<(CommitOutput, CommitSecrets), MlsError> {
375        let (output, pending_commit) = self
376            .group
377            .commit_internal(
378                self.proposals,
379                None,
380                self.authenticated_data,
381                self.group_info_extensions,
382                self.new_signer,
383                self.new_signing_identity,
384                self.new_leaf_node_extensions,
385                self.commit_time,
386            )
387            .await?;
388
389        Ok((
390            output,
391            CommitSecrets(PendingCommitSnapshot::PendingCommit(
392                pending_commit.mls_encode_to_vec()?,
393            )),
394        ))
395    }
396}
397
398impl<C> Group<C>
399where
400    C: ClientConfig + Clone,
401{
402    /// Perform a commit of received proposals.
403    ///
404    /// This function is the equivalent of [`Group::commit_builder`] immediately
405    /// followed by [`CommitBuilder::build`]. Any received proposals since the
406    /// last commit will be included in the resulting message by-reference.
407    ///
408    /// Data provided in the `authenticated_data` field will be placed into
409    /// the resulting commit message unencrypted.
410    ///
411    /// # Pending Commits
412    ///
413    /// When a commit is created, it is not applied immediately in order to
414    /// allow for the resolution of conflicts when multiple members of a group
415    /// attempt to make commits at the same time. For example, a central relay
416    /// can be used to decide which commit should be accepted by the group by
417    /// determining a consistent view of commit packet order for all clients.
418    ///
419    /// Pending commits are stored internally as part of the group's state
420    /// so they do not need to be tracked outside of this library. Any commit
421    /// message that is processed before calling [Group::apply_pending_commit]
422    /// will clear the currently pending commit.
423    ///
424    /// # Empty Commits
425    ///
426    /// Sending a commit that contains no proposals is a valid operation
427    /// within the MLS protocol. It is useful for providing stronger forward
428    /// secrecy and post-compromise security, especially for long running
429    /// groups when group membership does not change often.
430    ///
431    /// # Path Updates
432    ///
433    /// Path updates provide forward secrecy and post-compromise security
434    /// within the MLS protocol.
435    /// The `path_required` option returned by [`MlsRules::commit_options`](`crate::MlsRules::commit_options`)
436    /// controls the ability of a group to send a commit without a path update.
437    /// An update path will automatically be sent if there are no proposals
438    /// in the commit, or if any proposal other than
439    /// [`Add`](crate::group::proposal::Proposal::Add),
440    /// [`Psk`](crate::group::proposal::Proposal::Psk),
441    /// or [`ReInit`](crate::group::proposal::Proposal::ReInit) are part of the commit.
442    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
443    pub async fn commit(&mut self, authenticated_data: Vec<u8>) -> Result<CommitOutput, MlsError> {
444        self.commit_builder()
445            .authenticated_data(authenticated_data)
446            .build()
447            .await
448    }
449
450    /// The same function as `Group::commit` except the secrets generated
451    /// for the commit are outputted instead of being cached internally.
452    ///
453    /// A detached commit can be applied using `Group::apply_detached_commit`.
454    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
455    pub async fn commit_detached(
456        &mut self,
457        authenticated_data: Vec<u8>,
458    ) -> Result<(CommitOutput, CommitSecrets), MlsError> {
459        self.commit_builder()
460            .authenticated_data(authenticated_data)
461            .build_detached()
462            .await
463    }
464
465    /// Create a new commit builder that can include proposals
466    /// by-value.
467    pub fn commit_builder(&mut self) -> CommitBuilder<'_, C> {
468        CommitBuilder {
469            group: self,
470            proposals: Default::default(),
471            authenticated_data: Default::default(),
472            group_info_extensions: Default::default(),
473            new_signer: Default::default(),
474            new_signing_identity: Default::default(),
475            new_leaf_node_extensions: Default::default(),
476            commit_time: None,
477        }
478    }
479
480    /// Returns commit and optional [`MlsMessage`] containing a welcome message
481    /// for newly added members.
482    #[allow(clippy::too_many_arguments)]
483    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
484    pub(super) async fn commit_internal(
485        &mut self,
486        proposals: Vec<Proposal>,
487        external_leaf: Option<&LeafNode>,
488        authenticated_data: Vec<u8>,
489        mut welcome_group_info_extensions: ExtensionList,
490        new_signer: Option<SignatureSecretKey>,
491        new_signing_identity: Option<SigningIdentity>,
492        new_leaf_node_extensions: Option<ExtensionList>,
493        commit_time: Option<MlsTime>,
494    ) -> Result<(CommitOutput, PendingCommit), MlsError> {
495        if !self.pending_commit.is_none() {
496            return Err(MlsError::ExistingPendingCommit);
497        }
498
499        if self.state.pending_reinit.is_some() {
500            return Err(MlsError::GroupUsedAfterReInit);
501        }
502
503        let mls_rules = self.config.mls_rules();
504
505        let is_external = external_leaf.is_some();
506
507        // Construct an initial Commit object with the proposals field populated from Proposals
508        // received during the current epoch, and an empty path field. Add passed in proposals
509        // by value
510        let sender = if is_external {
511            Sender::NewMemberCommit
512        } else {
513            Sender::Member(*self.private_tree.self_index)
514        };
515
516        let new_signer = new_signer.unwrap_or_else(|| self.signer.clone());
517        let old_signer = &self.signer;
518
519        #[cfg(feature = "std")]
520        let time = Some(crate::time::MlsTime::now());
521
522        #[cfg(not(feature = "std"))]
523        let time = None;
524
525        let time = if commit_time.is_some() {
526            commit_time
527        } else {
528            time
529        };
530
531        #[cfg(feature = "by_ref_proposal")]
532        let proposals = self.state.proposals.prepare_commit(sender, proposals);
533
534        #[cfg(not(feature = "by_ref_proposal"))]
535        let proposals = prepare_commit(sender, proposals);
536
537        let mut provisional_state = self
538            .state
539            .apply_resolved(
540                sender,
541                proposals,
542                external_leaf,
543                &self.config.identity_provider(),
544                &self.cipher_suite_provider,
545                &self.config.secret_store(),
546                &mls_rules,
547                time,
548                CommitDirection::Send,
549            )
550            .await?;
551
552        let (mut provisional_private_tree, _) =
553            self.provisional_private_tree(&provisional_state)?;
554
555        if is_external {
556            provisional_private_tree.self_index = provisional_state
557                .external_init_index
558                .ok_or(MlsError::ExternalCommitMissingExternalInit)?;
559
560            self.private_tree.self_index = provisional_private_tree.self_index;
561        }
562
563        // Decide whether to populate the path field: If the path field is required based on the
564        // proposals that are in the commit (see above), then it MUST be populated. Otherwise, the
565        // sender MAY omit the path field at its discretion.
566        let commit_options = mls_rules
567            .commit_options(
568                &provisional_state.public_tree.roster(),
569                &provisional_state.group_context,
570                &provisional_state.applied_proposals,
571            )
572            .map_err(|e| MlsError::MlsRulesError(e.into_any_error()))?;
573
574        let perform_path_update = commit_options.path_required
575            || path_update_required(&provisional_state.applied_proposals);
576
577        let (update_path, path_secrets, commit_secret) = if perform_path_update {
578            // If populating the path field: Create an UpdatePath using the new tree. Any new
579            // member (from an add proposal) MUST be excluded from the resolution during the
580            // computation of the UpdatePath. The GroupContext for this operation uses the
581            // group_id, epoch, tree_hash, and confirmed_transcript_hash values in the initial
582            // GroupContext object. The leaf_key_package for this UpdatePath must have a
583            // parent_hash extension.
584
585            let new_leaf_node_extensions =
586                new_leaf_node_extensions.or(external_leaf.map(|ln| ln.ungreased_extensions()));
587
588            let new_leaf_node_extensions = match new_leaf_node_extensions {
589                Some(extensions) => extensions,
590                // If we are not setting new extensions and this is not an external leaf then the current node MUST exist.
591                None => self.current_user_leaf_node()?.ungreased_extensions(),
592            };
593
594            #[cfg(feature = "tree_index")]
595            let old_committer_leaf = provisional_state
596                .public_tree
597                .get_leaf_node(provisional_private_tree.self_index)?
598                .clone();
599
600            let encap_gen = TreeKem::new(
601                &mut provisional_state.public_tree,
602                &mut provisional_private_tree,
603            )
604            .encap(
605                &mut provisional_state.group_context,
606                &provisional_state.indexes_of_added_kpkgs,
607                &new_signer,
608                Some(self.config.leaf_properties(new_leaf_node_extensions)),
609                new_signing_identity,
610                &self.cipher_suite_provider,
611                #[cfg(test)]
612                &self.commit_modifiers,
613            )
614            .await?;
615
616            provisional_state
617                .public_tree
618                .update_committer_leaf(
619                    &self.config.identity_provider(),
620                    &provisional_state.group_context.extensions,
621                    provisional_private_tree.self_index,
622                    #[cfg(feature = "tree_index")]
623                    &old_committer_leaf,
624                    #[cfg(test)]
625                    !self.commit_modifiers.skip_committer_self_update_validation,
626                )
627                .await?;
628
629            (
630                Some(encap_gen.update_path),
631                Some(encap_gen.path_secrets),
632                encap_gen.commit_secret,
633            )
634        } else {
635            // Update the tree hash, since it was not updated by encap.
636            provisional_state
637                .public_tree
638                .update_hashes(
639                    &[provisional_private_tree.self_index],
640                    &self.cipher_suite_provider,
641                )
642                .await?;
643
644            provisional_state.group_context.tree_hash = provisional_state
645                .public_tree
646                .tree_hash(&self.cipher_suite_provider)
647                .await?;
648
649            (None, None, PathSecret::empty(&self.cipher_suite_provider))
650        };
651
652        #[cfg(feature = "psk")]
653        let (psk_secret, psks) = self
654            .get_psk(&provisional_state.applied_proposals.psks)
655            .await?;
656
657        #[cfg(not(feature = "psk"))]
658        let psk_secret = self.get_psk();
659
660        let added_key_pkgs: Vec<_> = provisional_state
661            .applied_proposals
662            .additions
663            .iter()
664            .map(|info| info.proposal.key_package.clone())
665            .collect();
666
667        let commit = Commit {
668            proposals: provisional_state.applied_proposals.proposals_or_refs(),
669            path: update_path,
670        };
671
672        let mut auth_content = AuthenticatedContent::new_signed(
673            &self.cipher_suite_provider,
674            self.context(),
675            sender,
676            Content::Commit(Box::new(commit)),
677            old_signer,
678            #[cfg(feature = "private_message")]
679            self.encryption_options()?.control_wire_format(sender),
680            #[cfg(not(feature = "private_message"))]
681            WireFormat::PublicMessage,
682            authenticated_data,
683        )
684        .await?;
685
686        // Use the signature, the commit_secret and the psk_secret to advance the key schedule and
687        // compute the confirmation_tag value in the MlsPlaintext.
688        let confirmed_transcript_hash = super::transcript_hash::create(
689            self.cipher_suite_provider(),
690            &self.state.interim_transcript_hash,
691            &auth_content,
692        )
693        .await?;
694
695        provisional_state.group_context.confirmed_transcript_hash = confirmed_transcript_hash;
696
697        let key_schedule_result = KeySchedule::from_key_schedule(
698            &self.key_schedule,
699            &commit_secret,
700            &provisional_state.group_context,
701            #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
702            provisional_state.public_tree.total_leaf_count(),
703            &psk_secret,
704            &self.cipher_suite_provider,
705        )
706        .await?;
707
708        let confirmation_tag = ConfirmationTag::create(
709            &key_schedule_result.confirmation_key,
710            &provisional_state.group_context.confirmed_transcript_hash,
711            &self.cipher_suite_provider,
712        )
713        .await?;
714
715        let interim_transcript_hash = InterimTranscriptHash::create(
716            self.cipher_suite_provider(),
717            &provisional_state.group_context.confirmed_transcript_hash,
718            &confirmation_tag,
719        )
720        .await?;
721
722        auth_content.auth.confirmation_tag = Some(confirmation_tag.clone());
723
724        let ratchet_tree_ext = commit_options
725            .ratchet_tree_extension
726            .then(|| RatchetTreeExt {
727                tree_data: ExportedTree::new(provisional_state.public_tree.nodes.clone()),
728            });
729
730        // Generate external commit group info if required by commit_options
731        let external_commit_group_info = match commit_options.allow_external_commit {
732            true => {
733                let mut extensions = ExtensionList::new();
734
735                extensions.set_from({
736                    key_schedule_result
737                        .key_schedule
738                        .get_external_key_pair_ext(&self.cipher_suite_provider)
739                        .await?
740                })?;
741
742                if let Some(ref ratchet_tree_ext) = ratchet_tree_ext {
743                    if !commit_options.always_out_of_band_ratchet_tree {
744                        extensions.set_from(ratchet_tree_ext.clone())?;
745                    }
746                }
747
748                let info = self
749                    .make_group_info(
750                        &provisional_state.group_context,
751                        extensions,
752                        &confirmation_tag,
753                        &new_signer,
754                    )
755                    .await?;
756
757                let msg =
758                    MlsMessage::new(self.protocol_version(), MlsMessagePayload::GroupInfo(info));
759
760                Some(msg)
761            }
762            false => None,
763        };
764
765        // Build the group info that will be placed into the welcome messages.
766        // Add the ratchet tree extension if necessary
767        if let Some(ratchet_tree_ext) = ratchet_tree_ext {
768            welcome_group_info_extensions.set_from(ratchet_tree_ext)?;
769        }
770
771        let welcome_group_info = self
772            .make_group_info(
773                &provisional_state.group_context,
774                welcome_group_info_extensions,
775                &confirmation_tag,
776                &new_signer,
777            )
778            .await?;
779
780        // Encrypt the GroupInfo using the key and nonce derived from the joiner_secret for
781        // the new epoch
782        let welcome_secret = WelcomeSecret::from_joiner_secret(
783            &self.cipher_suite_provider,
784            &key_schedule_result.joiner_secret,
785            &psk_secret,
786        )
787        .await?;
788
789        let encrypted_group_info = welcome_secret
790            .encrypt(&welcome_group_info.mls_encode_to_vec()?)
791            .await?;
792
793        // Encrypt path secrets and joiner secret to new members
794        let path_secrets = path_secrets.as_ref();
795
796        #[cfg(not(any(mls_build_async, not(feature = "rayon"))))]
797        let encrypted_path_secrets: Vec<_> = added_key_pkgs
798            .into_par_iter()
799            .zip(&provisional_state.indexes_of_added_kpkgs)
800            .map(|(key_package, leaf_index)| {
801                self.encrypt_group_secrets(
802                    &key_package,
803                    *leaf_index,
804                    &key_schedule_result.joiner_secret,
805                    path_secrets,
806                    #[cfg(feature = "psk")]
807                    psks.clone(),
808                    &encrypted_group_info,
809                )
810            })
811            .try_collect()?;
812
813        #[cfg(any(mls_build_async, not(feature = "rayon")))]
814        let encrypted_path_secrets = {
815            let mut secrets = Vec::new();
816
817            for (key_package, leaf_index) in added_key_pkgs
818                .into_iter()
819                .zip(&provisional_state.indexes_of_added_kpkgs)
820            {
821                secrets.push(
822                    self.encrypt_group_secrets(
823                        &key_package,
824                        *leaf_index,
825                        &key_schedule_result.joiner_secret,
826                        path_secrets,
827                        #[cfg(feature = "psk")]
828                        psks.clone(),
829                        &encrypted_group_info,
830                    )
831                    .await?,
832                );
833            }
834
835            secrets
836        };
837
838        let welcome_messages =
839            if commit_options.single_welcome_message && !encrypted_path_secrets.is_empty() {
840                vec![self.make_welcome_message(encrypted_path_secrets, encrypted_group_info)]
841            } else {
842                encrypted_path_secrets
843                    .into_iter()
844                    .map(|s| self.make_welcome_message(vec![s], encrypted_group_info.clone()))
845                    .collect()
846            };
847
848        let commit_message = self.format_for_wire(auth_content.clone()).await?;
849
850        // TODO is it necessary to clone the tree here? or can we just output serialized bytes?
851        let ratchet_tree = (!commit_options.ratchet_tree_extension
852            || commit_options.always_out_of_band_ratchet_tree)
853            .then(|| ExportedTree::new(provisional_state.public_tree.nodes.clone()));
854
855        let pending_reinit = provisional_state
856            .applied_proposals
857            .reinitializations
858            .first();
859
860        let pending_commit = PendingCommit {
861            output: CommitMessageDescription {
862                is_external: matches!(auth_content.content.sender, Sender::NewMemberCommit),
863                authenticated_data: auth_content.content.authenticated_data,
864                committer: *provisional_private_tree.self_index,
865                effect: match pending_reinit {
866                    Some(r) => CommitEffect::ReInit(r.clone()),
867                    None => CommitEffect::NewEpoch(
868                        NewEpoch::new(self.state.clone(), &provisional_state).into(),
869                    ),
870                },
871            },
872
873            state: GroupState {
874                #[cfg(feature = "by_ref_proposal")]
875                proposals: crate::group::ProposalCache::new(
876                    self.protocol_version(),
877                    self.group_id().to_vec(),
878                ),
879                context: provisional_state.group_context,
880                public_tree: provisional_state.public_tree,
881                interim_transcript_hash,
882                pending_reinit: pending_reinit.map(|r| r.proposal.clone()),
883                confirmation_tag,
884            },
885
886            commit_message_hash: MessageHash::compute(&self.cipher_suite_provider, &commit_message)
887                .await?,
888            signer: new_signer,
889            epoch_secrets: key_schedule_result.epoch_secrets,
890            key_schedule: key_schedule_result.key_schedule,
891
892            private_tree: provisional_private_tree,
893        };
894
895        let output = CommitOutput {
896            commit_message,
897            welcome_messages,
898            ratchet_tree,
899            external_commit_group_info,
900            contains_update_path: perform_path_update,
901            #[cfg(feature = "by_ref_proposal")]
902            unused_proposals: provisional_state.unused_proposals,
903        };
904
905        Ok((output, pending_commit))
906    }
907
908    // Construct a GroupInfo reflecting the new state
909    // Group ID, epoch, tree, and confirmed transcript hash from the new state
910    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
911    async fn make_group_info(
912        &self,
913        group_context: &GroupContext,
914        extensions: ExtensionList,
915        confirmation_tag: &ConfirmationTag,
916        signer: &SignatureSecretKey,
917    ) -> Result<GroupInfo, MlsError> {
918        let mut group_info = GroupInfo {
919            group_context: group_context.clone(),
920            extensions,
921            confirmation_tag: confirmation_tag.clone(), // The confirmation_tag from the MlsPlaintext object
922            signer: self.current_member_leaf_index(),
923            signature: vec![],
924        };
925
926        group_info.grease(self.cipher_suite_provider())?;
927
928        // Sign the GroupInfo using the member's private signing key
929        group_info
930            .sign(&self.cipher_suite_provider, signer, &())
931            .await?;
932
933        Ok(group_info)
934    }
935
936    fn make_welcome_message(
937        &self,
938        secrets: Vec<EncryptedGroupSecrets>,
939        encrypted_group_info: Vec<u8>,
940    ) -> MlsMessage {
941        MlsMessage::new(
942            self.context().protocol_version,
943            MlsMessagePayload::Welcome(Welcome {
944                cipher_suite: self.context().cipher_suite,
945                secrets,
946                encrypted_group_info,
947            }),
948        )
949    }
950}
951
952#[cfg(test)]
953pub(crate) mod test_utils {
954    use alloc::vec::Vec;
955
956    use crate::{
957        crypto::SignatureSecretKey,
958        tree_kem::{leaf_node::LeafNode, TreeKemPublic, UpdatePathNode},
959    };
960
961    #[derive(Copy, Clone, Debug)]
962    pub struct CommitModifiers {
963        pub modify_leaf: fn(&mut LeafNode, &SignatureSecretKey) -> Option<SignatureSecretKey>,
964        pub modify_tree: fn(&mut TreeKemPublic),
965        pub modify_path: fn(Vec<UpdatePathNode>) -> Vec<UpdatePathNode>,
966        pub skip_committer_self_update_validation: bool,
967    }
968
969    impl Default for CommitModifiers {
970        fn default() -> Self {
971            Self {
972                modify_leaf: |_, _| None,
973                modify_tree: |_| (),
974                modify_path: |a| a,
975                skip_committer_self_update_validation: false,
976            }
977        }
978    }
979}
980
981#[cfg(test)]
982mod tests {
983    use mls_rs_core::{
984        error::IntoAnyError,
985        extension::ExtensionType,
986        identity::{CredentialType, IdentityProvider, MemberValidationContext},
987        time::MlsTime,
988    };
989
990    use crate::extension::RequiredCapabilitiesExt;
991    use crate::{
992        client::test_utils::{test_client_with_key_pkg, TEST_CIPHER_SUITE, TEST_PROTOCOL_VERSION},
993        client_builder::{
994            test_utils::TestClientConfig, BaseConfig, ClientBuilder, WithCryptoProvider,
995            WithIdentityProvider,
996        },
997        client_config::ClientConfig,
998        crypto::test_utils::TestCryptoProvider,
999        extension::test_utils::{TestExtension, TEST_EXTENSION_TYPE},
1000        group::test_utils::{test_group, test_group_custom},
1001        group::{
1002            proposal::ProposalType,
1003            test_utils::{test_group_custom_config, test_n_member_group},
1004        },
1005        identity::test_utils::get_test_signing_identity,
1006        identity::{basic::BasicIdentityProvider, test_utils::get_test_basic_credential},
1007        key_package::test_utils::test_key_package_message,
1008        mls_rules::CommitOptions,
1009        Client,
1010    };
1011
1012    #[cfg(feature = "by_ref_proposal")]
1013    use crate::crypto::test_utils::test_cipher_suite_provider;
1014    #[cfg(feature = "by_ref_proposal")]
1015    use crate::extension::ExternalSendersExt;
1016    #[cfg(feature = "by_ref_proposal")]
1017    use crate::group::mls_rules::DefaultMlsRules;
1018
1019    #[cfg(feature = "psk")]
1020    use crate::{
1021        group::proposal::PreSharedKeyProposal,
1022        psk::{JustPreSharedKeyID, PreSharedKey, PreSharedKeyID},
1023    };
1024
1025    use super::*;
1026
1027    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
1028    async fn test_commit_builder_group() -> Group<TestClientConfig> {
1029        test_group_custom_config(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, |b| {
1030            b.custom_proposal_type(ProposalType::from(42))
1031                .extension_type(TEST_EXTENSION_TYPE.into())
1032        })
1033        .await
1034        .group
1035    }
1036
1037    fn assert_commit_builder_output<C: ClientConfig>(
1038        group: Group<C>,
1039        mut commit_output: CommitOutput,
1040        expected: Vec<Proposal>,
1041        welcome_count: usize,
1042    ) {
1043        let plaintext = commit_output.commit_message.into_plaintext().unwrap();
1044
1045        let commit_data = match plaintext.content.content {
1046            Content::Commit(commit) => commit,
1047            #[cfg(any(feature = "private_message", feature = "by_ref_proposal"))]
1048            _ => panic!("Found non-commit data"),
1049        };
1050
1051        assert_eq!(commit_data.proposals.len(), expected.len());
1052
1053        commit_data.proposals.into_iter().for_each(|proposal| {
1054            let proposal = match proposal {
1055                ProposalOrRef::Proposal(p) => p,
1056                #[cfg(feature = "by_ref_proposal")]
1057                ProposalOrRef::Reference(_) => panic!("found proposal reference"),
1058            };
1059
1060            #[cfg(feature = "psk")]
1061            if let Some(psk_id) = match proposal.as_ref() {
1062                Proposal::Psk(PreSharedKeyProposal { psk: PreSharedKeyID { key_id: JustPreSharedKeyID::External(psk_id), .. },}) => Some(psk_id),
1063                _ => None,
1064            } {
1065                let found = expected.iter().any(|item| matches!(item, Proposal::Psk(PreSharedKeyProposal { psk: PreSharedKeyID { key_id: JustPreSharedKeyID::External(id), .. }}) if id == psk_id));
1066
1067                assert!(found)
1068            } else {
1069                assert!(expected.contains(&proposal));
1070            }
1071
1072            #[cfg(not(feature = "psk"))]
1073            assert!(expected.contains(&proposal));
1074        });
1075
1076        if welcome_count > 0 {
1077            let welcome_msg = commit_output.welcome_messages.pop().unwrap();
1078
1079            assert_eq!(welcome_msg.version, group.state.context.protocol_version);
1080
1081            let welcome_msg = welcome_msg.into_welcome().unwrap();
1082
1083            assert_eq!(welcome_msg.cipher_suite, group.state.context.cipher_suite);
1084            assert_eq!(welcome_msg.secrets.len(), welcome_count);
1085        } else {
1086            assert!(commit_output.welcome_messages.is_empty());
1087        }
1088    }
1089
1090    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1091    async fn test_commit_builder_add() {
1092        let mut group = test_commit_builder_group().await;
1093
1094        let test_key_package =
1095            test_key_package_message(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "alice").await;
1096
1097        let commit_output = group
1098            .commit_builder()
1099            .add_member(test_key_package.clone())
1100            .unwrap()
1101            .build()
1102            .await
1103            .unwrap();
1104
1105        let expected_add = group.add_proposal(test_key_package).unwrap();
1106
1107        assert_commit_builder_output(group, commit_output, vec![expected_add], 1)
1108    }
1109
1110    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1111    async fn test_commit_builder_add_with_ext() {
1112        let mut group = test_commit_builder_group().await;
1113
1114        let (bob_client, bob_key_package) =
1115            test_client_with_key_pkg(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "bob").await;
1116
1117        let ext = TestExtension { foo: 42 };
1118        let mut extension_list = ExtensionList::default();
1119        extension_list.set_from(ext.clone()).unwrap();
1120
1121        let welcome_message = group
1122            .commit_builder()
1123            .add_member(bob_key_package)
1124            .unwrap()
1125            .set_group_info_ext(extension_list)
1126            .build()
1127            .await
1128            .unwrap()
1129            .welcome_messages
1130            .remove(0);
1131
1132        let (_, context) = bob_client
1133            .join_group(None, &welcome_message, None)
1134            .await
1135            .unwrap();
1136
1137        assert_eq!(
1138            context
1139                .group_info_extensions
1140                .get_as::<TestExtension>()
1141                .unwrap()
1142                .unwrap(),
1143            ext
1144        );
1145    }
1146
1147    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1148    async fn test_commit_builder_remove() {
1149        let mut group = test_commit_builder_group().await;
1150        let test_key_package =
1151            test_key_package_message(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "alice").await;
1152
1153        group
1154            .commit_builder()
1155            .add_member(test_key_package)
1156            .unwrap()
1157            .build()
1158            .await
1159            .unwrap();
1160
1161        group.apply_pending_commit().await.unwrap();
1162
1163        let commit_output = group
1164            .commit_builder()
1165            .remove_member(1)
1166            .unwrap()
1167            .build()
1168            .await
1169            .unwrap();
1170
1171        let expected_remove = group.remove_proposal(1).unwrap();
1172
1173        assert_commit_builder_output(group, commit_output, vec![expected_remove], 0);
1174    }
1175
1176    #[cfg(feature = "psk")]
1177    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1178    async fn test_commit_builder_psk() {
1179        let mut group = test_commit_builder_group().await;
1180        let test_psk = ExternalPskId::new(vec![1]);
1181
1182        group
1183            .config
1184            .secret_store()
1185            .insert(test_psk.clone(), PreSharedKey::from(vec![1]));
1186
1187        let commit_output = group
1188            .commit_builder()
1189            .add_external_psk(test_psk.clone())
1190            .unwrap()
1191            .build()
1192            .await
1193            .unwrap();
1194
1195        let key_id = JustPreSharedKeyID::External(test_psk);
1196        let expected_psk = group.psk_proposal(key_id).unwrap();
1197
1198        assert_commit_builder_output(group, commit_output, vec![expected_psk], 0)
1199    }
1200
1201    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1202    async fn test_commit_builder_group_context_ext() {
1203        let mut group = test_commit_builder_group().await;
1204        let mut test_ext = ExtensionList::default();
1205        test_ext
1206            .set_from(RequiredCapabilitiesExt::default())
1207            .unwrap();
1208
1209        let commit_output = group
1210            .commit_builder()
1211            .set_group_context_ext(test_ext.clone())
1212            .unwrap()
1213            .build()
1214            .await
1215            .unwrap();
1216
1217        let expected_ext = group.group_context_extensions_proposal(test_ext);
1218
1219        assert_commit_builder_output(group, commit_output, vec![expected_ext], 0);
1220    }
1221
1222    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1223    async fn test_commit_builder_reinit() {
1224        let mut group = test_commit_builder_group().await;
1225        let test_group_id = "foo".as_bytes().to_vec();
1226        let test_cipher_suite = TEST_CIPHER_SUITE;
1227        let test_protocol_version = TEST_PROTOCOL_VERSION;
1228        let mut test_ext = ExtensionList::default();
1229
1230        test_ext
1231            .set_from(RequiredCapabilitiesExt::default())
1232            .unwrap();
1233
1234        let commit_output = group
1235            .commit_builder()
1236            .reinit(
1237                Some(test_group_id.clone()),
1238                test_protocol_version,
1239                test_cipher_suite,
1240                test_ext.clone(),
1241            )
1242            .unwrap()
1243            .build()
1244            .await
1245            .unwrap();
1246
1247        let expected_reinit = group
1248            .reinit_proposal(
1249                Some(test_group_id),
1250                test_protocol_version,
1251                test_cipher_suite,
1252                test_ext,
1253            )
1254            .unwrap();
1255
1256        assert_commit_builder_output(group, commit_output, vec![expected_reinit], 0);
1257    }
1258
1259    #[cfg(feature = "custom_proposal")]
1260    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1261    async fn test_commit_builder_custom_proposal() {
1262        let mut group = test_commit_builder_group().await;
1263
1264        let proposal = CustomProposal::new(42.into(), vec![0, 1]);
1265
1266        let commit_output = group
1267            .commit_builder()
1268            .custom_proposal(proposal.clone())
1269            .build()
1270            .await
1271            .unwrap();
1272
1273        assert_commit_builder_output(group, commit_output, vec![Proposal::Custom(proposal)], 0);
1274    }
1275
1276    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1277    async fn test_commit_builder_chaining() {
1278        let mut group = test_commit_builder_group().await;
1279        let kp1 = test_key_package_message(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "alice").await;
1280        let kp2 = test_key_package_message(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "bob").await;
1281
1282        let expected_adds = vec![
1283            group.add_proposal(kp1.clone()).unwrap(),
1284            group.add_proposal(kp2.clone()).unwrap(),
1285        ];
1286
1287        let commit_output = group
1288            .commit_builder()
1289            .add_member(kp1)
1290            .unwrap()
1291            .add_member(kp2)
1292            .unwrap()
1293            .build()
1294            .await
1295            .unwrap();
1296
1297        assert_commit_builder_output(group, commit_output, expected_adds, 2);
1298    }
1299
1300    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1301    async fn test_commit_builder_empty_commit() {
1302        let mut group = test_commit_builder_group().await;
1303
1304        let commit_output = group.commit_builder().build().await.unwrap();
1305
1306        assert_commit_builder_output(group, commit_output, vec![], 0);
1307    }
1308
1309    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1310    async fn test_commit_builder_authenticated_data() {
1311        let mut group = test_commit_builder_group().await;
1312        let test_data = "test".as_bytes().to_vec();
1313
1314        let commit_output = group
1315            .commit_builder()
1316            .authenticated_data(test_data.clone())
1317            .build()
1318            .await
1319            .unwrap();
1320
1321        assert_eq!(
1322            commit_output
1323                .commit_message
1324                .into_plaintext()
1325                .unwrap()
1326                .content
1327                .authenticated_data,
1328            test_data
1329        );
1330    }
1331
1332    #[cfg(feature = "by_ref_proposal")]
1333    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1334    async fn test_commit_builder_multiple_welcome_messages() {
1335        let mut group = test_group_custom_config(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, |b| {
1336            let options = CommitOptions::new().with_single_welcome_message(false);
1337            b.mls_rules(DefaultMlsRules::new().with_commit_options(options))
1338        })
1339        .await;
1340
1341        let (alice, alice_kp) =
1342            test_client_with_key_pkg(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "a").await;
1343
1344        let (bob, bob_kp) =
1345            test_client_with_key_pkg(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "b").await;
1346
1347        group.propose_add(alice_kp.clone(), vec![]).await.unwrap();
1348
1349        group.propose_add(bob_kp.clone(), vec![]).await.unwrap();
1350
1351        let output = group.commit(Vec::new()).await.unwrap();
1352        let welcomes = output.welcome_messages;
1353
1354        let cs = test_cipher_suite_provider(TEST_CIPHER_SUITE);
1355
1356        for (client, kp) in [(alice, alice_kp), (bob, bob_kp)] {
1357            let kp_ref = kp.key_package_reference(&cs).await.unwrap().unwrap();
1358
1359            let welcome = welcomes
1360                .iter()
1361                .find(|w| w.welcome_key_package_references().contains(&&kp_ref))
1362                .unwrap();
1363
1364            client.join_group(None, welcome, None).await.unwrap();
1365
1366            assert_eq!(welcome.clone().into_welcome().unwrap().secrets.len(), 1);
1367        }
1368    }
1369
1370    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1371    async fn commit_can_change_credential() {
1372        let cs = TEST_CIPHER_SUITE;
1373        let mut groups = test_n_member_group(TEST_PROTOCOL_VERSION, cs, 3).await;
1374        let (identity, secret_key) = get_test_signing_identity(cs, b"member").await;
1375
1376        let commit_output = groups[0]
1377            .commit_builder()
1378            .set_new_signing_identity(secret_key, identity.clone())
1379            .build()
1380            .await
1381            .unwrap();
1382
1383        // Check that the credential was updated by in the committer's state.
1384        groups[0].process_pending_commit().await.unwrap();
1385        let new_member = groups[0].roster().member_with_index(0).unwrap();
1386
1387        assert_eq!(
1388            new_member.signing_identity.credential,
1389            get_test_basic_credential(b"member".to_vec())
1390        );
1391
1392        assert_eq!(
1393            new_member.signing_identity.signature_key,
1394            identity.signature_key
1395        );
1396
1397        // Check that the credential was updated in another member's state.
1398        groups[1]
1399            .process_message(commit_output.commit_message)
1400            .await
1401            .unwrap();
1402
1403        let new_member = groups[1].roster().member_with_index(0).unwrap();
1404
1405        assert_eq!(
1406            new_member.signing_identity.credential,
1407            get_test_basic_credential(b"member".to_vec())
1408        );
1409
1410        assert_eq!(
1411            new_member.signing_identity.signature_key,
1412            identity.signature_key
1413        );
1414    }
1415
1416    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1417    async fn commit_includes_tree_if_no_ratchet_tree_ext() {
1418        let mut group = test_group_custom(
1419            TEST_PROTOCOL_VERSION,
1420            TEST_CIPHER_SUITE,
1421            Default::default(),
1422            None,
1423            Some(CommitOptions::new().with_ratchet_tree_extension(false)),
1424        )
1425        .await;
1426
1427        let commit = group.commit(vec![]).await.unwrap();
1428
1429        group.apply_pending_commit().await.unwrap();
1430
1431        let new_tree = group.export_tree();
1432
1433        assert_eq!(new_tree, commit.ratchet_tree.unwrap())
1434    }
1435
1436    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1437    async fn commit_does_not_include_tree_if_ratchet_tree_ext() {
1438        let mut group = test_group_custom(
1439            TEST_PROTOCOL_VERSION,
1440            TEST_CIPHER_SUITE,
1441            Default::default(),
1442            None,
1443            Some(CommitOptions::new().with_ratchet_tree_extension(true)),
1444        )
1445        .await;
1446
1447        let commit = group.commit(vec![]).await.unwrap();
1448
1449        assert!(commit.ratchet_tree.is_none());
1450    }
1451
1452    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1453    async fn commit_includes_external_commit_group_info_if_requested() {
1454        let mut group = test_group_custom(
1455            TEST_PROTOCOL_VERSION,
1456            TEST_CIPHER_SUITE,
1457            Default::default(),
1458            None,
1459            Some(
1460                CommitOptions::new()
1461                    .with_allow_external_commit(true)
1462                    .with_ratchet_tree_extension(false),
1463            ),
1464        )
1465        .await;
1466
1467        let commit = group.commit(vec![]).await.unwrap();
1468
1469        let info = commit
1470            .external_commit_group_info
1471            .unwrap()
1472            .into_group_info()
1473            .unwrap();
1474
1475        assert!(!info.extensions.has_extension(ExtensionType::RATCHET_TREE));
1476        assert!(info.extensions.has_extension(ExtensionType::EXTERNAL_PUB));
1477    }
1478
1479    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1480    async fn commit_includes_external_commit_and_tree_if_requested() {
1481        let mut group = test_group_custom(
1482            TEST_PROTOCOL_VERSION,
1483            TEST_CIPHER_SUITE,
1484            Default::default(),
1485            None,
1486            Some(
1487                CommitOptions::new()
1488                    .with_allow_external_commit(true)
1489                    .with_ratchet_tree_extension(true),
1490            ),
1491        )
1492        .await;
1493
1494        let commit = group.commit(vec![]).await.unwrap();
1495
1496        let info = commit
1497            .external_commit_group_info
1498            .unwrap()
1499            .into_group_info()
1500            .unwrap();
1501
1502        assert!(info.extensions.has_extension(ExtensionType::RATCHET_TREE));
1503        assert!(info.extensions.has_extension(ExtensionType::EXTERNAL_PUB));
1504    }
1505
1506    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1507    async fn commit_does_not_include_external_commit_group_info_if_not_requested() {
1508        let mut group = test_group_custom(
1509            TEST_PROTOCOL_VERSION,
1510            TEST_CIPHER_SUITE,
1511            Default::default(),
1512            None,
1513            Some(CommitOptions::new().with_allow_external_commit(false)),
1514        )
1515        .await;
1516
1517        let commit = group.commit(vec![]).await.unwrap();
1518
1519        assert!(commit.external_commit_group_info.is_none());
1520    }
1521
1522    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1523    async fn commit_includes_tree_out_of_bounds_and_not_in_external_group_info_if_requested_tree_ext_off(
1524    ) {
1525        let mut group = test_group_custom(
1526            TEST_PROTOCOL_VERSION,
1527            TEST_CIPHER_SUITE,
1528            Default::default(),
1529            None,
1530            Some(
1531                CommitOptions::new()
1532                    .with_always_out_of_band_ratchet_tree(true)
1533                    .with_ratchet_tree_extension(false)
1534                    .with_allow_external_commit(true),
1535            ),
1536        )
1537        .await;
1538
1539        let commit = group.commit(vec![]).await.unwrap();
1540
1541        assert!(commit.ratchet_tree.is_some());
1542
1543        let info = commit
1544            .external_commit_group_info
1545            .unwrap()
1546            .into_group_info()
1547            .unwrap();
1548
1549        assert!(!info.extensions.has_extension(ExtensionType::RATCHET_TREE));
1550    }
1551
1552    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1553    async fn commit_includes_tree_out_of_bounds_and_not_in_external_group_info_if_requested_tree_ext_on(
1554    ) {
1555        let mut group = test_group_custom(
1556            TEST_PROTOCOL_VERSION,
1557            TEST_CIPHER_SUITE,
1558            Default::default(),
1559            None,
1560            Some(
1561                CommitOptions::new()
1562                    .with_always_out_of_band_ratchet_tree(true)
1563                    .with_ratchet_tree_extension(true)
1564                    .with_allow_external_commit(true),
1565            ),
1566        )
1567        .await;
1568
1569        let commit = group.commit(vec![]).await.unwrap();
1570
1571        assert!(commit.ratchet_tree.is_some());
1572
1573        let info = commit
1574            .external_commit_group_info
1575            .unwrap()
1576            .into_group_info()
1577            .unwrap();
1578
1579        assert!(!info.extensions.has_extension(ExtensionType::RATCHET_TREE));
1580    }
1581
1582    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1583    async fn member_identity_is_validated_against_new_extensions() {
1584        let alice = client_with_test_extension(b"alice").await;
1585        let mut alice = alice
1586            .create_group(ExtensionList::new(), Default::default(), None)
1587            .await
1588            .unwrap();
1589
1590        let bob = client_with_test_extension(b"bob").await;
1591        let bob_kp = bob
1592            .generate_key_package_message(Default::default(), Default::default(), None)
1593            .await
1594            .unwrap();
1595
1596        let mut extension_list = ExtensionList::new();
1597        let extension = TestExtension { foo: b'a' };
1598        extension_list.set_from(extension).unwrap();
1599
1600        let res = alice
1601            .commit_builder()
1602            .add_member(bob_kp)
1603            .unwrap()
1604            .set_group_context_ext(extension_list.clone())
1605            .unwrap()
1606            .build()
1607            .await;
1608
1609        assert!(res.is_err());
1610
1611        let alex = client_with_test_extension(b"alex").await;
1612
1613        alice
1614            .commit_builder()
1615            .add_member(
1616                alex.generate_key_package_message(Default::default(), Default::default(), None)
1617                    .await
1618                    .unwrap(),
1619            )
1620            .unwrap()
1621            .set_group_context_ext(extension_list.clone())
1622            .unwrap()
1623            .build()
1624            .await
1625            .unwrap();
1626    }
1627
1628    #[cfg(feature = "by_ref_proposal")]
1629    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1630    async fn server_identity_is_validated_against_new_extensions() {
1631        let alice = client_with_test_extension(b"alice").await;
1632        let mut alice = alice
1633            .create_group(ExtensionList::new(), Default::default(), None)
1634            .await
1635            .unwrap();
1636
1637        let mut extension_list = ExtensionList::new();
1638        let extension = TestExtension { foo: b'a' };
1639        extension_list.set_from(extension).unwrap();
1640
1641        let (alex_server, _) = get_test_signing_identity(TEST_CIPHER_SUITE, b"alex").await;
1642
1643        let mut alex_extensions = extension_list.clone();
1644
1645        alex_extensions
1646            .set_from(ExternalSendersExt {
1647                allowed_senders: vec![alex_server],
1648            })
1649            .unwrap();
1650
1651        let res = alice
1652            .commit_builder()
1653            .set_group_context_ext(alex_extensions)
1654            .unwrap()
1655            .build()
1656            .await;
1657
1658        assert!(res.is_err());
1659
1660        let (bob_server, _) = get_test_signing_identity(TEST_CIPHER_SUITE, b"bob").await;
1661
1662        let mut bob_extensions = extension_list;
1663
1664        bob_extensions
1665            .set_from(ExternalSendersExt {
1666                allowed_senders: vec![bob_server],
1667            })
1668            .unwrap();
1669
1670        alice
1671            .commit_builder()
1672            .set_group_context_ext(bob_extensions)
1673            .unwrap()
1674            .build()
1675            .await
1676            .unwrap();
1677    }
1678
1679    #[derive(Debug, Clone)]
1680    struct IdentityProviderWithExtension(BasicIdentityProvider);
1681
1682    #[derive(Clone, Debug)]
1683    #[cfg_attr(feature = "std", derive(thiserror::Error))]
1684    #[cfg_attr(feature = "std", error("test error"))]
1685    struct IdentityProviderWithExtensionError {}
1686
1687    impl IntoAnyError for IdentityProviderWithExtensionError {
1688        #[cfg(feature = "std")]
1689        fn into_dyn_error(self) -> Result<Box<dyn std::error::Error + Send + Sync>, Self> {
1690            Ok(self.into())
1691        }
1692    }
1693
1694    impl IdentityProviderWithExtension {
1695        // True if the identity starts with the character `foo` from `TestExtension` or if `TestExtension`
1696        // is not set.
1697        #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
1698        async fn starts_with_foo(
1699            &self,
1700            identity: &SigningIdentity,
1701            _timestamp: Option<MlsTime>,
1702            extensions: Option<&ExtensionList>,
1703        ) -> bool {
1704            if let Some(extensions) = extensions {
1705                if let Some(ext) = extensions.get_as::<TestExtension>().unwrap() {
1706                    self.identity(identity, extensions).await.unwrap()[0] == ext.foo
1707                } else {
1708                    true
1709                }
1710            } else {
1711                true
1712            }
1713        }
1714    }
1715
1716    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
1717    #[cfg_attr(mls_build_async, maybe_async::must_be_async)]
1718    impl IdentityProvider for IdentityProviderWithExtension {
1719        type Error = IdentityProviderWithExtensionError;
1720
1721        async fn validate_member(
1722            &self,
1723            identity: &SigningIdentity,
1724            timestamp: Option<MlsTime>,
1725            context: MemberValidationContext<'_>,
1726        ) -> Result<(), Self::Error> {
1727            self.starts_with_foo(identity, timestamp, context.new_extensions())
1728                .await
1729                .then_some(())
1730                .ok_or(IdentityProviderWithExtensionError {})
1731        }
1732
1733        async fn validate_external_sender(
1734            &self,
1735            identity: &SigningIdentity,
1736            timestamp: Option<MlsTime>,
1737            extensions: Option<&ExtensionList>,
1738        ) -> Result<(), Self::Error> {
1739            (!self.starts_with_foo(identity, timestamp, extensions).await)
1740                .then_some(())
1741                .ok_or(IdentityProviderWithExtensionError {})
1742        }
1743
1744        async fn identity(
1745            &self,
1746            signing_identity: &SigningIdentity,
1747            extensions: &ExtensionList,
1748        ) -> Result<Vec<u8>, Self::Error> {
1749            self.0
1750                .identity(signing_identity, extensions)
1751                .await
1752                .map_err(|_| IdentityProviderWithExtensionError {})
1753        }
1754
1755        async fn valid_successor(
1756            &self,
1757            _predecessor: &SigningIdentity,
1758            _successor: &SigningIdentity,
1759            _extensions: &ExtensionList,
1760        ) -> Result<bool, Self::Error> {
1761            Ok(true)
1762        }
1763
1764        fn supported_types(&self) -> Vec<CredentialType> {
1765            self.0.supported_types()
1766        }
1767    }
1768
1769    type ExtensionClientConfig = WithIdentityProvider<
1770        IdentityProviderWithExtension,
1771        WithCryptoProvider<TestCryptoProvider, BaseConfig>,
1772    >;
1773
1774    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
1775    async fn client_with_test_extension(name: &[u8]) -> Client<ExtensionClientConfig> {
1776        let (identity, secret_key) = get_test_signing_identity(TEST_CIPHER_SUITE, name).await;
1777
1778        ClientBuilder::new()
1779            .crypto_provider(TestCryptoProvider::new())
1780            .extension_types(vec![TEST_EXTENSION_TYPE.into()])
1781            .identity_provider(IdentityProviderWithExtension(BasicIdentityProvider::new()))
1782            .signing_identity(identity, secret_key, TEST_CIPHER_SUITE)
1783            .build()
1784    }
1785
1786    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1787    async fn detached_commit() {
1788        let mut group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
1789
1790        let (_commit, secrets) = group.commit_builder().build_detached().await.unwrap();
1791        assert!(group.pending_commit.is_none());
1792        group.apply_detached_commit(secrets).await.unwrap();
1793        assert_eq!(group.context().epoch, 1);
1794    }
1795
1796    #[cfg(feature = "tree_index")]
1797    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1798    async fn tree_index_consistent_after_committer_self_update() {
1799        use crate::identity::basic::BasicIdentityProvider;
1800        use crate::tree_kem::TreeKemPublic;
1801
1802        let mut group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
1803
1804        group.commit(vec![]).await.unwrap();
1805        group.process_pending_commit().await.unwrap();
1806
1807        let mut rebuilt = TreeKemPublic::import_node_data(
1808            group.state.public_tree.nodes.clone(),
1809            &BasicIdentityProvider,
1810            &Default::default(),
1811        )
1812        .await
1813        .unwrap();
1814
1815        let cs = test_cipher_suite_provider(TEST_CIPHER_SUITE);
1816        rebuilt.tree_hash(&cs).await.unwrap();
1817
1818        assert!(group.state.public_tree.equal_internals(&rebuilt));
1819    }
1820}