1use crate::{util::credential_to_chain, CertificateChain, X509IdentityError};
6use alloc::vec;
7use alloc::vec::Vec;
8use mls_rs_core::{
9 crypto::SignaturePublicKey,
10 error::IntoAnyError,
11 extension::ExtensionList,
12 identity::{CredentialType, IdentityProvider, MemberValidationContext, SigningIdentity},
13 time::MlsTime,
14};
15
16#[cfg(not(feature = "std"))]
17use alloc::boxed::Box;
18
19#[cfg(all(test, feature = "std"))]
20use mockall::automock;
21
22#[cfg_attr(all(test, feature = "std"), automock(type Error = crate::test_utils::TestError;))]
23pub trait X509IdentityExtractor {
25 type Error: IntoAnyError;
26
27 fn identity(&self, certificate_chain: &CertificateChain) -> Result<Vec<u8>, Self::Error>;
30
31 fn valid_successor(
34 &self,
35 predecessor: &CertificateChain,
36 successor: &CertificateChain,
37 ) -> Result<bool, Self::Error>;
38}
39
40#[cfg_attr(all(test, feature = "std"), automock(type Error = crate::test_utils::TestError;))]
41pub trait X509CredentialValidator {
43 type Error: IntoAnyError;
44
45 fn validate_chain(
49 &self,
50 chain: &CertificateChain,
51 timestamp: Option<MlsTime>,
52 ) -> Result<SignaturePublicKey, Self::Error>;
53}
54
55#[derive(Clone, Debug)]
56#[non_exhaustive]
57pub struct X509IdentityProvider<IE, V> {
64 pub identity_extractor: IE,
65 pub validator: V,
66}
67
68impl<IE, V> X509IdentityProvider<IE, V>
69where
70 IE: X509IdentityExtractor,
71 V: X509CredentialValidator,
72{
73 pub fn new(identity_extractor: IE, validator: V) -> Self {
75 Self {
76 identity_extractor,
77 validator,
78 }
79 }
80
81 fn validate(
84 &self,
85 signing_identity: &SigningIdentity,
86 timestamp: Option<MlsTime>,
87 ) -> Result<(), X509IdentityError> {
88 let chain = credential_to_chain(&signing_identity.credential)?;
89
90 let leaf_public_key = self
91 .validator
92 .validate_chain(&chain, timestamp)
93 .map_err(|e| X509IdentityError::X509ValidationError(e.into_any_error()))?;
94
95 if leaf_public_key != signing_identity.signature_key {
96 return Err(X509IdentityError::SignatureKeyMismatch);
97 }
98
99 Ok(())
100 }
101}
102
103#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
104#[cfg_attr(mls_build_async, maybe_async::must_be_async)]
105impl<IE, V> IdentityProvider for X509IdentityProvider<IE, V>
106where
107 IE: X509IdentityExtractor + Send + Sync,
108 V: X509CredentialValidator + Send + Sync,
109{
110 type Error = X509IdentityError;
111
112 async fn validate_member(
115 &self,
116 signing_identity: &SigningIdentity,
117 timestamp: Option<MlsTime>,
118 _context: MemberValidationContext<'_>,
119 ) -> Result<(), X509IdentityError> {
120 self.validate(signing_identity, timestamp)
121 }
122
123 async fn identity(
126 &self,
127 signing_id: &SigningIdentity,
128 _extensions: &ExtensionList,
129 ) -> Result<Vec<u8>, X509IdentityError> {
130 self.identity_extractor
131 .identity(&credential_to_chain(&signing_id.credential)?)
132 .map_err(|e| X509IdentityError::IdentityExtractorError(e.into_any_error()))
133 }
134
135 async fn valid_successor(
139 &self,
140 predecessor: &SigningIdentity,
141 successor: &SigningIdentity,
142 _extensions: &ExtensionList,
143 ) -> Result<bool, X509IdentityError> {
144 self.identity_extractor
145 .valid_successor(
146 &credential_to_chain(&predecessor.credential)?,
147 &credential_to_chain(&successor.credential)?,
148 )
149 .map_err(|e| X509IdentityError::IdentityExtractorError(e.into_any_error()))
150 }
151
152 async fn validate_external_sender(
153 &self,
154 signing_identity: &SigningIdentity,
155 timestamp: Option<MlsTime>,
156 _extensions: Option<&ExtensionList>,
157 ) -> Result<(), Self::Error> {
158 self.validate(signing_identity, timestamp)
159 }
160
161 fn supported_types(&self) -> Vec<CredentialType> {
165 vec![CredentialType::X509]
166 }
167}
168
169#[cfg(all(test, feature = "std"))]
170mod tests {
171 use super::*;
172 use mls_rs_core::{crypto::SignaturePublicKey, identity::CredentialType, time::MlsTime};
173
174 use crate::{
175 test_utils::{
176 test_certificate_chain, test_signing_identity, test_signing_identity_with_chain,
177 TestError,
178 },
179 MockX509CredentialValidator, MockX509IdentityExtractor, X509IdentityError,
180 X509IdentityProvider,
181 };
182
183 use alloc::vec;
184
185 use assert_matches::assert_matches;
186
187 #[cfg(target_arch = "wasm32")]
188 use wasm_bindgen_test::wasm_bindgen_test as test;
189
190 fn test_setup<F>(
191 mut mock_setup: F,
192 ) -> X509IdentityProvider<MockX509IdentityExtractor, MockX509CredentialValidator>
193 where
194 F: FnMut(&mut MockX509IdentityExtractor, &mut MockX509CredentialValidator),
195 {
196 let mut identity_extractor = MockX509IdentityExtractor::new();
197 let mut validator = MockX509CredentialValidator::new();
198
199 mock_setup(&mut identity_extractor, &mut validator);
200
201 X509IdentityProvider::new(identity_extractor, validator)
202 }
203
204 #[test]
205 fn test_supported_types() {
206 let test_provider = test_setup(|_, _| ());
207
208 assert_eq!(
209 test_provider.supported_types(),
210 vec![CredentialType::new(2)]
211 )
212 }
213
214 #[test]
215 fn test_successful_validation() {
216 let chain = test_certificate_chain();
217
218 let test_signing_identity = test_signing_identity_with_chain(chain.clone());
219
220 let test_timestamp = MlsTime::now();
221
222 let test_provider = test_setup(|_, validator| {
223 let validation_result = test_signing_identity.signature_key.clone();
224
225 validator
226 .expect_validate_chain()
227 .once()
228 .with(
229 mockall::predicate::eq(chain.clone()),
230 mockall::predicate::eq(Some(test_timestamp)),
231 )
232 .return_once_st(|_, _| Ok(validation_result));
233 });
234
235 test_provider
236 .validate(&test_signing_identity, Some(test_timestamp))
237 .unwrap();
238 }
239
240 #[test]
241 fn test_signing_identity_key_mismatch() {
242 let test_signing_identity = test_signing_identity();
243
244 let test_provider = test_setup(|_, validator| {
245 let validation_result = SignaturePublicKey::from(vec![42u8; 32]);
246
247 validator
248 .expect_validate_chain()
249 .return_once_st(|_, _| Ok(validation_result));
250 });
251
252 assert_matches!(
253 test_provider.validate(&test_signing_identity, None),
254 Err(X509IdentityError::SignatureKeyMismatch)
255 );
256 }
257
258 #[test]
259 fn test_failing_validation() {
260 let test_provider = test_setup(|_, validator| {
261 validator
262 .expect_validate_chain()
263 .return_once_st(|_, _| Err(TestError));
264 });
265
266 assert_matches!(
267 test_provider.validate(&test_signing_identity(), None),
268 Err(X509IdentityError::X509ValidationError(_))
269 )
270 }
271}