1use core::{
6 fmt::{self, Debug},
7 ops::Deref,
8};
9
10use crate::error::{AnyError, IntoAnyError};
11use alloc::vec::Vec;
12use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
13
14mod list;
15
16pub use list::*;
17
18#[derive(
21 Debug, PartialEq, Eq, Hash, Clone, Copy, PartialOrd, Ord, MlsSize, MlsEncode, MlsDecode,
22)]
23#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
24#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
25#[repr(transparent)]
26pub struct ExtensionType(u16);
27
28impl ExtensionType {
29 pub const APPLICATION_ID: ExtensionType = ExtensionType(1);
30 pub const RATCHET_TREE: ExtensionType = ExtensionType(2);
31 pub const REQUIRED_CAPABILITIES: ExtensionType = ExtensionType(3);
32 pub const EXTERNAL_PUB: ExtensionType = ExtensionType(4);
33 pub const EXTERNAL_SENDERS: ExtensionType = ExtensionType(5);
34
35 #[cfg(feature = "last_resort_key_package_ext")]
36 pub const LAST_RESORT_KEY_PACKAGE: ExtensionType = ExtensionType(0x000A);
37
38 pub const DEFAULT: &'static [ExtensionType] = &[
41 ExtensionType::APPLICATION_ID,
42 ExtensionType::RATCHET_TREE,
43 ExtensionType::REQUIRED_CAPABILITIES,
44 ExtensionType::EXTERNAL_PUB,
45 ExtensionType::EXTERNAL_SENDERS,
46 ];
47
48 pub const fn new(raw_value: u16) -> Self {
50 ExtensionType(raw_value)
51 }
52
53 pub const fn raw_value(&self) -> u16 {
55 self.0
56 }
57
58 pub const fn is_default(&self) -> bool {
61 self.0 <= 5
62 }
63}
64
65impl From<u16> for ExtensionType {
66 fn from(value: u16) -> Self {
67 ExtensionType(value)
68 }
69}
70
71impl Deref for ExtensionType {
72 type Target = u16;
73
74 fn deref(&self) -> &Self::Target {
75 &self.0
76 }
77}
78
79#[derive(Debug)]
80#[cfg_attr(feature = "std", derive(thiserror::Error))]
81pub enum ExtensionError {
82 #[cfg_attr(feature = "std", error(transparent))]
83 SerializationError(AnyError),
84 #[cfg_attr(feature = "std", error(transparent))]
85 DeserializationError(AnyError),
86 #[cfg_attr(feature = "std", error("incorrect extension type: {0:?}"))]
87 IncorrectType(ExtensionType),
88}
89
90impl IntoAnyError for ExtensionError {
91 #[cfg(feature = "std")]
92 fn into_dyn_error(self) -> Result<Box<dyn std::error::Error + Send + Sync>, Self> {
93 Ok(self.into())
94 }
95}
96
97#[derive(Clone, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)]
98#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
99#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
100#[non_exhaustive]
101pub struct Extension {
106 pub extension_type: ExtensionType,
108 #[mls_codec(with = "mls_rs_codec::byte_vec")]
110 #[cfg_attr(feature = "serde", serde(with = "crate::vec_serde"))]
111 pub extension_data: Vec<u8>,
112}
113
114impl Debug for Extension {
115 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
116 f.debug_struct("Extension")
117 .field("extension_type", &self.extension_type)
118 .field(
119 "extension_data",
120 &crate::debug::pretty_bytes(&self.extension_data),
121 )
122 .finish()
123 }
124}
125
126impl Extension {
127 pub fn new(extension_type: ExtensionType, extension_data: Vec<u8>) -> Extension {
129 Extension {
130 extension_type,
131 extension_data,
132 }
133 }
134
135 pub fn extension_type(&self) -> ExtensionType {
137 self.extension_type
138 }
139
140 pub fn extension_data(&self) -> &[u8] {
142 &self.extension_data
143 }
144}
145
146pub trait MlsExtension: Sized {
148 type SerializationError: IntoAnyError;
150
151 type DeserializationError: IntoAnyError;
153
154 fn extension_type() -> ExtensionType;
156
157 fn to_bytes(&self) -> Result<Vec<u8>, Self::SerializationError>;
159
160 fn from_bytes(data: &[u8]) -> Result<Self, Self::DeserializationError>;
162
163 fn into_extension(self) -> Result<Extension, ExtensionError> {
165 Ok(Extension::new(
166 Self::extension_type(),
167 self.to_bytes()
168 .map_err(|e| ExtensionError::SerializationError(e.into_any_error()))?,
169 ))
170 }
171
172 fn from_extension(ext: &Extension) -> Result<Self, ExtensionError> {
174 if ext.extension_type != Self::extension_type() {
175 return Err(ExtensionError::IncorrectType(ext.extension_type));
176 }
177
178 Self::from_bytes(&ext.extension_data)
179 .map_err(|e| ExtensionError::DeserializationError(e.into_any_error()))
180 }
181}
182
183pub trait MlsCodecExtension: MlsSize + MlsEncode + MlsDecode {
186 fn extension_type() -> ExtensionType;
187}
188
189impl<T> MlsExtension for T
190where
191 T: MlsCodecExtension,
192{
193 type SerializationError = mls_rs_codec::Error;
194 type DeserializationError = mls_rs_codec::Error;
195
196 fn extension_type() -> ExtensionType {
197 <Self as MlsCodecExtension>::extension_type()
198 }
199
200 fn to_bytes(&self) -> Result<Vec<u8>, Self::SerializationError> {
201 self.mls_encode_to_vec()
202 }
203
204 fn from_bytes(data: &[u8]) -> Result<Self, Self::DeserializationError> {
205 Self::mls_decode(&mut &*data)
206 }
207}
208
209#[cfg(test)]
210mod tests {
211 use core::convert::Infallible;
212
213 use alloc::vec;
214 use alloc::vec::Vec;
215 use assert_matches::assert_matches;
216 use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
217
218 use super::{Extension, ExtensionError, ExtensionType, MlsCodecExtension, MlsExtension};
219
220 struct TestExtension;
221
222 #[derive(Debug, MlsSize, MlsEncode, MlsDecode)]
223 struct AnotherTestExtension;
224
225 impl MlsExtension for TestExtension {
226 type SerializationError = Infallible;
227 type DeserializationError = Infallible;
228
229 fn extension_type() -> super::ExtensionType {
230 ExtensionType(42)
231 }
232
233 fn to_bytes(&self) -> Result<Vec<u8>, Self::SerializationError> {
234 Ok(vec![0])
235 }
236
237 fn from_bytes(_data: &[u8]) -> Result<Self, Self::DeserializationError> {
238 Ok(TestExtension)
239 }
240 }
241
242 impl MlsCodecExtension for AnotherTestExtension {
243 fn extension_type() -> ExtensionType {
244 ExtensionType(43)
245 }
246 }
247
248 #[test]
249 fn into_extension() {
250 assert_eq!(
251 TestExtension.into_extension().unwrap(),
252 Extension::new(42.into(), vec![0])
253 )
254 }
255
256 #[test]
257 fn incorrect_type_is_discovered() {
258 let ext = Extension::new(42.into(), vec![0]);
259
260 assert_matches!(AnotherTestExtension::from_extension(&ext), Err(ExtensionError::IncorrectType(found)) if found == 42.into());
261 }
262}