Skip to main content

mls_rs_core/
extension.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// Copyright by contributors to this project.
3// SPDX-License-Identifier: (Apache-2.0 OR MIT)
4
5use core::{
6    fmt::{self, Debug},
7    ops::Deref,
8};
9
10use crate::error::{AnyError, IntoAnyError};
11use alloc::vec::Vec;
12use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
13
14mod list;
15
16pub use list::*;
17
18/// Wrapper type representing an extension identifier along with default values
19/// defined by the MLS RFC.
20#[derive(
21    Debug, PartialEq, Eq, Hash, Clone, Copy, PartialOrd, Ord, MlsSize, MlsEncode, MlsDecode,
22)]
23#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
24#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
25#[repr(transparent)]
26pub struct ExtensionType(u16);
27
28impl ExtensionType {
29    pub const APPLICATION_ID: ExtensionType = ExtensionType(1);
30    pub const RATCHET_TREE: ExtensionType = ExtensionType(2);
31    pub const REQUIRED_CAPABILITIES: ExtensionType = ExtensionType(3);
32    pub const EXTERNAL_PUB: ExtensionType = ExtensionType(4);
33    pub const EXTERNAL_SENDERS: ExtensionType = ExtensionType(5);
34
35    #[cfg(feature = "last_resort_key_package_ext")]
36    pub const LAST_RESORT_KEY_PACKAGE: ExtensionType = ExtensionType(0x000A);
37
38    /// Default extension types defined
39    /// in [RFC 9420](https://www.rfc-editor.org/rfc/rfc9420.html#name-leaf-node-contents)
40    pub const DEFAULT: &'static [ExtensionType] = &[
41        ExtensionType::APPLICATION_ID,
42        ExtensionType::RATCHET_TREE,
43        ExtensionType::REQUIRED_CAPABILITIES,
44        ExtensionType::EXTERNAL_PUB,
45        ExtensionType::EXTERNAL_SENDERS,
46    ];
47
48    /// Extension type from a raw value
49    pub const fn new(raw_value: u16) -> Self {
50        ExtensionType(raw_value)
51    }
52
53    /// Raw numerical wrapped value.
54    pub const fn raw_value(&self) -> u16 {
55        self.0
56    }
57
58    /// Determines if this extension type is required to be implemented
59    /// by the MLS RFC.
60    pub const fn is_default(&self) -> bool {
61        self.0 <= 5
62    }
63}
64
65impl From<u16> for ExtensionType {
66    fn from(value: u16) -> Self {
67        ExtensionType(value)
68    }
69}
70
71impl Deref for ExtensionType {
72    type Target = u16;
73
74    fn deref(&self) -> &Self::Target {
75        &self.0
76    }
77}
78
79#[derive(Debug)]
80#[cfg_attr(feature = "std", derive(thiserror::Error))]
81pub enum ExtensionError {
82    #[cfg_attr(feature = "std", error(transparent))]
83    SerializationError(AnyError),
84    #[cfg_attr(feature = "std", error(transparent))]
85    DeserializationError(AnyError),
86    #[cfg_attr(feature = "std", error("incorrect extension type: {0:?}"))]
87    IncorrectType(ExtensionType),
88}
89
90impl IntoAnyError for ExtensionError {
91    #[cfg(feature = "std")]
92    fn into_dyn_error(self) -> Result<Box<dyn std::error::Error + Send + Sync>, Self> {
93        Ok(self.into())
94    }
95}
96
97#[derive(Clone, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)]
98#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
99#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
100#[non_exhaustive]
101/// An MLS protocol [extension](https://messaginglayersecurity.rocks/mls-protocol/draft-ietf-mls-protocol.html#name-extensions).
102///
103/// Extensions are used as customization points in various parts of the
104/// MLS protocol and are inserted into an [ExtensionList](self::ExtensionList).
105pub struct Extension {
106    /// Extension type of this extension
107    pub extension_type: ExtensionType,
108    /// Data held within this extension
109    #[mls_codec(with = "mls_rs_codec::byte_vec")]
110    #[cfg_attr(feature = "serde", serde(with = "crate::vec_serde"))]
111    pub extension_data: Vec<u8>,
112}
113
114impl Debug for Extension {
115    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
116        f.debug_struct("Extension")
117            .field("extension_type", &self.extension_type)
118            .field(
119                "extension_data",
120                &crate::debug::pretty_bytes(&self.extension_data),
121            )
122            .finish()
123    }
124}
125
126impl Extension {
127    /// Create an extension with specified type and data properties.
128    pub fn new(extension_type: ExtensionType, extension_data: Vec<u8>) -> Extension {
129        Extension {
130            extension_type,
131            extension_data,
132        }
133    }
134
135    /// Extension type of this extension
136    pub fn extension_type(&self) -> ExtensionType {
137        self.extension_type
138    }
139
140    /// Data held within this extension
141    pub fn extension_data(&self) -> &[u8] {
142        &self.extension_data
143    }
144}
145
146/// Trait used to convert a type to and from an [Extension]
147pub trait MlsExtension: Sized {
148    /// Error type of the underlying serializer that can convert this type into a `Vec<u8>`.
149    type SerializationError: IntoAnyError;
150
151    /// Error type of the underlying deserializer that can convert a `Vec<u8>` into this type.
152    type DeserializationError: IntoAnyError;
153
154    /// Extension type value that this type represents.
155    fn extension_type() -> ExtensionType;
156
157    /// Convert this type to opaque bytes.
158    fn to_bytes(&self) -> Result<Vec<u8>, Self::SerializationError>;
159
160    /// Create this type from opaque bytes.
161    fn from_bytes(data: &[u8]) -> Result<Self, Self::DeserializationError>;
162
163    /// Convert this type into an [Extension].
164    fn into_extension(self) -> Result<Extension, ExtensionError> {
165        Ok(Extension::new(
166            Self::extension_type(),
167            self.to_bytes()
168                .map_err(|e| ExtensionError::SerializationError(e.into_any_error()))?,
169        ))
170    }
171
172    /// Create this type from an [Extension].
173    fn from_extension(ext: &Extension) -> Result<Self, ExtensionError> {
174        if ext.extension_type != Self::extension_type() {
175            return Err(ExtensionError::IncorrectType(ext.extension_type));
176        }
177
178        Self::from_bytes(&ext.extension_data)
179            .map_err(|e| ExtensionError::DeserializationError(e.into_any_error()))
180    }
181}
182
183/// Convenience trait for custom extension types that use
184/// [mls_rs_codec] as an underlying serialization mechanism
185pub trait MlsCodecExtension: MlsSize + MlsEncode + MlsDecode {
186    fn extension_type() -> ExtensionType;
187}
188
189impl<T> MlsExtension for T
190where
191    T: MlsCodecExtension,
192{
193    type SerializationError = mls_rs_codec::Error;
194    type DeserializationError = mls_rs_codec::Error;
195
196    fn extension_type() -> ExtensionType {
197        <Self as MlsCodecExtension>::extension_type()
198    }
199
200    fn to_bytes(&self) -> Result<Vec<u8>, Self::SerializationError> {
201        self.mls_encode_to_vec()
202    }
203
204    fn from_bytes(data: &[u8]) -> Result<Self, Self::DeserializationError> {
205        Self::mls_decode(&mut &*data)
206    }
207}
208
209#[cfg(test)]
210mod tests {
211    use core::convert::Infallible;
212
213    use alloc::vec;
214    use alloc::vec::Vec;
215    use assert_matches::assert_matches;
216    use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
217
218    use super::{Extension, ExtensionError, ExtensionType, MlsCodecExtension, MlsExtension};
219
220    struct TestExtension;
221
222    #[derive(Debug, MlsSize, MlsEncode, MlsDecode)]
223    struct AnotherTestExtension;
224
225    impl MlsExtension for TestExtension {
226        type SerializationError = Infallible;
227        type DeserializationError = Infallible;
228
229        fn extension_type() -> super::ExtensionType {
230            ExtensionType(42)
231        }
232
233        fn to_bytes(&self) -> Result<Vec<u8>, Self::SerializationError> {
234            Ok(vec![0])
235        }
236
237        fn from_bytes(_data: &[u8]) -> Result<Self, Self::DeserializationError> {
238            Ok(TestExtension)
239        }
240    }
241
242    impl MlsCodecExtension for AnotherTestExtension {
243        fn extension_type() -> ExtensionType {
244            ExtensionType(43)
245        }
246    }
247
248    #[test]
249    fn into_extension() {
250        assert_eq!(
251            TestExtension.into_extension().unwrap(),
252            Extension::new(42.into(), vec![0])
253        )
254    }
255
256    #[test]
257    fn incorrect_type_is_discovered() {
258        let ext = Extension::new(42.into(), vec![0]);
259
260        assert_matches!(AnotherTestExtension::from_extension(&ext), Err(ExtensionError::IncorrectType(found)) if found == 42.into());
261    }
262}