mls_rs/group/
commit.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// Copyright by contributors to this project.
3// SPDX-License-Identifier: (Apache-2.0 OR MIT)
4
5use alloc::boxed::Box;
6use alloc::vec;
7use alloc::vec::Vec;
8use core::fmt::Debug;
9use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
10use mls_rs_core::{crypto::SignatureSecretKey, error::IntoAnyError};
11
12use crate::{
13    cipher_suite::CipherSuite,
14    client::MlsError,
15    client_config::ClientConfig,
16    extension::RatchetTreeExt,
17    identity::SigningIdentity,
18    protocol_version::ProtocolVersion,
19    signer::Signable,
20    tree_kem::{
21        kem::TreeKem, node::LeafIndex, path_secret::PathSecret, TreeKemPrivate, UpdatePath,
22    },
23    ExtensionList, MlsRules,
24};
25
26#[cfg(all(not(mls_build_async), feature = "rayon"))]
27use {crate::iter::ParallelIteratorExt, rayon::prelude::*};
28
29use crate::tree_kem::leaf_node::LeafNode;
30
31#[cfg(not(feature = "private_message"))]
32use crate::WireFormat;
33
34#[cfg(feature = "psk")]
35use crate::{
36    group::{JustPreSharedKeyID, PskGroupId, ResumptionPSKUsage, ResumptionPsk},
37    psk::ExternalPskId,
38};
39
40use super::{
41    confirmation_tag::ConfirmationTag,
42    framing::{Content, MlsMessage, MlsMessagePayload, Sender},
43    key_schedule::{KeySchedule, WelcomeSecret},
44    message_hash::MessageHash,
45    message_processor::{path_update_required, MessageProcessor},
46    message_signature::AuthenticatedContent,
47    mls_rules::CommitDirection,
48    proposal::{Proposal, ProposalOrRef},
49    CommitEffect, CommitMessageDescription, EncryptedGroupSecrets, EpochSecrets, ExportedTree,
50    Group, GroupContext, GroupInfo, GroupState, InterimTranscriptHash, NewEpoch,
51    PendingCommitSnapshot, Welcome,
52};
53
54#[cfg(not(feature = "by_ref_proposal"))]
55use super::proposal_cache::prepare_commit;
56
57#[cfg(feature = "custom_proposal")]
58use super::proposal::CustomProposal;
59
60#[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode)]
61#[cfg_attr(feature = "arbitrary", derive(mls_rs_core::arbitrary::Arbitrary))]
62#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
63pub(crate) struct Commit {
64    pub proposals: Vec<ProposalOrRef>,
65    pub path: Option<UpdatePath>,
66}
67
68#[derive(Clone, PartialEq, Debug, MlsEncode, MlsDecode, MlsSize)]
69pub(crate) struct PendingCommit {
70    pub(crate) state: GroupState,
71    pub(crate) epoch_secrets: EpochSecrets,
72    pub(crate) private_tree: TreeKemPrivate,
73    pub(crate) key_schedule: KeySchedule,
74    pub(crate) signer: SignatureSecretKey,
75
76    pub(crate) output: CommitMessageDescription,
77
78    pub(crate) commit_message_hash: MessageHash,
79}
80
81#[cfg_attr(
82    all(feature = "ffi", not(test)),
83    safer_ffi_gen::ffi_type(clone, opaque)
84)]
85#[derive(Clone)]
86pub struct CommitSecrets(pub(crate) PendingCommitSnapshot);
87
88impl CommitSecrets {
89    /// Deserialize the commit secrets from bytes
90    pub fn from_bytes(bytes: &[u8]) -> Result<Self, MlsError> {
91        Ok(MlsDecode::mls_decode(&mut &*bytes).map(Self)?)
92    }
93
94    /// Serialize the commit secrets to bytes
95    pub fn to_bytes(&self) -> Result<Vec<u8>, MlsError> {
96        Ok(self.0.mls_encode_to_vec()?)
97    }
98}
99
100#[cfg_attr(
101    all(feature = "ffi", not(test)),
102    safer_ffi_gen::ffi_type(clone, opaque)
103)]
104#[derive(Clone, Debug)]
105#[non_exhaustive]
106/// Result of MLS commit operation using
107/// [`Group::commit`](crate::group::Group::commit) or
108/// [`CommitBuilder::build`](CommitBuilder::build).
109pub struct CommitOutput {
110    /// Commit message to send to other group members.
111    pub commit_message: MlsMessage,
112    /// Welcome messages to send to new group members. If the commit does not add members,
113    /// this list is empty. Otherwise, if [`MlsRules::commit_options`] returns `single_welcome_message`
114    /// set to true, then this list contains a single message sent to all members. Else, the list
115    /// contains one message for each added member. Recipients of each message can be identified using
116    /// [`MlsMessage::key_package_reference`] of their key packages and
117    /// [`MlsMessage::welcome_key_package_references`].
118    pub welcome_messages: Vec<MlsMessage>,
119    /// Ratchet tree that can be sent out of band if
120    /// `ratchet_tree_extension` is not used according to
121    /// [`MlsRules::commit_options`].
122    pub ratchet_tree: Option<ExportedTree<'static>>,
123    /// A group info that can be provided to new members in order to enable external commit
124    /// functionality. This value is set if [`MlsRules::commit_options`] returns
125    /// `allow_external_commit` set to true.
126    pub external_commit_group_info: Option<MlsMessage>,
127    /// Proposals that were received in the prior epoch but not included in the following commit.
128    #[cfg(feature = "by_ref_proposal")]
129    pub unused_proposals: Vec<crate::mls_rules::ProposalInfo<Proposal>>,
130    /// Indicator that the commit contains a path update
131    pub contains_update_path: bool,
132}
133
134#[cfg_attr(all(feature = "ffi", not(test)), ::safer_ffi_gen::safer_ffi_gen)]
135impl CommitOutput {
136    /// Commit message to send to other group members.
137    #[cfg(feature = "ffi")]
138    pub fn commit_message(&self) -> &MlsMessage {
139        &self.commit_message
140    }
141
142    /// Welcome message to send to new group members.
143    #[cfg(feature = "ffi")]
144    pub fn welcome_messages(&self) -> &[MlsMessage] {
145        &self.welcome_messages
146    }
147
148    /// Ratchet tree that can be sent out of band if
149    /// `ratchet_tree_extension` is not used according to
150    /// [`MlsRules::commit_options`].
151    #[cfg(feature = "ffi")]
152    pub fn ratchet_tree(&self) -> Option<&ExportedTree<'static>> {
153        self.ratchet_tree.as_ref()
154    }
155
156    /// A group info that can be provided to new members in order to enable external commit
157    /// functionality. This value is set if [`MlsRules::commit_options`] returns
158    /// `allow_external_commit` set to true.
159    #[cfg(feature = "ffi")]
160    pub fn external_commit_group_info(&self) -> Option<&MlsMessage> {
161        self.external_commit_group_info.as_ref()
162    }
163
164    /// Proposals that were received in the prior epoch but not included in the following commit.
165    #[cfg(all(feature = "ffi", feature = "by_ref_proposal"))]
166    pub fn unused_proposals(&self) -> &[crate::mls_rules::ProposalInfo<Proposal>] {
167        &self.unused_proposals
168    }
169}
170
171/// Build a commit with multiple proposals by-value.
172///
173/// Proposals within a commit can be by-value or by-reference.
174/// Proposals received during the current epoch will be added to the resulting
175/// commit by-reference automatically so long as they pass the rules defined
176/// in the current
177/// [proposal rules](crate::client_builder::ClientBuilder::mls_rules).
178pub struct CommitBuilder<'a, C>
179where
180    C: ClientConfig + Clone,
181{
182    group: &'a mut Group<C>,
183    pub(super) proposals: Vec<Proposal>,
184    authenticated_data: Vec<u8>,
185    group_info_extensions: ExtensionList,
186    new_signer: Option<SignatureSecretKey>,
187    new_signing_identity: Option<SigningIdentity>,
188    new_leaf_node_extensions: Option<ExtensionList>,
189}
190
191impl<'a, C> CommitBuilder<'a, C>
192where
193    C: ClientConfig + Clone,
194{
195    /// Insert an [`AddProposal`](crate::group::proposal::AddProposal) into
196    /// the current commit that is being built.
197    pub fn add_member(mut self, key_package: MlsMessage) -> Result<CommitBuilder<'a, C>, MlsError> {
198        let proposal = self.group.add_proposal(key_package)?;
199        self.proposals.push(proposal);
200        Ok(self)
201    }
202
203    /// Set group info extensions that will be inserted into the resulting
204    /// [welcome messages](CommitOutput::welcome_messages) for new members.
205    ///
206    /// Group info extensions that are transmitted as part of a welcome message
207    /// are encrypted along with other private values.
208    ///
209    /// These extensions can be retrieved as part of
210    /// [`NewMemberInfo`](crate::group::NewMemberInfo) that is returned
211    /// by joining the group via
212    /// [`Client::join_group`](crate::Client::join_group).
213    pub fn set_group_info_ext(self, extensions: ExtensionList) -> Self {
214        Self {
215            group_info_extensions: extensions,
216            ..self
217        }
218    }
219
220    /// Insert a [`RemoveProposal`](crate::group::proposal::RemoveProposal) into
221    /// the current commit that is being built.
222    pub fn remove_member(mut self, index: u32) -> Result<Self, MlsError> {
223        let proposal = self.group.remove_proposal(index)?;
224        self.proposals.push(proposal);
225        Ok(self)
226    }
227
228    /// Insert a
229    /// [`GroupContextExtensions`](crate::group::proposal::Proposal::GroupContextExtensions)
230    /// into the current commit that is being built.
231    pub fn set_group_context_ext(mut self, extensions: ExtensionList) -> Result<Self, MlsError> {
232        let proposal = self.group.group_context_extensions_proposal(extensions);
233        self.proposals.push(proposal);
234        Ok(self)
235    }
236
237    /// Insert a
238    /// [`PreSharedKeyProposal`](crate::group::proposal::PreSharedKeyProposal) with
239    /// an external PSK into the current commit that is being built.
240    #[cfg(feature = "psk")]
241    pub fn add_external_psk(mut self, psk_id: ExternalPskId) -> Result<Self, MlsError> {
242        let key_id = JustPreSharedKeyID::External(psk_id);
243        let proposal = self.group.psk_proposal(key_id)?;
244        self.proposals.push(proposal);
245        Ok(self)
246    }
247
248    /// Insert a
249    /// [`PreSharedKeyProposal`](crate::group::proposal::PreSharedKeyProposal) with
250    /// a resumption PSK into the current commit that is being built.
251    #[cfg(feature = "psk")]
252    pub fn add_resumption_psk(mut self, psk_epoch: u64) -> Result<Self, MlsError> {
253        let psk_id = ResumptionPsk {
254            psk_epoch,
255            usage: ResumptionPSKUsage::Application,
256            psk_group_id: PskGroupId(self.group.group_id().to_vec()),
257        };
258
259        let key_id = JustPreSharedKeyID::Resumption(psk_id);
260        let proposal = self.group.psk_proposal(key_id)?;
261        self.proposals.push(proposal);
262        Ok(self)
263    }
264
265    /// Insert a [`ReInitProposal`](crate::group::proposal::ReInitProposal) into
266    /// the current commit that is being built.
267    pub fn reinit(
268        mut self,
269        group_id: Option<Vec<u8>>,
270        version: ProtocolVersion,
271        cipher_suite: CipherSuite,
272        extensions: ExtensionList,
273    ) -> Result<Self, MlsError> {
274        let proposal = self
275            .group
276            .reinit_proposal(group_id, version, cipher_suite, extensions)?;
277
278        self.proposals.push(proposal);
279        Ok(self)
280    }
281
282    /// Insert a [`CustomProposal`](crate::group::proposal::CustomProposal) into
283    /// the current commit that is being built.
284    #[cfg(feature = "custom_proposal")]
285    pub fn custom_proposal(mut self, proposal: CustomProposal) -> Self {
286        self.proposals.push(Proposal::Custom(proposal));
287        self
288    }
289
290    /// Insert a proposal that was previously constructed such as when a
291    /// proposal is returned from
292    /// [`NewEpoch::unused_proposals`](super::NewEpoch::unused_proposals).
293    pub fn raw_proposal(mut self, proposal: Proposal) -> Self {
294        self.proposals.push(proposal);
295        self
296    }
297
298    /// Insert proposals that were previously constructed such as when a
299    /// proposal is returned from
300    /// [`NewEpoch::unused_proposals`](super::NewEpoch::unused_proposals).
301    pub fn raw_proposals(mut self, mut proposals: Vec<Proposal>) -> Self {
302        self.proposals.append(&mut proposals);
303        self
304    }
305
306    /// Add additional authenticated data to the commit.
307    ///
308    /// # Warning
309    ///
310    /// The data provided here is always sent unencrypted.
311    pub fn authenticated_data(self, authenticated_data: Vec<u8>) -> Self {
312        Self {
313            authenticated_data,
314            ..self
315        }
316    }
317
318    /// Change the committer's signing identity as part of making this commit.
319    /// This will only succeed if the [`IdentityProvider`](crate::IdentityProvider)
320    /// in use by the group considers the credential inside this signing_identity
321    /// [valid](crate::IdentityProvider::validate_member)
322    /// and results in the same
323    /// [identity](crate::IdentityProvider::identity)
324    /// being used.
325    pub fn set_new_signing_identity(
326        self,
327        signer: SignatureSecretKey,
328        signing_identity: SigningIdentity,
329    ) -> Self {
330        Self {
331            new_signer: Some(signer),
332            new_signing_identity: Some(signing_identity),
333            ..self
334        }
335    }
336
337    /// Change the committer's leaf node extensions as part of making this commit.
338    pub fn set_leaf_node_extensions(self, new_leaf_node_extensions: ExtensionList) -> Self {
339        Self {
340            new_leaf_node_extensions: Some(new_leaf_node_extensions),
341            ..self
342        }
343    }
344
345    /// Finalize the commit to send.
346    ///
347    /// # Errors
348    ///
349    /// This function will return an error if any of the proposals provided
350    /// are not contextually valid according to the rules defined by the
351    /// MLS RFC, or if they do not pass the custom rules defined by the current
352    /// [proposal rules](crate::client_builder::ClientBuilder::mls_rules).
353    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
354    pub async fn build(self) -> Result<CommitOutput, MlsError> {
355        let (output, pending_commit) = self
356            .group
357            .commit_internal(
358                self.proposals,
359                None,
360                self.authenticated_data,
361                self.group_info_extensions,
362                self.new_signer,
363                self.new_signing_identity,
364                self.new_leaf_node_extensions,
365            )
366            .await?;
367
368        self.group.pending_commit = pending_commit.try_into()?;
369
370        Ok(output)
371    }
372
373    /// The same function as `GroupBuilder::build` except the secrets generated
374    /// for the commit are outputted instead of being cached internally.
375    ///
376    /// A detached commit can be applied using `Group::apply_detached_commit`.
377    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
378    pub async fn build_detached(self) -> Result<(CommitOutput, CommitSecrets), MlsError> {
379        let (output, pending_commit) = self
380            .group
381            .commit_internal(
382                self.proposals,
383                None,
384                self.authenticated_data,
385                self.group_info_extensions,
386                self.new_signer,
387                self.new_signing_identity,
388                self.new_leaf_node_extensions,
389            )
390            .await?;
391
392        Ok((
393            output,
394            CommitSecrets(PendingCommitSnapshot::PendingCommit(
395                pending_commit.mls_encode_to_vec()?,
396            )),
397        ))
398    }
399}
400
401impl<C> Group<C>
402where
403    C: ClientConfig + Clone,
404{
405    /// Perform a commit of received proposals.
406    ///
407    /// This function is the equivalent of [`Group::commit_builder`] immediately
408    /// followed by [`CommitBuilder::build`]. Any received proposals since the
409    /// last commit will be included in the resulting message by-reference.
410    ///
411    /// Data provided in the `authenticated_data` field will be placed into
412    /// the resulting commit message unencrypted.
413    ///
414    /// # Pending Commits
415    ///
416    /// When a commit is created, it is not applied immediately in order to
417    /// allow for the resolution of conflicts when multiple members of a group
418    /// attempt to make commits at the same time. For example, a central relay
419    /// can be used to decide which commit should be accepted by the group by
420    /// determining a consistent view of commit packet order for all clients.
421    ///
422    /// Pending commits are stored internally as part of the group's state
423    /// so they do not need to be tracked outside of this library. Any commit
424    /// message that is processed before calling [Group::apply_pending_commit]
425    /// will clear the currently pending commit.
426    ///
427    /// # Empty Commits
428    ///
429    /// Sending a commit that contains no proposals is a valid operation
430    /// within the MLS protocol. It is useful for providing stronger forward
431    /// secrecy and post-compromise security, especially for long running
432    /// groups when group membership does not change often.
433    ///
434    /// # Path Updates
435    ///
436    /// Path updates provide forward secrecy and post-compromise security
437    /// within the MLS protocol.
438    /// The `path_required` option returned by [`MlsRules::commit_options`](`crate::MlsRules::commit_options`)
439    /// controls the ability of a group to send a commit without a path update.
440    /// An update path will automatically be sent if there are no proposals
441    /// in the commit, or if any proposal other than
442    /// [`Add`](crate::group::proposal::Proposal::Add),
443    /// [`Psk`](crate::group::proposal::Proposal::Psk),
444    /// or [`ReInit`](crate::group::proposal::Proposal::ReInit) are part of the commit.
445    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
446    pub async fn commit(&mut self, authenticated_data: Vec<u8>) -> Result<CommitOutput, MlsError> {
447        self.commit_builder()
448            .authenticated_data(authenticated_data)
449            .build()
450            .await
451    }
452
453    /// The same function as `Group::commit` except the secrets generated
454    /// for the commit are outputted instead of being cached internally.
455    ///
456    /// A detached commit can be applied using `Group::apply_detached_commit`.
457    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
458    pub async fn commit_detached(
459        &mut self,
460        authenticated_data: Vec<u8>,
461    ) -> Result<(CommitOutput, CommitSecrets), MlsError> {
462        self.commit_builder()
463            .authenticated_data(authenticated_data)
464            .build_detached()
465            .await
466    }
467
468    /// Create a new commit builder that can include proposals
469    /// by-value.
470    pub fn commit_builder(&mut self) -> CommitBuilder<C> {
471        CommitBuilder {
472            group: self,
473            proposals: Default::default(),
474            authenticated_data: Default::default(),
475            group_info_extensions: Default::default(),
476            new_signer: Default::default(),
477            new_signing_identity: Default::default(),
478            new_leaf_node_extensions: Default::default(),
479        }
480    }
481
482    /// Returns commit and optional [`MlsMessage`] containing a welcome message
483    /// for newly added members.
484    #[allow(clippy::too_many_arguments)]
485    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
486    pub(super) async fn commit_internal(
487        &mut self,
488        proposals: Vec<Proposal>,
489        external_leaf: Option<&LeafNode>,
490        authenticated_data: Vec<u8>,
491        mut welcome_group_info_extensions: ExtensionList,
492        new_signer: Option<SignatureSecretKey>,
493        new_signing_identity: Option<SigningIdentity>,
494        new_leaf_node_extensions: Option<ExtensionList>,
495    ) -> Result<(CommitOutput, PendingCommit), MlsError> {
496        if !self.pending_commit.is_none() {
497            return Err(MlsError::ExistingPendingCommit);
498        }
499
500        if self.state.pending_reinit.is_some() {
501            return Err(MlsError::GroupUsedAfterReInit);
502        }
503
504        let mls_rules = self.config.mls_rules();
505
506        let is_external = external_leaf.is_some();
507
508        // Construct an initial Commit object with the proposals field populated from Proposals
509        // received during the current epoch, and an empty path field. Add passed in proposals
510        // by value
511        let sender = if is_external {
512            Sender::NewMemberCommit
513        } else {
514            Sender::Member(*self.private_tree.self_index)
515        };
516
517        let new_signer = new_signer.unwrap_or_else(|| self.signer.clone());
518        let old_signer = &self.signer;
519
520        #[cfg(feature = "std")]
521        let time = Some(crate::time::MlsTime::now());
522
523        #[cfg(not(feature = "std"))]
524        let time = None;
525
526        #[cfg(feature = "by_ref_proposal")]
527        let proposals = self.state.proposals.prepare_commit(sender, proposals);
528
529        #[cfg(not(feature = "by_ref_proposal"))]
530        let proposals = prepare_commit(sender, proposals);
531
532        let mut provisional_state = self
533            .state
534            .apply_resolved(
535                sender,
536                proposals,
537                external_leaf,
538                &self.config.identity_provider(),
539                &self.cipher_suite_provider,
540                &self.config.secret_store(),
541                &mls_rules,
542                time,
543                CommitDirection::Send,
544            )
545            .await?;
546
547        let (mut provisional_private_tree, _) =
548            self.provisional_private_tree(&provisional_state)?;
549
550        if is_external {
551            provisional_private_tree.self_index = provisional_state
552                .external_init_index
553                .ok_or(MlsError::ExternalCommitMissingExternalInit)?;
554
555            self.private_tree.self_index = provisional_private_tree.self_index;
556        }
557
558        // Decide whether to populate the path field: If the path field is required based on the
559        // proposals that are in the commit (see above), then it MUST be populated. Otherwise, the
560        // sender MAY omit the path field at its discretion.
561        let commit_options = mls_rules
562            .commit_options(
563                &provisional_state.public_tree.roster(),
564                &provisional_state.group_context,
565                &provisional_state.applied_proposals,
566            )
567            .map_err(|e| MlsError::MlsRulesError(e.into_any_error()))?;
568
569        let perform_path_update = commit_options.path_required
570            || path_update_required(&provisional_state.applied_proposals);
571
572        let (update_path, path_secrets, commit_secret) = if perform_path_update {
573            // If populating the path field: Create an UpdatePath using the new tree. Any new
574            // member (from an add proposal) MUST be excluded from the resolution during the
575            // computation of the UpdatePath. The GroupContext for this operation uses the
576            // group_id, epoch, tree_hash, and confirmed_transcript_hash values in the initial
577            // GroupContext object. The leaf_key_package for this UpdatePath must have a
578            // parent_hash extension.
579
580            let new_leaf_node_extensions =
581                new_leaf_node_extensions.or(external_leaf.map(|ln| ln.ungreased_extensions()));
582
583            let new_leaf_node_extensions = match new_leaf_node_extensions {
584                Some(extensions) => extensions,
585                // If we are not setting new extensions and this is not an external leaf then the current node MUST exist.
586                None => self.current_user_leaf_node()?.ungreased_extensions(),
587            };
588
589            let encap_gen = TreeKem::new(
590                &mut provisional_state.public_tree,
591                &mut provisional_private_tree,
592            )
593            .encap(
594                &mut provisional_state.group_context,
595                &provisional_state.indexes_of_added_kpkgs,
596                &new_signer,
597                Some(self.config.leaf_properties(new_leaf_node_extensions)),
598                new_signing_identity,
599                &self.cipher_suite_provider,
600                #[cfg(test)]
601                &self.commit_modifiers,
602            )
603            .await?;
604
605            (
606                Some(encap_gen.update_path),
607                Some(encap_gen.path_secrets),
608                encap_gen.commit_secret,
609            )
610        } else {
611            // Update the tree hash, since it was not updated by encap.
612            provisional_state
613                .public_tree
614                .update_hashes(
615                    &[provisional_private_tree.self_index],
616                    &self.cipher_suite_provider,
617                )
618                .await?;
619
620            provisional_state.group_context.tree_hash = provisional_state
621                .public_tree
622                .tree_hash(&self.cipher_suite_provider)
623                .await?;
624
625            (None, None, PathSecret::empty(&self.cipher_suite_provider))
626        };
627
628        #[cfg(feature = "psk")]
629        let (psk_secret, psks) = self
630            .get_psk(&provisional_state.applied_proposals.psks)
631            .await?;
632
633        #[cfg(not(feature = "psk"))]
634        let psk_secret = self.get_psk();
635
636        let added_key_pkgs: Vec<_> = provisional_state
637            .applied_proposals
638            .additions
639            .iter()
640            .map(|info| info.proposal.key_package.clone())
641            .collect();
642
643        let commit = Commit {
644            proposals: provisional_state.applied_proposals.proposals_or_refs(),
645            path: update_path,
646        };
647
648        let mut auth_content = AuthenticatedContent::new_signed(
649            &self.cipher_suite_provider,
650            self.context(),
651            sender,
652            Content::Commit(Box::new(commit)),
653            old_signer,
654            #[cfg(feature = "private_message")]
655            self.encryption_options()?.control_wire_format(sender),
656            #[cfg(not(feature = "private_message"))]
657            WireFormat::PublicMessage,
658            authenticated_data,
659        )
660        .await?;
661
662        // Use the signature, the commit_secret and the psk_secret to advance the key schedule and
663        // compute the confirmation_tag value in the MlsPlaintext.
664        let confirmed_transcript_hash = super::transcript_hash::create(
665            self.cipher_suite_provider(),
666            &self.state.interim_transcript_hash,
667            &auth_content,
668        )
669        .await?;
670
671        provisional_state.group_context.confirmed_transcript_hash = confirmed_transcript_hash;
672
673        let key_schedule_result = KeySchedule::from_key_schedule(
674            &self.key_schedule,
675            &commit_secret,
676            &provisional_state.group_context,
677            #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
678            provisional_state.public_tree.total_leaf_count(),
679            &psk_secret,
680            &self.cipher_suite_provider,
681        )
682        .await?;
683
684        let confirmation_tag = ConfirmationTag::create(
685            &key_schedule_result.confirmation_key,
686            &provisional_state.group_context.confirmed_transcript_hash,
687            &self.cipher_suite_provider,
688        )
689        .await?;
690
691        let interim_transcript_hash = InterimTranscriptHash::create(
692            self.cipher_suite_provider(),
693            &provisional_state.group_context.confirmed_transcript_hash,
694            &confirmation_tag,
695        )
696        .await?;
697
698        auth_content.auth.confirmation_tag = Some(confirmation_tag.clone());
699
700        let ratchet_tree_ext = commit_options
701            .ratchet_tree_extension
702            .then(|| RatchetTreeExt {
703                tree_data: ExportedTree::new(provisional_state.public_tree.nodes.clone()),
704            });
705
706        // Generate external commit group info if required by commit_options
707        let external_commit_group_info = match commit_options.allow_external_commit {
708            true => {
709                let mut extensions = ExtensionList::new();
710
711                extensions.set_from({
712                    key_schedule_result
713                        .key_schedule
714                        .get_external_key_pair_ext(&self.cipher_suite_provider)
715                        .await?
716                })?;
717
718                if let Some(ref ratchet_tree_ext) = ratchet_tree_ext {
719                    extensions.set_from(ratchet_tree_ext.clone())?;
720                }
721
722                let info = self
723                    .make_group_info(
724                        &provisional_state.group_context,
725                        extensions,
726                        &confirmation_tag,
727                        &new_signer,
728                    )
729                    .await?;
730
731                let msg =
732                    MlsMessage::new(self.protocol_version(), MlsMessagePayload::GroupInfo(info));
733
734                Some(msg)
735            }
736            false => None,
737        };
738
739        // Build the group info that will be placed into the welcome messages.
740        // Add the ratchet tree extension if necessary
741        if let Some(ratchet_tree_ext) = ratchet_tree_ext {
742            welcome_group_info_extensions.set_from(ratchet_tree_ext)?;
743        }
744
745        let welcome_group_info = self
746            .make_group_info(
747                &provisional_state.group_context,
748                welcome_group_info_extensions,
749                &confirmation_tag,
750                &new_signer,
751            )
752            .await?;
753
754        // Encrypt the GroupInfo using the key and nonce derived from the joiner_secret for
755        // the new epoch
756        let welcome_secret = WelcomeSecret::from_joiner_secret(
757            &self.cipher_suite_provider,
758            &key_schedule_result.joiner_secret,
759            &psk_secret,
760        )
761        .await?;
762
763        let encrypted_group_info = welcome_secret
764            .encrypt(&welcome_group_info.mls_encode_to_vec()?)
765            .await?;
766
767        // Encrypt path secrets and joiner secret to new members
768        let path_secrets = path_secrets.as_ref();
769
770        #[cfg(not(any(mls_build_async, not(feature = "rayon"))))]
771        let encrypted_path_secrets: Vec<_> = added_key_pkgs
772            .into_par_iter()
773            .zip(&provisional_state.indexes_of_added_kpkgs)
774            .map(|(key_package, leaf_index)| {
775                self.encrypt_group_secrets(
776                    &key_package,
777                    *leaf_index,
778                    &key_schedule_result.joiner_secret,
779                    path_secrets,
780                    #[cfg(feature = "psk")]
781                    psks.clone(),
782                    &encrypted_group_info,
783                )
784            })
785            .try_collect()?;
786
787        #[cfg(any(mls_build_async, not(feature = "rayon")))]
788        let encrypted_path_secrets = {
789            let mut secrets = Vec::new();
790
791            for (key_package, leaf_index) in added_key_pkgs
792                .into_iter()
793                .zip(&provisional_state.indexes_of_added_kpkgs)
794            {
795                secrets.push(
796                    self.encrypt_group_secrets(
797                        &key_package,
798                        *leaf_index,
799                        &key_schedule_result.joiner_secret,
800                        path_secrets,
801                        #[cfg(feature = "psk")]
802                        psks.clone(),
803                        &encrypted_group_info,
804                    )
805                    .await?,
806                );
807            }
808
809            secrets
810        };
811
812        let welcome_messages =
813            if commit_options.single_welcome_message && !encrypted_path_secrets.is_empty() {
814                vec![self.make_welcome_message(encrypted_path_secrets, encrypted_group_info)]
815            } else {
816                encrypted_path_secrets
817                    .into_iter()
818                    .map(|s| self.make_welcome_message(vec![s], encrypted_group_info.clone()))
819                    .collect()
820            };
821
822        let commit_message = self.format_for_wire(auth_content.clone()).await?;
823
824        // TODO is it necessary to clone the tree here? or can we just output serialized bytes?
825        let ratchet_tree = (!commit_options.ratchet_tree_extension)
826            .then(|| ExportedTree::new(provisional_state.public_tree.nodes.clone()));
827
828        let pending_reinit = provisional_state
829            .applied_proposals
830            .reinitializations
831            .first();
832
833        let pending_commit = PendingCommit {
834            output: CommitMessageDescription {
835                is_external: matches!(auth_content.content.sender, Sender::NewMemberCommit),
836                authenticated_data: auth_content.content.authenticated_data,
837                committer: *provisional_private_tree.self_index,
838                effect: match pending_reinit {
839                    Some(r) => CommitEffect::ReInit(r.clone()),
840                    None => CommitEffect::NewEpoch(
841                        NewEpoch::new(self.state.clone(), &provisional_state).into(),
842                    ),
843                },
844            },
845
846            state: GroupState {
847                #[cfg(feature = "by_ref_proposal")]
848                proposals: crate::group::ProposalCache::new(
849                    self.protocol_version(),
850                    self.group_id().to_vec(),
851                ),
852                context: provisional_state.group_context,
853                public_tree: provisional_state.public_tree,
854                interim_transcript_hash,
855                pending_reinit: pending_reinit.map(|r| r.proposal.clone()),
856                confirmation_tag,
857            },
858
859            commit_message_hash: MessageHash::compute(&self.cipher_suite_provider, &commit_message)
860                .await?,
861            signer: new_signer,
862            epoch_secrets: key_schedule_result.epoch_secrets,
863            key_schedule: key_schedule_result.key_schedule,
864
865            private_tree: provisional_private_tree,
866        };
867
868        let output = CommitOutput {
869            commit_message,
870            welcome_messages,
871            ratchet_tree,
872            external_commit_group_info,
873            contains_update_path: perform_path_update,
874            #[cfg(feature = "by_ref_proposal")]
875            unused_proposals: provisional_state.unused_proposals,
876        };
877
878        Ok((output, pending_commit))
879    }
880
881    // Construct a GroupInfo reflecting the new state
882    // Group ID, epoch, tree, and confirmed transcript hash from the new state
883    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
884    async fn make_group_info(
885        &self,
886        group_context: &GroupContext,
887        extensions: ExtensionList,
888        confirmation_tag: &ConfirmationTag,
889        signer: &SignatureSecretKey,
890    ) -> Result<GroupInfo, MlsError> {
891        let mut group_info = GroupInfo {
892            group_context: group_context.clone(),
893            extensions,
894            confirmation_tag: confirmation_tag.clone(), // The confirmation_tag from the MlsPlaintext object
895            signer: LeafIndex(self.current_member_index()),
896            signature: vec![],
897        };
898
899        group_info.grease(self.cipher_suite_provider())?;
900
901        // Sign the GroupInfo using the member's private signing key
902        group_info
903            .sign(&self.cipher_suite_provider, signer, &())
904            .await?;
905
906        Ok(group_info)
907    }
908
909    fn make_welcome_message(
910        &self,
911        secrets: Vec<EncryptedGroupSecrets>,
912        encrypted_group_info: Vec<u8>,
913    ) -> MlsMessage {
914        MlsMessage::new(
915            self.context().protocol_version,
916            MlsMessagePayload::Welcome(Welcome {
917                cipher_suite: self.context().cipher_suite,
918                secrets,
919                encrypted_group_info,
920            }),
921        )
922    }
923}
924
925#[cfg(test)]
926pub(crate) mod test_utils {
927    use alloc::vec::Vec;
928
929    use crate::{
930        crypto::SignatureSecretKey,
931        tree_kem::{leaf_node::LeafNode, TreeKemPublic, UpdatePathNode},
932    };
933
934    #[derive(Copy, Clone, Debug)]
935    pub struct CommitModifiers {
936        pub modify_leaf: fn(&mut LeafNode, &SignatureSecretKey) -> Option<SignatureSecretKey>,
937        pub modify_tree: fn(&mut TreeKemPublic),
938        pub modify_path: fn(Vec<UpdatePathNode>) -> Vec<UpdatePathNode>,
939    }
940
941    impl Default for CommitModifiers {
942        fn default() -> Self {
943            Self {
944                modify_leaf: |_, _| None,
945                modify_tree: |_| (),
946                modify_path: |a| a,
947            }
948        }
949    }
950}
951
952#[cfg(test)]
953mod tests {
954    use mls_rs_core::{
955        error::IntoAnyError,
956        extension::ExtensionType,
957        identity::{CredentialType, IdentityProvider, MemberValidationContext},
958        time::MlsTime,
959    };
960
961    use crate::extension::RequiredCapabilitiesExt;
962    use crate::{
963        client::test_utils::{test_client_with_key_pkg, TEST_CIPHER_SUITE, TEST_PROTOCOL_VERSION},
964        client_builder::{
965            test_utils::TestClientConfig, BaseConfig, ClientBuilder, WithCryptoProvider,
966            WithIdentityProvider,
967        },
968        client_config::ClientConfig,
969        crypto::test_utils::TestCryptoProvider,
970        extension::test_utils::{TestExtension, TEST_EXTENSION_TYPE},
971        group::test_utils::{test_group, test_group_custom},
972        group::{
973            proposal::ProposalType,
974            test_utils::{test_group_custom_config, test_n_member_group},
975        },
976        identity::test_utils::get_test_signing_identity,
977        identity::{basic::BasicIdentityProvider, test_utils::get_test_basic_credential},
978        key_package::test_utils::test_key_package_message,
979        mls_rules::CommitOptions,
980        Client,
981    };
982
983    #[cfg(feature = "by_ref_proposal")]
984    use crate::crypto::test_utils::test_cipher_suite_provider;
985    #[cfg(feature = "by_ref_proposal")]
986    use crate::extension::ExternalSendersExt;
987    #[cfg(feature = "by_ref_proposal")]
988    use crate::group::mls_rules::DefaultMlsRules;
989
990    #[cfg(feature = "psk")]
991    use crate::{
992        group::proposal::PreSharedKeyProposal,
993        psk::{JustPreSharedKeyID, PreSharedKey, PreSharedKeyID},
994    };
995
996    use super::*;
997
998    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
999    async fn test_commit_builder_group() -> Group<TestClientConfig> {
1000        test_group_custom_config(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, |b| {
1001            b.custom_proposal_type(ProposalType::from(42))
1002                .extension_type(TEST_EXTENSION_TYPE.into())
1003        })
1004        .await
1005        .group
1006    }
1007
1008    fn assert_commit_builder_output<C: ClientConfig>(
1009        group: Group<C>,
1010        mut commit_output: CommitOutput,
1011        expected: Vec<Proposal>,
1012        welcome_count: usize,
1013    ) {
1014        let plaintext = commit_output.commit_message.into_plaintext().unwrap();
1015
1016        let commit_data = match plaintext.content.content {
1017            Content::Commit(commit) => commit,
1018            #[cfg(any(feature = "private_message", feature = "by_ref_proposal"))]
1019            _ => panic!("Found non-commit data"),
1020        };
1021
1022        assert_eq!(commit_data.proposals.len(), expected.len());
1023
1024        commit_data.proposals.into_iter().for_each(|proposal| {
1025            let proposal = match proposal {
1026                ProposalOrRef::Proposal(p) => p,
1027                #[cfg(feature = "by_ref_proposal")]
1028                ProposalOrRef::Reference(_) => panic!("found proposal reference"),
1029            };
1030
1031            #[cfg(feature = "psk")]
1032            if let Some(psk_id) = match proposal.as_ref() {
1033                Proposal::Psk(PreSharedKeyProposal { psk: PreSharedKeyID { key_id: JustPreSharedKeyID::External(psk_id), .. },}) => Some(psk_id),
1034                _ => None,
1035            } {
1036                let found = expected.iter().any(|item| matches!(item, Proposal::Psk(PreSharedKeyProposal { psk: PreSharedKeyID { key_id: JustPreSharedKeyID::External(id), .. }}) if id == psk_id));
1037
1038                assert!(found)
1039            } else {
1040                assert!(expected.contains(&proposal));
1041            }
1042
1043            #[cfg(not(feature = "psk"))]
1044            assert!(expected.contains(&proposal));
1045        });
1046
1047        if welcome_count > 0 {
1048            let welcome_msg = commit_output.welcome_messages.pop().unwrap();
1049
1050            assert_eq!(welcome_msg.version, group.state.context.protocol_version);
1051
1052            let welcome_msg = welcome_msg.into_welcome().unwrap();
1053
1054            assert_eq!(welcome_msg.cipher_suite, group.state.context.cipher_suite);
1055            assert_eq!(welcome_msg.secrets.len(), welcome_count);
1056        } else {
1057            assert!(commit_output.welcome_messages.is_empty());
1058        }
1059    }
1060
1061    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1062    async fn test_commit_builder_add() {
1063        let mut group = test_commit_builder_group().await;
1064
1065        let test_key_package =
1066            test_key_package_message(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "alice").await;
1067
1068        let commit_output = group
1069            .commit_builder()
1070            .add_member(test_key_package.clone())
1071            .unwrap()
1072            .build()
1073            .await
1074            .unwrap();
1075
1076        let expected_add = group.add_proposal(test_key_package).unwrap();
1077
1078        assert_commit_builder_output(group, commit_output, vec![expected_add], 1)
1079    }
1080
1081    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1082    async fn test_commit_builder_add_with_ext() {
1083        let mut group = test_commit_builder_group().await;
1084
1085        let (bob_client, bob_key_package) =
1086            test_client_with_key_pkg(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "bob").await;
1087
1088        let ext = TestExtension { foo: 42 };
1089        let mut extension_list = ExtensionList::default();
1090        extension_list.set_from(ext.clone()).unwrap();
1091
1092        let welcome_message = group
1093            .commit_builder()
1094            .add_member(bob_key_package)
1095            .unwrap()
1096            .set_group_info_ext(extension_list)
1097            .build()
1098            .await
1099            .unwrap()
1100            .welcome_messages
1101            .remove(0);
1102
1103        let (_, context) = bob_client.join_group(None, &welcome_message).await.unwrap();
1104
1105        assert_eq!(
1106            context
1107                .group_info_extensions
1108                .get_as::<TestExtension>()
1109                .unwrap()
1110                .unwrap(),
1111            ext
1112        );
1113    }
1114
1115    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1116    async fn test_commit_builder_remove() {
1117        let mut group = test_commit_builder_group().await;
1118        let test_key_package =
1119            test_key_package_message(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "alice").await;
1120
1121        group
1122            .commit_builder()
1123            .add_member(test_key_package)
1124            .unwrap()
1125            .build()
1126            .await
1127            .unwrap();
1128
1129        group.apply_pending_commit().await.unwrap();
1130
1131        let commit_output = group
1132            .commit_builder()
1133            .remove_member(1)
1134            .unwrap()
1135            .build()
1136            .await
1137            .unwrap();
1138
1139        let expected_remove = group.remove_proposal(1).unwrap();
1140
1141        assert_commit_builder_output(group, commit_output, vec![expected_remove], 0);
1142    }
1143
1144    #[cfg(feature = "psk")]
1145    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1146    async fn test_commit_builder_psk() {
1147        let mut group = test_commit_builder_group().await;
1148        let test_psk = ExternalPskId::new(vec![1]);
1149
1150        group
1151            .config
1152            .secret_store()
1153            .insert(test_psk.clone(), PreSharedKey::from(vec![1]));
1154
1155        let commit_output = group
1156            .commit_builder()
1157            .add_external_psk(test_psk.clone())
1158            .unwrap()
1159            .build()
1160            .await
1161            .unwrap();
1162
1163        let key_id = JustPreSharedKeyID::External(test_psk);
1164        let expected_psk = group.psk_proposal(key_id).unwrap();
1165
1166        assert_commit_builder_output(group, commit_output, vec![expected_psk], 0)
1167    }
1168
1169    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1170    async fn test_commit_builder_group_context_ext() {
1171        let mut group = test_commit_builder_group().await;
1172        let mut test_ext = ExtensionList::default();
1173        test_ext
1174            .set_from(RequiredCapabilitiesExt::default())
1175            .unwrap();
1176
1177        let commit_output = group
1178            .commit_builder()
1179            .set_group_context_ext(test_ext.clone())
1180            .unwrap()
1181            .build()
1182            .await
1183            .unwrap();
1184
1185        let expected_ext = group.group_context_extensions_proposal(test_ext);
1186
1187        assert_commit_builder_output(group, commit_output, vec![expected_ext], 0);
1188    }
1189
1190    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1191    async fn test_commit_builder_reinit() {
1192        let mut group = test_commit_builder_group().await;
1193        let test_group_id = "foo".as_bytes().to_vec();
1194        let test_cipher_suite = TEST_CIPHER_SUITE;
1195        let test_protocol_version = TEST_PROTOCOL_VERSION;
1196        let mut test_ext = ExtensionList::default();
1197
1198        test_ext
1199            .set_from(RequiredCapabilitiesExt::default())
1200            .unwrap();
1201
1202        let commit_output = group
1203            .commit_builder()
1204            .reinit(
1205                Some(test_group_id.clone()),
1206                test_protocol_version,
1207                test_cipher_suite,
1208                test_ext.clone(),
1209            )
1210            .unwrap()
1211            .build()
1212            .await
1213            .unwrap();
1214
1215        let expected_reinit = group
1216            .reinit_proposal(
1217                Some(test_group_id),
1218                test_protocol_version,
1219                test_cipher_suite,
1220                test_ext,
1221            )
1222            .unwrap();
1223
1224        assert_commit_builder_output(group, commit_output, vec![expected_reinit], 0);
1225    }
1226
1227    #[cfg(feature = "custom_proposal")]
1228    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1229    async fn test_commit_builder_custom_proposal() {
1230        let mut group = test_commit_builder_group().await;
1231
1232        let proposal = CustomProposal::new(42.into(), vec![0, 1]);
1233
1234        let commit_output = group
1235            .commit_builder()
1236            .custom_proposal(proposal.clone())
1237            .build()
1238            .await
1239            .unwrap();
1240
1241        assert_commit_builder_output(group, commit_output, vec![Proposal::Custom(proposal)], 0);
1242    }
1243
1244    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1245    async fn test_commit_builder_chaining() {
1246        let mut group = test_commit_builder_group().await;
1247        let kp1 = test_key_package_message(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "alice").await;
1248        let kp2 = test_key_package_message(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "bob").await;
1249
1250        let expected_adds = vec![
1251            group.add_proposal(kp1.clone()).unwrap(),
1252            group.add_proposal(kp2.clone()).unwrap(),
1253        ];
1254
1255        let commit_output = group
1256            .commit_builder()
1257            .add_member(kp1)
1258            .unwrap()
1259            .add_member(kp2)
1260            .unwrap()
1261            .build()
1262            .await
1263            .unwrap();
1264
1265        assert_commit_builder_output(group, commit_output, expected_adds, 2);
1266    }
1267
1268    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1269    async fn test_commit_builder_empty_commit() {
1270        let mut group = test_commit_builder_group().await;
1271
1272        let commit_output = group.commit_builder().build().await.unwrap();
1273
1274        assert_commit_builder_output(group, commit_output, vec![], 0);
1275    }
1276
1277    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1278    async fn test_commit_builder_authenticated_data() {
1279        let mut group = test_commit_builder_group().await;
1280        let test_data = "test".as_bytes().to_vec();
1281
1282        let commit_output = group
1283            .commit_builder()
1284            .authenticated_data(test_data.clone())
1285            .build()
1286            .await
1287            .unwrap();
1288
1289        assert_eq!(
1290            commit_output
1291                .commit_message
1292                .into_plaintext()
1293                .unwrap()
1294                .content
1295                .authenticated_data,
1296            test_data
1297        );
1298    }
1299
1300    #[cfg(feature = "by_ref_proposal")]
1301    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1302    async fn test_commit_builder_multiple_welcome_messages() {
1303        let mut group = test_group_custom_config(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, |b| {
1304            let options = CommitOptions::new().with_single_welcome_message(false);
1305            b.mls_rules(DefaultMlsRules::new().with_commit_options(options))
1306        })
1307        .await;
1308
1309        let (alice, alice_kp) =
1310            test_client_with_key_pkg(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "a").await;
1311
1312        let (bob, bob_kp) =
1313            test_client_with_key_pkg(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "b").await;
1314
1315        group.propose_add(alice_kp.clone(), vec![]).await.unwrap();
1316
1317        group.propose_add(bob_kp.clone(), vec![]).await.unwrap();
1318
1319        let output = group.commit(Vec::new()).await.unwrap();
1320        let welcomes = output.welcome_messages;
1321
1322        let cs = test_cipher_suite_provider(TEST_CIPHER_SUITE);
1323
1324        for (client, kp) in [(alice, alice_kp), (bob, bob_kp)] {
1325            let kp_ref = kp.key_package_reference(&cs).await.unwrap().unwrap();
1326
1327            let welcome = welcomes
1328                .iter()
1329                .find(|w| w.welcome_key_package_references().contains(&&kp_ref))
1330                .unwrap();
1331
1332            client.join_group(None, welcome).await.unwrap();
1333
1334            assert_eq!(welcome.clone().into_welcome().unwrap().secrets.len(), 1);
1335        }
1336    }
1337
1338    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1339    async fn commit_can_change_credential() {
1340        let cs = TEST_CIPHER_SUITE;
1341        let mut groups = test_n_member_group(TEST_PROTOCOL_VERSION, cs, 3).await;
1342        let (identity, secret_key) = get_test_signing_identity(cs, b"member").await;
1343
1344        let commit_output = groups[0]
1345            .commit_builder()
1346            .set_new_signing_identity(secret_key, identity.clone())
1347            .build()
1348            .await
1349            .unwrap();
1350
1351        // Check that the credential was updated by in the committer's state.
1352        groups[0].process_pending_commit().await.unwrap();
1353        let new_member = groups[0].roster().member_with_index(0).unwrap();
1354
1355        assert_eq!(
1356            new_member.signing_identity.credential,
1357            get_test_basic_credential(b"member".to_vec())
1358        );
1359
1360        assert_eq!(
1361            new_member.signing_identity.signature_key,
1362            identity.signature_key
1363        );
1364
1365        // Check that the credential was updated in another member's state.
1366        groups[1]
1367            .process_message(commit_output.commit_message)
1368            .await
1369            .unwrap();
1370
1371        let new_member = groups[1].roster().member_with_index(0).unwrap();
1372
1373        assert_eq!(
1374            new_member.signing_identity.credential,
1375            get_test_basic_credential(b"member".to_vec())
1376        );
1377
1378        assert_eq!(
1379            new_member.signing_identity.signature_key,
1380            identity.signature_key
1381        );
1382    }
1383
1384    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1385    async fn commit_includes_tree_if_no_ratchet_tree_ext() {
1386        let mut group = test_group_custom(
1387            TEST_PROTOCOL_VERSION,
1388            TEST_CIPHER_SUITE,
1389            Default::default(),
1390            None,
1391            Some(CommitOptions::new().with_ratchet_tree_extension(false)),
1392        )
1393        .await;
1394
1395        let commit = group.commit(vec![]).await.unwrap();
1396
1397        group.apply_pending_commit().await.unwrap();
1398
1399        let new_tree = group.export_tree();
1400
1401        assert_eq!(new_tree, commit.ratchet_tree.unwrap())
1402    }
1403
1404    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1405    async fn commit_does_not_include_tree_if_ratchet_tree_ext() {
1406        let mut group = test_group_custom(
1407            TEST_PROTOCOL_VERSION,
1408            TEST_CIPHER_SUITE,
1409            Default::default(),
1410            None,
1411            Some(CommitOptions::new().with_ratchet_tree_extension(true)),
1412        )
1413        .await;
1414
1415        let commit = group.commit(vec![]).await.unwrap();
1416
1417        assert!(commit.ratchet_tree.is_none());
1418    }
1419
1420    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1421    async fn commit_includes_external_commit_group_info_if_requested() {
1422        let mut group = test_group_custom(
1423            TEST_PROTOCOL_VERSION,
1424            TEST_CIPHER_SUITE,
1425            Default::default(),
1426            None,
1427            Some(
1428                CommitOptions::new()
1429                    .with_allow_external_commit(true)
1430                    .with_ratchet_tree_extension(false),
1431            ),
1432        )
1433        .await;
1434
1435        let commit = group.commit(vec![]).await.unwrap();
1436
1437        let info = commit
1438            .external_commit_group_info
1439            .unwrap()
1440            .into_group_info()
1441            .unwrap();
1442
1443        assert!(!info.extensions.has_extension(ExtensionType::RATCHET_TREE));
1444        assert!(info.extensions.has_extension(ExtensionType::EXTERNAL_PUB));
1445    }
1446
1447    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1448    async fn commit_includes_external_commit_and_tree_if_requested() {
1449        let mut group = test_group_custom(
1450            TEST_PROTOCOL_VERSION,
1451            TEST_CIPHER_SUITE,
1452            Default::default(),
1453            None,
1454            Some(
1455                CommitOptions::new()
1456                    .with_allow_external_commit(true)
1457                    .with_ratchet_tree_extension(true),
1458            ),
1459        )
1460        .await;
1461
1462        let commit = group.commit(vec![]).await.unwrap();
1463
1464        let info = commit
1465            .external_commit_group_info
1466            .unwrap()
1467            .into_group_info()
1468            .unwrap();
1469
1470        assert!(info.extensions.has_extension(ExtensionType::RATCHET_TREE));
1471        assert!(info.extensions.has_extension(ExtensionType::EXTERNAL_PUB));
1472    }
1473
1474    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1475    async fn commit_does_not_include_external_commit_group_info_if_not_requested() {
1476        let mut group = test_group_custom(
1477            TEST_PROTOCOL_VERSION,
1478            TEST_CIPHER_SUITE,
1479            Default::default(),
1480            None,
1481            Some(CommitOptions::new().with_allow_external_commit(false)),
1482        )
1483        .await;
1484
1485        let commit = group.commit(vec![]).await.unwrap();
1486
1487        assert!(commit.external_commit_group_info.is_none());
1488    }
1489
1490    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1491    async fn member_identity_is_validated_against_new_extensions() {
1492        let alice = client_with_test_extension(b"alice").await;
1493        let mut alice = alice
1494            .create_group(ExtensionList::new(), Default::default())
1495            .await
1496            .unwrap();
1497
1498        let bob = client_with_test_extension(b"bob").await;
1499        let bob_kp = bob
1500            .generate_key_package_message(Default::default(), Default::default())
1501            .await
1502            .unwrap();
1503
1504        let mut extension_list = ExtensionList::new();
1505        let extension = TestExtension { foo: b'a' };
1506        extension_list.set_from(extension).unwrap();
1507
1508        let res = alice
1509            .commit_builder()
1510            .add_member(bob_kp)
1511            .unwrap()
1512            .set_group_context_ext(extension_list.clone())
1513            .unwrap()
1514            .build()
1515            .await;
1516
1517        assert!(res.is_err());
1518
1519        let alex = client_with_test_extension(b"alex").await;
1520
1521        alice
1522            .commit_builder()
1523            .add_member(
1524                alex.generate_key_package_message(Default::default(), Default::default())
1525                    .await
1526                    .unwrap(),
1527            )
1528            .unwrap()
1529            .set_group_context_ext(extension_list.clone())
1530            .unwrap()
1531            .build()
1532            .await
1533            .unwrap();
1534    }
1535
1536    #[cfg(feature = "by_ref_proposal")]
1537    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1538    async fn server_identity_is_validated_against_new_extensions() {
1539        let alice = client_with_test_extension(b"alice").await;
1540        let mut alice = alice
1541            .create_group(ExtensionList::new(), Default::default())
1542            .await
1543            .unwrap();
1544
1545        let mut extension_list = ExtensionList::new();
1546        let extension = TestExtension { foo: b'a' };
1547        extension_list.set_from(extension).unwrap();
1548
1549        let (alex_server, _) = get_test_signing_identity(TEST_CIPHER_SUITE, b"alex").await;
1550
1551        let mut alex_extensions = extension_list.clone();
1552
1553        alex_extensions
1554            .set_from(ExternalSendersExt {
1555                allowed_senders: vec![alex_server],
1556            })
1557            .unwrap();
1558
1559        let res = alice
1560            .commit_builder()
1561            .set_group_context_ext(alex_extensions)
1562            .unwrap()
1563            .build()
1564            .await;
1565
1566        assert!(res.is_err());
1567
1568        let (bob_server, _) = get_test_signing_identity(TEST_CIPHER_SUITE, b"bob").await;
1569
1570        let mut bob_extensions = extension_list;
1571
1572        bob_extensions
1573            .set_from(ExternalSendersExt {
1574                allowed_senders: vec![bob_server],
1575            })
1576            .unwrap();
1577
1578        alice
1579            .commit_builder()
1580            .set_group_context_ext(bob_extensions)
1581            .unwrap()
1582            .build()
1583            .await
1584            .unwrap();
1585    }
1586
1587    #[derive(Debug, Clone)]
1588    struct IdentityProviderWithExtension(BasicIdentityProvider);
1589
1590    #[derive(Clone, Debug)]
1591    #[cfg_attr(feature = "std", derive(thiserror::Error))]
1592    #[cfg_attr(feature = "std", error("test error"))]
1593    struct IdentityProviderWithExtensionError {}
1594
1595    impl IntoAnyError for IdentityProviderWithExtensionError {
1596        #[cfg(feature = "std")]
1597        fn into_dyn_error(self) -> Result<Box<dyn std::error::Error + Send + Sync>, Self> {
1598            Ok(self.into())
1599        }
1600    }
1601
1602    impl IdentityProviderWithExtension {
1603        // True if the identity starts with the character `foo` from `TestExtension` or if `TestExtension`
1604        // is not set.
1605        #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
1606        async fn starts_with_foo(
1607            &self,
1608            identity: &SigningIdentity,
1609            _timestamp: Option<MlsTime>,
1610            extensions: Option<&ExtensionList>,
1611        ) -> bool {
1612            if let Some(extensions) = extensions {
1613                if let Some(ext) = extensions.get_as::<TestExtension>().unwrap() {
1614                    self.identity(identity, extensions).await.unwrap()[0] == ext.foo
1615                } else {
1616                    true
1617                }
1618            } else {
1619                true
1620            }
1621        }
1622    }
1623
1624    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
1625    #[cfg_attr(mls_build_async, maybe_async::must_be_async)]
1626    impl IdentityProvider for IdentityProviderWithExtension {
1627        type Error = IdentityProviderWithExtensionError;
1628
1629        async fn validate_member(
1630            &self,
1631            identity: &SigningIdentity,
1632            timestamp: Option<MlsTime>,
1633            context: MemberValidationContext<'_>,
1634        ) -> Result<(), Self::Error> {
1635            self.starts_with_foo(identity, timestamp, context.new_extensions())
1636                .await
1637                .then_some(())
1638                .ok_or(IdentityProviderWithExtensionError {})
1639        }
1640
1641        async fn validate_external_sender(
1642            &self,
1643            identity: &SigningIdentity,
1644            timestamp: Option<MlsTime>,
1645            extensions: Option<&ExtensionList>,
1646        ) -> Result<(), Self::Error> {
1647            (!self.starts_with_foo(identity, timestamp, extensions).await)
1648                .then_some(())
1649                .ok_or(IdentityProviderWithExtensionError {})
1650        }
1651
1652        async fn identity(
1653            &self,
1654            signing_identity: &SigningIdentity,
1655            extensions: &ExtensionList,
1656        ) -> Result<Vec<u8>, Self::Error> {
1657            self.0
1658                .identity(signing_identity, extensions)
1659                .await
1660                .map_err(|_| IdentityProviderWithExtensionError {})
1661        }
1662
1663        async fn valid_successor(
1664            &self,
1665            _predecessor: &SigningIdentity,
1666            _successor: &SigningIdentity,
1667            _extensions: &ExtensionList,
1668        ) -> Result<bool, Self::Error> {
1669            Ok(true)
1670        }
1671
1672        fn supported_types(&self) -> Vec<CredentialType> {
1673            self.0.supported_types()
1674        }
1675    }
1676
1677    type ExtensionClientConfig = WithIdentityProvider<
1678        IdentityProviderWithExtension,
1679        WithCryptoProvider<TestCryptoProvider, BaseConfig>,
1680    >;
1681
1682    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
1683    async fn client_with_test_extension(name: &[u8]) -> Client<ExtensionClientConfig> {
1684        let (identity, secret_key) = get_test_signing_identity(TEST_CIPHER_SUITE, name).await;
1685
1686        ClientBuilder::new()
1687            .crypto_provider(TestCryptoProvider::new())
1688            .extension_types(vec![TEST_EXTENSION_TYPE.into()])
1689            .identity_provider(IdentityProviderWithExtension(BasicIdentityProvider::new()))
1690            .signing_identity(identity, secret_key, TEST_CIPHER_SUITE)
1691            .build()
1692    }
1693
1694    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1695    async fn detached_commit() {
1696        let mut group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
1697
1698        let (_commit, secrets) = group.commit_builder().build_detached().await.unwrap();
1699        assert!(group.pending_commit.is_none());
1700        group.apply_detached_commit(secrets).await.unwrap();
1701        assert_eq!(group.context().epoch, 1);
1702    }
1703}