1use core::ops::Deref;
6
7use crate::{client::MlsError, tree_kem::node::LeafIndex, KeyPackage, KeyPackageRef};
8
9use super::{Commit, FramedContentAuthData, GroupInfo, MembershipTag, Welcome};
10
11use crate::group::proposal::{Proposal, ProposalOrRef};
12
13#[cfg(feature = "by_ref_proposal")]
14use crate::mls_rules::ProposalRef;
15
16use alloc::vec::Vec;
17use core::fmt::{self, Debug};
18use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
19use mls_rs_core::{
20 crypto::{CipherSuite, CipherSuiteProvider},
21 protocol_version::ProtocolVersion,
22};
23use zeroize::ZeroizeOnDrop;
24
25#[cfg(feature = "private_message")]
26use alloc::boxed::Box;
27
28#[cfg(feature = "custom_proposal")]
29use crate::group::proposal::CustomProposal;
30
31#[derive(Copy, Clone, Debug, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)]
32#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
33#[repr(u8)]
34pub enum ContentType {
35 #[cfg(feature = "private_message")]
36 Application = 1u8,
37 #[cfg(feature = "by_ref_proposal")]
38 Proposal = 2u8,
39 Commit = 3u8,
40}
41
42impl From<&Content> for ContentType {
43 fn from(content: &Content) -> Self {
44 match content {
45 #[cfg(feature = "private_message")]
46 Content::Application(_) => ContentType::Application,
47 #[cfg(feature = "by_ref_proposal")]
48 Content::Proposal(_) => ContentType::Proposal,
49 Content::Commit(_) => ContentType::Commit,
50 }
51 }
52}
53
54#[derive(Clone, Copy, Debug, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)]
55#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
56#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
57#[repr(u8)]
58#[non_exhaustive]
59pub enum Sender {
61 Member(u32) = 1u8,
63 #[cfg(feature = "by_ref_proposal")]
68 External(u32) = 2u8,
69 #[cfg(feature = "by_ref_proposal")]
71 NewMemberProposal = 3u8,
72 NewMemberCommit = 4u8,
74}
75
76impl From<LeafIndex> for Sender {
77 fn from(leaf_index: LeafIndex) -> Self {
78 Sender::Member(*leaf_index)
79 }
80}
81
82impl From<u32> for Sender {
83 fn from(leaf_index: u32) -> Self {
84 Sender::Member(leaf_index)
85 }
86}
87
88#[derive(Clone, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode, ZeroizeOnDrop)]
89#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
90#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
91pub struct ApplicationData(
92 #[mls_codec(with = "mls_rs_codec::byte_vec")]
93 #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))]
94 Vec<u8>,
95);
96
97impl Debug for ApplicationData {
98 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
99 mls_rs_core::debug::pretty_bytes(&self.0)
100 .named("ApplicationData")
101 .fmt(f)
102 }
103}
104
105impl From<Vec<u8>> for ApplicationData {
106 fn from(data: Vec<u8>) -> Self {
107 Self(data)
108 }
109}
110
111impl Deref for ApplicationData {
112 type Target = [u8];
113
114 fn deref(&self) -> &Self::Target {
115 &self.0
116 }
117}
118
119impl ApplicationData {
120 pub fn as_bytes(&self) -> &[u8] {
122 &self.0
123 }
124}
125
126#[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode)]
127#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
128#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
129#[repr(u8)]
130pub(crate) enum Content {
131 #[cfg(feature = "private_message")]
132 Application(ApplicationData) = 1u8,
133 #[cfg(feature = "by_ref_proposal")]
134 Proposal(alloc::boxed::Box<Proposal>) = 2u8,
135 Commit(alloc::boxed::Box<Commit>) = 3u8,
136}
137
138impl Content {
139 pub fn content_type(&self) -> ContentType {
140 self.into()
141 }
142}
143
144#[derive(Clone, Debug, PartialEq)]
145#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
146pub(crate) struct PublicMessage {
147 pub content: FramedContent,
148 pub auth: FramedContentAuthData,
149 pub membership_tag: Option<MembershipTag>,
150}
151
152impl MlsSize for PublicMessage {
153 fn mls_encoded_len(&self) -> usize {
154 self.content.mls_encoded_len()
155 + self.auth.mls_encoded_len()
156 + self
157 .membership_tag
158 .as_ref()
159 .map_or(0, |tag| tag.mls_encoded_len())
160 }
161}
162
163impl MlsEncode for PublicMessage {
164 fn mls_encode(&self, writer: &mut Vec<u8>) -> Result<(), mls_rs_codec::Error> {
165 self.content.mls_encode(writer)?;
166 self.auth.mls_encode(writer)?;
167
168 self.membership_tag
169 .as_ref()
170 .map_or(Ok(()), |tag| tag.mls_encode(writer))
171 }
172}
173
174impl MlsDecode for PublicMessage {
175 fn mls_decode(reader: &mut &[u8]) -> Result<Self, mls_rs_codec::Error> {
176 let content = FramedContent::mls_decode(reader)?;
177 let auth = FramedContentAuthData::mls_decode(reader, content.content_type())?;
178
179 let membership_tag = match content.sender {
180 Sender::Member(_) => Some(MembershipTag::mls_decode(reader)?),
181 _ => None,
182 };
183
184 Ok(Self {
185 content,
186 auth,
187 membership_tag,
188 })
189 }
190}
191
192#[cfg(feature = "private_message")]
193#[derive(Clone, Debug, PartialEq)]
194pub(crate) struct PrivateMessageContent {
195 pub content: Content,
196 pub auth: FramedContentAuthData,
197}
198
199#[cfg(feature = "private_message")]
200impl MlsSize for PrivateMessageContent {
201 fn mls_encoded_len(&self) -> usize {
202 let content_len_without_type = match &self.content {
203 Content::Application(c) => c.mls_encoded_len(),
204 #[cfg(feature = "by_ref_proposal")]
205 Content::Proposal(c) => c.mls_encoded_len(),
206 Content::Commit(c) => c.mls_encoded_len(),
207 };
208
209 content_len_without_type + self.auth.mls_encoded_len()
210 }
211}
212
213#[cfg(feature = "private_message")]
214impl MlsEncode for PrivateMessageContent {
215 fn mls_encode(&self, writer: &mut Vec<u8>) -> Result<(), mls_rs_codec::Error> {
216 match &self.content {
217 Content::Application(c) => c.mls_encode(writer),
218 #[cfg(feature = "by_ref_proposal")]
219 Content::Proposal(c) => c.mls_encode(writer),
220 Content::Commit(c) => c.mls_encode(writer),
221 }?;
222
223 self.auth.mls_encode(writer)?;
224
225 Ok(())
226 }
227}
228
229#[cfg(feature = "private_message")]
230impl PrivateMessageContent {
231 pub(crate) fn mls_decode(
232 reader: &mut &[u8],
233 content_type: ContentType,
234 ) -> Result<Self, mls_rs_codec::Error> {
235 let content = match content_type {
236 ContentType::Application => Content::Application(ApplicationData::mls_decode(reader)?),
237 #[cfg(feature = "by_ref_proposal")]
238 ContentType::Proposal => Content::Proposal(Box::new(Proposal::mls_decode(reader)?)),
239 ContentType::Commit => {
240 Content::Commit(alloc::boxed::Box::new(Commit::mls_decode(reader)?))
241 }
242 };
243
244 let auth = FramedContentAuthData::mls_decode(reader, content.content_type())?;
245
246 if reader.iter().any(|&i| i != 0u8) {
247 return Err(mls_rs_codec::Error::Custom(5));
254 }
255
256 Ok(Self { content, auth })
257 }
258}
259
260#[cfg(feature = "private_message")]
261#[derive(Clone, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)]
262pub struct PrivateContentAAD {
263 #[mls_codec(with = "mls_rs_codec::byte_vec")]
264 pub group_id: Vec<u8>,
265 pub epoch: u64,
266 pub content_type: ContentType,
267 #[mls_codec(with = "mls_rs_codec::byte_vec")]
268 pub authenticated_data: Vec<u8>,
269}
270
271#[cfg(feature = "private_message")]
272impl Debug for PrivateContentAAD {
273 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
274 f.debug_struct("PrivateContentAAD")
275 .field(
276 "group_id",
277 &mls_rs_core::debug::pretty_group_id(&self.group_id),
278 )
279 .field("epoch", &self.epoch)
280 .field("content_type", &self.content_type)
281 .field(
282 "authenticated_data",
283 &mls_rs_core::debug::pretty_bytes(&self.authenticated_data),
284 )
285 .finish()
286 }
287}
288
289#[cfg(feature = "private_message")]
290#[derive(Clone, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)]
291#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
292pub struct PrivateMessage {
293 #[mls_codec(with = "mls_rs_codec::byte_vec")]
294 pub group_id: Vec<u8>,
295 pub epoch: u64,
296 pub content_type: ContentType,
297 #[mls_codec(with = "mls_rs_codec::byte_vec")]
298 pub authenticated_data: Vec<u8>,
299 #[mls_codec(with = "mls_rs_codec::byte_vec")]
300 pub encrypted_sender_data: Vec<u8>,
301 #[mls_codec(with = "mls_rs_codec::byte_vec")]
302 pub ciphertext: Vec<u8>,
303}
304
305#[cfg(feature = "private_message")]
306impl Debug for PrivateMessage {
307 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
308 f.debug_struct("PrivateMessage")
309 .field(
310 "group_id",
311 &mls_rs_core::debug::pretty_group_id(&self.group_id),
312 )
313 .field("epoch", &self.epoch)
314 .field("content_type", &self.content_type)
315 .field(
316 "authenticated_data",
317 &mls_rs_core::debug::pretty_bytes(&self.authenticated_data),
318 )
319 .field(
320 "encrypted_sender_data",
321 &mls_rs_core::debug::pretty_bytes(&self.encrypted_sender_data),
322 )
323 .field(
324 "ciphertext",
325 &mls_rs_core::debug::pretty_bytes(&self.ciphertext),
326 )
327 .finish()
328 }
329}
330
331#[cfg(feature = "private_message")]
332impl From<&PrivateMessage> for PrivateContentAAD {
333 fn from(ciphertext: &PrivateMessage) -> Self {
334 Self {
335 group_id: ciphertext.group_id.clone(),
336 epoch: ciphertext.epoch,
337 content_type: ciphertext.content_type,
338 authenticated_data: ciphertext.authenticated_data.clone(),
339 }
340 }
341}
342
343#[derive(Clone, Debug, PartialEq)]
344pub enum MlsMessageDescription<'a> {
345 Welcome {
346 key_package_refs: Vec<&'a KeyPackageRef>,
347 cipher_suite: CipherSuite,
348 },
349 PrivateProtocolMessage {
350 group_id: &'a [u8],
351 epoch_id: u64,
352 content_type: ContentType, },
354 PublicProtocolMessage {
355 group_id: &'a [u8],
356 epoch_id: u64,
357 content_type: ContentType,
358 sender: Sender,
359 authenticated_data: &'a [u8],
360 },
361 GroupInfo,
362 KeyPackage,
363}
364
365impl MlsMessage {
366 pub fn description(&self) -> MlsMessageDescription<'_> {
367 match &self.payload {
368 MlsMessagePayload::Welcome(w) => MlsMessageDescription::Welcome {
369 key_package_refs: w.secrets.iter().map(|s| &s.new_member).collect(),
370 cipher_suite: w.cipher_suite,
371 },
372 MlsMessagePayload::Plain(p) => MlsMessageDescription::PublicProtocolMessage {
373 group_id: &p.content.group_id,
374 epoch_id: p.content.epoch,
375 content_type: p.content.content_type(),
376 sender: p.content.sender,
377 authenticated_data: &p.content.authenticated_data,
378 },
379 #[cfg(feature = "private_message")]
380 MlsMessagePayload::Cipher(c) => MlsMessageDescription::PrivateProtocolMessage {
381 group_id: &c.group_id,
382 epoch_id: c.epoch,
383 content_type: c.content_type,
384 },
385 MlsMessagePayload::GroupInfo(_) => MlsMessageDescription::GroupInfo,
386 MlsMessagePayload::KeyPackage(_) => MlsMessageDescription::KeyPackage,
387 }
388 }
389}
390
391#[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode)]
392#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
393pub struct MlsMessage {
395 pub(crate) version: ProtocolVersion,
396 pub(crate) payload: MlsMessagePayload,
397}
398
399#[allow(dead_code)]
400impl MlsMessage {
401 pub(crate) fn new(version: ProtocolVersion, payload: MlsMessagePayload) -> MlsMessage {
402 Self { version, payload }
403 }
404
405 #[inline(always)]
406 pub(crate) fn into_plaintext(self) -> Option<PublicMessage> {
407 match self.payload {
408 MlsMessagePayload::Plain(plaintext) => Some(plaintext),
409 _ => None,
410 }
411 }
412
413 #[cfg(feature = "private_message")]
414 #[inline(always)]
415 pub(crate) fn into_ciphertext(self) -> Option<PrivateMessage> {
416 match self.payload {
417 MlsMessagePayload::Cipher(ciphertext) => Some(ciphertext),
418 _ => None,
419 }
420 }
421
422 #[inline(always)]
423 pub(crate) fn into_welcome(self) -> Option<Welcome> {
424 match self.payload {
425 MlsMessagePayload::Welcome(welcome) => Some(welcome),
426 _ => None,
427 }
428 }
429
430 #[inline(always)]
431 pub fn into_group_info(self) -> Option<GroupInfo> {
432 match self.payload {
433 MlsMessagePayload::GroupInfo(info) => Some(info),
434 _ => None,
435 }
436 }
437
438 #[inline(always)]
439 pub fn as_group_info(&self) -> Option<&GroupInfo> {
440 match &self.payload {
441 MlsMessagePayload::GroupInfo(info) => Some(info),
442 _ => None,
443 }
444 }
445
446 #[inline(always)]
447 pub fn into_key_package(self) -> Option<KeyPackage> {
448 match self.payload {
449 MlsMessagePayload::KeyPackage(kp) => Some(kp),
450 _ => None,
451 }
452 }
453
454 pub fn as_key_package(&self) -> Option<&KeyPackage> {
455 match &self.payload {
456 MlsMessagePayload::KeyPackage(kp) => Some(kp),
457 _ => None,
458 }
459 }
460
461 pub fn wire_format(&self) -> WireFormat {
463 match self.payload {
464 MlsMessagePayload::Plain(_) => WireFormat::PublicMessage,
465 #[cfg(feature = "private_message")]
466 MlsMessagePayload::Cipher(_) => WireFormat::PrivateMessage,
467 MlsMessagePayload::Welcome(_) => WireFormat::Welcome,
468 MlsMessagePayload::GroupInfo(_) => WireFormat::GroupInfo,
469 MlsMessagePayload::KeyPackage(_) => WireFormat::KeyPackage,
470 }
471 }
472
473 pub fn epoch(&self) -> Option<u64> {
478 match &self.payload {
479 MlsMessagePayload::Plain(p) => Some(p.content.epoch),
480 #[cfg(feature = "private_message")]
481 MlsMessagePayload::Cipher(c) => Some(c.epoch),
482 MlsMessagePayload::GroupInfo(gi) => Some(gi.group_context.epoch),
483 _ => None,
484 }
485 }
486
487 pub fn cipher_suite(&self) -> Option<CipherSuite> {
488 match &self.payload {
489 MlsMessagePayload::GroupInfo(i) => Some(i.group_context.cipher_suite),
490 MlsMessagePayload::Welcome(w) => Some(w.cipher_suite),
491 MlsMessagePayload::KeyPackage(k) => Some(k.cipher_suite),
492 _ => None,
493 }
494 }
495
496 pub fn group_id(&self) -> Option<&[u8]> {
497 match &self.payload {
498 MlsMessagePayload::Plain(p) => Some(&p.content.group_id),
499 #[cfg(feature = "private_message")]
500 MlsMessagePayload::Cipher(p) => Some(&p.group_id),
501 MlsMessagePayload::GroupInfo(p) => Some(&p.group_context.group_id),
502 MlsMessagePayload::KeyPackage(_) | MlsMessagePayload::Welcome(_) => None,
503 }
504 }
505
506 #[inline(never)]
508 pub fn from_bytes(bytes: &[u8]) -> Result<Self, MlsError> {
509 Self::mls_decode(&mut &*bytes).map_err(Into::into)
510 }
511
512 pub fn to_bytes(&self) -> Result<Vec<u8>, MlsError> {
514 self.mls_encode_to_vec().map_err(Into::into)
515 }
516
517 #[cfg(feature = "custom_proposal")]
520 pub fn custom_proposals_by_value(&self) -> Vec<&CustomProposal> {
521 match &self.payload {
522 MlsMessagePayload::Plain(plaintext) => match &plaintext.content.content {
523 Content::Commit(commit) => Self::find_custom_proposals(commit),
524 _ => Vec::new(),
525 },
526 _ => Vec::new(),
527 }
528 }
529
530 #[allow(unreachable_patterns)]
534 pub fn proposals_by_value(&self) -> Vec<&Proposal> {
535 match &self.payload {
536 MlsMessagePayload::Plain(plaintext) => match &plaintext.content.content {
537 Content::Commit(commit) => Self::find_all_proposals(commit),
538 _ => Vec::new(),
539 },
540 _ => Vec::new(),
541 }
542 }
543
544 pub fn welcome_key_package_references(&self) -> Vec<&KeyPackageRef> {
547 let MlsMessagePayload::Welcome(welcome) = &self.payload else {
548 return Vec::new();
549 };
550
551 welcome.secrets.iter().map(|s| &s.new_member).collect()
552 }
553
554 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
556 pub async fn key_package_reference<C: CipherSuiteProvider>(
557 &self,
558 cipher_suite: &C,
559 ) -> Result<Option<KeyPackageRef>, MlsError> {
560 let MlsMessagePayload::KeyPackage(kp) = &self.payload else {
561 return Ok(None);
562 };
563
564 kp.to_reference(cipher_suite).await.map(Some)
565 }
566
567 #[cfg(feature = "by_ref_proposal")]
570 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
571 pub async fn into_proposal_reference<C: CipherSuiteProvider>(
572 self,
573 cipher_suite: &C,
574 ) -> Result<Option<Vec<u8>>, MlsError> {
575 let MlsMessagePayload::Plain(public_message) = self.payload else {
576 return Ok(None);
577 };
578
579 ProposalRef::from_content(cipher_suite, &public_message.into())
580 .await
581 .map(|r| Some(r.to_vec()))
582 }
583}
584
585impl MlsMessage {
586 #[cfg(feature = "custom_proposal")]
587 fn find_custom_proposals(commit: &Commit) -> Vec<&CustomProposal> {
588 commit
589 .proposals
590 .iter()
591 .filter_map(|p| match p {
592 ProposalOrRef::Proposal(p) => match p.as_ref() {
593 crate::group::Proposal::Custom(p) => Some(p),
594 _ => None,
595 },
596 _ => None,
597 })
598 .collect()
599 }
600
601 #[allow(unreachable_patterns)]
602 fn find_all_proposals(commit: &Commit) -> Vec<&Proposal> {
603 commit
604 .proposals
605 .iter()
606 .filter_map(|p| match p {
607 ProposalOrRef::Proposal(p) => Some(p.as_ref()),
608 _ => None,
609 })
610 .collect()
611 }
612}
613
614#[allow(clippy::large_enum_variant)]
615#[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode)]
616#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
617#[repr(u16)]
618pub(crate) enum MlsMessagePayload {
619 Plain(PublicMessage) = 1u16,
620 #[cfg(feature = "private_message")]
621 Cipher(PrivateMessage) = 2u16,
622 Welcome(Welcome) = 3u16,
623 GroupInfo(GroupInfo) = 4u16,
624 KeyPackage(KeyPackage) = 5u16,
625}
626
627impl From<PublicMessage> for MlsMessagePayload {
628 fn from(m: PublicMessage) -> Self {
629 Self::Plain(m)
630 }
631}
632
633#[derive(
634 Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd, MlsSize, MlsEncode, MlsDecode,
635)]
636#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
637#[repr(u16)]
638#[non_exhaustive]
639pub enum WireFormat {
641 PublicMessage = 1u16,
642 PrivateMessage = 2u16,
643 Welcome = 3u16,
644 GroupInfo = 4u16,
645 KeyPackage = 5u16,
646}
647
648#[derive(Clone, PartialEq, MlsSize, MlsEncode, MlsDecode)]
649#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
650#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
651pub(crate) struct FramedContent {
652 #[mls_codec(with = "mls_rs_codec::byte_vec")]
653 #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))]
654 pub group_id: Vec<u8>,
655 pub epoch: u64,
656 pub sender: Sender,
657 #[mls_codec(with = "mls_rs_codec::byte_vec")]
658 #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))]
659 pub authenticated_data: Vec<u8>,
660 pub content: Content,
661}
662
663impl Debug for FramedContent {
664 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
665 f.debug_struct("FramedContent")
666 .field(
667 "group_id",
668 &mls_rs_core::debug::pretty_group_id(&self.group_id),
669 )
670 .field("epoch", &self.epoch)
671 .field("sender", &self.sender)
672 .field(
673 "authenticated_data",
674 &mls_rs_core::debug::pretty_bytes(&self.authenticated_data),
675 )
676 .field("content", &self.content)
677 .finish()
678 }
679}
680
681impl FramedContent {
682 pub fn content_type(&self) -> ContentType {
683 self.content.content_type()
684 }
685}
686
687#[cfg(test)]
688pub(crate) mod test_utils {
689 #[cfg(feature = "private_message")]
690 use crate::group::test_utils::random_bytes;
691
692 use crate::group::{AuthenticatedContent, MessageSignature};
693
694 use super::*;
695
696 use alloc::boxed::Box;
697
698 pub(crate) fn get_test_auth_content() -> AuthenticatedContent {
699 let commit = Commit {
701 proposals: Default::default(),
702 path: None,
703 };
704
705 AuthenticatedContent {
706 wire_format: WireFormat::PublicMessage,
707 content: FramedContent {
708 group_id: Vec::new(),
709 epoch: 0,
710 sender: Sender::Member(1),
711 authenticated_data: Vec::new(),
712 content: Content::Commit(Box::new(commit)),
713 },
714 auth: FramedContentAuthData {
715 signature: MessageSignature::empty(),
716 confirmation_tag: None,
717 },
718 }
719 }
720
721 #[cfg(feature = "private_message")]
722 pub(crate) fn get_test_ciphertext_content() -> PrivateMessageContent {
723 PrivateMessageContent {
724 content: Content::Application(random_bytes(1024).into()),
725 auth: FramedContentAuthData {
726 signature: MessageSignature::from(random_bytes(128)),
727 confirmation_tag: None,
728 },
729 }
730 }
731
732 impl AsRef<[u8]> for ApplicationData {
733 fn as_ref(&self) -> &[u8] {
734 &self.0
735 }
736 }
737}
738
739#[cfg(feature = "private_message")]
740#[cfg(test)]
741mod tests {
742 use alloc::vec;
743 use assert_matches::assert_matches;
744
745 use crate::{
746 client::test_utils::{TEST_CIPHER_SUITE, TEST_PROTOCOL_VERSION},
747 crypto::test_utils::test_cipher_suite_provider,
748 group::{
749 framing::test_utils::get_test_ciphertext_content,
750 proposal_ref::test_utils::auth_content_from_proposal, test_utils::test_group,
751 RemoveProposal,
752 },
753 key_package::test_utils::test_key_package_message,
754 };
755
756 use super::*;
757
758 #[test]
759 fn test_mls_ciphertext_content_mls_encoding() {
760 let ciphertext_content = get_test_ciphertext_content();
761
762 let mut encoded = ciphertext_content.mls_encode_to_vec().unwrap();
763 encoded.extend_from_slice(&[0u8; 128]);
764
765 let decoded =
766 PrivateMessageContent::mls_decode(&mut &*encoded, (&ciphertext_content.content).into())
767 .unwrap();
768
769 assert_eq!(ciphertext_content, decoded);
770 }
771
772 #[test]
773 fn test_mls_ciphertext_content_non_zero_padding_error() {
774 let ciphertext_content = get_test_ciphertext_content();
775
776 let mut encoded = ciphertext_content.mls_encode_to_vec().unwrap();
777 encoded.extend_from_slice(&[1u8; 128]);
778
779 let decoded =
780 PrivateMessageContent::mls_decode(&mut &*encoded, (&ciphertext_content.content).into());
781
782 assert_matches!(decoded, Err(mls_rs_codec::Error::Custom(_)));
783 }
784
785 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
786 async fn proposal_ref() {
787 let cs = test_cipher_suite_provider(TEST_CIPHER_SUITE);
788
789 let test_auth = auth_content_from_proposal(
790 Proposal::Remove(RemoveProposal {
791 to_remove: LeafIndex::unchecked(0),
792 }),
793 Sender::External(0),
794 );
795
796 let expected_ref = ProposalRef::from_content(&cs, &test_auth).await.unwrap();
797
798 let test_message = MlsMessage {
799 version: TEST_PROTOCOL_VERSION,
800 payload: MlsMessagePayload::Plain(PublicMessage {
801 content: test_auth.content,
802 auth: test_auth.auth,
803 membership_tag: Some(cs.mac(&[1, 2, 3], &[1, 2, 3]).await.unwrap().into()),
804 }),
805 };
806
807 let computed_ref = test_message
808 .into_proposal_reference(&cs)
809 .await
810 .unwrap()
811 .unwrap();
812
813 assert_eq!(computed_ref, expected_ref.to_vec());
814 }
815
816 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
817 async fn message_description() {
818 let mut group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
819
820 let message = group.commit(vec![]).await.unwrap();
821
822 let expected = MlsMessageDescription::PublicProtocolMessage {
823 group_id: group.group_id(),
824 epoch_id: group.context().epoch,
825 content_type: ContentType::Commit,
826 sender: Sender::Member(0),
827 authenticated_data: &[],
828 };
829
830 assert_eq!(message.commit_message.description(), expected);
831
832 group.apply_pending_commit().await.unwrap();
833
834 let message = group
835 .encrypt_application_message(b"123", vec![])
836 .await
837 .unwrap();
838
839 let expected = MlsMessageDescription::PrivateProtocolMessage {
840 group_id: group.group_id(),
841 epoch_id: group.context().epoch,
842 content_type: ContentType::Application,
843 };
844
845 assert_eq!(message.description(), expected);
846
847 let group_info = group
848 .group_info_message_allowing_ext_commit(true)
849 .await
850 .unwrap();
851
852 assert_eq!(group_info.description(), MlsMessageDescription::GroupInfo);
853
854 let key_package =
855 test_key_package_message(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "something").await;
856
857 assert_eq!(key_package.description(), MlsMessageDescription::KeyPackage);
858 }
859}