mls_spec/messages/
content.rs

1use tls_codec::Deserialize;
2
3use crate::{
4    MlsSpecError, MlsSpecResult, SensitiveBytes,
5    crypto::Mac,
6    defs::{Epoch, ProposalType, ProtocolVersion, WireFormat},
7    group::{GroupId, group_info::GroupInfo, welcome::Welcome},
8    key_package::KeyPackage,
9    key_schedule::{ConfirmedTranscriptHashInput, GroupContext},
10    messages::{ContentType, ContentTypeInner, PrivateMessage, PublicMessage, Sender, SenderType},
11};
12
13#[derive(
14    Debug,
15    Clone,
16    PartialEq,
17    Eq,
18    tls_codec::TlsSerialize,
19    tls_codec::TlsDeserialize,
20    tls_codec::TlsSize,
21)]
22#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
23pub struct FramedContent {
24    pub group_id: GroupId,
25    pub epoch: Epoch,
26    pub sender: Sender,
27    pub authenticated_data: SensitiveBytes,
28    pub content: ContentTypeInner,
29}
30
31impl FramedContent {
32    pub fn to_tbs<'a>(
33        &'a self,
34        wire_format: &'a WireFormat,
35        ctx: &'a GroupContext,
36    ) -> MlsSpecResult<FramedContentTBS<'a>> {
37        let sender_type_raw: SenderType = (&self.sender).into();
38        let sender_type =
39            FramedContentTBSSenderType::from_sender_type_with_ctx(sender_type_raw, Some(ctx))?;
40
41        Ok(FramedContentTBS {
42            version: &ctx.version,
43            wire_format,
44            content: self,
45            sender_type,
46        })
47    }
48}
49
50#[derive(Debug, Copy, Clone, PartialEq, Eq, tls_codec::TlsSerialize, tls_codec::TlsSize)]
51#[cfg_attr(feature = "serde", derive(serde::Serialize))]
52#[repr(u8)]
53pub enum FramedContentTBSSenderType<'a> {
54    #[tls_codec(discriminant = "SenderType::Member")]
55    Member(FramedContentTBSSenderTypeContext<'a>),
56    #[tls_codec(discriminant = "SenderType::External")]
57    External,
58    #[tls_codec(discriminant = "SenderType::NewMemberCommit")]
59    NewMemberCommit(FramedContentTBSSenderTypeContext<'a>),
60    #[tls_codec(discriminant = "SenderType::NewMemberProposal")]
61    NewMemberProposal,
62}
63
64impl<'a> FramedContentTBSSenderType<'a> {
65    pub fn from_sender_type_with_ctx(
66        sender_type: SenderType,
67        mut ctx: Option<&'a GroupContext>,
68    ) -> MlsSpecResult<Self> {
69        Ok(match sender_type {
70            SenderType::NewMemberCommit => {
71                let Some(context) = ctx.take() else {
72                    return Err(MlsSpecError::FramedContentTBSMissingGroupContext);
73                };
74                Self::NewMemberCommit(FramedContentTBSSenderTypeContext { context })
75            }
76            SenderType::Member => {
77                let Some(context) = ctx.take() else {
78                    return Err(MlsSpecError::FramedContentTBSMissingGroupContext);
79                };
80                Self::Member(FramedContentTBSSenderTypeContext { context })
81            }
82            SenderType::External => Self::External,
83            SenderType::NewMemberProposal => Self::NewMemberProposal,
84            _ => return Err(MlsSpecError::ReservedValueUsage),
85        })
86    }
87}
88
89#[derive(Debug, Copy, Clone, PartialEq, Eq, tls_codec::TlsSerialize, tls_codec::TlsSize)]
90#[cfg_attr(feature = "serde", derive(serde::Serialize))]
91pub struct FramedContentTBSSenderTypeContext<'a> {
92    pub context: &'a GroupContext,
93}
94
95#[derive(Debug, Copy, Clone, PartialEq, Eq)]
96pub struct FramedContentTBS<'a> {
97    pub version: &'a ProtocolVersion,
98    pub wire_format: &'a WireFormat,
99    pub content: &'a FramedContent,
100    pub sender_type: FramedContentTBSSenderType<'a>,
101}
102
103// Impl TLS serialization by hand to make this depend on `self.content.sender.sender_type`'s discriminant
104impl tls_codec::Size for FramedContentTBS<'_> {
105    fn tls_serialized_len(&self) -> usize {
106        let mut len = self.version.tls_serialized_len()
107            + self.wire_format.tls_serialized_len()
108            + self.content.tls_serialized_len();
109        if matches!(
110            self.content.sender,
111            Sender::Member(_) | Sender::NewMemberCommit
112        ) {
113            match &self.sender_type {
114                FramedContentTBSSenderType::NewMemberCommit(context)
115                | FramedContentTBSSenderType::Member(context) => {
116                    len += context.tls_serialized_len();
117                }
118                _ => {}
119            }
120        }
121        len
122    }
123}
124
125impl tls_codec::Serialize for FramedContentTBS<'_> {
126    fn tls_serialize<W: std::io::Write>(&self, writer: &mut W) -> Result<usize, tls_codec::Error> {
127        let mut ret = self.version.tls_serialize(writer)?;
128
129        ret += self.wire_format.tls_serialize(writer)?;
130        ret += self.content.tls_serialize(writer)?;
131        if matches!(
132            self.content.sender,
133            Sender::Member(_) | Sender::NewMemberCommit
134        ) {
135            match &self.sender_type {
136                FramedContentTBSSenderType::NewMemberCommit(context)
137                | FramedContentTBSSenderType::Member(context) => {
138                    ret += context.tls_serialize(writer)?;
139                }
140                _ => {}
141            }
142        }
143
144        Ok(ret)
145    }
146}
147
148impl<'a> tls_codec::Size for &'a FramedContentTBS<'a> {
149    fn tls_serialized_len(&self) -> usize {
150        (*self).tls_serialized_len()
151    }
152}
153
154impl<'a> tls_codec::Serialize for &'a FramedContentTBS<'a> {
155    fn tls_serialize<W: std::io::Write>(&self, writer: &mut W) -> Result<usize, tls_codec::Error> {
156        (*self).tls_serialize(writer)
157    }
158}
159
160#[derive(Debug, Clone, PartialEq, Eq)]
161#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
162pub struct FramedContentAuthData {
163    pub signature: SensitiveBytes,
164    pub confirmation_tag: Option<Mac>,
165}
166
167impl FramedContentAuthData {
168    pub fn without_confirmation_tag(&self) -> Self {
169        Self {
170            signature: self.signature.clone(),
171            confirmation_tag: None,
172        }
173    }
174}
175
176impl tls_codec::Size for FramedContentAuthData {
177    fn tls_serialized_len(&self) -> usize {
178        self.signature.tls_serialized_len()
179            + self
180                .confirmation_tag
181                .as_ref()
182                .map_or(0, SensitiveBytes::tls_serialized_len)
183    }
184}
185
186impl tls_codec::Size for &FramedContentAuthData {
187    fn tls_serialized_len(&self) -> usize {
188        (*self).tls_serialized_len()
189    }
190}
191
192impl tls_codec::Serialize for FramedContentAuthData {
193    fn tls_serialize<W: std::io::Write>(&self, writer: &mut W) -> Result<usize, tls_codec::Error> {
194        let mut written = self.signature.tls_serialize(writer)?;
195        if let Some(confirmation_tag) = &self.confirmation_tag {
196            written += confirmation_tag.tls_serialize(writer)?;
197        }
198        Ok(written)
199    }
200}
201
202impl tls_codec::Serialize for &FramedContentAuthData {
203    fn tls_serialize<W: std::io::Write>(&self, writer: &mut W) -> Result<usize, tls_codec::Error> {
204        (*self).tls_serialize(writer)
205    }
206}
207
208impl FramedContentAuthData {
209    pub fn tls_deserialize_with_content_type<R: std::io::Read>(
210        bytes: &mut R,
211        content_type: ContentType,
212    ) -> Result<Self, tls_codec::Error> {
213        let signature = SensitiveBytes::tls_deserialize(bytes)?;
214        let confirmation_tag = (content_type == ContentType::Commit)
215            .then(|| Mac::tls_deserialize(bytes))
216            .transpose()?;
217
218        Ok(Self {
219            signature,
220            confirmation_tag,
221        })
222    }
223}
224
225#[derive(
226    Debug,
227    Clone,
228    PartialEq,
229    Eq,
230    tls_codec::TlsSerialize,
231    tls_codec::TlsDeserialize,
232    tls_codec::TlsSize,
233)]
234#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
235#[repr(u16)]
236pub enum MlsMessageContent {
237    #[tls_codec(discriminant = "WireFormat::MLS_PUBLIC_MESSAGE")]
238    MlsPublicMessage(PublicMessage),
239    #[tls_codec(discriminant = "WireFormat::MLS_PRIVATE_MESSAGE")]
240    MlsPrivateMessage(PrivateMessage),
241    #[tls_codec(discriminant = "WireFormat::MLS_WELCOME")]
242    Welcome(Welcome),
243    #[tls_codec(discriminant = "WireFormat::MLS_GROUP_INFO")]
244    GroupInfo(GroupInfo),
245    #[tls_codec(discriminant = "WireFormat::MLS_KEY_PACKAGE")]
246    KeyPackage(KeyPackage),
247    #[cfg(feature = "draft-ietf-mls-extensions")]
248    #[tls_codec(discriminant = "WireFormat::MLS_TARGETED_MESSAGE")]
249    MlsTargetedMessage(crate::drafts::mls_extensions::targeted_message::TargetedMessage),
250    #[cfg(feature = "draft-mahy-mls-semiprivatemessage")]
251    #[tls_codec(discriminant = "WireFormat::MLS_SEMIPRIVATE_MESSAGE")]
252    MlsSemiPrivateMessage(crate::drafts::semiprivate_message::messages::SemiPrivateMessage),
253    #[cfg(feature = "draft-mularczyk-mls-splitcommit")]
254    #[tls_codec(discriminant = "WireFormat::MLS_SPLIT_COMMIT")]
255    MlsSplitCommitMessage(crate::drafts::split_commit::SplitCommitMessage),
256    #[cfg(feature = "draft-pham-mls-additional-wire-formats")]
257    #[tls_codec(discriminant = "WireFormat::MLS_MESSAGE_WITHOUT_AAD")]
258    MlsMessageWithoutAad(crate::drafts::additional_wire_formats::MessageWithoutAad),
259}
260
261impl MlsMessageContent {
262    pub fn content_type(&self) -> Option<ContentType> {
263        match self {
264            MlsMessageContent::MlsPublicMessage(pub_msg) => Some((&pub_msg.content.content).into()),
265            MlsMessageContent::MlsPrivateMessage(priv_msg) => Some(priv_msg.content_type),
266            #[cfg(feature = "draft-mularczyk-mls-splitcommit")]
267            MlsMessageContent::MlsSplitCommitMessage(message) => {
268                message.split_commit_message.content.content_type()
269            }
270            _ => None,
271        }
272    }
273
274    pub fn proposal_type(&self) -> Option<ProposalType> {
275        match self {
276            MlsMessageContent::MlsPublicMessage(pub_msg) => {
277                if let ContentTypeInner::Proposal { proposal } = &pub_msg.content.content {
278                    Some(proposal.into())
279                } else {
280                    None
281                }
282            }
283            #[cfg(feature = "draft-mularczyk-mls-splitcommit")]
284            MlsMessageContent::MlsSplitCommitMessage(message) => {
285                message.split_commit_message.content.proposal_type()
286            }
287            _ => None,
288        }
289    }
290}
291
292#[allow(clippy::from_over_into)]
293impl Into<WireFormat> for &MlsMessageContent {
294    fn into(self) -> WireFormat {
295        match self {
296            MlsMessageContent::MlsPublicMessage(_) => {
297                WireFormat::new_unchecked(WireFormat::MLS_PUBLIC_MESSAGE)
298            }
299            MlsMessageContent::MlsPrivateMessage(_) => {
300                WireFormat::new_unchecked(WireFormat::MLS_PRIVATE_MESSAGE)
301            }
302            MlsMessageContent::Welcome(_) => WireFormat::new_unchecked(WireFormat::MLS_WELCOME),
303            MlsMessageContent::GroupInfo(_) => {
304                WireFormat::new_unchecked(WireFormat::MLS_GROUP_INFO)
305            }
306            MlsMessageContent::KeyPackage(_) => {
307                WireFormat::new_unchecked(WireFormat::MLS_KEY_PACKAGE)
308            }
309            #[cfg(feature = "draft-ietf-mls-extensions")]
310            MlsMessageContent::MlsTargetedMessage(_) => {
311                WireFormat::new_unchecked(WireFormat::MLS_TARGETED_MESSAGE)
312            }
313            #[cfg(feature = "draft-mahy-mls-semiprivatemessage")]
314            MlsMessageContent::MlsSemiPrivateMessage(_) => {
315                WireFormat::new_unchecked(WireFormat::MLS_SEMIPRIVATE_MESSAGE)
316            }
317            #[cfg(feature = "draft-mularczyk-mls-splitcommit")]
318            MlsMessageContent::MlsSplitCommitMessage(_) => {
319                WireFormat::new_unchecked(WireFormat::MLS_SPLIT_COMMIT)
320            }
321            #[cfg(feature = "draft-pham-mls-additional-wire-formats")]
322            MlsMessageContent::MlsMessageWithoutAad(_) => {
323                WireFormat::new_unchecked(WireFormat::MLS_MESSAGE_WITHOUT_AAD)
324            }
325        }
326    }
327}
328
329#[derive(Debug, Clone, PartialEq, Eq, tls_codec::TlsSerialize, tls_codec::TlsSize)]
330#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
331pub struct AuthenticatedContent {
332    pub wire_format: WireFormat,
333    pub content: FramedContent,
334    pub auth: FramedContentAuthData,
335}
336
337#[derive(Debug, Clone, PartialEq, Eq, tls_codec::TlsSerialize, tls_codec::TlsSize)]
338#[cfg_attr(feature = "serde", derive(serde::Serialize))]
339pub struct AuthenticatedContentRef<'a> {
340    pub wire_format: &'a WireFormat,
341    pub content: &'a FramedContent,
342    pub auth: &'a FramedContentAuthData,
343}
344
345impl AuthenticatedContent {
346    pub fn confirmed_transcript_hash_input(&self) -> ConfirmedTranscriptHashInput {
347        ConfirmedTranscriptHashInput {
348            wire_format: &self.wire_format,
349            content: &self.content,
350            signature: &self.auth.signature,
351        }
352    }
353
354    pub fn as_ref(&self) -> AuthenticatedContentRef {
355        AuthenticatedContentRef {
356            wire_format: &self.wire_format,
357            content: &self.content,
358            auth: &self.auth,
359        }
360    }
361}
362
363impl tls_codec::Deserialize for AuthenticatedContent {
364    fn tls_deserialize<R: std::io::Read>(bytes: &mut R) -> Result<Self, tls_codec::Error>
365    where
366        Self: Sized,
367    {
368        let wire_format = WireFormat::tls_deserialize(bytes)?;
369        let content = FramedContent::tls_deserialize(bytes)?;
370        let auth = FramedContentAuthData::tls_deserialize_with_content_type(
371            bytes,
372            (&content.content).into(),
373        )?;
374        Ok(Self {
375            wire_format,
376            content,
377            auth,
378        })
379    }
380}