1use super::{Extension, ExtensionError, ExtensionType, MlsExtension};
6use alloc::vec::Vec;
7use core::ops::Deref;
8use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
9
10#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
17#[cfg_attr(
18 all(feature = "ffi", not(test)),
19 safer_ffi_gen::ffi_type(clone, opaque)
20)]
21#[derive(Debug, Clone, Default, MlsSize, MlsEncode, Eq)]
22#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
23pub struct ExtensionList(Vec<Extension>);
24
25impl Deref for ExtensionList {
26 type Target = Vec<Extension>;
27
28 fn deref(&self) -> &Self::Target {
29 &self.0
30 }
31}
32
33impl PartialEq for ExtensionList {
34 fn eq(&self, other: &Self) -> bool {
35 self.len() == other.len()
36 && self
37 .iter()
38 .all(|ext| other.get(ext.extension_type).as_ref() == Some(ext))
39 }
40}
41
42impl MlsDecode for ExtensionList {
43 fn mls_decode(reader: &mut &[u8]) -> Result<Self, mls_rs_codec::Error> {
44 mls_rs_codec::iter::mls_decode_collection(reader, |data| {
45 let mut list = ExtensionList::new();
46
47 while !data.is_empty() {
48 let ext = Extension::mls_decode(data)?;
49 let ext_type = ext.extension_type;
50
51 if list.0.iter().any(|e| e.extension_type == ext_type) {
52 return Err(mls_rs_codec::Error::Custom(1));
59 }
60
61 list.0.push(ext);
62 }
63
64 Ok(list)
65 })
66 }
67}
68
69impl From<Vec<Extension>> for ExtensionList {
70 fn from(extensions: Vec<Extension>) -> Self {
71 extensions.into_iter().collect()
72 }
73}
74
75impl Extend<Extension> for ExtensionList {
76 fn extend<T: IntoIterator<Item = Extension>>(&mut self, iter: T) {
77 iter.into_iter().for_each(|ext| self.set(ext));
78 }
79}
80
81impl FromIterator<Extension> for ExtensionList {
82 fn from_iter<T: IntoIterator<Item = Extension>>(iter: T) -> Self {
83 let mut list = Self::new();
84 list.extend(iter);
85 list
86 }
87}
88
89impl ExtensionList {
90 pub fn new() -> ExtensionList {
92 Default::default()
93 }
94
95 pub fn get_as<E: MlsExtension>(&self) -> Result<Option<E>, ExtensionError> {
101 self.0
102 .iter()
103 .find(|e| e.extension_type == E::extension_type())
104 .map(E::from_extension)
105 .transpose()
106 }
107
108 pub fn has_extension(&self, ext_id: ExtensionType) -> bool {
110 self.0.iter().any(|e| e.extension_type == ext_id)
111 }
112
113 pub fn set_from<E: MlsExtension>(&mut self, ext: E) -> Result<(), ExtensionError> {
122 let ext = ext.into_extension()?;
123 self.set(ext);
124 Ok(())
125 }
126
127 pub fn set(&mut self, ext: Extension) {
133 let mut found = self
134 .0
135 .iter_mut()
136 .find(|e| e.extension_type == ext.extension_type);
137
138 if let Some(found) = found.take() {
139 *found = ext;
140 } else {
141 self.0.push(ext);
142 }
143 }
144
145 pub fn get(&self, extension_type: ExtensionType) -> Option<Extension> {
148 self.0
149 .iter()
150 .find(|e| e.extension_type == extension_type)
151 .cloned()
152 }
153
154 pub fn remove(&mut self, ext_type: ExtensionType) {
157 self.0.retain(|e| e.extension_type != ext_type)
158 }
159
160 pub fn append(&mut self, others: Self) {
165 self.0.extend(others.0);
166 }
167}
168
169#[cfg(test)]
170mod tests {
171 use alloc::vec;
172 use alloc::vec::Vec;
173 use assert_matches::assert_matches;
174 use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
175
176 use crate::extension::{
177 list::ExtensionList, Extension, ExtensionType, MlsCodecExtension, MlsExtension,
178 };
179
180 #[derive(Debug, Clone, MlsSize, MlsEncode, MlsDecode, PartialEq, Eq)]
181 struct TestExtensionA(u32);
182
183 #[derive(Debug, Clone, MlsEncode, MlsDecode, MlsSize, PartialEq, Eq)]
184 struct TestExtensionB(#[mls_codec(with = "mls_rs_codec::byte_vec")] Vec<u8>);
185
186 #[derive(Debug, Clone, MlsEncode, MlsDecode, MlsSize, PartialEq, Eq)]
187 struct TestExtensionC(u8);
188
189 impl MlsCodecExtension for TestExtensionA {
190 fn extension_type() -> ExtensionType {
191 ExtensionType(128)
192 }
193 }
194
195 impl MlsCodecExtension for TestExtensionB {
196 fn extension_type() -> ExtensionType {
197 ExtensionType(129)
198 }
199 }
200
201 impl MlsCodecExtension for TestExtensionC {
202 fn extension_type() -> ExtensionType {
203 ExtensionType(130)
204 }
205 }
206
207 #[test]
208 fn test_extension_list_get_set_from_get_as() {
209 let mut list = ExtensionList::new();
210
211 let ext_a = TestExtensionA(0);
212 let ext_b = TestExtensionB(vec![1]);
213
214 list.set_from(ext_a.clone()).unwrap();
216 list.set_from(ext_b.clone()).unwrap();
217
218 assert_eq!(list.len(), 2);
219 assert_eq!(list.get_as::<TestExtensionA>().unwrap(), Some(ext_a));
220 assert_eq!(list.get_as::<TestExtensionB>().unwrap(), Some(ext_b));
221 }
222
223 #[test]
224 fn test_extension_list_get_set() {
225 let mut list = ExtensionList::new();
226
227 let ext_a = Extension::new(ExtensionType(254), vec![0, 1, 2]);
228 let ext_b = Extension::new(ExtensionType(255), vec![4, 5, 6]);
229
230 list.set(ext_a.clone());
232 list.set(ext_b.clone());
233
234 assert_eq!(list.len(), 2);
235 assert_eq!(list.get(ExtensionType(254)), Some(ext_a));
236 assert_eq!(list.get(ExtensionType(255)), Some(ext_b));
237 }
238
239 #[test]
240 fn extension_list_can_overwrite_values() {
241 let mut list = ExtensionList::new();
242
243 let ext_1 = TestExtensionA(0);
244 let ext_2 = TestExtensionA(1);
245
246 list.set_from(ext_1).unwrap();
247 list.set_from(ext_2.clone()).unwrap();
248
249 assert_eq!(list.get_as::<TestExtensionA>().unwrap(), Some(ext_2));
250 }
251
252 #[test]
253 fn extension_list_will_return_none_for_type_not_stored() {
254 let mut list = ExtensionList::new();
255
256 assert!(list.get_as::<TestExtensionA>().unwrap().is_none());
257
258 assert!(list
259 .get(<TestExtensionA as MlsCodecExtension>::extension_type())
260 .is_none());
261
262 list.set_from(TestExtensionA(1)).unwrap();
263
264 assert!(list.get_as::<TestExtensionB>().unwrap().is_none());
265
266 assert!(list
267 .get(<TestExtensionB as MlsCodecExtension>::extension_type())
268 .is_none());
269 }
270
271 #[test]
272 fn test_extension_list_has_ext() {
273 let mut list = ExtensionList::new();
274
275 let ext = TestExtensionA(255);
276
277 list.set_from(ext).unwrap();
278
279 assert!(list.has_extension(<TestExtensionA as MlsCodecExtension>::extension_type()));
280 assert!(!list.has_extension(42.into()));
281 }
282
283 #[derive(MlsEncode, MlsSize)]
284 struct ExtensionsVec(Vec<Extension>);
285
286 #[test]
287 fn extension_list_is_serialized_like_a_sequence_of_extensions() {
288 let extension_vec = vec![
289 Extension::new(ExtensionType(128), vec![0, 1, 2, 3]),
290 Extension::new(ExtensionType(129), vec![1, 2, 3, 4]),
291 ];
292
293 let extension_list: ExtensionList = ExtensionList::from(extension_vec.clone());
294
295 assert_eq!(
296 ExtensionsVec(extension_vec).mls_encode_to_vec().unwrap(),
297 extension_list.mls_encode_to_vec().unwrap(),
298 );
299 }
300
301 #[test]
302 fn deserializing_extension_list_fails_on_duplicate_extension() {
303 let extensions = ExtensionsVec(vec![
304 TestExtensionA(1).into_extension().unwrap(),
305 TestExtensionA(2).into_extension().unwrap(),
306 ]);
307
308 let serialized_extensions = extensions.mls_encode_to_vec().unwrap();
309
310 assert_matches!(
311 ExtensionList::mls_decode(&mut &*serialized_extensions),
312 Err(mls_rs_codec::Error::Custom(_))
313 );
314 }
315
316 #[test]
317 fn extension_list_equality_does_not_consider_order() {
318 let extensions = [
319 TestExtensionA(33).into_extension().unwrap(),
320 TestExtensionC(34).into_extension().unwrap(),
321 ];
322
323 let a = extensions.iter().cloned().collect::<ExtensionList>();
324 let b = extensions.iter().rev().cloned().collect::<ExtensionList>();
325
326 assert_eq!(a, b);
327 }
328
329 #[test]
330 fn extending_extension_list_maintains_extension_uniqueness() {
331 let mut list = ExtensionList::new();
332 list.set_from(TestExtensionA(33)).unwrap();
333 list.set_from(TestExtensionC(34)).unwrap();
334 list.extend([
335 TestExtensionA(35).into_extension().unwrap(),
336 TestExtensionB(vec![36]).into_extension().unwrap(),
337 TestExtensionA(37).into_extension().unwrap(),
338 ]);
339
340 let expected = ExtensionList(vec![
341 TestExtensionA(37).into_extension().unwrap(),
342 TestExtensionB(vec![36]).into_extension().unwrap(),
343 TestExtensionC(34).into_extension().unwrap(),
344 ]);
345
346 assert_eq!(list, expected);
347 }
348
349 #[test]
350 fn extension_list_from_vec_maintains_extension_uniqueness() {
351 let list = ExtensionList::from(vec![
352 TestExtensionA(33).into_extension().unwrap(),
353 TestExtensionC(34).into_extension().unwrap(),
354 TestExtensionA(35).into_extension().unwrap(),
355 ]);
356
357 let expected = ExtensionList(vec![
358 TestExtensionA(35).into_extension().unwrap(),
359 TestExtensionC(34).into_extension().unwrap(),
360 ]);
361
362 assert_eq!(list, expected);
363 }
364}