mls_rs/group/
commit.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// Copyright by contributors to this project.
3// SPDX-License-Identifier: (Apache-2.0 OR MIT)
4
5use alloc::boxed::Box;
6use alloc::vec;
7use alloc::vec::Vec;
8use core::fmt::Debug;
9use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
10use mls_rs_core::{crypto::SignatureSecretKey, error::IntoAnyError};
11
12use crate::{
13    cipher_suite::CipherSuite,
14    client::MlsError,
15    client_config::ClientConfig,
16    extension::RatchetTreeExt,
17    identity::SigningIdentity,
18    protocol_version::ProtocolVersion,
19    signer::Signable,
20    tree_kem::{
21        kem::TreeKem, node::LeafIndex, path_secret::PathSecret, TreeKemPrivate, UpdatePath,
22    },
23    ExtensionList, MlsRules,
24};
25
26#[cfg(all(not(mls_build_async), feature = "rayon"))]
27use {crate::iter::ParallelIteratorExt, rayon::prelude::*};
28
29use crate::tree_kem::leaf_node::LeafNode;
30
31#[cfg(not(feature = "private_message"))]
32use crate::WireFormat;
33
34#[cfg(feature = "psk")]
35use crate::{
36    group::{JustPreSharedKeyID, PskGroupId, ResumptionPSKUsage, ResumptionPsk},
37    psk::ExternalPskId,
38};
39
40use super::{
41    confirmation_tag::ConfirmationTag,
42    framing::{Content, MlsMessage, MlsMessagePayload, Sender},
43    key_schedule::{KeySchedule, WelcomeSecret},
44    message_hash::MessageHash,
45    message_processor::{path_update_required, MessageProcessor},
46    message_signature::AuthenticatedContent,
47    mls_rules::CommitDirection,
48    proposal::{Proposal, ProposalOrRef},
49    CommitEffect, CommitMessageDescription, EncryptedGroupSecrets, EpochSecrets, ExportedTree,
50    Group, GroupContext, GroupInfo, GroupState, InterimTranscriptHash, NewEpoch,
51    PendingCommitSnapshot, Welcome,
52};
53
54#[cfg(not(feature = "by_ref_proposal"))]
55use super::proposal_cache::prepare_commit;
56
57#[cfg(feature = "custom_proposal")]
58use super::proposal::CustomProposal;
59
60#[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode)]
61#[cfg_attr(feature = "arbitrary", derive(mls_rs_core::arbitrary::Arbitrary))]
62#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
63pub(crate) struct Commit {
64    pub proposals: Vec<ProposalOrRef>,
65    pub path: Option<UpdatePath>,
66}
67
68#[derive(Clone, PartialEq, Debug, MlsEncode, MlsDecode, MlsSize)]
69pub(crate) struct PendingCommit {
70    pub(crate) state: GroupState,
71    pub(crate) epoch_secrets: EpochSecrets,
72    pub(crate) private_tree: TreeKemPrivate,
73    pub(crate) key_schedule: KeySchedule,
74    pub(crate) signer: SignatureSecretKey,
75
76    pub(crate) output: CommitMessageDescription,
77
78    pub(crate) commit_message_hash: MessageHash,
79}
80
81#[cfg_attr(
82    all(feature = "ffi", not(test)),
83    safer_ffi_gen::ffi_type(clone, opaque)
84)]
85#[derive(Clone)]
86pub struct CommitSecrets(pub(crate) PendingCommitSnapshot);
87
88impl CommitSecrets {
89    /// Deserialize the commit secrets from bytes
90    pub fn from_bytes(bytes: &[u8]) -> Result<Self, MlsError> {
91        Ok(MlsDecode::mls_decode(&mut &*bytes).map(Self)?)
92    }
93
94    /// Serialize the commit secrets to bytes
95    pub fn to_bytes(&self) -> Result<Vec<u8>, MlsError> {
96        Ok(self.0.mls_encode_to_vec()?)
97    }
98}
99
100#[cfg_attr(
101    all(feature = "ffi", not(test)),
102    safer_ffi_gen::ffi_type(clone, opaque)
103)]
104#[derive(Clone, Debug)]
105#[non_exhaustive]
106/// Result of MLS commit operation using
107/// [`Group::commit`](crate::group::Group::commit) or
108/// [`CommitBuilder::build`](CommitBuilder::build).
109pub struct CommitOutput {
110    /// Commit message to send to other group members.
111    pub commit_message: MlsMessage,
112    /// Welcome messages to send to new group members. If the commit does not add members,
113    /// this list is empty. Otherwise, if [`MlsRules::commit_options`] returns `single_welcome_message`
114    /// set to true, then this list contains a single message sent to all members. Else, the list
115    /// contains one message for each added member. Recipients of each message can be identified using
116    /// [`MlsMessage::key_package_reference`] of their key packages and
117    /// [`MlsMessage::welcome_key_package_references`].
118    pub welcome_messages: Vec<MlsMessage>,
119    /// Ratchet tree that can be sent out of band if
120    /// `ratchet_tree_extension` is not used according to
121    /// [`MlsRules::commit_options`].
122    pub ratchet_tree: Option<ExportedTree<'static>>,
123    /// A group info that can be provided to new members in order to enable external commit
124    /// functionality. This value is set if [`MlsRules::commit_options`] returns
125    /// `allow_external_commit` set to true.
126    pub external_commit_group_info: Option<MlsMessage>,
127    /// Proposals that were received in the prior epoch but not included in the following commit.
128    #[cfg(feature = "by_ref_proposal")]
129    pub unused_proposals: Vec<crate::mls_rules::ProposalInfo<Proposal>>,
130    /// Indicator that the commit contains a path update
131    pub contains_update_path: bool,
132}
133
134#[cfg_attr(all(feature = "ffi", not(test)), ::safer_ffi_gen::safer_ffi_gen)]
135impl CommitOutput {
136    /// Commit message to send to other group members.
137    #[cfg(feature = "ffi")]
138    pub fn commit_message(&self) -> &MlsMessage {
139        &self.commit_message
140    }
141
142    /// Welcome message to send to new group members.
143    #[cfg(feature = "ffi")]
144    pub fn welcome_messages(&self) -> &[MlsMessage] {
145        &self.welcome_messages
146    }
147
148    /// Ratchet tree that can be sent out of band if
149    /// `ratchet_tree_extension` is not used according to
150    /// [`MlsRules::commit_options`].
151    #[cfg(feature = "ffi")]
152    pub fn ratchet_tree(&self) -> Option<&ExportedTree<'static>> {
153        self.ratchet_tree.as_ref()
154    }
155
156    /// A group info that can be provided to new members in order to enable external commit
157    /// functionality. This value is set if [`MlsRules::commit_options`] returns
158    /// `allow_external_commit` set to true.
159    #[cfg(feature = "ffi")]
160    pub fn external_commit_group_info(&self) -> Option<&MlsMessage> {
161        self.external_commit_group_info.as_ref()
162    }
163
164    /// Proposals that were received in the prior epoch but not included in the following commit.
165    #[cfg(all(feature = "ffi", feature = "by_ref_proposal"))]
166    pub fn unused_proposals(&self) -> &[crate::mls_rules::ProposalInfo<Proposal>] {
167        &self.unused_proposals
168    }
169}
170
171/// Build a commit with multiple proposals by-value.
172///
173/// Proposals within a commit can be by-value or by-reference.
174/// Proposals received during the current epoch will be added to the resulting
175/// commit by-reference automatically so long as they pass the rules defined
176/// in the current
177/// [proposal rules](crate::client_builder::ClientBuilder::mls_rules).
178pub struct CommitBuilder<'a, C>
179where
180    C: ClientConfig + Clone,
181{
182    group: &'a mut Group<C>,
183    pub(super) proposals: Vec<Proposal>,
184    authenticated_data: Vec<u8>,
185    group_info_extensions: ExtensionList,
186    new_signer: Option<SignatureSecretKey>,
187    new_signing_identity: Option<SigningIdentity>,
188    new_leaf_node_extensions: Option<ExtensionList>,
189}
190
191impl<'a, C> CommitBuilder<'a, C>
192where
193    C: ClientConfig + Clone,
194{
195    /// Insert an [`AddProposal`](crate::group::proposal::AddProposal) into
196    /// the current commit that is being built.
197    pub fn add_member(mut self, key_package: MlsMessage) -> Result<CommitBuilder<'a, C>, MlsError> {
198        let proposal = self.group.add_proposal(key_package)?;
199        self.proposals.push(proposal);
200        Ok(self)
201    }
202
203    /// Set group info extensions that will be inserted into the resulting
204    /// [welcome messages](CommitOutput::welcome_messages) for new members.
205    ///
206    /// Group info extensions that are transmitted as part of a welcome message
207    /// are encrypted along with other private values.
208    ///
209    /// These extensions can be retrieved as part of
210    /// [`NewMemberInfo`](crate::group::NewMemberInfo) that is returned
211    /// by joining the group via
212    /// [`Client::join_group`](crate::Client::join_group).
213    pub fn set_group_info_ext(self, extensions: ExtensionList) -> Self {
214        Self {
215            group_info_extensions: extensions,
216            ..self
217        }
218    }
219
220    /// Insert a [`RemoveProposal`](crate::group::proposal::RemoveProposal) into
221    /// the current commit that is being built.
222    pub fn remove_member(mut self, index: u32) -> Result<Self, MlsError> {
223        let proposal = self.group.remove_proposal(index)?;
224        self.proposals.push(proposal);
225        Ok(self)
226    }
227
228    /// Insert a
229    /// [`GroupContextExtensions`](crate::group::proposal::Proposal::GroupContextExtensions)
230    /// into the current commit that is being built.
231    pub fn set_group_context_ext(mut self, extensions: ExtensionList) -> Result<Self, MlsError> {
232        let proposal = self.group.group_context_extensions_proposal(extensions);
233        self.proposals.push(proposal);
234        Ok(self)
235    }
236
237    /// Insert a
238    /// [`PreSharedKeyProposal`](crate::group::proposal::PreSharedKeyProposal) with
239    /// an external PSK into the current commit that is being built.
240    #[cfg(feature = "psk")]
241    pub fn add_external_psk(mut self, psk_id: ExternalPskId) -> Result<Self, MlsError> {
242        let key_id = JustPreSharedKeyID::External(psk_id);
243        let proposal = self.group.psk_proposal(key_id)?;
244        self.proposals.push(proposal);
245        Ok(self)
246    }
247
248    /// Insert a
249    /// [`PreSharedKeyProposal`](crate::group::proposal::PreSharedKeyProposal) with
250    /// a resumption PSK into the current commit that is being built.
251    #[cfg(feature = "psk")]
252    pub fn add_resumption_psk(mut self, psk_epoch: u64) -> Result<Self, MlsError> {
253        let psk_id = ResumptionPsk {
254            psk_epoch,
255            usage: ResumptionPSKUsage::Application,
256            psk_group_id: PskGroupId(self.group.group_id().to_vec()),
257        };
258
259        let key_id = JustPreSharedKeyID::Resumption(psk_id);
260        let proposal = self.group.psk_proposal(key_id)?;
261        self.proposals.push(proposal);
262        Ok(self)
263    }
264
265    /// Insert a [`ReInitProposal`](crate::group::proposal::ReInitProposal) into
266    /// the current commit that is being built.
267    pub fn reinit(
268        mut self,
269        group_id: Option<Vec<u8>>,
270        version: ProtocolVersion,
271        cipher_suite: CipherSuite,
272        extensions: ExtensionList,
273    ) -> Result<Self, MlsError> {
274        let proposal = self
275            .group
276            .reinit_proposal(group_id, version, cipher_suite, extensions)?;
277
278        self.proposals.push(proposal);
279        Ok(self)
280    }
281
282    /// Insert a [`CustomProposal`](crate::group::proposal::CustomProposal) into
283    /// the current commit that is being built.
284    #[cfg(feature = "custom_proposal")]
285    pub fn custom_proposal(mut self, proposal: CustomProposal) -> Self {
286        self.proposals.push(Proposal::Custom(proposal));
287        self
288    }
289
290    /// Insert a proposal that was previously constructed such as when a
291    /// proposal is returned from
292    /// [`NewEpoch::unused_proposals`](super::NewEpoch::unused_proposals).
293    pub fn raw_proposal(mut self, proposal: Proposal) -> Self {
294        self.proposals.push(proposal);
295        self
296    }
297
298    /// Insert proposals that were previously constructed such as when a
299    /// proposal is returned from
300    /// [`NewEpoch::unused_proposals`](super::NewEpoch::unused_proposals).
301    pub fn raw_proposals(mut self, mut proposals: Vec<Proposal>) -> Self {
302        self.proposals.append(&mut proposals);
303        self
304    }
305
306    /// Add additional authenticated data to the commit.
307    ///
308    /// # Warning
309    ///
310    /// The data provided here is always sent unencrypted.
311    pub fn authenticated_data(self, authenticated_data: Vec<u8>) -> Self {
312        Self {
313            authenticated_data,
314            ..self
315        }
316    }
317
318    /// Change the committer's signing identity as part of making this commit.
319    /// This will only succeed if the [`IdentityProvider`](crate::IdentityProvider)
320    /// in use by the group considers the credential inside this signing_identity
321    /// [valid](crate::IdentityProvider::validate_member)
322    /// and results in the same
323    /// [identity](crate::IdentityProvider::identity)
324    /// being used.
325    pub fn set_new_signing_identity(
326        self,
327        signer: SignatureSecretKey,
328        signing_identity: SigningIdentity,
329    ) -> Self {
330        Self {
331            new_signer: Some(signer),
332            new_signing_identity: Some(signing_identity),
333            ..self
334        }
335    }
336
337    /// Change the committer's leaf node extensions as part of making this commit.
338    pub fn set_leaf_node_extensions(self, new_leaf_node_extensions: ExtensionList) -> Self {
339        Self {
340            new_leaf_node_extensions: Some(new_leaf_node_extensions),
341            ..self
342        }
343    }
344
345    /// Finalize the commit to send.
346    ///
347    /// # Errors
348    ///
349    /// This function will return an error if any of the proposals provided
350    /// are not contextually valid according to the rules defined by the
351    /// MLS RFC, or if they do not pass the custom rules defined by the current
352    /// [proposal rules](crate::client_builder::ClientBuilder::mls_rules).
353    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
354    pub async fn build(self) -> Result<CommitOutput, MlsError> {
355        let (output, pending_commit) = self
356            .group
357            .commit_internal(
358                self.proposals,
359                None,
360                self.authenticated_data,
361                self.group_info_extensions,
362                self.new_signer,
363                self.new_signing_identity,
364                self.new_leaf_node_extensions,
365            )
366            .await?;
367
368        self.group.pending_commit = pending_commit.try_into()?;
369
370        Ok(output)
371    }
372
373    /// The same function as `GroupBuilder::build` except the secrets generated
374    /// for the commit are outputted instead of being cached internally.
375    ///
376    /// A detached commit can be applied using `Group::apply_detached_commit`.
377    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
378    pub async fn build_detached(self) -> Result<(CommitOutput, CommitSecrets), MlsError> {
379        let (output, pending_commit) = self
380            .group
381            .commit_internal(
382                self.proposals,
383                None,
384                self.authenticated_data,
385                self.group_info_extensions,
386                self.new_signer,
387                self.new_signing_identity,
388                self.new_leaf_node_extensions,
389            )
390            .await?;
391
392        Ok((
393            output,
394            CommitSecrets(PendingCommitSnapshot::PendingCommit(
395                pending_commit.mls_encode_to_vec()?,
396            )),
397        ))
398    }
399}
400
401impl<C> Group<C>
402where
403    C: ClientConfig + Clone,
404{
405    /// Perform a commit of received proposals.
406    ///
407    /// This function is the equivalent of [`Group::commit_builder`] immediately
408    /// followed by [`CommitBuilder::build`]. Any received proposals since the
409    /// last commit will be included in the resulting message by-reference.
410    ///
411    /// Data provided in the `authenticated_data` field will be placed into
412    /// the resulting commit message unencrypted.
413    ///
414    /// # Pending Commits
415    ///
416    /// When a commit is created, it is not applied immediately in order to
417    /// allow for the resolution of conflicts when multiple members of a group
418    /// attempt to make commits at the same time. For example, a central relay
419    /// can be used to decide which commit should be accepted by the group by
420    /// determining a consistent view of commit packet order for all clients.
421    ///
422    /// Pending commits are stored internally as part of the group's state
423    /// so they do not need to be tracked outside of this library. Any commit
424    /// message that is processed before calling [Group::apply_pending_commit]
425    /// will clear the currently pending commit.
426    ///
427    /// # Empty Commits
428    ///
429    /// Sending a commit that contains no proposals is a valid operation
430    /// within the MLS protocol. It is useful for providing stronger forward
431    /// secrecy and post-compromise security, especially for long running
432    /// groups when group membership does not change often.
433    ///
434    /// # Path Updates
435    ///
436    /// Path updates provide forward secrecy and post-compromise security
437    /// within the MLS protocol.
438    /// The `path_required` option returned by [`MlsRules::commit_options`](`crate::MlsRules::commit_options`)
439    /// controls the ability of a group to send a commit without a path update.
440    /// An update path will automatically be sent if there are no proposals
441    /// in the commit, or if any proposal other than
442    /// [`Add`](crate::group::proposal::Proposal::Add),
443    /// [`Psk`](crate::group::proposal::Proposal::Psk),
444    /// or [`ReInit`](crate::group::proposal::Proposal::ReInit) are part of the commit.
445    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
446    pub async fn commit(&mut self, authenticated_data: Vec<u8>) -> Result<CommitOutput, MlsError> {
447        self.commit_builder()
448            .authenticated_data(authenticated_data)
449            .build()
450            .await
451    }
452
453    /// The same function as `Group::commit` except the secrets generated
454    /// for the commit are outputted instead of being cached internally.
455    ///
456    /// A detached commit can be applied using `Group::apply_detached_commit`.
457    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
458    pub async fn commit_detached(
459        &mut self,
460        authenticated_data: Vec<u8>,
461    ) -> Result<(CommitOutput, CommitSecrets), MlsError> {
462        self.commit_builder()
463            .authenticated_data(authenticated_data)
464            .build_detached()
465            .await
466    }
467
468    /// Create a new commit builder that can include proposals
469    /// by-value.
470    pub fn commit_builder(&mut self) -> CommitBuilder<C> {
471        CommitBuilder {
472            group: self,
473            proposals: Default::default(),
474            authenticated_data: Default::default(),
475            group_info_extensions: Default::default(),
476            new_signer: Default::default(),
477            new_signing_identity: Default::default(),
478            new_leaf_node_extensions: Default::default(),
479        }
480    }
481
482    /// Returns commit and optional [`MlsMessage`] containing a welcome message
483    /// for newly added members.
484    #[allow(clippy::too_many_arguments)]
485    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
486    pub(super) async fn commit_internal(
487        &mut self,
488        proposals: Vec<Proposal>,
489        external_leaf: Option<&LeafNode>,
490        authenticated_data: Vec<u8>,
491        mut welcome_group_info_extensions: ExtensionList,
492        new_signer: Option<SignatureSecretKey>,
493        new_signing_identity: Option<SigningIdentity>,
494        new_leaf_node_extensions: Option<ExtensionList>,
495    ) -> Result<(CommitOutput, PendingCommit), MlsError> {
496        if !self.pending_commit.is_none() {
497            return Err(MlsError::ExistingPendingCommit);
498        }
499
500        if self.state.pending_reinit.is_some() {
501            return Err(MlsError::GroupUsedAfterReInit);
502        }
503
504        let mls_rules = self.config.mls_rules();
505
506        let is_external = external_leaf.is_some();
507
508        // Construct an initial Commit object with the proposals field populated from Proposals
509        // received during the current epoch, and an empty path field. Add passed in proposals
510        // by value
511        let sender = if is_external {
512            Sender::NewMemberCommit
513        } else {
514            Sender::Member(*self.private_tree.self_index)
515        };
516
517        let new_signer = new_signer.unwrap_or_else(|| self.signer.clone());
518        let old_signer = &self.signer;
519
520        #[cfg(feature = "std")]
521        let time = Some(crate::time::MlsTime::now());
522
523        #[cfg(not(feature = "std"))]
524        let time = None;
525
526        #[cfg(feature = "by_ref_proposal")]
527        let proposals = self.state.proposals.prepare_commit(sender, proposals);
528
529        #[cfg(not(feature = "by_ref_proposal"))]
530        let proposals = prepare_commit(sender, proposals);
531
532        let mut provisional_state = self
533            .state
534            .apply_resolved(
535                sender,
536                proposals,
537                external_leaf,
538                &self.config.identity_provider(),
539                &self.cipher_suite_provider,
540                &self.config.secret_store(),
541                &mls_rules,
542                time,
543                CommitDirection::Send,
544            )
545            .await?;
546
547        let (mut provisional_private_tree, _) =
548            self.provisional_private_tree(&provisional_state)?;
549
550        if is_external {
551            provisional_private_tree.self_index = provisional_state
552                .external_init_index
553                .ok_or(MlsError::ExternalCommitMissingExternalInit)?;
554
555            self.private_tree.self_index = provisional_private_tree.self_index;
556        }
557
558        // Decide whether to populate the path field: If the path field is required based on the
559        // proposals that are in the commit (see above), then it MUST be populated. Otherwise, the
560        // sender MAY omit the path field at its discretion.
561        let commit_options = mls_rules
562            .commit_options(
563                &provisional_state.public_tree.roster(),
564                &provisional_state.group_context,
565                &provisional_state.applied_proposals,
566            )
567            .map_err(|e| MlsError::MlsRulesError(e.into_any_error()))?;
568
569        let perform_path_update = commit_options.path_required
570            || path_update_required(&provisional_state.applied_proposals);
571
572        let (update_path, path_secrets, commit_secret) = if perform_path_update {
573            // If populating the path field: Create an UpdatePath using the new tree. Any new
574            // member (from an add proposal) MUST be excluded from the resolution during the
575            // computation of the UpdatePath. The GroupContext for this operation uses the
576            // group_id, epoch, tree_hash, and confirmed_transcript_hash values in the initial
577            // GroupContext object. The leaf_key_package for this UpdatePath must have a
578            // parent_hash extension.
579
580            let new_leaf_node_extensions =
581                new_leaf_node_extensions.or(external_leaf.map(|ln| ln.ungreased_extensions()));
582
583            let new_leaf_node_extensions = match new_leaf_node_extensions {
584                Some(extensions) => extensions,
585                // If we are not setting new extensions and this is not an external leaf then the current node MUST exist.
586                None => self.current_user_leaf_node()?.ungreased_extensions(),
587            };
588
589            let encap_gen = TreeKem::new(
590                &mut provisional_state.public_tree,
591                &mut provisional_private_tree,
592            )
593            .encap(
594                &mut provisional_state.group_context,
595                &provisional_state.indexes_of_added_kpkgs,
596                &new_signer,
597                Some(self.config.leaf_properties(new_leaf_node_extensions)),
598                new_signing_identity,
599                &self.cipher_suite_provider,
600                #[cfg(test)]
601                &self.commit_modifiers,
602            )
603            .await?;
604
605            (
606                Some(encap_gen.update_path),
607                Some(encap_gen.path_secrets),
608                encap_gen.commit_secret,
609            )
610        } else {
611            // Update the tree hash, since it was not updated by encap.
612            provisional_state
613                .public_tree
614                .update_hashes(
615                    &[provisional_private_tree.self_index],
616                    &self.cipher_suite_provider,
617                )
618                .await?;
619
620            provisional_state.group_context.tree_hash = provisional_state
621                .public_tree
622                .tree_hash(&self.cipher_suite_provider)
623                .await?;
624
625            (None, None, PathSecret::empty(&self.cipher_suite_provider))
626        };
627
628        #[cfg(feature = "psk")]
629        let (psk_secret, psks) = self
630            .get_psk(&provisional_state.applied_proposals.psks)
631            .await?;
632
633        #[cfg(not(feature = "psk"))]
634        let psk_secret = self.get_psk();
635
636        let added_key_pkgs: Vec<_> = provisional_state
637            .applied_proposals
638            .additions
639            .iter()
640            .map(|info| info.proposal.key_package.clone())
641            .collect();
642
643        let commit = Commit {
644            proposals: provisional_state.applied_proposals.proposals_or_refs(),
645            path: update_path,
646        };
647
648        let mut auth_content = AuthenticatedContent::new_signed(
649            &self.cipher_suite_provider,
650            self.context(),
651            sender,
652            Content::Commit(Box::new(commit)),
653            old_signer,
654            #[cfg(feature = "private_message")]
655            self.encryption_options()?.control_wire_format(sender),
656            #[cfg(not(feature = "private_message"))]
657            WireFormat::PublicMessage,
658            authenticated_data,
659        )
660        .await?;
661
662        // Use the signature, the commit_secret and the psk_secret to advance the key schedule and
663        // compute the confirmation_tag value in the MlsPlaintext.
664        let confirmed_transcript_hash = super::transcript_hash::create(
665            self.cipher_suite_provider(),
666            &self.state.interim_transcript_hash,
667            &auth_content,
668        )
669        .await?;
670
671        provisional_state.group_context.confirmed_transcript_hash = confirmed_transcript_hash;
672
673        let key_schedule_result = KeySchedule::from_key_schedule(
674            &self.key_schedule,
675            &commit_secret,
676            &provisional_state.group_context,
677            #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
678            provisional_state.public_tree.total_leaf_count(),
679            &psk_secret,
680            &self.cipher_suite_provider,
681        )
682        .await?;
683
684        let confirmation_tag = ConfirmationTag::create(
685            &key_schedule_result.confirmation_key,
686            &provisional_state.group_context.confirmed_transcript_hash,
687            &self.cipher_suite_provider,
688        )
689        .await?;
690
691        let interim_transcript_hash = InterimTranscriptHash::create(
692            self.cipher_suite_provider(),
693            &provisional_state.group_context.confirmed_transcript_hash,
694            &confirmation_tag,
695        )
696        .await?;
697
698        auth_content.auth.confirmation_tag = Some(confirmation_tag.clone());
699
700        let ratchet_tree_ext = commit_options
701            .ratchet_tree_extension
702            .then(|| RatchetTreeExt {
703                tree_data: ExportedTree::new(provisional_state.public_tree.nodes.clone()),
704            });
705
706        // Generate external commit group info if required by commit_options
707        let external_commit_group_info = match commit_options.allow_external_commit {
708            true => {
709                let mut extensions = ExtensionList::new();
710
711                extensions.set_from({
712                    key_schedule_result
713                        .key_schedule
714                        .get_external_key_pair_ext(&self.cipher_suite_provider)
715                        .await?
716                })?;
717
718                if let Some(ref ratchet_tree_ext) = ratchet_tree_ext {
719                    if !commit_options.always_out_of_band_ratchet_tree {
720                        extensions.set_from(ratchet_tree_ext.clone())?;
721                    }
722                }
723
724                let info = self
725                    .make_group_info(
726                        &provisional_state.group_context,
727                        extensions,
728                        &confirmation_tag,
729                        &new_signer,
730                    )
731                    .await?;
732
733                let msg =
734                    MlsMessage::new(self.protocol_version(), MlsMessagePayload::GroupInfo(info));
735
736                Some(msg)
737            }
738            false => None,
739        };
740
741        // Build the group info that will be placed into the welcome messages.
742        // Add the ratchet tree extension if necessary
743        if let Some(ratchet_tree_ext) = ratchet_tree_ext {
744            welcome_group_info_extensions.set_from(ratchet_tree_ext)?;
745        }
746
747        let welcome_group_info = self
748            .make_group_info(
749                &provisional_state.group_context,
750                welcome_group_info_extensions,
751                &confirmation_tag,
752                &new_signer,
753            )
754            .await?;
755
756        // Encrypt the GroupInfo using the key and nonce derived from the joiner_secret for
757        // the new epoch
758        let welcome_secret = WelcomeSecret::from_joiner_secret(
759            &self.cipher_suite_provider,
760            &key_schedule_result.joiner_secret,
761            &psk_secret,
762        )
763        .await?;
764
765        let encrypted_group_info = welcome_secret
766            .encrypt(&welcome_group_info.mls_encode_to_vec()?)
767            .await?;
768
769        // Encrypt path secrets and joiner secret to new members
770        let path_secrets = path_secrets.as_ref();
771
772        #[cfg(not(any(mls_build_async, not(feature = "rayon"))))]
773        let encrypted_path_secrets: Vec<_> = added_key_pkgs
774            .into_par_iter()
775            .zip(&provisional_state.indexes_of_added_kpkgs)
776            .map(|(key_package, leaf_index)| {
777                self.encrypt_group_secrets(
778                    &key_package,
779                    *leaf_index,
780                    &key_schedule_result.joiner_secret,
781                    path_secrets,
782                    #[cfg(feature = "psk")]
783                    psks.clone(),
784                    &encrypted_group_info,
785                )
786            })
787            .try_collect()?;
788
789        #[cfg(any(mls_build_async, not(feature = "rayon")))]
790        let encrypted_path_secrets = {
791            let mut secrets = Vec::new();
792
793            for (key_package, leaf_index) in added_key_pkgs
794                .into_iter()
795                .zip(&provisional_state.indexes_of_added_kpkgs)
796            {
797                secrets.push(
798                    self.encrypt_group_secrets(
799                        &key_package,
800                        *leaf_index,
801                        &key_schedule_result.joiner_secret,
802                        path_secrets,
803                        #[cfg(feature = "psk")]
804                        psks.clone(),
805                        &encrypted_group_info,
806                    )
807                    .await?,
808                );
809            }
810
811            secrets
812        };
813
814        let welcome_messages =
815            if commit_options.single_welcome_message && !encrypted_path_secrets.is_empty() {
816                vec![self.make_welcome_message(encrypted_path_secrets, encrypted_group_info)]
817            } else {
818                encrypted_path_secrets
819                    .into_iter()
820                    .map(|s| self.make_welcome_message(vec![s], encrypted_group_info.clone()))
821                    .collect()
822            };
823
824        let commit_message = self.format_for_wire(auth_content.clone()).await?;
825
826        // TODO is it necessary to clone the tree here? or can we just output serialized bytes?
827        let ratchet_tree = (!commit_options.ratchet_tree_extension
828            || commit_options.always_out_of_band_ratchet_tree)
829            .then(|| ExportedTree::new(provisional_state.public_tree.nodes.clone()));
830
831        let pending_reinit = provisional_state
832            .applied_proposals
833            .reinitializations
834            .first();
835
836        let pending_commit = PendingCommit {
837            output: CommitMessageDescription {
838                is_external: matches!(auth_content.content.sender, Sender::NewMemberCommit),
839                authenticated_data: auth_content.content.authenticated_data,
840                committer: *provisional_private_tree.self_index,
841                effect: match pending_reinit {
842                    Some(r) => CommitEffect::ReInit(r.clone()),
843                    None => CommitEffect::NewEpoch(
844                        NewEpoch::new(self.state.clone(), &provisional_state).into(),
845                    ),
846                },
847            },
848
849            state: GroupState {
850                #[cfg(feature = "by_ref_proposal")]
851                proposals: crate::group::ProposalCache::new(
852                    self.protocol_version(),
853                    self.group_id().to_vec(),
854                ),
855                context: provisional_state.group_context,
856                public_tree: provisional_state.public_tree,
857                interim_transcript_hash,
858                pending_reinit: pending_reinit.map(|r| r.proposal.clone()),
859                confirmation_tag,
860            },
861
862            commit_message_hash: MessageHash::compute(&self.cipher_suite_provider, &commit_message)
863                .await?,
864            signer: new_signer,
865            epoch_secrets: key_schedule_result.epoch_secrets,
866            key_schedule: key_schedule_result.key_schedule,
867
868            private_tree: provisional_private_tree,
869        };
870
871        let output = CommitOutput {
872            commit_message,
873            welcome_messages,
874            ratchet_tree,
875            external_commit_group_info,
876            contains_update_path: perform_path_update,
877            #[cfg(feature = "by_ref_proposal")]
878            unused_proposals: provisional_state.unused_proposals,
879        };
880
881        Ok((output, pending_commit))
882    }
883
884    // Construct a GroupInfo reflecting the new state
885    // Group ID, epoch, tree, and confirmed transcript hash from the new state
886    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
887    async fn make_group_info(
888        &self,
889        group_context: &GroupContext,
890        extensions: ExtensionList,
891        confirmation_tag: &ConfirmationTag,
892        signer: &SignatureSecretKey,
893    ) -> Result<GroupInfo, MlsError> {
894        let mut group_info = GroupInfo {
895            group_context: group_context.clone(),
896            extensions,
897            confirmation_tag: confirmation_tag.clone(), // The confirmation_tag from the MlsPlaintext object
898            signer: LeafIndex(self.current_member_index()),
899            signature: vec![],
900        };
901
902        group_info.grease(self.cipher_suite_provider())?;
903
904        // Sign the GroupInfo using the member's private signing key
905        group_info
906            .sign(&self.cipher_suite_provider, signer, &())
907            .await?;
908
909        Ok(group_info)
910    }
911
912    fn make_welcome_message(
913        &self,
914        secrets: Vec<EncryptedGroupSecrets>,
915        encrypted_group_info: Vec<u8>,
916    ) -> MlsMessage {
917        MlsMessage::new(
918            self.context().protocol_version,
919            MlsMessagePayload::Welcome(Welcome {
920                cipher_suite: self.context().cipher_suite,
921                secrets,
922                encrypted_group_info,
923            }),
924        )
925    }
926}
927
928#[cfg(test)]
929pub(crate) mod test_utils {
930    use alloc::vec::Vec;
931
932    use crate::{
933        crypto::SignatureSecretKey,
934        tree_kem::{leaf_node::LeafNode, TreeKemPublic, UpdatePathNode},
935    };
936
937    #[derive(Copy, Clone, Debug)]
938    pub struct CommitModifiers {
939        pub modify_leaf: fn(&mut LeafNode, &SignatureSecretKey) -> Option<SignatureSecretKey>,
940        pub modify_tree: fn(&mut TreeKemPublic),
941        pub modify_path: fn(Vec<UpdatePathNode>) -> Vec<UpdatePathNode>,
942    }
943
944    impl Default for CommitModifiers {
945        fn default() -> Self {
946            Self {
947                modify_leaf: |_, _| None,
948                modify_tree: |_| (),
949                modify_path: |a| a,
950            }
951        }
952    }
953}
954
955#[cfg(test)]
956mod tests {
957    use mls_rs_core::{
958        error::IntoAnyError,
959        extension::ExtensionType,
960        identity::{CredentialType, IdentityProvider, MemberValidationContext},
961        time::MlsTime,
962    };
963
964    use crate::extension::RequiredCapabilitiesExt;
965    use crate::{
966        client::test_utils::{test_client_with_key_pkg, TEST_CIPHER_SUITE, TEST_PROTOCOL_VERSION},
967        client_builder::{
968            test_utils::TestClientConfig, BaseConfig, ClientBuilder, WithCryptoProvider,
969            WithIdentityProvider,
970        },
971        client_config::ClientConfig,
972        crypto::test_utils::TestCryptoProvider,
973        extension::test_utils::{TestExtension, TEST_EXTENSION_TYPE},
974        group::test_utils::{test_group, test_group_custom},
975        group::{
976            proposal::ProposalType,
977            test_utils::{test_group_custom_config, test_n_member_group},
978        },
979        identity::test_utils::get_test_signing_identity,
980        identity::{basic::BasicIdentityProvider, test_utils::get_test_basic_credential},
981        key_package::test_utils::test_key_package_message,
982        mls_rules::CommitOptions,
983        Client,
984    };
985
986    #[cfg(feature = "by_ref_proposal")]
987    use crate::crypto::test_utils::test_cipher_suite_provider;
988    #[cfg(feature = "by_ref_proposal")]
989    use crate::extension::ExternalSendersExt;
990    #[cfg(feature = "by_ref_proposal")]
991    use crate::group::mls_rules::DefaultMlsRules;
992
993    #[cfg(feature = "psk")]
994    use crate::{
995        group::proposal::PreSharedKeyProposal,
996        psk::{JustPreSharedKeyID, PreSharedKey, PreSharedKeyID},
997    };
998
999    use super::*;
1000
1001    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
1002    async fn test_commit_builder_group() -> Group<TestClientConfig> {
1003        test_group_custom_config(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, |b| {
1004            b.custom_proposal_type(ProposalType::from(42))
1005                .extension_type(TEST_EXTENSION_TYPE.into())
1006        })
1007        .await
1008        .group
1009    }
1010
1011    fn assert_commit_builder_output<C: ClientConfig>(
1012        group: Group<C>,
1013        mut commit_output: CommitOutput,
1014        expected: Vec<Proposal>,
1015        welcome_count: usize,
1016    ) {
1017        let plaintext = commit_output.commit_message.into_plaintext().unwrap();
1018
1019        let commit_data = match plaintext.content.content {
1020            Content::Commit(commit) => commit,
1021            #[cfg(any(feature = "private_message", feature = "by_ref_proposal"))]
1022            _ => panic!("Found non-commit data"),
1023        };
1024
1025        assert_eq!(commit_data.proposals.len(), expected.len());
1026
1027        commit_data.proposals.into_iter().for_each(|proposal| {
1028            let proposal = match proposal {
1029                ProposalOrRef::Proposal(p) => p,
1030                #[cfg(feature = "by_ref_proposal")]
1031                ProposalOrRef::Reference(_) => panic!("found proposal reference"),
1032            };
1033
1034            #[cfg(feature = "psk")]
1035            if let Some(psk_id) = match proposal.as_ref() {
1036                Proposal::Psk(PreSharedKeyProposal { psk: PreSharedKeyID { key_id: JustPreSharedKeyID::External(psk_id), .. },}) => Some(psk_id),
1037                _ => None,
1038            } {
1039                let found = expected.iter().any(|item| matches!(item, Proposal::Psk(PreSharedKeyProposal { psk: PreSharedKeyID { key_id: JustPreSharedKeyID::External(id), .. }}) if id == psk_id));
1040
1041                assert!(found)
1042            } else {
1043                assert!(expected.contains(&proposal));
1044            }
1045
1046            #[cfg(not(feature = "psk"))]
1047            assert!(expected.contains(&proposal));
1048        });
1049
1050        if welcome_count > 0 {
1051            let welcome_msg = commit_output.welcome_messages.pop().unwrap();
1052
1053            assert_eq!(welcome_msg.version, group.state.context.protocol_version);
1054
1055            let welcome_msg = welcome_msg.into_welcome().unwrap();
1056
1057            assert_eq!(welcome_msg.cipher_suite, group.state.context.cipher_suite);
1058            assert_eq!(welcome_msg.secrets.len(), welcome_count);
1059        } else {
1060            assert!(commit_output.welcome_messages.is_empty());
1061        }
1062    }
1063
1064    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1065    async fn test_commit_builder_add() {
1066        let mut group = test_commit_builder_group().await;
1067
1068        let test_key_package =
1069            test_key_package_message(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "alice").await;
1070
1071        let commit_output = group
1072            .commit_builder()
1073            .add_member(test_key_package.clone())
1074            .unwrap()
1075            .build()
1076            .await
1077            .unwrap();
1078
1079        let expected_add = group.add_proposal(test_key_package).unwrap();
1080
1081        assert_commit_builder_output(group, commit_output, vec![expected_add], 1)
1082    }
1083
1084    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1085    async fn test_commit_builder_add_with_ext() {
1086        let mut group = test_commit_builder_group().await;
1087
1088        let (bob_client, bob_key_package) =
1089            test_client_with_key_pkg(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "bob").await;
1090
1091        let ext = TestExtension { foo: 42 };
1092        let mut extension_list = ExtensionList::default();
1093        extension_list.set_from(ext.clone()).unwrap();
1094
1095        let welcome_message = group
1096            .commit_builder()
1097            .add_member(bob_key_package)
1098            .unwrap()
1099            .set_group_info_ext(extension_list)
1100            .build()
1101            .await
1102            .unwrap()
1103            .welcome_messages
1104            .remove(0);
1105
1106        let (_, context) = bob_client.join_group(None, &welcome_message).await.unwrap();
1107
1108        assert_eq!(
1109            context
1110                .group_info_extensions
1111                .get_as::<TestExtension>()
1112                .unwrap()
1113                .unwrap(),
1114            ext
1115        );
1116    }
1117
1118    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1119    async fn test_commit_builder_remove() {
1120        let mut group = test_commit_builder_group().await;
1121        let test_key_package =
1122            test_key_package_message(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "alice").await;
1123
1124        group
1125            .commit_builder()
1126            .add_member(test_key_package)
1127            .unwrap()
1128            .build()
1129            .await
1130            .unwrap();
1131
1132        group.apply_pending_commit().await.unwrap();
1133
1134        let commit_output = group
1135            .commit_builder()
1136            .remove_member(1)
1137            .unwrap()
1138            .build()
1139            .await
1140            .unwrap();
1141
1142        let expected_remove = group.remove_proposal(1).unwrap();
1143
1144        assert_commit_builder_output(group, commit_output, vec![expected_remove], 0);
1145    }
1146
1147    #[cfg(feature = "psk")]
1148    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1149    async fn test_commit_builder_psk() {
1150        let mut group = test_commit_builder_group().await;
1151        let test_psk = ExternalPskId::new(vec![1]);
1152
1153        group
1154            .config
1155            .secret_store()
1156            .insert(test_psk.clone(), PreSharedKey::from(vec![1]));
1157
1158        let commit_output = group
1159            .commit_builder()
1160            .add_external_psk(test_psk.clone())
1161            .unwrap()
1162            .build()
1163            .await
1164            .unwrap();
1165
1166        let key_id = JustPreSharedKeyID::External(test_psk);
1167        let expected_psk = group.psk_proposal(key_id).unwrap();
1168
1169        assert_commit_builder_output(group, commit_output, vec![expected_psk], 0)
1170    }
1171
1172    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1173    async fn test_commit_builder_group_context_ext() {
1174        let mut group = test_commit_builder_group().await;
1175        let mut test_ext = ExtensionList::default();
1176        test_ext
1177            .set_from(RequiredCapabilitiesExt::default())
1178            .unwrap();
1179
1180        let commit_output = group
1181            .commit_builder()
1182            .set_group_context_ext(test_ext.clone())
1183            .unwrap()
1184            .build()
1185            .await
1186            .unwrap();
1187
1188        let expected_ext = group.group_context_extensions_proposal(test_ext);
1189
1190        assert_commit_builder_output(group, commit_output, vec![expected_ext], 0);
1191    }
1192
1193    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1194    async fn test_commit_builder_reinit() {
1195        let mut group = test_commit_builder_group().await;
1196        let test_group_id = "foo".as_bytes().to_vec();
1197        let test_cipher_suite = TEST_CIPHER_SUITE;
1198        let test_protocol_version = TEST_PROTOCOL_VERSION;
1199        let mut test_ext = ExtensionList::default();
1200
1201        test_ext
1202            .set_from(RequiredCapabilitiesExt::default())
1203            .unwrap();
1204
1205        let commit_output = group
1206            .commit_builder()
1207            .reinit(
1208                Some(test_group_id.clone()),
1209                test_protocol_version,
1210                test_cipher_suite,
1211                test_ext.clone(),
1212            )
1213            .unwrap()
1214            .build()
1215            .await
1216            .unwrap();
1217
1218        let expected_reinit = group
1219            .reinit_proposal(
1220                Some(test_group_id),
1221                test_protocol_version,
1222                test_cipher_suite,
1223                test_ext,
1224            )
1225            .unwrap();
1226
1227        assert_commit_builder_output(group, commit_output, vec![expected_reinit], 0);
1228    }
1229
1230    #[cfg(feature = "custom_proposal")]
1231    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1232    async fn test_commit_builder_custom_proposal() {
1233        let mut group = test_commit_builder_group().await;
1234
1235        let proposal = CustomProposal::new(42.into(), vec![0, 1]);
1236
1237        let commit_output = group
1238            .commit_builder()
1239            .custom_proposal(proposal.clone())
1240            .build()
1241            .await
1242            .unwrap();
1243
1244        assert_commit_builder_output(group, commit_output, vec![Proposal::Custom(proposal)], 0);
1245    }
1246
1247    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1248    async fn test_commit_builder_chaining() {
1249        let mut group = test_commit_builder_group().await;
1250        let kp1 = test_key_package_message(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "alice").await;
1251        let kp2 = test_key_package_message(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "bob").await;
1252
1253        let expected_adds = vec![
1254            group.add_proposal(kp1.clone()).unwrap(),
1255            group.add_proposal(kp2.clone()).unwrap(),
1256        ];
1257
1258        let commit_output = group
1259            .commit_builder()
1260            .add_member(kp1)
1261            .unwrap()
1262            .add_member(kp2)
1263            .unwrap()
1264            .build()
1265            .await
1266            .unwrap();
1267
1268        assert_commit_builder_output(group, commit_output, expected_adds, 2);
1269    }
1270
1271    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1272    async fn test_commit_builder_empty_commit() {
1273        let mut group = test_commit_builder_group().await;
1274
1275        let commit_output = group.commit_builder().build().await.unwrap();
1276
1277        assert_commit_builder_output(group, commit_output, vec![], 0);
1278    }
1279
1280    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1281    async fn test_commit_builder_authenticated_data() {
1282        let mut group = test_commit_builder_group().await;
1283        let test_data = "test".as_bytes().to_vec();
1284
1285        let commit_output = group
1286            .commit_builder()
1287            .authenticated_data(test_data.clone())
1288            .build()
1289            .await
1290            .unwrap();
1291
1292        assert_eq!(
1293            commit_output
1294                .commit_message
1295                .into_plaintext()
1296                .unwrap()
1297                .content
1298                .authenticated_data,
1299            test_data
1300        );
1301    }
1302
1303    #[cfg(feature = "by_ref_proposal")]
1304    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1305    async fn test_commit_builder_multiple_welcome_messages() {
1306        let mut group = test_group_custom_config(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, |b| {
1307            let options = CommitOptions::new().with_single_welcome_message(false);
1308            b.mls_rules(DefaultMlsRules::new().with_commit_options(options))
1309        })
1310        .await;
1311
1312        let (alice, alice_kp) =
1313            test_client_with_key_pkg(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "a").await;
1314
1315        let (bob, bob_kp) =
1316            test_client_with_key_pkg(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "b").await;
1317
1318        group.propose_add(alice_kp.clone(), vec![]).await.unwrap();
1319
1320        group.propose_add(bob_kp.clone(), vec![]).await.unwrap();
1321
1322        let output = group.commit(Vec::new()).await.unwrap();
1323        let welcomes = output.welcome_messages;
1324
1325        let cs = test_cipher_suite_provider(TEST_CIPHER_SUITE);
1326
1327        for (client, kp) in [(alice, alice_kp), (bob, bob_kp)] {
1328            let kp_ref = kp.key_package_reference(&cs).await.unwrap().unwrap();
1329
1330            let welcome = welcomes
1331                .iter()
1332                .find(|w| w.welcome_key_package_references().contains(&&kp_ref))
1333                .unwrap();
1334
1335            client.join_group(None, welcome).await.unwrap();
1336
1337            assert_eq!(welcome.clone().into_welcome().unwrap().secrets.len(), 1);
1338        }
1339    }
1340
1341    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1342    async fn commit_can_change_credential() {
1343        let cs = TEST_CIPHER_SUITE;
1344        let mut groups = test_n_member_group(TEST_PROTOCOL_VERSION, cs, 3).await;
1345        let (identity, secret_key) = get_test_signing_identity(cs, b"member").await;
1346
1347        let commit_output = groups[0]
1348            .commit_builder()
1349            .set_new_signing_identity(secret_key, identity.clone())
1350            .build()
1351            .await
1352            .unwrap();
1353
1354        // Check that the credential was updated by in the committer's state.
1355        groups[0].process_pending_commit().await.unwrap();
1356        let new_member = groups[0].roster().member_with_index(0).unwrap();
1357
1358        assert_eq!(
1359            new_member.signing_identity.credential,
1360            get_test_basic_credential(b"member".to_vec())
1361        );
1362
1363        assert_eq!(
1364            new_member.signing_identity.signature_key,
1365            identity.signature_key
1366        );
1367
1368        // Check that the credential was updated in another member's state.
1369        groups[1]
1370            .process_message(commit_output.commit_message)
1371            .await
1372            .unwrap();
1373
1374        let new_member = groups[1].roster().member_with_index(0).unwrap();
1375
1376        assert_eq!(
1377            new_member.signing_identity.credential,
1378            get_test_basic_credential(b"member".to_vec())
1379        );
1380
1381        assert_eq!(
1382            new_member.signing_identity.signature_key,
1383            identity.signature_key
1384        );
1385    }
1386
1387    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1388    async fn commit_includes_tree_if_no_ratchet_tree_ext() {
1389        let mut group = test_group_custom(
1390            TEST_PROTOCOL_VERSION,
1391            TEST_CIPHER_SUITE,
1392            Default::default(),
1393            None,
1394            Some(CommitOptions::new().with_ratchet_tree_extension(false)),
1395        )
1396        .await;
1397
1398        let commit = group.commit(vec![]).await.unwrap();
1399
1400        group.apply_pending_commit().await.unwrap();
1401
1402        let new_tree = group.export_tree();
1403
1404        assert_eq!(new_tree, commit.ratchet_tree.unwrap())
1405    }
1406
1407    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1408    async fn commit_does_not_include_tree_if_ratchet_tree_ext() {
1409        let mut group = test_group_custom(
1410            TEST_PROTOCOL_VERSION,
1411            TEST_CIPHER_SUITE,
1412            Default::default(),
1413            None,
1414            Some(CommitOptions::new().with_ratchet_tree_extension(true)),
1415        )
1416        .await;
1417
1418        let commit = group.commit(vec![]).await.unwrap();
1419
1420        assert!(commit.ratchet_tree.is_none());
1421    }
1422
1423    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1424    async fn commit_includes_external_commit_group_info_if_requested() {
1425        let mut group = test_group_custom(
1426            TEST_PROTOCOL_VERSION,
1427            TEST_CIPHER_SUITE,
1428            Default::default(),
1429            None,
1430            Some(
1431                CommitOptions::new()
1432                    .with_allow_external_commit(true)
1433                    .with_ratchet_tree_extension(false),
1434            ),
1435        )
1436        .await;
1437
1438        let commit = group.commit(vec![]).await.unwrap();
1439
1440        let info = commit
1441            .external_commit_group_info
1442            .unwrap()
1443            .into_group_info()
1444            .unwrap();
1445
1446        assert!(!info.extensions.has_extension(ExtensionType::RATCHET_TREE));
1447        assert!(info.extensions.has_extension(ExtensionType::EXTERNAL_PUB));
1448    }
1449
1450    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1451    async fn commit_includes_external_commit_and_tree_if_requested() {
1452        let mut group = test_group_custom(
1453            TEST_PROTOCOL_VERSION,
1454            TEST_CIPHER_SUITE,
1455            Default::default(),
1456            None,
1457            Some(
1458                CommitOptions::new()
1459                    .with_allow_external_commit(true)
1460                    .with_ratchet_tree_extension(true),
1461            ),
1462        )
1463        .await;
1464
1465        let commit = group.commit(vec![]).await.unwrap();
1466
1467        let info = commit
1468            .external_commit_group_info
1469            .unwrap()
1470            .into_group_info()
1471            .unwrap();
1472
1473        assert!(info.extensions.has_extension(ExtensionType::RATCHET_TREE));
1474        assert!(info.extensions.has_extension(ExtensionType::EXTERNAL_PUB));
1475    }
1476
1477    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1478    async fn commit_does_not_include_external_commit_group_info_if_not_requested() {
1479        let mut group = test_group_custom(
1480            TEST_PROTOCOL_VERSION,
1481            TEST_CIPHER_SUITE,
1482            Default::default(),
1483            None,
1484            Some(CommitOptions::new().with_allow_external_commit(false)),
1485        )
1486        .await;
1487
1488        let commit = group.commit(vec![]).await.unwrap();
1489
1490        assert!(commit.external_commit_group_info.is_none());
1491    }
1492
1493    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1494    async fn commit_includes_tree_out_of_bounds_and_not_in_external_group_info_if_requested_tree_ext_off(
1495    ) {
1496        let mut group = test_group_custom(
1497            TEST_PROTOCOL_VERSION,
1498            TEST_CIPHER_SUITE,
1499            Default::default(),
1500            None,
1501            Some(
1502                CommitOptions::new()
1503                    .with_always_out_of_band_ratchet_tree(true)
1504                    .with_ratchet_tree_extension(false)
1505                    .with_allow_external_commit(true),
1506            ),
1507        )
1508        .await;
1509
1510        let commit = group.commit(vec![]).await.unwrap();
1511
1512        assert!(commit.ratchet_tree.is_some());
1513
1514        let info = commit
1515            .external_commit_group_info
1516            .unwrap()
1517            .into_group_info()
1518            .unwrap();
1519
1520        assert!(!info.extensions.has_extension(ExtensionType::RATCHET_TREE));
1521    }
1522
1523    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1524    async fn commit_includes_tree_out_of_bounds_and_not_in_external_group_info_if_requested_tree_ext_on(
1525    ) {
1526        let mut group = test_group_custom(
1527            TEST_PROTOCOL_VERSION,
1528            TEST_CIPHER_SUITE,
1529            Default::default(),
1530            None,
1531            Some(
1532                CommitOptions::new()
1533                    .with_always_out_of_band_ratchet_tree(true)
1534                    .with_ratchet_tree_extension(true)
1535                    .with_allow_external_commit(true),
1536            ),
1537        )
1538        .await;
1539
1540        let commit = group.commit(vec![]).await.unwrap();
1541
1542        assert!(commit.ratchet_tree.is_some());
1543
1544        let info = commit
1545            .external_commit_group_info
1546            .unwrap()
1547            .into_group_info()
1548            .unwrap();
1549
1550        assert!(!info.extensions.has_extension(ExtensionType::RATCHET_TREE));
1551    }
1552
1553    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1554    async fn member_identity_is_validated_against_new_extensions() {
1555        let alice = client_with_test_extension(b"alice").await;
1556        let mut alice = alice
1557            .create_group(ExtensionList::new(), Default::default())
1558            .await
1559            .unwrap();
1560
1561        let bob = client_with_test_extension(b"bob").await;
1562        let bob_kp = bob
1563            .generate_key_package_message(Default::default(), Default::default())
1564            .await
1565            .unwrap();
1566
1567        let mut extension_list = ExtensionList::new();
1568        let extension = TestExtension { foo: b'a' };
1569        extension_list.set_from(extension).unwrap();
1570
1571        let res = alice
1572            .commit_builder()
1573            .add_member(bob_kp)
1574            .unwrap()
1575            .set_group_context_ext(extension_list.clone())
1576            .unwrap()
1577            .build()
1578            .await;
1579
1580        assert!(res.is_err());
1581
1582        let alex = client_with_test_extension(b"alex").await;
1583
1584        alice
1585            .commit_builder()
1586            .add_member(
1587                alex.generate_key_package_message(Default::default(), Default::default())
1588                    .await
1589                    .unwrap(),
1590            )
1591            .unwrap()
1592            .set_group_context_ext(extension_list.clone())
1593            .unwrap()
1594            .build()
1595            .await
1596            .unwrap();
1597    }
1598
1599    #[cfg(feature = "by_ref_proposal")]
1600    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1601    async fn server_identity_is_validated_against_new_extensions() {
1602        let alice = client_with_test_extension(b"alice").await;
1603        let mut alice = alice
1604            .create_group(ExtensionList::new(), Default::default())
1605            .await
1606            .unwrap();
1607
1608        let mut extension_list = ExtensionList::new();
1609        let extension = TestExtension { foo: b'a' };
1610        extension_list.set_from(extension).unwrap();
1611
1612        let (alex_server, _) = get_test_signing_identity(TEST_CIPHER_SUITE, b"alex").await;
1613
1614        let mut alex_extensions = extension_list.clone();
1615
1616        alex_extensions
1617            .set_from(ExternalSendersExt {
1618                allowed_senders: vec![alex_server],
1619            })
1620            .unwrap();
1621
1622        let res = alice
1623            .commit_builder()
1624            .set_group_context_ext(alex_extensions)
1625            .unwrap()
1626            .build()
1627            .await;
1628
1629        assert!(res.is_err());
1630
1631        let (bob_server, _) = get_test_signing_identity(TEST_CIPHER_SUITE, b"bob").await;
1632
1633        let mut bob_extensions = extension_list;
1634
1635        bob_extensions
1636            .set_from(ExternalSendersExt {
1637                allowed_senders: vec![bob_server],
1638            })
1639            .unwrap();
1640
1641        alice
1642            .commit_builder()
1643            .set_group_context_ext(bob_extensions)
1644            .unwrap()
1645            .build()
1646            .await
1647            .unwrap();
1648    }
1649
1650    #[derive(Debug, Clone)]
1651    struct IdentityProviderWithExtension(BasicIdentityProvider);
1652
1653    #[derive(Clone, Debug)]
1654    #[cfg_attr(feature = "std", derive(thiserror::Error))]
1655    #[cfg_attr(feature = "std", error("test error"))]
1656    struct IdentityProviderWithExtensionError {}
1657
1658    impl IntoAnyError for IdentityProviderWithExtensionError {
1659        #[cfg(feature = "std")]
1660        fn into_dyn_error(self) -> Result<Box<dyn std::error::Error + Send + Sync>, Self> {
1661            Ok(self.into())
1662        }
1663    }
1664
1665    impl IdentityProviderWithExtension {
1666        // True if the identity starts with the character `foo` from `TestExtension` or if `TestExtension`
1667        // is not set.
1668        #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
1669        async fn starts_with_foo(
1670            &self,
1671            identity: &SigningIdentity,
1672            _timestamp: Option<MlsTime>,
1673            extensions: Option<&ExtensionList>,
1674        ) -> bool {
1675            if let Some(extensions) = extensions {
1676                if let Some(ext) = extensions.get_as::<TestExtension>().unwrap() {
1677                    self.identity(identity, extensions).await.unwrap()[0] == ext.foo
1678                } else {
1679                    true
1680                }
1681            } else {
1682                true
1683            }
1684        }
1685    }
1686
1687    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
1688    #[cfg_attr(mls_build_async, maybe_async::must_be_async)]
1689    impl IdentityProvider for IdentityProviderWithExtension {
1690        type Error = IdentityProviderWithExtensionError;
1691
1692        async fn validate_member(
1693            &self,
1694            identity: &SigningIdentity,
1695            timestamp: Option<MlsTime>,
1696            context: MemberValidationContext<'_>,
1697        ) -> Result<(), Self::Error> {
1698            self.starts_with_foo(identity, timestamp, context.new_extensions())
1699                .await
1700                .then_some(())
1701                .ok_or(IdentityProviderWithExtensionError {})
1702        }
1703
1704        async fn validate_external_sender(
1705            &self,
1706            identity: &SigningIdentity,
1707            timestamp: Option<MlsTime>,
1708            extensions: Option<&ExtensionList>,
1709        ) -> Result<(), Self::Error> {
1710            (!self.starts_with_foo(identity, timestamp, extensions).await)
1711                .then_some(())
1712                .ok_or(IdentityProviderWithExtensionError {})
1713        }
1714
1715        async fn identity(
1716            &self,
1717            signing_identity: &SigningIdentity,
1718            extensions: &ExtensionList,
1719        ) -> Result<Vec<u8>, Self::Error> {
1720            self.0
1721                .identity(signing_identity, extensions)
1722                .await
1723                .map_err(|_| IdentityProviderWithExtensionError {})
1724        }
1725
1726        async fn valid_successor(
1727            &self,
1728            _predecessor: &SigningIdentity,
1729            _successor: &SigningIdentity,
1730            _extensions: &ExtensionList,
1731        ) -> Result<bool, Self::Error> {
1732            Ok(true)
1733        }
1734
1735        fn supported_types(&self) -> Vec<CredentialType> {
1736            self.0.supported_types()
1737        }
1738    }
1739
1740    type ExtensionClientConfig = WithIdentityProvider<
1741        IdentityProviderWithExtension,
1742        WithCryptoProvider<TestCryptoProvider, BaseConfig>,
1743    >;
1744
1745    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
1746    async fn client_with_test_extension(name: &[u8]) -> Client<ExtensionClientConfig> {
1747        let (identity, secret_key) = get_test_signing_identity(TEST_CIPHER_SUITE, name).await;
1748
1749        ClientBuilder::new()
1750            .crypto_provider(TestCryptoProvider::new())
1751            .extension_types(vec![TEST_EXTENSION_TYPE.into()])
1752            .identity_provider(IdentityProviderWithExtension(BasicIdentityProvider::new()))
1753            .signing_identity(identity, secret_key, TEST_CIPHER_SUITE)
1754            .build()
1755    }
1756
1757    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1758    async fn detached_commit() {
1759        let mut group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
1760
1761        let (_commit, secrets) = group.commit_builder().build_detached().await.unwrap();
1762        assert!(group.pending_commit.is_none());
1763        group.apply_detached_commit(secrets).await.unwrap();
1764        assert_eq!(group.context().epoch, 1);
1765    }
1766}