1use clap::Args;
2use miette::{Diagnostic, Result};
3use rustls::{ClientConfig, RootCertStore, ServerConfig};
4use rustls_pki_types::{CertificateDer, PrivateKeyDer, pem::PemObject};
5use serde::{Deserialize, Serialize, ser::SerializeStruct};
6use std::{
7 path::{Path, PathBuf},
8 sync::OnceLock,
9};
10use thiserror::Error;
11
12static CELL: OnceLock<bool> = OnceLock::new();
13
14#[derive(Debug, Diagnostic, Error)]
15pub enum TlsError {
16 #[error("missing TLS certificate")]
17 #[diagnostic()]
18 MissingTlsCert,
19
20 #[error("missing TLS key")]
21 #[diagnostic()]
22 MissingTlsKey,
23
24 #[error("invalid TLS file: {0}, {1}")]
25 #[diagnostic()]
26 InvalidTlsFile(PathBuf, rustls_pki_types::pem::Error),
27
28 #[error("failed to parse TLS key: {0}")]
29 #[diagnostic()]
30 FailedToParseTlsKey(String),
31
32 #[error("failed to parse config: {0}")]
33 #[diagnostic()]
34 FailedToParseConfig(#[from] rustls::Error),
35}
36
37#[derive(Args, Clone, Debug, Deserialize, Serialize)]
38pub struct TlsOptions {
39 #[arg(long, conflicts_with = "remote")]
41 #[serde(default)]
42 pub tls_cert: Option<PathBuf>,
43 #[arg(long, conflicts_with = "remote")]
45 #[serde(default)]
46 pub tls_key: Option<PathBuf>,
47 #[arg(long, conflicts_with = "remote")]
49 #[serde(default)]
50 pub tls_ca: Option<PathBuf>,
51
52 #[cfg(test)]
53 pub config_dir: PathBuf,
54}
55
56impl TlsOptions {
57 #[cfg(not(test))]
58 pub fn new(
59 tls_cert: Option<PathBuf>,
60 tls_key: Option<PathBuf>,
61 tls_ca: Option<PathBuf>,
62 ) -> Self {
63 Self {
64 tls_cert,
65 tls_key,
66 tls_ca,
67 }
68 }
69
70 #[cfg(test)]
71 pub fn new(
72 tls_cert: Option<PathBuf>,
73 tls_key: Option<PathBuf>,
74 tls_ca: Option<PathBuf>,
75 ) -> Self {
76 Self {
77 tls_cert,
78 tls_key,
79 tls_ca,
80 config_dir: tempfile::TempDir::new().unwrap().path().to_path_buf(),
81 }
82 }
83
84 pub fn is_secure(&self) -> bool {
85 self.cert_path().is_some() && self.key_path().is_some()
86 }
87
88 pub fn server_config(&self) -> Result<Option<ServerConfig>> {
89 if !self.is_secure() {
90 return Ok(None);
91 }
92
93 CELL.get_or_init(install_default_tls_provider);
94
95 let (mut cert_chain, key) =
96 parse_cert_and_key(self.cert_path().as_ref(), self.key_path().as_ref())?;
97
98 if let Some(path) = self.ca_path() {
99 let certs = parse_certificates(path)?;
100 if !certs.is_empty() {
101 cert_chain.extend(certs);
102 }
103 }
104
105 let mut config = ServerConfig::builder()
106 .with_no_client_auth()
107 .with_single_cert(cert_chain, key)
108 .map_err(TlsError::FailedToParseConfig)?;
109
110 config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
111
112 Ok(Some(config))
113 }
114
115 pub fn client_config(&self) -> Result<ClientConfig> {
116 CELL.get_or_init(install_default_tls_provider);
117
118 let builder = if let Some(path) = self.ca_path() {
119 let mut root_store = RootCertStore::empty();
120 root_store.add_parsable_certificates(parse_certificates(path)?);
121 ClientConfig::builder().with_root_certificates(root_store)
122 } else {
123 use rustls_platform_verifier::BuilderVerifierExt;
124 ClientConfig::builder().with_platform_verifier()
125 };
126
127 let (cert, key) = parse_cert_and_key(self.cert_path().as_ref(), self.key_path().as_ref())?;
128
129 let config = builder
130 .with_client_auth_cert(cert, key)
131 .map_err(TlsError::FailedToParseConfig)?;
132
133 Ok(config)
134 }
135
136 fn cert_path(&self) -> Option<PathBuf> {
137 self.tls_cert.clone().or_else(|| self.cached_cert_path())
138 }
139
140 fn key_path(&self) -> Option<PathBuf> {
141 self.tls_key.clone().or_else(|| self.cached_key_path())
142 }
143
144 fn ca_path(&self) -> Option<PathBuf> {
145 self.tls_ca.clone().or_else(|| self.cached_ca_path())
146 }
147
148 fn cached_cert_path(&self) -> Option<PathBuf> {
149 let cache = self.config_dir().map(|p| p.join("cert.pem"));
150 if cache.as_ref().is_some_and(|p| p.exists() && p.is_file()) {
151 return cache;
152 }
153
154 None
155 }
156
157 fn cached_key_path(&self) -> Option<PathBuf> {
158 let cache = self.config_dir().map(|p| p.join("key.pem"));
159 if cache.as_ref().is_some_and(|p| p.exists() && p.is_file()) {
160 return cache;
161 }
162
163 None
164 }
165
166 fn cached_ca_path(&self) -> Option<PathBuf> {
167 let cache = self.config_dir().map(|p| p.join("ca.pem"));
168 if cache.as_ref().is_some_and(|p| p.exists() && p.is_file()) {
169 return cache;
170 }
171
172 None
173 }
174
175 #[cfg(not(test))]
176 fn config_dir(&self) -> Option<PathBuf> {
177 dirs::config_dir().map(|p| p.join("cargo-lambda"))
178 }
179
180 #[cfg(test)]
181 fn config_dir(&self) -> Option<PathBuf> {
182 Some(self.config_dir.clone())
183 }
184
185 pub fn count_fields(&self) -> usize {
186 self.tls_cert.is_some() as usize
187 + self.tls_key.is_some() as usize
188 + self.tls_ca.is_some() as usize
189 }
190
191 pub fn serialize_fields<S>(
192 &self,
193 state: &mut <S as serde::Serializer>::SerializeStruct,
194 ) -> Result<(), S::Error>
195 where
196 S: serde::Serializer,
197 {
198 if let Some(tls_cert) = &self.tls_cert {
199 state.serialize_field("tls_cert", tls_cert)?;
200 }
201 if let Some(tls_key) = &self.tls_key {
202 state.serialize_field("tls_key", tls_key)?;
203 }
204 if let Some(tls_ca) = &self.tls_ca {
205 state.serialize_field("tls_ca", tls_ca)?;
206 }
207 Ok(())
208 }
209}
210
211impl Default for TlsOptions {
212 fn default() -> Self {
213 Self::new(None, None, None)
214 }
215}
216
217fn parse_certificates<P: AsRef<Path>>(path: P) -> Result<Vec<CertificateDer<'static>>> {
218 let path = path.as_ref();
219 let parser = CertificateDer::pem_file_iter(path)
220 .map_err(|e| TlsError::InvalidTlsFile(path.to_path_buf(), e))?
221 .collect::<Vec<_>>();
222
223 let mut certs = Vec::with_capacity(parser.len());
224 for cert in parser {
225 certs.push(cert.map_err(|e| TlsError::InvalidTlsFile(path.to_path_buf(), e))?);
226 }
227
228 Ok(certs)
229}
230
231fn parse_cert_and_key(
232 cert: Option<&PathBuf>,
233 key: Option<&PathBuf>,
234) -> Result<(Vec<CertificateDer<'static>>, PrivateKeyDer<'static>)> {
235 let path = cert.ok_or(TlsError::MissingTlsCert)?;
236 let cert = parse_certificates(path)?;
237
238 let path = key.ok_or(TlsError::MissingTlsKey)?;
239 let key = PrivateKeyDer::from_pem_file(path)
240 .map_err(|e| TlsError::FailedToParseTlsKey(e.to_string()))?;
241
242 Ok((cert, key))
243}
244
245fn install_default_tls_provider() -> bool {
246 rustls::crypto::aws_lc_rs::default_provider()
247 .install_default()
248 .expect("failed to install the default TLS provider");
249 true
250}
251
252#[cfg(test)]
253mod tests {
254 use super::*;
255
256 fn create_test_file(source: &str, destination: &PathBuf) {
257 std::fs::create_dir_all(destination.parent().unwrap()).unwrap();
258 std::fs::copy(source, destination).unwrap();
259 }
260
261 #[tokio::test]
262 async fn test_tls_options_default() {
263 let opts = TlsOptions::default();
264 assert!(!opts.is_secure());
265
266 create_test_file(
267 "../../tests/certs/cert.pem",
268 &opts.config_dir.join("cert.pem"),
269 );
270 create_test_file(
271 "../../tests/certs/key.pem",
272 &opts.config_dir.join("key.pem"),
273 );
274 create_test_file("../../tests/certs/ca.pem", &opts.config_dir.join("ca.pem"));
275
276 assert_eq!(opts.cert_path().unwrap(), opts.config_dir.join("cert.pem"));
278 assert_eq!(opts.key_path().unwrap(), opts.config_dir.join("key.pem"));
279 assert_eq!(opts.ca_path().unwrap(), opts.config_dir.join("ca.pem"));
280 assert!(opts.is_secure());
281
282 let config = opts.server_config().unwrap();
283 assert!(config.is_some());
284 }
285
286 #[test]
287 fn test_tls_options_with_paths() {
288 let opts = TlsOptions::new(
289 Some("../../tests/certs/cert.pem".into()),
290 Some("../../tests/certs/key.pem".into()),
291 Some("../../tests/certs/ca.pem".into()),
292 );
293
294 assert_eq!(
295 opts.cert_path().unwrap(),
296 PathBuf::from("../../tests/certs/cert.pem")
297 );
298 assert_eq!(
299 opts.key_path().unwrap(),
300 PathBuf::from("../../tests/certs/key.pem")
301 );
302 assert_eq!(
303 opts.ca_path().unwrap(),
304 PathBuf::from("../../tests/certs/ca.pem")
305 );
306 assert!(opts.is_secure());
307 }
308
309 #[test]
310 fn test_cached_paths() {
311 let opts = TlsOptions::default();
312
313 assert!(opts.cached_cert_path().is_none());
314 assert!(opts.cached_key_path().is_none());
315 assert!(opts.cached_ca_path().is_none());
316
317 create_test_file(
318 "../../tests/certs/cert.pem",
319 &opts.config_dir.join("cert.pem"),
320 );
321 create_test_file(
322 "../../tests/certs/key.pem",
323 &opts.config_dir.join("key.pem"),
324 );
325 create_test_file("../../tests/certs/ca.pem", &opts.config_dir.join("ca.pem"));
326
327 assert_eq!(
328 opts.cached_cert_path().unwrap(),
329 opts.config_dir.join("cert.pem")
330 );
331 assert_eq!(
332 opts.cached_key_path().unwrap(),
333 opts.config_dir.join("key.pem")
334 );
335 assert_eq!(
336 opts.cached_ca_path().unwrap(),
337 opts.config_dir.join("ca.pem")
338 );
339 }
340
341 #[tokio::test]
342 async fn test_server_config_with_valid_files_in_temp_dir() {
343 let opts = TlsOptions::new(
344 Some("../../tests/certs/cert.pem".into()),
345 Some("../../tests/certs/key.pem".into()),
346 None,
347 );
348
349 assert!(opts.is_secure());
350
351 let config = opts.server_config().unwrap();
352 assert!(config.is_some());
353 }
354
355 #[tokio::test]
356 async fn test_server_config_with_ca() {
357 let opts = TlsOptions::default();
358
359 create_test_file(
360 "../../tests/certs/cert.pem",
361 &opts.config_dir.join("cert.pem"),
362 );
363 create_test_file(
364 "../../tests/certs/key.pem",
365 &opts.config_dir.join("key.pem"),
366 );
367 create_test_file("../../tests/certs/ca.pem", &opts.config_dir.join("ca.pem"));
368
369 let config = opts.server_config().unwrap();
370 assert!(config.is_some());
371 }
372
373 #[tokio::test]
374 async fn test_client_config_with_ca() {
375 let opts = TlsOptions::default();
376
377 create_test_file(
378 "../../tests/certs/cert.pem",
379 &opts.config_dir.join("cert.pem"),
380 );
381 create_test_file(
382 "../../tests/certs/key.pem",
383 &opts.config_dir.join("key.pem"),
384 );
385 create_test_file("../../tests/certs/ca.pem", &opts.config_dir.join("ca.pem"));
386
387 let config = opts.client_config().unwrap();
388 assert!(config.alpn_protocols.is_empty()); }
390
391 #[tokio::test]
392 async fn test_client_config_without_ca() {
393 let opts = TlsOptions::default();
394
395 create_test_file(
396 "../../tests/certs/cert.pem",
397 &opts.config_dir.join("cert.pem"),
398 );
399 create_test_file(
400 "../../tests/certs/key.pem",
401 &opts.config_dir.join("key.pem"),
402 );
403
404 let config = opts.client_config().unwrap();
405 assert!(config.alpn_protocols.is_empty());
406 }
407}