mls_rs_identity_x509/
identity_extractor.rs1use alloc::vec::Vec;
6use mls_rs_core::{error::IntoAnyError, identity::CertificateChain};
7
8use crate::{
9 DerCertificate, SubjectComponent, X509CertificateReader, X509IdentityError,
10 X509IdentityExtractor,
11};
12
13#[derive(Debug, Clone)]
14pub struct SubjectIdentityExtractor<R: X509CertificateReader> {
22 offset: usize,
23 reader: R,
24}
25
26impl<R> SubjectIdentityExtractor<R>
27where
28 R: X509CertificateReader,
29{
30 pub fn new(offset: usize, reader: R) -> Self {
36 Self { offset, reader }
37 }
38
39 fn extract_common_name(
40 &self,
41 certificate: &DerCertificate,
42 ) -> Result<Option<SubjectComponent>, X509IdentityError> {
43 Ok(self
44 .reader
45 .subject_components(certificate)
46 .map_err(|err| X509IdentityError::IdentityExtractorError(err.into_any_error()))?
47 .iter()
48 .find(|component| matches!(component, SubjectComponent::CommonName(_)))
49 .cloned())
50 }
51
52 pub fn identity(
54 &self,
55 certificate_chain: &CertificateChain,
56 ) -> Result<Vec<u8>, X509IdentityError> {
57 let cert = get_certificate(certificate_chain, self.offset)?;
58
59 let common_name_value = self.extract_common_name(cert)?;
60
61 if let Some(SubjectComponent::CommonName(common_name)) = common_name_value {
62 return Ok(common_name.as_bytes().to_vec());
63 }
64
65 self.subject_bytes(cert)
66 }
67
68 fn subject_bytes(&self, certificate: &DerCertificate) -> Result<Vec<u8>, X509IdentityError> {
69 self.reader
70 .subject_bytes(certificate)
71 .map_err(|e| X509IdentityError::X509ReaderError(e.into_any_error()))
72 }
73
74 pub fn valid_successor(
78 &self,
79 predecessor: &CertificateChain,
80 successor: &CertificateChain,
81 ) -> Result<bool, X509IdentityError> {
82 let predecessor_cert = get_certificate(predecessor, 0)?;
83 let successor_cert = get_certificate(successor, 0)?;
84
85 let predecessor_common_name = self.extract_common_name(predecessor_cert)?;
86
87 let successor_common_name = self.extract_common_name(successor_cert)?;
88
89 if let (Some(pre_common_name), Some(succ_common_name)) =
90 (predecessor_common_name, successor_common_name)
91 {
92 return Ok(pre_common_name == succ_common_name);
93 }
94
95 Ok(self.subject_bytes(predecessor_cert)? == self.subject_bytes(successor_cert)?)
96 }
97}
98
99impl<R> X509IdentityExtractor for SubjectIdentityExtractor<R>
100where
101 R: X509CertificateReader,
102{
103 type Error = X509IdentityError;
104
105 fn identity(&self, certificate_chain: &CertificateChain) -> Result<Vec<u8>, Self::Error> {
106 self.identity(certificate_chain)
107 }
108
109 fn valid_successor(
110 &self,
111 predecessor: &CertificateChain,
112 successor: &CertificateChain,
113 ) -> Result<bool, Self::Error> {
114 self.valid_successor(predecessor, successor)
115 }
116}
117
118fn get_certificate(
119 certificate_chain: &CertificateChain,
120 offset: usize,
121) -> Result<&DerCertificate, X509IdentityError> {
122 certificate_chain
123 .get(offset)
124 .ok_or(X509IdentityError::InvalidOffset)
125}
126
127#[cfg(all(test, feature = "std"))]
128mod tests {
129 use crate::{
130 test_utils::test_certificate_chain, MockX509CertificateReader, SubjectComponent,
131 SubjectIdentityExtractor, X509IdentityError,
132 };
133
134 use alloc::vec;
135 use assert_matches::assert_matches;
136
137 #[cfg(target_arch = "wasm32")]
138 use wasm_bindgen_test::wasm_bindgen_test as test;
139
140 fn test_setup<F>(
141 offset: usize,
142 mut mock_setup: F,
143 ) -> SubjectIdentityExtractor<MockX509CertificateReader>
144 where
145 F: FnMut(&mut MockX509CertificateReader),
146 {
147 let mut x509_reader = MockX509CertificateReader::new();
148
149 mock_setup(&mut x509_reader);
150
151 SubjectIdentityExtractor {
152 offset,
153 reader: x509_reader,
154 }
155 }
156
157 #[test]
158 fn invalid_offset_is_rejected() {
159 let subject_extractor = test_setup(4, |subject_extractor| {
160 subject_extractor.expect_subject_bytes().never();
161 });
162
163 assert_matches!(
164 subject_extractor.identity(&test_certificate_chain()),
165 Err(X509IdentityError::InvalidOffset)
166 );
167 }
168
169 #[test]
170 fn common_name_can_be_retrived_as_identity() {
171 let test_subject = b"test_name".to_vec();
172 let cert_chain = test_certificate_chain();
173
174 let expected_certificate = cert_chain[1].clone();
175
176 let subject_extractor = test_setup(1, |parser| {
177 parser.expect_subject_bytes().never();
178
179 parser
180 .expect_subject_components()
181 .with(mockall::predicate::eq(expected_certificate.clone()))
182 .times(1)
183 .return_once_st(|_| {
184 Ok(vec![
185 SubjectComponent::CommonName("test_name".to_string()),
186 SubjectComponent::CountryName("US".to_string()),
187 ])
188 });
189 });
190
191 assert_eq!(
192 subject_extractor.identity(&cert_chain).unwrap(),
193 test_subject
194 );
195 }
196
197 #[test]
198 fn subject_can_be_retrived_as_identity_if_no_common_name() {
199 let test_subject = b"subject".to_vec();
200 let cert_chain = test_certificate_chain();
201
202 let expected_certificate = cert_chain[1].clone();
203
204 let subject_extractor = test_setup(1, |parser| {
205 let test_subject = test_subject.clone();
206
207 parser
208 .expect_subject_bytes()
209 .once()
210 .with(mockall::predicate::eq(expected_certificate.clone()))
211 .return_once_st(|_| Ok(test_subject));
212
213 parser
214 .expect_subject_components()
215 .with(mockall::predicate::eq(expected_certificate.clone()))
216 .times(1)
217 .return_once_st(|_| Ok(vec![SubjectComponent::CountryName("US".to_string())]));
218 });
219
220 assert_eq!(
221 subject_extractor.identity(&cert_chain).unwrap(),
222 test_subject
223 );
224 }
225
226 #[test]
227 fn valid_successor_matching_common_name() {
228 let predecessor = test_certificate_chain();
229 let mut successor = test_certificate_chain();
230
231 successor[0] = predecessor[0].clone();
233
234 let subject_extractor = test_setup(1, |reader| {
235 let predecessor = predecessor[0].clone();
236 let successor = successor[0].clone();
237
238 reader
239 .expect_subject_components()
240 .with(mockall::predicate::eq(successor))
241 .times(1)
242 .return_once_st(|_| {
243 Ok(vec![SubjectComponent::CommonName("test_name".to_string())])
244 });
245
246 reader
247 .expect_subject_components()
248 .with(mockall::predicate::eq(predecessor))
249 .times(1)
250 .return_once_st(|_| {
251 Ok(vec![SubjectComponent::CommonName("test_name".to_string())])
252 });
253
254 reader.expect_subject_bytes().never();
255
256 reader.expect_subject_bytes().never();
257 });
258
259 assert!(subject_extractor
260 .valid_successor(&predecessor, &successor)
261 .unwrap());
262 }
263
264 #[test]
265 fn invalid_successor_different_common_name() {
266 let predecessor = test_certificate_chain();
267 let mut successor = test_certificate_chain();
268
269 successor[0] = predecessor[0].clone();
271
272 let subject_extractor = test_setup(1, |reader| {
273 let predecessor = predecessor[0].clone();
274 let successor = successor[0].clone();
275
276 reader
277 .expect_subject_components()
278 .with(mockall::predicate::eq(successor))
279 .times(1)
280 .return_once_st(|_| {
281 Ok(vec![
282 SubjectComponent::CommonName("test_name_copy".to_string()),
283 SubjectComponent::CountryName("US".to_string()),
284 ])
285 });
286
287 reader
288 .expect_subject_components()
289 .with(mockall::predicate::eq(predecessor))
290 .times(1)
291 .return_once_st(|_| {
292 Ok(vec![
293 SubjectComponent::CommonName("test_name".to_string()),
294 SubjectComponent::CountryName("US".to_string()),
295 ])
296 });
297
298 reader.expect_subject_bytes().never();
299
300 reader.expect_subject_bytes().never();
301 });
302
303 assert!(
304 !subject_extractor
305 .valid_successor(&predecessor, &successor)
306 .unwrap(),
307 "Successor chain cert with different CommonName passed check!"
308 );
309 }
310
311 #[test]
312 fn valid_successor_no_common_name() {
313 let predecessor = test_certificate_chain();
314 let mut successor = test_certificate_chain();
315
316 successor[0] = predecessor[0].clone();
318
319 let subject_extractor = test_setup(1, |reader| {
320 let predecessor = predecessor[0].clone();
321 let successor = successor[0].clone();
322
323 reader
324 .expect_subject_components()
325 .with(mockall::predicate::eq(successor.clone()))
326 .times(1)
327 .return_once_st(|_| Ok(vec![SubjectComponent::CountryName("US".to_string())]));
328
329 reader
330 .expect_subject_components()
331 .with(mockall::predicate::eq(predecessor.clone()))
332 .times(1)
333 .return_once_st(|_| {
334 Ok(vec![
335 SubjectComponent::CommonName("test_name".to_string()),
336 SubjectComponent::CountryName("US".to_string()),
337 ])
338 });
339
340 reader
341 .expect_subject_bytes()
342 .with(mockall::predicate::eq(predecessor))
343 .times(1)
344 .return_once_st(|_| Ok(b"subject".to_vec()));
345
346 reader
347 .expect_subject_bytes()
348 .with(mockall::predicate::eq(successor))
349 .times(1)
350 .return_once_st(|_| Ok(b"subject".to_vec()));
351 });
352
353 assert!(subject_extractor
354 .valid_successor(&predecessor, &successor)
355 .unwrap());
356 }
357
358 #[test]
359 fn invalid_successor_no_common_name() {
360 let predecessor = test_certificate_chain();
361 let mut successor = test_certificate_chain();
362
363 successor[0] = predecessor[0].clone();
365
366 let subject_extractor = test_setup(1, |reader| {
367 let predecessor = predecessor[0].clone();
368 let successor = successor[0].clone();
369
370 reader
371 .expect_subject_bytes()
372 .with(mockall::predicate::eq(predecessor.clone()))
373 .times(1)
374 .return_once_st(|_| Ok(b"subject_copy".to_vec()));
375
376 reader
377 .expect_subject_bytes()
378 .with(mockall::predicate::eq(successor.clone()))
379 .times(1)
380 .return_once_st(|_| Ok(b"subject".to_vec()));
381
382 reader
383 .expect_subject_components()
384 .with(mockall::predicate::eq(successor))
385 .times(1)
386 .return_once_st(|_| Ok(vec![SubjectComponent::CountryName("US".to_string())]));
387
388 reader
389 .expect_subject_components()
390 .with(mockall::predicate::eq(predecessor))
391 .times(1)
392 .return_once_st(|_| Ok(vec![SubjectComponent::CountryName("US".to_string())]));
393 });
394
395 assert!(
396 !subject_extractor
397 .valid_successor(&predecessor, &successor)
398 .unwrap(),
399 "Successor cert chain with different subjects passed valid check!"
400 );
401 }
402}