1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
//! MLS Message types
//!
//! This module defines two opaque message types that are used by the [`MlsGroup`](crate::group::mls_group::MlsGroup) API.
//! [`MlsMessageIn`] is used for messages between the Delivery Service and the client. It can be instantiated
//! from a byte slice.
//! [`MlsMessageOut`] is returned by various functions of the [`MlsGroup`](crate::group::mls_group::MlsGroup) API.
//! It is to be used between the client and the Delivery Service. It can be serialized to a byte vector.
//!
//! Both messages have the same API. The framing part of the message can be inspected through it. In particular,
//! it is important to look at [`MlsMessageIn::group_id()`] to determine in which
//! [`MlsGroup`](crate::group::mls_group::MlsGroup) it should be processed.

use tls_codec::{Deserialize, Serialize};

use super::*;

use crate::error::LibraryError;

/// Unified message type for MLS messages.
/// /// This is only used internally, externally we use either [`MlsMessageIn`] or
/// [`MlsMessageOut`], depending on the context.
/// Since the memory footprint can differ considerably between [`VerifiableMlsPlaintext`]
/// and [`MlsCiphertext`], we use `Box<T>` for more efficient memory allocation.
#[derive(PartialEq, Debug, Clone)]
pub(crate) enum MlsMessage {
    /// Plaintext message
    Plaintext(Box<VerifiableMlsPlaintext>),

    /// Ciphertext message
    Ciphertext(Box<MlsCiphertext>),
}

impl MlsMessage {
    /// Returns the wire format.
    fn wire_format(&self) -> WireFormat {
        match self {
            MlsMessage::Ciphertext(_) => WireFormat::MlsCiphertext,
            MlsMessage::Plaintext(_) => WireFormat::MlsPlaintext,
        }
    }

    /// Returns the group ID.
    fn group_id(&self) -> &GroupId {
        match self {
            MlsMessage::Ciphertext(m) => m.group_id(),
            MlsMessage::Plaintext(m) => m.group_id(),
        }
    }

    /// Returns the epoch.
    fn epoch(&self) -> GroupEpoch {
        match self {
            MlsMessage::Ciphertext(m) => m.epoch(),
            MlsMessage::Plaintext(m) => m.epoch(),
        }
    }

    /// Returns the content type.
    fn content_type(&self) -> ContentType {
        match self {
            MlsMessage::Ciphertext(m) => m.content_type(),
            MlsMessage::Plaintext(m) => m.content_type(),
        }
    }

    /// Returns `true` if this is a handshake message and `false` otherwise.
    fn is_handshake_message(&self) -> bool {
        self.content_type().is_handshake_message()
    }

    /// Tries to deserialize from a byte slice. Returns [`MlsMessageError::DecodingError`] on failure.
    fn try_from_bytes(mut bytes: &[u8]) -> Result<Self, MlsMessageError> {
        MlsMessage::tls_deserialize(&mut bytes).map_err(|_| MlsMessageError::UnableToDecode)
    }

    /// Serializes the message to a byte vector. Returns [`MlsMessageError::EncodingError`] on failure.
    fn to_bytes(&self) -> Result<Vec<u8>, MlsMessageError> {
        Ok(self
            .tls_serialize_detached()
            .map_err(LibraryError::missing_bound_check)?)
    }
}

/// Unified message type for incoming MLS messages.
#[derive(Debug, Clone, TlsSerialize, TlsDeserialize, TlsSize)]
pub struct MlsMessageIn {
    pub(crate) mls_message: MlsMessage,
}

impl MlsMessageIn {
    /// Returns the wire format.
    pub fn wire_format(&self) -> WireFormat {
        self.mls_message.wire_format()
    }

    /// Returns the group ID.
    pub fn group_id(&self) -> &GroupId {
        self.mls_message.group_id()
    }

    /// Returns the epoch.
    pub fn epoch(&self) -> GroupEpoch {
        self.mls_message.epoch()
    }

    /// Returns the content type.
    pub fn content_type(&self) -> ContentType {
        self.mls_message.content_type()
    }

    /// Returns `true` if this is a handshake message and `false` otherwise.
    pub fn is_handshake_message(&self) -> bool {
        self.mls_message.is_handshake_message()
    }

    /// Tries to deserialize from a byte slice. Returns [`MlsMessageError::UnableToDecode`] on failure.
    pub fn try_from_bytes(bytes: &[u8]) -> Result<Self, MlsMessageError> {
        Ok(Self {
            mls_message: MlsMessage::try_from_bytes(bytes)?,
        })
    }

    /// Serializes the message to a byte vector. Returns [`MlsMessageError::LibraryError`] on failure.
    pub fn to_bytes(&self) -> Result<Vec<u8>, MlsMessageError> {
        self.mls_message.to_bytes()
    }
}

/// Unified message type for outgoing MLS messages.
#[derive(PartialEq, Debug, Clone, TlsSerialize, TlsDeserialize, TlsSize)]
pub struct MlsMessageOut {
    pub(crate) mls_message: MlsMessage,
}

impl From<VerifiableMlsPlaintext> for MlsMessageOut {
    fn from(plaintext: VerifiableMlsPlaintext) -> Self {
        Self {
            mls_message: MlsMessage::Plaintext(Box::new(plaintext)),
        }
    }
}

impl From<MlsPlaintext> for MlsMessageOut {
    fn from(plaintext: MlsPlaintext) -> Self {
        Self {
            mls_message: MlsMessage::Plaintext(Box::new(VerifiableMlsPlaintext::from_plaintext(
                plaintext, None,
            ))),
        }
    }
}

impl From<MlsCiphertext> for MlsMessageOut {
    fn from(ciphertext: MlsCiphertext) -> Self {
        Self {
            mls_message: MlsMessage::Ciphertext(Box::new(ciphertext)),
        }
    }
}

impl MlsMessageOut {
    /// Returns the wire format.
    pub fn wire_format(&self) -> WireFormat {
        self.mls_message.wire_format()
    }

    /// Returns the group ID.
    pub fn group_id(&self) -> &GroupId {
        self.mls_message.group_id()
    }

    /// Returns the epoch.
    pub fn epoch(&self) -> GroupEpoch {
        self.mls_message.epoch()
    }

    /// Returns the content type.
    pub fn content_type(&self) -> ContentType {
        self.mls_message.content_type()
    }

    /// Returns `true` if this is a handshake message and `false` otherwise.
    pub fn is_handshake_message(&self) -> bool {
        self.mls_message.is_handshake_message()
    }

    /// Tries to deserialize from a byte slice. Returns [`MlsMessageError::UnableToDecode`] on failure.
    pub fn try_from_bytes(bytes: &[u8]) -> Result<Self, MlsMessageError> {
        Ok(Self {
            mls_message: MlsMessage::try_from_bytes(bytes)?,
        })
    }

    /// Serializes the message to a byte vector. Returns [`MlsMessageError::LibraryError`] on failure.
    pub fn to_bytes(&self) -> Result<Vec<u8>, MlsMessageError> {
        self.mls_message.to_bytes()
    }
}

impl From<MlsMessageOut> for MlsMessageIn {
    fn from(message: MlsMessageOut) -> Self {
        MlsMessageIn {
            mls_message: message.mls_message,
        }
    }
}

#[cfg(any(feature = "test-utils", test))]
impl From<VerifiableMlsPlaintext> for MlsMessageIn {
    fn from(plaintext: VerifiableMlsPlaintext) -> Self {
        Self {
            mls_message: MlsMessage::Plaintext(Box::new(plaintext)),
        }
    }
}

#[cfg(any(feature = "test-utils", test))]
impl From<MlsCiphertext> for MlsMessageIn {
    fn from(ciphertext: MlsCiphertext) -> Self {
        Self {
            mls_message: MlsMessage::Ciphertext(Box::new(ciphertext)),
        }
    }
}