mls_rs_core/
extension.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// Copyright by contributors to this project.
3// SPDX-License-Identifier: (Apache-2.0 OR MIT)
4
5use core::{
6    fmt::{self, Debug},
7    ops::Deref,
8};
9
10use crate::error::{AnyError, IntoAnyError};
11use alloc::vec::Vec;
12use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
13
14mod list;
15
16pub use list::*;
17
18/// Wrapper type representing an extension identifier along with default values
19/// defined by the MLS RFC.
20#[derive(
21    Debug, PartialEq, Eq, Hash, Clone, Copy, PartialOrd, Ord, MlsSize, MlsEncode, MlsDecode,
22)]
23#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
24#[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::ffi_type)]
25#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
26#[repr(transparent)]
27pub struct ExtensionType(u16);
28
29#[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen)]
30impl ExtensionType {
31    pub const APPLICATION_ID: ExtensionType = ExtensionType(1);
32    pub const RATCHET_TREE: ExtensionType = ExtensionType(2);
33    pub const REQUIRED_CAPABILITIES: ExtensionType = ExtensionType(3);
34    pub const EXTERNAL_PUB: ExtensionType = ExtensionType(4);
35    pub const EXTERNAL_SENDERS: ExtensionType = ExtensionType(5);
36
37    #[cfg(feature = "last_resort_key_package_ext")]
38    pub const LAST_RESORT_KEY_PACKAGE: ExtensionType = ExtensionType(0x000A);
39
40    /// Default extension types defined
41    /// in [RFC 9420](https://www.rfc-editor.org/rfc/rfc9420.html#name-leaf-node-contents)
42    pub const DEFAULT: &'static [ExtensionType] = &[
43        ExtensionType::APPLICATION_ID,
44        ExtensionType::RATCHET_TREE,
45        ExtensionType::REQUIRED_CAPABILITIES,
46        ExtensionType::EXTERNAL_PUB,
47        ExtensionType::EXTERNAL_SENDERS,
48    ];
49
50    /// Extension type from a raw value
51    pub const fn new(raw_value: u16) -> Self {
52        ExtensionType(raw_value)
53    }
54
55    /// Raw numerical wrapped value.
56    pub const fn raw_value(&self) -> u16 {
57        self.0
58    }
59
60    /// Determines if this extension type is required to be implemented
61    /// by the MLS RFC.
62    pub const fn is_default(&self) -> bool {
63        self.0 <= 5
64    }
65}
66
67impl From<u16> for ExtensionType {
68    fn from(value: u16) -> Self {
69        ExtensionType(value)
70    }
71}
72
73impl Deref for ExtensionType {
74    type Target = u16;
75
76    fn deref(&self) -> &Self::Target {
77        &self.0
78    }
79}
80
81#[derive(Debug)]
82#[cfg_attr(feature = "std", derive(thiserror::Error))]
83pub enum ExtensionError {
84    #[cfg_attr(feature = "std", error(transparent))]
85    SerializationError(AnyError),
86    #[cfg_attr(feature = "std", error(transparent))]
87    DeserializationError(AnyError),
88    #[cfg_attr(feature = "std", error("incorrect extension type: {0:?}"))]
89    IncorrectType(ExtensionType),
90}
91
92impl IntoAnyError for ExtensionError {
93    #[cfg(feature = "std")]
94    fn into_dyn_error(self) -> Result<Box<dyn std::error::Error + Send + Sync>, Self> {
95        Ok(self.into())
96    }
97}
98
99#[derive(Clone, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)]
100#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
101#[cfg_attr(
102    all(feature = "ffi", not(test)),
103    safer_ffi_gen::ffi_type(clone, opaque)
104)]
105#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
106#[non_exhaustive]
107/// An MLS protocol [extension](https://messaginglayersecurity.rocks/mls-protocol/draft-ietf-mls-protocol.html#name-extensions).
108///
109/// Extensions are used as customization points in various parts of the
110/// MLS protocol and are inserted into an [ExtensionList](self::ExtensionList).
111pub struct Extension {
112    /// Extension type of this extension
113    pub extension_type: ExtensionType,
114    /// Data held within this extension
115    #[mls_codec(with = "mls_rs_codec::byte_vec")]
116    #[cfg_attr(feature = "serde", serde(with = "crate::vec_serde"))]
117    pub extension_data: Vec<u8>,
118}
119
120impl Debug for Extension {
121    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
122        f.debug_struct("Extension")
123            .field("extension_type", &self.extension_type)
124            .field(
125                "extension_data",
126                &crate::debug::pretty_bytes(&self.extension_data),
127            )
128            .finish()
129    }
130}
131
132#[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen)]
133impl Extension {
134    /// Create an extension with specified type and data properties.
135    pub fn new(extension_type: ExtensionType, extension_data: Vec<u8>) -> Extension {
136        Extension {
137            extension_type,
138            extension_data,
139        }
140    }
141
142    /// Extension type of this extension
143    #[cfg(feature = "ffi")]
144    pub fn extension_type(&self) -> ExtensionType {
145        self.extension_type
146    }
147
148    /// Data held within this extension
149    #[cfg(feature = "ffi")]
150    pub fn extension_data(&self) -> &[u8] {
151        &self.extension_data
152    }
153}
154
155/// Trait used to convert a type to and from an [Extension]
156pub trait MlsExtension: Sized {
157    /// Error type of the underlying serializer that can convert this type into a `Vec<u8>`.
158    type SerializationError: IntoAnyError;
159
160    /// Error type of the underlying deserializer that can convert a `Vec<u8>` into this type.
161    type DeserializationError: IntoAnyError;
162
163    /// Extension type value that this type represents.
164    fn extension_type() -> ExtensionType;
165
166    /// Convert this type to opaque bytes.
167    fn to_bytes(&self) -> Result<Vec<u8>, Self::SerializationError>;
168
169    /// Create this type from opaque bytes.
170    fn from_bytes(data: &[u8]) -> Result<Self, Self::DeserializationError>;
171
172    /// Convert this type into an [Extension].
173    fn into_extension(self) -> Result<Extension, ExtensionError> {
174        Ok(Extension::new(
175            Self::extension_type(),
176            self.to_bytes()
177                .map_err(|e| ExtensionError::SerializationError(e.into_any_error()))?,
178        ))
179    }
180
181    /// Create this type from an [Extension].
182    fn from_extension(ext: &Extension) -> Result<Self, ExtensionError> {
183        if ext.extension_type != Self::extension_type() {
184            return Err(ExtensionError::IncorrectType(ext.extension_type));
185        }
186
187        Self::from_bytes(&ext.extension_data)
188            .map_err(|e| ExtensionError::DeserializationError(e.into_any_error()))
189    }
190}
191
192/// Convenience trait for custom extension types that use
193/// [mls_rs_codec] as an underlying serialization mechanism
194pub trait MlsCodecExtension: MlsSize + MlsEncode + MlsDecode {
195    fn extension_type() -> ExtensionType;
196}
197
198impl<T> MlsExtension for T
199where
200    T: MlsCodecExtension,
201{
202    type SerializationError = mls_rs_codec::Error;
203    type DeserializationError = mls_rs_codec::Error;
204
205    fn extension_type() -> ExtensionType {
206        <Self as MlsCodecExtension>::extension_type()
207    }
208
209    fn to_bytes(&self) -> Result<Vec<u8>, Self::SerializationError> {
210        self.mls_encode_to_vec()
211    }
212
213    fn from_bytes(data: &[u8]) -> Result<Self, Self::DeserializationError> {
214        Self::mls_decode(&mut &*data)
215    }
216}
217
218#[cfg(test)]
219mod tests {
220    use core::convert::Infallible;
221
222    use alloc::vec;
223    use alloc::vec::Vec;
224    use assert_matches::assert_matches;
225    use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
226
227    use super::{Extension, ExtensionError, ExtensionType, MlsCodecExtension, MlsExtension};
228
229    struct TestExtension;
230
231    #[derive(Debug, MlsSize, MlsEncode, MlsDecode)]
232    struct AnotherTestExtension;
233
234    impl MlsExtension for TestExtension {
235        type SerializationError = Infallible;
236        type DeserializationError = Infallible;
237
238        fn extension_type() -> super::ExtensionType {
239            ExtensionType(42)
240        }
241
242        fn to_bytes(&self) -> Result<Vec<u8>, Self::SerializationError> {
243            Ok(vec![0])
244        }
245
246        fn from_bytes(_data: &[u8]) -> Result<Self, Self::DeserializationError> {
247            Ok(TestExtension)
248        }
249    }
250
251    impl MlsCodecExtension for AnotherTestExtension {
252        fn extension_type() -> ExtensionType {
253            ExtensionType(43)
254        }
255    }
256
257    #[test]
258    fn into_extension() {
259        assert_eq!(
260            TestExtension.into_extension().unwrap(),
261            Extension::new(42.into(), vec![0])
262        )
263    }
264
265    #[test]
266    fn incorrect_type_is_discovered() {
267        let ext = Extension::new(42.into(), vec![0]);
268
269        assert_matches!(AnotherTestExtension::from_extension(&ext), Err(ExtensionError::IncorrectType(found)) if found == 42.into());
270    }
271}