Skip to main content

lexe_tls/
lib.rs

1//! Lexe TLS configs, certs, and utilities.
2
3use std::{str::FromStr, sync::LazyLock};
4
5use asn1_rs::FromDer;
6use lexe_byte_array::ByteArray;
7use lexe_crypto::ed25519;
8use lexe_sha256::sha256;
9use rcgen::{DistinguishedName, DnType, string::Ia5String};
10use x509_parser::{
11    certificate::X509Certificate, extensions::GeneralName, time::ASN1Time,
12};
13
14/// mTLS clients for verifying SGX remote attestations.
15pub mod attest_client;
16/// ed25519 key pair extension trait (PEM ser/de).
17pub mod ed25519_ext;
18/// Certs and utilities related to Lexe's CA.
19pub mod lexe_ca;
20/// ECDSA P-256 key pairs for webpki TLS certs.
21pub mod p256;
22/// mTLS based on a shared `RootSeed`.
23pub mod shared_seed;
24/// TLS newtypes, namely DER-encoded certs and cert keys.
25pub mod types;
26
27/// Re-export all of `lexe_tls_core`.
28pub use lexe_tls_core::*;
29
30/// Whether the given DER-encoded cert is bound to the given DNS names.
31///
32/// Returns [`false`] if the cert doesn't contain all the dns names, fails to
33/// parse, or is otherwise invalid.
34#[must_use]
35pub fn cert_contains_dns(cert_der: &[u8], expected_dns: &[&str]) -> bool {
36    fn contains_dns(cert_der: &[u8], expected_dns: &[&str]) -> Option<()> {
37        if expected_dns.is_empty() {
38            return Some(());
39        }
40
41        let (_unparsed, cert) = X509Certificate::from_der(cert_der).ok()?;
42
43        let sans = &cert.subject_alternative_name().ok()??.value.general_names;
44
45        expected_dns
46            .iter()
47            .all(|dns_name| sans.contains(&GeneralName::DNSName(dns_name)))
48            .then_some(())
49    }
50
51    contains_dns(cert_der, expected_dns).is_some()
52}
53
54/// Whether the given DER-encoded cert is currently valid and will be valid for
55/// at least `buffer_days` more days. `buffer_days=0` can be used if you only
56/// wish to check whether the cert is currently valid. Does not validate
57/// anything other than expiry. Returns [`false`] if the cert failed to parse.
58#[must_use]
59pub fn cert_is_valid_for_at_least(cert_der: &[u8], buffer_days: u16) -> bool {
60    fn is_valid_for_at_least(cert_der: &[u8], buffer_days: i64) -> Option<()> {
61        use std::ops::Add;
62
63        let (_unparsed, cert) = X509Certificate::from_der(cert_der).ok()?;
64
65        let now = ASN1Time::now();
66        let validity = cert.validity();
67
68        if now < validity.not_before {
69            return None;
70        }
71        if now > validity.not_after {
72            return None;
73        }
74
75        let buffer_days_dur = time::Duration::days(buffer_days);
76
77        // Check the same conditions `buffer_days` later.
78        let now_plus_buffer = now.add(buffer_days_dur)?;
79        if now_plus_buffer < validity.not_before {
80            return None;
81        }
82        if now_plus_buffer > validity.not_after {
83            return None;
84        }
85
86        Some(())
87    }
88
89    is_valid_for_at_least(cert_der, i64::from(buffer_days)).is_some()
90}
91
92/// A safe default for [`rcgen::CertificateParams::subject_alt_names`] when
93/// there isn't a specific value that makes sense. Used for client / CA certs.
94pub static DEFAULT_SUBJECT_ALT_NAMES: LazyLock<Vec<rcgen::SanType>> =
95    LazyLock::new(|| {
96        vec![rcgen::SanType::DnsName(
97            Ia5String::from_str("lexe.app").unwrap(),
98        )]
99    });
100
101/// Build an [`rcgen::CertificateParams`] with Lexe presets and optional
102/// overrides.
103/// - This builder function helps ensure that important fields in the inner
104///   [`rcgen::CertificateParams`] are considered. See struct for details.
105/// - Any special fields or overrides can be specified using the `overrides`
106///   closure. See usages for examples.
107/// - `key_pair` and `alg` cannot be overridden.
108//
109// TODO(phlip9): needs some normalizing with WebPKI
110// - <https://letsencrypt.org/docs/profiles/>
111// - <https://github.com/cabforum/servercert/blob/main/docs/BR.md>
112// - use `ExplicitNoCa` for end-entity certs, use exact `KeyUsage` and
113//   `ExtendedKeyUsage`, end-entities should just use first SAN as CN
114pub fn build_rcgen_cert_params(
115    common_name: &str,
116    not_before: time::OffsetDateTime,
117    not_after: time::OffsetDateTime,
118    subject_alt_names: Vec<rcgen::SanType>,
119    public_key: &ed25519::PublicKey,
120    overrides: impl FnOnce(&mut rcgen::CertificateParams),
121) -> rcgen::CertificateParams {
122    let mut params = rcgen::CertificateParams::default();
123
124    params.not_before = not_before;
125    params.not_after = not_after;
126    params.subject_alt_names = subject_alt_names;
127    params.distinguished_name = lexe_distinguished_name(common_name);
128
129    // is_ca: IsCa::NoCa,
130    // key_usages: Vec::new(),
131    // extended_key_usages: Vec::new(),
132    // name_constraints: None,
133    // crl_distribution_points: Vec::new(),
134    // custom_extensions: Vec::new(),
135    // use_authority_key_identifier_extension: false,
136
137    // Custom caller overrides
138    overrides(&mut params);
139
140    // Preserve old `ring` pre-v0.14.0 behavior that uses
141    // `key_identifier_method := Sha256(public-key)[0..20]` instead of
142    // `key_identifier_method := Sha256(spki)[0..20]` used in newer `ring`.
143    //
144    // Conveniently also calculate the serial number at the same time, since
145    // it's almost the same thing.
146    //
147    // RFC 5280 specifies at most 20 bytes for a serial/subject key identifier
148    let pubkey_hash = {
149        let hash = sha256::digest(public_key.as_slice());
150        hash.as_slice()[0..20].to_vec()
151    };
152
153    // Only CA certs (including explicit self-signed-only certs) need a
154    // `SubjectKeyIdentifier`.
155    if matches!(params.is_ca, rcgen::IsCa::Ca(_) | rcgen::IsCa::ExplicitNoCa) {
156        params.key_identifier_method =
157            rcgen::KeyIdMethod::PreSpecified(pubkey_hash.clone());
158    }
159
160    // Use the (tweaked) pubkey hash as the cert serial number.
161    let mut serial = pubkey_hash;
162    serial[0] &= 0x7f; // MSB must be 0 to ensure encoding bignum in 20 B
163    params.serial_number = Some(rcgen::SerialNumber::from(serial));
164
165    params
166}
167
168/// Build a Lexe Distinguished Name given a Common Name.
169pub fn lexe_distinguished_name(common_name: &str) -> DistinguishedName {
170    let mut name = DistinguishedName::new();
171    name.push(DnType::CountryName, "US");
172    name.push(DnType::StateOrProvinceName, "CA");
173    name.push(DnType::OrganizationName, "lexe-app");
174    name.push(DnType::CommonName, common_name);
175    name
176}
177
178/// TLS-specific test utilities.
179#[cfg(any(test, feature = "test-utils"))]
180pub mod test_utils {
181    use std::sync::Arc;
182
183    use anyhow::Context;
184    use rustls::{ClientConfig, ServerConfig, pki_types::ServerName};
185    use tokio::io::{AsyncReadExt, AsyncWriteExt};
186
187    /// Conducts a TLS handshake without any other [`reqwest`]/[`axum`] infra,
188    /// over a fake pair of connected streams. Returns the client and server
189    /// results instead of panicking so that negative cases can be tested too.
190    pub async fn do_tls_handshake(
191        client_config: Arc<ClientConfig>,
192        server_config: Arc<ServerConfig>,
193        // This is the DNS name that the *client* expects the server to have.
194        expected_dns: &str,
195    ) -> [Result<(), String>; 2] {
196        // a fake pair of connected streams
197        let (client_stream, server_stream) = tokio::io::duplex(4096);
198
199        // client connects, sends "hello", receives "goodbye"
200        let client = async move {
201            let connector = tokio_rustls::TlsConnector::from(client_config);
202            let sni = ServerName::try_from(expected_dns.to_owned()).unwrap();
203            let mut stream = connector
204                .connect(sni, client_stream)
205                .await
206                .context("Client didn't connect")?;
207
208            // client: >> send "hello"
209            stream
210                .write_all(b"hello")
211                .await
212                .context("Could not write hello")?;
213            stream.flush().await.context("Toilet clogged")?;
214            stream.shutdown().await.context("Could not shutdown")?;
215
216            // client: << recv "goodbye"
217            let mut resp = Vec::new();
218            stream.read_to_end(&mut resp).await.context("Read failed")?;
219            assert_eq!(&resp, b"goodbye");
220
221            Ok::<_, anyhow::Error>(())
222        };
223
224        // server accepts, receives "hello", responds with "goodbye"
225        let server = async move {
226            let acceptor = tokio_rustls::TlsAcceptor::from(server_config);
227            let mut stream = acceptor
228                .accept(server_stream)
229                .await
230                .context("Server didn't accept")?;
231
232            // server: >> recv "hello"
233            let mut req = Vec::new();
234            stream.read_to_end(&mut req).await.context("Read failed")?;
235            assert_eq!(&req, b"hello");
236
237            // server: << send "goodbye"
238            stream
239                .write_all(b"goodbye")
240                .await
241                .context("Could not write goodbye")?;
242            stream.shutdown().await.context("Could not shutdown")?;
243
244            Ok::<_, anyhow::Error>(())
245        };
246
247        let (client_result, server_result) = tokio::join!(client, server);
248
249        // Convert `anyhow::Error`s to strings for better ergonomics downstream
250        let (client_result, server_result) = (
251            client_result.map_err(|e| format!("{e:#}")),
252            server_result.map_err(|e| format!("{e:#}")),
253        );
254
255        println!("Client result: {client_result:?}");
256        println!("Server result: {server_result:?}");
257        println!("---");
258
259        [client_result, server_result]
260    }
261}