1use core::{
6 fmt::{self, Debug},
7 ops::Deref,
8};
9
10use crate::error::{AnyError, IntoAnyError};
11use alloc::vec::Vec;
12use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
13
14mod list;
15
16pub use list::*;
17
18#[derive(
21 Debug, PartialEq, Eq, Hash, Clone, Copy, PartialOrd, Ord, MlsSize, MlsEncode, MlsDecode,
22)]
23#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
24#[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::ffi_type)]
25#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
26#[repr(transparent)]
27pub struct ExtensionType(u16);
28
29#[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen)]
30impl ExtensionType {
31 pub const APPLICATION_ID: ExtensionType = ExtensionType(1);
32 pub const RATCHET_TREE: ExtensionType = ExtensionType(2);
33 pub const REQUIRED_CAPABILITIES: ExtensionType = ExtensionType(3);
34 pub const EXTERNAL_PUB: ExtensionType = ExtensionType(4);
35 pub const EXTERNAL_SENDERS: ExtensionType = ExtensionType(5);
36
37 #[cfg(feature = "last_resort_key_package_ext")]
38 pub const LAST_RESORT_KEY_PACKAGE: ExtensionType = ExtensionType(0x000A);
39
40 pub const DEFAULT: &'static [ExtensionType] = &[
43 ExtensionType::APPLICATION_ID,
44 ExtensionType::RATCHET_TREE,
45 ExtensionType::REQUIRED_CAPABILITIES,
46 ExtensionType::EXTERNAL_PUB,
47 ExtensionType::EXTERNAL_SENDERS,
48 ];
49
50 pub const fn new(raw_value: u16) -> Self {
52 ExtensionType(raw_value)
53 }
54
55 pub const fn raw_value(&self) -> u16 {
57 self.0
58 }
59
60 pub const fn is_default(&self) -> bool {
63 self.0 <= 5
64 }
65}
66
67impl From<u16> for ExtensionType {
68 fn from(value: u16) -> Self {
69 ExtensionType(value)
70 }
71}
72
73impl Deref for ExtensionType {
74 type Target = u16;
75
76 fn deref(&self) -> &Self::Target {
77 &self.0
78 }
79}
80
81#[derive(Debug)]
82#[cfg_attr(feature = "std", derive(thiserror::Error))]
83pub enum ExtensionError {
84 #[cfg_attr(feature = "std", error(transparent))]
85 SerializationError(AnyError),
86 #[cfg_attr(feature = "std", error(transparent))]
87 DeserializationError(AnyError),
88 #[cfg_attr(feature = "std", error("incorrect extension type: {0:?}"))]
89 IncorrectType(ExtensionType),
90}
91
92impl IntoAnyError for ExtensionError {
93 #[cfg(feature = "std")]
94 fn into_dyn_error(self) -> Result<Box<dyn std::error::Error + Send + Sync>, Self> {
95 Ok(self.into())
96 }
97}
98
99#[derive(Clone, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)]
100#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
101#[cfg_attr(
102 all(feature = "ffi", not(test)),
103 safer_ffi_gen::ffi_type(clone, opaque)
104)]
105#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
106#[non_exhaustive]
107pub struct Extension {
112 pub extension_type: ExtensionType,
114 #[mls_codec(with = "mls_rs_codec::byte_vec")]
116 #[cfg_attr(feature = "serde", serde(with = "crate::vec_serde"))]
117 pub extension_data: Vec<u8>,
118}
119
120impl Debug for Extension {
121 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
122 f.debug_struct("Extension")
123 .field("extension_type", &self.extension_type)
124 .field(
125 "extension_data",
126 &crate::debug::pretty_bytes(&self.extension_data),
127 )
128 .finish()
129 }
130}
131
132#[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen)]
133impl Extension {
134 pub fn new(extension_type: ExtensionType, extension_data: Vec<u8>) -> Extension {
136 Extension {
137 extension_type,
138 extension_data,
139 }
140 }
141
142 #[cfg(feature = "ffi")]
144 pub fn extension_type(&self) -> ExtensionType {
145 self.extension_type
146 }
147
148 #[cfg(feature = "ffi")]
150 pub fn extension_data(&self) -> &[u8] {
151 &self.extension_data
152 }
153}
154
155pub trait MlsExtension: Sized {
157 type SerializationError: IntoAnyError;
159
160 type DeserializationError: IntoAnyError;
162
163 fn extension_type() -> ExtensionType;
165
166 fn to_bytes(&self) -> Result<Vec<u8>, Self::SerializationError>;
168
169 fn from_bytes(data: &[u8]) -> Result<Self, Self::DeserializationError>;
171
172 fn into_extension(self) -> Result<Extension, ExtensionError> {
174 Ok(Extension::new(
175 Self::extension_type(),
176 self.to_bytes()
177 .map_err(|e| ExtensionError::SerializationError(e.into_any_error()))?,
178 ))
179 }
180
181 fn from_extension(ext: &Extension) -> Result<Self, ExtensionError> {
183 if ext.extension_type != Self::extension_type() {
184 return Err(ExtensionError::IncorrectType(ext.extension_type));
185 }
186
187 Self::from_bytes(&ext.extension_data)
188 .map_err(|e| ExtensionError::DeserializationError(e.into_any_error()))
189 }
190}
191
192pub trait MlsCodecExtension: MlsSize + MlsEncode + MlsDecode {
195 fn extension_type() -> ExtensionType;
196}
197
198impl<T> MlsExtension for T
199where
200 T: MlsCodecExtension,
201{
202 type SerializationError = mls_rs_codec::Error;
203 type DeserializationError = mls_rs_codec::Error;
204
205 fn extension_type() -> ExtensionType {
206 <Self as MlsCodecExtension>::extension_type()
207 }
208
209 fn to_bytes(&self) -> Result<Vec<u8>, Self::SerializationError> {
210 self.mls_encode_to_vec()
211 }
212
213 fn from_bytes(data: &[u8]) -> Result<Self, Self::DeserializationError> {
214 Self::mls_decode(&mut &*data)
215 }
216}
217
218#[cfg(test)]
219mod tests {
220 use core::convert::Infallible;
221
222 use alloc::vec;
223 use alloc::vec::Vec;
224 use assert_matches::assert_matches;
225 use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
226
227 use super::{Extension, ExtensionError, ExtensionType, MlsCodecExtension, MlsExtension};
228
229 struct TestExtension;
230
231 #[derive(Debug, MlsSize, MlsEncode, MlsDecode)]
232 struct AnotherTestExtension;
233
234 impl MlsExtension for TestExtension {
235 type SerializationError = Infallible;
236 type DeserializationError = Infallible;
237
238 fn extension_type() -> super::ExtensionType {
239 ExtensionType(42)
240 }
241
242 fn to_bytes(&self) -> Result<Vec<u8>, Self::SerializationError> {
243 Ok(vec![0])
244 }
245
246 fn from_bytes(_data: &[u8]) -> Result<Self, Self::DeserializationError> {
247 Ok(TestExtension)
248 }
249 }
250
251 impl MlsCodecExtension for AnotherTestExtension {
252 fn extension_type() -> ExtensionType {
253 ExtensionType(43)
254 }
255 }
256
257 #[test]
258 fn into_extension() {
259 assert_eq!(
260 TestExtension.into_extension().unwrap(),
261 Extension::new(42.into(), vec![0])
262 )
263 }
264
265 #[test]
266 fn incorrect_type_is_discovered() {
267 let ext = Extension::new(42.into(), vec![0]);
268
269 assert_matches!(AnotherTestExtension::from_extension(&ext), Err(ExtensionError::IncorrectType(found)) if found == 42.into());
270 }
271}