mls_rs_identity_x509/
identity_extractor.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// Copyright by contributors to this project.
3// SPDX-License-Identifier: (Apache-2.0 OR MIT)
4
5use alloc::vec::Vec;
6use mls_rs_core::{error::IntoAnyError, identity::CertificateChain};
7
8use crate::{
9    DerCertificate, SubjectComponent, X509CertificateReader, X509IdentityError,
10    X509IdentityExtractor,
11};
12
13#[derive(Debug, Clone)]
14/// A utility to determine unique identity for use with MLS by reading
15/// the subject of a certificate.
16///
17/// The default behavior of this struct is to try and produce an identity
18/// based on the common name component of the subject. If a common name
19/// component is not found, then the byte value of the entire subject
20/// is used as a fallback.
21pub struct SubjectIdentityExtractor<R: X509CertificateReader> {
22    offset: usize,
23    reader: R,
24}
25
26impl<R> SubjectIdentityExtractor<R>
27where
28    R: X509CertificateReader,
29{
30    /// Create a new identity extractor.
31    ///
32    /// `offset` is used to determine which certificate in a [`CertificateChain`]
33    /// should be used to evaluate identity. A value of 0 indicates to use the
34    /// leaf (first value) of the chain.
35    pub fn new(offset: usize, reader: R) -> Self {
36        Self { offset, reader }
37    }
38
39    fn extract_common_name(
40        &self,
41        certificate: &DerCertificate,
42    ) -> Result<Option<SubjectComponent>, X509IdentityError> {
43        Ok(self
44            .reader
45            .subject_components(certificate)
46            .map_err(|err| X509IdentityError::IdentityExtractorError(err.into_any_error()))?
47            .iter()
48            .find(|component| matches!(component, SubjectComponent::CommonName(_)))
49            .cloned())
50    }
51
52    /// Get a unique identifier for a `certificate_chain`.
53    pub fn identity(
54        &self,
55        certificate_chain: &CertificateChain,
56    ) -> Result<Vec<u8>, X509IdentityError> {
57        let cert = get_certificate(certificate_chain, self.offset)?;
58
59        let common_name_value = self.extract_common_name(cert)?;
60
61        if let Some(SubjectComponent::CommonName(common_name)) = common_name_value {
62            return Ok(common_name.as_bytes().to_vec());
63        }
64
65        self.subject_bytes(cert)
66    }
67
68    fn subject_bytes(&self, certificate: &DerCertificate) -> Result<Vec<u8>, X509IdentityError> {
69        self.reader
70            .subject_bytes(certificate)
71            .map_err(|e| X509IdentityError::X509ReaderError(e.into_any_error()))
72    }
73
74    /// Determine if `successor` resolves to the same
75    /// identity value as `predecessor`, indicating that
76    /// `predecessor` and `successor` are controlled by the same entity.
77    pub fn valid_successor(
78        &self,
79        predecessor: &CertificateChain,
80        successor: &CertificateChain,
81    ) -> Result<bool, X509IdentityError> {
82        let predecessor_cert = get_certificate(predecessor, 0)?;
83        let successor_cert = get_certificate(successor, 0)?;
84
85        let predecessor_common_name = self.extract_common_name(predecessor_cert)?;
86
87        let successor_common_name = self.extract_common_name(successor_cert)?;
88
89        if let (Some(pre_common_name), Some(succ_common_name)) =
90            (predecessor_common_name, successor_common_name)
91        {
92            return Ok(pre_common_name == succ_common_name);
93        }
94
95        Ok(self.subject_bytes(predecessor_cert)? == self.subject_bytes(successor_cert)?)
96    }
97}
98
99impl<R> X509IdentityExtractor for SubjectIdentityExtractor<R>
100where
101    R: X509CertificateReader,
102{
103    type Error = X509IdentityError;
104
105    fn identity(&self, certificate_chain: &CertificateChain) -> Result<Vec<u8>, Self::Error> {
106        self.identity(certificate_chain)
107    }
108
109    fn valid_successor(
110        &self,
111        predecessor: &CertificateChain,
112        successor: &CertificateChain,
113    ) -> Result<bool, Self::Error> {
114        self.valid_successor(predecessor, successor)
115    }
116}
117
118fn get_certificate(
119    certificate_chain: &CertificateChain,
120    offset: usize,
121) -> Result<&DerCertificate, X509IdentityError> {
122    certificate_chain
123        .get(offset)
124        .ok_or(X509IdentityError::InvalidOffset)
125}
126
127#[cfg(all(test, feature = "std"))]
128mod tests {
129    use crate::{
130        test_utils::test_certificate_chain, MockX509CertificateReader, SubjectComponent,
131        SubjectIdentityExtractor, X509IdentityError,
132    };
133
134    use alloc::vec;
135    use assert_matches::assert_matches;
136
137    #[cfg(target_arch = "wasm32")]
138    use wasm_bindgen_test::wasm_bindgen_test as test;
139
140    fn test_setup<F>(
141        offset: usize,
142        mut mock_setup: F,
143    ) -> SubjectIdentityExtractor<MockX509CertificateReader>
144    where
145        F: FnMut(&mut MockX509CertificateReader),
146    {
147        let mut x509_reader = MockX509CertificateReader::new();
148
149        mock_setup(&mut x509_reader);
150
151        SubjectIdentityExtractor {
152            offset,
153            reader: x509_reader,
154        }
155    }
156
157    #[test]
158    fn invalid_offset_is_rejected() {
159        let subject_extractor = test_setup(4, |subject_extractor| {
160            subject_extractor.expect_subject_bytes().never();
161        });
162
163        assert_matches!(
164            subject_extractor.identity(&test_certificate_chain()),
165            Err(X509IdentityError::InvalidOffset)
166        );
167    }
168
169    #[test]
170    fn common_name_can_be_retrived_as_identity() {
171        let test_subject = b"test_name".to_vec();
172        let cert_chain = test_certificate_chain();
173
174        let expected_certificate = cert_chain[1].clone();
175
176        let subject_extractor = test_setup(1, |parser| {
177            parser.expect_subject_bytes().never();
178
179            parser
180                .expect_subject_components()
181                .with(mockall::predicate::eq(expected_certificate.clone()))
182                .times(1)
183                .return_once_st(|_| {
184                    Ok(vec![
185                        SubjectComponent::CommonName("test_name".to_string()),
186                        SubjectComponent::CountryName("US".to_string()),
187                    ])
188                });
189        });
190
191        assert_eq!(
192            subject_extractor.identity(&cert_chain).unwrap(),
193            test_subject
194        );
195    }
196
197    #[test]
198    fn subject_can_be_retrived_as_identity_if_no_common_name() {
199        let test_subject = b"subject".to_vec();
200        let cert_chain = test_certificate_chain();
201
202        let expected_certificate = cert_chain[1].clone();
203
204        let subject_extractor = test_setup(1, |parser| {
205            let test_subject = test_subject.clone();
206
207            parser
208                .expect_subject_bytes()
209                .once()
210                .with(mockall::predicate::eq(expected_certificate.clone()))
211                .return_once_st(|_| Ok(test_subject));
212
213            parser
214                .expect_subject_components()
215                .with(mockall::predicate::eq(expected_certificate.clone()))
216                .times(1)
217                .return_once_st(|_| Ok(vec![SubjectComponent::CountryName("US".to_string())]));
218        });
219
220        assert_eq!(
221            subject_extractor.identity(&cert_chain).unwrap(),
222            test_subject
223        );
224    }
225
226    #[test]
227    fn valid_successor_matching_common_name() {
228        let predecessor = test_certificate_chain();
229        let mut successor = test_certificate_chain();
230
231        // Make sure both chains have the same leaf
232        successor[0] = predecessor[0].clone();
233
234        let subject_extractor = test_setup(1, |reader| {
235            let predecessor = predecessor[0].clone();
236            let successor = successor[0].clone();
237
238            reader
239                .expect_subject_components()
240                .with(mockall::predicate::eq(successor))
241                .times(1)
242                .return_once_st(|_| {
243                    Ok(vec![SubjectComponent::CommonName("test_name".to_string())])
244                });
245
246            reader
247                .expect_subject_components()
248                .with(mockall::predicate::eq(predecessor))
249                .times(1)
250                .return_once_st(|_| {
251                    Ok(vec![SubjectComponent::CommonName("test_name".to_string())])
252                });
253
254            reader.expect_subject_bytes().never();
255
256            reader.expect_subject_bytes().never();
257        });
258
259        assert!(subject_extractor
260            .valid_successor(&predecessor, &successor)
261            .unwrap());
262    }
263
264    #[test]
265    fn invalid_successor_different_common_name() {
266        let predecessor = test_certificate_chain();
267        let mut successor = test_certificate_chain();
268
269        // Make sure both chains have the same leaf
270        successor[0] = predecessor[0].clone();
271
272        let subject_extractor = test_setup(1, |reader| {
273            let predecessor = predecessor[0].clone();
274            let successor = successor[0].clone();
275
276            reader
277                .expect_subject_components()
278                .with(mockall::predicate::eq(successor))
279                .times(1)
280                .return_once_st(|_| {
281                    Ok(vec![
282                        SubjectComponent::CommonName("test_name_copy".to_string()),
283                        SubjectComponent::CountryName("US".to_string()),
284                    ])
285                });
286
287            reader
288                .expect_subject_components()
289                .with(mockall::predicate::eq(predecessor))
290                .times(1)
291                .return_once_st(|_| {
292                    Ok(vec![
293                        SubjectComponent::CommonName("test_name".to_string()),
294                        SubjectComponent::CountryName("US".to_string()),
295                    ])
296                });
297
298            reader.expect_subject_bytes().never();
299
300            reader.expect_subject_bytes().never();
301        });
302
303        assert!(
304            !subject_extractor
305                .valid_successor(&predecessor, &successor)
306                .unwrap(),
307            "Successor chain cert with different CommonName passed check!"
308        );
309    }
310
311    #[test]
312    fn valid_successor_no_common_name() {
313        let predecessor = test_certificate_chain();
314        let mut successor = test_certificate_chain();
315
316        // Make sure both chains have the same leaf
317        successor[0] = predecessor[0].clone();
318
319        let subject_extractor = test_setup(1, |reader| {
320            let predecessor = predecessor[0].clone();
321            let successor = successor[0].clone();
322
323            reader
324                .expect_subject_components()
325                .with(mockall::predicate::eq(successor.clone()))
326                .times(1)
327                .return_once_st(|_| Ok(vec![SubjectComponent::CountryName("US".to_string())]));
328
329            reader
330                .expect_subject_components()
331                .with(mockall::predicate::eq(predecessor.clone()))
332                .times(1)
333                .return_once_st(|_| {
334                    Ok(vec![
335                        SubjectComponent::CommonName("test_name".to_string()),
336                        SubjectComponent::CountryName("US".to_string()),
337                    ])
338                });
339
340            reader
341                .expect_subject_bytes()
342                .with(mockall::predicate::eq(predecessor))
343                .times(1)
344                .return_once_st(|_| Ok(b"subject".to_vec()));
345
346            reader
347                .expect_subject_bytes()
348                .with(mockall::predicate::eq(successor))
349                .times(1)
350                .return_once_st(|_| Ok(b"subject".to_vec()));
351        });
352
353        assert!(subject_extractor
354            .valid_successor(&predecessor, &successor)
355            .unwrap());
356    }
357
358    #[test]
359    fn invalid_successor_no_common_name() {
360        let predecessor = test_certificate_chain();
361        let mut successor = test_certificate_chain();
362
363        // Make sure both chains have the same leaf
364        successor[0] = predecessor[0].clone();
365
366        let subject_extractor = test_setup(1, |reader| {
367            let predecessor = predecessor[0].clone();
368            let successor = successor[0].clone();
369
370            reader
371                .expect_subject_bytes()
372                .with(mockall::predicate::eq(predecessor.clone()))
373                .times(1)
374                .return_once_st(|_| Ok(b"subject_copy".to_vec()));
375
376            reader
377                .expect_subject_bytes()
378                .with(mockall::predicate::eq(successor.clone()))
379                .times(1)
380                .return_once_st(|_| Ok(b"subject".to_vec()));
381
382            reader
383                .expect_subject_components()
384                .with(mockall::predicate::eq(successor))
385                .times(1)
386                .return_once_st(|_| Ok(vec![SubjectComponent::CountryName("US".to_string())]));
387
388            reader
389                .expect_subject_components()
390                .with(mockall::predicate::eq(predecessor))
391                .times(1)
392                .return_once_st(|_| Ok(vec![SubjectComponent::CountryName("US".to_string())]));
393        });
394
395        assert!(
396            !subject_extractor
397                .valid_successor(&predecessor, &successor)
398                .unwrap(),
399            "Successor cert chain with different subjects passed valid check!"
400        );
401    }
402}