mls_rs_crypto_rustcrypto/x509/
validator.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// Copyright by contributors to this project.
3// SPDX-License-Identifier: (Apache-2.0 OR MIT)
4
5use mls_rs_core::{crypto::SignaturePublicKey, time::MlsTime};
6use mls_rs_identity_x509::{CertificateChain, DerCertificate, X509CredentialValidator};
7use spki::der::{Decode, Encode};
8use std::{
9    collections::HashMap,
10    fmt::{self, Debug},
11};
12use x509_cert::Certificate;
13
14use crate::{
15    ec::pub_key_to_uncompressed,
16    ec_for_x509::{pub_key_from_spki, signer_from_algorithm},
17};
18
19use super::X509Error;
20
21#[derive(Clone)]
22pub struct X509Validator {
23    root_ca_list: HashMap<Vec<u8>, DerCertificate>,
24    pinned_cert: Option<DerCertificate>,
25    allow_self_signed: bool,
26}
27
28impl Debug for X509Validator {
29    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
30        f.debug_struct("X509Validator")
31            .field(
32                "root_ca_list",
33                &mls_rs_core::debug::pretty_with(|f| {
34                    f.debug_map()
35                        .entries(
36                            self.root_ca_list
37                                .iter()
38                                .map(|(k, v)| (mls_rs_core::debug::pretty_bytes(k), v)),
39                        )
40                        .finish()
41                }),
42            )
43            .field("pinned_cert", &self.pinned_cert)
44            .field("allow_self_signed", &self.allow_self_signed)
45            .finish()
46    }
47}
48
49impl X509Validator {
50    pub fn new(root_ca_list: Vec<DerCertificate>) -> Result<Self, X509Error> {
51        let root_ca_list = root_ca_list
52            .into_iter()
53            .map(|cert_data| {
54                // Verify the self-signture. Time is validated when CAs are used
55                let cert = Certificate::from_der(&cert_data)?;
56                verify_cert(&cert, &cert, None)?;
57                let subject = cert.tbs_certificate.subject.to_der()?;
58                Ok((subject, cert_data))
59            })
60            .collect::<Result<_, X509Error>>()?;
61
62        Ok(Self {
63            root_ca_list,
64            pinned_cert: None,
65            allow_self_signed: false,
66        })
67    }
68
69    pub fn set_pinned_cert(&mut self, pinned_cert: Option<DerCertificate>) {
70        self.pinned_cert = pinned_cert;
71    }
72
73    /// This MUST be used only in tests. DO NOT use in production.
74    pub fn allow_self_signed(&mut self, allow: bool) {
75        self.allow_self_signed = allow;
76    }
77
78    fn validate_chain(
79        &self,
80        chain: &CertificateChain,
81        timestamp: Option<MlsTime>,
82    ) -> Result<SignaturePublicKey, X509Error> {
83        (!chain.is_empty())
84            .then_some(())
85            .ok_or(X509Error::EmptyCertificateChain)?;
86
87        if let Some(pinned_cert) = self.pinned_cert.as_ref() {
88            chain
89                .contains(pinned_cert)
90                .then_some(())
91                .ok_or(X509Error::PinnedCertNotFound)?;
92        }
93
94        let chain = chain
95            .iter()
96            .map(|cert_data| Certificate::from_der(cert_data))
97            .collect::<Result<Vec<_>, _>>()?;
98
99        for (cert1, cert2) in chain
100            .iter()
101            .zip(chain.iter().skip(1).chain(chain.iter().rev().take(1)))
102        {
103            let maybe_ca = self
104                .root_ca_list
105                .get(&cert1.tbs_certificate.issuer.to_der()?);
106
107            let verifier = maybe_ca
108                .map(|ca| {
109                    let ca = Certificate::from_der(ca)?;
110
111                    if let Some(time) = timestamp {
112                        verify_time(&ca, time)?;
113                    }
114
115                    Ok::<_, X509Error>(ca)
116                })
117                .transpose()?;
118
119            let verifier = verifier.as_ref().unwrap_or(cert2);
120            verify_cert(verifier, cert1, timestamp)?;
121
122            // If we found a CA, we're done with the chain.
123            if maybe_ca.is_some() {
124                let leaf_cert = chain.first().ok_or(X509Error::EmptyCertificateChain)?;
125
126                let pub_key =
127                    pub_key_from_spki(&leaf_cert.tbs_certificate.subject_public_key_info)?;
128
129                let pub_signing_key = pub_key_to_uncompressed(&pub_key).map(Into::into)?;
130
131                return Ok(pub_signing_key);
132            }
133        }
134
135        Err(X509Error::CaNotFound)
136    }
137}
138
139fn verify_time(cert: &Certificate, timestamp: MlsTime) -> Result<(), X509Error> {
140    let validity = cert.tbs_certificate.validity;
141    let not_before = MlsTime::from(validity.not_before.to_unix_duration());
142    let not_after = MlsTime::from(validity.not_after.to_unix_duration());
143
144    if timestamp < not_before || timestamp > not_after {
145        return Err(X509Error::ValidityError {
146            timestamp,
147            not_before,
148            not_after,
149        });
150    }
151
152    Ok(())
153}
154
155fn verify_cert(
156    verifier: &Certificate,
157    verified: &Certificate,
158    timestamp: Option<MlsTime>,
159) -> Result<(), X509Error> {
160    // Re-encode the verified TBS struct to get the signed bytes
161    let mut tbs = Vec::new();
162    verified.tbs_certificate.encode_to_vec(&mut tbs)?;
163
164    // Create a signer for the verifier
165    let signer =
166        signer_from_algorithm(&verifier.tbs_certificate.subject_public_key_info.algorithm)?;
167
168    let pub_key = pub_key_from_spki(&verifier.tbs_certificate.subject_public_key_info)?;
169
170    // Verify the signature
171    signer.verify(
172        &pub_key_to_uncompressed(&pub_key).map(Into::into)?,
173        verified.signature.raw_bytes(),
174        &tbs,
175    )?;
176
177    // Verify properties
178    if let Some(time) = timestamp {
179        verify_time(verified, time)?;
180    }
181
182    Ok(())
183}
184
185fn validate_self_signed(
186    chain: &CertificateChain,
187    timestamp: Option<MlsTime>,
188) -> Result<SignaturePublicKey, X509Error> {
189    if chain.len() != 1 {
190        return Err(X509Error::SelfSignedWrongLength(chain.len()));
191    }
192
193    let cert = Certificate::from_der(&chain[0])?;
194
195    verify_cert(&cert, &cert, timestamp)?;
196
197    let pub_key = pub_key_from_spki(&cert.tbs_certificate.subject_public_key_info)?;
198
199    let pub_signing_key = pub_key_to_uncompressed(&pub_key).map(Into::into)?;
200
201    Ok(pub_signing_key)
202}
203
204impl X509CredentialValidator for X509Validator {
205    type Error = X509Error;
206
207    fn validate_chain(
208        &self,
209        chain: &mls_rs_identity_x509::CertificateChain,
210        timestamp: Option<mls_rs_core::time::MlsTime>,
211    ) -> Result<SignaturePublicKey, Self::Error> {
212        if !self.allow_self_signed {
213            self.validate_chain(chain, timestamp)
214        } else {
215            validate_self_signed(chain, timestamp)
216        }
217    }
218}
219
220#[cfg(test)]
221mod tests {
222    use assert_matches::assert_matches;
223    use mls_rs_core::time::MlsTime;
224    use mls_rs_identity_x509::{CertificateChain, X509CredentialValidator};
225    use spki::der::Decode;
226    use x509_cert::Certificate;
227
228    use crate::{
229        ec_signer::EcSignerError,
230        x509::{
231            util::test_utils::{
232                load_another_ca, load_test_ca, load_test_cert_chain, load_test_invalid_ca_chain,
233                load_test_invalid_chain, load_test_p384_ca,
234            },
235            X509Error,
236        },
237    };
238
239    use super::X509Validator;
240
241    #[test]
242    fn can_validate_cert_chain() {
243        let chain = load_test_cert_chain();
244
245        let validator = X509Validator::new(vec![load_test_ca()]).unwrap();
246
247        validator
248            .validate_chain(&chain, Some(MlsTime::now()))
249            .unwrap();
250    }
251
252    #[test]
253    fn can_validate_cert_chain_without_ca() {
254        let chain = load_test_cert_chain();
255        let chain = chain[0..chain.len() - 1].to_vec().into();
256
257        let validator = X509Validator::new(vec![load_test_ca()]).unwrap();
258
259        validator
260            .validate_chain(&chain, Some(MlsTime::now()))
261            .unwrap();
262    }
263
264    #[test]
265    fn can_validate_cert_chain_with_pinned() {
266        let chain = load_test_cert_chain();
267
268        let mut validator = X509Validator::new(vec![load_test_ca()]).unwrap();
269        validator.set_pinned_cert(Some(chain.get(1).unwrap().clone()));
270
271        validator
272            .validate_chain(&chain, Some(MlsTime::now()))
273            .unwrap();
274    }
275
276    #[test]
277    fn can_validate_self_signed() {
278        let mut validator = X509Validator::new(vec![]).unwrap();
279        validator.allow_self_signed(true);
280
281        let chain = vec![load_test_ca()].into();
282
283        X509CredentialValidator::validate_chain(&validator, &chain, Some(MlsTime::now())).unwrap();
284    }
285
286    #[test]
287    fn can_validate_p384_cert() {
288        let mut validator = X509Validator::new(vec![]).unwrap();
289        validator.allow_self_signed(true);
290
291        let chain = vec![load_test_p384_ca()].into();
292
293        X509CredentialValidator::validate_chain(&validator, &chain, Some(MlsTime::now())).unwrap();
294    }
295
296    #[test]
297    fn fails_on_too_long_self_signed() {
298        let mut validator = X509Validator::new(vec![]).unwrap();
299        validator.allow_self_signed(true);
300
301        let chain = vec![load_test_ca(), load_another_ca()].into();
302
303        let res = X509CredentialValidator::validate_chain(&validator, &chain, Some(MlsTime::now()));
304
305        assert_matches!(res, Err(X509Error::SelfSignedWrongLength(2)))
306    }
307
308    #[test]
309    fn fails_if_pinned_missing() {
310        let chain = load_test_cert_chain();
311
312        let mut validator = X509Validator::new(vec![load_test_ca()]).unwrap();
313        validator.set_pinned_cert(Some(load_another_ca()));
314
315        let res = validator.validate_chain(&chain, Some(MlsTime::now()));
316
317        assert_matches!(res, Err(X509Error::PinnedCertNotFound));
318    }
319
320    #[test]
321    fn can_detect_invalid_ca_certificates() {
322        assert_matches!(
323            X509Validator::new(vec![vec![0u8; 32].into()]),
324            Err(X509Error::X509DerError(_))
325        )
326    }
327
328    #[test]
329    fn can_detect_ca_cert_with_invalid_self_signed_signature() {
330        let test_cert = load_test_cert_chain()[0].clone();
331
332        assert_matches!(
333            X509Validator::new(vec![test_cert]),
334            Err(X509Error::EcSignerError(EcSignerError::InvalidSignature))
335        )
336    }
337
338    #[test]
339    fn will_fail_on_empty_chain() {
340        let validator = X509Validator::new(vec![]).unwrap();
341        let empty: Vec<Vec<u8>> = Vec::new();
342
343        let res = validator.validate_chain(&CertificateChain::from(empty), Some(MlsTime::now()));
344
345        assert_matches!(res, Err(X509Error::EmptyCertificateChain));
346    }
347
348    #[test]
349    fn will_fail_on_invalid_chain() {
350        let chain = load_test_invalid_chain();
351        let validator = X509Validator::new(vec![load_test_ca()]).unwrap();
352
353        let res = validator.validate_chain(&chain, Some(MlsTime::now()));
354
355        assert_matches!(
356            res,
357            Err(X509Error::EcSignerError(EcSignerError::InvalidSignature))
358        );
359    }
360
361    #[test]
362    fn will_fail_on_invalid_ca() {
363        let chain = load_test_invalid_ca_chain();
364        let validator = X509Validator::new(vec![load_another_ca()]).unwrap();
365        let res = validator.validate_chain(&chain, Some(MlsTime::now()));
366
367        assert_matches!(
368            res,
369            Err(X509Error::EcSignerError(EcSignerError::InvalidSignature))
370        );
371    }
372
373    #[test]
374    fn can_detect_expired_certs() {
375        let chain = load_test_cert_chain();
376
377        let validator = X509Validator::new(vec![load_test_ca()]).unwrap();
378
379        let res = validator.validate_chain(&chain, Some(MlsTime::from(1798761600)));
380
381        let Err(X509Error::ValidityError {
382            timestamp,
383            not_before,
384            not_after,
385        }) = res
386        else {
387            panic!("Expected validity error, got: {res:?}");
388        };
389        assert_eq!(timestamp, MlsTime::from(1798761600));
390        assert_eq!(not_before, MlsTime::from(1673273683));
391        assert_eq!(not_after, MlsTime::from(1767968040));
392    }
393
394    #[test]
395    fn will_return_public_key_of_leaf() {
396        let chain = load_test_cert_chain();
397
398        let expected = Certificate::from_der(chain.leaf().unwrap())
399            .unwrap()
400            .tbs_certificate
401            .subject_public_key_info
402            .subject_public_key
403            .raw_bytes()
404            .to_vec()
405            .into();
406
407        let validator = X509Validator::new(vec![load_test_ca()]).unwrap();
408
409        assert_eq!(validator.validate_chain(&chain, None).unwrap(), expected)
410    }
411}