use super::*;
use crate::{
key_packages::KeyPackageIn, messages::group_info::VerifiableGroupInfo,
versions::ProtocolVersion,
};
#[derive(PartialEq, Debug, Clone, TlsSize)]
#[cfg_attr(feature = "test-utils", derive(TlsSerialize))]
pub struct MlsMessageIn {
pub(crate) version: ProtocolVersion,
pub(crate) body: MlsMessageBodyIn,
}
#[derive(Debug, PartialEq, Clone, TlsDeserialize, TlsDeserializeBytes, TlsSize)]
#[cfg_attr(feature = "test-utils", derive(TlsSerialize))]
#[repr(u16)]
pub enum MlsMessageBodyIn {
#[tls_codec(discriminant = 1)]
PublicMessage(PublicMessageIn),
#[tls_codec(discriminant = 2)]
PrivateMessage(PrivateMessageIn),
#[tls_codec(discriminant = 3)]
Welcome(Welcome),
#[tls_codec(discriminant = 4)]
GroupInfo(VerifiableGroupInfo),
#[tls_codec(discriminant = 5)]
KeyPackage(KeyPackageIn),
}
impl MlsMessageIn {
pub fn wire_format(&self) -> WireFormat {
match self.body {
MlsMessageBodyIn::PrivateMessage(_) => WireFormat::PrivateMessage,
MlsMessageBodyIn::PublicMessage(_) => WireFormat::PublicMessage,
MlsMessageBodyIn::Welcome(_) => WireFormat::Welcome,
MlsMessageBodyIn::GroupInfo(_) => WireFormat::GroupInfo,
MlsMessageBodyIn::KeyPackage(_) => WireFormat::KeyPackage,
}
}
pub fn extract(self) -> MlsMessageBodyIn {
self.body
}
pub fn try_into_protocol_message(self) -> Result<ProtocolMessage, ProtocolMessageError> {
self.try_into()
}
#[cfg(any(test, feature = "test-utils"))]
pub fn into_keypackage(self) -> Option<crate::key_packages::KeyPackage> {
match self.body {
MlsMessageBodyIn::KeyPackage(key_package) => {
debug_assert!(key_package.version_is_supported(self.version));
Some(key_package.into())
}
_ => None,
}
}
#[cfg(test)]
pub(crate) fn into_plaintext(self) -> Option<PublicMessage> {
match self.body {
MlsMessageBodyIn::PublicMessage(m) => Some(m.into()),
_ => None,
}
}
#[cfg(test)]
pub(crate) fn into_ciphertext(self) -> Option<PrivateMessageIn> {
match self.body {
MlsMessageBodyIn::PrivateMessage(m) => Some(m),
_ => None,
}
}
#[cfg(any(feature = "test-utils", test))]
pub fn into_welcome(self) -> Option<Welcome> {
match self.body {
MlsMessageBodyIn::Welcome(w) => Some(w),
_ => None,
}
}
#[cfg(any(feature = "test-utils", test))]
pub fn into_protocol_message(self) -> Option<ProtocolMessage> {
match self.body {
MlsMessageBodyIn::PublicMessage(m) => Some(m.into()),
MlsMessageBodyIn::PrivateMessage(m) => Some(m.into()),
_ => None,
}
}
#[cfg(any(feature = "test-utils", test))]
pub fn into_verifiable_group_info(self) -> Option<VerifiableGroupInfo> {
match self.body {
MlsMessageBodyIn::GroupInfo(group_info) => Some(group_info),
_ => None,
}
}
}
#[derive(Debug, Clone)]
pub enum ProtocolMessage {
PrivateMessage(PrivateMessageIn),
PublicMessage(Box<PublicMessageIn>),
}
impl ProtocolMessage {
pub fn wire_format(&self) -> WireFormat {
match self {
ProtocolMessage::PrivateMessage(_) => WireFormat::PrivateMessage,
ProtocolMessage::PublicMessage(_) => WireFormat::PublicMessage,
}
}
pub fn group_id(&self) -> &GroupId {
match self {
ProtocolMessage::PrivateMessage(ref m) => m.group_id(),
ProtocolMessage::PublicMessage(ref m) => m.group_id(),
}
}
pub fn epoch(&self) -> GroupEpoch {
match self {
ProtocolMessage::PrivateMessage(ref m) => m.epoch(),
ProtocolMessage::PublicMessage(ref m) => m.epoch(),
}
}
pub fn content_type(&self) -> ContentType {
match self {
ProtocolMessage::PrivateMessage(ref m) => m.content_type(),
ProtocolMessage::PublicMessage(ref m) => m.content_type(),
}
}
pub fn is_external(&self) -> bool {
match &self {
ProtocolMessage::PublicMessage(p) => {
matches!(
p.sender(),
Sender::NewMemberProposal | Sender::NewMemberCommit | Sender::External(_)
)
}
ProtocolMessage::PrivateMessage(_) => false,
}
}
pub fn is_handshake_message(&self) -> bool {
self.content_type().is_handshake_message()
}
}
impl From<PrivateMessageIn> for ProtocolMessage {
fn from(private_message: PrivateMessageIn) -> Self {
ProtocolMessage::PrivateMessage(private_message)
}
}
impl From<PublicMessageIn> for ProtocolMessage {
fn from(public_message: PublicMessageIn) -> Self {
ProtocolMessage::PublicMessage(Box::new(public_message))
}
}
impl TryFrom<MlsMessageIn> for ProtocolMessage {
type Error = ProtocolMessageError;
fn try_from(msg: MlsMessageIn) -> Result<Self, Self::Error> {
match msg.body {
MlsMessageBodyIn::PublicMessage(m) => Ok(m.into()),
MlsMessageBodyIn::PrivateMessage(m) => Ok(ProtocolMessage::PrivateMessage(m)),
_ => Err(ProtocolMessageError::WrongWireFormat),
}
}
}
#[cfg(any(feature = "test-utils", test))]
impl From<PublicMessage> for ProtocolMessage {
fn from(msg: PublicMessage) -> Self {
PublicMessageIn::from(msg).into()
}
}