ic_bn_lib/tls/
mod.rs

1#[cfg(feature = "acme")]
2pub mod acme;
3#[cfg(feature = "cert-providers")]
4pub mod providers;
5pub mod resolver;
6pub mod sessions;
7pub mod tickets;
8pub mod verify;
9
10use std::{fs::read, path::PathBuf, sync::Arc};
11
12use anyhow::{Context, anyhow};
13use fqdn::{FQDN, Fqdn};
14use ic_bn_lib_common::types::{
15    http::{ALPN_H1, ALPN_H2},
16    tls::TlsOptions,
17};
18use prometheus::Registry;
19use rustls::{
20    ClientConfig, ServerConfig, SupportedProtocolVersion, TicketRotator,
21    client::{ClientSessionMemoryCache, Resumption},
22    compress::CompressionCache,
23    crypto::ring,
24    server::ResolvesServerCert,
25    sign::CertifiedKey,
26};
27use rustls_platform_verifier::Verifier;
28use std::net::{Ipv4Addr, Ipv6Addr};
29use x509_parser::prelude::{FromDer, GeneralName, ParsedExtension, X509Certificate};
30
31/// Generic error for now
32/// TODO improve
33#[derive(thiserror::Error, Debug)]
34pub enum Error {
35    #[error(transparent)]
36    Generic(#[from] anyhow::Error),
37}
38
39/// Checks if given host matches any of domains.
40/// If wildcard is true then also checks if host is a direct child of any of domains
41pub fn sni_matches(host: &Fqdn, domains: &[FQDN], wildcard: bool) -> bool {
42    domains
43        .iter()
44        .any(|x| x == host || (wildcard && Some(x.as_ref()) == host.parent()))
45}
46
47fn parse_general_name(name: &GeneralName<'_>) -> Result<Option<String>, Error> {
48    let name = match name {
49        GeneralName::DNSName(v) => (*v).to_string(),
50        GeneralName::IPAddress(v) => match v.len() {
51            4 => {
52                let b: [u8; 4] = (*v).try_into().unwrap(); // We already checked that it's 4
53                let ip = Ipv4Addr::from(b);
54                ip.to_string()
55            }
56
57            16 => {
58                let b: [u8; 16] = (*v).try_into().unwrap(); // We already checked that it's 16
59                let ip = Ipv6Addr::from(b);
60                ip.to_string()
61            }
62
63            _ => return Err(anyhow!("Invalid IP address length {}", v.len()).into()),
64        },
65
66        // Ignore other types
67        _ => return Ok(None),
68    };
69
70    Ok(Some(name))
71}
72
73/// Extracts a list of SubjectAlternativeName from a single certificate in DER format, formatted as strings.
74/// Skips everything except DNSName and IPAddress
75pub fn extract_sans_der(cert: &[u8]) -> Result<Vec<String>, Error> {
76    let cert = X509Certificate::from_der(cert)
77        .context("unable to parse DER-encoded certificate")?
78        .1;
79
80    // Extract a list of SANs from the 1st certificate in the chain (the leaf one)
81    extract_sans(&cert)
82}
83
84/// Parses the given PEM-encoded certificate (1st if there are more than one) & extracts its validity period.
85pub fn extract_validity(mut pem: &[u8]) -> Result<(i64, i64), Error> {
86    let certs = rustls_pemfile::certs(&mut pem)
87        .collect::<Result<Vec<_>, _>>()
88        .context("unable to read certificate")?;
89
90    if certs.is_empty() {
91        return Err(anyhow!("no certificates found").into());
92    }
93
94    extract_validity_der(&certs[0])
95}
96
97/// Parses the given DER-encoded certificate (1st if there are more than one) & extracts its validity period.
98pub fn extract_validity_der(der: &[u8]) -> Result<(i64, i64), Error> {
99    let cert = X509Certificate::from_der(der)
100        .context("unable to parse DER-encoded certificate")?
101        .1;
102
103    Ok((
104        cert.validity().not_before.timestamp(),
105        cert.validity().not_after.timestamp(),
106    ))
107}
108
109/// Extracts a list of SubjectAlternativeName from a single certificate, formatted as strings.
110/// Skips everything except DNSName and IPAddress
111pub fn extract_sans(cert: &X509Certificate) -> Result<Vec<String>, Error> {
112    for ext in cert.extensions() {
113        if let ParsedExtension::SubjectAlternativeName(san) = ext.parsed_extension() {
114            let names = san
115                .general_names
116                .iter()
117                .map(parse_general_name)
118                .collect::<Result<Vec<_>, _>>()?
119                .into_iter()
120                .flatten()
121                .collect::<Vec<_>>();
122
123            return Ok(names);
124        }
125    }
126
127    Err(anyhow!("SubjectAlternativeName extension not found").into())
128}
129
130/// Converts raw PEM certificate chain & private key to a CertifiedKey ready to be consumed by Rustls.
131/// This reads the first private key and ignores any others.
132pub fn pem_convert_to_rustls(key: &[u8], certs: &[u8]) -> Result<CertifiedKey, Error> {
133    let (key, certs) = (key.to_vec(), certs.to_vec());
134    #[allow(clippy::tuple_array_conversions)] // Clippy being stupid here
135    let pem = [key, certs].concat();
136
137    pem_convert_to_rustls_single(&pem)
138}
139
140/// Converts raw concatenated PEM certificate chain & private key to a CertifiedKey ready to be consumed by Rustls.
141/// This reads the first private key and ignores any others.
142pub fn pem_convert_to_rustls_single(pem: &[u8]) -> Result<CertifiedKey, Error> {
143    let pem = pem.to_vec();
144
145    let key = rustls_pemfile::private_key(&mut pem.as_ref())
146        .context("unable to read private key")?
147        .ok_or_else(|| anyhow!("no private key found"))?;
148
149    // Load the cert chain
150    let certs = rustls_pemfile::certs(&mut pem.as_ref())
151        .collect::<Result<Vec<_>, _>>()
152        .context("unable to read certificate chain")?;
153
154    if certs.is_empty() {
155        return Err(anyhow!("no certificates found").into());
156    }
157
158    // Parse private key
159    let key = ring::sign::any_supported_type(&key).context("unable to parse private key")?;
160
161    Ok(CertifiedKey::new(certs, key))
162}
163
164/// Loads raw concatenated PEM certificate chain & private key and converts to a CertifiedKey ready to be consumed by Rustls.
165/// This reads the first private key and ignores any others.
166pub fn pem_load_rustls(key: PathBuf, certs: PathBuf) -> Result<CertifiedKey, Error> {
167    let key = read(key).context("unable to read private key")?;
168    let certs = read(certs).context("unable to read certificate chain")?;
169    pem_convert_to_rustls(&key, &certs)
170}
171
172/// Loads raw PEM certificate chain & private key and converts to a CertifiedKey ready to be consumed by Rustls.
173/// This reads the first private key and ignores any others.
174pub fn pem_load_rustls_single(pem: PathBuf) -> Result<CertifiedKey, Error> {
175    let pem = read(pem).context("unable to read PEM file")?;
176    pem_convert_to_rustls_single(&pem)
177}
178
179/// Creates Rustls server config.
180/// Must be run in Tokio environment since it spawns a task to record metrics
181pub fn prepare_server_config(
182    opts: TlsOptions,
183    resolver: Arc<dyn ResolvesServerCert>,
184    registry: &Registry,
185) -> ServerConfig {
186    let mut cfg = ServerConfig::builder_with_protocol_versions(&opts.tls_versions)
187        .with_no_client_auth()
188        .with_cert_resolver(resolver);
189
190    // Create custom session storage to allow effective TLS session resumption
191    let session_storage = Arc::new(sessions::Storage::new(
192        opts.sessions_count,
193        opts.sessions_tti,
194        registry,
195    ));
196    let session_storage_metrics = session_storage.clone();
197    // Spawn metrics runner
198    tokio::spawn(async move { session_storage_metrics.metrics_runner().await });
199    cfg.session_storage = session_storage;
200
201    // Enable ticketer to encrypt/decrypt TLS tickets.
202    // TicketSwitcher rotates the inner ticketers every `ticket_lifetime`
203    // while keeping the previous one available for decryption of tickets
204    // issued earlier than `ticket_lifetime` ago.
205    let ticketer = tickets::WithMetrics(
206        TicketRotator::new(opts.ticket_lifetime.as_secs() as u32, move || {
207            Ok(Box::new(tickets::Ticketer::new()))
208        })
209        .unwrap(),
210        tickets::Metrics::new(registry),
211    );
212    cfg.ticketer = Arc::new(ticketer);
213
214    // Enable certificate compression cache.
215    // See https://datatracker.ietf.org/doc/rfc8879/ for details
216    cfg.cert_compression_cache = Arc::new(CompressionCache::new(8192));
217
218    // Enable ALPN
219    cfg.alpn_protocols = vec![ALPN_H2.to_vec(), ALPN_H1.to_vec()];
220    cfg.alpn_protocols.extend_from_slice(&opts.additional_alpn);
221
222    cfg
223}
224
225pub fn prepare_client_config(tls_versions: &[&'static SupportedProtocolVersion]) -> ClientConfig {
226    let cfg = ClientConfig::builder_with_protocol_versions(tls_versions);
227
228    let crypto_provider = rustls::crypto::CryptoProvider::get_default()
229        .unwrap()
230        .clone();
231
232    // Use a custom certificate verifier from rustls project that is presumably secure.
233    let verifier = Verifier::new_with_extra_roots(
234        webpki_root_certs::TLS_SERVER_ROOT_CERTS.iter().cloned(),
235        crypto_provider,
236    )
237    .unwrap();
238
239    let mut cfg = cfg
240        .dangerous() // Nothing really dangerous here
241        .with_custom_certificate_verifier(Arc::new(verifier))
242        .with_no_client_auth();
243
244    // Session resumption
245    let store = ClientSessionMemoryCache::new(2048);
246    cfg.resumption = Resumption::store(Arc::new(store));
247    cfg.alpn_protocols = vec![ALPN_H2.to_vec(), ALPN_H1.to_vec()];
248
249    cfg
250}
251
252#[cfg(test)]
253mod test {
254    use fqdn::fqdn;
255
256    use crate::tests::{TEST_CERT_1, TEST_KEY_1};
257
258    use super::*;
259
260    #[test]
261    fn test_sni_matches() {
262        let domains = vec![fqdn!("foo1.bar"), fqdn!("foo2.bar"), fqdn!("foo3.bar")];
263
264        // Check direct
265        assert!(sni_matches(&fqdn!("foo1.bar"), &domains, false));
266        assert!(sni_matches(&fqdn!("foo2.bar"), &domains, false));
267        assert!(sni_matches(&fqdn!("foo3.bar"), &domains, false));
268        assert!(!sni_matches(&fqdn!("foo4.bar"), &domains, false));
269
270        // Check wildcard
271        assert!(sni_matches(&fqdn!("foo1.bar"), &domains, true));
272        assert!(sni_matches(&fqdn!("baz.foo1.bar"), &domains, true));
273        assert!(sni_matches(&fqdn!("bza.foo1.bar"), &domains, true));
274        assert!(sni_matches(&fqdn!("baz.foo2.bar"), &domains, true));
275        assert!(sni_matches(&fqdn!("bza.foo2.bar"), &domains, true));
276
277        // Make sure deeper subdomains are not matched
278        assert!(!sni_matches(&fqdn!("baz.baz.foo1.bar"), &domains, true));
279    }
280
281    #[test]
282    fn test_pem_convert_to_rustls_single() {
283        let pem = [TEST_KEY_1, TEST_CERT_1].concat();
284        let res = pem_convert_to_rustls_single(pem.as_bytes()).unwrap();
285        assert!(res.cert.len() == 1);
286    }
287
288    #[test]
289    fn test_pem_convert_to_rustls() {
290        let res = pem_convert_to_rustls(TEST_KEY_1.as_bytes(), TEST_CERT_1.as_bytes()).unwrap();
291        assert!(res.cert.len() == 1);
292    }
293
294    #[test]
295    fn test_prepare_client_config() {
296        prepare_client_config(&[&rustls::version::TLS13, &rustls::version::TLS12]);
297    }
298
299    #[test]
300    fn test_extract_validity() {
301        let (from, to) = extract_validity(TEST_CERT_1.as_bytes()).unwrap();
302        assert_eq!(from, 1673300396);
303        assert_eq!(to, 1988660396);
304    }
305}