1use crate::{MlsDecode, MlsEncode, MlsSize};
6use alloc::vec::Vec;
7
8impl<T: MlsSize> MlsSize for Option<T> {
9 #[inline]
10 fn mls_encoded_len(&self) -> usize {
11 1 + match self {
12 Some(v) => v.mls_encoded_len(),
13 None => 0,
14 }
15 }
16}
17
18impl<T: MlsEncode> MlsEncode for Option<T> {
19 fn mls_encode(&self, writer: &mut Vec<u8>) -> Result<(), crate::Error> {
20 if let Some(item) = self {
21 writer.push(1);
22 item.mls_encode(writer)
23 } else {
24 writer.push(0);
25 Ok(())
26 }
27 }
28}
29
30impl<T: MlsDecode> MlsDecode for Option<T> {
31 fn mls_decode(reader: &mut &[u8]) -> Result<Self, crate::Error> {
32 match u8::mls_decode(reader)? {
33 0 => Ok(None),
34 1 => T::mls_decode(reader).map(Some),
35 n => Err(crate::Error::OptionOutOfRange(n)),
36 }
37 }
38}
39
40#[cfg(test)]
41mod tests {
42 use alloc::vec;
43
44 use crate::{Error, MlsDecode, MlsEncode};
45 use assert_matches::assert_matches;
46
47 #[cfg(target_arch = "wasm32")]
48 use wasm_bindgen_test::wasm_bindgen_test as test;
49
50 #[test]
51 fn none_is_serialized_correctly() {
52 assert_eq!(vec![0u8], None::<u8>.mls_encode_to_vec().unwrap());
53 }
54
55 #[test]
56 fn some_is_serialized_correctly() {
57 assert_eq!(vec![1u8, 2], Some(2u8).mls_encode_to_vec().unwrap());
58 }
59
60 #[test]
61 fn none_round_trips() {
62 let val = None::<u8>;
63 let x = val.mls_encode_to_vec().unwrap();
64 assert_eq!(val, Option::mls_decode(&mut &*x).unwrap());
65 }
66
67 #[test]
68 fn some_round_trips() {
69 let val = Some(32u8);
70 let x = val.mls_encode_to_vec().unwrap();
71 assert_eq!(val, Option::mls_decode(&mut &*x).unwrap());
72 }
73
74 #[test]
75 fn deserializing_invalid_discriminant_fails() {
76 assert_matches!(
77 Option::<u8>::mls_decode(&mut &[2u8][..]),
78 Err(Error::OptionOutOfRange(_))
79 );
80 }
81}