mls_spec/group/
extensions.rs

1use crate::{
2    crypto::HpkePublicKey,
3    group::{ExtensionType, ExternalSender, RequiredCapabilities},
4    macros::ref_forward_tls_impl,
5    tlspl::{bytes as vlbytes, tls_serialized_len_as_vlvec},
6    tree::RatchetTree,
7};
8
9#[derive(
10    Debug,
11    Clone,
12    PartialEq,
13    Eq,
14    tls_codec::TlsDeserialize,
15    tls_codec::TlsSerialize,
16    tls_codec::TlsSize,
17)]
18#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
19pub struct RatchetTreeExtension {
20    pub ratchet_tree: RatchetTree,
21}
22
23#[derive(Debug, Clone, PartialEq, Eq)]
24#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
25#[repr(u16)]
26pub enum Extension {
27    /// Extension to uniquely identify clients
28    ///
29    /// <https://www.rfc-editor.org/rfc/rfc9420.html#section-5.3.3>
30    ApplicationId(Vec<u8>),
31    /// Sparse vec of TreeNodes, that is right-trimmed
32    RatchetTree(RatchetTreeExtension),
33    RequiredCapabilities(RequiredCapabilities),
34    /// Extension that enables "External Joins" via external commits
35    ExternalPub(ExternalPub),
36    /// Extension that allows external proposals to be signed by a third party (i.e. a server or something)
37    ExternalSenders(Vec<ExternalSender>),
38    #[cfg(feature = "draft-ietf-mls-extensions")]
39    ApplicationData(crate::drafts::mls_extensions::safe_application::ApplicationDataDictionary),
40    #[cfg(feature = "draft-ietf-mls-extensions")]
41    SupportedWireFormats(crate::drafts::mls_extensions::safe_application::WireFormats),
42    #[cfg(feature = "draft-ietf-mls-extensions")]
43    RequiredWireFormats(crate::drafts::mls_extensions::safe_application::WireFormats),
44    #[cfg(feature = "draft-ietf-mls-extensions")]
45    TargetedMessagesCapability,
46    Arbitrary(ArbitraryExtension),
47}
48
49impl From<&Extension> for ExtensionType {
50    fn from(value: &Extension) -> Self {
51        ExtensionType::new_unchecked(match value {
52            Extension::ApplicationId(_) => ExtensionType::APPLICATION_ID,
53            Extension::RatchetTree(_) => ExtensionType::RATCHET_TREE,
54            Extension::RequiredCapabilities(_) => ExtensionType::REQUIRED_CAPABILITIES,
55            Extension::ExternalPub(_) => ExtensionType::EXTERNAL_PUB,
56            Extension::ExternalSenders(_) => ExtensionType::EXTERNAL_SENDERS,
57            #[cfg(feature = "draft-ietf-mls-extensions")]
58            Extension::ApplicationData(_) => ExtensionType::APPLICATION_DATA_DICTIONARY,
59            #[cfg(feature = "draft-ietf-mls-extensions")]
60            Extension::SupportedWireFormats(_) => ExtensionType::SUPPORTED_WIRE_FORMATS,
61            #[cfg(feature = "draft-ietf-mls-extensions")]
62            Extension::RequiredWireFormats(_) => ExtensionType::REQUIRED_WIRE_FORMATS,
63            #[cfg(feature = "draft-ietf-mls-extensions")]
64            Extension::TargetedMessagesCapability => ExtensionType::TARGETED_MESSAGES_CAPABILITY,
65            Extension::Arbitrary(ArbitraryExtension { extension_id, .. }) => {
66                (**extension_id) as u16
67            }
68        })
69    }
70}
71
72impl Extension {
73    pub fn new(extension_id: u16, extension_data: Vec<u8>) -> crate::MlsSpecResult<Self> {
74        use tls_codec::Deserialize as _;
75        Ok(match extension_id {
76            ExtensionType::APPLICATION_ID => {
77                Self::ApplicationId(vlbytes::tls_deserialize(&mut &extension_data[..])?)
78            }
79            ExtensionType::RATCHET_TREE => Self::RatchetTree(
80                RatchetTreeExtension::tls_deserialize_exact(&extension_data)?,
81            ),
82            ExtensionType::REQUIRED_CAPABILITIES => Self::RequiredCapabilities(
83                RequiredCapabilities::tls_deserialize_exact(&extension_data)?,
84            ),
85            ExtensionType::EXTERNAL_PUB => {
86                Self::ExternalPub(ExternalPub::tls_deserialize_exact(&extension_data)?)
87            }
88            ExtensionType::EXTERNAL_SENDERS => Self::ExternalSenders(
89                Vec::<ExternalSender>::tls_deserialize_exact(&extension_data)?,
90            ),
91            #[cfg(feature = "draft-ietf-mls-extensions")]
92            ExtensionType::APPLICATION_DATA_DICTIONARY => Self::ApplicationData(
93                crate::drafts::mls_extensions::safe_application::ApplicationDataDictionary::tls_deserialize_exact(
94                    &extension_data,
95                )?,
96            ),
97            #[cfg(feature = "draft-ietf-mls-extensions")]
98            ExtensionType::SUPPORTED_WIRE_FORMATS => Self::SupportedWireFormats(<_>::tls_deserialize_exact(&extension_data)?),
99            #[cfg(feature = "draft-ietf-mls-extensions")]
100            ExtensionType::REQUIRED_WIRE_FORMATS => Self::RequiredWireFormats(<_>::tls_deserialize_exact(&extension_data)?),
101            #[cfg(feature = "draft-ietf-mls-extensions")]
102            ExtensionType::TARGETED_MESSAGES_CAPABILITY => Self::TargetedMessagesCapability,
103            _ => Self::Arbitrary(ArbitraryExtension {
104                extension_id: ExtensionType::new_unchecked(extension_id),
105                extension_data,
106            }),
107        })
108    }
109
110    pub fn ext_type(&self) -> ExtensionType {
111        self.into()
112    }
113}
114
115impl tls_codec::Size for Extension {
116    fn tls_serialized_len(&self) -> usize {
117        let ext_type_len = ExtensionType::from(self).tls_serialized_len();
118        let ext_value_len = match self {
119            Extension::ApplicationId(data) => {
120                tls_serialized_len_as_vlvec(data.tls_serialized_len())
121            }
122            Extension::RatchetTree(nodes) => {
123                tls_serialized_len_as_vlvec(nodes.tls_serialized_len())
124            }
125            Extension::RequiredCapabilities(caps) => {
126                tls_serialized_len_as_vlvec(caps.tls_serialized_len())
127            }
128            Extension::ExternalPub(ext_pub) => {
129                tls_serialized_len_as_vlvec(ext_pub.tls_serialized_len())
130            }
131            Extension::ExternalSenders(ext_senders) => {
132                tls_serialized_len_as_vlvec(ext_senders.tls_serialized_len())
133            }
134            #[cfg(feature = "draft-ietf-mls-extensions")]
135            Extension::ApplicationData(app_data_dict) => {
136                tls_serialized_len_as_vlvec(app_data_dict.tls_serialized_len())
137            }
138            #[cfg(feature = "draft-ietf-mls-extensions")]
139            Extension::SupportedWireFormats(wfs) | Extension::RequiredWireFormats(wfs) => {
140                tls_serialized_len_as_vlvec(wfs.tls_serialized_len())
141            }
142            #[cfg(feature = "draft-ietf-mls-extensions")]
143            Extension::TargetedMessagesCapability => tls_serialized_len_as_vlvec(0),
144            Extension::Arbitrary(ArbitraryExtension { extension_data, .. }) => {
145                tls_serialized_len_as_vlvec(extension_data.len())
146            }
147        };
148
149        ext_type_len + ext_value_len
150    }
151}
152
153impl tls_codec::Serialize for Extension {
154    fn tls_serialize<W: std::io::Write>(&self, writer: &mut W) -> Result<usize, tls_codec::Error> {
155        use tls_codec::Size as _;
156
157        let extension_id = ExtensionType::from(self);
158        let mut written = extension_id.tls_serialize(writer)?;
159
160        // FIXME: Probably can get rid of this copy
161        let extdata_len = self.tls_serialized_len() - written;
162        let mut extension_data = Vec::with_capacity(extdata_len);
163
164        let _ = match self {
165            Extension::ApplicationId(data) => data.tls_serialize(&mut extension_data)?,
166            Extension::RatchetTree(nodes) => nodes.tls_serialize(&mut extension_data)?,
167            Extension::RequiredCapabilities(caps) => caps.tls_serialize(&mut extension_data)?,
168            Extension::ExternalPub(ext_pub) => ext_pub.tls_serialize(&mut extension_data)?,
169            Extension::ExternalSenders(ext_senders) => {
170                ext_senders.tls_serialize(&mut extension_data)?
171            }
172            #[cfg(feature = "draft-ietf-mls-extensions")]
173            Extension::ApplicationData(app_data_dict) => {
174                app_data_dict.tls_serialize(&mut extension_data)?
175            }
176            #[cfg(feature = "draft-ietf-mls-extensions")]
177            Extension::SupportedWireFormats(wfs) | Extension::RequiredWireFormats(wfs) => {
178                wfs.tls_serialize(&mut extension_data)?
179            }
180            #[cfg(feature = "draft-ietf-mls-extensions")]
181            Extension::TargetedMessagesCapability => [0u8; 0].tls_serialize(&mut extension_data)?,
182            Extension::Arbitrary(ArbitraryExtension {
183                extension_data: arbitrary_ext_data,
184                ..
185            }) => {
186                use std::io::Write as _;
187                extension_data.write_all(arbitrary_ext_data)?;
188                arbitrary_ext_data.len()
189            }
190        };
191
192        written += extension_data.tls_serialize(writer)?;
193
194        Ok(written)
195    }
196}
197
198impl tls_codec::Deserialize for Extension {
199    fn tls_deserialize<R: std::io::Read>(bytes: &mut R) -> Result<Self, tls_codec::Error>
200    where
201        Self: Sized,
202    {
203        Self::new(
204            *ExtensionType::tls_deserialize(bytes)?,
205            vlbytes::tls_deserialize(bytes)?,
206        )
207        .map_err(|e| match e {
208            crate::MlsSpecError::TlsCodecError(e) => e,
209            _ => tls_codec::Error::DecodingError(e.to_string()),
210        })
211    }
212}
213
214ref_forward_tls_impl!(Extension);
215
216#[derive(
217    Debug,
218    Clone,
219    PartialEq,
220    Eq,
221    tls_codec::TlsSerialize,
222    tls_codec::TlsDeserialize,
223    tls_codec::TlsSize,
224)]
225#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
226pub struct ExternalPub {
227    pub external_pub: HpkePublicKey,
228}
229
230#[derive(Debug, Clone, PartialEq, Eq)]
231#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
232pub struct ArbitraryExtension {
233    pub extension_id: ExtensionType,
234    pub extension_data: Vec<u8>,
235}