use tls_codec::{Deserialize, Serialize};
use super::*;
use crate::error::LibraryError;
#[derive(PartialEq, Debug, Clone)]
pub(crate) enum MlsMessage {
Plaintext(Box<VerifiableMlsPlaintext>),
Ciphertext(Box<MlsCiphertext>),
}
impl MlsMessage {
fn wire_format(&self) -> WireFormat {
match self {
MlsMessage::Ciphertext(_) => WireFormat::MlsCiphertext,
MlsMessage::Plaintext(_) => WireFormat::MlsPlaintext,
}
}
fn group_id(&self) -> &GroupId {
match self {
MlsMessage::Ciphertext(m) => m.group_id(),
MlsMessage::Plaintext(m) => m.group_id(),
}
}
fn epoch(&self) -> GroupEpoch {
match self {
MlsMessage::Ciphertext(m) => m.epoch(),
MlsMessage::Plaintext(m) => m.epoch(),
}
}
fn content_type(&self) -> ContentType {
match self {
MlsMessage::Ciphertext(m) => m.content_type(),
MlsMessage::Plaintext(m) => m.content_type(),
}
}
fn is_handshake_message(&self) -> bool {
self.content_type().is_handshake_message()
}
fn try_from_bytes(mut bytes: &[u8]) -> Result<Self, MlsMessageError> {
MlsMessage::tls_deserialize(&mut bytes).map_err(|_| MlsMessageError::UnableToDecode)
}
fn to_bytes(&self) -> Result<Vec<u8>, MlsMessageError> {
Ok(self
.tls_serialize_detached()
.map_err(LibraryError::missing_bound_check)?)
}
}
#[derive(Debug, Clone, TlsSerialize, TlsDeserialize, TlsSize)]
pub struct MlsMessageIn {
pub(crate) mls_message: MlsMessage,
}
impl MlsMessageIn {
pub fn wire_format(&self) -> WireFormat {
self.mls_message.wire_format()
}
pub fn group_id(&self) -> &GroupId {
self.mls_message.group_id()
}
pub fn epoch(&self) -> GroupEpoch {
self.mls_message.epoch()
}
pub fn content_type(&self) -> ContentType {
self.mls_message.content_type()
}
pub fn is_handshake_message(&self) -> bool {
self.mls_message.is_handshake_message()
}
pub fn try_from_bytes(bytes: &[u8]) -> Result<Self, MlsMessageError> {
Ok(Self {
mls_message: MlsMessage::try_from_bytes(bytes)?,
})
}
pub fn to_bytes(&self) -> Result<Vec<u8>, MlsMessageError> {
self.mls_message.to_bytes()
}
}
#[derive(PartialEq, Debug, Clone, TlsSerialize, TlsDeserialize, TlsSize)]
pub struct MlsMessageOut {
pub(crate) mls_message: MlsMessage,
}
impl From<VerifiableMlsPlaintext> for MlsMessageOut {
fn from(plaintext: VerifiableMlsPlaintext) -> Self {
Self {
mls_message: MlsMessage::Plaintext(Box::new(plaintext)),
}
}
}
impl From<MlsPlaintext> for MlsMessageOut {
fn from(plaintext: MlsPlaintext) -> Self {
Self {
mls_message: MlsMessage::Plaintext(Box::new(VerifiableMlsPlaintext::from_plaintext(
plaintext, None,
))),
}
}
}
impl From<MlsCiphertext> for MlsMessageOut {
fn from(ciphertext: MlsCiphertext) -> Self {
Self {
mls_message: MlsMessage::Ciphertext(Box::new(ciphertext)),
}
}
}
impl MlsMessageOut {
pub fn wire_format(&self) -> WireFormat {
self.mls_message.wire_format()
}
pub fn group_id(&self) -> &GroupId {
self.mls_message.group_id()
}
pub fn epoch(&self) -> GroupEpoch {
self.mls_message.epoch()
}
pub fn content_type(&self) -> ContentType {
self.mls_message.content_type()
}
pub fn is_handshake_message(&self) -> bool {
self.mls_message.is_handshake_message()
}
pub fn try_from_bytes(bytes: &[u8]) -> Result<Self, MlsMessageError> {
Ok(Self {
mls_message: MlsMessage::try_from_bytes(bytes)?,
})
}
pub fn to_bytes(&self) -> Result<Vec<u8>, MlsMessageError> {
self.mls_message.to_bytes()
}
}
impl From<MlsMessageOut> for MlsMessageIn {
fn from(message: MlsMessageOut) -> Self {
MlsMessageIn {
mls_message: message.mls_message,
}
}
}
#[cfg(any(feature = "test-utils", test))]
impl From<VerifiableMlsPlaintext> for MlsMessageIn {
fn from(plaintext: VerifiableMlsPlaintext) -> Self {
Self {
mls_message: MlsMessage::Plaintext(Box::new(plaintext)),
}
}
}
#[cfg(any(feature = "test-utils", test))]
impl From<MlsCiphertext> for MlsMessageIn {
fn from(ciphertext: MlsCiphertext) -> Self {
Self {
mls_message: MlsMessage::Ciphertext(Box::new(ciphertext)),
}
}
}