mls_spec/group/
extensions.rs

1use crate::{
2    crypto::HpkePublicKey,
3    group::{ExtensionType, ExternalSender, RequiredCapabilities},
4    macros::ref_forward_tls_impl,
5    tlspl::{bytes as vlbytes, tls_serialized_len_as_vlvec},
6    tree::RatchetTree,
7};
8
9#[derive(
10    Debug,
11    Clone,
12    PartialEq,
13    Eq,
14    Hash,
15    tls_codec::TlsDeserialize,
16    tls_codec::TlsSerialize,
17    tls_codec::TlsSize,
18)]
19#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
20pub struct RatchetTreeExtension {
21    pub ratchet_tree: RatchetTree,
22}
23
24#[derive(Debug, Clone, PartialEq, Eq, Hash)]
25#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
26#[repr(u16)]
27pub enum Extension {
28    /// Extension to uniquely identify clients
29    ///
30    /// <https://www.rfc-editor.org/rfc/rfc9420.html#section-5.3.3>
31    ApplicationId(Vec<u8>),
32    /// Sparse vec of TreeNodes, that is right-trimmed
33    RatchetTree(RatchetTreeExtension),
34    RequiredCapabilities(RequiredCapabilities),
35    /// Extension that enables "External Joins" via external commits
36    ExternalPub(ExternalPub),
37    /// Extension that allows external proposals to be signed by a third party (i.e. a server or something)
38    ExternalSenders(Vec<ExternalSender>),
39    #[cfg(feature = "draft-ietf-mls-extensions")]
40    ApplicationData(crate::drafts::mls_extensions::safe_application::ApplicationDataDictionary),
41    #[cfg(feature = "draft-ietf-mls-extensions")]
42    SupportedWireFormats(crate::drafts::mls_extensions::safe_application::WireFormats),
43    #[cfg(feature = "draft-ietf-mls-extensions")]
44    RequiredWireFormats(crate::drafts::mls_extensions::safe_application::WireFormats),
45    #[cfg(feature = "draft-ietf-mls-extensions")]
46    TargetedMessagesCapability,
47    Arbitrary(ArbitraryExtension),
48}
49
50impl From<&Extension> for ExtensionType {
51    fn from(value: &Extension) -> Self {
52        ExtensionType::new_unchecked(match value {
53            Extension::ApplicationId(_) => ExtensionType::APPLICATION_ID,
54            Extension::RatchetTree(_) => ExtensionType::RATCHET_TREE,
55            Extension::RequiredCapabilities(_) => ExtensionType::REQUIRED_CAPABILITIES,
56            Extension::ExternalPub(_) => ExtensionType::EXTERNAL_PUB,
57            Extension::ExternalSenders(_) => ExtensionType::EXTERNAL_SENDERS,
58            #[cfg(feature = "draft-ietf-mls-extensions")]
59            Extension::ApplicationData(_) => ExtensionType::APPLICATION_DATA_DICTIONARY,
60            #[cfg(feature = "draft-ietf-mls-extensions")]
61            Extension::SupportedWireFormats(_) => ExtensionType::SUPPORTED_WIRE_FORMATS,
62            #[cfg(feature = "draft-ietf-mls-extensions")]
63            Extension::RequiredWireFormats(_) => ExtensionType::REQUIRED_WIRE_FORMATS,
64            #[cfg(feature = "draft-ietf-mls-extensions")]
65            Extension::TargetedMessagesCapability => ExtensionType::TARGETED_MESSAGES_CAPABILITY,
66            Extension::Arbitrary(ArbitraryExtension { extension_id, .. }) => {
67                (**extension_id) as u16
68            }
69        })
70    }
71}
72
73impl Extension {
74    pub fn new(extension_id: u16, extension_data: Vec<u8>) -> crate::MlsSpecResult<Self> {
75        use tls_codec::Deserialize as _;
76        Ok(match extension_id {
77            ExtensionType::APPLICATION_ID => {
78                Self::ApplicationId(vlbytes::tls_deserialize(&mut &extension_data[..])?)
79            }
80            ExtensionType::RATCHET_TREE => Self::RatchetTree(
81                RatchetTreeExtension::tls_deserialize_exact(&extension_data)?,
82            ),
83            ExtensionType::REQUIRED_CAPABILITIES => Self::RequiredCapabilities(
84                RequiredCapabilities::tls_deserialize_exact(&extension_data)?,
85            ),
86            ExtensionType::EXTERNAL_PUB => {
87                Self::ExternalPub(ExternalPub::tls_deserialize_exact(&extension_data)?)
88            }
89            ExtensionType::EXTERNAL_SENDERS => Self::ExternalSenders(
90                Vec::<ExternalSender>::tls_deserialize_exact(&extension_data)?,
91            ),
92            #[cfg(feature = "draft-ietf-mls-extensions")]
93            ExtensionType::APPLICATION_DATA_DICTIONARY => Self::ApplicationData(
94                crate::drafts::mls_extensions::safe_application::ApplicationDataDictionary::tls_deserialize_exact(
95                    &extension_data,
96                )?,
97            ),
98            #[cfg(feature = "draft-ietf-mls-extensions")]
99            ExtensionType::SUPPORTED_WIRE_FORMATS => Self::SupportedWireFormats(<_>::tls_deserialize_exact(&extension_data)?),
100            #[cfg(feature = "draft-ietf-mls-extensions")]
101            ExtensionType::REQUIRED_WIRE_FORMATS => Self::RequiredWireFormats(<_>::tls_deserialize_exact(&extension_data)?),
102            #[cfg(feature = "draft-ietf-mls-extensions")]
103            ExtensionType::TARGETED_MESSAGES_CAPABILITY => Self::TargetedMessagesCapability,
104            _ => Self::Arbitrary(ArbitraryExtension {
105                extension_id: ExtensionType::new_unchecked(extension_id),
106                extension_data,
107            }),
108        })
109    }
110
111    pub fn ext_type(&self) -> ExtensionType {
112        self.into()
113    }
114}
115
116impl tls_codec::Size for Extension {
117    fn tls_serialized_len(&self) -> usize {
118        let ext_type_len = ExtensionType::from(self).tls_serialized_len();
119        let ext_value_len = match self {
120            Extension::ApplicationId(data) => {
121                tls_serialized_len_as_vlvec(data.tls_serialized_len())
122            }
123            Extension::RatchetTree(nodes) => {
124                tls_serialized_len_as_vlvec(nodes.tls_serialized_len())
125            }
126            Extension::RequiredCapabilities(caps) => {
127                tls_serialized_len_as_vlvec(caps.tls_serialized_len())
128            }
129            Extension::ExternalPub(ext_pub) => {
130                tls_serialized_len_as_vlvec(ext_pub.tls_serialized_len())
131            }
132            Extension::ExternalSenders(ext_senders) => {
133                tls_serialized_len_as_vlvec(ext_senders.tls_serialized_len())
134            }
135            #[cfg(feature = "draft-ietf-mls-extensions")]
136            Extension::ApplicationData(app_data_dict) => {
137                tls_serialized_len_as_vlvec(app_data_dict.tls_serialized_len())
138            }
139            #[cfg(feature = "draft-ietf-mls-extensions")]
140            Extension::SupportedWireFormats(wfs) | Extension::RequiredWireFormats(wfs) => {
141                tls_serialized_len_as_vlvec(wfs.tls_serialized_len())
142            }
143            #[cfg(feature = "draft-ietf-mls-extensions")]
144            Extension::TargetedMessagesCapability => tls_serialized_len_as_vlvec(0),
145            Extension::Arbitrary(ArbitraryExtension { extension_data, .. }) => {
146                tls_serialized_len_as_vlvec(extension_data.len())
147            }
148        };
149
150        ext_type_len + ext_value_len
151    }
152}
153
154impl tls_codec::Serialize for Extension {
155    fn tls_serialize<W: std::io::Write>(&self, writer: &mut W) -> Result<usize, tls_codec::Error> {
156        use tls_codec::Size as _;
157
158        let extension_id = ExtensionType::from(self);
159        let mut written = extension_id.tls_serialize(writer)?;
160
161        // FIXME: Probably can get rid of this copy
162        let extdata_len = self.tls_serialized_len() - written;
163        let mut extension_data = Vec::with_capacity(extdata_len);
164
165        let _ = match self {
166            Extension::ApplicationId(data) => data.tls_serialize(&mut extension_data)?,
167            Extension::RatchetTree(nodes) => nodes.tls_serialize(&mut extension_data)?,
168            Extension::RequiredCapabilities(caps) => caps.tls_serialize(&mut extension_data)?,
169            Extension::ExternalPub(ext_pub) => ext_pub.tls_serialize(&mut extension_data)?,
170            Extension::ExternalSenders(ext_senders) => {
171                ext_senders.tls_serialize(&mut extension_data)?
172            }
173            #[cfg(feature = "draft-ietf-mls-extensions")]
174            Extension::ApplicationData(app_data_dict) => {
175                app_data_dict.tls_serialize(&mut extension_data)?
176            }
177            #[cfg(feature = "draft-ietf-mls-extensions")]
178            Extension::SupportedWireFormats(wfs) | Extension::RequiredWireFormats(wfs) => {
179                wfs.tls_serialize(&mut extension_data)?
180            }
181            #[cfg(feature = "draft-ietf-mls-extensions")]
182            Extension::TargetedMessagesCapability => [0u8; 0].tls_serialize(&mut extension_data)?,
183            Extension::Arbitrary(ArbitraryExtension {
184                extension_data: arbitrary_ext_data,
185                ..
186            }) => {
187                use std::io::Write as _;
188                extension_data.write_all(arbitrary_ext_data)?;
189                arbitrary_ext_data.len()
190            }
191        };
192
193        written += extension_data.tls_serialize(writer)?;
194
195        Ok(written)
196    }
197}
198
199impl tls_codec::Deserialize for Extension {
200    fn tls_deserialize<R: std::io::Read>(bytes: &mut R) -> Result<Self, tls_codec::Error>
201    where
202        Self: Sized,
203    {
204        Self::new(
205            *ExtensionType::tls_deserialize(bytes)?,
206            vlbytes::tls_deserialize(bytes)?,
207        )
208        .map_err(|e| match e {
209            crate::MlsSpecError::TlsCodecError(e) => e,
210            _ => tls_codec::Error::DecodingError(e.to_string()),
211        })
212    }
213}
214
215ref_forward_tls_impl!(Extension);
216
217#[derive(
218    Debug,
219    Clone,
220    PartialEq,
221    Eq,
222    Hash,
223    tls_codec::TlsSerialize,
224    tls_codec::TlsDeserialize,
225    tls_codec::TlsSize,
226)]
227#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
228pub struct ExternalPub {
229    pub external_pub: HpkePublicKey,
230}
231
232#[derive(Debug, Clone, PartialEq, Eq, Hash)]
233#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
234pub struct ArbitraryExtension {
235    pub extension_id: ExtensionType,
236    pub extension_data: Vec<u8>,
237}