1use clap::Args;
2use miette::{Diagnostic, Result};
3use rustls::{ClientConfig, RootCertStore, ServerConfig};
4use rustls_pki_types::{CertificateDer, PrivateKeyDer, pem::PemObject};
5use serde::{Deserialize, Serialize, ser::SerializeStruct};
6use std::{
7 path::{Path, PathBuf},
8 sync::OnceLock,
9};
10use thiserror::Error;
11
12static CELL: OnceLock<bool> = OnceLock::new();
13
14#[derive(Debug, Diagnostic, Error)]
15pub enum TlsError {
16 #[error("missing TLS certificate")]
17 #[diagnostic()]
18 MissingTlsCert,
19
20 #[error("missing TLS key")]
21 #[diagnostic()]
22 MissingTlsKey,
23
24 #[error("invalid TLS file: {0}, {1}")]
25 #[diagnostic()]
26 InvalidTlsFile(PathBuf, rustls_pki_types::pem::Error),
27
28 #[error("failed to parse TLS key: {0}")]
29 #[diagnostic()]
30 FailedToParseTlsKey(String),
31
32 #[error("failed to parse config: {0}")]
33 #[diagnostic()]
34 FailedToParseConfig(#[from] rustls::Error),
35}
36
37#[derive(Args, Clone, Debug, Deserialize, Serialize)]
38pub struct TlsOptions {
39 #[arg(long, conflicts_with = "remote")]
41 #[serde(default)]
42 pub tls_cert: Option<PathBuf>,
43 #[arg(long, conflicts_with = "remote")]
45 #[serde(default)]
46 pub tls_key: Option<PathBuf>,
47 #[arg(long, conflicts_with = "remote")]
49 #[serde(default)]
50 pub tls_ca: Option<PathBuf>,
51
52 #[cfg(test)]
53 pub config_dir: PathBuf,
54}
55
56impl TlsOptions {
57 #[cfg(not(test))]
58 pub fn new(
59 tls_cert: Option<PathBuf>,
60 tls_key: Option<PathBuf>,
61 tls_ca: Option<PathBuf>,
62 ) -> Self {
63 Self {
64 tls_cert,
65 tls_key,
66 tls_ca,
67 }
68 }
69
70 #[cfg(test)]
71 pub fn new(
72 tls_cert: Option<PathBuf>,
73 tls_key: Option<PathBuf>,
74 tls_ca: Option<PathBuf>,
75 ) -> Self {
76 Self {
77 tls_cert,
78 tls_key,
79 tls_ca,
80 config_dir: tempfile::TempDir::new().unwrap().path().to_path_buf(),
81 }
82 }
83
84 pub fn is_secure(&self) -> bool {
85 self.cert_path().is_some() && self.key_path().is_some()
86 }
87
88 pub fn server_config(&self) -> Result<Option<ServerConfig>> {
89 if !self.is_secure() {
90 return Ok(None);
91 }
92
93 CELL.get_or_init(install_default_tls_provider);
94
95 let (mut cert_chain, key) =
96 parse_cert_and_key(self.cert_path().as_ref(), self.key_path().as_ref())?;
97
98 if let Some(path) = self.ca_path() {
99 let certs = parse_certificates(path)?;
100 if !certs.is_empty() {
101 cert_chain.extend(certs);
102 }
103 }
104
105 let mut config = ServerConfig::builder()
106 .with_no_client_auth()
107 .with_single_cert(cert_chain, key)
108 .map_err(TlsError::FailedToParseConfig)?;
109
110 config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
111
112 Ok(Some(config))
113 }
114
115 pub fn client_config(&self) -> Result<ClientConfig> {
116 CELL.get_or_init(install_default_tls_provider);
117
118 let builder = if let Some(path) = self.ca_path() {
119 let mut root_store = RootCertStore::empty();
120 root_store.add_parsable_certificates(parse_certificates(path)?);
121 ClientConfig::builder().with_root_certificates(root_store)
122 } else {
123 use rustls_platform_verifier::BuilderVerifierExt;
124 ClientConfig::builder().with_platform_verifier()
125 };
126
127 let (cert, key) = parse_cert_and_key(self.cert_path().as_ref(), self.key_path().as_ref())?;
128
129 let config = builder
130 .with_client_auth_cert(cert, key)
131 .map_err(TlsError::FailedToParseConfig)?;
132
133 Ok(config)
134 }
135
136 fn cert_path(&self) -> Option<PathBuf> {
137 self.tls_cert.clone().or_else(|| self.cached_cert_path())
138 }
139
140 fn key_path(&self) -> Option<PathBuf> {
141 self.tls_key.clone().or_else(|| self.cached_key_path())
142 }
143
144 fn ca_path(&self) -> Option<PathBuf> {
145 self.tls_ca.clone().or_else(|| self.cached_ca_path())
146 }
147
148 fn cached_cert_path(&self) -> Option<PathBuf> {
149 let cache = self.config_dir().map(|p| p.join("cert.pem"));
150 if cache.as_ref().is_some_and(|p| p.exists() && p.is_file()) {
151 return cache;
152 }
153
154 None
155 }
156
157 fn cached_key_path(&self) -> Option<PathBuf> {
158 let cache = self.config_dir().map(|p| p.join("key.pem"));
159 if cache.as_ref().is_some_and(|p| p.exists() && p.is_file()) {
160 return cache;
161 }
162
163 None
164 }
165
166 fn cached_ca_path(&self) -> Option<PathBuf> {
167 let cache = self.config_dir().map(|p| p.join("ca.pem"));
168 if cache.as_ref().is_some_and(|p| p.exists() && p.is_file()) {
169 return cache;
170 }
171
172 None
173 }
174
175 #[cfg(not(test))]
176 fn config_dir(&self) -> Option<PathBuf> {
177 if let Ok(xdg_config) = std::env::var("XDG_CONFIG_HOME") {
179 let path = PathBuf::from(xdg_config).join("cargo-lambda");
180 if path.exists() {
181 return Some(path);
182 }
183 }
184
185 dirs::config_dir().map(|p| p.join("cargo-lambda"))
187 }
188
189 #[cfg(test)]
190 fn config_dir(&self) -> Option<PathBuf> {
191 Some(self.config_dir.clone())
192 }
193
194 pub fn count_fields(&self) -> usize {
195 self.tls_cert.is_some() as usize
196 + self.tls_key.is_some() as usize
197 + self.tls_ca.is_some() as usize
198 }
199
200 pub fn serialize_fields<S>(
201 &self,
202 state: &mut <S as serde::Serializer>::SerializeStruct,
203 ) -> Result<(), S::Error>
204 where
205 S: serde::Serializer,
206 {
207 if let Some(tls_cert) = &self.tls_cert {
208 state.serialize_field("tls_cert", tls_cert)?;
209 }
210 if let Some(tls_key) = &self.tls_key {
211 state.serialize_field("tls_key", tls_key)?;
212 }
213 if let Some(tls_ca) = &self.tls_ca {
214 state.serialize_field("tls_ca", tls_ca)?;
215 }
216 Ok(())
217 }
218}
219
220impl Default for TlsOptions {
221 fn default() -> Self {
222 Self::new(None, None, None)
223 }
224}
225
226fn parse_certificates<P: AsRef<Path>>(path: P) -> Result<Vec<CertificateDer<'static>>> {
227 let path = path.as_ref();
228 let parser = CertificateDer::pem_file_iter(path)
229 .map_err(|e| TlsError::InvalidTlsFile(path.to_path_buf(), e))?
230 .collect::<Vec<_>>();
231
232 let mut certs = Vec::with_capacity(parser.len());
233 for cert in parser {
234 certs.push(cert.map_err(|e| TlsError::InvalidTlsFile(path.to_path_buf(), e))?);
235 }
236
237 Ok(certs)
238}
239
240fn parse_cert_and_key(
241 cert: Option<&PathBuf>,
242 key: Option<&PathBuf>,
243) -> Result<(Vec<CertificateDer<'static>>, PrivateKeyDer<'static>)> {
244 let path = cert.ok_or(TlsError::MissingTlsCert)?;
245 let cert = parse_certificates(path)?;
246
247 let path = key.ok_or(TlsError::MissingTlsKey)?;
248 let key = PrivateKeyDer::from_pem_file(path)
249 .map_err(|e| TlsError::FailedToParseTlsKey(e.to_string()))?;
250
251 Ok((cert, key))
252}
253
254fn install_default_tls_provider() -> bool {
255 rustls::crypto::aws_lc_rs::default_provider()
256 .install_default()
257 .expect("failed to install the default TLS provider");
258 true
259}
260
261#[cfg(test)]
262mod tests {
263 use super::*;
264
265 fn create_test_file(source: &str, destination: &PathBuf) {
266 std::fs::create_dir_all(destination.parent().unwrap()).unwrap();
267 std::fs::copy(source, destination).unwrap();
268 }
269
270 #[tokio::test]
271 async fn test_tls_options_default() {
272 let opts = TlsOptions::default();
273 assert!(!opts.is_secure());
274
275 create_test_file(
276 "../../tests/certs/cert.pem",
277 &opts.config_dir.join("cert.pem"),
278 );
279 create_test_file(
280 "../../tests/certs/key.pem",
281 &opts.config_dir.join("key.pem"),
282 );
283 create_test_file("../../tests/certs/ca.pem", &opts.config_dir.join("ca.pem"));
284
285 assert_eq!(opts.cert_path().unwrap(), opts.config_dir.join("cert.pem"));
287 assert_eq!(opts.key_path().unwrap(), opts.config_dir.join("key.pem"));
288 assert_eq!(opts.ca_path().unwrap(), opts.config_dir.join("ca.pem"));
289 assert!(opts.is_secure());
290
291 let config = opts.server_config().unwrap();
292 assert!(config.is_some());
293 }
294
295 #[test]
296 fn test_tls_options_with_paths() {
297 let opts = TlsOptions::new(
298 Some("../../tests/certs/cert.pem".into()),
299 Some("../../tests/certs/key.pem".into()),
300 Some("../../tests/certs/ca.pem".into()),
301 );
302
303 assert_eq!(
304 opts.cert_path().unwrap(),
305 PathBuf::from("../../tests/certs/cert.pem")
306 );
307 assert_eq!(
308 opts.key_path().unwrap(),
309 PathBuf::from("../../tests/certs/key.pem")
310 );
311 assert_eq!(
312 opts.ca_path().unwrap(),
313 PathBuf::from("../../tests/certs/ca.pem")
314 );
315 assert!(opts.is_secure());
316 }
317
318 #[test]
319 fn test_cached_paths() {
320 let opts = TlsOptions::default();
321
322 assert!(opts.cached_cert_path().is_none());
323 assert!(opts.cached_key_path().is_none());
324 assert!(opts.cached_ca_path().is_none());
325
326 create_test_file(
327 "../../tests/certs/cert.pem",
328 &opts.config_dir.join("cert.pem"),
329 );
330 create_test_file(
331 "../../tests/certs/key.pem",
332 &opts.config_dir.join("key.pem"),
333 );
334 create_test_file("../../tests/certs/ca.pem", &opts.config_dir.join("ca.pem"));
335
336 assert_eq!(
337 opts.cached_cert_path().unwrap(),
338 opts.config_dir.join("cert.pem")
339 );
340 assert_eq!(
341 opts.cached_key_path().unwrap(),
342 opts.config_dir.join("key.pem")
343 );
344 assert_eq!(
345 opts.cached_ca_path().unwrap(),
346 opts.config_dir.join("ca.pem")
347 );
348 }
349
350 #[tokio::test]
351 async fn test_server_config_with_valid_files_in_temp_dir() {
352 let opts = TlsOptions::new(
353 Some("../../tests/certs/cert.pem".into()),
354 Some("../../tests/certs/key.pem".into()),
355 None,
356 );
357
358 assert!(opts.is_secure());
359
360 let config = opts.server_config().unwrap();
361 assert!(config.is_some());
362 }
363
364 #[tokio::test]
365 async fn test_server_config_with_ca() {
366 let opts = TlsOptions::default();
367
368 create_test_file(
369 "../../tests/certs/cert.pem",
370 &opts.config_dir.join("cert.pem"),
371 );
372 create_test_file(
373 "../../tests/certs/key.pem",
374 &opts.config_dir.join("key.pem"),
375 );
376 create_test_file("../../tests/certs/ca.pem", &opts.config_dir.join("ca.pem"));
377
378 let config = opts.server_config().unwrap();
379 assert!(config.is_some());
380 }
381
382 #[tokio::test]
383 async fn test_client_config_with_ca() {
384 let opts = TlsOptions::default();
385
386 create_test_file(
387 "../../tests/certs/cert.pem",
388 &opts.config_dir.join("cert.pem"),
389 );
390 create_test_file(
391 "../../tests/certs/key.pem",
392 &opts.config_dir.join("key.pem"),
393 );
394 create_test_file("../../tests/certs/ca.pem", &opts.config_dir.join("ca.pem"));
395
396 let config = opts.client_config().unwrap();
397 assert!(config.alpn_protocols.is_empty()); }
399
400 #[tokio::test]
401 async fn test_client_config_without_ca() {
402 let opts = TlsOptions::default();
403
404 create_test_file(
405 "../../tests/certs/cert.pem",
406 &opts.config_dir.join("cert.pem"),
407 );
408 create_test_file(
409 "../../tests/certs/key.pem",
410 &opts.config_dir.join("key.pem"),
411 );
412
413 let config = opts.client_config().unwrap();
414 assert!(config.alpn_protocols.is_empty());
415 }
416}