mls_rs/group/
message_processor.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// Copyright by contributors to this project.
3// SPDX-License-Identifier: (Apache-2.0 OR MIT)
4
5#[cfg(all(
6    feature = "by_ref_proposal",
7    feature = "custom_proposal",
8    feature = "self_remove_proposal"
9))]
10use super::SelfRemoveProposal;
11use super::{
12    commit_sender,
13    confirmation_tag::ConfirmationTag,
14    framing::{
15        ApplicationData, Content, ContentType, MlsMessage, MlsMessagePayload, PublicMessage, Sender,
16    },
17    message_signature::AuthenticatedContent,
18    mls_rules::{CommitDirection, MlsRules},
19    proposal_filter::ProposalBundle,
20    state::GroupState,
21    transcript_hash::InterimTranscriptHash,
22    transcript_hashes, validate_group_info_member, GroupContext, GroupInfo, ReInitProposal,
23    RemoveProposal, Welcome,
24};
25use crate::{
26    client::MlsError,
27    key_package::validate_key_package_properties,
28    time::MlsTime,
29    tree_kem::{
30        leaf_node_validator::{LeafNodeValidator, ValidationContext},
31        node::LeafIndex,
32        path_secret::PathSecret,
33        validate_update_path, TreeKemPrivate, TreeKemPublic, ValidatedUpdatePath,
34    },
35    CipherSuiteProvider, KeyPackage,
36};
37use itertools::Itertools;
38use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
39
40use alloc::boxed::Box;
41use alloc::vec::Vec;
42use core::fmt::{self, Debug};
43use mls_rs_core::{
44    identity::{IdentityProvider, MemberValidationContext},
45    protocol_version::ProtocolVersion,
46    psk::PreSharedKeyStorage,
47};
48
49#[cfg(feature = "by_ref_proposal")]
50use super::proposal_ref::ProposalRef;
51
52#[cfg(not(feature = "by_ref_proposal"))]
53use crate::group::proposal_cache::resolve_for_commit;
54
55use super::proposal::Proposal;
56use super::proposal_filter::ProposalInfo;
57
58#[cfg(feature = "private_message")]
59use crate::group::framing::PrivateMessage;
60
61#[derive(Debug)]
62pub(crate) struct ProvisionalState {
63    pub(crate) public_tree: TreeKemPublic,
64    pub(crate) applied_proposals: ProposalBundle,
65    pub(crate) group_context: GroupContext,
66    pub(crate) external_init_index: Option<LeafIndex>,
67    pub(crate) indexes_of_added_kpkgs: Vec<LeafIndex>,
68    pub(crate) unused_proposals: Vec<ProposalInfo<Proposal>>,
69}
70
71//By default, the path field of a Commit MUST be populated. The path field MAY be omitted if
72//(a) it covers at least one proposal and (b) none of the proposals covered by the Commit are
73//of "path required" types. A proposal type requires a path if it cannot change the group
74//membership in a way that requires the forward secrecy and post-compromise security guarantees
75//that an UpdatePath provides. The only proposal types defined in this document that do not
76//require a path are:
77
78// add
79// psk
80// reinit
81pub(crate) fn path_update_required(proposals: &ProposalBundle) -> bool {
82    let res = !proposals.external_init_proposals().is_empty();
83
84    #[cfg(feature = "by_ref_proposal")]
85    let res = res || !proposals.update_proposals().is_empty();
86
87    #[cfg(all(
88        feature = "by_ref_proposal",
89        feature = "custom_proposal",
90        feature = "self_remove_proposal"
91    ))]
92    let res = res || !proposals.self_removes.is_empty();
93
94    res || proposals.length() == 0
95        || proposals.group_context_extensions_proposal().is_some()
96        || !proposals.remove_proposals().is_empty()
97}
98
99#[cfg_attr(
100    all(feature = "ffi", not(test)),
101    safer_ffi_gen::ffi_type(clone, opaque)
102)]
103#[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode)]
104#[non_exhaustive]
105pub struct NewEpoch {
106    pub epoch: u64,
107    pub prior_state: GroupState,
108    pub applied_proposals: Vec<ProposalInfo<Proposal>>,
109    pub unused_proposals: Vec<ProposalInfo<Proposal>>,
110}
111
112impl NewEpoch {
113    pub(crate) fn new(prior_state: GroupState, provisional_state: &ProvisionalState) -> NewEpoch {
114        NewEpoch {
115            epoch: provisional_state.group_context.epoch,
116            prior_state,
117            unused_proposals: provisional_state.unused_proposals.clone(),
118            applied_proposals: provisional_state
119                .applied_proposals
120                .clone()
121                .into_proposals()
122                .collect_vec(),
123        }
124    }
125}
126
127#[cfg(all(feature = "ffi", not(test)))]
128#[safer_ffi_gen::safer_ffi_gen]
129impl NewEpoch {
130    pub fn epoch(&self) -> u64 {
131        self.epoch
132    }
133
134    pub fn prior_state(&self) -> &GroupState {
135        &self.prior_state
136    }
137
138    pub fn applied_proposals(&self) -> &[ProposalInfo<Proposal>] {
139        &self.applied_proposals
140    }
141
142    pub fn unused_proposals(&self) -> &[ProposalInfo<Proposal>] {
143        &self.unused_proposals
144    }
145}
146
147#[cfg_attr(
148    all(feature = "ffi", not(test)),
149    safer_ffi_gen::ffi_type(clone, opaque)
150)]
151#[derive(Clone, Debug, PartialEq)]
152pub enum CommitEffect {
153    NewEpoch(Box<NewEpoch>),
154    Removed {
155        new_epoch: Box<NewEpoch>,
156        remover: Sender,
157    },
158    ReInit(ProposalInfo<ReInitProposal>),
159}
160
161impl MlsSize for CommitEffect {
162    fn mls_encoded_len(&self) -> usize {
163        0u8.mls_encoded_len()
164            + match self {
165                Self::NewEpoch(e) => e.mls_encoded_len(),
166                Self::Removed { new_epoch, remover } => {
167                    new_epoch.mls_encoded_len() + remover.mls_encoded_len()
168                }
169                Self::ReInit(r) => r.mls_encoded_len(),
170            }
171    }
172}
173
174impl MlsEncode for CommitEffect {
175    fn mls_encode(&self, writer: &mut Vec<u8>) -> Result<(), mls_rs_codec::Error> {
176        match self {
177            Self::NewEpoch(e) => {
178                1u8.mls_encode(writer)?;
179                e.mls_encode(writer)?;
180            }
181            Self::Removed { new_epoch, remover } => {
182                2u8.mls_encode(writer)?;
183                new_epoch.mls_encode(writer)?;
184                remover.mls_encode(writer)?;
185            }
186            Self::ReInit(r) => {
187                3u8.mls_encode(writer)?;
188                r.mls_encode(writer)?;
189            }
190        }
191
192        Ok(())
193    }
194}
195
196impl MlsDecode for CommitEffect {
197    fn mls_decode(reader: &mut &[u8]) -> Result<Self, mls_rs_codec::Error> {
198        match u8::mls_decode(reader)? {
199            1u8 => Ok(Self::NewEpoch(NewEpoch::mls_decode(reader)?.into())),
200            2u8 => Ok(Self::Removed {
201                new_epoch: NewEpoch::mls_decode(reader)?.into(),
202                remover: Sender::mls_decode(reader)?,
203            }),
204            3u8 => Ok(Self::ReInit(ProposalInfo::mls_decode(reader)?)),
205            _ => Err(mls_rs_codec::Error::UnsupportedEnumDiscriminant),
206        }
207    }
208}
209
210#[cfg_attr(
211    all(feature = "ffi", not(test)),
212    safer_ffi_gen::ffi_type(clone, opaque)
213)]
214#[derive(Debug, Clone)]
215#[allow(clippy::large_enum_variant)]
216/// An event generated as a result of processing a message for a group with
217/// [`Group::process_incoming_message`](crate::group::Group::process_incoming_message).
218pub enum ReceivedMessage {
219    /// An application message was decrypted.
220    ApplicationMessage(ApplicationMessageDescription),
221    /// A new commit was processed creating a new group state.
222    Commit(CommitMessageDescription),
223    /// A proposal was received.
224    Proposal(ProposalMessageDescription),
225    /// Validated GroupInfo object
226    GroupInfo(GroupInfo),
227    /// Validated welcome message
228    Welcome,
229    /// Validated key package
230    KeyPackage(KeyPackage),
231}
232
233impl TryFrom<ApplicationMessageDescription> for ReceivedMessage {
234    type Error = MlsError;
235
236    fn try_from(value: ApplicationMessageDescription) -> Result<Self, Self::Error> {
237        Ok(ReceivedMessage::ApplicationMessage(value))
238    }
239}
240
241impl From<CommitMessageDescription> for ReceivedMessage {
242    fn from(value: CommitMessageDescription) -> Self {
243        ReceivedMessage::Commit(value)
244    }
245}
246
247impl From<ProposalMessageDescription> for ReceivedMessage {
248    fn from(value: ProposalMessageDescription) -> Self {
249        ReceivedMessage::Proposal(value)
250    }
251}
252
253impl From<GroupInfo> for ReceivedMessage {
254    fn from(value: GroupInfo) -> Self {
255        ReceivedMessage::GroupInfo(value)
256    }
257}
258
259impl From<Welcome> for ReceivedMessage {
260    fn from(_: Welcome) -> Self {
261        ReceivedMessage::Welcome
262    }
263}
264
265impl From<KeyPackage> for ReceivedMessage {
266    fn from(value: KeyPackage) -> Self {
267        ReceivedMessage::KeyPackage(value)
268    }
269}
270
271#[cfg_attr(
272    all(feature = "ffi", not(test)),
273    safer_ffi_gen::ffi_type(clone, opaque)
274)]
275#[derive(Clone, PartialEq, Eq)]
276/// Description of a MLS application message.
277pub struct ApplicationMessageDescription {
278    /// Index of this user in the group state.
279    pub sender_index: u32,
280    /// Received application data.
281    data: ApplicationData,
282    /// Plaintext authenticated data in the received MLS packet.
283    pub authenticated_data: Vec<u8>,
284}
285
286impl Debug for ApplicationMessageDescription {
287    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
288        f.debug_struct("ApplicationMessageDescription")
289            .field("sender_index", &self.sender_index)
290            .field("data", &self.data)
291            .field(
292                "authenticated_data",
293                &mls_rs_core::debug::pretty_bytes(&self.authenticated_data),
294            )
295            .finish()
296    }
297}
298
299#[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen)]
300impl ApplicationMessageDescription {
301    pub fn data(&self) -> &[u8] {
302        self.data.as_bytes()
303    }
304}
305
306#[cfg_attr(
307    all(feature = "ffi", not(test)),
308    safer_ffi_gen::ffi_type(clone, opaque)
309)]
310#[derive(Clone, PartialEq, MlsSize, MlsEncode, MlsDecode)]
311#[non_exhaustive]
312/// Description of a processed MLS commit message.
313pub struct CommitMessageDescription {
314    /// True if this is the result of an external commit.
315    pub is_external: bool,
316    /// The index in the group state of the member who performed this commit.
317    pub committer: u32,
318    /// A full description of group state changes as a result of this commit.
319    pub effect: CommitEffect,
320    /// Plaintext authenticated data in the received MLS packet.
321    #[mls_codec(with = "mls_rs_codec::byte_vec")]
322    pub authenticated_data: Vec<u8>,
323}
324
325impl Debug for CommitMessageDescription {
326    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
327        f.debug_struct("CommitMessageDescription")
328            .field("is_external", &self.is_external)
329            .field("committer", &self.committer)
330            .field("effect", &self.effect)
331            .field(
332                "authenticated_data",
333                &mls_rs_core::debug::pretty_bytes(&self.authenticated_data),
334            )
335            .finish()
336    }
337}
338
339#[derive(Debug, Clone, Copy, PartialEq, Eq, MlsEncode, MlsDecode, MlsSize)]
340#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
341#[repr(u8)]
342/// Proposal sender type.
343pub enum ProposalSender {
344    /// A current member of the group by index in the group state.
345    Member(u32) = 1u8,
346    /// An external entity by index within an
347    /// [`ExternalSendersExt`](crate::extension::built_in::ExternalSendersExt).
348    External(u32) = 2u8,
349    /// A new member proposing their addition to the group.
350    NewMember = 3u8,
351}
352
353impl TryFrom<Sender> for ProposalSender {
354    type Error = MlsError;
355
356    fn try_from(value: Sender) -> Result<Self, Self::Error> {
357        match value {
358            Sender::Member(index) => Ok(Self::Member(index)),
359            #[cfg(feature = "by_ref_proposal")]
360            Sender::External(index) => Ok(Self::External(index)),
361            #[cfg(feature = "by_ref_proposal")]
362            Sender::NewMemberProposal => Ok(Self::NewMember),
363            Sender::NewMemberCommit => Err(MlsError::InvalidSender),
364        }
365    }
366}
367
368#[cfg(feature = "by_ref_proposal")]
369#[cfg_attr(
370    all(feature = "ffi", not(test)),
371    safer_ffi_gen::ffi_type(clone, opaque)
372)]
373#[derive(Clone, MlsEncode, MlsDecode, MlsSize, PartialEq)]
374#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
375#[non_exhaustive]
376/// Description of a processed MLS proposal message.
377pub struct ProposalMessageDescription {
378    /// Sender of the proposal.
379    pub sender: ProposalSender,
380    /// Proposal content.
381    pub proposal: Proposal,
382    /// Plaintext authenticated data in the received MLS packet.
383    pub authenticated_data: Vec<u8>,
384    /// Proposal reference.
385    pub proposal_ref: ProposalRef,
386}
387
388#[cfg(feature = "by_ref_proposal")]
389impl Debug for ProposalMessageDescription {
390    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
391        f.debug_struct("ProposalMessageDescription")
392            .field("sender", &self.sender)
393            .field("proposal", &self.proposal)
394            .field(
395                "authenticated_data",
396                &mls_rs_core::debug::pretty_bytes(&self.authenticated_data),
397            )
398            .field("proposal_ref", &self.proposal_ref)
399            .finish()
400    }
401}
402
403#[cfg(feature = "by_ref_proposal")]
404#[derive(MlsSize, MlsEncode, MlsDecode)]
405pub struct CachedProposal {
406    pub(crate) proposal: Proposal,
407    pub(crate) proposal_ref: ProposalRef,
408    pub(crate) sender: Sender,
409}
410
411#[cfg(feature = "by_ref_proposal")]
412impl CachedProposal {
413    /// Deserialize the proposal
414    pub fn from_bytes(bytes: &[u8]) -> Result<Self, MlsError> {
415        Ok(Self::mls_decode(&mut &*bytes)?)
416    }
417
418    /// Serialize the proposal
419    pub fn to_bytes(&self) -> Result<Vec<u8>, MlsError> {
420        Ok(self.mls_encode_to_vec()?)
421    }
422}
423
424#[cfg(feature = "by_ref_proposal")]
425impl ProposalMessageDescription {
426    pub fn cached_proposal(self) -> CachedProposal {
427        let sender = match self.sender {
428            ProposalSender::Member(i) => Sender::Member(i),
429            ProposalSender::External(i) => Sender::External(i),
430            ProposalSender::NewMember => Sender::NewMemberProposal,
431        };
432
433        CachedProposal {
434            proposal: self.proposal,
435            proposal_ref: self.proposal_ref,
436            sender,
437        }
438    }
439
440    pub fn proposal_ref(&self) -> Vec<u8> {
441        self.proposal_ref.to_vec()
442    }
443
444    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
445    pub(crate) async fn new<C: CipherSuiteProvider>(
446        cs: &C,
447        content: &AuthenticatedContent,
448        proposal: Proposal,
449    ) -> Result<Self, MlsError> {
450        Ok(ProposalMessageDescription {
451            authenticated_data: content.content.authenticated_data.clone(),
452            proposal,
453            sender: content.content.sender.try_into()?,
454            proposal_ref: ProposalRef::from_content(cs, content).await?,
455        })
456    }
457}
458
459#[cfg(not(feature = "by_ref_proposal"))]
460#[cfg_attr(
461    all(feature = "ffi", not(test)),
462    safer_ffi_gen::ffi_type(clone, opaque)
463)]
464#[derive(Debug, Clone)]
465/// Description of a processed MLS proposal message.
466pub struct ProposalMessageDescription {}
467
468#[allow(clippy::large_enum_variant)]
469pub(crate) enum EventOrContent<E> {
470    #[cfg_attr(
471        not(all(feature = "private_message", feature = "external_client")),
472        allow(dead_code)
473    )]
474    Event(E),
475    Content(AuthenticatedContent),
476}
477
478#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
479#[cfg_attr(all(target_arch = "wasm32", mls_build_async), maybe_async::must_be_async(?Send))]
480#[cfg_attr(
481    all(not(target_arch = "wasm32"), mls_build_async),
482    maybe_async::must_be_async
483)]
484pub(crate) trait MessageProcessor: Send + Sync {
485    type OutputType: TryFrom<ApplicationMessageDescription, Error = MlsError>
486        + From<CommitMessageDescription>
487        + From<ProposalMessageDescription>
488        + From<GroupInfo>
489        + From<Welcome>
490        + From<KeyPackage>
491        + Send;
492
493    type MlsRules: MlsRules;
494    type IdentityProvider: IdentityProvider;
495    type CipherSuiteProvider: CipherSuiteProvider;
496    type PreSharedKeyStorage: PreSharedKeyStorage;
497
498    async fn process_incoming_message(
499        &mut self,
500        message: MlsMessage,
501        #[cfg(feature = "by_ref_proposal")] cache_proposal: bool,
502    ) -> Result<Self::OutputType, MlsError> {
503        self.process_incoming_message_with_time(
504            message,
505            #[cfg(feature = "by_ref_proposal")]
506            cache_proposal,
507            None,
508        )
509        .await
510    }
511
512    async fn process_incoming_message_with_time(
513        &mut self,
514        message: MlsMessage,
515        #[cfg(feature = "by_ref_proposal")] cache_proposal: bool,
516        time_sent: Option<MlsTime>,
517    ) -> Result<Self::OutputType, MlsError> {
518        let event_or_content = self
519            .get_event_from_incoming_message(message, time_sent)
520            .await?;
521
522        self.process_event_or_content(
523            event_or_content,
524            #[cfg(feature = "by_ref_proposal")]
525            cache_proposal,
526            time_sent,
527        )
528        .await
529    }
530
531    async fn get_event_from_incoming_message(
532        &mut self,
533        message: MlsMessage,
534        time: Option<MlsTime>,
535    ) -> Result<EventOrContent<Self::OutputType>, MlsError> {
536        self.check_metadata(&message)?;
537
538        match message.payload {
539            MlsMessagePayload::Plain(plaintext) => {
540                self.verify_plaintext_authentication(plaintext).await
541            }
542            #[cfg(feature = "private_message")]
543            MlsMessagePayload::Cipher(cipher_text) => self.process_ciphertext(&cipher_text).await,
544            MlsMessagePayload::GroupInfo(group_info) => {
545                validate_group_info_member(
546                    self.group_state(),
547                    message.version,
548                    &group_info,
549                    self.cipher_suite_provider(),
550                )
551                .await?;
552
553                Ok(EventOrContent::Event(group_info.into()))
554            }
555            MlsMessagePayload::Welcome(welcome) => {
556                self.validate_welcome(&welcome, message.version)?;
557
558                Ok(EventOrContent::Event(welcome.into()))
559            }
560            MlsMessagePayload::KeyPackage(key_package) => {
561                self.validate_key_package(&key_package, message.version, time)
562                    .await?;
563
564                Ok(EventOrContent::Event(key_package.into()))
565            }
566        }
567    }
568
569    async fn process_event_or_content(
570        &mut self,
571        event_or_content: EventOrContent<Self::OutputType>,
572        #[cfg(feature = "by_ref_proposal")] cache_proposal: bool,
573        time_sent: Option<MlsTime>,
574    ) -> Result<Self::OutputType, MlsError> {
575        let msg = match event_or_content {
576            EventOrContent::Event(event) => event,
577            EventOrContent::Content(content) => {
578                self.process_auth_content(
579                    content,
580                    #[cfg(feature = "by_ref_proposal")]
581                    cache_proposal,
582                    time_sent,
583                )
584                .await?
585            }
586        };
587
588        Ok(msg)
589    }
590
591    async fn process_auth_content(
592        &mut self,
593        auth_content: AuthenticatedContent,
594        #[cfg(feature = "by_ref_proposal")] cache_proposal: bool,
595        time_sent: Option<MlsTime>,
596    ) -> Result<Self::OutputType, MlsError> {
597        let event = match auth_content.content.content {
598            #[cfg(feature = "private_message")]
599            Content::Application(data) => {
600                let authenticated_data = auth_content.content.authenticated_data;
601                let sender = auth_content.content.sender;
602
603                self.process_application_message(data, sender, authenticated_data)
604                    .and_then(Self::OutputType::try_from)
605            }
606            Content::Commit(_) => self
607                .process_commit(auth_content, time_sent)
608                .await
609                .map(Self::OutputType::from),
610            #[cfg(feature = "by_ref_proposal")]
611            Content::Proposal(ref proposal) => self
612                .process_proposal(&auth_content, proposal, cache_proposal)
613                .await
614                .map(Self::OutputType::from),
615        }?;
616
617        Ok(event)
618    }
619
620    #[cfg(feature = "private_message")]
621    fn process_application_message(
622        &self,
623        data: ApplicationData,
624        sender: Sender,
625        authenticated_data: Vec<u8>,
626    ) -> Result<ApplicationMessageDescription, MlsError> {
627        let Sender::Member(sender_index) = sender else {
628            return Err(MlsError::InvalidSender);
629        };
630
631        Ok(ApplicationMessageDescription {
632            authenticated_data,
633            sender_index,
634            data,
635        })
636    }
637
638    #[cfg(feature = "by_ref_proposal")]
639    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
640    async fn process_proposal(
641        &mut self,
642        auth_content: &AuthenticatedContent,
643        proposal: &Proposal,
644        cache_proposal: bool,
645    ) -> Result<ProposalMessageDescription, MlsError> {
646        let proposal = ProposalMessageDescription::new(
647            self.cipher_suite_provider(),
648            auth_content,
649            proposal.clone(),
650        )
651        .await?;
652
653        let group_state = self.group_state_mut();
654
655        if cache_proposal {
656            group_state.proposals.insert(
657                proposal.proposal_ref.clone(),
658                proposal.proposal.clone(),
659                auth_content.content.sender,
660            );
661        }
662
663        Ok(proposal)
664    }
665
666    async fn process_commit(
667        &mut self,
668        auth_content: AuthenticatedContent,
669        time_sent: Option<MlsTime>,
670    ) -> Result<CommitMessageDescription, MlsError> {
671        if self.group_state().pending_reinit.is_some() {
672            return Err(MlsError::GroupUsedAfterReInit);
673        }
674
675        // Update the new GroupContext's confirmed and interim transcript hashes using the new Commit.
676        let (interim_transcript_hash, confirmed_transcript_hash) = transcript_hashes(
677            self.cipher_suite_provider(),
678            &self.group_state().interim_transcript_hash,
679            &auth_content,
680        )
681        .await?;
682
683        #[cfg(any(feature = "private_message", feature = "by_ref_proposal"))]
684        let commit = match auth_content.content.content {
685            Content::Commit(commit) => Ok(commit),
686            _ => Err(MlsError::UnexpectedMessageType),
687        }?;
688
689        #[cfg(not(any(feature = "private_message", feature = "by_ref_proposal")))]
690        let Content::Commit(commit) = auth_content.content.content;
691
692        let group_state = self.group_state();
693        let id_provider = self.identity_provider();
694
695        #[cfg(feature = "by_ref_proposal")]
696        let proposals = group_state
697            .proposals
698            .resolve_for_commit(auth_content.content.sender, commit.proposals)?;
699
700        #[cfg(not(feature = "by_ref_proposal"))]
701        let proposals = resolve_for_commit(auth_content.content.sender, commit.proposals)?;
702
703        let mut provisional_state = group_state
704            .apply_resolved(
705                auth_content.content.sender,
706                proposals,
707                commit.path.as_ref().map(|path| &path.leaf_node),
708                &id_provider,
709                self.cipher_suite_provider(),
710                &self.psk_storage(),
711                &self.mls_rules(),
712                time_sent,
713                CommitDirection::Receive,
714            )
715            .await?;
716
717        let sender = commit_sender(&auth_content.content.sender, &provisional_state)?;
718
719        //Verify that the path value is populated if the proposals vector contains any Update
720        // or Remove proposals, or if it's empty. Otherwise, the path value MAY be omitted.
721        if path_update_required(&provisional_state.applied_proposals) && commit.path.is_none() {
722            return Err(MlsError::CommitMissingPath);
723        }
724
725        let self_removed = self.removal_proposal(&provisional_state);
726        #[cfg(all(
727            feature = "by_ref_proposal",
728            feature = "custom_proposal",
729            feature = "self_remove_proposal"
730        ))]
731        let self_removed_by_self = self.self_removal_proposal(&provisional_state);
732
733        let is_self_removed = self_removed.is_some();
734        #[cfg(all(
735            feature = "by_ref_proposal",
736            feature = "custom_proposal",
737            feature = "self_remove_proposal"
738        ))]
739        let is_self_removed = is_self_removed || self_removed_by_self.is_some();
740
741        let update_path = match commit.path {
742            Some(update_path) => Some(
743                validate_update_path(
744                    &self.identity_provider(),
745                    self.cipher_suite_provider(),
746                    update_path,
747                    &provisional_state,
748                    sender,
749                    time_sent,
750                    &group_state.context,
751                )
752                .await?,
753            ),
754            None => None,
755        };
756
757        let commit_effect =
758            if let Some(reinit) = provisional_state.applied_proposals.reinitializations.pop() {
759                self.group_state_mut().pending_reinit = Some(reinit.proposal.clone());
760                CommitEffect::ReInit(reinit)
761            } else if let Some(remove_proposal) = self_removed {
762                let new_epoch = NewEpoch::new(self.group_state().clone(), &provisional_state);
763                CommitEffect::Removed {
764                    remover: remove_proposal.sender,
765                    new_epoch: Box::new(new_epoch),
766                }
767            } else {
768                CommitEffect::NewEpoch(Box::new(NewEpoch::new(
769                    self.group_state().clone(),
770                    &provisional_state,
771                )))
772            };
773
774        #[cfg(all(
775            feature = "by_ref_proposal",
776            feature = "custom_proposal",
777            feature = "self_remove_proposal"
778        ))]
779        let commit_effect = if let Some(self_remove_proposal) = self_removed_by_self {
780            let new_epoch = NewEpoch::new(self.group_state().clone(), &provisional_state);
781            CommitEffect::Removed {
782                remover: self_remove_proposal.sender,
783                new_epoch: Box::new(new_epoch),
784            }
785        } else {
786            commit_effect
787        };
788
789        let new_secrets = match update_path {
790            Some(update_path) if !is_self_removed => {
791                self.apply_update_path(sender, &update_path, &mut provisional_state)
792                    .await
793            }
794            _ => Ok(None),
795        }?;
796
797        // Update the transcript hash to get the new context.
798        provisional_state.group_context.confirmed_transcript_hash = confirmed_transcript_hash;
799
800        // Update the parent hashes in the new context
801        provisional_state
802            .public_tree
803            .update_hashes(&[sender], self.cipher_suite_provider())
804            .await?;
805
806        // Update the tree hash in the new context
807        provisional_state.group_context.tree_hash = provisional_state
808            .public_tree
809            .tree_hash(self.cipher_suite_provider())
810            .await?;
811
812        if let Some(confirmation_tag) = &auth_content.auth.confirmation_tag {
813            if !is_self_removed {
814                // Update the key schedule to calculate new private keys
815                self.update_key_schedule(
816                    new_secrets,
817                    interim_transcript_hash,
818                    confirmation_tag,
819                    provisional_state,
820                )
821                .await?;
822            }
823            Ok(CommitMessageDescription {
824                is_external: matches!(auth_content.content.sender, Sender::NewMemberCommit),
825                authenticated_data: auth_content.content.authenticated_data,
826                committer: *sender,
827                effect: commit_effect,
828            })
829        } else {
830            Err(MlsError::InvalidConfirmationTag)
831        }
832    }
833
834    fn group_state(&self) -> &GroupState;
835    fn group_state_mut(&mut self) -> &mut GroupState;
836    fn mls_rules(&self) -> Self::MlsRules;
837    fn identity_provider(&self) -> Self::IdentityProvider;
838    fn cipher_suite_provider(&self) -> &Self::CipherSuiteProvider;
839    fn psk_storage(&self) -> Self::PreSharedKeyStorage;
840
841    fn removal_proposal(
842        &self,
843        provisional_state: &ProvisionalState,
844    ) -> Option<ProposalInfo<RemoveProposal>>;
845
846    #[cfg(all(
847        feature = "by_ref_proposal",
848        feature = "custom_proposal",
849        feature = "self_remove_proposal"
850    ))]
851    #[cfg_attr(feature = "ffi", safer_ffi_gen::safer_ffi_gen_ignore)]
852    fn self_removal_proposal(
853        &self,
854        provisional_state: &ProvisionalState,
855    ) -> Option<ProposalInfo<SelfRemoveProposal>>;
856
857    #[cfg(feature = "private_message")]
858    fn min_epoch_available(&self) -> Option<u64>;
859
860    fn check_metadata(&self, message: &MlsMessage) -> Result<(), MlsError> {
861        let context = &self.group_state().context;
862
863        if message.version != context.protocol_version {
864            return Err(MlsError::ProtocolVersionMismatch);
865        }
866
867        if let Some((group_id, epoch, content_type)) = match &message.payload {
868            MlsMessagePayload::Plain(plaintext) => Some((
869                &plaintext.content.group_id,
870                plaintext.content.epoch,
871                plaintext.content.content_type(),
872            )),
873            #[cfg(feature = "private_message")]
874            MlsMessagePayload::Cipher(ciphertext) => Some((
875                &ciphertext.group_id,
876                ciphertext.epoch,
877                ciphertext.content_type,
878            )),
879            _ => None,
880        } {
881            if group_id != &context.group_id {
882                return Err(MlsError::GroupIdMismatch);
883            }
884
885            match content_type {
886                ContentType::Commit => {
887                    if context.epoch != epoch {
888                        Err(MlsError::InvalidEpoch)
889                    } else {
890                        Ok(())
891                    }
892                }
893                #[cfg(feature = "by_ref_proposal")]
894                ContentType::Proposal => {
895                    if context.epoch != epoch {
896                        Err(MlsError::InvalidEpoch)
897                    } else {
898                        Ok(())
899                    }
900                }
901                #[cfg(feature = "private_message")]
902                ContentType::Application => {
903                    if let Some(min) = self.min_epoch_available() {
904                        if epoch < min {
905                            Err(MlsError::InvalidEpoch)
906                        } else {
907                            Ok(())
908                        }
909                    } else {
910                        Ok(())
911                    }
912                }
913            }?;
914
915            // Proposal and commit messages must be sent in the current epoch
916            let check_epoch = content_type == ContentType::Commit;
917
918            #[cfg(feature = "by_ref_proposal")]
919            let check_epoch = check_epoch || content_type == ContentType::Proposal;
920
921            if check_epoch && epoch != context.epoch {
922                return Err(MlsError::InvalidEpoch);
923            }
924
925            // Unencrypted application messages are not allowed
926            #[cfg(feature = "private_message")]
927            if !matches!(&message.payload, MlsMessagePayload::Cipher(_))
928                && content_type == ContentType::Application
929            {
930                return Err(MlsError::UnencryptedApplicationMessage);
931            }
932        }
933
934        Ok(())
935    }
936
937    fn validate_welcome(
938        &self,
939        welcome: &Welcome,
940        version: ProtocolVersion,
941    ) -> Result<(), MlsError> {
942        let state = self.group_state();
943
944        (welcome.cipher_suite == state.context.cipher_suite
945            && version == state.context.protocol_version)
946            .then_some(())
947            .ok_or(MlsError::InvalidWelcomeMessage)
948    }
949
950    async fn validate_key_package(
951        &self,
952        key_package: &KeyPackage,
953        version: ProtocolVersion,
954        time: Option<MlsTime>,
955    ) -> Result<(), MlsError> {
956        let cs = self.cipher_suite_provider();
957        let id = self.identity_provider();
958
959        validate_key_package(key_package, version, cs, &id, time).await
960    }
961
962    #[cfg(feature = "private_message")]
963    async fn process_ciphertext(
964        &mut self,
965        cipher_text: &PrivateMessage,
966    ) -> Result<EventOrContent<Self::OutputType>, MlsError>;
967
968    async fn verify_plaintext_authentication(
969        &self,
970        message: PublicMessage,
971    ) -> Result<EventOrContent<Self::OutputType>, MlsError>;
972
973    async fn apply_update_path(
974        &mut self,
975        sender: LeafIndex,
976        update_path: &ValidatedUpdatePath,
977        provisional_state: &mut ProvisionalState,
978    ) -> Result<Option<(TreeKemPrivate, PathSecret)>, MlsError> {
979        provisional_state
980            .public_tree
981            .apply_update_path(
982                sender,
983                update_path,
984                &provisional_state.group_context.extensions,
985                self.identity_provider(),
986                self.cipher_suite_provider(),
987            )
988            .await
989            .map(|_| None)
990    }
991
992    async fn update_key_schedule(
993        &mut self,
994        secrets: Option<(TreeKemPrivate, PathSecret)>,
995        interim_transcript_hash: InterimTranscriptHash,
996        confirmation_tag: &ConfirmationTag,
997        provisional_public_state: ProvisionalState,
998    ) -> Result<(), MlsError>;
999}
1000
1001#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
1002pub(crate) async fn validate_key_package<C: CipherSuiteProvider, I: IdentityProvider>(
1003    key_package: &KeyPackage,
1004    version: ProtocolVersion,
1005    cs: &C,
1006    id: &I,
1007    time: Option<MlsTime>,
1008) -> Result<(), MlsError> {
1009    let validator = LeafNodeValidator::new(cs, id, MemberValidationContext::None);
1010
1011    #[cfg(feature = "std")]
1012    let context = Some(MlsTime::now());
1013
1014    #[cfg(not(feature = "std"))]
1015    let context = None;
1016
1017    let context = if time.is_some() { time } else { context };
1018
1019    let context = ValidationContext::Add(context);
1020
1021    validator
1022        .check_if_valid(&key_package.leaf_node, context)
1023        .await?;
1024
1025    validate_key_package_properties(key_package, version, cs).await?;
1026
1027    Ok(())
1028}
1029
1030#[cfg(test)]
1031mod tests {
1032    use alloc::{vec, vec::Vec};
1033    use mls_rs_codec::{MlsDecode, MlsEncode};
1034
1035    use crate::{
1036        client::test_utils::TEST_PROTOCOL_VERSION,
1037        group::{test_utils::get_test_group_context, GroupState, Sender},
1038    };
1039
1040    use super::{CommitEffect, NewEpoch};
1041
1042    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1043    async fn commit_effect_codec() {
1044        let epoch = NewEpoch {
1045            epoch: 7,
1046            prior_state: GroupState {
1047                #[cfg(feature = "by_ref_proposal")]
1048                proposals: crate::group::ProposalCache::new(TEST_PROTOCOL_VERSION, vec![]),
1049                context: get_test_group_context(7, 7.into()).await,
1050                public_tree: Default::default(),
1051                interim_transcript_hash: vec![].into(),
1052                pending_reinit: None,
1053                confirmation_tag: Default::default(),
1054            },
1055            applied_proposals: vec![],
1056            unused_proposals: vec![],
1057        };
1058
1059        let effects = vec![
1060            CommitEffect::NewEpoch(epoch.clone().into()),
1061            CommitEffect::Removed {
1062                new_epoch: epoch.into(),
1063                remover: Sender::Member(0),
1064            },
1065        ];
1066
1067        let bytes = effects.mls_encode_to_vec().unwrap();
1068
1069        assert_eq!(
1070            effects,
1071            Vec::<CommitEffect>::mls_decode(&mut &*bytes).unwrap()
1072        );
1073    }
1074}