mls_rs/group/
framing.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// Copyright by contributors to this project.
3// SPDX-License-Identifier: (Apache-2.0 OR MIT)
4
5use core::ops::Deref;
6
7use crate::{client::MlsError, tree_kem::node::LeafIndex, KeyPackage, KeyPackageRef};
8
9use super::{Commit, FramedContentAuthData, GroupInfo, MembershipTag, Welcome};
10
11use crate::group::proposal::{Proposal, ProposalOrRef};
12
13#[cfg(feature = "by_ref_proposal")]
14use crate::mls_rules::ProposalRef;
15
16use alloc::vec::Vec;
17use core::fmt::{self, Debug};
18use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
19use mls_rs_core::{
20    crypto::{CipherSuite, CipherSuiteProvider},
21    protocol_version::ProtocolVersion,
22};
23use zeroize::ZeroizeOnDrop;
24
25#[cfg(feature = "private_message")]
26use alloc::boxed::Box;
27
28#[cfg(feature = "custom_proposal")]
29use crate::group::proposal::CustomProposal;
30
31#[derive(Copy, Clone, Debug, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)]
32#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
33#[repr(u8)]
34pub enum ContentType {
35    #[cfg(feature = "private_message")]
36    Application = 1u8,
37    #[cfg(feature = "by_ref_proposal")]
38    Proposal = 2u8,
39    Commit = 3u8,
40}
41
42impl From<&Content> for ContentType {
43    fn from(content: &Content) -> Self {
44        match content {
45            #[cfg(feature = "private_message")]
46            Content::Application(_) => ContentType::Application,
47            #[cfg(feature = "by_ref_proposal")]
48            Content::Proposal(_) => ContentType::Proposal,
49            Content::Commit(_) => ContentType::Commit,
50        }
51    }
52}
53
54#[cfg_attr(
55    all(feature = "ffi", not(test)),
56    safer_ffi_gen::ffi_type(clone, opaque)
57)]
58#[derive(Clone, Copy, Debug, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)]
59#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
60#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
61#[repr(u8)]
62#[non_exhaustive]
63/// Description of a [`MlsMessage`] sender
64pub enum Sender {
65    /// Current group member index.
66    Member(u32) = 1u8,
67    /// An external entity sending a proposal proposal identified by an index
68    /// in the current
69    /// [`ExternalSendersExt`](crate::extension::ExternalSendersExt) stored in
70    /// group context extensions.
71    #[cfg(feature = "by_ref_proposal")]
72    External(u32) = 2u8,
73    /// A new member proposing their own addition to the group.
74    #[cfg(feature = "by_ref_proposal")]
75    NewMemberProposal = 3u8,
76    /// A member sending an external commit.
77    NewMemberCommit = 4u8,
78}
79
80impl From<LeafIndex> for Sender {
81    fn from(leaf_index: LeafIndex) -> Self {
82        Sender::Member(*leaf_index)
83    }
84}
85
86impl From<u32> for Sender {
87    fn from(leaf_index: u32) -> Self {
88        Sender::Member(leaf_index)
89    }
90}
91
92#[derive(Clone, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode, ZeroizeOnDrop)]
93#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
94#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
95pub struct ApplicationData(
96    #[mls_codec(with = "mls_rs_codec::byte_vec")]
97    #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))]
98    Vec<u8>,
99);
100
101impl Debug for ApplicationData {
102    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
103        mls_rs_core::debug::pretty_bytes(&self.0)
104            .named("ApplicationData")
105            .fmt(f)
106    }
107}
108
109impl From<Vec<u8>> for ApplicationData {
110    fn from(data: Vec<u8>) -> Self {
111        Self(data)
112    }
113}
114
115impl Deref for ApplicationData {
116    type Target = [u8];
117
118    fn deref(&self) -> &Self::Target {
119        &self.0
120    }
121}
122
123impl ApplicationData {
124    /// Underlying message content.
125    pub fn as_bytes(&self) -> &[u8] {
126        &self.0
127    }
128}
129
130#[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode)]
131#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
132#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
133#[repr(u8)]
134pub(crate) enum Content {
135    #[cfg(feature = "private_message")]
136    Application(ApplicationData) = 1u8,
137    #[cfg(feature = "by_ref_proposal")]
138    Proposal(alloc::boxed::Box<Proposal>) = 2u8,
139    Commit(alloc::boxed::Box<Commit>) = 3u8,
140}
141
142impl Content {
143    pub fn content_type(&self) -> ContentType {
144        self.into()
145    }
146}
147
148#[derive(Clone, Debug, PartialEq)]
149#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
150pub(crate) struct PublicMessage {
151    pub content: FramedContent,
152    pub auth: FramedContentAuthData,
153    pub membership_tag: Option<MembershipTag>,
154}
155
156impl MlsSize for PublicMessage {
157    fn mls_encoded_len(&self) -> usize {
158        self.content.mls_encoded_len()
159            + self.auth.mls_encoded_len()
160            + self
161                .membership_tag
162                .as_ref()
163                .map_or(0, |tag| tag.mls_encoded_len())
164    }
165}
166
167impl MlsEncode for PublicMessage {
168    fn mls_encode(&self, writer: &mut Vec<u8>) -> Result<(), mls_rs_codec::Error> {
169        self.content.mls_encode(writer)?;
170        self.auth.mls_encode(writer)?;
171
172        self.membership_tag
173            .as_ref()
174            .map_or(Ok(()), |tag| tag.mls_encode(writer))
175    }
176}
177
178impl MlsDecode for PublicMessage {
179    fn mls_decode(reader: &mut &[u8]) -> Result<Self, mls_rs_codec::Error> {
180        let content = FramedContent::mls_decode(reader)?;
181        let auth = FramedContentAuthData::mls_decode(reader, content.content_type())?;
182
183        let membership_tag = match content.sender {
184            Sender::Member(_) => Some(MembershipTag::mls_decode(reader)?),
185            _ => None,
186        };
187
188        Ok(Self {
189            content,
190            auth,
191            membership_tag,
192        })
193    }
194}
195
196#[cfg(feature = "private_message")]
197#[derive(Clone, Debug, PartialEq)]
198pub(crate) struct PrivateMessageContent {
199    pub content: Content,
200    pub auth: FramedContentAuthData,
201}
202
203#[cfg(feature = "private_message")]
204impl MlsSize for PrivateMessageContent {
205    fn mls_encoded_len(&self) -> usize {
206        let content_len_without_type = match &self.content {
207            Content::Application(c) => c.mls_encoded_len(),
208            #[cfg(feature = "by_ref_proposal")]
209            Content::Proposal(c) => c.mls_encoded_len(),
210            Content::Commit(c) => c.mls_encoded_len(),
211        };
212
213        content_len_without_type + self.auth.mls_encoded_len()
214    }
215}
216
217#[cfg(feature = "private_message")]
218impl MlsEncode for PrivateMessageContent {
219    fn mls_encode(&self, writer: &mut Vec<u8>) -> Result<(), mls_rs_codec::Error> {
220        match &self.content {
221            Content::Application(c) => c.mls_encode(writer),
222            #[cfg(feature = "by_ref_proposal")]
223            Content::Proposal(c) => c.mls_encode(writer),
224            Content::Commit(c) => c.mls_encode(writer),
225        }?;
226
227        self.auth.mls_encode(writer)?;
228
229        Ok(())
230    }
231}
232
233#[cfg(feature = "private_message")]
234impl PrivateMessageContent {
235    pub(crate) fn mls_decode(
236        reader: &mut &[u8],
237        content_type: ContentType,
238    ) -> Result<Self, mls_rs_codec::Error> {
239        let content = match content_type {
240            ContentType::Application => Content::Application(ApplicationData::mls_decode(reader)?),
241            #[cfg(feature = "by_ref_proposal")]
242            ContentType::Proposal => Content::Proposal(Box::new(Proposal::mls_decode(reader)?)),
243            ContentType::Commit => {
244                Content::Commit(alloc::boxed::Box::new(Commit::mls_decode(reader)?))
245            }
246        };
247
248        let auth = FramedContentAuthData::mls_decode(reader, content.content_type())?;
249
250        if reader.iter().any(|&i| i != 0u8) {
251            // #[cfg(feature = "std")]
252            // return Err(mls_rs_codec::Error::Custom(
253            //    "non-zero padding bytes discovered".to_string(),
254            // ));
255
256            // #[cfg(not(feature = "std"))]
257            return Err(mls_rs_codec::Error::Custom(5));
258        }
259
260        Ok(Self { content, auth })
261    }
262}
263
264#[cfg(feature = "private_message")]
265#[derive(Clone, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)]
266pub struct PrivateContentAAD {
267    #[mls_codec(with = "mls_rs_codec::byte_vec")]
268    pub group_id: Vec<u8>,
269    pub epoch: u64,
270    pub content_type: ContentType,
271    #[mls_codec(with = "mls_rs_codec::byte_vec")]
272    pub authenticated_data: Vec<u8>,
273}
274
275#[cfg(feature = "private_message")]
276impl Debug for PrivateContentAAD {
277    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
278        f.debug_struct("PrivateContentAAD")
279            .field(
280                "group_id",
281                &mls_rs_core::debug::pretty_group_id(&self.group_id),
282            )
283            .field("epoch", &self.epoch)
284            .field("content_type", &self.content_type)
285            .field(
286                "authenticated_data",
287                &mls_rs_core::debug::pretty_bytes(&self.authenticated_data),
288            )
289            .finish()
290    }
291}
292
293#[cfg(feature = "private_message")]
294#[derive(Clone, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)]
295#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
296pub struct PrivateMessage {
297    #[mls_codec(with = "mls_rs_codec::byte_vec")]
298    pub group_id: Vec<u8>,
299    pub epoch: u64,
300    pub content_type: ContentType,
301    #[mls_codec(with = "mls_rs_codec::byte_vec")]
302    pub authenticated_data: Vec<u8>,
303    #[mls_codec(with = "mls_rs_codec::byte_vec")]
304    pub encrypted_sender_data: Vec<u8>,
305    #[mls_codec(with = "mls_rs_codec::byte_vec")]
306    pub ciphertext: Vec<u8>,
307}
308
309#[cfg(feature = "private_message")]
310impl Debug for PrivateMessage {
311    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
312        f.debug_struct("PrivateMessage")
313            .field(
314                "group_id",
315                &mls_rs_core::debug::pretty_group_id(&self.group_id),
316            )
317            .field("epoch", &self.epoch)
318            .field("content_type", &self.content_type)
319            .field(
320                "authenticated_data",
321                &mls_rs_core::debug::pretty_bytes(&self.authenticated_data),
322            )
323            .field(
324                "encrypted_sender_data",
325                &mls_rs_core::debug::pretty_bytes(&self.encrypted_sender_data),
326            )
327            .field(
328                "ciphertext",
329                &mls_rs_core::debug::pretty_bytes(&self.ciphertext),
330            )
331            .finish()
332    }
333}
334
335#[cfg(feature = "private_message")]
336impl From<&PrivateMessage> for PrivateContentAAD {
337    fn from(ciphertext: &PrivateMessage) -> Self {
338        Self {
339            group_id: ciphertext.group_id.clone(),
340            epoch: ciphertext.epoch,
341            content_type: ciphertext.content_type,
342            authenticated_data: ciphertext.authenticated_data.clone(),
343        }
344    }
345}
346
347#[derive(Clone, Debug, PartialEq)]
348pub enum MlsMessageDescription<'a> {
349    Welcome {
350        key_package_refs: Vec<&'a KeyPackageRef>,
351        cipher_suite: CipherSuite,
352    },
353    PrivateProtocolMessage {
354        group_id: &'a [u8],
355        epoch_id: u64,
356        content_type: ContentType, // commit, proposal, or application
357    },
358    PublicProtocolMessage {
359        group_id: &'a [u8],
360        epoch_id: u64,
361        content_type: ContentType,
362        sender: Sender,
363        authenticated_data: &'a [u8],
364    },
365    GroupInfo,
366    KeyPackage,
367}
368
369impl MlsMessage {
370    pub fn description(&self) -> MlsMessageDescription<'_> {
371        match &self.payload {
372            MlsMessagePayload::Welcome(w) => MlsMessageDescription::Welcome {
373                key_package_refs: w.secrets.iter().map(|s| &s.new_member).collect(),
374                cipher_suite: w.cipher_suite,
375            },
376            MlsMessagePayload::Plain(p) => MlsMessageDescription::PublicProtocolMessage {
377                group_id: &p.content.group_id,
378                epoch_id: p.content.epoch,
379                content_type: p.content.content_type(),
380                sender: p.content.sender,
381                authenticated_data: &p.content.authenticated_data,
382            },
383            #[cfg(feature = "private_message")]
384            MlsMessagePayload::Cipher(c) => MlsMessageDescription::PrivateProtocolMessage {
385                group_id: &c.group_id,
386                epoch_id: c.epoch,
387                content_type: c.content_type,
388            },
389            MlsMessagePayload::GroupInfo(_) => MlsMessageDescription::GroupInfo,
390            MlsMessagePayload::KeyPackage(_) => MlsMessageDescription::KeyPackage,
391        }
392    }
393}
394
395#[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode)]
396#[cfg_attr(
397    all(feature = "ffi", not(test)),
398    ::safer_ffi_gen::ffi_type(clone, opaque)
399)]
400#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
401/// A MLS protocol message for sending data over the wire.
402pub struct MlsMessage {
403    pub(crate) version: ProtocolVersion,
404    pub(crate) payload: MlsMessagePayload,
405}
406
407#[cfg_attr(all(feature = "ffi", not(test)), ::safer_ffi_gen::safer_ffi_gen)]
408#[allow(dead_code)]
409impl MlsMessage {
410    pub(crate) fn new(version: ProtocolVersion, payload: MlsMessagePayload) -> MlsMessage {
411        Self { version, payload }
412    }
413
414    #[inline(always)]
415    pub(crate) fn into_plaintext(self) -> Option<PublicMessage> {
416        match self.payload {
417            MlsMessagePayload::Plain(plaintext) => Some(plaintext),
418            _ => None,
419        }
420    }
421
422    #[cfg(feature = "private_message")]
423    #[inline(always)]
424    pub(crate) fn into_ciphertext(self) -> Option<PrivateMessage> {
425        match self.payload {
426            MlsMessagePayload::Cipher(ciphertext) => Some(ciphertext),
427            _ => None,
428        }
429    }
430
431    #[inline(always)]
432    pub(crate) fn into_welcome(self) -> Option<Welcome> {
433        match self.payload {
434            MlsMessagePayload::Welcome(welcome) => Some(welcome),
435            _ => None,
436        }
437    }
438
439    #[inline(always)]
440    pub fn into_group_info(self) -> Option<GroupInfo> {
441        match self.payload {
442            MlsMessagePayload::GroupInfo(info) => Some(info),
443            _ => None,
444        }
445    }
446
447    #[inline(always)]
448    pub fn as_group_info(&self) -> Option<&GroupInfo> {
449        match &self.payload {
450            MlsMessagePayload::GroupInfo(info) => Some(info),
451            _ => None,
452        }
453    }
454
455    #[inline(always)]
456    pub fn into_key_package(self) -> Option<KeyPackage> {
457        match self.payload {
458            MlsMessagePayload::KeyPackage(kp) => Some(kp),
459            _ => None,
460        }
461    }
462
463    pub fn as_key_package(&self) -> Option<&KeyPackage> {
464        match &self.payload {
465            MlsMessagePayload::KeyPackage(kp) => Some(kp),
466            _ => None,
467        }
468    }
469
470    /// The wire format value describing the contents of this message.
471    pub fn wire_format(&self) -> WireFormat {
472        match self.payload {
473            MlsMessagePayload::Plain(_) => WireFormat::PublicMessage,
474            #[cfg(feature = "private_message")]
475            MlsMessagePayload::Cipher(_) => WireFormat::PrivateMessage,
476            MlsMessagePayload::Welcome(_) => WireFormat::Welcome,
477            MlsMessagePayload::GroupInfo(_) => WireFormat::GroupInfo,
478            MlsMessagePayload::KeyPackage(_) => WireFormat::KeyPackage,
479        }
480    }
481
482    /// The epoch that this message belongs to.
483    ///
484    /// Returns `None` if the message is [`WireFormat::KeyPackage`]
485    /// or [`WireFormat::Welcome`]
486    #[cfg_attr(all(feature = "ffi", not(test)), ::safer_ffi_gen::safer_ffi_gen_ignore)]
487    pub fn epoch(&self) -> Option<u64> {
488        match &self.payload {
489            MlsMessagePayload::Plain(p) => Some(p.content.epoch),
490            #[cfg(feature = "private_message")]
491            MlsMessagePayload::Cipher(c) => Some(c.epoch),
492            MlsMessagePayload::GroupInfo(gi) => Some(gi.group_context.epoch),
493            _ => None,
494        }
495    }
496
497    #[cfg_attr(all(feature = "ffi", not(test)), ::safer_ffi_gen::safer_ffi_gen_ignore)]
498    pub fn cipher_suite(&self) -> Option<CipherSuite> {
499        match &self.payload {
500            MlsMessagePayload::GroupInfo(i) => Some(i.group_context.cipher_suite),
501            MlsMessagePayload::Welcome(w) => Some(w.cipher_suite),
502            MlsMessagePayload::KeyPackage(k) => Some(k.cipher_suite),
503            _ => None,
504        }
505    }
506
507    pub fn group_id(&self) -> Option<&[u8]> {
508        match &self.payload {
509            MlsMessagePayload::Plain(p) => Some(&p.content.group_id),
510            #[cfg(feature = "private_message")]
511            MlsMessagePayload::Cipher(p) => Some(&p.group_id),
512            MlsMessagePayload::GroupInfo(p) => Some(&p.group_context.group_id),
513            MlsMessagePayload::KeyPackage(_) | MlsMessagePayload::Welcome(_) => None,
514        }
515    }
516
517    /// Deserialize a message from transport.
518    #[inline(never)]
519    pub fn from_bytes(bytes: &[u8]) -> Result<Self, MlsError> {
520        Self::mls_decode(&mut &*bytes).map_err(Into::into)
521    }
522
523    /// Serialize a message for transport.
524    pub fn to_bytes(&self) -> Result<Vec<u8>, MlsError> {
525        self.mls_encode_to_vec().map_err(Into::into)
526    }
527
528    /// If this is a plaintext commit message, return all custom proposals committed by value.
529    /// If this is not a plaintext or not a commit, this returns an empty list.
530    #[cfg(feature = "custom_proposal")]
531    pub fn custom_proposals_by_value(&self) -> Vec<&CustomProposal> {
532        match &self.payload {
533            MlsMessagePayload::Plain(plaintext) => match &plaintext.content.content {
534                Content::Commit(commit) => Self::find_custom_proposals(commit),
535                _ => Vec::new(),
536            },
537            _ => Vec::new(),
538        }
539    }
540
541    /// If this is a plaintext commit message, return all proposals committed by value.
542    /// If this is not a plaintext or not a commit, this returns an empty list.
543    /// **Note**: This method is not available in FFI bindings due to Proposal type constraints.
544    #[cfg_attr(all(feature = "ffi", not(test)), ::safer_ffi_gen::safer_ffi_gen_ignore)]
545    #[allow(unreachable_patterns)]
546    pub fn proposals_by_value(&self) -> Vec<&Proposal> {
547        match &self.payload {
548            MlsMessagePayload::Plain(plaintext) => match &plaintext.content.content {
549                Content::Commit(commit) => Self::find_all_proposals(commit),
550                _ => Vec::new(),
551            },
552            _ => Vec::new(),
553        }
554    }
555
556    /// If this is a welcome message, return key package references of all members who can
557    /// join using this message.
558    pub fn welcome_key_package_references(&self) -> Vec<&KeyPackageRef> {
559        let MlsMessagePayload::Welcome(welcome) = &self.payload else {
560            return Vec::new();
561        };
562
563        welcome.secrets.iter().map(|s| &s.new_member).collect()
564    }
565
566    /// If this is a key package, return its key package reference.
567    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
568    pub async fn key_package_reference<C: CipherSuiteProvider>(
569        &self,
570        cipher_suite: &C,
571    ) -> Result<Option<KeyPackageRef>, MlsError> {
572        let MlsMessagePayload::KeyPackage(kp) = &self.payload else {
573            return Ok(None);
574        };
575
576        kp.to_reference(cipher_suite).await.map(Some)
577    }
578
579    /// If this is a plaintext proposal, return the proposal reference that can be matched e.g. with
580    /// [`NewEpoch::unused_proposals`](super::NewEpoch::unused_proposals).
581    #[cfg(feature = "by_ref_proposal")]
582    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
583    pub async fn into_proposal_reference<C: CipherSuiteProvider>(
584        self,
585        cipher_suite: &C,
586    ) -> Result<Option<Vec<u8>>, MlsError> {
587        let MlsMessagePayload::Plain(public_message) = self.payload else {
588            return Ok(None);
589        };
590
591        ProposalRef::from_content(cipher_suite, &public_message.into())
592            .await
593            .map(|r| Some(r.to_vec()))
594    }
595}
596
597impl MlsMessage {
598    #[cfg(feature = "custom_proposal")]
599    fn find_custom_proposals(commit: &Commit) -> Vec<&CustomProposal> {
600        commit
601            .proposals
602            .iter()
603            .filter_map(|p| match p {
604                ProposalOrRef::Proposal(p) => match p.as_ref() {
605                    crate::group::Proposal::Custom(p) => Some(p),
606                    _ => None,
607                },
608                _ => None,
609            })
610            .collect()
611    }
612
613    #[allow(unreachable_patterns)]
614    fn find_all_proposals(commit: &Commit) -> Vec<&Proposal> {
615        commit
616            .proposals
617            .iter()
618            .filter_map(|p| match p {
619                ProposalOrRef::Proposal(p) => Some(p.as_ref()),
620                _ => None,
621            })
622            .collect()
623    }
624}
625
626#[allow(clippy::large_enum_variant)]
627#[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode)]
628#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
629#[repr(u16)]
630pub(crate) enum MlsMessagePayload {
631    Plain(PublicMessage) = 1u16,
632    #[cfg(feature = "private_message")]
633    Cipher(PrivateMessage) = 2u16,
634    Welcome(Welcome) = 3u16,
635    GroupInfo(GroupInfo) = 4u16,
636    KeyPackage(KeyPackage) = 5u16,
637}
638
639impl From<PublicMessage> for MlsMessagePayload {
640    fn from(m: PublicMessage) -> Self {
641        Self::Plain(m)
642    }
643}
644
645#[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::ffi_type)]
646#[derive(
647    Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd, MlsSize, MlsEncode, MlsDecode,
648)]
649#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
650#[repr(u16)]
651#[non_exhaustive]
652/// Content description of an [`MlsMessage`]
653pub enum WireFormat {
654    PublicMessage = 1u16,
655    PrivateMessage = 2u16,
656    Welcome = 3u16,
657    GroupInfo = 4u16,
658    KeyPackage = 5u16,
659}
660
661#[derive(Clone, PartialEq, MlsSize, MlsEncode, MlsDecode)]
662#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
663#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
664pub(crate) struct FramedContent {
665    #[mls_codec(with = "mls_rs_codec::byte_vec")]
666    #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))]
667    pub group_id: Vec<u8>,
668    pub epoch: u64,
669    pub sender: Sender,
670    #[mls_codec(with = "mls_rs_codec::byte_vec")]
671    #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))]
672    pub authenticated_data: Vec<u8>,
673    pub content: Content,
674}
675
676impl Debug for FramedContent {
677    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
678        f.debug_struct("FramedContent")
679            .field(
680                "group_id",
681                &mls_rs_core::debug::pretty_group_id(&self.group_id),
682            )
683            .field("epoch", &self.epoch)
684            .field("sender", &self.sender)
685            .field(
686                "authenticated_data",
687                &mls_rs_core::debug::pretty_bytes(&self.authenticated_data),
688            )
689            .field("content", &self.content)
690            .finish()
691    }
692}
693
694impl FramedContent {
695    pub fn content_type(&self) -> ContentType {
696        self.content.content_type()
697    }
698}
699
700#[cfg(test)]
701pub(crate) mod test_utils {
702    #[cfg(feature = "private_message")]
703    use crate::group::test_utils::random_bytes;
704
705    use crate::group::{AuthenticatedContent, MessageSignature};
706
707    use super::*;
708
709    use alloc::boxed::Box;
710
711    pub(crate) fn get_test_auth_content() -> AuthenticatedContent {
712        // This is not a valid commit and should not be validated
713        let commit = Commit {
714            proposals: Default::default(),
715            path: None,
716        };
717
718        AuthenticatedContent {
719            wire_format: WireFormat::PublicMessage,
720            content: FramedContent {
721                group_id: Vec::new(),
722                epoch: 0,
723                sender: Sender::Member(1),
724                authenticated_data: Vec::new(),
725                content: Content::Commit(Box::new(commit)),
726            },
727            auth: FramedContentAuthData {
728                signature: MessageSignature::empty(),
729                confirmation_tag: None,
730            },
731        }
732    }
733
734    #[cfg(feature = "private_message")]
735    pub(crate) fn get_test_ciphertext_content() -> PrivateMessageContent {
736        PrivateMessageContent {
737            content: Content::Application(random_bytes(1024).into()),
738            auth: FramedContentAuthData {
739                signature: MessageSignature::from(random_bytes(128)),
740                confirmation_tag: None,
741            },
742        }
743    }
744
745    impl AsRef<[u8]> for ApplicationData {
746        fn as_ref(&self) -> &[u8] {
747            &self.0
748        }
749    }
750}
751
752#[cfg(feature = "private_message")]
753#[cfg(test)]
754mod tests {
755    use alloc::vec;
756    use assert_matches::assert_matches;
757
758    use crate::{
759        client::test_utils::{TEST_CIPHER_SUITE, TEST_PROTOCOL_VERSION},
760        crypto::test_utils::test_cipher_suite_provider,
761        group::{
762            framing::test_utils::get_test_ciphertext_content,
763            proposal_ref::test_utils::auth_content_from_proposal, test_utils::test_group,
764            RemoveProposal,
765        },
766        key_package::test_utils::test_key_package_message,
767    };
768
769    use super::*;
770
771    #[test]
772    fn test_mls_ciphertext_content_mls_encoding() {
773        let ciphertext_content = get_test_ciphertext_content();
774
775        let mut encoded = ciphertext_content.mls_encode_to_vec().unwrap();
776        encoded.extend_from_slice(&[0u8; 128]);
777
778        let decoded =
779            PrivateMessageContent::mls_decode(&mut &*encoded, (&ciphertext_content.content).into())
780                .unwrap();
781
782        assert_eq!(ciphertext_content, decoded);
783    }
784
785    #[test]
786    fn test_mls_ciphertext_content_non_zero_padding_error() {
787        let ciphertext_content = get_test_ciphertext_content();
788
789        let mut encoded = ciphertext_content.mls_encode_to_vec().unwrap();
790        encoded.extend_from_slice(&[1u8; 128]);
791
792        let decoded =
793            PrivateMessageContent::mls_decode(&mut &*encoded, (&ciphertext_content.content).into());
794
795        assert_matches!(decoded, Err(mls_rs_codec::Error::Custom(_)));
796    }
797
798    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
799    async fn proposal_ref() {
800        let cs = test_cipher_suite_provider(TEST_CIPHER_SUITE);
801
802        let test_auth = auth_content_from_proposal(
803            Proposal::Remove(RemoveProposal {
804                to_remove: LeafIndex::unchecked(0),
805            }),
806            Sender::External(0),
807        );
808
809        let expected_ref = ProposalRef::from_content(&cs, &test_auth).await.unwrap();
810
811        let test_message = MlsMessage {
812            version: TEST_PROTOCOL_VERSION,
813            payload: MlsMessagePayload::Plain(PublicMessage {
814                content: test_auth.content,
815                auth: test_auth.auth,
816                membership_tag: Some(cs.mac(&[1, 2, 3], &[1, 2, 3]).await.unwrap().into()),
817            }),
818        };
819
820        let computed_ref = test_message
821            .into_proposal_reference(&cs)
822            .await
823            .unwrap()
824            .unwrap();
825
826        assert_eq!(computed_ref, expected_ref.to_vec());
827    }
828
829    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
830    async fn message_description() {
831        let mut group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
832
833        let message = group.commit(vec![]).await.unwrap();
834
835        let expected = MlsMessageDescription::PublicProtocolMessage {
836            group_id: group.group_id(),
837            epoch_id: group.context().epoch,
838            content_type: ContentType::Commit,
839            sender: Sender::Member(0),
840            authenticated_data: &[],
841        };
842
843        assert_eq!(message.commit_message.description(), expected);
844
845        group.apply_pending_commit().await.unwrap();
846
847        let message = group
848            .encrypt_application_message(b"123", vec![])
849            .await
850            .unwrap();
851
852        let expected = MlsMessageDescription::PrivateProtocolMessage {
853            group_id: group.group_id(),
854            epoch_id: group.context().epoch,
855            content_type: ContentType::Application,
856        };
857
858        assert_eq!(message.description(), expected);
859
860        let group_info = group
861            .group_info_message_allowing_ext_commit(true)
862            .await
863            .unwrap();
864
865        assert_eq!(group_info.description(), MlsMessageDescription::GroupInfo);
866
867        let key_package =
868            test_key_package_message(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "something").await;
869
870        assert_eq!(key_package.description(), MlsMessageDescription::KeyPackage);
871    }
872}