1#[cfg(all(
6 feature = "by_ref_proposal",
7 feature = "custom_proposal",
8 feature = "self_remove_proposal"
9))]
10use super::SelfRemoveProposal;
11use super::{
12 commit_sender,
13 confirmation_tag::ConfirmationTag,
14 framing::{
15 ApplicationData, Content, ContentType, MlsMessage, MlsMessagePayload, PublicMessage, Sender,
16 },
17 message_signature::AuthenticatedContent,
18 mls_rules::{CommitDirection, MlsRules},
19 proposal_filter::ProposalBundle,
20 state::GroupState,
21 transcript_hash::InterimTranscriptHash,
22 transcript_hashes, validate_group_info_member, GroupContext, GroupInfo, ReInitProposal,
23 RemoveProposal, Welcome,
24};
25use crate::{
26 client::MlsError,
27 key_package::validate_key_package_properties,
28 time::MlsTime,
29 tree_kem::{
30 leaf_node_validator::{LeafNodeValidator, ValidationContext},
31 node::LeafIndex,
32 path_secret::PathSecret,
33 validate_update_path, TreeKemPrivate, TreeKemPublic, ValidatedUpdatePath,
34 },
35 CipherSuiteProvider, KeyPackage,
36};
37use itertools::Itertools;
38use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
39
40use alloc::boxed::Box;
41use alloc::vec::Vec;
42use core::fmt::{self, Debug};
43use mls_rs_core::{
44 identity::{IdentityProvider, MemberValidationContext},
45 protocol_version::ProtocolVersion,
46 psk::PreSharedKeyStorage,
47};
48
49#[cfg(feature = "by_ref_proposal")]
50use super::proposal_ref::ProposalRef;
51
52#[cfg(not(feature = "by_ref_proposal"))]
53use crate::group::proposal_cache::resolve_for_commit;
54
55use super::proposal::Proposal;
56use super::proposal_filter::ProposalInfo;
57
58#[cfg(feature = "private_message")]
59use crate::group::framing::PrivateMessage;
60
61#[derive(Debug)]
62pub(crate) struct ProvisionalState {
63 pub(crate) public_tree: TreeKemPublic,
64 pub(crate) applied_proposals: ProposalBundle,
65 pub(crate) group_context: GroupContext,
66 pub(crate) external_init_index: Option<LeafIndex>,
67 pub(crate) indexes_of_added_kpkgs: Vec<LeafIndex>,
68 pub(crate) unused_proposals: Vec<ProposalInfo<Proposal>>,
69}
70
71pub(crate) fn path_update_required(proposals: &ProposalBundle) -> bool {
82 let res = !proposals.external_init_proposals().is_empty();
83
84 #[cfg(feature = "by_ref_proposal")]
85 let res = res || !proposals.update_proposals().is_empty();
86
87 #[cfg(all(
88 feature = "by_ref_proposal",
89 feature = "custom_proposal",
90 feature = "self_remove_proposal"
91 ))]
92 let res = res || !proposals.self_removes.is_empty();
93
94 res || proposals.length() == 0
95 || proposals.group_context_extensions_proposal().is_some()
96 || !proposals.remove_proposals().is_empty()
97}
98
99#[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode)]
100#[non_exhaustive]
101pub struct NewEpoch {
102 pub epoch: u64,
103 pub prior_state: GroupState,
104 pub applied_proposals: Vec<ProposalInfo<Proposal>>,
105 pub unused_proposals: Vec<ProposalInfo<Proposal>>,
106}
107
108impl NewEpoch {
109 pub(crate) fn new(prior_state: GroupState, provisional_state: &ProvisionalState) -> NewEpoch {
110 NewEpoch {
111 epoch: provisional_state.group_context.epoch,
112 prior_state,
113 unused_proposals: provisional_state.unused_proposals.clone(),
114 applied_proposals: provisional_state
115 .applied_proposals
116 .clone()
117 .into_proposals()
118 .collect_vec(),
119 }
120 }
121}
122
123impl NewEpoch {
124 pub fn epoch(&self) -> u64 {
125 self.epoch
126 }
127
128 pub fn prior_state(&self) -> &GroupState {
129 &self.prior_state
130 }
131
132 pub fn applied_proposals(&self) -> &[ProposalInfo<Proposal>] {
133 &self.applied_proposals
134 }
135
136 pub fn unused_proposals(&self) -> &[ProposalInfo<Proposal>] {
137 &self.unused_proposals
138 }
139}
140
141#[derive(Clone, Debug, PartialEq)]
142pub enum CommitEffect {
143 NewEpoch(Box<NewEpoch>),
144 Removed {
145 new_epoch: Box<NewEpoch>,
146 remover: Sender,
147 },
148 ReInit(ProposalInfo<ReInitProposal>),
149}
150
151impl MlsSize for CommitEffect {
152 fn mls_encoded_len(&self) -> usize {
153 0u8.mls_encoded_len()
154 + match self {
155 Self::NewEpoch(e) => e.mls_encoded_len(),
156 Self::Removed { new_epoch, remover } => {
157 new_epoch.mls_encoded_len() + remover.mls_encoded_len()
158 }
159 Self::ReInit(r) => r.mls_encoded_len(),
160 }
161 }
162}
163
164impl MlsEncode for CommitEffect {
165 fn mls_encode(&self, writer: &mut Vec<u8>) -> Result<(), mls_rs_codec::Error> {
166 match self {
167 Self::NewEpoch(e) => {
168 1u8.mls_encode(writer)?;
169 e.mls_encode(writer)?;
170 }
171 Self::Removed { new_epoch, remover } => {
172 2u8.mls_encode(writer)?;
173 new_epoch.mls_encode(writer)?;
174 remover.mls_encode(writer)?;
175 }
176 Self::ReInit(r) => {
177 3u8.mls_encode(writer)?;
178 r.mls_encode(writer)?;
179 }
180 }
181
182 Ok(())
183 }
184}
185
186impl MlsDecode for CommitEffect {
187 fn mls_decode(reader: &mut &[u8]) -> Result<Self, mls_rs_codec::Error> {
188 match u8::mls_decode(reader)? {
189 1u8 => Ok(Self::NewEpoch(NewEpoch::mls_decode(reader)?.into())),
190 2u8 => Ok(Self::Removed {
191 new_epoch: NewEpoch::mls_decode(reader)?.into(),
192 remover: Sender::mls_decode(reader)?,
193 }),
194 3u8 => Ok(Self::ReInit(ProposalInfo::mls_decode(reader)?)),
195 _ => Err(mls_rs_codec::Error::UnsupportedEnumDiscriminant),
196 }
197 }
198}
199
200#[derive(Debug, Clone)]
201#[allow(clippy::large_enum_variant)]
202pub enum ReceivedMessage {
205 ApplicationMessage(ApplicationMessageDescription),
207 Commit(CommitMessageDescription),
209 Proposal(ProposalMessageDescription),
211 GroupInfo(GroupInfo),
213 Welcome,
215 KeyPackage(KeyPackage),
217}
218
219impl TryFrom<ApplicationMessageDescription> for ReceivedMessage {
220 type Error = MlsError;
221
222 fn try_from(value: ApplicationMessageDescription) -> Result<Self, Self::Error> {
223 Ok(ReceivedMessage::ApplicationMessage(value))
224 }
225}
226
227impl From<CommitMessageDescription> for ReceivedMessage {
228 fn from(value: CommitMessageDescription) -> Self {
229 ReceivedMessage::Commit(value)
230 }
231}
232
233impl From<ProposalMessageDescription> for ReceivedMessage {
234 fn from(value: ProposalMessageDescription) -> Self {
235 ReceivedMessage::Proposal(value)
236 }
237}
238
239impl From<GroupInfo> for ReceivedMessage {
240 fn from(value: GroupInfo) -> Self {
241 ReceivedMessage::GroupInfo(value)
242 }
243}
244
245impl From<Welcome> for ReceivedMessage {
246 fn from(_: Welcome) -> Self {
247 ReceivedMessage::Welcome
248 }
249}
250
251impl From<KeyPackage> for ReceivedMessage {
252 fn from(value: KeyPackage) -> Self {
253 ReceivedMessage::KeyPackage(value)
254 }
255}
256
257#[derive(Clone, PartialEq, Eq)]
258pub struct ApplicationMessageDescription {
260 pub sender_index: u32,
262 data: ApplicationData,
264 pub authenticated_data: Vec<u8>,
266 #[cfg(all(feature = "export_key_generation", feature = "private_message"))]
269 pub unauthenticated_key_generation: Option<u32>,
270}
271
272impl Debug for ApplicationMessageDescription {
273 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
274 let mut res = f.debug_struct("ApplicationMessageDescription");
275 res.field("sender_index", &self.sender_index)
276 .field("data", &self.data)
277 .field(
278 "authenticated_data",
279 &mls_rs_core::debug::pretty_bytes(&self.authenticated_data),
280 );
281 #[cfg(all(feature = "export_key_generation", feature = "private_message"))]
282 res.field(
283 "unauthenticated_key_generation",
284 &self.unauthenticated_key_generation,
285 );
286 res.finish()
287 }
288}
289
290impl ApplicationMessageDescription {
291 pub fn data(&self) -> &[u8] {
292 self.data.as_bytes()
293 }
294}
295
296#[derive(Clone, PartialEq, MlsSize, MlsEncode, MlsDecode)]
297#[non_exhaustive]
298pub struct CommitMessageDescription {
300 pub is_external: bool,
302 pub committer: u32,
304 pub effect: CommitEffect,
306 #[mls_codec(with = "mls_rs_codec::byte_vec")]
308 pub authenticated_data: Vec<u8>,
309}
310
311impl Debug for CommitMessageDescription {
312 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
313 f.debug_struct("CommitMessageDescription")
314 .field("is_external", &self.is_external)
315 .field("committer", &self.committer)
316 .field("effect", &self.effect)
317 .field(
318 "authenticated_data",
319 &mls_rs_core::debug::pretty_bytes(&self.authenticated_data),
320 )
321 .finish()
322 }
323}
324
325#[derive(Debug, Clone, Copy, PartialEq, Eq, MlsEncode, MlsDecode, MlsSize)]
326#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
327#[repr(u8)]
328pub enum ProposalSender {
330 Member(u32) = 1u8,
332 External(u32) = 2u8,
335 NewMember = 3u8,
337}
338
339impl TryFrom<Sender> for ProposalSender {
340 type Error = MlsError;
341
342 fn try_from(value: Sender) -> Result<Self, Self::Error> {
343 match value {
344 Sender::Member(index) => Ok(Self::Member(index)),
345 #[cfg(feature = "by_ref_proposal")]
346 Sender::External(index) => Ok(Self::External(index)),
347 #[cfg(feature = "by_ref_proposal")]
348 Sender::NewMemberProposal => Ok(Self::NewMember),
349 Sender::NewMemberCommit => Err(MlsError::InvalidSender),
350 }
351 }
352}
353
354#[cfg(feature = "by_ref_proposal")]
355#[derive(Clone, MlsEncode, MlsDecode, MlsSize, PartialEq)]
356#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
357#[non_exhaustive]
358pub struct ProposalMessageDescription {
360 pub sender: ProposalSender,
362 pub proposal: Proposal,
364 pub authenticated_data: Vec<u8>,
366 pub proposal_ref: ProposalRef,
368}
369
370#[cfg(feature = "by_ref_proposal")]
371impl Debug for ProposalMessageDescription {
372 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
373 f.debug_struct("ProposalMessageDescription")
374 .field("sender", &self.sender)
375 .field("proposal", &self.proposal)
376 .field(
377 "authenticated_data",
378 &mls_rs_core::debug::pretty_bytes(&self.authenticated_data),
379 )
380 .field("proposal_ref", &self.proposal_ref)
381 .finish()
382 }
383}
384
385#[cfg(feature = "by_ref_proposal")]
386#[derive(MlsSize, MlsEncode, MlsDecode)]
387pub struct CachedProposal {
388 pub(crate) proposal: Proposal,
389 pub(crate) proposal_ref: ProposalRef,
390 pub(crate) sender: Sender,
391}
392
393#[cfg(feature = "by_ref_proposal")]
394impl CachedProposal {
395 pub fn from_bytes(bytes: &[u8]) -> Result<Self, MlsError> {
397 Ok(Self::mls_decode(&mut &*bytes)?)
398 }
399
400 pub fn to_bytes(&self) -> Result<Vec<u8>, MlsError> {
402 Ok(self.mls_encode_to_vec()?)
403 }
404
405 pub fn proposal(&self) -> &Proposal {
407 &self.proposal
408 }
409
410 pub fn proposal_ref(&self) -> &ProposalRef {
412 &self.proposal_ref
413 }
414
415 pub fn sender(&self) -> &Sender {
417 &self.sender
418 }
419}
420
421#[cfg(feature = "by_ref_proposal")]
422impl ProposalMessageDescription {
423 pub fn cached_proposal(self) -> CachedProposal {
424 let sender = match self.sender {
425 ProposalSender::Member(i) => Sender::Member(i),
426 ProposalSender::External(i) => Sender::External(i),
427 ProposalSender::NewMember => Sender::NewMemberProposal,
428 };
429
430 CachedProposal {
431 proposal: self.proposal,
432 proposal_ref: self.proposal_ref,
433 sender,
434 }
435 }
436
437 pub fn proposal_ref(&self) -> Vec<u8> {
438 self.proposal_ref.to_vec()
439 }
440
441 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
442 pub(crate) async fn new<C: CipherSuiteProvider>(
443 cs: &C,
444 content: &AuthenticatedContent,
445 proposal: Proposal,
446 ) -> Result<Self, MlsError> {
447 Ok(ProposalMessageDescription {
448 authenticated_data: content.content.authenticated_data.clone(),
449 proposal,
450 sender: content.content.sender.try_into()?,
451 proposal_ref: ProposalRef::from_content(cs, content).await?,
452 })
453 }
454}
455
456#[cfg(not(feature = "by_ref_proposal"))]
457#[derive(Debug, Clone)]
458pub struct ProposalMessageDescription {}
460
461#[allow(clippy::large_enum_variant)]
462pub(crate) enum EventOrContent<E> {
463 #[cfg_attr(
464 not(all(feature = "private_message", feature = "external_client")),
465 allow(dead_code)
466 )]
467 Event(E),
468 Content(AuthenticatedContent),
469}
470
471#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
472#[cfg_attr(all(target_arch = "wasm32", mls_build_async), maybe_async::must_be_async(?Send))]
473#[cfg_attr(
474 all(not(target_arch = "wasm32"), mls_build_async),
475 maybe_async::must_be_async
476)]
477pub(crate) trait MessageProcessor: Send + Sync {
478 type OutputType: TryFrom<ApplicationMessageDescription, Error = MlsError>
479 + From<CommitMessageDescription>
480 + From<ProposalMessageDescription>
481 + From<GroupInfo>
482 + From<Welcome>
483 + From<KeyPackage>
484 + Send;
485
486 type MlsRules: MlsRules;
487 type IdentityProvider: IdentityProvider;
488 type CipherSuiteProvider: CipherSuiteProvider;
489 type PreSharedKeyStorage: PreSharedKeyStorage;
490
491 async fn process_incoming_message(
492 &mut self,
493 message: MlsMessage,
494 #[cfg(feature = "by_ref_proposal")] cache_proposal: bool,
495 ) -> Result<Self::OutputType, MlsError> {
496 self.process_incoming_message_with_time(
497 message,
498 #[cfg(feature = "by_ref_proposal")]
499 cache_proposal,
500 None,
501 )
502 .await
503 }
504
505 async fn process_incoming_message_with_time(
506 &mut self,
507 message: MlsMessage,
508 #[cfg(feature = "by_ref_proposal")] cache_proposal: bool,
509 time_sent: Option<MlsTime>,
510 ) -> Result<Self::OutputType, MlsError> {
511 #[cfg(all(feature = "export_key_generation", feature = "private_message"))]
512 let unauthn_key_gen_in_app_msg: Option<u32> = match message.payload {
518 MlsMessagePayload::Cipher(ref cipher_text) => self
519 .get_unauthenticated_key_generation_from_sender_data(cipher_text)
520 .unwrap_or_default(),
521 _ => None,
522 };
523
524 let event_or_content = self
525 .get_event_from_incoming_message(message, time_sent)
526 .await?;
527
528 self.process_event_or_content(
529 event_or_content,
530 #[cfg(feature = "by_ref_proposal")]
531 cache_proposal,
532 #[cfg(all(feature = "export_key_generation", feature = "private_message"))]
533 unauthn_key_gen_in_app_msg,
534 time_sent,
535 )
536 .await
537 }
538
539 async fn get_event_from_incoming_message(
540 &mut self,
541 message: MlsMessage,
542 time: Option<MlsTime>,
543 ) -> Result<EventOrContent<Self::OutputType>, MlsError> {
544 self.check_metadata(&message)?;
545
546 match message.payload {
547 MlsMessagePayload::Plain(plaintext) => {
548 self.verify_plaintext_authentication(plaintext).await
549 }
550 #[cfg(feature = "private_message")]
551 MlsMessagePayload::Cipher(cipher_text) => self.process_ciphertext(&cipher_text).await,
552 MlsMessagePayload::GroupInfo(group_info) => {
553 validate_group_info_member(
554 self.group_state(),
555 message.version,
556 &group_info,
557 self.cipher_suite_provider(),
558 )
559 .await?;
560
561 Ok(EventOrContent::Event(group_info.into()))
562 }
563 MlsMessagePayload::Welcome(welcome) => {
564 self.validate_welcome(&welcome, message.version)?;
565
566 Ok(EventOrContent::Event(welcome.into()))
567 }
568 MlsMessagePayload::KeyPackage(key_package) => {
569 self.validate_key_package(&key_package, message.version, time)
570 .await?;
571
572 Ok(EventOrContent::Event(key_package.into()))
573 }
574 }
575 }
576
577 async fn process_event_or_content(
578 &mut self,
579 event_or_content: EventOrContent<Self::OutputType>,
580 #[cfg(feature = "by_ref_proposal")] cache_proposal: bool,
581 #[cfg(all(feature = "export_key_generation", feature = "private_message"))]
582 unauthn_key_gen_in_app_msg: Option<u32>,
583 time_sent: Option<MlsTime>,
584 ) -> Result<Self::OutputType, MlsError> {
585 let msg = match event_or_content {
586 EventOrContent::Event(event) => event,
587 EventOrContent::Content(content) => {
588 self.process_auth_content(
589 content,
590 #[cfg(feature = "by_ref_proposal")]
591 cache_proposal,
592 #[cfg(all(feature = "export_key_generation", feature = "private_message"))]
593 unauthn_key_gen_in_app_msg,
594 time_sent,
595 )
596 .await?
597 }
598 };
599
600 Ok(msg)
601 }
602
603 async fn process_auth_content(
604 &mut self,
605 auth_content: AuthenticatedContent,
606 #[cfg(feature = "by_ref_proposal")] cache_proposal: bool,
607 #[cfg(all(feature = "export_key_generation", feature = "private_message"))]
608 unauthn_key_gen_in_app_msg: Option<u32>,
609 time_sent: Option<MlsTime>,
610 ) -> Result<Self::OutputType, MlsError> {
611 let event = match auth_content.content.content {
612 #[cfg(feature = "private_message")]
613 Content::Application(data) => {
614 let authenticated_data = auth_content.content.authenticated_data;
615 let sender = auth_content.content.sender;
616
617 self.process_application_message(
618 data,
619 sender,
620 authenticated_data,
621 #[cfg(all(feature = "export_key_generation", feature = "private_message"))]
622 unauthn_key_gen_in_app_msg,
623 )
624 .and_then(Self::OutputType::try_from)
625 }
626 Content::Commit(_) => self
627 .process_commit(auth_content, time_sent)
628 .await
629 .map(Self::OutputType::from),
630 #[cfg(feature = "by_ref_proposal")]
631 Content::Proposal(ref proposal) => self
632 .process_proposal(&auth_content, proposal, cache_proposal)
633 .await
634 .map(Self::OutputType::from),
635 }?;
636
637 Ok(event)
638 }
639
640 #[cfg(feature = "private_message")]
641 fn process_application_message(
642 &self,
643 data: ApplicationData,
644 sender: Sender,
645 authenticated_data: Vec<u8>,
646 #[cfg(all(feature = "export_key_generation", feature = "private_message"))]
647 unauthenticated_key_generation: Option<u32>,
648 ) -> Result<ApplicationMessageDescription, MlsError> {
649 let Sender::Member(sender_index) = sender else {
650 return Err(MlsError::InvalidSender);
651 };
652
653 Ok(ApplicationMessageDescription {
654 authenticated_data,
655 sender_index,
656 data,
657 #[cfg(all(feature = "export_key_generation", feature = "private_message"))]
658 unauthenticated_key_generation,
659 })
660 }
661
662 #[cfg(feature = "by_ref_proposal")]
663 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
664 async fn process_proposal(
665 &mut self,
666 auth_content: &AuthenticatedContent,
667 proposal: &Proposal,
668 cache_proposal: bool,
669 ) -> Result<ProposalMessageDescription, MlsError> {
670 let proposal = ProposalMessageDescription::new(
671 self.cipher_suite_provider(),
672 auth_content,
673 proposal.clone(),
674 )
675 .await?;
676
677 let group_state = self.group_state_mut();
678
679 if cache_proposal {
680 group_state.proposals.insert(
681 proposal.proposal_ref.clone(),
682 proposal.proposal.clone(),
683 auth_content.content.sender,
684 );
685 }
686
687 Ok(proposal)
688 }
689
690 async fn process_commit(
691 &mut self,
692 auth_content: AuthenticatedContent,
693 time_sent: Option<MlsTime>,
694 ) -> Result<CommitMessageDescription, MlsError> {
695 if self.group_state().pending_reinit.is_some() {
696 return Err(MlsError::GroupUsedAfterReInit);
697 }
698
699 let (interim_transcript_hash, confirmed_transcript_hash) = transcript_hashes(
701 self.cipher_suite_provider(),
702 &self.group_state().interim_transcript_hash,
703 &auth_content,
704 )
705 .await?;
706
707 #[cfg(any(feature = "private_message", feature = "by_ref_proposal"))]
708 let commit = match auth_content.content.content {
709 Content::Commit(commit) => Ok(commit),
710 _ => Err(MlsError::UnexpectedMessageType),
711 }?;
712
713 #[cfg(not(any(feature = "private_message", feature = "by_ref_proposal")))]
714 let Content::Commit(commit) = auth_content.content.content;
715
716 let group_state = self.group_state();
717 let id_provider = self.identity_provider();
718
719 #[cfg(feature = "by_ref_proposal")]
720 let proposals = group_state
721 .proposals
722 .resolve_for_commit(auth_content.content.sender, commit.proposals)?;
723
724 #[cfg(not(feature = "by_ref_proposal"))]
725 let proposals = resolve_for_commit(auth_content.content.sender, commit.proposals)?;
726
727 let mut provisional_state = group_state
728 .apply_resolved(
729 auth_content.content.sender,
730 proposals,
731 commit.path.as_ref().map(|path| &path.leaf_node),
732 &id_provider,
733 self.cipher_suite_provider(),
734 &self.psk_storage(),
735 &self.mls_rules(),
736 time_sent,
737 CommitDirection::Receive,
738 )
739 .await?;
740
741 let sender = commit_sender(&auth_content.content.sender, &provisional_state)?;
742
743 if path_update_required(&provisional_state.applied_proposals) && commit.path.is_none() {
746 return Err(MlsError::CommitMissingPath);
747 }
748
749 let self_removed = self.removal_proposal(&provisional_state);
750 #[cfg(all(
751 feature = "by_ref_proposal",
752 feature = "custom_proposal",
753 feature = "self_remove_proposal"
754 ))]
755 let self_removed_by_self = self.self_removal_proposal(&provisional_state);
756
757 let is_self_removed = self_removed.is_some();
758 #[cfg(all(
759 feature = "by_ref_proposal",
760 feature = "custom_proposal",
761 feature = "self_remove_proposal"
762 ))]
763 let is_self_removed = is_self_removed || self_removed_by_self.is_some();
764
765 let update_path = match commit.path {
766 Some(update_path) => Some(
767 validate_update_path(
768 &self.identity_provider(),
769 self.cipher_suite_provider(),
770 update_path,
771 &provisional_state,
772 sender,
773 time_sent,
774 &group_state.context,
775 )
776 .await?,
777 ),
778 None => None,
779 };
780
781 let commit_effect =
782 if let Some(reinit) = provisional_state.applied_proposals.reinitializations.pop() {
783 self.group_state_mut().pending_reinit = Some(reinit.proposal.clone());
784 CommitEffect::ReInit(reinit)
785 } else if let Some(remove_proposal) = self_removed {
786 let new_epoch = NewEpoch::new(self.group_state().clone(), &provisional_state);
787 CommitEffect::Removed {
788 remover: remove_proposal.sender,
789 new_epoch: Box::new(new_epoch),
790 }
791 } else {
792 CommitEffect::NewEpoch(Box::new(NewEpoch::new(
793 self.group_state().clone(),
794 &provisional_state,
795 )))
796 };
797
798 #[cfg(all(
799 feature = "by_ref_proposal",
800 feature = "custom_proposal",
801 feature = "self_remove_proposal"
802 ))]
803 let commit_effect = if let Some(self_remove_proposal) = self_removed_by_self {
804 let new_epoch = NewEpoch::new(self.group_state().clone(), &provisional_state);
805 CommitEffect::Removed {
806 remover: self_remove_proposal.sender,
807 new_epoch: Box::new(new_epoch),
808 }
809 } else {
810 commit_effect
811 };
812
813 let new_secrets = match update_path {
814 Some(update_path) if !is_self_removed => {
815 self.apply_update_path(sender, &update_path, &mut provisional_state)
816 .await
817 }
818 _ => Ok(None),
819 }?;
820
821 provisional_state.group_context.confirmed_transcript_hash = confirmed_transcript_hash;
823
824 provisional_state
826 .public_tree
827 .update_hashes(&[sender], self.cipher_suite_provider())
828 .await?;
829
830 provisional_state.group_context.tree_hash = provisional_state
832 .public_tree
833 .tree_hash(self.cipher_suite_provider())
834 .await?;
835
836 if let Some(confirmation_tag) = &auth_content.auth.confirmation_tag {
837 if !is_self_removed {
838 self.update_key_schedule(
840 new_secrets,
841 interim_transcript_hash,
842 confirmation_tag,
843 provisional_state,
844 )
845 .await?;
846 }
847 Ok(CommitMessageDescription {
848 is_external: matches!(auth_content.content.sender, Sender::NewMemberCommit),
849 authenticated_data: auth_content.content.authenticated_data,
850 committer: *sender,
851 effect: commit_effect,
852 })
853 } else {
854 Err(MlsError::InvalidConfirmationTag)
855 }
856 }
857
858 fn group_state(&self) -> &GroupState;
859 fn group_state_mut(&mut self) -> &mut GroupState;
860 fn mls_rules(&self) -> Self::MlsRules;
861 fn identity_provider(&self) -> Self::IdentityProvider;
862 fn cipher_suite_provider(&self) -> &Self::CipherSuiteProvider;
863 fn psk_storage(&self) -> Self::PreSharedKeyStorage;
864
865 fn removal_proposal(
866 &self,
867 provisional_state: &ProvisionalState,
868 ) -> Option<ProposalInfo<RemoveProposal>>;
869
870 #[cfg(all(
871 feature = "by_ref_proposal",
872 feature = "custom_proposal",
873 feature = "self_remove_proposal"
874 ))]
875 fn self_removal_proposal(
876 &self,
877 provisional_state: &ProvisionalState,
878 ) -> Option<ProposalInfo<SelfRemoveProposal>>;
879
880 #[cfg(feature = "private_message")]
881 fn min_epoch_available(&self) -> Option<u64>;
882
883 fn check_metadata(&self, message: &MlsMessage) -> Result<(), MlsError> {
884 let context = &self.group_state().context;
885
886 if message.version != context.protocol_version {
887 return Err(MlsError::ProtocolVersionMismatch);
888 }
889
890 if let Some((group_id, epoch, content_type)) = match &message.payload {
891 MlsMessagePayload::Plain(plaintext) => Some((
892 &plaintext.content.group_id,
893 plaintext.content.epoch,
894 plaintext.content.content_type(),
895 )),
896 #[cfg(feature = "private_message")]
897 MlsMessagePayload::Cipher(ciphertext) => Some((
898 &ciphertext.group_id,
899 ciphertext.epoch,
900 ciphertext.content_type,
901 )),
902 _ => None,
903 } {
904 if group_id != &context.group_id {
905 return Err(MlsError::GroupIdMismatch);
906 }
907
908 match content_type {
909 ContentType::Commit => {
910 if context.epoch != epoch {
911 Err(MlsError::InvalidEpoch)
912 } else {
913 Ok(())
914 }
915 }
916 #[cfg(feature = "by_ref_proposal")]
917 ContentType::Proposal => {
918 if context.epoch != epoch {
919 Err(MlsError::InvalidEpoch)
920 } else {
921 Ok(())
922 }
923 }
924 #[cfg(feature = "private_message")]
925 ContentType::Application => {
926 if let Some(min) = self.min_epoch_available() {
927 if epoch < min {
928 Err(MlsError::InvalidEpoch)
929 } else {
930 Ok(())
931 }
932 } else {
933 Ok(())
934 }
935 }
936 }?;
937
938 let check_epoch = content_type == ContentType::Commit;
940
941 #[cfg(feature = "by_ref_proposal")]
942 let check_epoch = check_epoch || content_type == ContentType::Proposal;
943
944 if check_epoch && epoch != context.epoch {
945 return Err(MlsError::InvalidEpoch);
946 }
947
948 #[cfg(feature = "private_message")]
950 if !matches!(&message.payload, MlsMessagePayload::Cipher(_))
951 && content_type == ContentType::Application
952 {
953 return Err(MlsError::UnencryptedApplicationMessage);
954 }
955 }
956
957 Ok(())
958 }
959
960 fn validate_welcome(
961 &self,
962 welcome: &Welcome,
963 version: ProtocolVersion,
964 ) -> Result<(), MlsError> {
965 let state = self.group_state();
966
967 (welcome.cipher_suite == state.context.cipher_suite
968 && version == state.context.protocol_version)
969 .then_some(())
970 .ok_or(MlsError::InvalidWelcomeMessage)
971 }
972
973 async fn validate_key_package(
974 &self,
975 key_package: &KeyPackage,
976 version: ProtocolVersion,
977 time: Option<MlsTime>,
978 ) -> Result<(), MlsError> {
979 let cs = self.cipher_suite_provider();
980 let id = self.identity_provider();
981
982 validate_key_package(key_package, version, cs, &id, time).await
983 }
984
985 #[cfg(feature = "private_message")]
986 async fn process_ciphertext(
987 &mut self,
988 cipher_text: &PrivateMessage,
989 ) -> Result<EventOrContent<Self::OutputType>, MlsError>;
990
991 async fn verify_plaintext_authentication(
992 &self,
993 message: PublicMessage,
994 ) -> Result<EventOrContent<Self::OutputType>, MlsError>;
995
996 #[cfg(all(feature = "export_key_generation", feature = "private_message"))]
997 async fn get_unauthenticated_key_generation_from_sender_data(
999 &mut self,
1000 cipher_text: &PrivateMessage,
1001 ) -> Result<Option<u32>, MlsError>;
1002
1003 async fn apply_update_path(
1004 &mut self,
1005 sender: LeafIndex,
1006 update_path: &ValidatedUpdatePath,
1007 provisional_state: &mut ProvisionalState,
1008 ) -> Result<Option<(TreeKemPrivate, PathSecret)>, MlsError> {
1009 provisional_state
1010 .public_tree
1011 .apply_update_path(
1012 sender,
1013 update_path,
1014 &provisional_state.group_context.extensions,
1015 self.identity_provider(),
1016 self.cipher_suite_provider(),
1017 )
1018 .await
1019 .map(|_| None)
1020 }
1021
1022 async fn update_key_schedule(
1023 &mut self,
1024 secrets: Option<(TreeKemPrivate, PathSecret)>,
1025 interim_transcript_hash: InterimTranscriptHash,
1026 confirmation_tag: &ConfirmationTag,
1027 provisional_public_state: ProvisionalState,
1028 ) -> Result<(), MlsError>;
1029}
1030
1031#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
1032pub(crate) async fn validate_key_package<C: CipherSuiteProvider, I: IdentityProvider>(
1033 key_package: &KeyPackage,
1034 version: ProtocolVersion,
1035 cs: &C,
1036 id: &I,
1037 time: Option<MlsTime>,
1038) -> Result<(), MlsError> {
1039 let validator = LeafNodeValidator::new(cs, id, MemberValidationContext::None);
1040
1041 #[cfg(feature = "std")]
1042 let context = Some(MlsTime::now());
1043
1044 #[cfg(not(feature = "std"))]
1045 let context = None;
1046
1047 let context = if time.is_some() { time } else { context };
1048
1049 let context = ValidationContext::Add(context);
1050
1051 validator
1052 .check_if_valid(&key_package.leaf_node, context)
1053 .await?;
1054
1055 validate_key_package_properties(key_package, version, cs).await?;
1056
1057 Ok(())
1058}
1059
1060#[cfg(test)]
1061mod tests {
1062 use alloc::{vec, vec::Vec};
1063 use mls_rs_codec::{MlsDecode, MlsEncode};
1064
1065 use crate::{
1066 client::test_utils::TEST_PROTOCOL_VERSION,
1067 group::{test_utils::get_test_group_context, GroupState, Sender},
1068 };
1069
1070 use super::{CommitEffect, NewEpoch};
1071
1072 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
1073 async fn commit_effect_codec() {
1074 let epoch = NewEpoch {
1075 epoch: 7,
1076 prior_state: GroupState {
1077 #[cfg(feature = "by_ref_proposal")]
1078 proposals: crate::group::ProposalCache::new(TEST_PROTOCOL_VERSION, vec![]),
1079 context: get_test_group_context(7, 7.into()).await,
1080 public_tree: Default::default(),
1081 interim_transcript_hash: vec![].into(),
1082 pending_reinit: None,
1083 confirmation_tag: Default::default(),
1084 },
1085 applied_proposals: vec![],
1086 unused_proposals: vec![],
1087 };
1088
1089 let effects = vec![
1090 CommitEffect::NewEpoch(epoch.clone().into()),
1091 CommitEffect::Removed {
1092 new_epoch: epoch.into(),
1093 remover: Sender::Member(0),
1094 },
1095 ];
1096
1097 let bytes = effects.mls_encode_to_vec().unwrap();
1098
1099 assert_eq!(
1100 effects,
1101 Vec::<CommitEffect>::mls_decode(&mut &*bytes).unwrap()
1102 );
1103 }
1104}