1use crate::{
2 crypto::HpkePublicKey,
3 group::{ExtensionType, ExternalSender, RequiredCapabilities},
4 macros::ref_forward_tls_impl,
5 tlspl::{bytes as vlbytes, tls_serialized_len_as_vlvec},
6 tree::RatchetTree,
7};
8
9#[derive(
10 Debug,
11 Clone,
12 PartialEq,
13 Eq,
14 Hash,
15 tls_codec::TlsDeserialize,
16 tls_codec::TlsSerialize,
17 tls_codec::TlsSize,
18)]
19#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
20pub struct RatchetTreeExtension {
21 pub ratchet_tree: RatchetTree,
22}
23
24#[derive(Debug, Clone, PartialEq, Eq, Hash)]
25#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
26#[repr(u16)]
27pub enum Extension {
28 ApplicationId(Vec<u8>),
32 RatchetTree(RatchetTreeExtension),
34 RequiredCapabilities(RequiredCapabilities),
35 ExternalPub(ExternalPub),
37 ExternalSenders(Vec<ExternalSender>),
39 #[cfg(feature = "draft-ietf-mls-extensions")]
40 ApplicationData(crate::drafts::mls_extensions::safe_application::ApplicationDataDictionary),
41 #[cfg(feature = "draft-ietf-mls-extensions")]
42 SupportedWireFormats(crate::drafts::mls_extensions::safe_application::WireFormats),
43 #[cfg(feature = "draft-ietf-mls-extensions")]
44 RequiredWireFormats(crate::drafts::mls_extensions::safe_application::WireFormats),
45 #[cfg(feature = "draft-ietf-mls-extensions")]
46 TargetedMessagesCapability,
47 Arbitrary(ArbitraryExtension),
48}
49
50impl From<&Extension> for ExtensionType {
51 fn from(value: &Extension) -> Self {
52 ExtensionType::new_unchecked(match value {
53 Extension::ApplicationId(_) => ExtensionType::APPLICATION_ID,
54 Extension::RatchetTree(_) => ExtensionType::RATCHET_TREE,
55 Extension::RequiredCapabilities(_) => ExtensionType::REQUIRED_CAPABILITIES,
56 Extension::ExternalPub(_) => ExtensionType::EXTERNAL_PUB,
57 Extension::ExternalSenders(_) => ExtensionType::EXTERNAL_SENDERS,
58 #[cfg(feature = "draft-ietf-mls-extensions")]
59 Extension::ApplicationData(_) => ExtensionType::APPLICATION_DATA_DICTIONARY,
60 #[cfg(feature = "draft-ietf-mls-extensions")]
61 Extension::SupportedWireFormats(_) => ExtensionType::SUPPORTED_WIRE_FORMATS,
62 #[cfg(feature = "draft-ietf-mls-extensions")]
63 Extension::RequiredWireFormats(_) => ExtensionType::REQUIRED_WIRE_FORMATS,
64 #[cfg(feature = "draft-ietf-mls-extensions")]
65 Extension::TargetedMessagesCapability => ExtensionType::TARGETED_MESSAGES_CAPABILITY,
66 Extension::Arbitrary(ArbitraryExtension { extension_id, .. }) => {
67 (**extension_id) as u16
68 }
69 })
70 }
71}
72
73impl Extension {
74 pub fn new(extension_id: u16, extension_data: Vec<u8>) -> crate::MlsSpecResult<Self> {
75 use tls_codec::Deserialize as _;
76 Ok(match extension_id {
77 ExtensionType::APPLICATION_ID => {
78 Self::ApplicationId(vlbytes::tls_deserialize(&mut &extension_data[..])?)
79 }
80 ExtensionType::RATCHET_TREE => Self::RatchetTree(
81 RatchetTreeExtension::tls_deserialize_exact(&extension_data)?,
82 ),
83 ExtensionType::REQUIRED_CAPABILITIES => Self::RequiredCapabilities(
84 RequiredCapabilities::tls_deserialize_exact(&extension_data)?,
85 ),
86 ExtensionType::EXTERNAL_PUB => {
87 Self::ExternalPub(ExternalPub::tls_deserialize_exact(&extension_data)?)
88 }
89 ExtensionType::EXTERNAL_SENDERS => Self::ExternalSenders(
90 Vec::<ExternalSender>::tls_deserialize_exact(&extension_data)?,
91 ),
92 #[cfg(feature = "draft-ietf-mls-extensions")]
93 ExtensionType::APPLICATION_DATA_DICTIONARY => Self::ApplicationData(
94 crate::drafts::mls_extensions::safe_application::ApplicationDataDictionary::tls_deserialize_exact(
95 &extension_data,
96 )?,
97 ),
98 #[cfg(feature = "draft-ietf-mls-extensions")]
99 ExtensionType::SUPPORTED_WIRE_FORMATS => Self::SupportedWireFormats(<_>::tls_deserialize_exact(&extension_data)?),
100 #[cfg(feature = "draft-ietf-mls-extensions")]
101 ExtensionType::REQUIRED_WIRE_FORMATS => Self::RequiredWireFormats(<_>::tls_deserialize_exact(&extension_data)?),
102 #[cfg(feature = "draft-ietf-mls-extensions")]
103 ExtensionType::TARGETED_MESSAGES_CAPABILITY => Self::TargetedMessagesCapability,
104 _ => Self::Arbitrary(ArbitraryExtension {
105 extension_id: ExtensionType::new_unchecked(extension_id),
106 extension_data,
107 }),
108 })
109 }
110
111 pub fn ext_type(&self) -> ExtensionType {
112 self.into()
113 }
114}
115
116impl tls_codec::Size for Extension {
117 fn tls_serialized_len(&self) -> usize {
118 let ext_type_len = ExtensionType::from(self).tls_serialized_len();
119 let ext_value_len = match self {
120 Extension::ApplicationId(data) => {
121 tls_serialized_len_as_vlvec(data.tls_serialized_len())
122 }
123 Extension::RatchetTree(nodes) => {
124 tls_serialized_len_as_vlvec(nodes.tls_serialized_len())
125 }
126 Extension::RequiredCapabilities(caps) => {
127 tls_serialized_len_as_vlvec(caps.tls_serialized_len())
128 }
129 Extension::ExternalPub(ext_pub) => {
130 tls_serialized_len_as_vlvec(ext_pub.tls_serialized_len())
131 }
132 Extension::ExternalSenders(ext_senders) => {
133 tls_serialized_len_as_vlvec(ext_senders.tls_serialized_len())
134 }
135 #[cfg(feature = "draft-ietf-mls-extensions")]
136 Extension::ApplicationData(app_data_dict) => {
137 tls_serialized_len_as_vlvec(app_data_dict.tls_serialized_len())
138 }
139 #[cfg(feature = "draft-ietf-mls-extensions")]
140 Extension::SupportedWireFormats(wfs) | Extension::RequiredWireFormats(wfs) => {
141 tls_serialized_len_as_vlvec(wfs.tls_serialized_len())
142 }
143 #[cfg(feature = "draft-ietf-mls-extensions")]
144 Extension::TargetedMessagesCapability => tls_serialized_len_as_vlvec(0),
145 Extension::Arbitrary(ArbitraryExtension { extension_data, .. }) => {
146 tls_serialized_len_as_vlvec(extension_data.len())
147 }
148 };
149
150 ext_type_len + ext_value_len
151 }
152}
153
154impl tls_codec::Serialize for Extension {
155 fn tls_serialize<W: std::io::Write>(&self, writer: &mut W) -> Result<usize, tls_codec::Error> {
156 use tls_codec::Size as _;
157
158 let extension_id = ExtensionType::from(self);
159 let mut written = extension_id.tls_serialize(writer)?;
160
161 let extdata_len = self.tls_serialized_len() - written;
163 let mut extension_data = Vec::with_capacity(extdata_len);
164
165 let _ = match self {
166 Extension::ApplicationId(data) => data.tls_serialize(&mut extension_data)?,
167 Extension::RatchetTree(nodes) => nodes.tls_serialize(&mut extension_data)?,
168 Extension::RequiredCapabilities(caps) => caps.tls_serialize(&mut extension_data)?,
169 Extension::ExternalPub(ext_pub) => ext_pub.tls_serialize(&mut extension_data)?,
170 Extension::ExternalSenders(ext_senders) => {
171 ext_senders.tls_serialize(&mut extension_data)?
172 }
173 #[cfg(feature = "draft-ietf-mls-extensions")]
174 Extension::ApplicationData(app_data_dict) => {
175 app_data_dict.tls_serialize(&mut extension_data)?
176 }
177 #[cfg(feature = "draft-ietf-mls-extensions")]
178 Extension::SupportedWireFormats(wfs) | Extension::RequiredWireFormats(wfs) => {
179 wfs.tls_serialize(&mut extension_data)?
180 }
181 #[cfg(feature = "draft-ietf-mls-extensions")]
182 Extension::TargetedMessagesCapability => [0u8; 0].tls_serialize(&mut extension_data)?,
183 Extension::Arbitrary(ArbitraryExtension {
184 extension_data: arbitrary_ext_data,
185 ..
186 }) => {
187 use std::io::Write as _;
188 extension_data.write_all(arbitrary_ext_data)?;
189 arbitrary_ext_data.len()
190 }
191 };
192
193 written += extension_data.tls_serialize(writer)?;
194
195 Ok(written)
196 }
197}
198
199impl tls_codec::Deserialize for Extension {
200 fn tls_deserialize<R: std::io::Read>(bytes: &mut R) -> Result<Self, tls_codec::Error>
201 where
202 Self: Sized,
203 {
204 Self::new(
205 *ExtensionType::tls_deserialize(bytes)?,
206 vlbytes::tls_deserialize(bytes)?,
207 )
208 .map_err(|e| match e {
209 crate::MlsSpecError::TlsCodecError(e) => e,
210 _ => tls_codec::Error::DecodingError(e.to_string()),
211 })
212 }
213}
214
215ref_forward_tls_impl!(Extension);
216
217#[derive(
218 Debug,
219 Clone,
220 PartialEq,
221 Eq,
222 Hash,
223 tls_codec::TlsSerialize,
224 tls_codec::TlsDeserialize,
225 tls_codec::TlsSize,
226)]
227#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
228pub struct ExternalPub {
229 pub external_pub: HpkePublicKey,
230}
231
232#[derive(Debug, Clone, PartialEq, Eq, Hash)]
233#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
234pub struct ArbitraryExtension {
235 pub extension_id: ExtensionType,
236 pub extension_data: Vec<u8>,
237}