Skip to main content

mls_rs/group/
message_processor.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// Copyright by contributors to this project.
3// SPDX-License-Identifier: (Apache-2.0 OR MIT)
4
5#[cfg(all(
6    feature = "by_ref_proposal",
7    feature = "custom_proposal",
8    feature = "self_remove_proposal"
9))]
10use super::SelfRemoveProposal;
11use super::{
12    commit_sender,
13    confirmation_tag::ConfirmationTag,
14    framing::{
15        ApplicationData, Content, ContentType, MlsMessage, MlsMessagePayload, PublicMessage, Sender,
16    },
17    message_signature::AuthenticatedContent,
18    mls_rules::{CommitDirection, MlsRules},
19    proposal_filter::ProposalBundle,
20    state::GroupState,
21    transcript_hash::InterimTranscriptHash,
22    transcript_hashes, validate_group_info_member, GroupContext, GroupInfo, ReInitProposal,
23    RemoveProposal, Welcome,
24};
25use crate::{
26    client::MlsError,
27    key_package::validate_key_package_properties,
28    time::MlsTime,
29    tree_kem::{
30        leaf_node_validator::{LeafNodeValidator, ValidationContext},
31        node::LeafIndex,
32        path_secret::PathSecret,
33        validate_update_path, TreeKemPrivate, TreeKemPublic, ValidatedUpdatePath,
34    },
35    CipherSuiteProvider, KeyPackage,
36};
37use itertools::Itertools;
38use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
39
40use alloc::boxed::Box;
41use alloc::vec::Vec;
42use core::fmt::{self, Debug};
43use mls_rs_core::{
44    identity::{IdentityProvider, MemberValidationContext},
45    protocol_version::ProtocolVersion,
46    psk::PreSharedKeyStorage,
47};
48
49#[cfg(feature = "by_ref_proposal")]
50use super::proposal_ref::ProposalRef;
51
52#[cfg(not(feature = "by_ref_proposal"))]
53use crate::group::proposal_cache::resolve_for_commit;
54
55use super::proposal::Proposal;
56use super::proposal_filter::ProposalInfo;
57
58#[cfg(feature = "private_message")]
59use crate::group::framing::PrivateMessage;
60
61#[derive(Debug)]
62pub(crate) struct ProvisionalState {
63    pub(crate) public_tree: TreeKemPublic,
64    pub(crate) applied_proposals: ProposalBundle,
65    pub(crate) group_context: GroupContext,
66    pub(crate) external_init_index: Option<LeafIndex>,
67    pub(crate) indexes_of_added_kpkgs: Vec<LeafIndex>,
68    pub(crate) unused_proposals: Vec<ProposalInfo<Proposal>>,
69}
70
71//By default, the path field of a Commit MUST be populated. The path field MAY be omitted if
72//(a) it covers at least one proposal and (b) none of the proposals covered by the Commit are
73//of "path required" types. A proposal type requires a path if it cannot change the group
74//membership in a way that requires the forward secrecy and post-compromise security guarantees
75//that an UpdatePath provides. The only proposal types defined in this document that do not
76//require a path are:
77
78// add
79// psk
80// reinit
81pub(crate) fn path_update_required(proposals: &ProposalBundle) -> bool {
82    let res = !proposals.external_init_proposals().is_empty();
83
84    #[cfg(feature = "by_ref_proposal")]
85    let res = res || !proposals.update_proposals().is_empty();
86
87    #[cfg(all(
88        feature = "by_ref_proposal",
89        feature = "custom_proposal",
90        feature = "self_remove_proposal"
91    ))]
92    let res = res || !proposals.self_removes.is_empty();
93
94    res || proposals.length() == 0
95        || proposals.group_context_extensions_proposal().is_some()
96        || !proposals.remove_proposals().is_empty()
97}
98
99#[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode)]
100#[non_exhaustive]
101pub struct NewEpoch {
102    pub epoch: u64,
103    pub prior_state: GroupState,
104    pub applied_proposals: Vec<ProposalInfo<Proposal>>,
105    pub unused_proposals: Vec<ProposalInfo<Proposal>>,
106}
107
108impl NewEpoch {
109    pub(crate) fn new(prior_state: GroupState, provisional_state: &ProvisionalState) -> NewEpoch {
110        NewEpoch {
111            epoch: provisional_state.group_context.epoch,
112            prior_state,
113            unused_proposals: provisional_state.unused_proposals.clone(),
114            applied_proposals: provisional_state
115                .applied_proposals
116                .clone()
117                .into_proposals()
118                .collect_vec(),
119        }
120    }
121}
122
123impl NewEpoch {
124    pub fn epoch(&self) -> u64 {
125        self.epoch
126    }
127
128    pub fn prior_state(&self) -> &GroupState {
129        &self.prior_state
130    }
131
132    pub fn applied_proposals(&self) -> &[ProposalInfo<Proposal>] {
133        &self.applied_proposals
134    }
135
136    pub fn unused_proposals(&self) -> &[ProposalInfo<Proposal>] {
137        &self.unused_proposals
138    }
139}
140
141#[derive(Clone, Debug, PartialEq)]
142pub enum CommitEffect {
143    NewEpoch(Box<NewEpoch>),
144    Removed {
145        new_epoch: Box<NewEpoch>,
146        remover: Sender,
147    },
148    ReInit(ProposalInfo<ReInitProposal>),
149}
150
151impl MlsSize for CommitEffect {
152    fn mls_encoded_len(&self) -> usize {
153        0u8.mls_encoded_len()
154            + match self {
155                Self::NewEpoch(e) => e.mls_encoded_len(),
156                Self::Removed { new_epoch, remover } => {
157                    new_epoch.mls_encoded_len() + remover.mls_encoded_len()
158                }
159                Self::ReInit(r) => r.mls_encoded_len(),
160            }
161    }
162}
163
164impl MlsEncode for CommitEffect {
165    fn mls_encode(&self, writer: &mut Vec<u8>) -> Result<(), mls_rs_codec::Error> {
166        match self {
167            Self::NewEpoch(e) => {
168                1u8.mls_encode(writer)?;
169                e.mls_encode(writer)?;
170            }
171            Self::Removed { new_epoch, remover } => {
172                2u8.mls_encode(writer)?;
173                new_epoch.mls_encode(writer)?;
174                remover.mls_encode(writer)?;
175            }
176            Self::ReInit(r) => {
177                3u8.mls_encode(writer)?;
178                r.mls_encode(writer)?;
179            }
180        }
181
182        Ok(())
183    }
184}
185
186impl MlsDecode for CommitEffect {
187    fn mls_decode(reader: &mut &[u8]) -> Result<Self, mls_rs_codec::Error> {
188        match u8::mls_decode(reader)? {
189            1u8 => Ok(Self::NewEpoch(NewEpoch::mls_decode(reader)?.into())),
190            2u8 => Ok(Self::Removed {
191                new_epoch: NewEpoch::mls_decode(reader)?.into(),
192                remover: Sender::mls_decode(reader)?,
193            }),
194            3u8 => Ok(Self::ReInit(ProposalInfo::mls_decode(reader)?)),
195            _ => Err(mls_rs_codec::Error::UnsupportedEnumDiscriminant),
196        }
197    }
198}
199
200#[derive(Debug, Clone)]
201#[allow(clippy::large_enum_variant)]
202/// An event generated as a result of processing a message for a group with
203/// [`Group::process_incoming_message`](crate::group::Group::process_incoming_message).
204pub enum ReceivedMessage {
205    /// An application message was decrypted.
206    ApplicationMessage(ApplicationMessageDescription),
207    /// A new commit was processed creating a new group state.
208    Commit(CommitMessageDescription),
209    /// A proposal was received.
210    Proposal(ProposalMessageDescription),
211    /// Validated GroupInfo object
212    GroupInfo(GroupInfo),
213    /// Validated welcome message
214    Welcome,
215    /// Validated key package
216    KeyPackage(KeyPackage),
217}
218
219impl TryFrom<ApplicationMessageDescription> for ReceivedMessage {
220    type Error = MlsError;
221
222    fn try_from(value: ApplicationMessageDescription) -> Result<Self, Self::Error> {
223        Ok(ReceivedMessage::ApplicationMessage(value))
224    }
225}
226
227impl From<CommitMessageDescription> for ReceivedMessage {
228    fn from(value: CommitMessageDescription) -> Self {
229        ReceivedMessage::Commit(value)
230    }
231}
232
233impl From<ProposalMessageDescription> for ReceivedMessage {
234    fn from(value: ProposalMessageDescription) -> Self {
235        ReceivedMessage::Proposal(value)
236    }
237}
238
239impl From<GroupInfo> for ReceivedMessage {
240    fn from(value: GroupInfo) -> Self {
241        ReceivedMessage::GroupInfo(value)
242    }
243}
244
245impl From<Welcome> for ReceivedMessage {
246    fn from(_: Welcome) -> Self {
247        ReceivedMessage::Welcome
248    }
249}
250
251impl From<KeyPackage> for ReceivedMessage {
252    fn from(value: KeyPackage) -> Self {
253        ReceivedMessage::KeyPackage(value)
254    }
255}
256
257#[derive(Clone, PartialEq, Eq)]
258/// Description of a MLS application message.
259pub struct ApplicationMessageDescription {
260    /// Index of this user in the group state.
261    pub sender_index: u32,
262    /// Received application data.
263    data: ApplicationData,
264    /// Plaintext authenticated data in the received MLS packet.
265    pub authenticated_data: Vec<u8>,
266    /// Unauthenticated key generation used to decrypt the message. See documentation for
267    /// [`Group::peek_next_key_generation`] for usage.
268    #[cfg(all(feature = "export_key_generation", feature = "private_message"))]
269    pub unauthenticated_key_generation: Option<u32>,
270}
271
272impl Debug for ApplicationMessageDescription {
273    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
274        let mut res = f.debug_struct("ApplicationMessageDescription");
275        res.field("sender_index", &self.sender_index)
276            .field("data", &self.data)
277            .field(
278                "authenticated_data",
279                &mls_rs_core::debug::pretty_bytes(&self.authenticated_data),
280            );
281        #[cfg(all(feature = "export_key_generation", feature = "private_message"))]
282        res.field(
283            "unauthenticated_key_generation",
284            &self.unauthenticated_key_generation,
285        );
286        res.finish()
287    }
288}
289
290impl ApplicationMessageDescription {
291    pub fn data(&self) -> &[u8] {
292        self.data.as_bytes()
293    }
294}
295
296#[derive(Clone, PartialEq, MlsSize, MlsEncode, MlsDecode)]
297#[non_exhaustive]
298/// Description of a processed MLS commit message.
299pub struct CommitMessageDescription {
300    /// True if this is the result of an external commit.
301    pub is_external: bool,
302    /// The index in the group state of the member who performed this commit.
303    pub committer: u32,
304    /// A full description of group state changes as a result of this commit.
305    pub effect: CommitEffect,
306    /// Plaintext authenticated data in the received MLS packet.
307    #[mls_codec(with = "mls_rs_codec::byte_vec")]
308    pub authenticated_data: Vec<u8>,
309}
310
311impl Debug for CommitMessageDescription {
312    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
313        f.debug_struct("CommitMessageDescription")
314            .field("is_external", &self.is_external)
315            .field("committer", &self.committer)
316            .field("effect", &self.effect)
317            .field(
318                "authenticated_data",
319                &mls_rs_core::debug::pretty_bytes(&self.authenticated_data),
320            )
321            .finish()
322    }
323}
324
325#[derive(Debug, Clone, Copy, PartialEq, Eq, MlsEncode, MlsDecode, MlsSize)]
326#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
327#[repr(u8)]
328/// Proposal sender type.
329pub enum ProposalSender {
330    /// A current member of the group by index in the group state.
331    Member(u32) = 1u8,
332    /// An external entity by index within an
333    /// [`ExternalSendersExt`](crate::extension::built_in::ExternalSendersExt).
334    External(u32) = 2u8,
335    /// A new member proposing their addition to the group.
336    NewMember = 3u8,
337}
338
339impl TryFrom<Sender> for ProposalSender {
340    type Error = MlsError;
341
342    fn try_from(value: Sender) -> Result<Self, Self::Error> {
343        match value {
344            Sender::Member(index) => Ok(Self::Member(index)),
345            #[cfg(feature = "by_ref_proposal")]
346            Sender::External(index) => Ok(Self::External(index)),
347            #[cfg(feature = "by_ref_proposal")]
348            Sender::NewMemberProposal => Ok(Self::NewMember),
349            Sender::NewMemberCommit => Err(MlsError::InvalidSender),
350        }
351    }
352}
353
354#[cfg(feature = "by_ref_proposal")]
355#[derive(Clone, MlsEncode, MlsDecode, MlsSize, PartialEq)]
356#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
357#[non_exhaustive]
358/// Description of a processed MLS proposal message.
359pub struct ProposalMessageDescription {
360    /// Sender of the proposal.
361    pub sender: ProposalSender,
362    /// Proposal content.
363    pub proposal: Proposal,
364    /// Plaintext authenticated data in the received MLS packet.
365    pub authenticated_data: Vec<u8>,
366    /// Proposal reference.
367    pub proposal_ref: ProposalRef,
368}
369
370#[cfg(feature = "by_ref_proposal")]
371impl Debug for ProposalMessageDescription {
372    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
373        f.debug_struct("ProposalMessageDescription")
374            .field("sender", &self.sender)
375            .field("proposal", &self.proposal)
376            .field(
377                "authenticated_data",
378                &mls_rs_core::debug::pretty_bytes(&self.authenticated_data),
379            )
380            .field("proposal_ref", &self.proposal_ref)
381            .finish()
382    }
383}
384
385#[cfg(feature = "by_ref_proposal")]
386#[derive(MlsSize, MlsEncode, MlsDecode)]
387pub struct CachedProposal {
388    pub(crate) proposal: Proposal,
389    pub(crate) proposal_ref: ProposalRef,
390    pub(crate) sender: Sender,
391}
392
393#[cfg(feature = "by_ref_proposal")]
394impl CachedProposal {
395    /// Deserialize the proposal
396    pub fn from_bytes(bytes: &[u8]) -> Result<Self, MlsError> {
397        Ok(Self::mls_decode(&mut &*bytes)?)
398    }
399
400    /// Serialize the proposal
401    pub fn to_bytes(&self) -> Result<Vec<u8>, MlsError> {
402        Ok(self.mls_encode_to_vec()?)
403    }
404
405    /// The proposal content.
406    pub fn proposal(&self) -> &Proposal {
407        &self.proposal
408    }
409
410    /// The proposal reference (hash-based identifier).
411    pub fn proposal_ref(&self) -> &ProposalRef {
412        &self.proposal_ref
413    }
414
415    /// The sender of the proposal.
416    pub fn sender(&self) -> &Sender {
417        &self.sender
418    }
419}
420
421#[cfg(feature = "by_ref_proposal")]
422impl ProposalMessageDescription {
423    pub fn cached_proposal(self) -> CachedProposal {
424        let sender = match self.sender {
425            ProposalSender::Member(i) => Sender::Member(i),
426            ProposalSender::External(i) => Sender::External(i),
427            ProposalSender::NewMember => Sender::NewMemberProposal,
428        };
429
430        CachedProposal {
431            proposal: self.proposal,
432            proposal_ref: self.proposal_ref,
433            sender,
434        }
435    }
436
437    pub fn proposal_ref(&self) -> Vec<u8> {
438        self.proposal_ref.to_vec()
439    }
440
441    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
442    pub(crate) async fn new<C: CipherSuiteProvider>(
443        cs: &C,
444        content: &AuthenticatedContent,
445        proposal: Proposal,
446    ) -> Result<Self, MlsError> {
447        Ok(ProposalMessageDescription {
448            authenticated_data: content.content.authenticated_data.clone(),
449            proposal,
450            sender: content.content.sender.try_into()?,
451            proposal_ref: ProposalRef::from_content(cs, content).await?,
452        })
453    }
454}
455
456#[cfg(not(feature = "by_ref_proposal"))]
457#[derive(Debug, Clone)]
458/// Description of a processed MLS proposal message.
459pub struct ProposalMessageDescription {}
460
461#[allow(clippy::large_enum_variant)]
462pub(crate) enum EventOrContent<E> {
463    #[cfg_attr(
464        not(all(feature = "private_message", feature = "external_client")),
465        allow(dead_code)
466    )]
467    Event(E),
468    Content(AuthenticatedContent),
469}
470
471#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
472#[cfg_attr(all(target_arch = "wasm32", mls_build_async), maybe_async::must_be_async(?Send))]
473#[cfg_attr(
474    all(not(target_arch = "wasm32"), mls_build_async),
475    maybe_async::must_be_async
476)]
477pub(crate) trait MessageProcessor: Send + Sync {
478    type OutputType: TryFrom<ApplicationMessageDescription, Error = MlsError>
479        + From<CommitMessageDescription>
480        + From<ProposalMessageDescription>
481        + From<GroupInfo>
482        + From<Welcome>
483        + From<KeyPackage>
484        + Send;
485
486    type MlsRules: MlsRules;
487    type IdentityProvider: IdentityProvider;
488    type CipherSuiteProvider: CipherSuiteProvider;
489    type PreSharedKeyStorage: PreSharedKeyStorage;
490
491    async fn process_incoming_message(
492        &mut self,
493        message: MlsMessage,
494        #[cfg(feature = "by_ref_proposal")] cache_proposal: bool,
495    ) -> Result<Self::OutputType, MlsError> {
496        self.process_incoming_message_with_time(
497            message,
498            #[cfg(feature = "by_ref_proposal")]
499            cache_proposal,
500            None,
501        )
502        .await
503    }
504
505    async fn process_incoming_message_with_time(
506        &mut self,
507        message: MlsMessage,
508        #[cfg(feature = "by_ref_proposal")] cache_proposal: bool,
509        time_sent: Option<MlsTime>,
510    ) -> Result<Self::OutputType, MlsError> {
511        #[cfg(all(feature = "export_key_generation", feature = "private_message"))]
512        // For encrypted application messages, retrieve the unauthenticated key
513        // generation used to decrypt the message and return it with the plaintext. Does
514        // not return an error on failure, allowing `get_event_from_incoming_message` to
515        // continue owning that task.
516        // Note that this decrypts the SenderData twice, which is not ideal.
517        let unauthn_key_gen_in_app_msg: Option<u32> = match message.payload {
518            MlsMessagePayload::Cipher(ref cipher_text) => self
519                .get_unauthenticated_key_generation_from_sender_data(cipher_text)
520                .unwrap_or_default(),
521            _ => None,
522        };
523
524        let event_or_content = self
525            .get_event_from_incoming_message(message, time_sent)
526            .await?;
527
528        self.process_event_or_content(
529            event_or_content,
530            #[cfg(feature = "by_ref_proposal")]
531            cache_proposal,
532            #[cfg(all(feature = "export_key_generation", feature = "private_message"))]
533            unauthn_key_gen_in_app_msg,
534            time_sent,
535        )
536        .await
537    }
538
539    async fn get_event_from_incoming_message(
540        &mut self,
541        message: MlsMessage,
542        time: Option<MlsTime>,
543    ) -> Result<EventOrContent<Self::OutputType>, MlsError> {
544        self.check_metadata(&message)?;
545
546        match message.payload {
547            MlsMessagePayload::Plain(plaintext) => {
548                self.verify_plaintext_authentication(plaintext).await
549            }
550            #[cfg(feature = "private_message")]
551            MlsMessagePayload::Cipher(cipher_text) => self.process_ciphertext(&cipher_text).await,
552            MlsMessagePayload::GroupInfo(group_info) => {
553                validate_group_info_member(
554                    self.group_state(),
555                    message.version,
556                    &group_info,
557                    self.cipher_suite_provider(),
558                )
559                .await?;
560
561                Ok(EventOrContent::Event(group_info.into()))
562            }
563            MlsMessagePayload::Welcome(welcome) => {
564                self.validate_welcome(&welcome, message.version)?;
565
566                Ok(EventOrContent::Event(welcome.into()))
567            }
568            MlsMessagePayload::KeyPackage(key_package) => {
569                self.validate_key_package(&key_package, message.version, time)
570                    .await?;
571
572                Ok(EventOrContent::Event(key_package.into()))
573            }
574        }
575    }
576
577    async fn process_event_or_content(
578        &mut self,
579        event_or_content: EventOrContent<Self::OutputType>,
580        #[cfg(feature = "by_ref_proposal")] cache_proposal: bool,
581        #[cfg(all(feature = "export_key_generation", feature = "private_message"))]
582        unauthn_key_gen_in_app_msg: Option<u32>,
583        time_sent: Option<MlsTime>,
584    ) -> Result<Self::OutputType, MlsError> {
585        let msg = match event_or_content {
586            EventOrContent::Event(event) => event,
587            EventOrContent::Content(content) => {
588                self.process_auth_content(
589                    content,
590                    #[cfg(feature = "by_ref_proposal")]
591                    cache_proposal,
592                    #[cfg(all(feature = "export_key_generation", feature = "private_message"))]
593                    unauthn_key_gen_in_app_msg,
594                    time_sent,
595                )
596                .await?
597            }
598        };
599
600        Ok(msg)
601    }
602
603    async fn process_auth_content(
604        &mut self,
605        auth_content: AuthenticatedContent,
606        #[cfg(feature = "by_ref_proposal")] cache_proposal: bool,
607        #[cfg(all(feature = "export_key_generation", feature = "private_message"))]
608        unauthn_key_gen_in_app_msg: Option<u32>,
609        time_sent: Option<MlsTime>,
610    ) -> Result<Self::OutputType, MlsError> {
611        let event = match auth_content.content.content {
612            #[cfg(feature = "private_message")]
613            Content::Application(data) => {
614                let authenticated_data = auth_content.content.authenticated_data;
615                let sender = auth_content.content.sender;
616
617                self.process_application_message(
618                    data,
619                    sender,
620                    authenticated_data,
621                    #[cfg(all(feature = "export_key_generation", feature = "private_message"))]
622                    unauthn_key_gen_in_app_msg,
623                )
624                .and_then(Self::OutputType::try_from)
625            }
626            Content::Commit(_) => self
627                .process_commit(auth_content, time_sent)
628                .await
629                .map(Self::OutputType::from),
630            #[cfg(feature = "by_ref_proposal")]
631            Content::Proposal(ref proposal) => self
632                .process_proposal(&auth_content, proposal, cache_proposal)
633                .await
634                .map(Self::OutputType::from),
635        }?;
636
637        Ok(event)
638    }
639
640    #[cfg(feature = "private_message")]
641    fn process_application_message(
642        &self,
643        data: ApplicationData,
644        sender: Sender,
645        authenticated_data: Vec<u8>,
646        #[cfg(all(feature = "export_key_generation", feature = "private_message"))]
647        unauthenticated_key_generation: Option<u32>,
648    ) -> Result<ApplicationMessageDescription, MlsError> {
649        let Sender::Member(sender_index) = sender else {
650            return Err(MlsError::InvalidSender);
651        };
652
653        Ok(ApplicationMessageDescription {
654            authenticated_data,
655            sender_index,
656            data,
657            #[cfg(all(feature = "export_key_generation", feature = "private_message"))]
658            unauthenticated_key_generation,
659        })
660    }
661
662    #[cfg(feature = "by_ref_proposal")]
663    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
664    async fn process_proposal(
665        &mut self,
666        auth_content: &AuthenticatedContent,
667        proposal: &Proposal,
668        cache_proposal: bool,
669    ) -> Result<ProposalMessageDescription, MlsError> {
670        let proposal = ProposalMessageDescription::new(
671            self.cipher_suite_provider(),
672            auth_content,
673            proposal.clone(),
674        )
675        .await?;
676
677        let group_state = self.group_state_mut();
678
679        if cache_proposal {
680            group_state.proposals.insert(
681                proposal.proposal_ref.clone(),
682                proposal.proposal.clone(),
683                auth_content.content.sender,
684            );
685        }
686
687        Ok(proposal)
688    }
689
690    async fn process_commit(
691        &mut self,
692        auth_content: AuthenticatedContent,
693        time_sent: Option<MlsTime>,
694    ) -> Result<CommitMessageDescription, MlsError> {
695        if self.group_state().pending_reinit.is_some() {
696            return Err(MlsError::GroupUsedAfterReInit);
697        }
698
699        // Update the new GroupContext's confirmed and interim transcript hashes using the new Commit.
700        let (interim_transcript_hash, confirmed_transcript_hash) = transcript_hashes(
701            self.cipher_suite_provider(),
702            &self.group_state().interim_transcript_hash,
703            &auth_content,
704        )
705        .await?;
706
707        #[cfg(any(feature = "private_message", feature = "by_ref_proposal"))]
708        let commit = match auth_content.content.content {
709            Content::Commit(commit) => Ok(commit),
710            _ => Err(MlsError::UnexpectedMessageType),
711        }?;
712
713        #[cfg(not(any(feature = "private_message", feature = "by_ref_proposal")))]
714        let Content::Commit(commit) = auth_content.content.content;
715
716        let group_state = self.group_state();
717        let id_provider = self.identity_provider();
718
719        #[cfg(feature = "by_ref_proposal")]
720        let proposals = group_state
721            .proposals
722            .resolve_for_commit(auth_content.content.sender, commit.proposals)?;
723
724        #[cfg(not(feature = "by_ref_proposal"))]
725        let proposals = resolve_for_commit(auth_content.content.sender, commit.proposals)?;
726
727        let mut provisional_state = group_state
728            .apply_resolved(
729                auth_content.content.sender,
730                proposals,
731                commit.path.as_ref().map(|path| &path.leaf_node),
732                &id_provider,
733                self.cipher_suite_provider(),
734                &self.psk_storage(),
735                &self.mls_rules(),
736                time_sent,
737                CommitDirection::Receive,
738            )
739            .await?;
740
741        let sender = commit_sender(&auth_content.content.sender, &provisional_state)?;
742
743        //Verify that the path value is populated if the proposals vector contains any Update
744        // or Remove proposals, or if it's empty. Otherwise, the path value MAY be omitted.
745        if path_update_required(&provisional_state.applied_proposals) && commit.path.is_none() {
746            return Err(MlsError::CommitMissingPath);
747        }
748
749        let self_removed = self.removal_proposal(&provisional_state);
750        #[cfg(all(
751            feature = "by_ref_proposal",
752            feature = "custom_proposal",
753            feature = "self_remove_proposal"
754        ))]
755        let self_removed_by_self = self.self_removal_proposal(&provisional_state);
756
757        let is_self_removed = self_removed.is_some();
758        #[cfg(all(
759            feature = "by_ref_proposal",
760            feature = "custom_proposal",
761            feature = "self_remove_proposal"
762        ))]
763        let is_self_removed = is_self_removed || self_removed_by_self.is_some();
764
765        let update_path = match commit.path {
766            Some(update_path) => Some(
767                validate_update_path(
768                    &self.identity_provider(),
769                    self.cipher_suite_provider(),
770                    update_path,
771                    &provisional_state,
772                    sender,
773                    time_sent,
774                    &group_state.context,
775                )
776                .await?,
777            ),
778            None => None,
779        };
780
781        let commit_effect =
782            if let Some(reinit) = provisional_state.applied_proposals.reinitializations.pop() {
783                self.group_state_mut().pending_reinit = Some(reinit.proposal.clone());
784                CommitEffect::ReInit(reinit)
785            } else if let Some(remove_proposal) = self_removed {
786                let new_epoch = NewEpoch::new(self.group_state().clone(), &provisional_state);
787                CommitEffect::Removed {
788                    remover: remove_proposal.sender,
789                    new_epoch: Box::new(new_epoch),
790                }
791            } else {
792                CommitEffect::NewEpoch(Box::new(NewEpoch::new(
793                    self.group_state().clone(),
794                    &provisional_state,
795                )))
796            };
797
798        #[cfg(all(
799            feature = "by_ref_proposal",
800            feature = "custom_proposal",
801            feature = "self_remove_proposal"
802        ))]
803        let commit_effect = if let Some(self_remove_proposal) = self_removed_by_self {
804            let new_epoch = NewEpoch::new(self.group_state().clone(), &provisional_state);
805            CommitEffect::Removed {
806                remover: self_remove_proposal.sender,
807                new_epoch: Box::new(new_epoch),
808            }
809        } else {
810            commit_effect
811        };
812
813        let new_secrets = match update_path {
814            Some(update_path) if !is_self_removed => {
815                self.apply_update_path(sender, &update_path, &mut provisional_state)
816                    .await
817            }
818            _ => Ok(None),
819        }?;
820
821        // Update the transcript hash to get the new context.
822        provisional_state.group_context.confirmed_transcript_hash = confirmed_transcript_hash;
823
824        // Update the parent hashes in the new context
825        provisional_state
826            .public_tree
827            .update_hashes(&[sender], self.cipher_suite_provider())
828            .await?;
829
830        // Update the tree hash in the new context
831        provisional_state.group_context.tree_hash = provisional_state
832            .public_tree
833            .tree_hash(self.cipher_suite_provider())
834            .await?;
835
836        if let Some(confirmation_tag) = &auth_content.auth.confirmation_tag {
837            if !is_self_removed {
838                // Update the key schedule to calculate new private keys
839                self.update_key_schedule(
840                    new_secrets,
841                    interim_transcript_hash,
842                    confirmation_tag,
843                    provisional_state,
844                )
845                .await?;
846            }
847            Ok(CommitMessageDescription {
848                is_external: matches!(auth_content.content.sender, Sender::NewMemberCommit),
849                authenticated_data: auth_content.content.authenticated_data,
850                committer: *sender,
851                effect: commit_effect,
852            })
853        } else {
854            Err(MlsError::InvalidConfirmationTag)
855        }
856    }
857
858    fn group_state(&self) -> &GroupState;
859    fn group_state_mut(&mut self) -> &mut GroupState;
860    fn mls_rules(&self) -> Self::MlsRules;
861    fn identity_provider(&self) -> Self::IdentityProvider;
862    fn cipher_suite_provider(&self) -> &Self::CipherSuiteProvider;
863    fn psk_storage(&self) -> Self::PreSharedKeyStorage;
864
865    fn removal_proposal(
866        &self,
867        provisional_state: &ProvisionalState,
868    ) -> Option<ProposalInfo<RemoveProposal>>;
869
870    #[cfg(all(
871        feature = "by_ref_proposal",
872        feature = "custom_proposal",
873        feature = "self_remove_proposal"
874    ))]
875    fn self_removal_proposal(
876        &self,
877        provisional_state: &ProvisionalState,
878    ) -> Option<ProposalInfo<SelfRemoveProposal>>;
879
880    #[cfg(feature = "private_message")]
881    fn min_epoch_available(&self) -> Option<u64>;
882
883    fn check_metadata(&self, message: &MlsMessage) -> Result<(), MlsError> {
884        let context = &self.group_state().context;
885
886        if message.version != context.protocol_version {
887            return Err(MlsError::ProtocolVersionMismatch);
888        }
889
890        if let Some((group_id, epoch, content_type)) = match &message.payload {
891            MlsMessagePayload::Plain(plaintext) => Some((
892                &plaintext.content.group_id,
893                plaintext.content.epoch,
894                plaintext.content.content_type(),
895            )),
896            #[cfg(feature = "private_message")]
897            MlsMessagePayload::Cipher(ciphertext) => Some((
898                &ciphertext.group_id,
899                ciphertext.epoch,
900                ciphertext.content_type,
901            )),
902            _ => None,
903        } {
904            if group_id != &context.group_id {
905                return Err(MlsError::GroupIdMismatch);
906            }
907
908            match content_type {
909                ContentType::Commit => {
910                    if context.epoch != epoch {
911                        Err(MlsError::InvalidEpoch)
912                    } else {
913                        Ok(())
914                    }
915                }
916                #[cfg(feature = "by_ref_proposal")]
917                ContentType::Proposal => {
918                    if context.epoch != epoch {
919                        Err(MlsError::InvalidEpoch)
920                    } else {
921                        Ok(())
922                    }
923                }
924                #[cfg(feature = "private_message")]
925                ContentType::Application => {
926                    if let Some(min) = self.min_epoch_available() {
927                        if epoch < min {
928                            Err(MlsError::InvalidEpoch)
929                        } else {
930                            Ok(())
931                        }
932                    } else {
933                        Ok(())
934                    }
935                }
936            }?;
937
938            // Proposal and commit messages must be sent in the current epoch
939            let check_epoch = content_type == ContentType::Commit;
940
941            #[cfg(feature = "by_ref_proposal")]
942            let check_epoch = check_epoch || content_type == ContentType::Proposal;
943
944            if check_epoch && epoch != context.epoch {
945                return Err(MlsError::InvalidEpoch);
946            }
947
948            // Unencrypted application messages are not allowed
949            #[cfg(feature = "private_message")]
950            if !matches!(&message.payload, MlsMessagePayload::Cipher(_))
951                && content_type == ContentType::Application
952            {
953                return Err(MlsError::UnencryptedApplicationMessage);
954            }
955        }
956
957        Ok(())
958    }
959
960    fn validate_welcome(
961        &self,
962        welcome: &Welcome,
963        version: ProtocolVersion,
964    ) -> Result<(), MlsError> {
965        let state = self.group_state();
966
967        (welcome.cipher_suite == state.context.cipher_suite
968            && version == state.context.protocol_version)
969            .then_some(())
970            .ok_or(MlsError::InvalidWelcomeMessage)
971    }
972
973    async fn validate_key_package(
974        &self,
975        key_package: &KeyPackage,
976        version: ProtocolVersion,
977        time: Option<MlsTime>,
978    ) -> Result<(), MlsError> {
979        let cs = self.cipher_suite_provider();
980        let id = self.identity_provider();
981
982        validate_key_package(key_package, version, cs, &id, time).await
983    }
984
985    #[cfg(feature = "private_message")]
986    async fn process_ciphertext(
987        &mut self,
988        cipher_text: &PrivateMessage,
989    ) -> Result<EventOrContent<Self::OutputType>, MlsError>;
990
991    async fn verify_plaintext_authentication(
992        &self,
993        message: PublicMessage,
994    ) -> Result<EventOrContent<Self::OutputType>, MlsError>;
995
996    #[cfg(all(feature = "export_key_generation", feature = "private_message"))]
997    /// Returns the unauthenticated key generation used to decrypt the private message.
998    async fn get_unauthenticated_key_generation_from_sender_data(
999        &mut self,
1000        cipher_text: &PrivateMessage,
1001    ) -> Result<Option<u32>, MlsError>;
1002
1003    async fn apply_update_path(
1004        &mut self,
1005        sender: LeafIndex,
1006        update_path: &ValidatedUpdatePath,
1007        provisional_state: &mut ProvisionalState,
1008    ) -> Result<Option<(TreeKemPrivate, PathSecret)>, MlsError> {
1009        provisional_state
1010            .public_tree
1011            .apply_update_path(
1012                sender,
1013                update_path,
1014                &provisional_state.group_context.extensions,
1015                self.identity_provider(),
1016                self.cipher_suite_provider(),
1017            )
1018            .await
1019            .map(|_| None)
1020    }
1021
1022    async fn update_key_schedule(
1023        &mut self,
1024        secrets: Option<(TreeKemPrivate, PathSecret)>,
1025        interim_transcript_hash: InterimTranscriptHash,
1026        confirmation_tag: &ConfirmationTag,
1027        provisional_public_state: ProvisionalState,
1028    ) -> Result<(), MlsError>;
1029}
1030
1031#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
1032pub(crate) async fn validate_key_package<C: CipherSuiteProvider, I: IdentityProvider>(
1033    key_package: &KeyPackage,
1034    version: ProtocolVersion,
1035    cs: &C,
1036    id: &I,
1037    time: Option<MlsTime>,
1038) -> Result<(), MlsError> {
1039    let validator = LeafNodeValidator::new(cs, id, MemberValidationContext::None);
1040
1041    #[cfg(feature = "std")]
1042    let context = Some(MlsTime::now());
1043
1044    #[cfg(not(feature = "std"))]
1045    let context = None;
1046
1047    let context = if time.is_some() { time } else { context };
1048
1049    let context = ValidationContext::Add(context);
1050
1051    validator
1052        .check_if_valid(&key_package.leaf_node, context)
1053        .await?;
1054
1055    validate_key_package_properties(key_package, version, cs).await?;
1056
1057    Ok(())
1058}
1059
1060#[cfg(test)]
1061mod tests {
1062    use alloc::{vec, vec::Vec};
1063    use mls_rs_codec::{MlsDecode, MlsEncode};
1064
1065    use crate::{
1066        client::test_utils::TEST_PROTOCOL_VERSION,
1067        group::{test_utils::get_test_group_context, GroupState, Sender},
1068    };
1069
1070    use super::{CommitEffect, NewEpoch};
1071
1072    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1073    async fn commit_effect_codec() {
1074        let epoch = NewEpoch {
1075            epoch: 7,
1076            prior_state: GroupState {
1077                #[cfg(feature = "by_ref_proposal")]
1078                proposals: crate::group::ProposalCache::new(TEST_PROTOCOL_VERSION, vec![]),
1079                context: get_test_group_context(7, 7.into()).await,
1080                public_tree: Default::default(),
1081                interim_transcript_hash: vec![].into(),
1082                pending_reinit: None,
1083                confirmation_tag: Default::default(),
1084            },
1085            applied_proposals: vec![],
1086            unused_proposals: vec![],
1087        };
1088
1089        let effects = vec![
1090            CommitEffect::NewEpoch(epoch.clone().into()),
1091            CommitEffect::Removed {
1092                new_epoch: epoch.into(),
1093                remover: Sender::Member(0),
1094            },
1095        ];
1096
1097        let bytes = effects.mls_encode_to_vec().unwrap();
1098
1099        assert_eq!(
1100            effects,
1101            Vec::<CommitEffect>::mls_decode(&mut &*bytes).unwrap()
1102        );
1103    }
1104}