mls_rs_crypto_rustcrypto/x509/
validator.rs1use mls_rs_core::{crypto::SignaturePublicKey, time::MlsTime};
6use mls_rs_identity_x509::{CertificateChain, DerCertificate, X509CredentialValidator};
7use spki::der::{Decode, Encode};
8use std::{
9 collections::HashMap,
10 fmt::{self, Debug},
11};
12use x509_cert::Certificate;
13
14use crate::{
15 ec::pub_key_to_uncompressed,
16 ec_for_x509::{pub_key_from_spki, signer_from_algorithm},
17};
18
19use super::X509Error;
20
21#[derive(Clone)]
22pub struct X509Validator {
23 root_ca_list: HashMap<Vec<u8>, DerCertificate>,
24 pinned_cert: Option<DerCertificate>,
25 allow_self_signed: bool,
26}
27
28impl Debug for X509Validator {
29 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
30 f.debug_struct("X509Validator")
31 .field(
32 "root_ca_list",
33 &mls_rs_core::debug::pretty_with(|f| {
34 f.debug_map()
35 .entries(
36 self.root_ca_list
37 .iter()
38 .map(|(k, v)| (mls_rs_core::debug::pretty_bytes(k), v)),
39 )
40 .finish()
41 }),
42 )
43 .field("pinned_cert", &self.pinned_cert)
44 .field("allow_self_signed", &self.allow_self_signed)
45 .finish()
46 }
47}
48
49impl X509Validator {
50 pub fn new(root_ca_list: Vec<DerCertificate>) -> Result<Self, X509Error> {
51 let root_ca_list = root_ca_list
52 .into_iter()
53 .map(|cert_data| {
54 let cert = Certificate::from_der(&cert_data)?;
56 verify_cert(&cert, &cert, None)?;
57 let subject = cert.tbs_certificate.subject.to_der()?;
58 Ok((subject, cert_data))
59 })
60 .collect::<Result<_, X509Error>>()?;
61
62 Ok(Self {
63 root_ca_list,
64 pinned_cert: None,
65 allow_self_signed: false,
66 })
67 }
68
69 pub fn set_pinned_cert(&mut self, pinned_cert: Option<DerCertificate>) {
70 self.pinned_cert = pinned_cert;
71 }
72
73 pub fn allow_self_signed(&mut self, allow: bool) {
75 self.allow_self_signed = allow;
76 }
77
78 fn validate_chain(
79 &self,
80 chain: &CertificateChain,
81 timestamp: Option<MlsTime>,
82 ) -> Result<SignaturePublicKey, X509Error> {
83 (!chain.is_empty())
84 .then_some(())
85 .ok_or(X509Error::EmptyCertificateChain)?;
86
87 if let Some(pinned_cert) = self.pinned_cert.as_ref() {
88 chain
89 .contains(pinned_cert)
90 .then_some(())
91 .ok_or(X509Error::PinnedCertNotFound)?;
92 }
93
94 let chain = chain
95 .iter()
96 .map(|cert_data| Certificate::from_der(cert_data))
97 .collect::<Result<Vec<_>, _>>()?;
98
99 for (cert1, cert2) in chain
100 .iter()
101 .zip(chain.iter().skip(1).chain(chain.iter().rev().take(1)))
102 {
103 let maybe_ca = self
104 .root_ca_list
105 .get(&cert1.tbs_certificate.issuer.to_der()?);
106
107 let verifier = maybe_ca
108 .map(|ca| {
109 let ca = Certificate::from_der(ca)?;
110
111 if let Some(time) = timestamp {
112 verify_time(&ca, time)?;
113 }
114
115 Ok::<_, X509Error>(ca)
116 })
117 .transpose()?;
118
119 let verifier = verifier.as_ref().unwrap_or(cert2);
120 verify_cert(verifier, cert1, timestamp)?;
121
122 if maybe_ca.is_some() {
124 let leaf_cert = chain.first().ok_or(X509Error::EmptyCertificateChain)?;
125
126 let pub_key =
127 pub_key_from_spki(&leaf_cert.tbs_certificate.subject_public_key_info)?;
128
129 let pub_signing_key = pub_key_to_uncompressed(&pub_key).map(Into::into)?;
130
131 return Ok(pub_signing_key);
132 }
133 }
134
135 Err(X509Error::CaNotFound)
136 }
137}
138
139fn verify_time(cert: &Certificate, timestamp: MlsTime) -> Result<(), X509Error> {
140 let validity = cert.tbs_certificate.validity;
141 let not_before = MlsTime::from(validity.not_before.to_unix_duration());
142 let not_after = MlsTime::from(validity.not_after.to_unix_duration());
143
144 if timestamp < not_before || timestamp > not_after {
145 return Err(X509Error::ValidityError {
146 timestamp,
147 not_before,
148 not_after,
149 });
150 }
151
152 Ok(())
153}
154
155fn verify_cert(
156 verifier: &Certificate,
157 verified: &Certificate,
158 timestamp: Option<MlsTime>,
159) -> Result<(), X509Error> {
160 let mut tbs = Vec::new();
162 verified.tbs_certificate.encode_to_vec(&mut tbs)?;
163
164 let signer =
166 signer_from_algorithm(&verifier.tbs_certificate.subject_public_key_info.algorithm)?;
167
168 let pub_key = pub_key_from_spki(&verifier.tbs_certificate.subject_public_key_info)?;
169
170 signer.verify(
172 &pub_key_to_uncompressed(&pub_key).map(Into::into)?,
173 verified.signature.raw_bytes(),
174 &tbs,
175 )?;
176
177 if let Some(time) = timestamp {
179 verify_time(verified, time)?;
180 }
181
182 Ok(())
183}
184
185fn validate_self_signed(
186 chain: &CertificateChain,
187 timestamp: Option<MlsTime>,
188) -> Result<SignaturePublicKey, X509Error> {
189 if chain.len() != 1 {
190 return Err(X509Error::SelfSignedWrongLength(chain.len()));
191 }
192
193 let cert = Certificate::from_der(&chain[0])?;
194
195 verify_cert(&cert, &cert, timestamp)?;
196
197 let pub_key = pub_key_from_spki(&cert.tbs_certificate.subject_public_key_info)?;
198
199 let pub_signing_key = pub_key_to_uncompressed(&pub_key).map(Into::into)?;
200
201 Ok(pub_signing_key)
202}
203
204impl X509CredentialValidator for X509Validator {
205 type Error = X509Error;
206
207 fn validate_chain(
208 &self,
209 chain: &mls_rs_identity_x509::CertificateChain,
210 timestamp: Option<mls_rs_core::time::MlsTime>,
211 ) -> Result<SignaturePublicKey, Self::Error> {
212 if !self.allow_self_signed {
213 self.validate_chain(chain, timestamp)
214 } else {
215 validate_self_signed(chain, timestamp)
216 }
217 }
218}
219
220#[cfg(test)]
221mod tests {
222 use assert_matches::assert_matches;
223 use mls_rs_core::time::MlsTime;
224 use mls_rs_identity_x509::{CertificateChain, X509CredentialValidator};
225 use spki::der::Decode;
226 use x509_cert::Certificate;
227
228 use crate::{
229 ec_signer::EcSignerError,
230 x509::{
231 util::test_utils::{
232 load_another_ca, load_test_ca, load_test_cert_chain, load_test_invalid_ca_chain,
233 load_test_invalid_chain, load_test_p384_ca,
234 },
235 X509Error,
236 },
237 };
238
239 use super::X509Validator;
240
241 #[test]
242 fn can_validate_cert_chain() {
243 let chain = load_test_cert_chain();
244
245 let validator = X509Validator::new(vec![load_test_ca()]).unwrap();
246
247 validator
248 .validate_chain(&chain, Some(MlsTime::now()))
249 .unwrap();
250 }
251
252 #[test]
253 fn can_validate_cert_chain_without_ca() {
254 let chain = load_test_cert_chain();
255 let chain = chain[0..chain.len() - 1].to_vec().into();
256
257 let validator = X509Validator::new(vec![load_test_ca()]).unwrap();
258
259 validator
260 .validate_chain(&chain, Some(MlsTime::now()))
261 .unwrap();
262 }
263
264 #[test]
265 fn can_validate_cert_chain_with_pinned() {
266 let chain = load_test_cert_chain();
267
268 let mut validator = X509Validator::new(vec![load_test_ca()]).unwrap();
269 validator.set_pinned_cert(Some(chain.get(1).unwrap().clone()));
270
271 validator
272 .validate_chain(&chain, Some(MlsTime::now()))
273 .unwrap();
274 }
275
276 #[test]
277 fn can_validate_self_signed() {
278 let mut validator = X509Validator::new(vec![]).unwrap();
279 validator.allow_self_signed(true);
280
281 let chain = vec![load_test_ca()].into();
282
283 X509CredentialValidator::validate_chain(&validator, &chain, Some(MlsTime::now())).unwrap();
284 }
285
286 #[test]
287 fn can_validate_p384_cert() {
288 let mut validator = X509Validator::new(vec![]).unwrap();
289 validator.allow_self_signed(true);
290
291 let chain = vec![load_test_p384_ca()].into();
292
293 X509CredentialValidator::validate_chain(&validator, &chain, Some(MlsTime::now())).unwrap();
294 }
295
296 #[test]
297 fn fails_on_too_long_self_signed() {
298 let mut validator = X509Validator::new(vec![]).unwrap();
299 validator.allow_self_signed(true);
300
301 let chain = vec![load_test_ca(), load_another_ca()].into();
302
303 let res = X509CredentialValidator::validate_chain(&validator, &chain, Some(MlsTime::now()));
304
305 assert_matches!(res, Err(X509Error::SelfSignedWrongLength(2)))
306 }
307
308 #[test]
309 fn fails_if_pinned_missing() {
310 let chain = load_test_cert_chain();
311
312 let mut validator = X509Validator::new(vec![load_test_ca()]).unwrap();
313 validator.set_pinned_cert(Some(load_another_ca()));
314
315 let res = validator.validate_chain(&chain, Some(MlsTime::now()));
316
317 assert_matches!(res, Err(X509Error::PinnedCertNotFound));
318 }
319
320 #[test]
321 fn can_detect_invalid_ca_certificates() {
322 assert_matches!(
323 X509Validator::new(vec![vec![0u8; 32].into()]),
324 Err(X509Error::X509DerError(_))
325 )
326 }
327
328 #[test]
329 fn can_detect_ca_cert_with_invalid_self_signed_signature() {
330 let test_cert = load_test_cert_chain()[0].clone();
331
332 assert_matches!(
333 X509Validator::new(vec![test_cert]),
334 Err(X509Error::EcSignerError(EcSignerError::InvalidSignature))
335 )
336 }
337
338 #[test]
339 fn will_fail_on_empty_chain() {
340 let validator = X509Validator::new(vec![]).unwrap();
341 let empty: Vec<Vec<u8>> = Vec::new();
342
343 let res = validator.validate_chain(&CertificateChain::from(empty), Some(MlsTime::now()));
344
345 assert_matches!(res, Err(X509Error::EmptyCertificateChain));
346 }
347
348 #[test]
349 fn will_fail_on_invalid_chain() {
350 let chain = load_test_invalid_chain();
351 let validator = X509Validator::new(vec![load_test_ca()]).unwrap();
352
353 let res = validator.validate_chain(&chain, Some(MlsTime::now()));
354
355 assert_matches!(
356 res,
357 Err(X509Error::EcSignerError(EcSignerError::InvalidSignature))
358 );
359 }
360
361 #[test]
362 fn will_fail_on_invalid_ca() {
363 let chain = load_test_invalid_ca_chain();
364 let validator = X509Validator::new(vec![load_another_ca()]).unwrap();
365 let res = validator.validate_chain(&chain, Some(MlsTime::now()));
366
367 assert_matches!(
368 res,
369 Err(X509Error::EcSignerError(EcSignerError::InvalidSignature))
370 );
371 }
372
373 #[test]
374 fn can_detect_expired_certs() {
375 let chain = load_test_cert_chain();
376
377 let validator = X509Validator::new(vec![load_test_ca()]).unwrap();
378
379 let res = validator.validate_chain(&chain, Some(MlsTime::from(1798761600)));
380
381 let Err(X509Error::ValidityError {
382 timestamp,
383 not_before,
384 not_after,
385 }) = res
386 else {
387 panic!("Expected validity error, got: {res:?}");
388 };
389 assert_eq!(timestamp, MlsTime::from(1798761600));
390 assert_eq!(not_before, MlsTime::from(1673273683));
391 assert_eq!(not_after, MlsTime::from(1767968040));
392 }
393
394 #[test]
395 fn will_return_public_key_of_leaf() {
396 let chain = load_test_cert_chain();
397
398 let expected = Certificate::from_der(chain.leaf().unwrap())
399 .unwrap()
400 .tbs_certificate
401 .subject_public_key_info
402 .subject_public_key
403 .raw_bytes()
404 .to_vec()
405 .into();
406
407 let validator = X509Validator::new(vec![load_test_ca()]).unwrap();
408
409 assert_eq!(validator.validate_chain(&chain, None).unwrap(), expected)
410 }
411}