1use crate::{
2 crypto::HpkePublicKey,
3 group::{ExtensionType, ExternalSender, RequiredCapabilities},
4 macros::ref_forward_tls_impl,
5 tlspl::{bytes as vlbytes, tls_serialized_len_as_vlvec},
6 tree::RatchetTree,
7};
8
9#[derive(
10 Debug,
11 Clone,
12 PartialEq,
13 Eq,
14 tls_codec::TlsDeserialize,
15 tls_codec::TlsSerialize,
16 tls_codec::TlsSize,
17)]
18#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
19pub struct RatchetTreeExtension {
20 pub ratchet_tree: RatchetTree,
21}
22
23#[derive(Debug, Clone, PartialEq, Eq)]
24#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
25#[repr(u16)]
26pub enum Extension {
27 ApplicationId(Vec<u8>),
31 RatchetTree(RatchetTreeExtension),
33 RequiredCapabilities(RequiredCapabilities),
34 ExternalPub(ExternalPub),
36 ExternalSenders(Vec<ExternalSender>),
38 #[cfg(feature = "draft-ietf-mls-extensions")]
39 ApplicationData(crate::drafts::mls_extensions::safe_application::ApplicationDataDictionary),
40 #[cfg(feature = "draft-ietf-mls-extensions")]
41 SupportedWireFormats(crate::drafts::mls_extensions::safe_application::WireFormats),
42 #[cfg(feature = "draft-ietf-mls-extensions")]
43 RequiredWireFormats(crate::drafts::mls_extensions::safe_application::WireFormats),
44 #[cfg(feature = "draft-ietf-mls-extensions")]
45 TargetedMessagesCapability,
46 Arbitrary(ArbitraryExtension),
47}
48
49impl From<&Extension> for ExtensionType {
50 fn from(value: &Extension) -> Self {
51 ExtensionType::new_unchecked(match value {
52 Extension::ApplicationId(_) => ExtensionType::APPLICATION_ID,
53 Extension::RatchetTree(_) => ExtensionType::RATCHET_TREE,
54 Extension::RequiredCapabilities(_) => ExtensionType::REQUIRED_CAPABILITIES,
55 Extension::ExternalPub(_) => ExtensionType::EXTERNAL_PUB,
56 Extension::ExternalSenders(_) => ExtensionType::EXTERNAL_SENDERS,
57 #[cfg(feature = "draft-ietf-mls-extensions")]
58 Extension::ApplicationData(_) => ExtensionType::APPLICATION_DATA_DICTIONARY,
59 #[cfg(feature = "draft-ietf-mls-extensions")]
60 Extension::SupportedWireFormats(_) => ExtensionType::SUPPORTED_WIRE_FORMATS,
61 #[cfg(feature = "draft-ietf-mls-extensions")]
62 Extension::RequiredWireFormats(_) => ExtensionType::REQUIRED_WIRE_FORMATS,
63 #[cfg(feature = "draft-ietf-mls-extensions")]
64 Extension::TargetedMessagesCapability => ExtensionType::TARGETED_MESSAGES_CAPABILITY,
65 Extension::Arbitrary(ArbitraryExtension { extension_id, .. }) => {
66 (**extension_id) as u16
67 }
68 })
69 }
70}
71
72impl Extension {
73 pub fn new(extension_id: u16, extension_data: Vec<u8>) -> crate::MlsSpecResult<Self> {
74 use tls_codec::Deserialize as _;
75 Ok(match extension_id {
76 ExtensionType::APPLICATION_ID => {
77 Self::ApplicationId(vlbytes::tls_deserialize(&mut &extension_data[..])?)
78 }
79 ExtensionType::RATCHET_TREE => Self::RatchetTree(
80 RatchetTreeExtension::tls_deserialize_exact(&extension_data)?,
81 ),
82 ExtensionType::REQUIRED_CAPABILITIES => Self::RequiredCapabilities(
83 RequiredCapabilities::tls_deserialize_exact(&extension_data)?,
84 ),
85 ExtensionType::EXTERNAL_PUB => {
86 Self::ExternalPub(ExternalPub::tls_deserialize_exact(&extension_data)?)
87 }
88 ExtensionType::EXTERNAL_SENDERS => Self::ExternalSenders(
89 Vec::<ExternalSender>::tls_deserialize_exact(&extension_data)?,
90 ),
91 #[cfg(feature = "draft-ietf-mls-extensions")]
92 ExtensionType::APPLICATION_DATA_DICTIONARY => Self::ApplicationData(
93 crate::drafts::mls_extensions::safe_application::ApplicationDataDictionary::tls_deserialize_exact(
94 &extension_data,
95 )?,
96 ),
97 #[cfg(feature = "draft-ietf-mls-extensions")]
98 ExtensionType::SUPPORTED_WIRE_FORMATS => Self::SupportedWireFormats(<_>::tls_deserialize_exact(&extension_data)?),
99 #[cfg(feature = "draft-ietf-mls-extensions")]
100 ExtensionType::REQUIRED_WIRE_FORMATS => Self::RequiredWireFormats(<_>::tls_deserialize_exact(&extension_data)?),
101 #[cfg(feature = "draft-ietf-mls-extensions")]
102 ExtensionType::TARGETED_MESSAGES_CAPABILITY => Self::TargetedMessagesCapability,
103 _ => Self::Arbitrary(ArbitraryExtension {
104 extension_id: ExtensionType::new_unchecked(extension_id),
105 extension_data,
106 }),
107 })
108 }
109
110 pub fn ext_type(&self) -> ExtensionType {
111 self.into()
112 }
113}
114
115impl tls_codec::Size for Extension {
116 fn tls_serialized_len(&self) -> usize {
117 let ext_type_len = ExtensionType::from(self).tls_serialized_len();
118 let ext_value_len = match self {
119 Extension::ApplicationId(data) => {
120 tls_serialized_len_as_vlvec(data.tls_serialized_len())
121 }
122 Extension::RatchetTree(nodes) => {
123 tls_serialized_len_as_vlvec(nodes.tls_serialized_len())
124 }
125 Extension::RequiredCapabilities(caps) => {
126 tls_serialized_len_as_vlvec(caps.tls_serialized_len())
127 }
128 Extension::ExternalPub(ext_pub) => {
129 tls_serialized_len_as_vlvec(ext_pub.tls_serialized_len())
130 }
131 Extension::ExternalSenders(ext_senders) => {
132 tls_serialized_len_as_vlvec(ext_senders.tls_serialized_len())
133 }
134 #[cfg(feature = "draft-ietf-mls-extensions")]
135 Extension::ApplicationData(app_data_dict) => {
136 tls_serialized_len_as_vlvec(app_data_dict.tls_serialized_len())
137 }
138 #[cfg(feature = "draft-ietf-mls-extensions")]
139 Extension::SupportedWireFormats(wfs) | Extension::RequiredWireFormats(wfs) => {
140 tls_serialized_len_as_vlvec(wfs.tls_serialized_len())
141 }
142 #[cfg(feature = "draft-ietf-mls-extensions")]
143 Extension::TargetedMessagesCapability => tls_serialized_len_as_vlvec(0),
144 Extension::Arbitrary(ArbitraryExtension { extension_data, .. }) => {
145 tls_serialized_len_as_vlvec(extension_data.len())
146 }
147 };
148
149 ext_type_len + ext_value_len
150 }
151}
152
153impl tls_codec::Serialize for Extension {
154 fn tls_serialize<W: std::io::Write>(&self, writer: &mut W) -> Result<usize, tls_codec::Error> {
155 use tls_codec::Size as _;
156
157 let extension_id = ExtensionType::from(self);
158 let mut written = extension_id.tls_serialize(writer)?;
159
160 let extdata_len = self.tls_serialized_len() - written;
162 let mut extension_data = Vec::with_capacity(extdata_len);
163
164 let _ = match self {
165 Extension::ApplicationId(data) => data.tls_serialize(&mut extension_data)?,
166 Extension::RatchetTree(nodes) => nodes.tls_serialize(&mut extension_data)?,
167 Extension::RequiredCapabilities(caps) => caps.tls_serialize(&mut extension_data)?,
168 Extension::ExternalPub(ext_pub) => ext_pub.tls_serialize(&mut extension_data)?,
169 Extension::ExternalSenders(ext_senders) => {
170 ext_senders.tls_serialize(&mut extension_data)?
171 }
172 #[cfg(feature = "draft-ietf-mls-extensions")]
173 Extension::ApplicationData(app_data_dict) => {
174 app_data_dict.tls_serialize(&mut extension_data)?
175 }
176 #[cfg(feature = "draft-ietf-mls-extensions")]
177 Extension::SupportedWireFormats(wfs) | Extension::RequiredWireFormats(wfs) => {
178 wfs.tls_serialize(&mut extension_data)?
179 }
180 #[cfg(feature = "draft-ietf-mls-extensions")]
181 Extension::TargetedMessagesCapability => [0u8; 0].tls_serialize(&mut extension_data)?,
182 Extension::Arbitrary(ArbitraryExtension {
183 extension_data: arbitrary_ext_data,
184 ..
185 }) => {
186 use std::io::Write as _;
187 extension_data.write_all(arbitrary_ext_data)?;
188 arbitrary_ext_data.len()
189 }
190 };
191
192 written += extension_data.tls_serialize(writer)?;
193
194 Ok(written)
195 }
196}
197
198impl tls_codec::Deserialize for Extension {
199 fn tls_deserialize<R: std::io::Read>(bytes: &mut R) -> Result<Self, tls_codec::Error>
200 where
201 Self: Sized,
202 {
203 Self::new(
204 *ExtensionType::tls_deserialize(bytes)?,
205 vlbytes::tls_deserialize(bytes)?,
206 )
207 .map_err(|e| match e {
208 crate::MlsSpecError::TlsCodecError(e) => e,
209 _ => tls_codec::Error::DecodingError(e.to_string()),
210 })
211 }
212}
213
214ref_forward_tls_impl!(Extension);
215
216#[derive(
217 Debug,
218 Clone,
219 PartialEq,
220 Eq,
221 tls_codec::TlsSerialize,
222 tls_codec::TlsDeserialize,
223 tls_codec::TlsSize,
224)]
225#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
226pub struct ExternalPub {
227 pub external_pub: HpkePublicKey,
228}
229
230#[derive(Debug, Clone, PartialEq, Eq)]
231#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
232pub struct ArbitraryExtension {
233 pub extension_id: ExtensionType,
234 pub extension_data: Vec<u8>,
235}