1use std::{str::FromStr, sync::LazyLock};
4
5use asn1_rs::FromDer;
6use lexe_byte_array::ByteArray;
7use lexe_crypto::ed25519;
8use lexe_sha256::sha256;
9use rcgen::{DistinguishedName, DnType, string::Ia5String};
10use x509_parser::{
11 certificate::X509Certificate, extensions::GeneralName, time::ASN1Time,
12};
13
14pub mod attest_client;
16pub mod ed25519_ext;
18pub mod lexe_ca;
20pub mod p256;
22pub mod shared_seed;
24pub mod types;
26
27pub use lexe_tls_core::*;
29
30#[must_use]
35pub fn cert_contains_dns(cert_der: &[u8], expected_dns: &[&str]) -> bool {
36 fn contains_dns(cert_der: &[u8], expected_dns: &[&str]) -> Option<()> {
37 if expected_dns.is_empty() {
38 return Some(());
39 }
40
41 let (_unparsed, cert) = X509Certificate::from_der(cert_der).ok()?;
42
43 let sans = &cert.subject_alternative_name().ok()??.value.general_names;
44
45 expected_dns
46 .iter()
47 .all(|dns_name| sans.contains(&GeneralName::DNSName(dns_name)))
48 .then_some(())
49 }
50
51 contains_dns(cert_der, expected_dns).is_some()
52}
53
54#[must_use]
59pub fn cert_is_valid_for_at_least(cert_der: &[u8], buffer_days: u16) -> bool {
60 fn is_valid_for_at_least(cert_der: &[u8], buffer_days: i64) -> Option<()> {
61 use std::ops::Add;
62
63 let (_unparsed, cert) = X509Certificate::from_der(cert_der).ok()?;
64
65 let now = ASN1Time::now();
66 let validity = cert.validity();
67
68 if now < validity.not_before {
69 return None;
70 }
71 if now > validity.not_after {
72 return None;
73 }
74
75 let buffer_days_dur = time::Duration::days(buffer_days);
76
77 let now_plus_buffer = now.add(buffer_days_dur)?;
79 if now_plus_buffer < validity.not_before {
80 return None;
81 }
82 if now_plus_buffer > validity.not_after {
83 return None;
84 }
85
86 Some(())
87 }
88
89 is_valid_for_at_least(cert_der, i64::from(buffer_days)).is_some()
90}
91
92pub static DEFAULT_SUBJECT_ALT_NAMES: LazyLock<Vec<rcgen::SanType>> =
95 LazyLock::new(|| {
96 vec![rcgen::SanType::DnsName(
97 Ia5String::from_str("lexe.app").unwrap(),
98 )]
99 });
100
101pub fn build_rcgen_cert_params(
115 common_name: &str,
116 not_before: time::OffsetDateTime,
117 not_after: time::OffsetDateTime,
118 subject_alt_names: Vec<rcgen::SanType>,
119 public_key: &ed25519::PublicKey,
120 overrides: impl FnOnce(&mut rcgen::CertificateParams),
121) -> rcgen::CertificateParams {
122 let mut params = rcgen::CertificateParams::default();
123
124 params.not_before = not_before;
125 params.not_after = not_after;
126 params.subject_alt_names = subject_alt_names;
127 params.distinguished_name = lexe_distinguished_name(common_name);
128
129 overrides(&mut params);
139
140 let pubkey_hash = {
149 let hash = sha256::digest(public_key.as_slice());
150 hash.as_slice()[0..20].to_vec()
151 };
152
153 if matches!(params.is_ca, rcgen::IsCa::Ca(_) | rcgen::IsCa::ExplicitNoCa) {
156 params.key_identifier_method =
157 rcgen::KeyIdMethod::PreSpecified(pubkey_hash.clone());
158 }
159
160 let mut serial = pubkey_hash;
162 serial[0] &= 0x7f; params.serial_number = Some(rcgen::SerialNumber::from(serial));
164
165 params
166}
167
168pub fn lexe_distinguished_name(common_name: &str) -> DistinguishedName {
170 let mut name = DistinguishedName::new();
171 name.push(DnType::CountryName, "US");
172 name.push(DnType::StateOrProvinceName, "CA");
173 name.push(DnType::OrganizationName, "lexe-app");
174 name.push(DnType::CommonName, common_name);
175 name
176}
177
178#[cfg(any(test, feature = "test-utils"))]
180pub mod test_utils {
181 use std::sync::Arc;
182
183 use anyhow::Context;
184 use rustls::{ClientConfig, ServerConfig, pki_types::ServerName};
185 use tokio::io::{AsyncReadExt, AsyncWriteExt};
186
187 pub async fn do_tls_handshake(
191 client_config: Arc<ClientConfig>,
192 server_config: Arc<ServerConfig>,
193 expected_dns: &str,
195 ) -> [Result<(), String>; 2] {
196 let (client_stream, server_stream) = tokio::io::duplex(4096);
198
199 let client = async move {
201 let connector = tokio_rustls::TlsConnector::from(client_config);
202 let sni = ServerName::try_from(expected_dns.to_owned()).unwrap();
203 let mut stream = connector
204 .connect(sni, client_stream)
205 .await
206 .context("Client didn't connect")?;
207
208 stream
210 .write_all(b"hello")
211 .await
212 .context("Could not write hello")?;
213 stream.flush().await.context("Toilet clogged")?;
214 stream.shutdown().await.context("Could not shutdown")?;
215
216 let mut resp = Vec::new();
218 stream.read_to_end(&mut resp).await.context("Read failed")?;
219 assert_eq!(&resp, b"goodbye");
220
221 Ok::<_, anyhow::Error>(())
222 };
223
224 let server = async move {
226 let acceptor = tokio_rustls::TlsAcceptor::from(server_config);
227 let mut stream = acceptor
228 .accept(server_stream)
229 .await
230 .context("Server didn't accept")?;
231
232 let mut req = Vec::new();
234 stream.read_to_end(&mut req).await.context("Read failed")?;
235 assert_eq!(&req, b"hello");
236
237 stream
239 .write_all(b"goodbye")
240 .await
241 .context("Could not write goodbye")?;
242 stream.shutdown().await.context("Could not shutdown")?;
243
244 Ok::<_, anyhow::Error>(())
245 };
246
247 let (client_result, server_result) = tokio::join!(client, server);
248
249 let (client_result, server_result) = (
251 client_result.map_err(|e| format!("{e:#}")),
252 server_result.map_err(|e| format!("{e:#}")),
253 );
254
255 println!("Client result: {client_result:?}");
256 println!("Server result: {server_result:?}");
257 println!("---");
258
259 [client_result, server_result]
260 }
261}