1use tls_codec::Deserialize;
2
3use crate::{
4 MlsSpecError, MlsSpecResult, SensitiveBytes,
5 crypto::Mac,
6 defs::{Epoch, ProposalType, ProtocolVersion, WireFormat},
7 group::{GroupId, group_info::GroupInfo, welcome::Welcome},
8 key_package::KeyPackage,
9 key_schedule::{ConfirmedTranscriptHashInput, GroupContext},
10 messages::{ContentType, ContentTypeInner, PrivateMessage, PublicMessage, Sender, SenderType},
11};
12
13#[derive(
14 Debug,
15 Clone,
16 PartialEq,
17 Eq,
18 tls_codec::TlsSerialize,
19 tls_codec::TlsDeserialize,
20 tls_codec::TlsSize,
21)]
22#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
23pub struct FramedContent {
24 pub group_id: GroupId,
25 pub epoch: Epoch,
26 pub sender: Sender,
27 pub authenticated_data: SensitiveBytes,
28 pub content: ContentTypeInner,
29}
30
31impl FramedContent {
32 pub fn to_tbs<'a>(
33 &'a self,
34 wire_format: &'a WireFormat,
35 ctx: &'a GroupContext,
36 ) -> MlsSpecResult<FramedContentTBS<'a>> {
37 let sender_type_raw: SenderType = (&self.sender).into();
38 let sender_type =
39 FramedContentTBSSenderType::from_sender_type_with_ctx(sender_type_raw, Some(ctx))?;
40
41 Ok(FramedContentTBS {
42 version: &ctx.version,
43 wire_format,
44 content: self,
45 sender_type,
46 })
47 }
48}
49
50#[derive(Debug, Copy, Clone, PartialEq, Eq, tls_codec::TlsSerialize, tls_codec::TlsSize)]
51#[cfg_attr(feature = "serde", derive(serde::Serialize))]
52#[repr(u8)]
53pub enum FramedContentTBSSenderType<'a> {
54 #[tls_codec(discriminant = "SenderType::Member")]
55 Member(FramedContentTBSSenderTypeContext<'a>),
56 #[tls_codec(discriminant = "SenderType::External")]
57 External,
58 #[tls_codec(discriminant = "SenderType::NewMemberCommit")]
59 NewMemberCommit(FramedContentTBSSenderTypeContext<'a>),
60 #[tls_codec(discriminant = "SenderType::NewMemberProposal")]
61 NewMemberProposal,
62}
63
64impl<'a> FramedContentTBSSenderType<'a> {
65 pub fn from_sender_type_with_ctx(
66 sender_type: SenderType,
67 mut ctx: Option<&'a GroupContext>,
68 ) -> MlsSpecResult<Self> {
69 Ok(match sender_type {
70 SenderType::NewMemberCommit => {
71 let Some(context) = ctx.take() else {
72 return Err(MlsSpecError::FramedContentTBSMissingGroupContext);
73 };
74 Self::NewMemberCommit(FramedContentTBSSenderTypeContext { context })
75 }
76 SenderType::Member => {
77 let Some(context) = ctx.take() else {
78 return Err(MlsSpecError::FramedContentTBSMissingGroupContext);
79 };
80 Self::Member(FramedContentTBSSenderTypeContext { context })
81 }
82 SenderType::External => Self::External,
83 SenderType::NewMemberProposal => Self::NewMemberProposal,
84 _ => return Err(MlsSpecError::ReservedValueUsage),
85 })
86 }
87}
88
89#[derive(Debug, Copy, Clone, PartialEq, Eq, tls_codec::TlsSerialize, tls_codec::TlsSize)]
90#[cfg_attr(feature = "serde", derive(serde::Serialize))]
91pub struct FramedContentTBSSenderTypeContext<'a> {
92 pub context: &'a GroupContext,
93}
94
95#[derive(Debug, Copy, Clone, PartialEq, Eq)]
96pub struct FramedContentTBS<'a> {
97 pub version: &'a ProtocolVersion,
98 pub wire_format: &'a WireFormat,
99 pub content: &'a FramedContent,
100 pub sender_type: FramedContentTBSSenderType<'a>,
101}
102
103impl tls_codec::Size for FramedContentTBS<'_> {
105 fn tls_serialized_len(&self) -> usize {
106 let mut len = self.version.tls_serialized_len()
107 + self.wire_format.tls_serialized_len()
108 + self.content.tls_serialized_len();
109 if matches!(
110 self.content.sender,
111 Sender::Member(_) | Sender::NewMemberCommit
112 ) {
113 match &self.sender_type {
114 FramedContentTBSSenderType::NewMemberCommit(context)
115 | FramedContentTBSSenderType::Member(context) => {
116 len += context.tls_serialized_len();
117 }
118 _ => {}
119 }
120 }
121 len
122 }
123}
124
125impl tls_codec::Serialize for FramedContentTBS<'_> {
126 fn tls_serialize<W: std::io::Write>(&self, writer: &mut W) -> Result<usize, tls_codec::Error> {
127 let mut ret = self.version.tls_serialize(writer)?;
128
129 ret += self.wire_format.tls_serialize(writer)?;
130 ret += self.content.tls_serialize(writer)?;
131 if matches!(
132 self.content.sender,
133 Sender::Member(_) | Sender::NewMemberCommit
134 ) {
135 match &self.sender_type {
136 FramedContentTBSSenderType::NewMemberCommit(context)
137 | FramedContentTBSSenderType::Member(context) => {
138 ret += context.tls_serialize(writer)?;
139 }
140 _ => {}
141 }
142 }
143
144 Ok(ret)
145 }
146}
147
148impl<'a> tls_codec::Size for &'a FramedContentTBS<'a> {
149 fn tls_serialized_len(&self) -> usize {
150 (*self).tls_serialized_len()
151 }
152}
153
154impl<'a> tls_codec::Serialize for &'a FramedContentTBS<'a> {
155 fn tls_serialize<W: std::io::Write>(&self, writer: &mut W) -> Result<usize, tls_codec::Error> {
156 (*self).tls_serialize(writer)
157 }
158}
159
160#[derive(Debug, Clone, PartialEq, Eq)]
161#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
162pub struct FramedContentAuthData {
163 pub signature: SensitiveBytes,
164 pub confirmation_tag: Option<Mac>,
165}
166
167impl FramedContentAuthData {
168 pub fn without_confirmation_tag(&self) -> Self {
169 Self {
170 signature: self.signature.clone(),
171 confirmation_tag: None,
172 }
173 }
174}
175
176impl tls_codec::Size for FramedContentAuthData {
177 fn tls_serialized_len(&self) -> usize {
178 self.signature.tls_serialized_len()
179 + self
180 .confirmation_tag
181 .as_ref()
182 .map_or(0, SensitiveBytes::tls_serialized_len)
183 }
184}
185
186impl tls_codec::Size for &FramedContentAuthData {
187 fn tls_serialized_len(&self) -> usize {
188 (*self).tls_serialized_len()
189 }
190}
191
192impl tls_codec::Serialize for FramedContentAuthData {
193 fn tls_serialize<W: std::io::Write>(&self, writer: &mut W) -> Result<usize, tls_codec::Error> {
194 let mut written = self.signature.tls_serialize(writer)?;
195 if let Some(confirmation_tag) = &self.confirmation_tag {
196 written += confirmation_tag.tls_serialize(writer)?;
197 }
198 Ok(written)
199 }
200}
201
202impl tls_codec::Serialize for &FramedContentAuthData {
203 fn tls_serialize<W: std::io::Write>(&self, writer: &mut W) -> Result<usize, tls_codec::Error> {
204 (*self).tls_serialize(writer)
205 }
206}
207
208impl FramedContentAuthData {
209 pub fn tls_deserialize_with_content_type<R: std::io::Read>(
210 bytes: &mut R,
211 content_type: ContentType,
212 ) -> Result<Self, tls_codec::Error> {
213 let signature = SensitiveBytes::tls_deserialize(bytes)?;
214 let confirmation_tag = (content_type == ContentType::Commit)
215 .then(|| Mac::tls_deserialize(bytes))
216 .transpose()?;
217
218 Ok(Self {
219 signature,
220 confirmation_tag,
221 })
222 }
223}
224
225#[derive(
226 Debug,
227 Clone,
228 PartialEq,
229 Eq,
230 tls_codec::TlsSerialize,
231 tls_codec::TlsDeserialize,
232 tls_codec::TlsSize,
233)]
234#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
235#[repr(u16)]
236pub enum MlsMessageContent {
237 #[tls_codec(discriminant = "WireFormat::MLS_PUBLIC_MESSAGE")]
238 MlsPublicMessage(PublicMessage),
239 #[tls_codec(discriminant = "WireFormat::MLS_PRIVATE_MESSAGE")]
240 MlsPrivateMessage(PrivateMessage),
241 #[tls_codec(discriminant = "WireFormat::MLS_WELCOME")]
242 Welcome(Welcome),
243 #[tls_codec(discriminant = "WireFormat::MLS_GROUP_INFO")]
244 GroupInfo(GroupInfo),
245 #[tls_codec(discriminant = "WireFormat::MLS_KEY_PACKAGE")]
246 KeyPackage(KeyPackage),
247 #[cfg(feature = "draft-ietf-mls-extensions")]
248 #[tls_codec(discriminant = "WireFormat::MLS_TARGETED_MESSAGE")]
249 MlsTargetedMessage(crate::drafts::mls_extensions::targeted_message::TargetedMessage),
250 #[cfg(feature = "draft-mahy-mls-semiprivatemessage")]
251 #[tls_codec(discriminant = "WireFormat::MLS_SEMIPRIVATE_MESSAGE")]
252 MlsSemiPrivateMessage(crate::drafts::semiprivate_message::messages::SemiPrivateMessage),
253 #[cfg(feature = "draft-mularczyk-mls-splitcommit")]
254 #[tls_codec(discriminant = "WireFormat::MLS_SPLIT_COMMIT")]
255 MlsSplitCommitMessage(crate::drafts::split_commit::SplitCommitMessage),
256 #[cfg(feature = "draft-pham-mls-additional-wire-formats")]
257 #[tls_codec(discriminant = "WireFormat::MLS_MESSAGE_WITHOUT_AAD")]
258 MlsMessageWithoutAad(crate::drafts::additional_wire_formats::MessageWithoutAad),
259}
260
261impl MlsMessageContent {
262 pub fn content_type(&self) -> Option<ContentType> {
263 match self {
264 MlsMessageContent::MlsPublicMessage(pub_msg) => Some((&pub_msg.content.content).into()),
265 MlsMessageContent::MlsPrivateMessage(priv_msg) => Some(priv_msg.content_type),
266 #[cfg(feature = "draft-mahy-mls-semiprivatemessage")]
267 MlsMessageContent::MlsSemiPrivateMessage(semi_priv_msg) => {
268 Some(semi_priv_msg.content_type)
269 }
270 #[cfg(feature = "draft-mularczyk-mls-splitcommit")]
271 MlsMessageContent::MlsSplitCommitMessage(message) => {
272 message.split_commit_message.content.content_type()
273 }
274 _ => None,
275 }
276 }
277
278 pub fn proposal_type(&self) -> Option<ProposalType> {
279 match self {
280 MlsMessageContent::MlsPublicMessage(pub_msg) => {
281 if let ContentTypeInner::Proposal { proposal } = &pub_msg.content.content {
282 Some(proposal.into())
283 } else {
284 None
285 }
286 }
287 #[cfg(feature = "draft-mularczyk-mls-splitcommit")]
288 MlsMessageContent::MlsSplitCommitMessage(message) => {
289 message.split_commit_message.content.proposal_type()
290 }
291 _ => None,
292 }
293 }
294
295 pub fn authenticated_data(&self) -> Option<&[u8]> {
296 match self {
297 MlsMessageContent::MlsPublicMessage(public_message) => {
298 Some(&public_message.content.authenticated_data)
299 }
300 MlsMessageContent::MlsPrivateMessage(private_message) => {
301 Some(&private_message.authenticated_data)
302 }
303 #[cfg(feature = "draft-mahy-mls-semiprivatemessage")]
304 MlsMessageContent::MlsSemiPrivateMessage(semi_priv_msg) => {
305 Some(&semi_priv_msg.authenticated_data)
306 }
307 #[cfg(feature = "draft-mularczyk-mls-splitcommit")]
308 MlsMessageContent::MlsSplitCommitMessage(message) => {
309 message.split_commit_message.content.authenticated_data()
310 }
311 _ => None,
312 }
313 }
314}
315
316#[allow(clippy::from_over_into)]
317impl Into<WireFormat> for &MlsMessageContent {
318 fn into(self) -> WireFormat {
319 match self {
320 MlsMessageContent::MlsPublicMessage(_) => {
321 WireFormat::new_unchecked(WireFormat::MLS_PUBLIC_MESSAGE)
322 }
323 MlsMessageContent::MlsPrivateMessage(_) => {
324 WireFormat::new_unchecked(WireFormat::MLS_PRIVATE_MESSAGE)
325 }
326 MlsMessageContent::Welcome(_) => WireFormat::new_unchecked(WireFormat::MLS_WELCOME),
327 MlsMessageContent::GroupInfo(_) => {
328 WireFormat::new_unchecked(WireFormat::MLS_GROUP_INFO)
329 }
330 MlsMessageContent::KeyPackage(_) => {
331 WireFormat::new_unchecked(WireFormat::MLS_KEY_PACKAGE)
332 }
333 #[cfg(feature = "draft-ietf-mls-extensions")]
334 MlsMessageContent::MlsTargetedMessage(_) => {
335 WireFormat::new_unchecked(WireFormat::MLS_TARGETED_MESSAGE)
336 }
337 #[cfg(feature = "draft-mahy-mls-semiprivatemessage")]
338 MlsMessageContent::MlsSemiPrivateMessage(_) => {
339 WireFormat::new_unchecked(WireFormat::MLS_SEMIPRIVATE_MESSAGE)
340 }
341 #[cfg(feature = "draft-mularczyk-mls-splitcommit")]
342 MlsMessageContent::MlsSplitCommitMessage(_) => {
343 WireFormat::new_unchecked(WireFormat::MLS_SPLIT_COMMIT)
344 }
345 #[cfg(feature = "draft-pham-mls-additional-wire-formats")]
346 MlsMessageContent::MlsMessageWithoutAad(_) => {
347 WireFormat::new_unchecked(WireFormat::MLS_MESSAGE_WITHOUT_AAD)
348 }
349 }
350 }
351}
352
353#[derive(Debug, Clone, PartialEq, Eq, tls_codec::TlsSerialize, tls_codec::TlsSize)]
354#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
355pub struct AuthenticatedContent {
356 pub wire_format: WireFormat,
357 pub content: FramedContent,
358 pub auth: FramedContentAuthData,
359}
360
361#[derive(Debug, Clone, PartialEq, Eq, tls_codec::TlsSerialize, tls_codec::TlsSize)]
362#[cfg_attr(feature = "serde", derive(serde::Serialize))]
363pub struct AuthenticatedContentRef<'a> {
364 pub wire_format: &'a WireFormat,
365 pub content: &'a FramedContent,
366 pub auth: &'a FramedContentAuthData,
367}
368
369impl AuthenticatedContent {
370 pub fn confirmed_transcript_hash_input(&self) -> ConfirmedTranscriptHashInput {
371 ConfirmedTranscriptHashInput {
372 wire_format: &self.wire_format,
373 content: &self.content,
374 signature: &self.auth.signature,
375 }
376 }
377
378 pub fn as_ref(&self) -> AuthenticatedContentRef {
379 AuthenticatedContentRef {
380 wire_format: &self.wire_format,
381 content: &self.content,
382 auth: &self.auth,
383 }
384 }
385}
386
387impl tls_codec::Deserialize for AuthenticatedContent {
388 fn tls_deserialize<R: std::io::Read>(bytes: &mut R) -> Result<Self, tls_codec::Error>
389 where
390 Self: Sized,
391 {
392 let wire_format = WireFormat::tls_deserialize(bytes)?;
393 let content = FramedContent::tls_deserialize(bytes)?;
394 let auth = FramedContentAuthData::tls_deserialize_with_content_type(
395 bytes,
396 (&content.content).into(),
397 )?;
398 Ok(Self {
399 wire_format,
400 content,
401 auth,
402 })
403 }
404}