mls_rs/group/
message_processor.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// Copyright by contributors to this project.
3// SPDX-License-Identifier: (Apache-2.0 OR MIT)
4
5use super::{
6    commit_sender,
7    confirmation_tag::ConfirmationTag,
8    framing::{
9        ApplicationData, Content, ContentType, MlsMessage, MlsMessagePayload, PublicMessage, Sender,
10    },
11    message_signature::AuthenticatedContent,
12    mls_rules::{CommitDirection, MlsRules},
13    proposal_filter::ProposalBundle,
14    state::GroupState,
15    transcript_hash::InterimTranscriptHash,
16    transcript_hashes, validate_group_info_member, GroupContext, GroupInfo, ReInitProposal,
17    RemoveProposal, Welcome,
18};
19use crate::{
20    client::MlsError,
21    key_package::validate_key_package_properties,
22    time::MlsTime,
23    tree_kem::{
24        leaf_node_validator::{LeafNodeValidator, ValidationContext},
25        node::LeafIndex,
26        path_secret::PathSecret,
27        validate_update_path, TreeKemPrivate, TreeKemPublic, ValidatedUpdatePath,
28    },
29    CipherSuiteProvider, KeyPackage,
30};
31use itertools::Itertools;
32use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
33
34use alloc::boxed::Box;
35use alloc::vec::Vec;
36use core::fmt::{self, Debug};
37use mls_rs_core::{
38    identity::{IdentityProvider, MemberValidationContext},
39    protocol_version::ProtocolVersion,
40    psk::PreSharedKeyStorage,
41};
42
43#[cfg(feature = "by_ref_proposal")]
44use super::proposal_ref::ProposalRef;
45
46#[cfg(not(feature = "by_ref_proposal"))]
47use crate::group::proposal_cache::resolve_for_commit;
48
49use super::proposal::Proposal;
50use super::proposal_filter::ProposalInfo;
51
52#[cfg(feature = "private_message")]
53use crate::group::framing::PrivateMessage;
54
55#[derive(Debug)]
56pub(crate) struct ProvisionalState {
57    pub(crate) public_tree: TreeKemPublic,
58    pub(crate) applied_proposals: ProposalBundle,
59    pub(crate) group_context: GroupContext,
60    pub(crate) external_init_index: Option<LeafIndex>,
61    pub(crate) indexes_of_added_kpkgs: Vec<LeafIndex>,
62    pub(crate) unused_proposals: Vec<ProposalInfo<Proposal>>,
63}
64
65//By default, the path field of a Commit MUST be populated. The path field MAY be omitted if
66//(a) it covers at least one proposal and (b) none of the proposals covered by the Commit are
67//of "path required" types. A proposal type requires a path if it cannot change the group
68//membership in a way that requires the forward secrecy and post-compromise security guarantees
69//that an UpdatePath provides. The only proposal types defined in this document that do not
70//require a path are:
71
72// add
73// psk
74// reinit
75pub(crate) fn path_update_required(proposals: &ProposalBundle) -> bool {
76    let res = !proposals.external_init_proposals().is_empty();
77
78    #[cfg(feature = "by_ref_proposal")]
79    let res = res || !proposals.update_proposals().is_empty();
80
81    res || proposals.length() == 0
82        || proposals.group_context_extensions_proposal().is_some()
83        || !proposals.remove_proposals().is_empty()
84}
85
86#[cfg_attr(
87    all(feature = "ffi", not(test)),
88    safer_ffi_gen::ffi_type(clone, opaque)
89)]
90#[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode)]
91#[non_exhaustive]
92pub struct NewEpoch {
93    pub epoch: u64,
94    pub prior_state: GroupState,
95    pub applied_proposals: Vec<ProposalInfo<Proposal>>,
96    pub unused_proposals: Vec<ProposalInfo<Proposal>>,
97}
98
99impl NewEpoch {
100    pub(crate) fn new(prior_state: GroupState, provisional_state: &ProvisionalState) -> NewEpoch {
101        NewEpoch {
102            epoch: provisional_state.group_context.epoch,
103            prior_state,
104            unused_proposals: provisional_state.unused_proposals.clone(),
105            applied_proposals: provisional_state
106                .applied_proposals
107                .clone()
108                .into_proposals()
109                .collect_vec(),
110        }
111    }
112}
113
114#[cfg(all(feature = "ffi", not(test)))]
115#[safer_ffi_gen::safer_ffi_gen]
116impl NewEpoch {
117    pub fn epoch(&self) -> u64 {
118        self.epoch
119    }
120
121    pub fn prior_state(&self) -> &GroupState {
122        &self.prior_state
123    }
124
125    pub fn applied_proposals(&self) -> &[ProposalInfo<Proposal>] {
126        &self.applied_proposals
127    }
128
129    pub fn unused_proposals(&self) -> &[ProposalInfo<Proposal>] {
130        &self.unused_proposals
131    }
132}
133
134#[cfg_attr(
135    all(feature = "ffi", not(test)),
136    safer_ffi_gen::ffi_type(clone, opaque)
137)]
138#[derive(Clone, Debug, PartialEq)]
139pub enum CommitEffect {
140    NewEpoch(Box<NewEpoch>),
141    Removed {
142        new_epoch: Box<NewEpoch>,
143        remover: Sender,
144    },
145    ReInit(ProposalInfo<ReInitProposal>),
146}
147
148impl MlsSize for CommitEffect {
149    fn mls_encoded_len(&self) -> usize {
150        0u8.mls_encoded_len()
151            + match self {
152                Self::NewEpoch(e) => e.mls_encoded_len(),
153                Self::Removed { new_epoch, remover } => {
154                    new_epoch.mls_encoded_len() + remover.mls_encoded_len()
155                }
156                Self::ReInit(r) => r.mls_encoded_len(),
157            }
158    }
159}
160
161impl MlsEncode for CommitEffect {
162    fn mls_encode(&self, writer: &mut Vec<u8>) -> Result<(), mls_rs_codec::Error> {
163        match self {
164            Self::NewEpoch(e) => {
165                1u8.mls_encode(writer)?;
166                e.mls_encode(writer)?;
167            }
168            Self::Removed { new_epoch, remover } => {
169                2u8.mls_encode(writer)?;
170                new_epoch.mls_encode(writer)?;
171                remover.mls_encode(writer)?;
172            }
173            Self::ReInit(r) => {
174                3u8.mls_encode(writer)?;
175                r.mls_encode(writer)?;
176            }
177        }
178
179        Ok(())
180    }
181}
182
183impl MlsDecode for CommitEffect {
184    fn mls_decode(reader: &mut &[u8]) -> Result<Self, mls_rs_codec::Error> {
185        match u8::mls_decode(reader)? {
186            1u8 => Ok(Self::NewEpoch(NewEpoch::mls_decode(reader)?.into())),
187            2u8 => Ok(Self::Removed {
188                new_epoch: NewEpoch::mls_decode(reader)?.into(),
189                remover: Sender::mls_decode(reader)?,
190            }),
191            3u8 => Ok(Self::ReInit(ProposalInfo::mls_decode(reader)?)),
192            _ => Err(mls_rs_codec::Error::UnsupportedEnumDiscriminant),
193        }
194    }
195}
196
197#[cfg_attr(
198    all(feature = "ffi", not(test)),
199    safer_ffi_gen::ffi_type(clone, opaque)
200)]
201#[derive(Debug, Clone)]
202#[allow(clippy::large_enum_variant)]
203/// An event generated as a result of processing a message for a group with
204/// [`Group::process_incoming_message`](crate::group::Group::process_incoming_message).
205pub enum ReceivedMessage {
206    /// An application message was decrypted.
207    ApplicationMessage(ApplicationMessageDescription),
208    /// A new commit was processed creating a new group state.
209    Commit(CommitMessageDescription),
210    /// A proposal was received.
211    Proposal(ProposalMessageDescription),
212    /// Validated GroupInfo object
213    GroupInfo(GroupInfo),
214    /// Validated welcome message
215    Welcome,
216    /// Validated key package
217    KeyPackage(KeyPackage),
218}
219
220impl TryFrom<ApplicationMessageDescription> for ReceivedMessage {
221    type Error = MlsError;
222
223    fn try_from(value: ApplicationMessageDescription) -> Result<Self, Self::Error> {
224        Ok(ReceivedMessage::ApplicationMessage(value))
225    }
226}
227
228impl From<CommitMessageDescription> for ReceivedMessage {
229    fn from(value: CommitMessageDescription) -> Self {
230        ReceivedMessage::Commit(value)
231    }
232}
233
234impl From<ProposalMessageDescription> for ReceivedMessage {
235    fn from(value: ProposalMessageDescription) -> Self {
236        ReceivedMessage::Proposal(value)
237    }
238}
239
240impl From<GroupInfo> for ReceivedMessage {
241    fn from(value: GroupInfo) -> Self {
242        ReceivedMessage::GroupInfo(value)
243    }
244}
245
246impl From<Welcome> for ReceivedMessage {
247    fn from(_: Welcome) -> Self {
248        ReceivedMessage::Welcome
249    }
250}
251
252impl From<KeyPackage> for ReceivedMessage {
253    fn from(value: KeyPackage) -> Self {
254        ReceivedMessage::KeyPackage(value)
255    }
256}
257
258#[cfg_attr(
259    all(feature = "ffi", not(test)),
260    safer_ffi_gen::ffi_type(clone, opaque)
261)]
262#[derive(Clone, PartialEq, Eq)]
263/// Description of a MLS application message.
264pub struct ApplicationMessageDescription {
265    /// Index of this user in the group state.
266    pub sender_index: u32,
267    /// Received application data.
268    data: ApplicationData,
269    /// Plaintext authenticated data in the received MLS packet.
270    pub authenticated_data: Vec<u8>,
271}
272
273impl Debug for ApplicationMessageDescription {
274    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
275        f.debug_struct("ApplicationMessageDescription")
276            .field("sender_index", &self.sender_index)
277            .field("data", &self.data)
278            .field(
279                "authenticated_data",
280                &mls_rs_core::debug::pretty_bytes(&self.authenticated_data),
281            )
282            .finish()
283    }
284}
285
286#[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen)]
287impl ApplicationMessageDescription {
288    pub fn data(&self) -> &[u8] {
289        self.data.as_bytes()
290    }
291}
292
293#[cfg_attr(
294    all(feature = "ffi", not(test)),
295    safer_ffi_gen::ffi_type(clone, opaque)
296)]
297#[derive(Clone, PartialEq, MlsSize, MlsEncode, MlsDecode)]
298#[non_exhaustive]
299/// Description of a processed MLS commit message.
300pub struct CommitMessageDescription {
301    /// True if this is the result of an external commit.
302    pub is_external: bool,
303    /// The index in the group state of the member who performed this commit.
304    pub committer: u32,
305    /// A full description of group state changes as a result of this commit.
306    pub effect: CommitEffect,
307    /// Plaintext authenticated data in the received MLS packet.
308    #[mls_codec(with = "mls_rs_codec::byte_vec")]
309    pub authenticated_data: Vec<u8>,
310}
311
312impl Debug for CommitMessageDescription {
313    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
314        f.debug_struct("CommitMessageDescription")
315            .field("is_external", &self.is_external)
316            .field("committer", &self.committer)
317            .field("effect", &self.effect)
318            .field(
319                "authenticated_data",
320                &mls_rs_core::debug::pretty_bytes(&self.authenticated_data),
321            )
322            .finish()
323    }
324}
325
326#[derive(Debug, Clone, Copy, PartialEq, Eq, MlsEncode, MlsDecode, MlsSize)]
327#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
328#[repr(u8)]
329/// Proposal sender type.
330pub enum ProposalSender {
331    /// A current member of the group by index in the group state.
332    Member(u32) = 1u8,
333    /// An external entity by index within an
334    /// [`ExternalSendersExt`](crate::extension::built_in::ExternalSendersExt).
335    External(u32) = 2u8,
336    /// A new member proposing their addition to the group.
337    NewMember = 3u8,
338}
339
340impl TryFrom<Sender> for ProposalSender {
341    type Error = MlsError;
342
343    fn try_from(value: Sender) -> Result<Self, Self::Error> {
344        match value {
345            Sender::Member(index) => Ok(Self::Member(index)),
346            #[cfg(feature = "by_ref_proposal")]
347            Sender::External(index) => Ok(Self::External(index)),
348            #[cfg(feature = "by_ref_proposal")]
349            Sender::NewMemberProposal => Ok(Self::NewMember),
350            Sender::NewMemberCommit => Err(MlsError::InvalidSender),
351        }
352    }
353}
354
355#[cfg(feature = "by_ref_proposal")]
356#[cfg_attr(
357    all(feature = "ffi", not(test)),
358    safer_ffi_gen::ffi_type(clone, opaque)
359)]
360#[derive(Clone, MlsEncode, MlsDecode, MlsSize, PartialEq)]
361#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
362#[non_exhaustive]
363/// Description of a processed MLS proposal message.
364pub struct ProposalMessageDescription {
365    /// Sender of the proposal.
366    pub sender: ProposalSender,
367    /// Proposal content.
368    pub proposal: Proposal,
369    /// Plaintext authenticated data in the received MLS packet.
370    pub authenticated_data: Vec<u8>,
371    /// Proposal reference.
372    pub proposal_ref: ProposalRef,
373}
374
375#[cfg(feature = "by_ref_proposal")]
376impl Debug for ProposalMessageDescription {
377    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
378        f.debug_struct("ProposalMessageDescription")
379            .field("sender", &self.sender)
380            .field("proposal", &self.proposal)
381            .field(
382                "authenticated_data",
383                &mls_rs_core::debug::pretty_bytes(&self.authenticated_data),
384            )
385            .field("proposal_ref", &self.proposal_ref)
386            .finish()
387    }
388}
389
390#[cfg(feature = "by_ref_proposal")]
391#[derive(MlsSize, MlsEncode, MlsDecode)]
392pub struct CachedProposal {
393    pub(crate) proposal: Proposal,
394    pub(crate) proposal_ref: ProposalRef,
395    pub(crate) sender: Sender,
396}
397
398#[cfg(feature = "by_ref_proposal")]
399impl CachedProposal {
400    /// Deserialize the proposal
401    pub fn from_bytes(bytes: &[u8]) -> Result<Self, MlsError> {
402        Ok(Self::mls_decode(&mut &*bytes)?)
403    }
404
405    /// Serialize the proposal
406    pub fn to_bytes(&self) -> Result<Vec<u8>, MlsError> {
407        Ok(self.mls_encode_to_vec()?)
408    }
409}
410
411#[cfg(feature = "by_ref_proposal")]
412impl ProposalMessageDescription {
413    pub fn cached_proposal(self) -> CachedProposal {
414        let sender = match self.sender {
415            ProposalSender::Member(i) => Sender::Member(i),
416            ProposalSender::External(i) => Sender::External(i),
417            ProposalSender::NewMember => Sender::NewMemberProposal,
418        };
419
420        CachedProposal {
421            proposal: self.proposal,
422            proposal_ref: self.proposal_ref,
423            sender,
424        }
425    }
426
427    pub fn proposal_ref(&self) -> Vec<u8> {
428        self.proposal_ref.to_vec()
429    }
430
431    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
432    pub(crate) async fn new<C: CipherSuiteProvider>(
433        cs: &C,
434        content: &AuthenticatedContent,
435        proposal: Proposal,
436    ) -> Result<Self, MlsError> {
437        Ok(ProposalMessageDescription {
438            authenticated_data: content.content.authenticated_data.clone(),
439            proposal,
440            sender: content.content.sender.try_into()?,
441            proposal_ref: ProposalRef::from_content(cs, content).await?,
442        })
443    }
444}
445
446#[cfg(not(feature = "by_ref_proposal"))]
447#[cfg_attr(
448    all(feature = "ffi", not(test)),
449    safer_ffi_gen::ffi_type(clone, opaque)
450)]
451#[derive(Debug, Clone)]
452/// Description of a processed MLS proposal message.
453pub struct ProposalMessageDescription {}
454
455#[allow(clippy::large_enum_variant)]
456pub(crate) enum EventOrContent<E> {
457    #[cfg_attr(
458        not(all(feature = "private_message", feature = "external_client")),
459        allow(dead_code)
460    )]
461    Event(E),
462    Content(AuthenticatedContent),
463}
464
465#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
466#[cfg_attr(all(target_arch = "wasm32", mls_build_async), maybe_async::must_be_async(?Send))]
467#[cfg_attr(
468    all(not(target_arch = "wasm32"), mls_build_async),
469    maybe_async::must_be_async
470)]
471pub(crate) trait MessageProcessor: Send + Sync {
472    type OutputType: TryFrom<ApplicationMessageDescription, Error = MlsError>
473        + From<CommitMessageDescription>
474        + From<ProposalMessageDescription>
475        + From<GroupInfo>
476        + From<Welcome>
477        + From<KeyPackage>
478        + Send;
479
480    type MlsRules: MlsRules;
481    type IdentityProvider: IdentityProvider;
482    type CipherSuiteProvider: CipherSuiteProvider;
483    type PreSharedKeyStorage: PreSharedKeyStorage;
484
485    async fn process_incoming_message(
486        &mut self,
487        message: MlsMessage,
488        #[cfg(feature = "by_ref_proposal")] cache_proposal: bool,
489    ) -> Result<Self::OutputType, MlsError> {
490        self.process_incoming_message_with_time(
491            message,
492            #[cfg(feature = "by_ref_proposal")]
493            cache_proposal,
494            None,
495        )
496        .await
497    }
498
499    async fn process_incoming_message_with_time(
500        &mut self,
501        message: MlsMessage,
502        #[cfg(feature = "by_ref_proposal")] cache_proposal: bool,
503        time_sent: Option<MlsTime>,
504    ) -> Result<Self::OutputType, MlsError> {
505        let event_or_content = self.get_event_from_incoming_message(message).await?;
506
507        self.process_event_or_content(
508            event_or_content,
509            #[cfg(feature = "by_ref_proposal")]
510            cache_proposal,
511            time_sent,
512        )
513        .await
514    }
515
516    async fn get_event_from_incoming_message(
517        &mut self,
518        message: MlsMessage,
519    ) -> Result<EventOrContent<Self::OutputType>, MlsError> {
520        self.check_metadata(&message)?;
521
522        match message.payload {
523            MlsMessagePayload::Plain(plaintext) => {
524                self.verify_plaintext_authentication(plaintext).await
525            }
526            #[cfg(feature = "private_message")]
527            MlsMessagePayload::Cipher(cipher_text) => self.process_ciphertext(&cipher_text).await,
528            MlsMessagePayload::GroupInfo(group_info) => {
529                validate_group_info_member(
530                    self.group_state(),
531                    message.version,
532                    &group_info,
533                    self.cipher_suite_provider(),
534                )
535                .await?;
536
537                Ok(EventOrContent::Event(group_info.into()))
538            }
539            MlsMessagePayload::Welcome(welcome) => {
540                self.validate_welcome(&welcome, message.version)?;
541
542                Ok(EventOrContent::Event(welcome.into()))
543            }
544            MlsMessagePayload::KeyPackage(key_package) => {
545                self.validate_key_package(&key_package, message.version)
546                    .await?;
547
548                Ok(EventOrContent::Event(key_package.into()))
549            }
550        }
551    }
552
553    async fn process_event_or_content(
554        &mut self,
555        event_or_content: EventOrContent<Self::OutputType>,
556        #[cfg(feature = "by_ref_proposal")] cache_proposal: bool,
557        time_sent: Option<MlsTime>,
558    ) -> Result<Self::OutputType, MlsError> {
559        let msg = match event_or_content {
560            EventOrContent::Event(event) => event,
561            EventOrContent::Content(content) => {
562                self.process_auth_content(
563                    content,
564                    #[cfg(feature = "by_ref_proposal")]
565                    cache_proposal,
566                    time_sent,
567                )
568                .await?
569            }
570        };
571
572        Ok(msg)
573    }
574
575    async fn process_auth_content(
576        &mut self,
577        auth_content: AuthenticatedContent,
578        #[cfg(feature = "by_ref_proposal")] cache_proposal: bool,
579        time_sent: Option<MlsTime>,
580    ) -> Result<Self::OutputType, MlsError> {
581        let event = match auth_content.content.content {
582            #[cfg(feature = "private_message")]
583            Content::Application(data) => {
584                let authenticated_data = auth_content.content.authenticated_data;
585                let sender = auth_content.content.sender;
586
587                self.process_application_message(data, sender, authenticated_data)
588                    .and_then(Self::OutputType::try_from)
589            }
590            Content::Commit(_) => self
591                .process_commit(auth_content, time_sent)
592                .await
593                .map(Self::OutputType::from),
594            #[cfg(feature = "by_ref_proposal")]
595            Content::Proposal(ref proposal) => self
596                .process_proposal(&auth_content, proposal, cache_proposal)
597                .await
598                .map(Self::OutputType::from),
599        }?;
600
601        Ok(event)
602    }
603
604    #[cfg(feature = "private_message")]
605    fn process_application_message(
606        &self,
607        data: ApplicationData,
608        sender: Sender,
609        authenticated_data: Vec<u8>,
610    ) -> Result<ApplicationMessageDescription, MlsError> {
611        let Sender::Member(sender_index) = sender else {
612            return Err(MlsError::InvalidSender);
613        };
614
615        Ok(ApplicationMessageDescription {
616            authenticated_data,
617            sender_index,
618            data,
619        })
620    }
621
622    #[cfg(feature = "by_ref_proposal")]
623    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
624    async fn process_proposal(
625        &mut self,
626        auth_content: &AuthenticatedContent,
627        proposal: &Proposal,
628        cache_proposal: bool,
629    ) -> Result<ProposalMessageDescription, MlsError> {
630        let proposal = ProposalMessageDescription::new(
631            self.cipher_suite_provider(),
632            auth_content,
633            proposal.clone(),
634        )
635        .await?;
636
637        let group_state = self.group_state_mut();
638
639        if cache_proposal {
640            group_state.proposals.insert(
641                proposal.proposal_ref.clone(),
642                proposal.proposal.clone(),
643                auth_content.content.sender,
644            );
645        }
646
647        Ok(proposal)
648    }
649
650    async fn process_commit(
651        &mut self,
652        auth_content: AuthenticatedContent,
653        time_sent: Option<MlsTime>,
654    ) -> Result<CommitMessageDescription, MlsError> {
655        if self.group_state().pending_reinit.is_some() {
656            return Err(MlsError::GroupUsedAfterReInit);
657        }
658
659        // Update the new GroupContext's confirmed and interim transcript hashes using the new Commit.
660        let (interim_transcript_hash, confirmed_transcript_hash) = transcript_hashes(
661            self.cipher_suite_provider(),
662            &self.group_state().interim_transcript_hash,
663            &auth_content,
664        )
665        .await?;
666
667        #[cfg(any(feature = "private_message", feature = "by_ref_proposal"))]
668        let commit = match auth_content.content.content {
669            Content::Commit(commit) => Ok(commit),
670            _ => Err(MlsError::UnexpectedMessageType),
671        }?;
672
673        #[cfg(not(any(feature = "private_message", feature = "by_ref_proposal")))]
674        let Content::Commit(commit) = auth_content.content.content;
675
676        let group_state = self.group_state();
677        let id_provider = self.identity_provider();
678
679        #[cfg(feature = "by_ref_proposal")]
680        let proposals = group_state
681            .proposals
682            .resolve_for_commit(auth_content.content.sender, commit.proposals)?;
683
684        #[cfg(not(feature = "by_ref_proposal"))]
685        let proposals = resolve_for_commit(auth_content.content.sender, commit.proposals)?;
686
687        let mut provisional_state = group_state
688            .apply_resolved(
689                auth_content.content.sender,
690                proposals,
691                commit.path.as_ref().map(|path| &path.leaf_node),
692                &id_provider,
693                self.cipher_suite_provider(),
694                &self.psk_storage(),
695                &self.mls_rules(),
696                time_sent,
697                CommitDirection::Receive,
698            )
699            .await?;
700
701        let sender = commit_sender(&auth_content.content.sender, &provisional_state)?;
702
703        //Verify that the path value is populated if the proposals vector contains any Update
704        // or Remove proposals, or if it's empty. Otherwise, the path value MAY be omitted.
705        if path_update_required(&provisional_state.applied_proposals) && commit.path.is_none() {
706            return Err(MlsError::CommitMissingPath);
707        }
708
709        let self_removed = self.removal_proposal(&provisional_state);
710        let is_self_removed = self_removed.is_some();
711
712        let update_path = match commit.path {
713            Some(update_path) => Some(
714                validate_update_path(
715                    &self.identity_provider(),
716                    self.cipher_suite_provider(),
717                    update_path,
718                    &provisional_state,
719                    sender,
720                    time_sent,
721                    &group_state.context,
722                )
723                .await?,
724            ),
725            None => None,
726        };
727
728        let commit_effect =
729            if let Some(reinit) = provisional_state.applied_proposals.reinitializations.pop() {
730                self.group_state_mut().pending_reinit = Some(reinit.proposal.clone());
731                CommitEffect::ReInit(reinit)
732            } else if let Some(remove_proposal) = self_removed {
733                let new_epoch = NewEpoch::new(self.group_state().clone(), &provisional_state);
734                CommitEffect::Removed {
735                    remover: remove_proposal.sender,
736                    new_epoch: Box::new(new_epoch),
737                }
738            } else {
739                CommitEffect::NewEpoch(Box::new(NewEpoch::new(
740                    self.group_state().clone(),
741                    &provisional_state,
742                )))
743            };
744
745        let new_secrets = match update_path {
746            Some(update_path) if !is_self_removed => {
747                self.apply_update_path(sender, &update_path, &mut provisional_state)
748                    .await
749            }
750            _ => Ok(None),
751        }?;
752
753        // Update the transcript hash to get the new context.
754        provisional_state.group_context.confirmed_transcript_hash = confirmed_transcript_hash;
755
756        // Update the parent hashes in the new context
757        provisional_state
758            .public_tree
759            .update_hashes(&[sender], self.cipher_suite_provider())
760            .await?;
761
762        // Update the tree hash in the new context
763        provisional_state.group_context.tree_hash = provisional_state
764            .public_tree
765            .tree_hash(self.cipher_suite_provider())
766            .await?;
767
768        if let Some(confirmation_tag) = &auth_content.auth.confirmation_tag {
769            if !is_self_removed {
770                // Update the key schedule to calculate new private keys
771                self.update_key_schedule(
772                    new_secrets,
773                    interim_transcript_hash,
774                    confirmation_tag,
775                    provisional_state,
776                )
777                .await?;
778            }
779
780            Ok(CommitMessageDescription {
781                is_external: matches!(auth_content.content.sender, Sender::NewMemberCommit),
782                authenticated_data: auth_content.content.authenticated_data,
783                committer: *sender,
784                effect: commit_effect,
785            })
786        } else {
787            Err(MlsError::InvalidConfirmationTag)
788        }
789    }
790
791    fn group_state(&self) -> &GroupState;
792    fn group_state_mut(&mut self) -> &mut GroupState;
793    fn mls_rules(&self) -> Self::MlsRules;
794    fn identity_provider(&self) -> Self::IdentityProvider;
795    fn cipher_suite_provider(&self) -> &Self::CipherSuiteProvider;
796    fn psk_storage(&self) -> Self::PreSharedKeyStorage;
797
798    fn removal_proposal(
799        &self,
800        provisional_state: &ProvisionalState,
801    ) -> Option<ProposalInfo<RemoveProposal>>;
802
803    #[cfg(feature = "private_message")]
804    fn min_epoch_available(&self) -> Option<u64>;
805
806    fn check_metadata(&self, message: &MlsMessage) -> Result<(), MlsError> {
807        let context = &self.group_state().context;
808
809        if message.version != context.protocol_version {
810            return Err(MlsError::ProtocolVersionMismatch);
811        }
812
813        if let Some((group_id, epoch, content_type)) = match &message.payload {
814            MlsMessagePayload::Plain(plaintext) => Some((
815                &plaintext.content.group_id,
816                plaintext.content.epoch,
817                plaintext.content.content_type(),
818            )),
819            #[cfg(feature = "private_message")]
820            MlsMessagePayload::Cipher(ciphertext) => Some((
821                &ciphertext.group_id,
822                ciphertext.epoch,
823                ciphertext.content_type,
824            )),
825            _ => None,
826        } {
827            if group_id != &context.group_id {
828                return Err(MlsError::GroupIdMismatch);
829            }
830
831            match content_type {
832                ContentType::Commit => {
833                    if context.epoch != epoch {
834                        Err(MlsError::InvalidEpoch)
835                    } else {
836                        Ok(())
837                    }
838                }
839                #[cfg(feature = "by_ref_proposal")]
840                ContentType::Proposal => {
841                    if context.epoch != epoch {
842                        Err(MlsError::InvalidEpoch)
843                    } else {
844                        Ok(())
845                    }
846                }
847                #[cfg(feature = "private_message")]
848                ContentType::Application => {
849                    if let Some(min) = self.min_epoch_available() {
850                        if epoch < min {
851                            Err(MlsError::InvalidEpoch)
852                        } else {
853                            Ok(())
854                        }
855                    } else {
856                        Ok(())
857                    }
858                }
859            }?;
860
861            // Proposal and commit messages must be sent in the current epoch
862            let check_epoch = content_type == ContentType::Commit;
863
864            #[cfg(feature = "by_ref_proposal")]
865            let check_epoch = check_epoch || content_type == ContentType::Proposal;
866
867            if check_epoch && epoch != context.epoch {
868                return Err(MlsError::InvalidEpoch);
869            }
870
871            // Unencrypted application messages are not allowed
872            #[cfg(feature = "private_message")]
873            if !matches!(&message.payload, MlsMessagePayload::Cipher(_))
874                && content_type == ContentType::Application
875            {
876                return Err(MlsError::UnencryptedApplicationMessage);
877            }
878        }
879
880        Ok(())
881    }
882
883    fn validate_welcome(
884        &self,
885        welcome: &Welcome,
886        version: ProtocolVersion,
887    ) -> Result<(), MlsError> {
888        let state = self.group_state();
889
890        (welcome.cipher_suite == state.context.cipher_suite
891            && version == state.context.protocol_version)
892            .then_some(())
893            .ok_or(MlsError::InvalidWelcomeMessage)
894    }
895
896    async fn validate_key_package(
897        &self,
898        key_package: &KeyPackage,
899        version: ProtocolVersion,
900    ) -> Result<(), MlsError> {
901        let cs = self.cipher_suite_provider();
902        let id = self.identity_provider();
903
904        validate_key_package(key_package, version, cs, &id).await
905    }
906
907    #[cfg(feature = "private_message")]
908    async fn process_ciphertext(
909        &mut self,
910        cipher_text: &PrivateMessage,
911    ) -> Result<EventOrContent<Self::OutputType>, MlsError>;
912
913    async fn verify_plaintext_authentication(
914        &self,
915        message: PublicMessage,
916    ) -> Result<EventOrContent<Self::OutputType>, MlsError>;
917
918    async fn apply_update_path(
919        &mut self,
920        sender: LeafIndex,
921        update_path: &ValidatedUpdatePath,
922        provisional_state: &mut ProvisionalState,
923    ) -> Result<Option<(TreeKemPrivate, PathSecret)>, MlsError> {
924        provisional_state
925            .public_tree
926            .apply_update_path(
927                sender,
928                update_path,
929                &provisional_state.group_context.extensions,
930                self.identity_provider(),
931                self.cipher_suite_provider(),
932            )
933            .await
934            .map(|_| None)
935    }
936
937    async fn update_key_schedule(
938        &mut self,
939        secrets: Option<(TreeKemPrivate, PathSecret)>,
940        interim_transcript_hash: InterimTranscriptHash,
941        confirmation_tag: &ConfirmationTag,
942        provisional_public_state: ProvisionalState,
943    ) -> Result<(), MlsError>;
944}
945
946#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
947pub(crate) async fn validate_key_package<C: CipherSuiteProvider, I: IdentityProvider>(
948    key_package: &KeyPackage,
949    version: ProtocolVersion,
950    cs: &C,
951    id: &I,
952) -> Result<(), MlsError> {
953    let validator = LeafNodeValidator::new(cs, id, MemberValidationContext::None);
954
955    #[cfg(feature = "std")]
956    let context = Some(MlsTime::now());
957
958    #[cfg(not(feature = "std"))]
959    let context = None;
960
961    let context = ValidationContext::Add(context);
962
963    validator
964        .check_if_valid(&key_package.leaf_node, context)
965        .await?;
966
967    validate_key_package_properties(key_package, version, cs).await?;
968
969    Ok(())
970}
971
972#[cfg(test)]
973mod tests {
974    use alloc::{vec, vec::Vec};
975    use mls_rs_codec::{MlsDecode, MlsEncode};
976
977    use crate::{
978        client::test_utils::TEST_PROTOCOL_VERSION,
979        group::{test_utils::get_test_group_context, GroupState, Sender},
980    };
981
982    use super::{CommitEffect, NewEpoch};
983
984    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
985    async fn commit_effect_codec() {
986        let epoch = NewEpoch {
987            epoch: 7,
988            prior_state: GroupState {
989                #[cfg(feature = "by_ref_proposal")]
990                proposals: crate::group::ProposalCache::new(TEST_PROTOCOL_VERSION, vec![]),
991                context: get_test_group_context(7, 7.into()).await,
992                public_tree: Default::default(),
993                interim_transcript_hash: vec![].into(),
994                pending_reinit: None,
995                confirmation_tag: Default::default(),
996            },
997            applied_proposals: vec![],
998            unused_proposals: vec![],
999        };
1000
1001        let effects = vec![
1002            CommitEffect::NewEpoch(epoch.clone().into()),
1003            CommitEffect::Removed {
1004                new_epoch: epoch.into(),
1005                remover: Sender::Member(0),
1006            },
1007        ];
1008
1009        let bytes = effects.mls_encode_to_vec().unwrap();
1010
1011        assert_eq!(
1012            effects,
1013            Vec::<CommitEffect>::mls_decode(&mut &*bytes).unwrap()
1014        );
1015    }
1016}