Skip to main content

mls_spec/group/
extensions.rs

1use crate::{
2    crypto::HpkePublicKey,
3    group::{ExtensionType, ExternalSender, RequiredCapabilities},
4    macros::ref_forward_tls_impl,
5    tlspl::{bytes as vlbytes, tls_serialized_len_as_vlvec},
6    tree::RatchetTree,
7};
8
9#[derive(
10    Debug,
11    Clone,
12    PartialEq,
13    Eq,
14    Hash,
15    tls_codec::TlsDeserialize,
16    tls_codec::TlsSerialize,
17    tls_codec::TlsSize,
18)]
19#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
20pub struct RatchetTreeExtension {
21    pub ratchet_tree: RatchetTree,
22}
23
24#[derive(Debug, Clone, PartialEq, Eq, Hash)]
25#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
26#[repr(u16)]
27pub enum Extension {
28    /// Extension to uniquely identify clients
29    ///
30    /// <https://www.rfc-editor.org/rfc/rfc9420.html#section-5.3.3>
31    ApplicationId(Vec<u8>),
32    /// Sparse vec of TreeNodes, that is right-trimmed
33    RatchetTree(RatchetTreeExtension),
34    RequiredCapabilities(RequiredCapabilities),
35    /// Extension that enables "External Joins" via external commits
36    ExternalPub(ExternalPub),
37    /// Extension that allows external proposals to be signed by a third party (i.e. a server or something)
38    ExternalSenders(Vec<ExternalSender>),
39    #[cfg(feature = "draft-ietf-mls-extensions")]
40    ApplicationData(crate::drafts::mls_extensions::safe_application::ApplicationDataDictionary),
41    #[cfg(feature = "draft-ietf-mls-extensions")]
42    SupportedWireFormats(crate::drafts::mls_extensions::safe_application::WireFormats),
43    #[cfg(feature = "draft-ietf-mls-extensions")]
44    RequiredWireFormats(crate::drafts::mls_extensions::safe_application::WireFormats),
45    #[cfg(feature = "draft-mahy-mls-ratchet-tree-options")]
46    RatchetTreeSourceDomains(
47        crate::drafts::ratchet_tree_options::RatchetTreeSourceDomainsExtension,
48    ),
49    Arbitrary(ArbitraryExtension),
50}
51
52impl From<&Extension> for ExtensionType {
53    fn from(value: &Extension) -> Self {
54        ExtensionType::new_unchecked(match value {
55            Extension::ApplicationId(_) => ExtensionType::APPLICATION_ID,
56            Extension::RatchetTree(_) => ExtensionType::RATCHET_TREE,
57            Extension::RequiredCapabilities(_) => ExtensionType::REQUIRED_CAPABILITIES,
58            Extension::ExternalPub(_) => ExtensionType::EXTERNAL_PUB,
59            Extension::ExternalSenders(_) => ExtensionType::EXTERNAL_SENDERS,
60            #[cfg(feature = "draft-ietf-mls-extensions")]
61            Extension::ApplicationData(_) => ExtensionType::APPLICATION_DATA_DICTIONARY,
62            #[cfg(feature = "draft-ietf-mls-extensions")]
63            Extension::SupportedWireFormats(_) => ExtensionType::SUPPORTED_WIRE_FORMATS,
64            #[cfg(feature = "draft-ietf-mls-extensions")]
65            Extension::RequiredWireFormats(_) => ExtensionType::REQUIRED_WIRE_FORMATS,
66            #[cfg(feature = "draft-mahy-mls-ratchet-tree-options")]
67            Extension::RatchetTreeSourceDomains(_) => ExtensionType::RATCHET_TREE_SOURCE_DOMAINS,
68            Extension::Arbitrary(ArbitraryExtension { extension_id, .. }) => {
69                (**extension_id) as u16
70            }
71        })
72    }
73}
74
75impl Extension {
76    pub fn new(extension_id: u16, extension_data: Vec<u8>) -> crate::MlsSpecResult<Self> {
77        use tls_codec::Deserialize as _;
78        Ok(match extension_id {
79            ExtensionType::APPLICATION_ID => {
80                Self::ApplicationId(vlbytes::tls_deserialize(&mut &extension_data[..])?)
81            }
82            ExtensionType::RATCHET_TREE => Self::RatchetTree(
83                RatchetTreeExtension::tls_deserialize_exact(&extension_data)?,
84            ),
85            ExtensionType::REQUIRED_CAPABILITIES => Self::RequiredCapabilities(
86                RequiredCapabilities::tls_deserialize_exact(&extension_data)?,
87            ),
88            ExtensionType::EXTERNAL_PUB => {
89                Self::ExternalPub(ExternalPub::tls_deserialize_exact(&extension_data)?)
90            }
91            ExtensionType::EXTERNAL_SENDERS => Self::ExternalSenders(
92                Vec::<ExternalSender>::tls_deserialize_exact(&extension_data)?,
93            ),
94            #[cfg(feature = "draft-ietf-mls-extensions")]
95            ExtensionType::APPLICATION_DATA_DICTIONARY => Self::ApplicationData(
96                crate::drafts::mls_extensions::safe_application::ApplicationDataDictionary::tls_deserialize_exact(
97                    &extension_data,
98                )?,
99            ),
100            #[cfg(feature = "draft-ietf-mls-extensions")]
101            ExtensionType::SUPPORTED_WIRE_FORMATS => Self::SupportedWireFormats(<_>::tls_deserialize_exact(&extension_data)?),
102            #[cfg(feature = "draft-ietf-mls-extensions")]
103            ExtensionType::REQUIRED_WIRE_FORMATS => Self::RequiredWireFormats(<_>::tls_deserialize_exact(&extension_data)?),
104            #[cfg(feature = "draft-mahy-mls-ratchet-tree-options")]
105            ExtensionType::RATCHET_TREE_SOURCE_DOMAINS => Self::RatchetTreeSourceDomains(<_>::tls_deserialize_exact(&extension_data)?),
106            _ => Self::Arbitrary(ArbitraryExtension {
107                extension_id: ExtensionType::new_unchecked(extension_id),
108                extension_data,
109            }),
110        })
111    }
112
113    pub fn ext_type(&self) -> ExtensionType {
114        self.into()
115    }
116}
117
118impl tls_codec::Size for Extension {
119    fn tls_serialized_len(&self) -> usize {
120        let ext_type_len = ExtensionType::from(self).tls_serialized_len();
121        let ext_value_len = match self {
122            Extension::ApplicationId(data) => {
123                tls_serialized_len_as_vlvec(data.tls_serialized_len())
124            }
125            Extension::RatchetTree(nodes) => {
126                tls_serialized_len_as_vlvec(nodes.tls_serialized_len())
127            }
128            Extension::RequiredCapabilities(caps) => {
129                tls_serialized_len_as_vlvec(caps.tls_serialized_len())
130            }
131            Extension::ExternalPub(ext_pub) => {
132                tls_serialized_len_as_vlvec(ext_pub.tls_serialized_len())
133            }
134            Extension::ExternalSenders(ext_senders) => {
135                tls_serialized_len_as_vlvec(ext_senders.tls_serialized_len())
136            }
137            #[cfg(feature = "draft-ietf-mls-extensions")]
138            Extension::ApplicationData(app_data_dict) => {
139                tls_serialized_len_as_vlvec(app_data_dict.tls_serialized_len())
140            }
141            #[cfg(feature = "draft-ietf-mls-extensions")]
142            Extension::SupportedWireFormats(wfs) | Extension::RequiredWireFormats(wfs) => {
143                tls_serialized_len_as_vlvec(wfs.tls_serialized_len())
144            }
145            #[cfg(feature = "draft-mahy-mls-ratchet-tree-options")]
146            Extension::RatchetTreeSourceDomains(rtsd) => {
147                tls_serialized_len_as_vlvec(rtsd.tls_serialized_len())
148            }
149            Extension::Arbitrary(ArbitraryExtension { extension_data, .. }) => {
150                tls_serialized_len_as_vlvec(extension_data.len())
151            }
152        };
153
154        ext_type_len + ext_value_len
155    }
156}
157
158impl tls_codec::Serialize for Extension {
159    fn tls_serialize<W: std::io::Write>(&self, writer: &mut W) -> Result<usize, tls_codec::Error> {
160        use tls_codec::Size as _;
161
162        let extension_id = ExtensionType::from(self);
163        let mut written = extension_id.tls_serialize(writer)?;
164
165        // FIXME: Probably can get rid of this copy
166        let extdata_len = self.tls_serialized_len() - written;
167        let mut extension_data = Vec::with_capacity(extdata_len);
168
169        let _ = match self {
170            Extension::ApplicationId(data) => data.tls_serialize(&mut extension_data)?,
171            Extension::RatchetTree(nodes) => nodes.tls_serialize(&mut extension_data)?,
172            Extension::RequiredCapabilities(caps) => caps.tls_serialize(&mut extension_data)?,
173            Extension::ExternalPub(ext_pub) => ext_pub.tls_serialize(&mut extension_data)?,
174            Extension::ExternalSenders(ext_senders) => {
175                ext_senders.tls_serialize(&mut extension_data)?
176            }
177            #[cfg(feature = "draft-ietf-mls-extensions")]
178            Extension::ApplicationData(app_data_dict) => {
179                app_data_dict.tls_serialize(&mut extension_data)?
180            }
181            #[cfg(feature = "draft-ietf-mls-extensions")]
182            Extension::SupportedWireFormats(wfs) | Extension::RequiredWireFormats(wfs) => {
183                wfs.tls_serialize(&mut extension_data)?
184            }
185            #[cfg(feature = "draft-mahy-mls-ratchet-tree-options")]
186            Extension::RatchetTreeSourceDomains(rtsd) => rtsd.tls_serialize(&mut extension_data)?,
187            Extension::Arbitrary(ArbitraryExtension {
188                extension_data: arbitrary_ext_data,
189                ..
190            }) => {
191                use std::io::Write as _;
192                extension_data.write_all(arbitrary_ext_data)?;
193                arbitrary_ext_data.len()
194            }
195        };
196
197        written += extension_data.tls_serialize(writer)?;
198
199        Ok(written)
200    }
201}
202
203impl tls_codec::Deserialize for Extension {
204    fn tls_deserialize<R: std::io::Read>(bytes: &mut R) -> Result<Self, tls_codec::Error>
205    where
206        Self: Sized,
207    {
208        Self::new(
209            *ExtensionType::tls_deserialize(bytes)?,
210            vlbytes::tls_deserialize(bytes)?,
211        )
212        .map_err(|e| match e {
213            crate::MlsSpecError::TlsCodecError(e) => e,
214            _ => tls_codec::Error::DecodingError(e.to_string()),
215        })
216    }
217}
218
219ref_forward_tls_impl!(Extension);
220
221#[derive(
222    Debug,
223    Clone,
224    PartialEq,
225    Eq,
226    Hash,
227    tls_codec::TlsSerialize,
228    tls_codec::TlsDeserialize,
229    tls_codec::TlsSize,
230)]
231#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
232pub struct ExternalPub {
233    pub external_pub: HpkePublicKey,
234}
235
236#[derive(Debug, Clone, PartialEq, Eq, Hash)]
237#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
238pub struct ArbitraryExtension {
239    pub extension_id: ExtensionType,
240    pub extension_data: Vec<u8>,
241}