mls_rs_codec/
option.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// Copyright by contributors to this project.
3// SPDX-License-Identifier: (Apache-2.0 OR MIT)
4
5use crate::{MlsDecode, MlsEncode, MlsSize};
6use alloc::vec::Vec;
7
8impl<T: MlsSize> MlsSize for Option<T> {
9    #[inline]
10    fn mls_encoded_len(&self) -> usize {
11        1 + match self {
12            Some(v) => v.mls_encoded_len(),
13            None => 0,
14        }
15    }
16}
17
18impl<T: MlsEncode> MlsEncode for Option<T> {
19    fn mls_encode(&self, writer: &mut Vec<u8>) -> Result<(), crate::Error> {
20        if let Some(item) = self {
21            writer.push(1);
22            item.mls_encode(writer)
23        } else {
24            writer.push(0);
25            Ok(())
26        }
27    }
28}
29
30impl<T: MlsDecode> MlsDecode for Option<T> {
31    fn mls_decode(reader: &mut &[u8]) -> Result<Self, crate::Error> {
32        match u8::mls_decode(reader)? {
33            0 => Ok(None),
34            1 => T::mls_decode(reader).map(Some),
35            n => Err(crate::Error::OptionOutOfRange(n)),
36        }
37    }
38}
39
40#[cfg(test)]
41mod tests {
42    use alloc::vec;
43
44    use crate::{Error, MlsDecode, MlsEncode};
45    use assert_matches::assert_matches;
46
47    #[cfg(target_arch = "wasm32")]
48    use wasm_bindgen_test::wasm_bindgen_test as test;
49
50    #[test]
51    fn none_is_serialized_correctly() {
52        assert_eq!(vec![0u8], None::<u8>.mls_encode_to_vec().unwrap());
53    }
54
55    #[test]
56    fn some_is_serialized_correctly() {
57        assert_eq!(vec![1u8, 2], Some(2u8).mls_encode_to_vec().unwrap());
58    }
59
60    #[test]
61    fn none_round_trips() {
62        let val = None::<u8>;
63        let x = val.mls_encode_to_vec().unwrap();
64        assert_eq!(val, Option::mls_decode(&mut &*x).unwrap());
65    }
66
67    #[test]
68    fn some_round_trips() {
69        let val = Some(32u8);
70        let x = val.mls_encode_to_vec().unwrap();
71        assert_eq!(val, Option::mls_decode(&mut &*x).unwrap());
72    }
73
74    #[test]
75    fn deserializing_invalid_discriminant_fails() {
76        assert_matches!(
77            Option::<u8>::mls_decode(&mut &[2u8][..]),
78            Err(Error::OptionOutOfRange(_))
79        );
80    }
81}