mls_rs_core/extension/
list.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// Copyright by contributors to this project.
3// SPDX-License-Identifier: (Apache-2.0 OR MIT)
4
5use super::{Extension, ExtensionError, ExtensionType, MlsExtension};
6use alloc::vec::Vec;
7use core::ops::Deref;
8use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
9
10/// A collection of MLS [Extensions](super::Extension).
11///
12///
13/// # Warning
14///
15/// Extension lists require that each type of extension has at most one entry.
16#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
17#[cfg_attr(
18    all(feature = "ffi", not(test)),
19    safer_ffi_gen::ffi_type(clone, opaque)
20)]
21#[derive(Debug, Clone, Default, MlsSize, MlsEncode, Eq)]
22#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
23pub struct ExtensionList(Vec<Extension>);
24
25impl Deref for ExtensionList {
26    type Target = Vec<Extension>;
27
28    fn deref(&self) -> &Self::Target {
29        &self.0
30    }
31}
32
33impl PartialEq for ExtensionList {
34    fn eq(&self, other: &Self) -> bool {
35        self.len() == other.len()
36            && self
37                .iter()
38                .all(|ext| other.get(ext.extension_type).as_ref() == Some(ext))
39    }
40}
41
42impl MlsDecode for ExtensionList {
43    fn mls_decode(reader: &mut &[u8]) -> Result<Self, mls_rs_codec::Error> {
44        mls_rs_codec::iter::mls_decode_collection(reader, |data| {
45            let mut list = ExtensionList::new();
46
47            while !data.is_empty() {
48                let ext = Extension::mls_decode(data)?;
49                let ext_type = ext.extension_type;
50
51                if list.0.iter().any(|e| e.extension_type == ext_type) {
52                    // #[cfg(feature = "std")]
53                    // return Err(mls_rs_codec::Error::Custom(format!(
54                    //    "Extension list has duplicate extension of type {ext_type:?}"
55                    // )));
56
57                    // #[cfg(not(feature = "std"))]
58                    return Err(mls_rs_codec::Error::Custom(1));
59                }
60
61                list.0.push(ext);
62            }
63
64            Ok(list)
65        })
66    }
67}
68
69impl From<Vec<Extension>> for ExtensionList {
70    fn from(extensions: Vec<Extension>) -> Self {
71        extensions.into_iter().collect()
72    }
73}
74
75impl Extend<Extension> for ExtensionList {
76    fn extend<T: IntoIterator<Item = Extension>>(&mut self, iter: T) {
77        iter.into_iter().for_each(|ext| self.set(ext));
78    }
79}
80
81impl FromIterator<Extension> for ExtensionList {
82    fn from_iter<T: IntoIterator<Item = Extension>>(iter: T) -> Self {
83        let mut list = Self::new();
84        list.extend(iter);
85        list
86    }
87}
88
89impl ExtensionList {
90    /// Create a new empty extension list.
91    pub fn new() -> ExtensionList {
92        Default::default()
93    }
94
95    /// Retrieve an extension by providing a type that implements the
96    /// [MlsExtension](super::MlsExtension) trait.
97    ///
98    /// Returns an error if the underlying deserialization of the extension
99    /// data fails.
100    pub fn get_as<E: MlsExtension>(&self) -> Result<Option<E>, ExtensionError> {
101        self.0
102            .iter()
103            .find(|e| e.extension_type == E::extension_type())
104            .map(E::from_extension)
105            .transpose()
106    }
107
108    /// Determine if a specific extension exists within the list.
109    pub fn has_extension(&self, ext_id: ExtensionType) -> bool {
110        self.0.iter().any(|e| e.extension_type == ext_id)
111    }
112
113    /// Set an extension in the list based on a provided type that implements
114    /// the [MlsExtension](super::MlsExtension) trait.
115    ///
116    /// If there is already an entry in the list for the same extension type,
117    /// then the prior value is removed as part of the insertion.
118    ///
119    /// This function will return an error if `ext` fails to serialize
120    /// properly.
121    pub fn set_from<E: MlsExtension>(&mut self, ext: E) -> Result<(), ExtensionError> {
122        let ext = ext.into_extension()?;
123        self.set(ext);
124        Ok(())
125    }
126
127    /// Set an extension in the list based on a raw
128    /// [Extension](super::Extension) value.
129    ///
130    /// If there is already an entry in the list for the same extension type,
131    /// then the prior value is removed as part of the insertion.
132    pub fn set(&mut self, ext: Extension) {
133        let mut found = self
134            .0
135            .iter_mut()
136            .find(|e| e.extension_type == ext.extension_type);
137
138        if let Some(found) = found.take() {
139            *found = ext;
140        } else {
141            self.0.push(ext);
142        }
143    }
144
145    /// Get a raw [Extension](super::Extension) value based on an
146    /// [ExtensionType](super::ExtensionType).
147    pub fn get(&self, extension_type: ExtensionType) -> Option<Extension> {
148        self.0
149            .iter()
150            .find(|e| e.extension_type == extension_type)
151            .cloned()
152    }
153
154    /// Remove an extension from the list by
155    /// [ExtensionType](super::ExtensionType)
156    pub fn remove(&mut self, ext_type: ExtensionType) {
157        self.0.retain(|e| e.extension_type != ext_type)
158    }
159
160    /// Append another extension list to this one.
161    ///
162    /// If there is already an entry in the list for the same extension type,
163    /// then the existing value is removed.
164    pub fn append(&mut self, others: Self) {
165        self.0.extend(others.0);
166    }
167}
168
169#[cfg(test)]
170mod tests {
171    use alloc::vec;
172    use alloc::vec::Vec;
173    use assert_matches::assert_matches;
174    use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
175
176    use crate::extension::{
177        list::ExtensionList, Extension, ExtensionType, MlsCodecExtension, MlsExtension,
178    };
179
180    #[derive(Debug, Clone, MlsSize, MlsEncode, MlsDecode, PartialEq, Eq)]
181    struct TestExtensionA(u32);
182
183    #[derive(Debug, Clone, MlsEncode, MlsDecode, MlsSize, PartialEq, Eq)]
184    struct TestExtensionB(#[mls_codec(with = "mls_rs_codec::byte_vec")] Vec<u8>);
185
186    #[derive(Debug, Clone, MlsEncode, MlsDecode, MlsSize, PartialEq, Eq)]
187    struct TestExtensionC(u8);
188
189    impl MlsCodecExtension for TestExtensionA {
190        fn extension_type() -> ExtensionType {
191            ExtensionType(128)
192        }
193    }
194
195    impl MlsCodecExtension for TestExtensionB {
196        fn extension_type() -> ExtensionType {
197            ExtensionType(129)
198        }
199    }
200
201    impl MlsCodecExtension for TestExtensionC {
202        fn extension_type() -> ExtensionType {
203            ExtensionType(130)
204        }
205    }
206
207    #[test]
208    fn test_extension_list_get_set_from_get_as() {
209        let mut list = ExtensionList::new();
210
211        let ext_a = TestExtensionA(0);
212        let ext_b = TestExtensionB(vec![1]);
213
214        // Add the extensions to the list
215        list.set_from(ext_a.clone()).unwrap();
216        list.set_from(ext_b.clone()).unwrap();
217
218        assert_eq!(list.len(), 2);
219        assert_eq!(list.get_as::<TestExtensionA>().unwrap(), Some(ext_a));
220        assert_eq!(list.get_as::<TestExtensionB>().unwrap(), Some(ext_b));
221    }
222
223    #[test]
224    fn test_extension_list_get_set() {
225        let mut list = ExtensionList::new();
226
227        let ext_a = Extension::new(ExtensionType(254), vec![0, 1, 2]);
228        let ext_b = Extension::new(ExtensionType(255), vec![4, 5, 6]);
229
230        // Add the extensions to the list
231        list.set(ext_a.clone());
232        list.set(ext_b.clone());
233
234        assert_eq!(list.len(), 2);
235        assert_eq!(list.get(ExtensionType(254)), Some(ext_a));
236        assert_eq!(list.get(ExtensionType(255)), Some(ext_b));
237    }
238
239    #[test]
240    fn extension_list_can_overwrite_values() {
241        let mut list = ExtensionList::new();
242
243        let ext_1 = TestExtensionA(0);
244        let ext_2 = TestExtensionA(1);
245
246        list.set_from(ext_1).unwrap();
247        list.set_from(ext_2.clone()).unwrap();
248
249        assert_eq!(list.get_as::<TestExtensionA>().unwrap(), Some(ext_2));
250    }
251
252    #[test]
253    fn extension_list_will_return_none_for_type_not_stored() {
254        let mut list = ExtensionList::new();
255
256        assert!(list.get_as::<TestExtensionA>().unwrap().is_none());
257
258        assert!(list
259            .get(<TestExtensionA as MlsCodecExtension>::extension_type())
260            .is_none());
261
262        list.set_from(TestExtensionA(1)).unwrap();
263
264        assert!(list.get_as::<TestExtensionB>().unwrap().is_none());
265
266        assert!(list
267            .get(<TestExtensionB as MlsCodecExtension>::extension_type())
268            .is_none());
269    }
270
271    #[test]
272    fn test_extension_list_has_ext() {
273        let mut list = ExtensionList::new();
274
275        let ext = TestExtensionA(255);
276
277        list.set_from(ext).unwrap();
278
279        assert!(list.has_extension(<TestExtensionA as MlsCodecExtension>::extension_type()));
280        assert!(!list.has_extension(42.into()));
281    }
282
283    #[derive(MlsEncode, MlsSize)]
284    struct ExtensionsVec(Vec<Extension>);
285
286    #[test]
287    fn extension_list_is_serialized_like_a_sequence_of_extensions() {
288        let extension_vec = vec![
289            Extension::new(ExtensionType(128), vec![0, 1, 2, 3]),
290            Extension::new(ExtensionType(129), vec![1, 2, 3, 4]),
291        ];
292
293        let extension_list: ExtensionList = ExtensionList::from(extension_vec.clone());
294
295        assert_eq!(
296            ExtensionsVec(extension_vec).mls_encode_to_vec().unwrap(),
297            extension_list.mls_encode_to_vec().unwrap(),
298        );
299    }
300
301    #[test]
302    fn deserializing_extension_list_fails_on_duplicate_extension() {
303        let extensions = ExtensionsVec(vec![
304            TestExtensionA(1).into_extension().unwrap(),
305            TestExtensionA(2).into_extension().unwrap(),
306        ]);
307
308        let serialized_extensions = extensions.mls_encode_to_vec().unwrap();
309
310        assert_matches!(
311            ExtensionList::mls_decode(&mut &*serialized_extensions),
312            Err(mls_rs_codec::Error::Custom(_))
313        );
314    }
315
316    #[test]
317    fn extension_list_equality_does_not_consider_order() {
318        let extensions = [
319            TestExtensionA(33).into_extension().unwrap(),
320            TestExtensionC(34).into_extension().unwrap(),
321        ];
322
323        let a = extensions.iter().cloned().collect::<ExtensionList>();
324        let b = extensions.iter().rev().cloned().collect::<ExtensionList>();
325
326        assert_eq!(a, b);
327    }
328
329    #[test]
330    fn extending_extension_list_maintains_extension_uniqueness() {
331        let mut list = ExtensionList::new();
332        list.set_from(TestExtensionA(33)).unwrap();
333        list.set_from(TestExtensionC(34)).unwrap();
334        list.extend([
335            TestExtensionA(35).into_extension().unwrap(),
336            TestExtensionB(vec![36]).into_extension().unwrap(),
337            TestExtensionA(37).into_extension().unwrap(),
338        ]);
339
340        let expected = ExtensionList(vec![
341            TestExtensionA(37).into_extension().unwrap(),
342            TestExtensionB(vec![36]).into_extension().unwrap(),
343            TestExtensionC(34).into_extension().unwrap(),
344        ]);
345
346        assert_eq!(list, expected);
347    }
348
349    #[test]
350    fn extension_list_from_vec_maintains_extension_uniqueness() {
351        let list = ExtensionList::from(vec![
352            TestExtensionA(33).into_extension().unwrap(),
353            TestExtensionC(34).into_extension().unwrap(),
354            TestExtensionA(35).into_extension().unwrap(),
355        ]);
356
357        let expected = ExtensionList(vec![
358            TestExtensionA(35).into_extension().unwrap(),
359            TestExtensionC(34).into_extension().unwrap(),
360        ]);
361
362        assert_eq!(list, expected);
363    }
364}