#[cfg(feature = "by_ref_proposal")]
use super::proposal::Proposal;
use super::*;
use crate::{client::MlsError, protocol_version::ProtocolVersion};
use alloc::vec::Vec;
use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
use zeroize::ZeroizeOnDrop;
#[cfg(feature = "private_message")]
use alloc::boxed::Box;
#[derive(Copy, Clone, Debug, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)]
#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
#[repr(u8)]
pub enum ContentType {
#[cfg(feature = "private_message")]
Application = 1u8,
#[cfg(feature = "by_ref_proposal")]
Proposal = 2u8,
Commit = 3u8,
}
impl From<&Content> for ContentType {
fn from(content: &Content) -> Self {
match content {
#[cfg(feature = "private_message")]
Content::Application(_) => ContentType::Application,
#[cfg(feature = "by_ref_proposal")]
Content::Proposal(_) => ContentType::Proposal,
Content::Commit(_) => ContentType::Commit,
}
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)]
#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
#[repr(u8)]
#[non_exhaustive]
pub enum Sender {
Member(u32) = 1u8,
#[cfg(feature = "by_ref_proposal")]
External(u32) = 2u8,
#[cfg(feature = "by_ref_proposal")]
NewMemberProposal = 3u8,
NewMemberCommit = 4u8,
}
impl From<LeafIndex> for Sender {
fn from(leaf_index: LeafIndex) -> Self {
Sender::Member(*leaf_index)
}
}
impl From<u32> for Sender {
fn from(leaf_index: u32) -> Self {
Sender::Member(leaf_index)
}
}
#[derive(Clone, Debug, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode, ZeroizeOnDrop)]
#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
pub struct ApplicationData(#[mls_codec(with = "mls_rs_codec::byte_vec")] Vec<u8>);
impl From<Vec<u8>> for ApplicationData {
fn from(data: Vec<u8>) -> Self {
Self(data)
}
}
impl Deref for ApplicationData {
type Target = [u8];
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl ApplicationData {
pub fn as_bytes(&self) -> &[u8] {
&self.0
}
}
#[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode)]
#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
#[repr(u8)]
pub(crate) enum Content {
#[cfg(feature = "private_message")]
Application(ApplicationData) = 1u8,
#[cfg(feature = "by_ref_proposal")]
Proposal(alloc::boxed::Box<Proposal>) = 2u8,
Commit(alloc::boxed::Box<Commit>) = 3u8,
}
impl Content {
pub fn content_type(&self) -> ContentType {
self.into()
}
}
#[derive(Clone, Debug, PartialEq)]
#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
pub(crate) struct PublicMessage {
pub content: FramedContent,
pub auth: FramedContentAuthData,
pub membership_tag: Option<MembershipTag>,
}
impl MlsSize for PublicMessage {
fn mls_encoded_len(&self) -> usize {
self.content.mls_encoded_len()
+ self.auth.mls_encoded_len()
+ self
.membership_tag
.as_ref()
.map_or(0, |tag| tag.mls_encoded_len())
}
}
impl MlsEncode for PublicMessage {
fn mls_encode(&self, writer: &mut Vec<u8>) -> Result<(), mls_rs_codec::Error> {
self.content.mls_encode(writer)?;
self.auth.mls_encode(writer)?;
self.membership_tag
.as_ref()
.map_or(Ok(()), |tag| tag.mls_encode(writer))
}
}
impl MlsDecode for PublicMessage {
fn mls_decode(reader: &mut &[u8]) -> Result<Self, mls_rs_codec::Error> {
let content = FramedContent::mls_decode(reader)?;
let auth = FramedContentAuthData::mls_decode(reader, content.content_type())?;
let membership_tag = match content.sender {
Sender::Member(_) => Some(MembershipTag::mls_decode(reader)?),
_ => None,
};
Ok(Self {
content,
auth,
membership_tag,
})
}
}
#[cfg(feature = "private_message")]
#[derive(Clone, Debug, PartialEq)]
pub(crate) struct PrivateMessageContent {
pub content: Content,
pub auth: FramedContentAuthData,
}
#[cfg(feature = "private_message")]
impl MlsSize for PrivateMessageContent {
fn mls_encoded_len(&self) -> usize {
let content_len_without_type = match &self.content {
Content::Application(c) => c.mls_encoded_len(),
#[cfg(feature = "by_ref_proposal")]
Content::Proposal(c) => c.mls_encoded_len(),
Content::Commit(c) => c.mls_encoded_len(),
};
content_len_without_type + self.auth.mls_encoded_len()
}
}
#[cfg(feature = "private_message")]
impl MlsEncode for PrivateMessageContent {
fn mls_encode(&self, writer: &mut Vec<u8>) -> Result<(), mls_rs_codec::Error> {
match &self.content {
Content::Application(c) => c.mls_encode(writer),
#[cfg(feature = "by_ref_proposal")]
Content::Proposal(c) => c.mls_encode(writer),
Content::Commit(c) => c.mls_encode(writer),
}?;
self.auth.mls_encode(writer)?;
Ok(())
}
}
#[cfg(feature = "private_message")]
impl PrivateMessageContent {
pub(crate) fn mls_decode(
reader: &mut &[u8],
content_type: ContentType,
) -> Result<Self, mls_rs_codec::Error> {
let content = match content_type {
ContentType::Application => Content::Application(ApplicationData::mls_decode(reader)?),
#[cfg(feature = "by_ref_proposal")]
ContentType::Proposal => Content::Proposal(Box::new(Proposal::mls_decode(reader)?)),
ContentType::Commit => {
Content::Commit(alloc::boxed::Box::new(Commit::mls_decode(reader)?))
}
};
let auth = FramedContentAuthData::mls_decode(reader, content.content_type())?;
if reader.iter().any(|&i| i != 0u8) {
return Err(mls_rs_codec::Error::Custom(5));
}
Ok(Self { content, auth })
}
}
#[cfg(feature = "private_message")]
#[derive(Clone, Debug, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)]
pub struct PrivateContentAAD {
#[mls_codec(with = "mls_rs_codec::byte_vec")]
pub group_id: Vec<u8>,
pub epoch: u64,
pub content_type: ContentType,
#[mls_codec(with = "mls_rs_codec::byte_vec")]
pub authenticated_data: Vec<u8>,
}
#[cfg(feature = "private_message")]
#[derive(Clone, Debug, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)]
#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
pub struct PrivateMessage {
#[mls_codec(with = "mls_rs_codec::byte_vec")]
pub group_id: Vec<u8>,
pub epoch: u64,
pub content_type: ContentType,
#[mls_codec(with = "mls_rs_codec::byte_vec")]
pub authenticated_data: Vec<u8>,
#[mls_codec(with = "mls_rs_codec::byte_vec")]
pub encrypted_sender_data: Vec<u8>,
#[mls_codec(with = "mls_rs_codec::byte_vec")]
pub ciphertext: Vec<u8>,
}
#[cfg(feature = "private_message")]
impl From<&PrivateMessage> for PrivateContentAAD {
fn from(ciphertext: &PrivateMessage) -> Self {
Self {
group_id: ciphertext.group_id.clone(),
epoch: ciphertext.epoch,
content_type: ciphertext.content_type,
authenticated_data: ciphertext.authenticated_data.clone(),
}
}
}
#[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode)]
#[cfg_attr(
all(feature = "ffi", not(test)),
::safer_ffi_gen::ffi_type(clone, opaque)
)]
#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
pub struct MlsMessage {
pub(crate) version: ProtocolVersion,
pub(crate) payload: MlsMessagePayload,
}
#[cfg_attr(all(feature = "ffi", not(test)), ::safer_ffi_gen::safer_ffi_gen)]
#[allow(dead_code)]
impl MlsMessage {
pub(crate) fn new(version: ProtocolVersion, payload: MlsMessagePayload) -> MlsMessage {
Self { version, payload }
}
#[inline(always)]
pub(crate) fn into_plaintext(self) -> Option<PublicMessage> {
match self.payload {
MlsMessagePayload::Plain(plaintext) => Some(plaintext),
_ => None,
}
}
#[cfg(feature = "private_message")]
#[inline(always)]
pub(crate) fn into_ciphertext(self) -> Option<PrivateMessage> {
match self.payload {
MlsMessagePayload::Cipher(ciphertext) => Some(ciphertext),
_ => None,
}
}
#[inline(always)]
pub(crate) fn into_welcome(self) -> Option<Welcome> {
match self.payload {
MlsMessagePayload::Welcome(welcome) => Some(welcome),
_ => None,
}
}
#[inline(always)]
pub(crate) fn into_group_info(self) -> Option<GroupInfo> {
match self.payload {
MlsMessagePayload::GroupInfo(info) => Some(info),
_ => None,
}
}
#[inline(always)]
pub(crate) fn into_key_package(self) -> Option<KeyPackage> {
match self.payload {
MlsMessagePayload::KeyPackage(kp) => Some(kp),
_ => None,
}
}
pub fn wire_format(&self) -> WireFormat {
match self.payload {
MlsMessagePayload::Plain(_) => WireFormat::PublicMessage,
#[cfg(feature = "private_message")]
MlsMessagePayload::Cipher(_) => WireFormat::PrivateMessage,
MlsMessagePayload::Welcome(_) => WireFormat::Welcome,
MlsMessagePayload::GroupInfo(_) => WireFormat::GroupInfo,
MlsMessagePayload::KeyPackage(_) => WireFormat::KeyPackage,
}
}
#[cfg_attr(all(feature = "ffi", not(test)), ::safer_ffi_gen::safer_ffi_gen_ignore)]
pub fn epoch(&self) -> Option<u64> {
match &self.payload {
MlsMessagePayload::Plain(p) => Some(p.content.epoch),
#[cfg(feature = "private_message")]
MlsMessagePayload::Cipher(c) => Some(c.epoch),
MlsMessagePayload::GroupInfo(gi) => Some(gi.group_context.epoch),
_ => None,
}
}
#[cfg_attr(all(feature = "ffi", not(test)), ::safer_ffi_gen::safer_ffi_gen_ignore)]
pub fn cipher_suite(&self) -> Option<CipherSuite> {
match &self.payload {
MlsMessagePayload::GroupInfo(i) => Some(i.group_context.cipher_suite),
MlsMessagePayload::Welcome(w) => Some(w.cipher_suite),
MlsMessagePayload::KeyPackage(k) => Some(k.cipher_suite),
_ => None,
}
}
pub fn group_id(&self) -> Option<&[u8]> {
match &self.payload {
MlsMessagePayload::Plain(p) => Some(&p.content.group_id),
#[cfg(feature = "private_message")]
MlsMessagePayload::Cipher(p) => Some(&p.group_id),
MlsMessagePayload::GroupInfo(p) => Some(&p.group_context.group_id),
MlsMessagePayload::KeyPackage(_) | MlsMessagePayload::Welcome(_) => None,
}
}
#[inline(never)]
pub fn from_bytes(bytes: &[u8]) -> Result<Self, MlsError> {
Self::mls_decode(&mut &*bytes).map_err(Into::into)
}
pub fn to_bytes(&self) -> Result<Vec<u8>, MlsError> {
self.mls_encode_to_vec().map_err(Into::into)
}
#[cfg(feature = "custom_proposal")]
pub fn custom_proposals_by_value(&self) -> Vec<&CustomProposal> {
match &self.payload {
MlsMessagePayload::Plain(plaintext) => match &plaintext.content.content {
Content::Commit(commit) => Self::find_custom_proposals(commit),
_ => Vec::new(),
},
_ => Vec::new(),
}
}
pub fn welcome_key_package_references(&self) -> Vec<&KeyPackageRef> {
let MlsMessagePayload::Welcome(welcome) = &self.payload else {
return Vec::new();
};
welcome.secrets.iter().map(|s| &s.new_member).collect()
}
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
pub async fn key_package_reference<C: CipherSuiteProvider>(
&self,
cipher_suite: &C,
) -> Result<Option<KeyPackageRef>, MlsError> {
let MlsMessagePayload::KeyPackage(kp) = &self.payload else {
return Ok(None);
};
kp.to_reference(cipher_suite).await.map(Some)
}
}
#[cfg(feature = "custom_proposal")]
impl MlsMessage {
fn find_custom_proposals(commit: &Commit) -> Vec<&CustomProposal> {
commit
.proposals
.iter()
.filter_map(|p| match p {
ProposalOrRef::Proposal(p) => match p.as_ref() {
Proposal::Custom(p) => Some(p),
_ => None,
},
_ => None,
})
.collect()
}
}
#[allow(clippy::large_enum_variant)]
#[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode)]
#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
#[repr(u16)]
pub(crate) enum MlsMessagePayload {
Plain(PublicMessage) = 1u16,
#[cfg(feature = "private_message")]
Cipher(PrivateMessage) = 2u16,
Welcome(Welcome) = 3u16,
GroupInfo(GroupInfo) = 4u16,
KeyPackage(KeyPackage) = 5u16,
}
impl From<PublicMessage> for MlsMessagePayload {
fn from(m: PublicMessage) -> Self {
Self::Plain(m)
}
}
#[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::ffi_type)]
#[derive(
Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd, MlsSize, MlsEncode, MlsDecode,
)]
#[repr(u16)]
#[non_exhaustive]
pub enum WireFormat {
PublicMessage = 1u16,
PrivateMessage = 2u16,
Welcome = 3u16,
GroupInfo = 4u16,
KeyPackage = 5u16,
}
#[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode)]
#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
pub(crate) struct FramedContent {
#[mls_codec(with = "mls_rs_codec::byte_vec")]
pub group_id: Vec<u8>,
pub epoch: u64,
pub sender: Sender,
#[mls_codec(with = "mls_rs_codec::byte_vec")]
pub authenticated_data: Vec<u8>,
pub content: Content,
}
impl FramedContent {
pub fn content_type(&self) -> ContentType {
self.content.content_type()
}
}
#[cfg(test)]
pub(crate) mod test_utils {
#[cfg(feature = "private_message")]
use crate::group::test_utils::random_bytes;
use super::*;
use alloc::boxed::Box;
pub(crate) fn get_test_auth_content() -> AuthenticatedContent {
let commit = Commit {
proposals: Default::default(),
path: None,
};
AuthenticatedContent {
wire_format: WireFormat::PublicMessage,
content: FramedContent {
group_id: Vec::new(),
epoch: 0,
sender: Sender::Member(1),
authenticated_data: Vec::new(),
content: Content::Commit(Box::new(commit)),
},
auth: FramedContentAuthData {
signature: MessageSignature::empty(),
confirmation_tag: None,
},
}
}
#[cfg(feature = "private_message")]
pub(crate) fn get_test_ciphertext_content() -> PrivateMessageContent {
PrivateMessageContent {
content: Content::Application(random_bytes(1024).into()),
auth: FramedContentAuthData {
signature: MessageSignature::from(random_bytes(128)),
confirmation_tag: None,
},
}
}
impl AsRef<[u8]> for ApplicationData {
fn as_ref(&self) -> &[u8] {
&self.0
}
}
}
#[cfg(feature = "private_message")]
#[cfg(test)]
mod tests {
use assert_matches::assert_matches;
use crate::group::framing::test_utils::get_test_ciphertext_content;
use super::*;
#[test]
fn test_mls_ciphertext_content_mls_encoding() {
let ciphertext_content = get_test_ciphertext_content();
let mut encoded = ciphertext_content.mls_encode_to_vec().unwrap();
encoded.extend_from_slice(&[0u8; 128]);
let decoded =
PrivateMessageContent::mls_decode(&mut &*encoded, (&ciphertext_content.content).into())
.unwrap();
assert_eq!(ciphertext_content, decoded);
}
#[test]
fn test_mls_ciphertext_content_non_zero_padding_error() {
let ciphertext_content = get_test_ciphertext_content();
let mut encoded = ciphertext_content.mls_encode_to_vec().unwrap();
encoded.extend_from_slice(&[1u8; 128]);
let decoded =
PrivateMessageContent::mls_decode(&mut &*encoded, (&ciphertext_content.content).into());
assert_matches!(decoded, Err(mls_rs_codec::Error::Custom(_)));
}
}