Skip to main content

htsget_config/http/
mod.rs

1//! Configuration related to TLS.
2//!
3
4pub mod client;
5
6use std::fs::File;
7use std::io;
8use std::io::BufReader;
9use std::path::{Path, PathBuf};
10
11use rustls::ServerConfig;
12use rustls_pemfile::Item::{Pkcs1Key, Pkcs8Key, Sec1Key};
13use rustls_pemfile::{certs, read_one};
14use rustls_pki_types::{CertificateDer, PrivateKeyDer};
15use serde::{Deserialize, Serialize};
16
17use crate::error::Error::ParseError;
18use crate::error::{Error, Result};
19use crate::types::Scheme;
20use crate::types::Scheme::{Http, Https};
21
22/// A trait to determine which scheme a key pair option has.
23pub trait KeyPairScheme {
24  /// Get the scheme.
25  fn get_scheme(&self) -> Scheme;
26}
27
28/// A certificate and key pair used for TLS. Serialization is not implemented because there
29/// is no way to convert back to a `PathBuf`.
30#[derive(Deserialize, Debug, Clone)]
31#[serde(try_from = "CertificateKeyPairPath", deny_unknown_fields)]
32pub struct TlsServerConfig {
33  server_config: ServerConfig,
34}
35
36impl TlsServerConfig {
37  /// Create a new TlsServerConfig.
38  pub fn new(server_config: ServerConfig) -> Self {
39    Self { server_config }
40  }
41
42  /// Get the inner server config.
43  pub fn into_inner(self) -> ServerConfig {
44    self.server_config
45  }
46}
47
48/// The location of a certificate and key pair used for TLS.
49/// This is the path to the PEM formatted X.509 certificate and private key.
50#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
51#[serde(deny_unknown_fields)]
52pub struct CertificateKeyPairPath {
53  cert: PathBuf,
54  key: PathBuf,
55}
56
57/// The certificate and key pair used for TLS.
58#[derive(Debug, PartialEq, Eq)]
59pub struct CertificateKeyPair {
60  certs: Vec<CertificateDer<'static>>,
61  key: PrivateKeyDer<'static>,
62}
63
64impl CertificateKeyPair {
65  /// Create a new CertificateKeyPair.
66  pub fn new(certs: Vec<CertificateDer<'static>>, key: PrivateKeyDer<'static>) -> Self {
67    Self { certs, key }
68  }
69
70  /// Get the owned certificate and private key.
71  pub fn into_inner(self) -> (Vec<CertificateDer<'static>>, PrivateKeyDer<'static>) {
72    (self.certs, self.key)
73  }
74}
75
76/// The location of a certificate and key pair used for TLS.
77/// This is the path to the PEM formatted X.509 certificate and private key.
78#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
79#[serde(deny_unknown_fields, default)]
80pub struct RootCertStorePair {
81  #[serde(flatten)]
82  key_pair: Option<CertificateKeyPairPath>,
83  root_store: Option<PathBuf>,
84  use_cache: bool,
85}
86
87impl Default for RootCertStorePair {
88  fn default() -> Self {
89    Self {
90      key_pair: None,
91      root_store: None,
92      use_cache: true,
93    }
94  }
95}
96
97impl RootCertStorePair {
98  /// Create a new RootCertStorePair.
99  pub fn new(
100    key_pair: Option<CertificateKeyPairPath>,
101    root_store: Option<PathBuf>,
102    use_cache: bool,
103  ) -> Self {
104    Self {
105      key_pair,
106      root_store,
107      use_cache,
108    }
109  }
110
111  /// Get the owned root store pair.
112  pub fn into_inner(self) -> (Option<CertificateKeyPairPath>, Option<PathBuf>, bool) {
113    (self.key_pair, self.root_store, self.use_cache)
114  }
115}
116
117impl TryFrom<CertificateKeyPairPath> for TlsServerConfig {
118  type Error = Error;
119
120  fn try_from(key_pair: CertificateKeyPairPath) -> Result<Self> {
121    let server_config = tls_server_config(key_pair.try_into()?)?;
122
123    Ok(Self::new(server_config))
124  }
125}
126
127impl TryFrom<CertificateKeyPairPath> for CertificateKeyPair {
128  type Error = Error;
129
130  fn try_from(key_pair: CertificateKeyPairPath) -> Result<Self> {
131    let certs = load_certs(key_pair.cert)?;
132    let key = load_key(key_pair.key)?;
133
134    Ok(CertificateKeyPair::new(certs, key))
135  }
136}
137
138impl CertificateKeyPairPath {
139  /// Create a new certificate key pair.
140  pub fn new(cert: PathBuf, key: PathBuf) -> Self {
141    Self { cert, key }
142  }
143
144  /// Get the certs path.
145  pub fn certs(&self) -> &Path {
146    &self.cert
147  }
148
149  /// Get the key path.
150  pub fn key(&self) -> &Path {
151    &self.key
152  }
153}
154
155impl KeyPairScheme for Option<&TlsServerConfig> {
156  fn get_scheme(&self) -> Scheme {
157    match self {
158      None => Http,
159      Some(_) => Https,
160    }
161  }
162}
163
164/// Loads the first private key from a file. Supports RSA, PKCS8, and Sec1 encoded EC keys.
165pub fn load_key<P: AsRef<Path>>(key_path: P) -> Result<PrivateKeyDer<'static>> {
166  let mut key_reader = BufReader::new(File::open(key_path)?);
167
168  loop {
169    match read_one(&mut key_reader)? {
170      Some(Pkcs1Key(key)) => return Ok(PrivateKeyDer::from(key)),
171      Some(Pkcs8Key(key)) => return Ok(PrivateKeyDer::from(key)),
172      Some(Sec1Key(key)) => return Ok(PrivateKeyDer::from(key)),
173      // Silently disregard unknown private keys.
174      Some(_) => continue,
175      None => break,
176    }
177  }
178
179  Err(ParseError("no keys found in pem file".to_string()))
180}
181
182/// Load certificates from a file.
183pub fn load_certs<P: AsRef<Path>>(certs_path: P) -> Result<Vec<CertificateDer<'static>>> {
184  let mut cert_reader = BufReader::new(File::open(certs_path)?);
185
186  let certs: Vec<CertificateDer> =
187    certs(&mut cert_reader).collect::<io::Result<Vec<CertificateDer>>>()?;
188  if certs.is_empty() {
189    return Err(ParseError("no certificates found in .pem file".to_string()));
190  }
191
192  Ok(certs)
193}
194
195/// Load TLS server config.
196pub fn tls_server_config(key_pair: CertificateKeyPair) -> Result<ServerConfig> {
197  let (certs, key) = key_pair.into_inner();
198
199  let mut config = ServerConfig::builder()
200    .with_no_client_auth()
201    .with_single_cert(certs, key)
202    .map_err(|err| ParseError(err.to_string()))?;
203
204  config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
205
206  Ok(config)
207}
208
209#[cfg(test)]
210pub(crate) mod tests {
211  use std::fs::write;
212  use std::io::Cursor;
213  use std::path::Path;
214
215  use super::*;
216  use rcgen::generate_simple_self_signed;
217  use rustls::crypto::aws_lc_rs;
218  use rustls_pemfile::{certs, pkcs8_private_keys};
219  use tempfile::TempDir;
220
221  #[test]
222  fn test_load_key() {
223    with_test_certificates(|path, key, _| {
224      let key_path = path.join("key.pem");
225      let loaded_key = load_key(key_path).unwrap();
226
227      assert_eq!(loaded_key, key);
228    });
229  }
230
231  #[test]
232  fn test_load_cert() {
233    with_test_certificates(|path, _, cert| {
234      let cert_path = path.join("cert.pem");
235      let certs = load_certs(cert_path).unwrap();
236
237      assert_eq!(certs.len(), 1);
238      assert_eq!(certs.into_iter().next().unwrap(), cert);
239    });
240  }
241
242  #[tokio::test]
243  async fn test_tls_server_config() {
244    with_test_certificates(|_, key, cert| {
245      let server_config = tls_server_config(CertificateKeyPair::new(vec![cert], key)).unwrap();
246
247      assert_eq!(
248        server_config.alpn_protocols,
249        vec![b"h2".to_vec(), b"http/1.1".to_vec()]
250      );
251    });
252  }
253
254  pub(crate) fn with_test_certificates<F>(test: F)
255  where
256    F: FnOnce(&Path, PrivateKeyDer<'static>, CertificateDer<'static>),
257  {
258    let _ = aws_lc_rs::default_provider().install_default();
259
260    let tmp_dir = TempDir::new().unwrap();
261
262    let key_path = tmp_dir.path().join("key.pem");
263    let cert_path = tmp_dir.path().join("cert.pem");
264
265    let cert = generate_simple_self_signed(vec!["localhost".to_string()]).unwrap();
266
267    let key = cert.signing_key.serialize_pem();
268    let cert = cert.cert.pem();
269
270    write(key_path, &key).unwrap();
271    write(cert_path, &cert).unwrap();
272
273    let key = PrivateKeyDer::from(
274      pkcs8_private_keys(&mut Cursor::new(key.clone()))
275        .next()
276        .unwrap()
277        .unwrap(),
278    );
279    let cert = certs(&mut Cursor::new(cert)).next().unwrap().unwrap();
280
281    test(tmp_dir.path(), key, cert);
282  }
283}