1use crate::{
2 crypto::HpkePublicKey,
3 group::{ExtensionType, ExternalSender, RequiredCapabilities},
4 macros::ref_forward_tls_impl,
5 tlspl::{bytes as vlbytes, tls_serialized_len_as_vlvec},
6 tree::RatchetTree,
7};
8
9#[derive(
10 Debug,
11 Clone,
12 PartialEq,
13 Eq,
14 Hash,
15 tls_codec::TlsDeserialize,
16 tls_codec::TlsSerialize,
17 tls_codec::TlsSize,
18)]
19#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
20pub struct RatchetTreeExtension {
21 pub ratchet_tree: RatchetTree,
22}
23
24#[derive(Debug, Clone, PartialEq, Eq, Hash)]
25#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
26#[repr(u16)]
27pub enum Extension {
28 ApplicationId(Vec<u8>),
32 RatchetTree(RatchetTreeExtension),
34 RequiredCapabilities(RequiredCapabilities),
35 ExternalPub(ExternalPub),
37 ExternalSenders(Vec<ExternalSender>),
39 #[cfg(feature = "draft-ietf-mls-extensions")]
40 ApplicationData(crate::drafts::mls_extensions::safe_application::ApplicationDataDictionary),
41 #[cfg(feature = "draft-ietf-mls-extensions")]
42 SupportedWireFormats(crate::drafts::mls_extensions::safe_application::WireFormats),
43 #[cfg(feature = "draft-ietf-mls-extensions")]
44 RequiredWireFormats(crate::drafts::mls_extensions::safe_application::WireFormats),
45 #[cfg(feature = "draft-ietf-mls-extensions")]
46 TargetedMessagesCapability,
47 #[cfg(feature = "draft-mahy-mls-ratchet-tree-options")]
48 RatchetTreeSourceDomains(
49 crate::drafts::ratchet_tree_options::RatchetTreeSourceDomainsExtension,
50 ),
51 Arbitrary(ArbitraryExtension),
52}
53
54impl From<&Extension> for ExtensionType {
55 fn from(value: &Extension) -> Self {
56 ExtensionType::new_unchecked(match value {
57 Extension::ApplicationId(_) => ExtensionType::APPLICATION_ID,
58 Extension::RatchetTree(_) => ExtensionType::RATCHET_TREE,
59 Extension::RequiredCapabilities(_) => ExtensionType::REQUIRED_CAPABILITIES,
60 Extension::ExternalPub(_) => ExtensionType::EXTERNAL_PUB,
61 Extension::ExternalSenders(_) => ExtensionType::EXTERNAL_SENDERS,
62 #[cfg(feature = "draft-ietf-mls-extensions")]
63 Extension::ApplicationData(_) => ExtensionType::APPLICATION_DATA_DICTIONARY,
64 #[cfg(feature = "draft-ietf-mls-extensions")]
65 Extension::SupportedWireFormats(_) => ExtensionType::SUPPORTED_WIRE_FORMATS,
66 #[cfg(feature = "draft-ietf-mls-extensions")]
67 Extension::RequiredWireFormats(_) => ExtensionType::REQUIRED_WIRE_FORMATS,
68 #[cfg(feature = "draft-ietf-mls-extensions")]
69 Extension::TargetedMessagesCapability => ExtensionType::TARGETED_MESSAGES_CAPABILITY,
70 #[cfg(feature = "draft-mahy-mls-ratchet-tree-options")]
71 Extension::RatchetTreeSourceDomains(_) => ExtensionType::RATCHET_TREE_SOURCE_DOMAINS,
72 Extension::Arbitrary(ArbitraryExtension { extension_id, .. }) => {
73 (**extension_id) as u16
74 }
75 })
76 }
77}
78
79impl Extension {
80 pub fn new(extension_id: u16, extension_data: Vec<u8>) -> crate::MlsSpecResult<Self> {
81 use tls_codec::Deserialize as _;
82 Ok(match extension_id {
83 ExtensionType::APPLICATION_ID => {
84 Self::ApplicationId(vlbytes::tls_deserialize(&mut &extension_data[..])?)
85 }
86 ExtensionType::RATCHET_TREE => Self::RatchetTree(
87 RatchetTreeExtension::tls_deserialize_exact(&extension_data)?,
88 ),
89 ExtensionType::REQUIRED_CAPABILITIES => Self::RequiredCapabilities(
90 RequiredCapabilities::tls_deserialize_exact(&extension_data)?,
91 ),
92 ExtensionType::EXTERNAL_PUB => {
93 Self::ExternalPub(ExternalPub::tls_deserialize_exact(&extension_data)?)
94 }
95 ExtensionType::EXTERNAL_SENDERS => Self::ExternalSenders(
96 Vec::<ExternalSender>::tls_deserialize_exact(&extension_data)?,
97 ),
98 #[cfg(feature = "draft-ietf-mls-extensions")]
99 ExtensionType::APPLICATION_DATA_DICTIONARY => Self::ApplicationData(
100 crate::drafts::mls_extensions::safe_application::ApplicationDataDictionary::tls_deserialize_exact(
101 &extension_data,
102 )?,
103 ),
104 #[cfg(feature = "draft-ietf-mls-extensions")]
105 ExtensionType::SUPPORTED_WIRE_FORMATS => Self::SupportedWireFormats(<_>::tls_deserialize_exact(&extension_data)?),
106 #[cfg(feature = "draft-ietf-mls-extensions")]
107 ExtensionType::REQUIRED_WIRE_FORMATS => Self::RequiredWireFormats(<_>::tls_deserialize_exact(&extension_data)?),
108 #[cfg(feature = "draft-ietf-mls-extensions")]
109 ExtensionType::TARGETED_MESSAGES_CAPABILITY => Self::TargetedMessagesCapability,
110 #[cfg(feature = "draft-mahy-mls-ratchet-tree-options")]
111 ExtensionType::RATCHET_TREE_SOURCE_DOMAINS => Self::RatchetTreeSourceDomains(<_>::tls_deserialize_exact(&extension_data)?),
112 _ => Self::Arbitrary(ArbitraryExtension {
113 extension_id: ExtensionType::new_unchecked(extension_id),
114 extension_data,
115 }),
116 })
117 }
118
119 pub fn ext_type(&self) -> ExtensionType {
120 self.into()
121 }
122}
123
124impl tls_codec::Size for Extension {
125 fn tls_serialized_len(&self) -> usize {
126 let ext_type_len = ExtensionType::from(self).tls_serialized_len();
127 let ext_value_len = match self {
128 Extension::ApplicationId(data) => {
129 tls_serialized_len_as_vlvec(data.tls_serialized_len())
130 }
131 Extension::RatchetTree(nodes) => {
132 tls_serialized_len_as_vlvec(nodes.tls_serialized_len())
133 }
134 Extension::RequiredCapabilities(caps) => {
135 tls_serialized_len_as_vlvec(caps.tls_serialized_len())
136 }
137 Extension::ExternalPub(ext_pub) => {
138 tls_serialized_len_as_vlvec(ext_pub.tls_serialized_len())
139 }
140 Extension::ExternalSenders(ext_senders) => {
141 tls_serialized_len_as_vlvec(ext_senders.tls_serialized_len())
142 }
143 #[cfg(feature = "draft-ietf-mls-extensions")]
144 Extension::ApplicationData(app_data_dict) => {
145 tls_serialized_len_as_vlvec(app_data_dict.tls_serialized_len())
146 }
147 #[cfg(feature = "draft-ietf-mls-extensions")]
148 Extension::SupportedWireFormats(wfs) | Extension::RequiredWireFormats(wfs) => {
149 tls_serialized_len_as_vlvec(wfs.tls_serialized_len())
150 }
151 #[cfg(feature = "draft-ietf-mls-extensions")]
152 Extension::TargetedMessagesCapability => tls_serialized_len_as_vlvec(0),
153 #[cfg(feature = "draft-mahy-mls-ratchet-tree-options")]
154 Extension::RatchetTreeSourceDomains(rtsd) => {
155 tls_serialized_len_as_vlvec(rtsd.tls_serialized_len())
156 }
157 Extension::Arbitrary(ArbitraryExtension { extension_data, .. }) => {
158 tls_serialized_len_as_vlvec(extension_data.len())
159 }
160 };
161
162 ext_type_len + ext_value_len
163 }
164}
165
166impl tls_codec::Serialize for Extension {
167 fn tls_serialize<W: std::io::Write>(&self, writer: &mut W) -> Result<usize, tls_codec::Error> {
168 use tls_codec::Size as _;
169
170 let extension_id = ExtensionType::from(self);
171 let mut written = extension_id.tls_serialize(writer)?;
172
173 let extdata_len = self.tls_serialized_len() - written;
175 let mut extension_data = Vec::with_capacity(extdata_len);
176
177 let _ = match self {
178 Extension::ApplicationId(data) => data.tls_serialize(&mut extension_data)?,
179 Extension::RatchetTree(nodes) => nodes.tls_serialize(&mut extension_data)?,
180 Extension::RequiredCapabilities(caps) => caps.tls_serialize(&mut extension_data)?,
181 Extension::ExternalPub(ext_pub) => ext_pub.tls_serialize(&mut extension_data)?,
182 Extension::ExternalSenders(ext_senders) => {
183 ext_senders.tls_serialize(&mut extension_data)?
184 }
185 #[cfg(feature = "draft-ietf-mls-extensions")]
186 Extension::ApplicationData(app_data_dict) => {
187 app_data_dict.tls_serialize(&mut extension_data)?
188 }
189 #[cfg(feature = "draft-ietf-mls-extensions")]
190 Extension::SupportedWireFormats(wfs) | Extension::RequiredWireFormats(wfs) => {
191 wfs.tls_serialize(&mut extension_data)?
192 }
193 #[cfg(feature = "draft-ietf-mls-extensions")]
194 Extension::TargetedMessagesCapability => [0u8; 0].tls_serialize(&mut extension_data)?,
195 #[cfg(feature = "draft-mahy-mls-ratchet-tree-options")]
196 Extension::RatchetTreeSourceDomains(rtsd) => rtsd.tls_serialize(&mut extension_data)?,
197 Extension::Arbitrary(ArbitraryExtension {
198 extension_data: arbitrary_ext_data,
199 ..
200 }) => {
201 use std::io::Write as _;
202 extension_data.write_all(arbitrary_ext_data)?;
203 arbitrary_ext_data.len()
204 }
205 };
206
207 written += extension_data.tls_serialize(writer)?;
208
209 Ok(written)
210 }
211}
212
213impl tls_codec::Deserialize for Extension {
214 fn tls_deserialize<R: std::io::Read>(bytes: &mut R) -> Result<Self, tls_codec::Error>
215 where
216 Self: Sized,
217 {
218 Self::new(
219 *ExtensionType::tls_deserialize(bytes)?,
220 vlbytes::tls_deserialize(bytes)?,
221 )
222 .map_err(|e| match e {
223 crate::MlsSpecError::TlsCodecError(e) => e,
224 _ => tls_codec::Error::DecodingError(e.to_string()),
225 })
226 }
227}
228
229ref_forward_tls_impl!(Extension);
230
231#[derive(
232 Debug,
233 Clone,
234 PartialEq,
235 Eq,
236 Hash,
237 tls_codec::TlsSerialize,
238 tls_codec::TlsDeserialize,
239 tls_codec::TlsSize,
240)]
241#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
242pub struct ExternalPub {
243 pub external_pub: HpkePublicKey,
244}
245
246#[derive(Debug, Clone, PartialEq, Eq, Hash)]
247#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
248pub struct ArbitraryExtension {
249 pub extension_id: ExtensionType,
250 pub extension_data: Vec<u8>,
251}