Skip to main content

mls_spec/messages/
content.rs

1use tls_codec::Deserialize;
2
3use crate::{
4    MlsSpecError, MlsSpecResult, SensitiveBytes,
5    crypto::Mac,
6    defs::{Epoch, ProposalType, ProtocolVersion, WireFormat},
7    group::{GroupId, group_info::GroupInfo, welcome::Welcome},
8    key_package::KeyPackage,
9    key_schedule::{ConfirmedTranscriptHashInput, GroupContext},
10    messages::{ContentType, ContentTypeInner, PrivateMessage, PublicMessage, Sender, SenderType},
11};
12
13#[derive(
14    Debug,
15    Clone,
16    PartialEq,
17    Eq,
18    tls_codec::TlsSerialize,
19    tls_codec::TlsDeserialize,
20    tls_codec::TlsSize,
21)]
22#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
23pub struct FramedContent {
24    pub group_id: GroupId,
25    pub epoch: Epoch,
26    pub sender: Sender,
27    pub authenticated_data: SensitiveBytes,
28    pub content: ContentTypeInner,
29}
30
31impl FramedContent {
32    pub fn to_tbs<'a>(
33        &'a self,
34        wire_format: &'a WireFormat,
35        ctx: &'a GroupContext,
36    ) -> MlsSpecResult<FramedContentTBS<'a>> {
37        let sender_type_raw: SenderType = (&self.sender).into();
38        let sender_type =
39            FramedContentTBSSenderType::from_sender_type_with_ctx(sender_type_raw, Some(ctx))?;
40
41        Ok(FramedContentTBS {
42            version: &ctx.version,
43            wire_format,
44            content: self,
45            sender_type,
46        })
47    }
48}
49
50#[derive(Debug, Copy, Clone, PartialEq, Eq, tls_codec::TlsSerialize, tls_codec::TlsSize)]
51#[cfg_attr(feature = "serde", derive(serde::Serialize))]
52#[repr(u8)]
53pub enum FramedContentTBSSenderType<'a> {
54    #[tls_codec(discriminant = "SenderType::Member")]
55    Member(FramedContentTBSSenderTypeContext<'a>),
56    #[tls_codec(discriminant = "SenderType::External")]
57    External,
58    #[tls_codec(discriminant = "SenderType::NewMemberCommit")]
59    NewMemberCommit(FramedContentTBSSenderTypeContext<'a>),
60    #[tls_codec(discriminant = "SenderType::NewMemberProposal")]
61    NewMemberProposal,
62}
63
64impl<'a> FramedContentTBSSenderType<'a> {
65    pub fn from_sender_type_with_ctx(
66        sender_type: SenderType,
67        mut ctx: Option<&'a GroupContext>,
68    ) -> MlsSpecResult<Self> {
69        Ok(match sender_type {
70            SenderType::NewMemberCommit => {
71                let Some(context) = ctx.take() else {
72                    return Err(MlsSpecError::FramedContentTBSMissingGroupContext);
73                };
74                Self::NewMemberCommit(FramedContentTBSSenderTypeContext { context })
75            }
76            SenderType::Member => {
77                let Some(context) = ctx.take() else {
78                    return Err(MlsSpecError::FramedContentTBSMissingGroupContext);
79                };
80                Self::Member(FramedContentTBSSenderTypeContext { context })
81            }
82            SenderType::External => Self::External,
83            SenderType::NewMemberProposal => Self::NewMemberProposal,
84            _ => return Err(MlsSpecError::ReservedValueUsage),
85        })
86    }
87}
88
89#[derive(Debug, Copy, Clone, PartialEq, Eq, tls_codec::TlsSerialize, tls_codec::TlsSize)]
90#[cfg_attr(feature = "serde", derive(serde::Serialize))]
91pub struct FramedContentTBSSenderTypeContext<'a> {
92    pub context: &'a GroupContext,
93}
94
95#[derive(Debug, Copy, Clone, PartialEq, Eq)]
96pub struct FramedContentTBS<'a> {
97    pub version: &'a ProtocolVersion,
98    pub wire_format: &'a WireFormat,
99    pub content: &'a FramedContent,
100    pub sender_type: FramedContentTBSSenderType<'a>,
101}
102
103// Impl TLS serialization by hand to make this depend on `self.content.sender.sender_type`'s discriminant
104impl tls_codec::Size for FramedContentTBS<'_> {
105    fn tls_serialized_len(&self) -> usize {
106        let mut len = self.version.tls_serialized_len()
107            + self.wire_format.tls_serialized_len()
108            + self.content.tls_serialized_len();
109        if matches!(
110            self.content.sender,
111            Sender::Member(_) | Sender::NewMemberCommit
112        ) {
113            match &self.sender_type {
114                FramedContentTBSSenderType::NewMemberCommit(context)
115                | FramedContentTBSSenderType::Member(context) => {
116                    len += context.tls_serialized_len();
117                }
118                _ => {}
119            }
120        }
121        len
122    }
123}
124
125impl tls_codec::Serialize for FramedContentTBS<'_> {
126    fn tls_serialize<W: std::io::Write>(&self, writer: &mut W) -> Result<usize, tls_codec::Error> {
127        let mut ret = self.version.tls_serialize(writer)?;
128
129        ret += self.wire_format.tls_serialize(writer)?;
130        ret += self.content.tls_serialize(writer)?;
131        if matches!(
132            self.content.sender,
133            Sender::Member(_) | Sender::NewMemberCommit
134        ) {
135            match &self.sender_type {
136                FramedContentTBSSenderType::NewMemberCommit(context)
137                | FramedContentTBSSenderType::Member(context) => {
138                    ret += context.tls_serialize(writer)?;
139                }
140                _ => {}
141            }
142        }
143
144        Ok(ret)
145    }
146}
147
148impl<'a> tls_codec::Size for &'a FramedContentTBS<'a> {
149    fn tls_serialized_len(&self) -> usize {
150        (*self).tls_serialized_len()
151    }
152}
153
154impl<'a> tls_codec::Serialize for &'a FramedContentTBS<'a> {
155    fn tls_serialize<W: std::io::Write>(&self, writer: &mut W) -> Result<usize, tls_codec::Error> {
156        (*self).tls_serialize(writer)
157    }
158}
159
160#[derive(Debug, Clone, PartialEq, Eq)]
161#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
162pub struct FramedContentAuthData {
163    pub signature: SensitiveBytes,
164    pub confirmation_tag: Option<Mac>,
165}
166
167impl FramedContentAuthData {
168    pub fn without_confirmation_tag(&self) -> Self {
169        Self {
170            signature: self.signature.clone(),
171            confirmation_tag: None,
172        }
173    }
174}
175
176impl tls_codec::Size for FramedContentAuthData {
177    fn tls_serialized_len(&self) -> usize {
178        self.signature.tls_serialized_len()
179            + self
180                .confirmation_tag
181                .as_ref()
182                .map_or(0, SensitiveBytes::tls_serialized_len)
183    }
184}
185
186impl tls_codec::Size for &FramedContentAuthData {
187    fn tls_serialized_len(&self) -> usize {
188        (*self).tls_serialized_len()
189    }
190}
191
192impl tls_codec::Serialize for FramedContentAuthData {
193    fn tls_serialize<W: std::io::Write>(&self, writer: &mut W) -> Result<usize, tls_codec::Error> {
194        let mut written = self.signature.tls_serialize(writer)?;
195        if let Some(confirmation_tag) = &self.confirmation_tag {
196            written += confirmation_tag.tls_serialize(writer)?;
197        }
198        Ok(written)
199    }
200}
201
202impl tls_codec::Serialize for &FramedContentAuthData {
203    fn tls_serialize<W: std::io::Write>(&self, writer: &mut W) -> Result<usize, tls_codec::Error> {
204        (*self).tls_serialize(writer)
205    }
206}
207
208impl FramedContentAuthData {
209    pub fn tls_deserialize_with_content_type<R: std::io::Read>(
210        bytes: &mut R,
211        content_type: ContentType,
212    ) -> Result<Self, tls_codec::Error> {
213        let signature = SensitiveBytes::tls_deserialize(bytes)?;
214        let confirmation_tag = (content_type == ContentType::Commit)
215            .then(|| Mac::tls_deserialize(bytes))
216            .transpose()?;
217
218        Ok(Self {
219            signature,
220            confirmation_tag,
221        })
222    }
223}
224
225#[derive(
226    Debug,
227    Clone,
228    PartialEq,
229    Eq,
230    tls_codec::TlsSerialize,
231    tls_codec::TlsDeserialize,
232    tls_codec::TlsSize,
233)]
234#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
235#[repr(u16)]
236pub enum MlsMessageContent {
237    #[tls_codec(discriminant = "WireFormat::MLS_PUBLIC_MESSAGE")]
238    MlsPublicMessage(PublicMessage),
239    #[tls_codec(discriminant = "WireFormat::MLS_PRIVATE_MESSAGE")]
240    MlsPrivateMessage(PrivateMessage),
241    #[tls_codec(discriminant = "WireFormat::MLS_WELCOME")]
242    Welcome(Welcome),
243    #[tls_codec(discriminant = "WireFormat::MLS_GROUP_INFO")]
244    GroupInfo(GroupInfo),
245    #[tls_codec(discriminant = "WireFormat::MLS_KEY_PACKAGE")]
246    KeyPackage(KeyPackage),
247    #[cfg(feature = "draft-ietf-mls-targeted-messages")]
248    #[tls_codec(discriminant = "WireFormat::MLS_TARGETED_MESSAGE")]
249    MlsTargetedMessage(crate::drafts::targeted_messages::TargetedMessage),
250    #[cfg(feature = "draft-mahy-mls-semiprivatemessage")]
251    #[tls_codec(discriminant = "WireFormat::MLS_SEMIPRIVATE_MESSAGE")]
252    MlsSemiPrivateMessage(crate::drafts::semiprivate_message::messages::SemiPrivateMessage),
253    #[cfg(feature = "draft-mularczyk-mls-splitcommit")]
254    #[tls_codec(discriminant = "WireFormat::MLS_SPLIT_COMMIT")]
255    MlsSplitCommitMessage(crate::drafts::split_commit::SplitCommitMessage),
256    #[cfg(feature = "draft-pham-mls-additional-wire-formats")]
257    #[tls_codec(discriminant = "WireFormat::MLS_MESSAGE_WITHOUT_AAD")]
258    MlsMessageWithoutAad(crate::drafts::additional_wire_formats::MessageWithoutAad),
259    #[cfg(feature = "draft-mahy-mls-private-external")]
260    #[tls_codec(discriminant = "WireFormat::MLS_PRIVATE_EXTERNAL_MESSAGE")]
261    MlsPrivateExternalMessage(crate::drafts::private_external::PrivateExternalMessage),
262    #[cfg(feature = "draft-kohbrok-mls-leaf-operation-intents")]
263    #[tls_codec(discriminant = "WireFormat::MLS_LEAF_OPERATION_INTENT")]
264    MlsLeafOperationIntent(crate::drafts::leaf_operation_intents::LeafOperationIntent),
265}
266
267impl MlsMessageContent {
268    pub fn content_type(&self) -> Option<ContentType> {
269        match self {
270            MlsMessageContent::MlsPublicMessage(pub_msg) => Some((&pub_msg.content.content).into()),
271            MlsMessageContent::MlsPrivateMessage(priv_msg) => Some(priv_msg.content_type),
272            #[cfg(feature = "draft-mahy-mls-semiprivatemessage")]
273            MlsMessageContent::MlsSemiPrivateMessage(semi_priv_msg) => {
274                Some(semi_priv_msg.content_type)
275            }
276            #[cfg(feature = "draft-mularczyk-mls-splitcommit")]
277            MlsMessageContent::MlsSplitCommitMessage(message) => {
278                message.split_commit_message.content.content_type()
279            }
280            #[cfg(feature = "draft-mahy-mls-private-external")]
281            MlsMessageContent::MlsPrivateExternalMessage(message) => Some(message.content_type),
282            _ => None,
283        }
284    }
285
286    pub fn proposal_type(&self) -> Option<ProposalType> {
287        match self {
288            MlsMessageContent::MlsPublicMessage(pub_msg) => {
289                if let ContentTypeInner::Proposal { proposal } = &pub_msg.content.content {
290                    Some(proposal.into())
291                } else {
292                    None
293                }
294            }
295            #[cfg(feature = "draft-mularczyk-mls-splitcommit")]
296            MlsMessageContent::MlsSplitCommitMessage(message) => {
297                message.split_commit_message.content.proposal_type()
298            }
299            _ => None,
300        }
301    }
302
303    pub fn authenticated_data(&self) -> Option<&[u8]> {
304        match self {
305            MlsMessageContent::MlsPublicMessage(public_message) => {
306                Some(&public_message.content.authenticated_data)
307            }
308            MlsMessageContent::MlsPrivateMessage(private_message) => {
309                Some(&private_message.authenticated_data)
310            }
311            #[cfg(feature = "draft-mahy-mls-semiprivatemessage")]
312            MlsMessageContent::MlsSemiPrivateMessage(semi_priv_msg) => {
313                Some(&semi_priv_msg.authenticated_data)
314            }
315            #[cfg(feature = "draft-mularczyk-mls-splitcommit")]
316            MlsMessageContent::MlsSplitCommitMessage(message) => {
317                message.split_commit_message.content.authenticated_data()
318            }
319            #[cfg(feature = "draft-mahy-mls-private-external")]
320            MlsMessageContent::MlsPrivateExternalMessage(message) => {
321                Some(&message.authenticated_data)
322            }
323            _ => None,
324        }
325    }
326}
327
328#[allow(clippy::from_over_into)]
329impl Into<WireFormat> for &MlsMessageContent {
330    fn into(self) -> WireFormat {
331        match self {
332            MlsMessageContent::MlsPublicMessage(_) => {
333                WireFormat::new_unchecked(WireFormat::MLS_PUBLIC_MESSAGE)
334            }
335            MlsMessageContent::MlsPrivateMessage(_) => {
336                WireFormat::new_unchecked(WireFormat::MLS_PRIVATE_MESSAGE)
337            }
338            MlsMessageContent::Welcome(_) => WireFormat::new_unchecked(WireFormat::MLS_WELCOME),
339            MlsMessageContent::GroupInfo(_) => {
340                WireFormat::new_unchecked(WireFormat::MLS_GROUP_INFO)
341            }
342            MlsMessageContent::KeyPackage(_) => {
343                WireFormat::new_unchecked(WireFormat::MLS_KEY_PACKAGE)
344            }
345            #[cfg(feature = "draft-ietf-mls-targeted-messages")]
346            MlsMessageContent::MlsTargetedMessage(_) => {
347                WireFormat::new_unchecked(WireFormat::MLS_TARGETED_MESSAGE)
348            }
349            #[cfg(feature = "draft-mahy-mls-semiprivatemessage")]
350            MlsMessageContent::MlsSemiPrivateMessage(_) => {
351                WireFormat::new_unchecked(WireFormat::MLS_SEMIPRIVATE_MESSAGE)
352            }
353            #[cfg(feature = "draft-mularczyk-mls-splitcommit")]
354            MlsMessageContent::MlsSplitCommitMessage(_) => {
355                WireFormat::new_unchecked(WireFormat::MLS_SPLIT_COMMIT)
356            }
357            #[cfg(feature = "draft-pham-mls-additional-wire-formats")]
358            MlsMessageContent::MlsMessageWithoutAad(_) => {
359                WireFormat::new_unchecked(WireFormat::MLS_MESSAGE_WITHOUT_AAD)
360            }
361            #[cfg(feature = "draft-mahy-mls-private-external")]
362            MlsMessageContent::MlsPrivateExternalMessage(_) => {
363                WireFormat::new_unchecked(WireFormat::MLS_PRIVATE_EXTERNAL_MESSAGE)
364            }
365            #[cfg(feature = "draft-kohbrok-mls-leaf-operation-intents")]
366            MlsMessageContent::MlsLeafOperationIntent(_) => {
367                WireFormat::new_unchecked(WireFormat::MLS_LEAF_OPERATION_INTENT)
368            }
369        }
370    }
371}
372
373#[derive(Debug, Clone, PartialEq, Eq, tls_codec::TlsSerialize, tls_codec::TlsSize)]
374#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
375pub struct AuthenticatedContent {
376    pub wire_format: WireFormat,
377    pub content: FramedContent,
378    pub auth: FramedContentAuthData,
379}
380
381#[derive(Debug, Clone, PartialEq, Eq, tls_codec::TlsSerialize, tls_codec::TlsSize)]
382#[cfg_attr(feature = "serde", derive(serde::Serialize))]
383pub struct AuthenticatedContentRef<'a> {
384    pub wire_format: &'a WireFormat,
385    pub content: &'a FramedContent,
386    pub auth: &'a FramedContentAuthData,
387}
388
389impl AuthenticatedContent {
390    pub fn confirmed_transcript_hash_input(&self) -> ConfirmedTranscriptHashInput<'_> {
391        ConfirmedTranscriptHashInput {
392            wire_format: &self.wire_format,
393            content: &self.content,
394            signature: &self.auth.signature,
395        }
396    }
397
398    pub fn as_ref(&self) -> AuthenticatedContentRef<'_> {
399        AuthenticatedContentRef {
400            wire_format: &self.wire_format,
401            content: &self.content,
402            auth: &self.auth,
403        }
404    }
405}
406
407impl tls_codec::Deserialize for AuthenticatedContent {
408    fn tls_deserialize<R: std::io::Read>(bytes: &mut R) -> Result<Self, tls_codec::Error>
409    where
410        Self: Sized,
411    {
412        let wire_format = WireFormat::tls_deserialize(bytes)?;
413        let content = FramedContent::tls_deserialize(bytes)?;
414        let auth = FramedContentAuthData::tls_deserialize_with_content_type(
415            bytes,
416            (&content.content).into(),
417        )?;
418        Ok(Self {
419            wire_format,
420            content,
421            auth,
422        })
423    }
424}