1use crate::{
2 crypto::HpkePublicKey,
3 group::{ExtensionType, ExternalSender, RequiredCapabilities},
4 macros::ref_forward_tls_impl,
5 tlspl::{bytes as vlbytes, tls_serialized_len_as_vlvec},
6 tree::RatchetTree,
7};
8
9#[derive(
10 Debug,
11 Clone,
12 PartialEq,
13 Eq,
14 Hash,
15 tls_codec::TlsDeserialize,
16 tls_codec::TlsSerialize,
17 tls_codec::TlsSize,
18)]
19#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
20pub struct RatchetTreeExtension {
21 pub ratchet_tree: RatchetTree,
22}
23
24#[derive(Debug, Clone, PartialEq, Eq, Hash)]
25#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
26#[repr(u16)]
27pub enum Extension {
28 ApplicationId(Vec<u8>),
32 RatchetTree(RatchetTreeExtension),
34 RequiredCapabilities(RequiredCapabilities),
35 ExternalPub(ExternalPub),
37 ExternalSenders(Vec<ExternalSender>),
39 #[cfg(feature = "draft-ietf-mls-extensions")]
40 ApplicationData(crate::drafts::mls_extensions::safe_application::ApplicationDataDictionary),
41 #[cfg(feature = "draft-ietf-mls-extensions")]
42 SupportedWireFormats(crate::drafts::mls_extensions::safe_application::WireFormats),
43 #[cfg(feature = "draft-ietf-mls-extensions")]
44 RequiredWireFormats(crate::drafts::mls_extensions::safe_application::WireFormats),
45 #[cfg(feature = "draft-mahy-mls-ratchet-tree-options")]
46 RatchetTreeSourceDomains(
47 crate::drafts::ratchet_tree_options::RatchetTreeSourceDomainsExtension,
48 ),
49 Arbitrary(ArbitraryExtension),
50}
51
52impl From<&Extension> for ExtensionType {
53 fn from(value: &Extension) -> Self {
54 ExtensionType::new_unchecked(match value {
55 Extension::ApplicationId(_) => ExtensionType::APPLICATION_ID,
56 Extension::RatchetTree(_) => ExtensionType::RATCHET_TREE,
57 Extension::RequiredCapabilities(_) => ExtensionType::REQUIRED_CAPABILITIES,
58 Extension::ExternalPub(_) => ExtensionType::EXTERNAL_PUB,
59 Extension::ExternalSenders(_) => ExtensionType::EXTERNAL_SENDERS,
60 #[cfg(feature = "draft-ietf-mls-extensions")]
61 Extension::ApplicationData(_) => ExtensionType::APPLICATION_DATA_DICTIONARY,
62 #[cfg(feature = "draft-ietf-mls-extensions")]
63 Extension::SupportedWireFormats(_) => ExtensionType::SUPPORTED_WIRE_FORMATS,
64 #[cfg(feature = "draft-ietf-mls-extensions")]
65 Extension::RequiredWireFormats(_) => ExtensionType::REQUIRED_WIRE_FORMATS,
66 #[cfg(feature = "draft-mahy-mls-ratchet-tree-options")]
67 Extension::RatchetTreeSourceDomains(_) => ExtensionType::RATCHET_TREE_SOURCE_DOMAINS,
68 Extension::Arbitrary(ArbitraryExtension { extension_id, .. }) => {
69 (**extension_id) as u16
70 }
71 })
72 }
73}
74
75impl Extension {
76 pub fn new(extension_id: u16, extension_data: Vec<u8>) -> crate::MlsSpecResult<Self> {
77 use tls_codec::Deserialize as _;
78 Ok(match extension_id {
79 ExtensionType::APPLICATION_ID => {
80 Self::ApplicationId(vlbytes::tls_deserialize(&mut &extension_data[..])?)
81 }
82 ExtensionType::RATCHET_TREE => Self::RatchetTree(
83 RatchetTreeExtension::tls_deserialize_exact(&extension_data)?,
84 ),
85 ExtensionType::REQUIRED_CAPABILITIES => Self::RequiredCapabilities(
86 RequiredCapabilities::tls_deserialize_exact(&extension_data)?,
87 ),
88 ExtensionType::EXTERNAL_PUB => {
89 Self::ExternalPub(ExternalPub::tls_deserialize_exact(&extension_data)?)
90 }
91 ExtensionType::EXTERNAL_SENDERS => Self::ExternalSenders(
92 Vec::<ExternalSender>::tls_deserialize_exact(&extension_data)?,
93 ),
94 #[cfg(feature = "draft-ietf-mls-extensions")]
95 ExtensionType::APPLICATION_DATA_DICTIONARY => Self::ApplicationData(
96 crate::drafts::mls_extensions::safe_application::ApplicationDataDictionary::tls_deserialize_exact(
97 &extension_data,
98 )?,
99 ),
100 #[cfg(feature = "draft-ietf-mls-extensions")]
101 ExtensionType::SUPPORTED_WIRE_FORMATS => Self::SupportedWireFormats(<_>::tls_deserialize_exact(&extension_data)?),
102 #[cfg(feature = "draft-ietf-mls-extensions")]
103 ExtensionType::REQUIRED_WIRE_FORMATS => Self::RequiredWireFormats(<_>::tls_deserialize_exact(&extension_data)?),
104 #[cfg(feature = "draft-mahy-mls-ratchet-tree-options")]
105 ExtensionType::RATCHET_TREE_SOURCE_DOMAINS => Self::RatchetTreeSourceDomains(<_>::tls_deserialize_exact(&extension_data)?),
106 _ => Self::Arbitrary(ArbitraryExtension {
107 extension_id: ExtensionType::new_unchecked(extension_id),
108 extension_data,
109 }),
110 })
111 }
112
113 pub fn ext_type(&self) -> ExtensionType {
114 self.into()
115 }
116}
117
118impl tls_codec::Size for Extension {
119 fn tls_serialized_len(&self) -> usize {
120 let ext_type_len = ExtensionType::from(self).tls_serialized_len();
121 let ext_value_len = match self {
122 Extension::ApplicationId(data) => {
123 tls_serialized_len_as_vlvec(data.tls_serialized_len())
124 }
125 Extension::RatchetTree(nodes) => {
126 tls_serialized_len_as_vlvec(nodes.tls_serialized_len())
127 }
128 Extension::RequiredCapabilities(caps) => {
129 tls_serialized_len_as_vlvec(caps.tls_serialized_len())
130 }
131 Extension::ExternalPub(ext_pub) => {
132 tls_serialized_len_as_vlvec(ext_pub.tls_serialized_len())
133 }
134 Extension::ExternalSenders(ext_senders) => {
135 tls_serialized_len_as_vlvec(ext_senders.tls_serialized_len())
136 }
137 #[cfg(feature = "draft-ietf-mls-extensions")]
138 Extension::ApplicationData(app_data_dict) => {
139 tls_serialized_len_as_vlvec(app_data_dict.tls_serialized_len())
140 }
141 #[cfg(feature = "draft-ietf-mls-extensions")]
142 Extension::SupportedWireFormats(wfs) | Extension::RequiredWireFormats(wfs) => {
143 tls_serialized_len_as_vlvec(wfs.tls_serialized_len())
144 }
145 #[cfg(feature = "draft-mahy-mls-ratchet-tree-options")]
146 Extension::RatchetTreeSourceDomains(rtsd) => {
147 tls_serialized_len_as_vlvec(rtsd.tls_serialized_len())
148 }
149 Extension::Arbitrary(ArbitraryExtension { extension_data, .. }) => {
150 tls_serialized_len_as_vlvec(extension_data.len())
151 }
152 };
153
154 ext_type_len + ext_value_len
155 }
156}
157
158impl tls_codec::Serialize for Extension {
159 fn tls_serialize<W: std::io::Write>(&self, writer: &mut W) -> Result<usize, tls_codec::Error> {
160 use tls_codec::Size as _;
161
162 let extension_id = ExtensionType::from(self);
163 let mut written = extension_id.tls_serialize(writer)?;
164
165 let extdata_len = self.tls_serialized_len() - written;
167 let mut extension_data = Vec::with_capacity(extdata_len);
168
169 let _ = match self {
170 Extension::ApplicationId(data) => data.tls_serialize(&mut extension_data)?,
171 Extension::RatchetTree(nodes) => nodes.tls_serialize(&mut extension_data)?,
172 Extension::RequiredCapabilities(caps) => caps.tls_serialize(&mut extension_data)?,
173 Extension::ExternalPub(ext_pub) => ext_pub.tls_serialize(&mut extension_data)?,
174 Extension::ExternalSenders(ext_senders) => {
175 ext_senders.tls_serialize(&mut extension_data)?
176 }
177 #[cfg(feature = "draft-ietf-mls-extensions")]
178 Extension::ApplicationData(app_data_dict) => {
179 app_data_dict.tls_serialize(&mut extension_data)?
180 }
181 #[cfg(feature = "draft-ietf-mls-extensions")]
182 Extension::SupportedWireFormats(wfs) | Extension::RequiredWireFormats(wfs) => {
183 wfs.tls_serialize(&mut extension_data)?
184 }
185 #[cfg(feature = "draft-mahy-mls-ratchet-tree-options")]
186 Extension::RatchetTreeSourceDomains(rtsd) => rtsd.tls_serialize(&mut extension_data)?,
187 Extension::Arbitrary(ArbitraryExtension {
188 extension_data: arbitrary_ext_data,
189 ..
190 }) => {
191 use std::io::Write as _;
192 extension_data.write_all(arbitrary_ext_data)?;
193 arbitrary_ext_data.len()
194 }
195 };
196
197 written += extension_data.tls_serialize(writer)?;
198
199 Ok(written)
200 }
201}
202
203impl tls_codec::Deserialize for Extension {
204 fn tls_deserialize<R: std::io::Read>(bytes: &mut R) -> Result<Self, tls_codec::Error>
205 where
206 Self: Sized,
207 {
208 Self::new(
209 *ExtensionType::tls_deserialize(bytes)?,
210 vlbytes::tls_deserialize(bytes)?,
211 )
212 .map_err(|e| match e {
213 crate::MlsSpecError::TlsCodecError(e) => e,
214 _ => tls_codec::Error::DecodingError(e.to_string()),
215 })
216 }
217}
218
219ref_forward_tls_impl!(Extension);
220
221#[derive(
222 Debug,
223 Clone,
224 PartialEq,
225 Eq,
226 Hash,
227 tls_codec::TlsSerialize,
228 tls_codec::TlsDeserialize,
229 tls_codec::TlsSize,
230)]
231#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
232pub struct ExternalPub {
233 pub external_pub: HpkePublicKey,
234}
235
236#[derive(Debug, Clone, PartialEq, Eq, Hash)]
237#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
238pub struct ArbitraryExtension {
239 pub extension_id: ExtensionType,
240 pub extension_data: Vec<u8>,
241}