mls_rs/extension/
built_in.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// Copyright by contributors to this project.
3// SPDX-License-Identifier: (Apache-2.0 OR MIT)
4
5use alloc::vec::Vec;
6use core::fmt::{self, Debug};
7use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
8use mls_rs_core::extension::{ExtensionType, MlsCodecExtension};
9
10use mls_rs_core::{group::ProposalType, identity::CredentialType};
11
12#[cfg(feature = "by_ref_proposal")]
13use mls_rs_core::{
14    extension::ExtensionList,
15    identity::{IdentityProvider, SigningIdentity},
16    time::MlsTime,
17};
18
19use crate::group::ExportedTree;
20
21use mls_rs_core::crypto::HpkePublicKey;
22
23/// Application specific identifier.
24///
25/// A custom application level identifier that can be optionally stored
26/// within the `leaf_node_extensions` of a group [Member](crate::group::Member).
27#[cfg_attr(
28    all(feature = "ffi", not(test)),
29    safer_ffi_gen::ffi_type(clone, opaque)
30)]
31#[derive(Clone, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)]
32pub struct ApplicationIdExt {
33    /// Application level identifier presented by this extension.
34    #[mls_codec(with = "mls_rs_codec::byte_vec")]
35    pub identifier: Vec<u8>,
36}
37
38impl Debug for ApplicationIdExt {
39    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
40        f.debug_struct("ApplicationIdExt")
41            .field(
42                "identifier",
43                &mls_rs_core::debug::pretty_bytes(&self.identifier),
44            )
45            .finish()
46    }
47}
48
49#[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen)]
50impl ApplicationIdExt {
51    /// Create a new application level identifier extension.
52    pub fn new(identifier: Vec<u8>) -> Self {
53        ApplicationIdExt { identifier }
54    }
55
56    /// Get the application level identifier presented by this extension.
57    #[cfg(feature = "ffi")]
58    pub fn identifier(&self) -> &[u8] {
59        &self.identifier
60    }
61}
62
63impl MlsCodecExtension for ApplicationIdExt {
64    fn extension_type() -> ExtensionType {
65        ExtensionType::APPLICATION_ID
66    }
67}
68
69/// Representation of an MLS ratchet tree.
70///
71/// Used to provide new members
72/// a copy of the current group state in-band.
73#[cfg_attr(
74    all(feature = "ffi", not(test)),
75    safer_ffi_gen::ffi_type(clone, opaque)
76)]
77#[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode)]
78pub struct RatchetTreeExt {
79    pub tree_data: ExportedTree<'static>,
80}
81
82#[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen)]
83impl RatchetTreeExt {
84    /// Required custom extension types.
85    #[cfg(feature = "ffi")]
86    pub fn tree_data(&self) -> &ExportedTree<'static> {
87        &self.tree_data
88    }
89}
90
91impl MlsCodecExtension for RatchetTreeExt {
92    fn extension_type() -> ExtensionType {
93        ExtensionType::RATCHET_TREE
94    }
95}
96
97/// Require members to have certain capabilities.
98///
99/// Used within a
100/// [Group Context Extensions Proposal](crate::group::proposal::Proposal)
101/// in order to require that all current and future members of a group MUST
102/// support specific extensions, proposals, or credentials.
103///
104/// # Warning
105///
106/// Extension, proposal, and credential types defined by the MLS RFC and
107/// provided are considered required by default and should NOT be used
108/// within this extension.
109#[cfg_attr(
110    all(feature = "ffi", not(test)),
111    safer_ffi_gen::ffi_type(clone, opaque)
112)]
113#[derive(Clone, Debug, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode, Default)]
114pub struct RequiredCapabilitiesExt {
115    pub extensions: Vec<ExtensionType>,
116    pub proposals: Vec<ProposalType>,
117    pub credentials: Vec<CredentialType>,
118}
119
120#[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen)]
121impl RequiredCapabilitiesExt {
122    /// Create a required capabilities extension.
123    pub fn new(
124        extensions: Vec<ExtensionType>,
125        proposals: Vec<ProposalType>,
126        credentials: Vec<CredentialType>,
127    ) -> Self {
128        Self {
129            extensions,
130            proposals,
131            credentials,
132        }
133    }
134
135    /// Required custom extension types.
136    #[cfg(feature = "ffi")]
137    pub fn extensions(&self) -> &[ExtensionType] {
138        &self.extensions
139    }
140
141    /// Required custom proposal types.
142    #[cfg(feature = "ffi")]
143    pub fn proposals(&self) -> &[ProposalType] {
144        &self.proposals
145    }
146
147    /// Required custom credential types.
148    #[cfg(feature = "ffi")]
149    pub fn credentials(&self) -> &[CredentialType] {
150        &self.credentials
151    }
152}
153
154impl MlsCodecExtension for RequiredCapabilitiesExt {
155    fn extension_type() -> ExtensionType {
156        ExtensionType::REQUIRED_CAPABILITIES
157    }
158}
159
160/// External public key used for [External Commits](crate::Client::commit_external).
161///
162/// This proposal type is optionally provided as part of a
163/// [Group Info](crate::group::Group::group_info_message).
164#[cfg_attr(
165    all(feature = "ffi", not(test)),
166    safer_ffi_gen::ffi_type(clone, opaque)
167)]
168#[derive(Clone, Debug, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)]
169pub struct ExternalPubExt {
170    /// Public key to be used for an external commit.
171    #[mls_codec(with = "mls_rs_codec::byte_vec")]
172    pub external_pub: HpkePublicKey,
173}
174
175#[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen)]
176impl ExternalPubExt {
177    /// Get the public key to be used for an external commit.
178    #[cfg(feature = "ffi")]
179    pub fn external_pub(&self) -> &HpkePublicKey {
180        &self.external_pub
181    }
182}
183
184impl MlsCodecExtension for ExternalPubExt {
185    fn extension_type() -> ExtensionType {
186        ExtensionType::EXTERNAL_PUB
187    }
188}
189
190/// Enable proposals by an [ExternalClient](crate::external_client::ExternalClient).
191#[cfg(feature = "by_ref_proposal")]
192#[cfg_attr(
193    all(feature = "ffi", not(test)),
194    safer_ffi_gen::ffi_type(clone, opaque)
195)]
196#[derive(Clone, Debug, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)]
197#[non_exhaustive]
198pub struct ExternalSendersExt {
199    pub allowed_senders: Vec<SigningIdentity>,
200}
201
202#[cfg(feature = "by_ref_proposal")]
203#[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen)]
204impl ExternalSendersExt {
205    pub fn new(allowed_senders: Vec<SigningIdentity>) -> Self {
206        Self { allowed_senders }
207    }
208
209    #[cfg(feature = "ffi")]
210    pub fn allowed_senders(&self) -> &[SigningIdentity] {
211        &self.allowed_senders
212    }
213
214    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
215    pub(crate) async fn verify_all<I: IdentityProvider>(
216        &self,
217        provider: &I,
218        timestamp: Option<MlsTime>,
219        group_context_extensions: &ExtensionList,
220    ) -> Result<(), I::Error> {
221        for id in self.allowed_senders.iter() {
222            provider
223                .validate_external_sender(id, timestamp, Some(group_context_extensions))
224                .await?;
225        }
226
227        Ok(())
228    }
229}
230
231#[cfg(feature = "by_ref_proposal")]
232impl MlsCodecExtension for ExternalSendersExt {
233    fn extension_type() -> ExtensionType {
234        ExtensionType::EXTERNAL_SENDERS
235    }
236}
237
238#[cfg(test)]
239mod tests {
240    use super::*;
241
242    use crate::tree_kem::node::NodeVec;
243    #[cfg(feature = "by_ref_proposal")]
244    use crate::{
245        client::test_utils::TEST_CIPHER_SUITE, identity::test_utils::get_test_signing_identity,
246    };
247
248    use mls_rs_core::extension::MlsExtension;
249
250    use mls_rs_core::identity::BasicCredential;
251
252    use alloc::vec;
253
254    #[cfg(target_arch = "wasm32")]
255    use wasm_bindgen_test::wasm_bindgen_test as test;
256
257    #[test]
258    fn test_application_id_extension() {
259        let test_id = vec![0u8; 32];
260        let test_extension = ApplicationIdExt {
261            identifier: test_id.clone(),
262        };
263
264        let as_extension = test_extension.into_extension().unwrap();
265
266        assert_eq!(as_extension.extension_type, ExtensionType::APPLICATION_ID);
267
268        let restored = ApplicationIdExt::from_extension(&as_extension).unwrap();
269        assert_eq!(restored.identifier, test_id);
270    }
271
272    #[test]
273    fn test_ratchet_tree() {
274        let ext = RatchetTreeExt {
275            tree_data: ExportedTree::new(NodeVec::from(vec![None, None])),
276        };
277
278        let as_extension = ext.clone().into_extension().unwrap();
279        assert_eq!(as_extension.extension_type, ExtensionType::RATCHET_TREE);
280
281        let restored = RatchetTreeExt::from_extension(&as_extension).unwrap();
282        assert_eq!(ext, restored)
283    }
284
285    #[test]
286    fn test_required_capabilities() {
287        let ext = RequiredCapabilitiesExt {
288            extensions: vec![0.into(), 1.into()],
289            proposals: vec![42.into(), 43.into()],
290            credentials: vec![BasicCredential::credential_type()],
291        };
292
293        let as_extension = ext.clone().into_extension().unwrap();
294
295        assert_eq!(
296            as_extension.extension_type,
297            ExtensionType::REQUIRED_CAPABILITIES
298        );
299
300        let restored = RequiredCapabilitiesExt::from_extension(&as_extension).unwrap();
301        assert_eq!(ext, restored)
302    }
303
304    #[cfg(feature = "by_ref_proposal")]
305    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
306    async fn test_external_senders() {
307        let identity = get_test_signing_identity(TEST_CIPHER_SUITE, &[1]).await.0;
308        let ext = ExternalSendersExt::new(vec![identity]);
309
310        let as_extension = ext.clone().into_extension().unwrap();
311
312        assert_eq!(as_extension.extension_type, ExtensionType::EXTERNAL_SENDERS);
313
314        let restored = ExternalSendersExt::from_extension(&as_extension).unwrap();
315        assert_eq!(ext, restored)
316    }
317
318    #[test]
319    fn test_external_pub() {
320        let ext = ExternalPubExt {
321            external_pub: vec![0, 1, 2, 3].into(),
322        };
323
324        let as_extension = ext.clone().into_extension().unwrap();
325        assert_eq!(as_extension.extension_type, ExtensionType::EXTERNAL_PUB);
326
327        let restored = ExternalPubExt::from_extension(&as_extension).unwrap();
328        assert_eq!(ext, restored)
329    }
330}