Skip to main content

mls_spec/group/
extensions.rs

1use crate::{
2    crypto::HpkePublicKey,
3    group::{ExtensionType, ExternalSender, RequiredCapabilities},
4    macros::ref_forward_tls_impl,
5    tlspl::{bytes as vlbytes, tls_serialized_len_as_vlvec},
6    tree::RatchetTree,
7};
8
9#[derive(
10    Debug,
11    Clone,
12    PartialEq,
13    Eq,
14    Hash,
15    tls_codec::TlsDeserialize,
16    tls_codec::TlsSerialize,
17    tls_codec::TlsSize,
18)]
19#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
20pub struct RatchetTreeExtension {
21    pub ratchet_tree: RatchetTree,
22}
23
24#[derive(Debug, Clone, PartialEq, Eq, Hash)]
25#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
26#[repr(u16)]
27pub enum Extension {
28    /// Extension to uniquely identify clients
29    ///
30    /// <https://www.rfc-editor.org/rfc/rfc9420.html#section-5.3.3>
31    ApplicationId(Vec<u8>),
32    /// Sparse vec of TreeNodes, that is right-trimmed
33    RatchetTree(RatchetTreeExtension),
34    RequiredCapabilities(RequiredCapabilities),
35    /// Extension that enables "External Joins" via external commits
36    ExternalPub(ExternalPub),
37    /// Extension that allows external proposals to be signed by a third party (i.e. a server or something)
38    ExternalSenders(Vec<ExternalSender>),
39    #[cfg(feature = "draft-ietf-mls-extensions")]
40    ApplicationData(crate::drafts::mls_extensions::safe_application::ApplicationDataDictionary),
41    #[cfg(feature = "draft-ietf-mls-extensions")]
42    SupportedWireFormats(crate::drafts::mls_extensions::safe_application::WireFormats),
43    #[cfg(feature = "draft-ietf-mls-extensions")]
44    RequiredWireFormats(crate::drafts::mls_extensions::safe_application::WireFormats),
45    #[cfg(feature = "draft-ietf-mls-extensions")]
46    TargetedMessagesCapability,
47    #[cfg(feature = "draft-mahy-mls-ratchet-tree-options")]
48    RatchetTreeSourceDomains(
49        crate::drafts::ratchet_tree_options::RatchetTreeSourceDomainsExtension,
50    ),
51    Arbitrary(ArbitraryExtension),
52}
53
54impl From<&Extension> for ExtensionType {
55    fn from(value: &Extension) -> Self {
56        ExtensionType::new_unchecked(match value {
57            Extension::ApplicationId(_) => ExtensionType::APPLICATION_ID,
58            Extension::RatchetTree(_) => ExtensionType::RATCHET_TREE,
59            Extension::RequiredCapabilities(_) => ExtensionType::REQUIRED_CAPABILITIES,
60            Extension::ExternalPub(_) => ExtensionType::EXTERNAL_PUB,
61            Extension::ExternalSenders(_) => ExtensionType::EXTERNAL_SENDERS,
62            #[cfg(feature = "draft-ietf-mls-extensions")]
63            Extension::ApplicationData(_) => ExtensionType::APPLICATION_DATA_DICTIONARY,
64            #[cfg(feature = "draft-ietf-mls-extensions")]
65            Extension::SupportedWireFormats(_) => ExtensionType::SUPPORTED_WIRE_FORMATS,
66            #[cfg(feature = "draft-ietf-mls-extensions")]
67            Extension::RequiredWireFormats(_) => ExtensionType::REQUIRED_WIRE_FORMATS,
68            #[cfg(feature = "draft-ietf-mls-extensions")]
69            Extension::TargetedMessagesCapability => ExtensionType::TARGETED_MESSAGES_CAPABILITY,
70            #[cfg(feature = "draft-mahy-mls-ratchet-tree-options")]
71            Extension::RatchetTreeSourceDomains(_) => ExtensionType::RATCHET_TREE_SOURCE_DOMAINS,
72            Extension::Arbitrary(ArbitraryExtension { extension_id, .. }) => {
73                (**extension_id) as u16
74            }
75        })
76    }
77}
78
79impl Extension {
80    pub fn new(extension_id: u16, extension_data: Vec<u8>) -> crate::MlsSpecResult<Self> {
81        use tls_codec::Deserialize as _;
82        Ok(match extension_id {
83            ExtensionType::APPLICATION_ID => {
84                Self::ApplicationId(vlbytes::tls_deserialize(&mut &extension_data[..])?)
85            }
86            ExtensionType::RATCHET_TREE => Self::RatchetTree(
87                RatchetTreeExtension::tls_deserialize_exact(&extension_data)?,
88            ),
89            ExtensionType::REQUIRED_CAPABILITIES => Self::RequiredCapabilities(
90                RequiredCapabilities::tls_deserialize_exact(&extension_data)?,
91            ),
92            ExtensionType::EXTERNAL_PUB => {
93                Self::ExternalPub(ExternalPub::tls_deserialize_exact(&extension_data)?)
94            }
95            ExtensionType::EXTERNAL_SENDERS => Self::ExternalSenders(
96                Vec::<ExternalSender>::tls_deserialize_exact(&extension_data)?,
97            ),
98            #[cfg(feature = "draft-ietf-mls-extensions")]
99            ExtensionType::APPLICATION_DATA_DICTIONARY => Self::ApplicationData(
100                crate::drafts::mls_extensions::safe_application::ApplicationDataDictionary::tls_deserialize_exact(
101                    &extension_data,
102                )?,
103            ),
104            #[cfg(feature = "draft-ietf-mls-extensions")]
105            ExtensionType::SUPPORTED_WIRE_FORMATS => Self::SupportedWireFormats(<_>::tls_deserialize_exact(&extension_data)?),
106            #[cfg(feature = "draft-ietf-mls-extensions")]
107            ExtensionType::REQUIRED_WIRE_FORMATS => Self::RequiredWireFormats(<_>::tls_deserialize_exact(&extension_data)?),
108            #[cfg(feature = "draft-ietf-mls-extensions")]
109            ExtensionType::TARGETED_MESSAGES_CAPABILITY => Self::TargetedMessagesCapability,
110            #[cfg(feature = "draft-mahy-mls-ratchet-tree-options")]
111            ExtensionType::RATCHET_TREE_SOURCE_DOMAINS => Self::RatchetTreeSourceDomains(<_>::tls_deserialize_exact(&extension_data)?),
112            _ => Self::Arbitrary(ArbitraryExtension {
113                extension_id: ExtensionType::new_unchecked(extension_id),
114                extension_data,
115            }),
116        })
117    }
118
119    pub fn ext_type(&self) -> ExtensionType {
120        self.into()
121    }
122}
123
124impl tls_codec::Size for Extension {
125    fn tls_serialized_len(&self) -> usize {
126        let ext_type_len = ExtensionType::from(self).tls_serialized_len();
127        let ext_value_len = match self {
128            Extension::ApplicationId(data) => {
129                tls_serialized_len_as_vlvec(data.tls_serialized_len())
130            }
131            Extension::RatchetTree(nodes) => {
132                tls_serialized_len_as_vlvec(nodes.tls_serialized_len())
133            }
134            Extension::RequiredCapabilities(caps) => {
135                tls_serialized_len_as_vlvec(caps.tls_serialized_len())
136            }
137            Extension::ExternalPub(ext_pub) => {
138                tls_serialized_len_as_vlvec(ext_pub.tls_serialized_len())
139            }
140            Extension::ExternalSenders(ext_senders) => {
141                tls_serialized_len_as_vlvec(ext_senders.tls_serialized_len())
142            }
143            #[cfg(feature = "draft-ietf-mls-extensions")]
144            Extension::ApplicationData(app_data_dict) => {
145                tls_serialized_len_as_vlvec(app_data_dict.tls_serialized_len())
146            }
147            #[cfg(feature = "draft-ietf-mls-extensions")]
148            Extension::SupportedWireFormats(wfs) | Extension::RequiredWireFormats(wfs) => {
149                tls_serialized_len_as_vlvec(wfs.tls_serialized_len())
150            }
151            #[cfg(feature = "draft-ietf-mls-extensions")]
152            Extension::TargetedMessagesCapability => tls_serialized_len_as_vlvec(0),
153            #[cfg(feature = "draft-mahy-mls-ratchet-tree-options")]
154            Extension::RatchetTreeSourceDomains(rtsd) => {
155                tls_serialized_len_as_vlvec(rtsd.tls_serialized_len())
156            }
157            Extension::Arbitrary(ArbitraryExtension { extension_data, .. }) => {
158                tls_serialized_len_as_vlvec(extension_data.len())
159            }
160        };
161
162        ext_type_len + ext_value_len
163    }
164}
165
166impl tls_codec::Serialize for Extension {
167    fn tls_serialize<W: std::io::Write>(&self, writer: &mut W) -> Result<usize, tls_codec::Error> {
168        use tls_codec::Size as _;
169
170        let extension_id = ExtensionType::from(self);
171        let mut written = extension_id.tls_serialize(writer)?;
172
173        // FIXME: Probably can get rid of this copy
174        let extdata_len = self.tls_serialized_len() - written;
175        let mut extension_data = Vec::with_capacity(extdata_len);
176
177        let _ = match self {
178            Extension::ApplicationId(data) => data.tls_serialize(&mut extension_data)?,
179            Extension::RatchetTree(nodes) => nodes.tls_serialize(&mut extension_data)?,
180            Extension::RequiredCapabilities(caps) => caps.tls_serialize(&mut extension_data)?,
181            Extension::ExternalPub(ext_pub) => ext_pub.tls_serialize(&mut extension_data)?,
182            Extension::ExternalSenders(ext_senders) => {
183                ext_senders.tls_serialize(&mut extension_data)?
184            }
185            #[cfg(feature = "draft-ietf-mls-extensions")]
186            Extension::ApplicationData(app_data_dict) => {
187                app_data_dict.tls_serialize(&mut extension_data)?
188            }
189            #[cfg(feature = "draft-ietf-mls-extensions")]
190            Extension::SupportedWireFormats(wfs) | Extension::RequiredWireFormats(wfs) => {
191                wfs.tls_serialize(&mut extension_data)?
192            }
193            #[cfg(feature = "draft-ietf-mls-extensions")]
194            Extension::TargetedMessagesCapability => [0u8; 0].tls_serialize(&mut extension_data)?,
195            #[cfg(feature = "draft-mahy-mls-ratchet-tree-options")]
196            Extension::RatchetTreeSourceDomains(rtsd) => rtsd.tls_serialize(&mut extension_data)?,
197            Extension::Arbitrary(ArbitraryExtension {
198                extension_data: arbitrary_ext_data,
199                ..
200            }) => {
201                use std::io::Write as _;
202                extension_data.write_all(arbitrary_ext_data)?;
203                arbitrary_ext_data.len()
204            }
205        };
206
207        written += extension_data.tls_serialize(writer)?;
208
209        Ok(written)
210    }
211}
212
213impl tls_codec::Deserialize for Extension {
214    fn tls_deserialize<R: std::io::Read>(bytes: &mut R) -> Result<Self, tls_codec::Error>
215    where
216        Self: Sized,
217    {
218        Self::new(
219            *ExtensionType::tls_deserialize(bytes)?,
220            vlbytes::tls_deserialize(bytes)?,
221        )
222        .map_err(|e| match e {
223            crate::MlsSpecError::TlsCodecError(e) => e,
224            _ => tls_codec::Error::DecodingError(e.to_string()),
225        })
226    }
227}
228
229ref_forward_tls_impl!(Extension);
230
231#[derive(
232    Debug,
233    Clone,
234    PartialEq,
235    Eq,
236    Hash,
237    tls_codec::TlsSerialize,
238    tls_codec::TlsDeserialize,
239    tls_codec::TlsSize,
240)]
241#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
242pub struct ExternalPub {
243    pub external_pub: HpkePublicKey,
244}
245
246#[derive(Debug, Clone, PartialEq, Eq, Hash)]
247#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
248pub struct ArbitraryExtension {
249    pub extension_id: ExtensionType,
250    pub extension_data: Vec<u8>,
251}