1use super::{
6 commit_sender,
7 confirmation_tag::ConfirmationTag,
8 framing::{
9 ApplicationData, Content, ContentType, MlsMessage, MlsMessagePayload, PublicMessage, Sender,
10 },
11 message_signature::AuthenticatedContent,
12 mls_rules::{CommitDirection, MlsRules},
13 proposal_filter::ProposalBundle,
14 state::GroupState,
15 transcript_hash::InterimTranscriptHash,
16 transcript_hashes, validate_group_info_member, GroupContext, GroupInfo, ReInitProposal,
17 RemoveProposal, Welcome,
18};
19use crate::{
20 client::MlsError,
21 key_package::validate_key_package_properties,
22 time::MlsTime,
23 tree_kem::{
24 leaf_node_validator::{LeafNodeValidator, ValidationContext},
25 node::LeafIndex,
26 path_secret::PathSecret,
27 validate_update_path, TreeKemPrivate, TreeKemPublic, ValidatedUpdatePath,
28 },
29 CipherSuiteProvider, KeyPackage,
30};
31use itertools::Itertools;
32use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
33
34use alloc::boxed::Box;
35use alloc::vec::Vec;
36use core::fmt::{self, Debug};
37use mls_rs_core::{
38 identity::{IdentityProvider, MemberValidationContext},
39 protocol_version::ProtocolVersion,
40 psk::PreSharedKeyStorage,
41};
42
43#[cfg(feature = "by_ref_proposal")]
44use super::proposal_ref::ProposalRef;
45
46#[cfg(not(feature = "by_ref_proposal"))]
47use crate::group::proposal_cache::resolve_for_commit;
48
49use super::proposal::Proposal;
50use super::proposal_filter::ProposalInfo;
51
52#[cfg(feature = "private_message")]
53use crate::group::framing::PrivateMessage;
54
55#[derive(Debug)]
56pub(crate) struct ProvisionalState {
57 pub(crate) public_tree: TreeKemPublic,
58 pub(crate) applied_proposals: ProposalBundle,
59 pub(crate) group_context: GroupContext,
60 pub(crate) external_init_index: Option<LeafIndex>,
61 pub(crate) indexes_of_added_kpkgs: Vec<LeafIndex>,
62 pub(crate) unused_proposals: Vec<ProposalInfo<Proposal>>,
63}
64
65pub(crate) fn path_update_required(proposals: &ProposalBundle) -> bool {
76 let res = !proposals.external_init_proposals().is_empty();
77
78 #[cfg(feature = "by_ref_proposal")]
79 let res = res || !proposals.update_proposals().is_empty();
80
81 res || proposals.length() == 0
82 || proposals.group_context_extensions_proposal().is_some()
83 || !proposals.remove_proposals().is_empty()
84}
85
86#[cfg_attr(
87 all(feature = "ffi", not(test)),
88 safer_ffi_gen::ffi_type(clone, opaque)
89)]
90#[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode)]
91#[non_exhaustive]
92pub struct NewEpoch {
93 pub epoch: u64,
94 pub prior_state: GroupState,
95 pub applied_proposals: Vec<ProposalInfo<Proposal>>,
96 pub unused_proposals: Vec<ProposalInfo<Proposal>>,
97}
98
99impl NewEpoch {
100 pub(crate) fn new(prior_state: GroupState, provisional_state: &ProvisionalState) -> NewEpoch {
101 NewEpoch {
102 epoch: provisional_state.group_context.epoch,
103 prior_state,
104 unused_proposals: provisional_state.unused_proposals.clone(),
105 applied_proposals: provisional_state
106 .applied_proposals
107 .clone()
108 .into_proposals()
109 .collect_vec(),
110 }
111 }
112}
113
114#[cfg(all(feature = "ffi", not(test)))]
115#[safer_ffi_gen::safer_ffi_gen]
116impl NewEpoch {
117 pub fn epoch(&self) -> u64 {
118 self.epoch
119 }
120
121 pub fn prior_state(&self) -> &GroupState {
122 &self.prior_state
123 }
124
125 pub fn applied_proposals(&self) -> &[ProposalInfo<Proposal>] {
126 &self.applied_proposals
127 }
128
129 pub fn unused_proposals(&self) -> &[ProposalInfo<Proposal>] {
130 &self.unused_proposals
131 }
132}
133
134#[cfg_attr(
135 all(feature = "ffi", not(test)),
136 safer_ffi_gen::ffi_type(clone, opaque)
137)]
138#[derive(Clone, Debug, PartialEq)]
139pub enum CommitEffect {
140 NewEpoch(Box<NewEpoch>),
141 Removed {
142 new_epoch: Box<NewEpoch>,
143 remover: Sender,
144 },
145 ReInit(ProposalInfo<ReInitProposal>),
146}
147
148impl MlsSize for CommitEffect {
149 fn mls_encoded_len(&self) -> usize {
150 0u8.mls_encoded_len()
151 + match self {
152 Self::NewEpoch(e) => e.mls_encoded_len(),
153 Self::Removed { new_epoch, remover } => {
154 new_epoch.mls_encoded_len() + remover.mls_encoded_len()
155 }
156 Self::ReInit(r) => r.mls_encoded_len(),
157 }
158 }
159}
160
161impl MlsEncode for CommitEffect {
162 fn mls_encode(&self, writer: &mut Vec<u8>) -> Result<(), mls_rs_codec::Error> {
163 match self {
164 Self::NewEpoch(e) => {
165 1u8.mls_encode(writer)?;
166 e.mls_encode(writer)?;
167 }
168 Self::Removed { new_epoch, remover } => {
169 2u8.mls_encode(writer)?;
170 new_epoch.mls_encode(writer)?;
171 remover.mls_encode(writer)?;
172 }
173 Self::ReInit(r) => {
174 3u8.mls_encode(writer)?;
175 r.mls_encode(writer)?;
176 }
177 }
178
179 Ok(())
180 }
181}
182
183impl MlsDecode for CommitEffect {
184 fn mls_decode(reader: &mut &[u8]) -> Result<Self, mls_rs_codec::Error> {
185 match u8::mls_decode(reader)? {
186 1u8 => Ok(Self::NewEpoch(NewEpoch::mls_decode(reader)?.into())),
187 2u8 => Ok(Self::Removed {
188 new_epoch: NewEpoch::mls_decode(reader)?.into(),
189 remover: Sender::mls_decode(reader)?,
190 }),
191 3u8 => Ok(Self::ReInit(ProposalInfo::mls_decode(reader)?)),
192 _ => Err(mls_rs_codec::Error::UnsupportedEnumDiscriminant),
193 }
194 }
195}
196
197#[cfg_attr(
198 all(feature = "ffi", not(test)),
199 safer_ffi_gen::ffi_type(clone, opaque)
200)]
201#[derive(Debug, Clone)]
202#[allow(clippy::large_enum_variant)]
203pub enum ReceivedMessage {
206 ApplicationMessage(ApplicationMessageDescription),
208 Commit(CommitMessageDescription),
210 Proposal(ProposalMessageDescription),
212 GroupInfo(GroupInfo),
214 Welcome,
216 KeyPackage(KeyPackage),
218}
219
220impl TryFrom<ApplicationMessageDescription> for ReceivedMessage {
221 type Error = MlsError;
222
223 fn try_from(value: ApplicationMessageDescription) -> Result<Self, Self::Error> {
224 Ok(ReceivedMessage::ApplicationMessage(value))
225 }
226}
227
228impl From<CommitMessageDescription> for ReceivedMessage {
229 fn from(value: CommitMessageDescription) -> Self {
230 ReceivedMessage::Commit(value)
231 }
232}
233
234impl From<ProposalMessageDescription> for ReceivedMessage {
235 fn from(value: ProposalMessageDescription) -> Self {
236 ReceivedMessage::Proposal(value)
237 }
238}
239
240impl From<GroupInfo> for ReceivedMessage {
241 fn from(value: GroupInfo) -> Self {
242 ReceivedMessage::GroupInfo(value)
243 }
244}
245
246impl From<Welcome> for ReceivedMessage {
247 fn from(_: Welcome) -> Self {
248 ReceivedMessage::Welcome
249 }
250}
251
252impl From<KeyPackage> for ReceivedMessage {
253 fn from(value: KeyPackage) -> Self {
254 ReceivedMessage::KeyPackage(value)
255 }
256}
257
258#[cfg_attr(
259 all(feature = "ffi", not(test)),
260 safer_ffi_gen::ffi_type(clone, opaque)
261)]
262#[derive(Clone, PartialEq, Eq)]
263pub struct ApplicationMessageDescription {
265 pub sender_index: u32,
267 data: ApplicationData,
269 pub authenticated_data: Vec<u8>,
271}
272
273impl Debug for ApplicationMessageDescription {
274 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
275 f.debug_struct("ApplicationMessageDescription")
276 .field("sender_index", &self.sender_index)
277 .field("data", &self.data)
278 .field(
279 "authenticated_data",
280 &mls_rs_core::debug::pretty_bytes(&self.authenticated_data),
281 )
282 .finish()
283 }
284}
285
286#[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen)]
287impl ApplicationMessageDescription {
288 pub fn data(&self) -> &[u8] {
289 self.data.as_bytes()
290 }
291}
292
293#[cfg_attr(
294 all(feature = "ffi", not(test)),
295 safer_ffi_gen::ffi_type(clone, opaque)
296)]
297#[derive(Clone, PartialEq, MlsSize, MlsEncode, MlsDecode)]
298#[non_exhaustive]
299pub struct CommitMessageDescription {
301 pub is_external: bool,
303 pub committer: u32,
305 pub effect: CommitEffect,
307 #[mls_codec(with = "mls_rs_codec::byte_vec")]
309 pub authenticated_data: Vec<u8>,
310}
311
312impl Debug for CommitMessageDescription {
313 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
314 f.debug_struct("CommitMessageDescription")
315 .field("is_external", &self.is_external)
316 .field("committer", &self.committer)
317 .field("effect", &self.effect)
318 .field(
319 "authenticated_data",
320 &mls_rs_core::debug::pretty_bytes(&self.authenticated_data),
321 )
322 .finish()
323 }
324}
325
326#[derive(Debug, Clone, Copy, PartialEq, Eq, MlsEncode, MlsDecode, MlsSize)]
327#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
328#[repr(u8)]
329pub enum ProposalSender {
331 Member(u32) = 1u8,
333 External(u32) = 2u8,
336 NewMember = 3u8,
338}
339
340impl TryFrom<Sender> for ProposalSender {
341 type Error = MlsError;
342
343 fn try_from(value: Sender) -> Result<Self, Self::Error> {
344 match value {
345 Sender::Member(index) => Ok(Self::Member(index)),
346 #[cfg(feature = "by_ref_proposal")]
347 Sender::External(index) => Ok(Self::External(index)),
348 #[cfg(feature = "by_ref_proposal")]
349 Sender::NewMemberProposal => Ok(Self::NewMember),
350 Sender::NewMemberCommit => Err(MlsError::InvalidSender),
351 }
352 }
353}
354
355#[cfg(feature = "by_ref_proposal")]
356#[cfg_attr(
357 all(feature = "ffi", not(test)),
358 safer_ffi_gen::ffi_type(clone, opaque)
359)]
360#[derive(Clone, MlsEncode, MlsDecode, MlsSize, PartialEq)]
361#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
362#[non_exhaustive]
363pub struct ProposalMessageDescription {
365 pub sender: ProposalSender,
367 pub proposal: Proposal,
369 pub authenticated_data: Vec<u8>,
371 pub proposal_ref: ProposalRef,
373}
374
375#[cfg(feature = "by_ref_proposal")]
376impl Debug for ProposalMessageDescription {
377 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
378 f.debug_struct("ProposalMessageDescription")
379 .field("sender", &self.sender)
380 .field("proposal", &self.proposal)
381 .field(
382 "authenticated_data",
383 &mls_rs_core::debug::pretty_bytes(&self.authenticated_data),
384 )
385 .field("proposal_ref", &self.proposal_ref)
386 .finish()
387 }
388}
389
390#[cfg(feature = "by_ref_proposal")]
391#[derive(MlsSize, MlsEncode, MlsDecode)]
392pub struct CachedProposal {
393 pub(crate) proposal: Proposal,
394 pub(crate) proposal_ref: ProposalRef,
395 pub(crate) sender: Sender,
396}
397
398#[cfg(feature = "by_ref_proposal")]
399impl CachedProposal {
400 pub fn from_bytes(bytes: &[u8]) -> Result<Self, MlsError> {
402 Ok(Self::mls_decode(&mut &*bytes)?)
403 }
404
405 pub fn to_bytes(&self) -> Result<Vec<u8>, MlsError> {
407 Ok(self.mls_encode_to_vec()?)
408 }
409}
410
411#[cfg(feature = "by_ref_proposal")]
412impl ProposalMessageDescription {
413 pub fn cached_proposal(self) -> CachedProposal {
414 let sender = match self.sender {
415 ProposalSender::Member(i) => Sender::Member(i),
416 ProposalSender::External(i) => Sender::External(i),
417 ProposalSender::NewMember => Sender::NewMemberProposal,
418 };
419
420 CachedProposal {
421 proposal: self.proposal,
422 proposal_ref: self.proposal_ref,
423 sender,
424 }
425 }
426
427 pub fn proposal_ref(&self) -> Vec<u8> {
428 self.proposal_ref.to_vec()
429 }
430
431 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
432 pub(crate) async fn new<C: CipherSuiteProvider>(
433 cs: &C,
434 content: &AuthenticatedContent,
435 proposal: Proposal,
436 ) -> Result<Self, MlsError> {
437 Ok(ProposalMessageDescription {
438 authenticated_data: content.content.authenticated_data.clone(),
439 proposal,
440 sender: content.content.sender.try_into()?,
441 proposal_ref: ProposalRef::from_content(cs, content).await?,
442 })
443 }
444}
445
446#[cfg(not(feature = "by_ref_proposal"))]
447#[cfg_attr(
448 all(feature = "ffi", not(test)),
449 safer_ffi_gen::ffi_type(clone, opaque)
450)]
451#[derive(Debug, Clone)]
452pub struct ProposalMessageDescription {}
454
455#[allow(clippy::large_enum_variant)]
456pub(crate) enum EventOrContent<E> {
457 #[cfg_attr(
458 not(all(feature = "private_message", feature = "external_client")),
459 allow(dead_code)
460 )]
461 Event(E),
462 Content(AuthenticatedContent),
463}
464
465#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
466#[cfg_attr(all(target_arch = "wasm32", mls_build_async), maybe_async::must_be_async(?Send))]
467#[cfg_attr(
468 all(not(target_arch = "wasm32"), mls_build_async),
469 maybe_async::must_be_async
470)]
471pub(crate) trait MessageProcessor: Send + Sync {
472 type OutputType: TryFrom<ApplicationMessageDescription, Error = MlsError>
473 + From<CommitMessageDescription>
474 + From<ProposalMessageDescription>
475 + From<GroupInfo>
476 + From<Welcome>
477 + From<KeyPackage>
478 + Send;
479
480 type MlsRules: MlsRules;
481 type IdentityProvider: IdentityProvider;
482 type CipherSuiteProvider: CipherSuiteProvider;
483 type PreSharedKeyStorage: PreSharedKeyStorage;
484
485 async fn process_incoming_message(
486 &mut self,
487 message: MlsMessage,
488 #[cfg(feature = "by_ref_proposal")] cache_proposal: bool,
489 ) -> Result<Self::OutputType, MlsError> {
490 self.process_incoming_message_with_time(
491 message,
492 #[cfg(feature = "by_ref_proposal")]
493 cache_proposal,
494 None,
495 )
496 .await
497 }
498
499 async fn process_incoming_message_with_time(
500 &mut self,
501 message: MlsMessage,
502 #[cfg(feature = "by_ref_proposal")] cache_proposal: bool,
503 time_sent: Option<MlsTime>,
504 ) -> Result<Self::OutputType, MlsError> {
505 let event_or_content = self.get_event_from_incoming_message(message).await?;
506
507 self.process_event_or_content(
508 event_or_content,
509 #[cfg(feature = "by_ref_proposal")]
510 cache_proposal,
511 time_sent,
512 )
513 .await
514 }
515
516 async fn get_event_from_incoming_message(
517 &mut self,
518 message: MlsMessage,
519 ) -> Result<EventOrContent<Self::OutputType>, MlsError> {
520 self.check_metadata(&message)?;
521
522 match message.payload {
523 MlsMessagePayload::Plain(plaintext) => {
524 self.verify_plaintext_authentication(plaintext).await
525 }
526 #[cfg(feature = "private_message")]
527 MlsMessagePayload::Cipher(cipher_text) => self.process_ciphertext(&cipher_text).await,
528 MlsMessagePayload::GroupInfo(group_info) => {
529 validate_group_info_member(
530 self.group_state(),
531 message.version,
532 &group_info,
533 self.cipher_suite_provider(),
534 )
535 .await?;
536
537 Ok(EventOrContent::Event(group_info.into()))
538 }
539 MlsMessagePayload::Welcome(welcome) => {
540 self.validate_welcome(&welcome, message.version)?;
541
542 Ok(EventOrContent::Event(welcome.into()))
543 }
544 MlsMessagePayload::KeyPackage(key_package) => {
545 self.validate_key_package(&key_package, message.version)
546 .await?;
547
548 Ok(EventOrContent::Event(key_package.into()))
549 }
550 }
551 }
552
553 async fn process_event_or_content(
554 &mut self,
555 event_or_content: EventOrContent<Self::OutputType>,
556 #[cfg(feature = "by_ref_proposal")] cache_proposal: bool,
557 time_sent: Option<MlsTime>,
558 ) -> Result<Self::OutputType, MlsError> {
559 let msg = match event_or_content {
560 EventOrContent::Event(event) => event,
561 EventOrContent::Content(content) => {
562 self.process_auth_content(
563 content,
564 #[cfg(feature = "by_ref_proposal")]
565 cache_proposal,
566 time_sent,
567 )
568 .await?
569 }
570 };
571
572 Ok(msg)
573 }
574
575 async fn process_auth_content(
576 &mut self,
577 auth_content: AuthenticatedContent,
578 #[cfg(feature = "by_ref_proposal")] cache_proposal: bool,
579 time_sent: Option<MlsTime>,
580 ) -> Result<Self::OutputType, MlsError> {
581 let event = match auth_content.content.content {
582 #[cfg(feature = "private_message")]
583 Content::Application(data) => {
584 let authenticated_data = auth_content.content.authenticated_data;
585 let sender = auth_content.content.sender;
586
587 self.process_application_message(data, sender, authenticated_data)
588 .and_then(Self::OutputType::try_from)
589 }
590 Content::Commit(_) => self
591 .process_commit(auth_content, time_sent)
592 .await
593 .map(Self::OutputType::from),
594 #[cfg(feature = "by_ref_proposal")]
595 Content::Proposal(ref proposal) => self
596 .process_proposal(&auth_content, proposal, cache_proposal)
597 .await
598 .map(Self::OutputType::from),
599 }?;
600
601 Ok(event)
602 }
603
604 #[cfg(feature = "private_message")]
605 fn process_application_message(
606 &self,
607 data: ApplicationData,
608 sender: Sender,
609 authenticated_data: Vec<u8>,
610 ) -> Result<ApplicationMessageDescription, MlsError> {
611 let Sender::Member(sender_index) = sender else {
612 return Err(MlsError::InvalidSender);
613 };
614
615 Ok(ApplicationMessageDescription {
616 authenticated_data,
617 sender_index,
618 data,
619 })
620 }
621
622 #[cfg(feature = "by_ref_proposal")]
623 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
624 async fn process_proposal(
625 &mut self,
626 auth_content: &AuthenticatedContent,
627 proposal: &Proposal,
628 cache_proposal: bool,
629 ) -> Result<ProposalMessageDescription, MlsError> {
630 let proposal = ProposalMessageDescription::new(
631 self.cipher_suite_provider(),
632 auth_content,
633 proposal.clone(),
634 )
635 .await?;
636
637 let group_state = self.group_state_mut();
638
639 if cache_proposal {
640 group_state.proposals.insert(
641 proposal.proposal_ref.clone(),
642 proposal.proposal.clone(),
643 auth_content.content.sender,
644 );
645 }
646
647 Ok(proposal)
648 }
649
650 async fn process_commit(
651 &mut self,
652 auth_content: AuthenticatedContent,
653 time_sent: Option<MlsTime>,
654 ) -> Result<CommitMessageDescription, MlsError> {
655 if self.group_state().pending_reinit.is_some() {
656 return Err(MlsError::GroupUsedAfterReInit);
657 }
658
659 let (interim_transcript_hash, confirmed_transcript_hash) = transcript_hashes(
661 self.cipher_suite_provider(),
662 &self.group_state().interim_transcript_hash,
663 &auth_content,
664 )
665 .await?;
666
667 #[cfg(any(feature = "private_message", feature = "by_ref_proposal"))]
668 let commit = match auth_content.content.content {
669 Content::Commit(commit) => Ok(commit),
670 _ => Err(MlsError::UnexpectedMessageType),
671 }?;
672
673 #[cfg(not(any(feature = "private_message", feature = "by_ref_proposal")))]
674 let Content::Commit(commit) = auth_content.content.content;
675
676 let group_state = self.group_state();
677 let id_provider = self.identity_provider();
678
679 #[cfg(feature = "by_ref_proposal")]
680 let proposals = group_state
681 .proposals
682 .resolve_for_commit(auth_content.content.sender, commit.proposals)?;
683
684 #[cfg(not(feature = "by_ref_proposal"))]
685 let proposals = resolve_for_commit(auth_content.content.sender, commit.proposals)?;
686
687 let mut provisional_state = group_state
688 .apply_resolved(
689 auth_content.content.sender,
690 proposals,
691 commit.path.as_ref().map(|path| &path.leaf_node),
692 &id_provider,
693 self.cipher_suite_provider(),
694 &self.psk_storage(),
695 &self.mls_rules(),
696 time_sent,
697 CommitDirection::Receive,
698 )
699 .await?;
700
701 let sender = commit_sender(&auth_content.content.sender, &provisional_state)?;
702
703 if path_update_required(&provisional_state.applied_proposals) && commit.path.is_none() {
706 return Err(MlsError::CommitMissingPath);
707 }
708
709 let self_removed = self.removal_proposal(&provisional_state);
710 let is_self_removed = self_removed.is_some();
711
712 let update_path = match commit.path {
713 Some(update_path) => Some(
714 validate_update_path(
715 &self.identity_provider(),
716 self.cipher_suite_provider(),
717 update_path,
718 &provisional_state,
719 sender,
720 time_sent,
721 &group_state.context,
722 )
723 .await?,
724 ),
725 None => None,
726 };
727
728 let commit_effect =
729 if let Some(reinit) = provisional_state.applied_proposals.reinitializations.pop() {
730 self.group_state_mut().pending_reinit = Some(reinit.proposal.clone());
731 CommitEffect::ReInit(reinit)
732 } else if let Some(remove_proposal) = self_removed {
733 let new_epoch = NewEpoch::new(self.group_state().clone(), &provisional_state);
734 CommitEffect::Removed {
735 remover: remove_proposal.sender,
736 new_epoch: Box::new(new_epoch),
737 }
738 } else {
739 CommitEffect::NewEpoch(Box::new(NewEpoch::new(
740 self.group_state().clone(),
741 &provisional_state,
742 )))
743 };
744
745 let new_secrets = match update_path {
746 Some(update_path) if !is_self_removed => {
747 self.apply_update_path(sender, &update_path, &mut provisional_state)
748 .await
749 }
750 _ => Ok(None),
751 }?;
752
753 provisional_state.group_context.confirmed_transcript_hash = confirmed_transcript_hash;
755
756 provisional_state
758 .public_tree
759 .update_hashes(&[sender], self.cipher_suite_provider())
760 .await?;
761
762 provisional_state.group_context.tree_hash = provisional_state
764 .public_tree
765 .tree_hash(self.cipher_suite_provider())
766 .await?;
767
768 if let Some(confirmation_tag) = &auth_content.auth.confirmation_tag {
769 if !is_self_removed {
770 self.update_key_schedule(
772 new_secrets,
773 interim_transcript_hash,
774 confirmation_tag,
775 provisional_state,
776 )
777 .await?;
778 }
779
780 Ok(CommitMessageDescription {
781 is_external: matches!(auth_content.content.sender, Sender::NewMemberCommit),
782 authenticated_data: auth_content.content.authenticated_data,
783 committer: *sender,
784 effect: commit_effect,
785 })
786 } else {
787 Err(MlsError::InvalidConfirmationTag)
788 }
789 }
790
791 fn group_state(&self) -> &GroupState;
792 fn group_state_mut(&mut self) -> &mut GroupState;
793 fn mls_rules(&self) -> Self::MlsRules;
794 fn identity_provider(&self) -> Self::IdentityProvider;
795 fn cipher_suite_provider(&self) -> &Self::CipherSuiteProvider;
796 fn psk_storage(&self) -> Self::PreSharedKeyStorage;
797
798 fn removal_proposal(
799 &self,
800 provisional_state: &ProvisionalState,
801 ) -> Option<ProposalInfo<RemoveProposal>>;
802
803 #[cfg(feature = "private_message")]
804 fn min_epoch_available(&self) -> Option<u64>;
805
806 fn check_metadata(&self, message: &MlsMessage) -> Result<(), MlsError> {
807 let context = &self.group_state().context;
808
809 if message.version != context.protocol_version {
810 return Err(MlsError::ProtocolVersionMismatch);
811 }
812
813 if let Some((group_id, epoch, content_type)) = match &message.payload {
814 MlsMessagePayload::Plain(plaintext) => Some((
815 &plaintext.content.group_id,
816 plaintext.content.epoch,
817 plaintext.content.content_type(),
818 )),
819 #[cfg(feature = "private_message")]
820 MlsMessagePayload::Cipher(ciphertext) => Some((
821 &ciphertext.group_id,
822 ciphertext.epoch,
823 ciphertext.content_type,
824 )),
825 _ => None,
826 } {
827 if group_id != &context.group_id {
828 return Err(MlsError::GroupIdMismatch);
829 }
830
831 match content_type {
832 ContentType::Commit => {
833 if context.epoch != epoch {
834 Err(MlsError::InvalidEpoch)
835 } else {
836 Ok(())
837 }
838 }
839 #[cfg(feature = "by_ref_proposal")]
840 ContentType::Proposal => {
841 if context.epoch != epoch {
842 Err(MlsError::InvalidEpoch)
843 } else {
844 Ok(())
845 }
846 }
847 #[cfg(feature = "private_message")]
848 ContentType::Application => {
849 if let Some(min) = self.min_epoch_available() {
850 if epoch < min {
851 Err(MlsError::InvalidEpoch)
852 } else {
853 Ok(())
854 }
855 } else {
856 Ok(())
857 }
858 }
859 }?;
860
861 let check_epoch = content_type == ContentType::Commit;
863
864 #[cfg(feature = "by_ref_proposal")]
865 let check_epoch = check_epoch || content_type == ContentType::Proposal;
866
867 if check_epoch && epoch != context.epoch {
868 return Err(MlsError::InvalidEpoch);
869 }
870
871 #[cfg(feature = "private_message")]
873 if !matches!(&message.payload, MlsMessagePayload::Cipher(_))
874 && content_type == ContentType::Application
875 {
876 return Err(MlsError::UnencryptedApplicationMessage);
877 }
878 }
879
880 Ok(())
881 }
882
883 fn validate_welcome(
884 &self,
885 welcome: &Welcome,
886 version: ProtocolVersion,
887 ) -> Result<(), MlsError> {
888 let state = self.group_state();
889
890 (welcome.cipher_suite == state.context.cipher_suite
891 && version == state.context.protocol_version)
892 .then_some(())
893 .ok_or(MlsError::InvalidWelcomeMessage)
894 }
895
896 async fn validate_key_package(
897 &self,
898 key_package: &KeyPackage,
899 version: ProtocolVersion,
900 ) -> Result<(), MlsError> {
901 let cs = self.cipher_suite_provider();
902 let id = self.identity_provider();
903
904 validate_key_package(key_package, version, cs, &id).await
905 }
906
907 #[cfg(feature = "private_message")]
908 async fn process_ciphertext(
909 &mut self,
910 cipher_text: &PrivateMessage,
911 ) -> Result<EventOrContent<Self::OutputType>, MlsError>;
912
913 async fn verify_plaintext_authentication(
914 &self,
915 message: PublicMessage,
916 ) -> Result<EventOrContent<Self::OutputType>, MlsError>;
917
918 async fn apply_update_path(
919 &mut self,
920 sender: LeafIndex,
921 update_path: &ValidatedUpdatePath,
922 provisional_state: &mut ProvisionalState,
923 ) -> Result<Option<(TreeKemPrivate, PathSecret)>, MlsError> {
924 provisional_state
925 .public_tree
926 .apply_update_path(
927 sender,
928 update_path,
929 &provisional_state.group_context.extensions,
930 self.identity_provider(),
931 self.cipher_suite_provider(),
932 )
933 .await
934 .map(|_| None)
935 }
936
937 async fn update_key_schedule(
938 &mut self,
939 secrets: Option<(TreeKemPrivate, PathSecret)>,
940 interim_transcript_hash: InterimTranscriptHash,
941 confirmation_tag: &ConfirmationTag,
942 provisional_public_state: ProvisionalState,
943 ) -> Result<(), MlsError>;
944}
945
946#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
947pub(crate) async fn validate_key_package<C: CipherSuiteProvider, I: IdentityProvider>(
948 key_package: &KeyPackage,
949 version: ProtocolVersion,
950 cs: &C,
951 id: &I,
952) -> Result<(), MlsError> {
953 let validator = LeafNodeValidator::new(cs, id, MemberValidationContext::None);
954
955 #[cfg(feature = "std")]
956 let context = Some(MlsTime::now());
957
958 #[cfg(not(feature = "std"))]
959 let context = None;
960
961 let context = ValidationContext::Add(context);
962
963 validator
964 .check_if_valid(&key_package.leaf_node, context)
965 .await?;
966
967 validate_key_package_properties(key_package, version, cs).await?;
968
969 Ok(())
970}
971
972#[cfg(test)]
973mod tests {
974 use alloc::{vec, vec::Vec};
975 use mls_rs_codec::{MlsDecode, MlsEncode};
976
977 use crate::{
978 client::test_utils::TEST_PROTOCOL_VERSION,
979 group::{test_utils::get_test_group_context, GroupState, Sender},
980 };
981
982 use super::{CommitEffect, NewEpoch};
983
984 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
985 async fn commit_effect_codec() {
986 let epoch = NewEpoch {
987 epoch: 7,
988 prior_state: GroupState {
989 #[cfg(feature = "by_ref_proposal")]
990 proposals: crate::group::ProposalCache::new(TEST_PROTOCOL_VERSION, vec![]),
991 context: get_test_group_context(7, 7.into()).await,
992 public_tree: Default::default(),
993 interim_transcript_hash: vec![].into(),
994 pending_reinit: None,
995 confirmation_tag: Default::default(),
996 },
997 applied_proposals: vec![],
998 unused_proposals: vec![],
999 };
1000
1001 let effects = vec![
1002 CommitEffect::NewEpoch(epoch.clone().into()),
1003 CommitEffect::Removed {
1004 new_epoch: epoch.into(),
1005 remover: Sender::Member(0),
1006 },
1007 ];
1008
1009 let bytes = effects.mls_encode_to_vec().unwrap();
1010
1011 assert_eq!(
1012 effects,
1013 Vec::<CommitEffect>::mls_decode(&mut &*bytes).unwrap()
1014 );
1015 }
1016}