1use tls_codec::Deserialize;
2
3use crate::{
4 MlsSpecError, MlsSpecResult, SensitiveBytes,
5 crypto::Mac,
6 defs::{Epoch, ProposalType, ProtocolVersion, WireFormat},
7 group::{GroupId, group_info::GroupInfo, welcome::Welcome},
8 key_package::KeyPackage,
9 key_schedule::{ConfirmedTranscriptHashInput, GroupContext},
10 messages::{ContentType, ContentTypeInner, PrivateMessage, PublicMessage, Sender, SenderType},
11};
12
13#[derive(
14 Debug,
15 Clone,
16 PartialEq,
17 Eq,
18 tls_codec::TlsSerialize,
19 tls_codec::TlsDeserialize,
20 tls_codec::TlsSize,
21)]
22#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
23pub struct FramedContent {
24 pub group_id: GroupId,
25 pub epoch: Epoch,
26 pub sender: Sender,
27 pub authenticated_data: SensitiveBytes,
28 pub content: ContentTypeInner,
29}
30
31impl FramedContent {
32 pub fn to_tbs<'a>(
33 &'a self,
34 wire_format: &'a WireFormat,
35 ctx: &'a GroupContext,
36 ) -> MlsSpecResult<FramedContentTBS<'a>> {
37 let sender_type_raw: SenderType = (&self.sender).into();
38 let sender_type =
39 FramedContentTBSSenderType::from_sender_type_with_ctx(sender_type_raw, Some(ctx))?;
40
41 Ok(FramedContentTBS {
42 version: &ctx.version,
43 wire_format,
44 content: self,
45 sender_type,
46 })
47 }
48}
49
50#[derive(Debug, Copy, Clone, PartialEq, Eq, tls_codec::TlsSerialize, tls_codec::TlsSize)]
51#[cfg_attr(feature = "serde", derive(serde::Serialize))]
52#[repr(u8)]
53pub enum FramedContentTBSSenderType<'a> {
54 #[tls_codec(discriminant = "SenderType::Member")]
55 Member(FramedContentTBSSenderTypeContext<'a>),
56 #[tls_codec(discriminant = "SenderType::External")]
57 External,
58 #[tls_codec(discriminant = "SenderType::NewMemberCommit")]
59 NewMemberCommit(FramedContentTBSSenderTypeContext<'a>),
60 #[tls_codec(discriminant = "SenderType::NewMemberProposal")]
61 NewMemberProposal,
62}
63
64impl<'a> FramedContentTBSSenderType<'a> {
65 pub fn from_sender_type_with_ctx(
66 sender_type: SenderType,
67 mut ctx: Option<&'a GroupContext>,
68 ) -> MlsSpecResult<Self> {
69 Ok(match sender_type {
70 SenderType::NewMemberCommit => {
71 let Some(context) = ctx.take() else {
72 return Err(MlsSpecError::FramedContentTBSMissingGroupContext);
73 };
74 Self::NewMemberCommit(FramedContentTBSSenderTypeContext { context })
75 }
76 SenderType::Member => {
77 let Some(context) = ctx.take() else {
78 return Err(MlsSpecError::FramedContentTBSMissingGroupContext);
79 };
80 Self::Member(FramedContentTBSSenderTypeContext { context })
81 }
82 SenderType::External => Self::External,
83 SenderType::NewMemberProposal => Self::NewMemberProposal,
84 _ => return Err(MlsSpecError::ReservedValueUsage),
85 })
86 }
87}
88
89#[derive(Debug, Copy, Clone, PartialEq, Eq, tls_codec::TlsSerialize, tls_codec::TlsSize)]
90#[cfg_attr(feature = "serde", derive(serde::Serialize))]
91pub struct FramedContentTBSSenderTypeContext<'a> {
92 pub context: &'a GroupContext,
93}
94
95#[derive(Debug, Copy, Clone, PartialEq, Eq)]
96pub struct FramedContentTBS<'a> {
97 pub version: &'a ProtocolVersion,
98 pub wire_format: &'a WireFormat,
99 pub content: &'a FramedContent,
100 pub sender_type: FramedContentTBSSenderType<'a>,
101}
102
103impl tls_codec::Size for FramedContentTBS<'_> {
105 fn tls_serialized_len(&self) -> usize {
106 let mut len = self.version.tls_serialized_len()
107 + self.wire_format.tls_serialized_len()
108 + self.content.tls_serialized_len();
109 if matches!(
110 self.content.sender,
111 Sender::Member(_) | Sender::NewMemberCommit
112 ) {
113 match &self.sender_type {
114 FramedContentTBSSenderType::NewMemberCommit(context)
115 | FramedContentTBSSenderType::Member(context) => {
116 len += context.tls_serialized_len();
117 }
118 _ => {}
119 }
120 }
121 len
122 }
123}
124
125impl tls_codec::Serialize for FramedContentTBS<'_> {
126 fn tls_serialize<W: std::io::Write>(&self, writer: &mut W) -> Result<usize, tls_codec::Error> {
127 let mut ret = self.version.tls_serialize(writer)?;
128
129 ret += self.wire_format.tls_serialize(writer)?;
130 ret += self.content.tls_serialize(writer)?;
131 if matches!(
132 self.content.sender,
133 Sender::Member(_) | Sender::NewMemberCommit
134 ) {
135 match &self.sender_type {
136 FramedContentTBSSenderType::NewMemberCommit(context)
137 | FramedContentTBSSenderType::Member(context) => {
138 ret += context.tls_serialize(writer)?;
139 }
140 _ => {}
141 }
142 }
143
144 Ok(ret)
145 }
146}
147
148impl<'a> tls_codec::Size for &'a FramedContentTBS<'a> {
149 fn tls_serialized_len(&self) -> usize {
150 (*self).tls_serialized_len()
151 }
152}
153
154impl<'a> tls_codec::Serialize for &'a FramedContentTBS<'a> {
155 fn tls_serialize<W: std::io::Write>(&self, writer: &mut W) -> Result<usize, tls_codec::Error> {
156 (*self).tls_serialize(writer)
157 }
158}
159
160#[derive(Debug, Clone, PartialEq, Eq)]
161#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
162pub struct FramedContentAuthData {
163 pub signature: SensitiveBytes,
164 pub confirmation_tag: Option<Mac>,
165}
166
167impl FramedContentAuthData {
168 pub fn without_confirmation_tag(&self) -> Self {
169 Self {
170 signature: self.signature.clone(),
171 confirmation_tag: None,
172 }
173 }
174}
175
176impl tls_codec::Size for FramedContentAuthData {
177 fn tls_serialized_len(&self) -> usize {
178 self.signature.tls_serialized_len()
179 + self
180 .confirmation_tag
181 .as_ref()
182 .map_or(0, SensitiveBytes::tls_serialized_len)
183 }
184}
185
186impl tls_codec::Size for &FramedContentAuthData {
187 fn tls_serialized_len(&self) -> usize {
188 (*self).tls_serialized_len()
189 }
190}
191
192impl tls_codec::Serialize for FramedContentAuthData {
193 fn tls_serialize<W: std::io::Write>(&self, writer: &mut W) -> Result<usize, tls_codec::Error> {
194 let mut written = self.signature.tls_serialize(writer)?;
195 if let Some(confirmation_tag) = &self.confirmation_tag {
196 written += confirmation_tag.tls_serialize(writer)?;
197 }
198 Ok(written)
199 }
200}
201
202impl tls_codec::Serialize for &FramedContentAuthData {
203 fn tls_serialize<W: std::io::Write>(&self, writer: &mut W) -> Result<usize, tls_codec::Error> {
204 (*self).tls_serialize(writer)
205 }
206}
207
208impl FramedContentAuthData {
209 pub fn tls_deserialize_with_content_type<R: std::io::Read>(
210 bytes: &mut R,
211 content_type: ContentType,
212 ) -> Result<Self, tls_codec::Error> {
213 let signature = SensitiveBytes::tls_deserialize(bytes)?;
214 let confirmation_tag = (content_type == ContentType::Commit)
215 .then(|| Mac::tls_deserialize(bytes))
216 .transpose()?;
217
218 Ok(Self {
219 signature,
220 confirmation_tag,
221 })
222 }
223}
224
225#[derive(
226 Debug,
227 Clone,
228 PartialEq,
229 Eq,
230 tls_codec::TlsSerialize,
231 tls_codec::TlsDeserialize,
232 tls_codec::TlsSize,
233)]
234#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
235#[repr(u16)]
236pub enum MlsMessageContent {
237 #[tls_codec(discriminant = "WireFormat::MLS_PUBLIC_MESSAGE")]
238 MlsPublicMessage(PublicMessage),
239 #[tls_codec(discriminant = "WireFormat::MLS_PRIVATE_MESSAGE")]
240 MlsPrivateMessage(PrivateMessage),
241 #[tls_codec(discriminant = "WireFormat::MLS_WELCOME")]
242 Welcome(Welcome),
243 #[tls_codec(discriminant = "WireFormat::MLS_GROUP_INFO")]
244 GroupInfo(GroupInfo),
245 #[tls_codec(discriminant = "WireFormat::MLS_KEY_PACKAGE")]
246 KeyPackage(KeyPackage),
247 #[cfg(feature = "draft-ietf-mls-extensions")]
248 #[tls_codec(discriminant = "WireFormat::MLS_TARGETED_MESSAGE")]
249 MlsTargetedMessage(crate::drafts::mls_extensions::targeted_message::TargetedMessage),
250 #[cfg(feature = "draft-mahy-mls-semiprivatemessage")]
251 #[tls_codec(discriminant = "WireFormat::MLS_SEMIPRIVATE_MESSAGE")]
252 MlsSemiPrivateMessage(crate::drafts::semiprivate_message::messages::SemiPrivateMessage),
253 #[cfg(feature = "draft-mularczyk-mls-splitcommit")]
254 #[tls_codec(discriminant = "WireFormat::MLS_SPLIT_COMMIT")]
255 MlsSplitCommitMessage(crate::drafts::split_commit::SplitCommitMessage),
256 #[cfg(feature = "draft-pham-mls-additional-wire-formats")]
257 #[tls_codec(discriminant = "WireFormat::MLS_MESSAGE_WITHOUT_AAD")]
258 MlsMessageWithoutAad(crate::drafts::additional_wire_formats::MessageWithoutAad),
259}
260
261impl MlsMessageContent {
262 pub fn content_type(&self) -> Option<ContentType> {
263 match self {
264 MlsMessageContent::MlsPublicMessage(pub_msg) => Some((&pub_msg.content.content).into()),
265 MlsMessageContent::MlsPrivateMessage(priv_msg) => Some(priv_msg.content_type),
266 #[cfg(feature = "draft-mularczyk-mls-splitcommit")]
267 MlsMessageContent::MlsSplitCommitMessage(message) => {
268 message.split_commit_message.content.content_type()
269 }
270 _ => None,
271 }
272 }
273
274 pub fn proposal_type(&self) -> Option<ProposalType> {
275 match self {
276 MlsMessageContent::MlsPublicMessage(pub_msg) => {
277 if let ContentTypeInner::Proposal { proposal } = &pub_msg.content.content {
278 Some(proposal.into())
279 } else {
280 None
281 }
282 }
283 #[cfg(feature = "draft-mularczyk-mls-splitcommit")]
284 MlsMessageContent::MlsSplitCommitMessage(message) => {
285 message.split_commit_message.content.proposal_type()
286 }
287 _ => None,
288 }
289 }
290}
291
292#[allow(clippy::from_over_into)]
293impl Into<WireFormat> for &MlsMessageContent {
294 fn into(self) -> WireFormat {
295 match self {
296 MlsMessageContent::MlsPublicMessage(_) => {
297 WireFormat::new_unchecked(WireFormat::MLS_PUBLIC_MESSAGE)
298 }
299 MlsMessageContent::MlsPrivateMessage(_) => {
300 WireFormat::new_unchecked(WireFormat::MLS_PRIVATE_MESSAGE)
301 }
302 MlsMessageContent::Welcome(_) => WireFormat::new_unchecked(WireFormat::MLS_WELCOME),
303 MlsMessageContent::GroupInfo(_) => {
304 WireFormat::new_unchecked(WireFormat::MLS_GROUP_INFO)
305 }
306 MlsMessageContent::KeyPackage(_) => {
307 WireFormat::new_unchecked(WireFormat::MLS_KEY_PACKAGE)
308 }
309 #[cfg(feature = "draft-ietf-mls-extensions")]
310 MlsMessageContent::MlsTargetedMessage(_) => {
311 WireFormat::new_unchecked(WireFormat::MLS_TARGETED_MESSAGE)
312 }
313 #[cfg(feature = "draft-mahy-mls-semiprivatemessage")]
314 MlsMessageContent::MlsSemiPrivateMessage(_) => {
315 WireFormat::new_unchecked(WireFormat::MLS_SEMIPRIVATE_MESSAGE)
316 }
317 #[cfg(feature = "draft-mularczyk-mls-splitcommit")]
318 MlsMessageContent::MlsSplitCommitMessage(_) => {
319 WireFormat::new_unchecked(WireFormat::MLS_SPLIT_COMMIT)
320 }
321 #[cfg(feature = "draft-pham-mls-additional-wire-formats")]
322 MlsMessageContent::MlsMessageWithoutAad(_) => {
323 WireFormat::new_unchecked(WireFormat::MLS_MESSAGE_WITHOUT_AAD)
324 }
325 }
326 }
327}
328
329#[derive(Debug, Clone, PartialEq, Eq, tls_codec::TlsSerialize, tls_codec::TlsSize)]
330#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
331pub struct AuthenticatedContent {
332 pub wire_format: WireFormat,
333 pub content: FramedContent,
334 pub auth: FramedContentAuthData,
335}
336
337#[derive(Debug, Clone, PartialEq, Eq, tls_codec::TlsSerialize, tls_codec::TlsSize)]
338#[cfg_attr(feature = "serde", derive(serde::Serialize))]
339pub struct AuthenticatedContentRef<'a> {
340 pub wire_format: &'a WireFormat,
341 pub content: &'a FramedContent,
342 pub auth: &'a FramedContentAuthData,
343}
344
345impl AuthenticatedContent {
346 pub fn confirmed_transcript_hash_input(&self) -> ConfirmedTranscriptHashInput {
347 ConfirmedTranscriptHashInput {
348 wire_format: &self.wire_format,
349 content: &self.content,
350 signature: &self.auth.signature,
351 }
352 }
353
354 pub fn as_ref(&self) -> AuthenticatedContentRef {
355 AuthenticatedContentRef {
356 wire_format: &self.wire_format,
357 content: &self.content,
358 auth: &self.auth,
359 }
360 }
361}
362
363impl tls_codec::Deserialize for AuthenticatedContent {
364 fn tls_deserialize<R: std::io::Read>(bytes: &mut R) -> Result<Self, tls_codec::Error>
365 where
366 Self: Sized,
367 {
368 let wire_format = WireFormat::tls_deserialize(bytes)?;
369 let content = FramedContent::tls_deserialize(bytes)?;
370 let auth = FramedContentAuthData::tls_deserialize_with_content_type(
371 bytes,
372 (&content.content).into(),
373 )?;
374 Ok(Self {
375 wire_format,
376 content,
377 auth,
378 })
379 }
380}