mls_rs_identity_x509/
provider.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// Copyright by contributors to this project.
3// SPDX-License-Identifier: (Apache-2.0 OR MIT)
4
5use crate::{util::credential_to_chain, CertificateChain, X509IdentityError};
6use alloc::vec;
7use alloc::vec::Vec;
8use mls_rs_core::{
9    crypto::SignaturePublicKey,
10    error::IntoAnyError,
11    extension::ExtensionList,
12    identity::{CredentialType, IdentityProvider, MemberValidationContext, SigningIdentity},
13    time::MlsTime,
14};
15
16#[cfg(not(feature = "std"))]
17use alloc::boxed::Box;
18
19#[cfg(all(test, feature = "std"))]
20use mockall::automock;
21
22#[cfg_attr(all(test, feature = "std"), automock(type Error = crate::test_utils::TestError;))]
23/// X.509 certificate unique identity trait.
24pub trait X509IdentityExtractor {
25    type Error: IntoAnyError;
26
27    /// Produce a unique identity value to represent the entity controlling a
28    /// certificate credential within an MLS group.
29    fn identity(&self, certificate_chain: &CertificateChain) -> Result<Vec<u8>, Self::Error>;
30
31    /// Determine if `successor` is controlled by the same entity as
32    /// `predecessor`.
33    fn valid_successor(
34        &self,
35        predecessor: &CertificateChain,
36        successor: &CertificateChain,
37    ) -> Result<bool, Self::Error>;
38}
39
40#[cfg_attr(all(test, feature = "std"), automock(type Error = crate::test_utils::TestError;))]
41/// X.509 certificate validation trait.
42pub trait X509CredentialValidator {
43    type Error: IntoAnyError;
44
45    /// Validate a certificate chain.
46    ///
47    /// If `timestamp` is set to `None` then expiration checks should be skipped.
48    fn validate_chain(
49        &self,
50        chain: &CertificateChain,
51        timestamp: Option<MlsTime>,
52    ) -> Result<SignaturePublicKey, Self::Error>;
53}
54
55#[derive(Clone, Debug)]
56#[non_exhaustive]
57/// A customizable generic X.509 certificate identity provider.
58///
59/// This provider forwards its individual [`IdentityProvider`]
60/// behavior to its generic sub-components.
61///
62/// Only X509 credentials are supported by this provider.
63pub struct X509IdentityProvider<IE, V> {
64    pub identity_extractor: IE,
65    pub validator: V,
66}
67
68impl<IE, V> X509IdentityProvider<IE, V>
69where
70    IE: X509IdentityExtractor,
71    V: X509CredentialValidator,
72{
73    /// Create a new identity provider.
74    pub fn new(identity_extractor: IE, validator: V) -> Self {
75        Self {
76            identity_extractor,
77            validator,
78        }
79    }
80
81    /// Determine if a certificate is valid based on the behavior of the
82    /// underlying validator provided.
83    fn validate(
84        &self,
85        signing_identity: &SigningIdentity,
86        timestamp: Option<MlsTime>,
87    ) -> Result<(), X509IdentityError> {
88        let chain = credential_to_chain(&signing_identity.credential)?;
89
90        let leaf_public_key = self
91            .validator
92            .validate_chain(&chain, timestamp)
93            .map_err(|e| X509IdentityError::X509ValidationError(e.into_any_error()))?;
94
95        if leaf_public_key != signing_identity.signature_key {
96            return Err(X509IdentityError::SignatureKeyMismatch);
97        }
98
99        Ok(())
100    }
101}
102
103#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
104#[cfg_attr(mls_build_async, maybe_async::must_be_async)]
105impl<IE, V> IdentityProvider for X509IdentityProvider<IE, V>
106where
107    IE: X509IdentityExtractor + Send + Sync,
108    V: X509CredentialValidator + Send + Sync,
109{
110    type Error = X509IdentityError;
111
112    /// Determine if a certificate is valid based on the behavior of the
113    /// underlying validator provided.
114    async fn validate_member(
115        &self,
116        signing_identity: &SigningIdentity,
117        timestamp: Option<MlsTime>,
118        _context: MemberValidationContext<'_>,
119    ) -> Result<(), X509IdentityError> {
120        self.validate(signing_identity, timestamp)
121    }
122
123    /// Produce a unique identity value to represent the entity controlling a
124    /// certificate credential within an MLS group.
125    async fn identity(
126        &self,
127        signing_id: &SigningIdentity,
128        _extensions: &ExtensionList,
129    ) -> Result<Vec<u8>, X509IdentityError> {
130        self.identity_extractor
131            .identity(&credential_to_chain(&signing_id.credential)?)
132            .map_err(|e| X509IdentityError::IdentityExtractorError(e.into_any_error()))
133    }
134
135    /// Determine if `successor` is controlled by the same entity as
136    /// `predecessor` based on the behavior of the underlying identity
137    /// extractor provided.
138    async fn valid_successor(
139        &self,
140        predecessor: &SigningIdentity,
141        successor: &SigningIdentity,
142        _extensions: &ExtensionList,
143    ) -> Result<bool, X509IdentityError> {
144        self.identity_extractor
145            .valid_successor(
146                &credential_to_chain(&predecessor.credential)?,
147                &credential_to_chain(&successor.credential)?,
148            )
149            .map_err(|e| X509IdentityError::IdentityExtractorError(e.into_any_error()))
150    }
151
152    async fn validate_external_sender(
153        &self,
154        signing_identity: &SigningIdentity,
155        timestamp: Option<MlsTime>,
156        _extensions: Option<&ExtensionList>,
157    ) -> Result<(), Self::Error> {
158        self.validate(signing_identity, timestamp)
159    }
160
161    /// Supported credential types.
162    ///
163    /// Only [`CredentialType::X509`] is supported.
164    fn supported_types(&self) -> Vec<CredentialType> {
165        vec![CredentialType::X509]
166    }
167}
168
169#[cfg(all(test, feature = "std"))]
170mod tests {
171    use super::*;
172    use mls_rs_core::{crypto::SignaturePublicKey, identity::CredentialType, time::MlsTime};
173
174    use crate::{
175        test_utils::{
176            test_certificate_chain, test_signing_identity, test_signing_identity_with_chain,
177            TestError,
178        },
179        MockX509CredentialValidator, MockX509IdentityExtractor, X509IdentityError,
180        X509IdentityProvider,
181    };
182
183    use alloc::vec;
184
185    use assert_matches::assert_matches;
186
187    #[cfg(target_arch = "wasm32")]
188    use wasm_bindgen_test::wasm_bindgen_test as test;
189
190    fn test_setup<F>(
191        mut mock_setup: F,
192    ) -> X509IdentityProvider<MockX509IdentityExtractor, MockX509CredentialValidator>
193    where
194        F: FnMut(&mut MockX509IdentityExtractor, &mut MockX509CredentialValidator),
195    {
196        let mut identity_extractor = MockX509IdentityExtractor::new();
197        let mut validator = MockX509CredentialValidator::new();
198
199        mock_setup(&mut identity_extractor, &mut validator);
200
201        X509IdentityProvider::new(identity_extractor, validator)
202    }
203
204    #[test]
205    fn test_supported_types() {
206        let test_provider = test_setup(|_, _| ());
207
208        assert_eq!(
209            test_provider.supported_types(),
210            vec![CredentialType::new(2)]
211        )
212    }
213
214    #[test]
215    fn test_successful_validation() {
216        let chain = test_certificate_chain();
217
218        let test_signing_identity = test_signing_identity_with_chain(chain.clone());
219
220        let test_timestamp = MlsTime::now();
221
222        let test_provider = test_setup(|_, validator| {
223            let validation_result = test_signing_identity.signature_key.clone();
224
225            validator
226                .expect_validate_chain()
227                .once()
228                .with(
229                    mockall::predicate::eq(chain.clone()),
230                    mockall::predicate::eq(Some(test_timestamp)),
231                )
232                .return_once_st(|_, _| Ok(validation_result));
233        });
234
235        test_provider
236            .validate(&test_signing_identity, Some(test_timestamp))
237            .unwrap();
238    }
239
240    #[test]
241    fn test_signing_identity_key_mismatch() {
242        let test_signing_identity = test_signing_identity();
243
244        let test_provider = test_setup(|_, validator| {
245            let validation_result = SignaturePublicKey::from(vec![42u8; 32]);
246
247            validator
248                .expect_validate_chain()
249                .return_once_st(|_, _| Ok(validation_result));
250        });
251
252        assert_matches!(
253            test_provider.validate(&test_signing_identity, None),
254            Err(X509IdentityError::SignatureKeyMismatch)
255        );
256    }
257
258    #[test]
259    fn test_failing_validation() {
260        let test_provider = test_setup(|_, validator| {
261            validator
262                .expect_validate_chain()
263                .return_once_st(|_, _| Err(TestError));
264        });
265
266        assert_matches!(
267            test_provider.validate(&test_signing_identity(), None),
268            Err(X509IdentityError::X509ValidationError(_))
269        )
270    }
271}