cargo_lambda_remote/
tls.rs

1use clap::Args;
2use miette::{Diagnostic, Result};
3use rustls::{ClientConfig, RootCertStore, ServerConfig};
4use rustls_pki_types::{CertificateDer, PrivateKeyDer, pem::PemObject};
5use serde::{Deserialize, Serialize, ser::SerializeStruct};
6use std::{
7    path::{Path, PathBuf},
8    sync::OnceLock,
9};
10use thiserror::Error;
11
12static CELL: OnceLock<bool> = OnceLock::new();
13
14#[derive(Debug, Diagnostic, Error)]
15pub enum TlsError {
16    #[error("missing TLS certificate")]
17    #[diagnostic()]
18    MissingTlsCert,
19
20    #[error("missing TLS key")]
21    #[diagnostic()]
22    MissingTlsKey,
23
24    #[error("invalid TLS file: {0}, {1}")]
25    #[diagnostic()]
26    InvalidTlsFile(PathBuf, rustls_pki_types::pem::Error),
27
28    #[error("failed to parse TLS key: {0}")]
29    #[diagnostic()]
30    FailedToParseTlsKey(String),
31
32    #[error("failed to parse config: {0}")]
33    #[diagnostic()]
34    FailedToParseConfig(#[from] rustls::Error),
35}
36
37#[derive(Args, Clone, Debug, Deserialize, Serialize)]
38pub struct TlsOptions {
39    /// Path to a TLS certificate file
40    #[arg(long, conflicts_with = "remote")]
41    #[serde(default)]
42    pub tls_cert: Option<PathBuf>,
43    /// Path to a TLS key file
44    #[arg(long, conflicts_with = "remote")]
45    #[serde(default)]
46    pub tls_key: Option<PathBuf>,
47    /// Path to a TLS CA file
48    #[arg(long, conflicts_with = "remote")]
49    #[serde(default)]
50    pub tls_ca: Option<PathBuf>,
51
52    #[cfg(test)]
53    pub config_dir: PathBuf,
54}
55
56impl TlsOptions {
57    #[cfg(not(test))]
58    pub fn new(
59        tls_cert: Option<PathBuf>,
60        tls_key: Option<PathBuf>,
61        tls_ca: Option<PathBuf>,
62    ) -> Self {
63        Self {
64            tls_cert,
65            tls_key,
66            tls_ca,
67        }
68    }
69
70    #[cfg(test)]
71    pub fn new(
72        tls_cert: Option<PathBuf>,
73        tls_key: Option<PathBuf>,
74        tls_ca: Option<PathBuf>,
75    ) -> Self {
76        Self {
77            tls_cert,
78            tls_key,
79            tls_ca,
80            config_dir: tempfile::TempDir::new().unwrap().path().to_path_buf(),
81        }
82    }
83
84    pub fn is_secure(&self) -> bool {
85        self.cert_path().is_some() && self.key_path().is_some()
86    }
87
88    pub fn server_config(&self) -> Result<Option<ServerConfig>> {
89        if !self.is_secure() {
90            return Ok(None);
91        }
92
93        CELL.get_or_init(install_default_tls_provider);
94
95        let (mut cert_chain, key) =
96            parse_cert_and_key(self.cert_path().as_ref(), self.key_path().as_ref())?;
97
98        if let Some(path) = self.ca_path() {
99            let certs = parse_certificates(path)?;
100            if !certs.is_empty() {
101                cert_chain.extend(certs);
102            }
103        }
104
105        let mut config = ServerConfig::builder()
106            .with_no_client_auth()
107            .with_single_cert(cert_chain, key)
108            .map_err(TlsError::FailedToParseConfig)?;
109
110        config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
111
112        Ok(Some(config))
113    }
114
115    pub fn client_config(&self) -> Result<ClientConfig> {
116        CELL.get_or_init(install_default_tls_provider);
117
118        let builder = if let Some(path) = self.ca_path() {
119            let mut root_store = RootCertStore::empty();
120            root_store.add_parsable_certificates(parse_certificates(path)?);
121            ClientConfig::builder().with_root_certificates(root_store)
122        } else {
123            use rustls_platform_verifier::BuilderVerifierExt;
124            ClientConfig::builder().with_platform_verifier()
125        };
126
127        let (cert, key) = parse_cert_and_key(self.cert_path().as_ref(), self.key_path().as_ref())?;
128
129        let config = builder
130            .with_client_auth_cert(cert, key)
131            .map_err(TlsError::FailedToParseConfig)?;
132
133        Ok(config)
134    }
135
136    fn cert_path(&self) -> Option<PathBuf> {
137        self.tls_cert.clone().or_else(|| self.cached_cert_path())
138    }
139
140    fn key_path(&self) -> Option<PathBuf> {
141        self.tls_key.clone().or_else(|| self.cached_key_path())
142    }
143
144    fn ca_path(&self) -> Option<PathBuf> {
145        self.tls_ca.clone().or_else(|| self.cached_ca_path())
146    }
147
148    fn cached_cert_path(&self) -> Option<PathBuf> {
149        let cache = self.config_dir().map(|p| p.join("cert.pem"));
150        if cache.as_ref().is_some_and(|p| p.exists() && p.is_file()) {
151            return cache;
152        }
153
154        None
155    }
156
157    fn cached_key_path(&self) -> Option<PathBuf> {
158        let cache = self.config_dir().map(|p| p.join("key.pem"));
159        if cache.as_ref().is_some_and(|p| p.exists() && p.is_file()) {
160            return cache;
161        }
162
163        None
164    }
165
166    fn cached_ca_path(&self) -> Option<PathBuf> {
167        let cache = self.config_dir().map(|p| p.join("ca.pem"));
168        if cache.as_ref().is_some_and(|p| p.exists() && p.is_file()) {
169            return cache;
170        }
171
172        None
173    }
174
175    #[cfg(not(test))]
176    fn config_dir(&self) -> Option<PathBuf> {
177        dirs::config_dir().map(|p| p.join("cargo-lambda"))
178    }
179
180    #[cfg(test)]
181    fn config_dir(&self) -> Option<PathBuf> {
182        Some(self.config_dir.clone())
183    }
184
185    pub fn count_fields(&self) -> usize {
186        self.tls_cert.is_some() as usize
187            + self.tls_key.is_some() as usize
188            + self.tls_ca.is_some() as usize
189    }
190
191    pub fn serialize_fields<S>(
192        &self,
193        state: &mut <S as serde::Serializer>::SerializeStruct,
194    ) -> Result<(), S::Error>
195    where
196        S: serde::Serializer,
197    {
198        if let Some(tls_cert) = &self.tls_cert {
199            state.serialize_field("tls_cert", tls_cert)?;
200        }
201        if let Some(tls_key) = &self.tls_key {
202            state.serialize_field("tls_key", tls_key)?;
203        }
204        if let Some(tls_ca) = &self.tls_ca {
205            state.serialize_field("tls_ca", tls_ca)?;
206        }
207        Ok(())
208    }
209}
210
211impl Default for TlsOptions {
212    fn default() -> Self {
213        Self::new(None, None, None)
214    }
215}
216
217fn parse_certificates<P: AsRef<Path>>(path: P) -> Result<Vec<CertificateDer<'static>>> {
218    let path = path.as_ref();
219    let parser = CertificateDer::pem_file_iter(path)
220        .map_err(|e| TlsError::InvalidTlsFile(path.to_path_buf(), e))?
221        .collect::<Vec<_>>();
222
223    let mut certs = Vec::with_capacity(parser.len());
224    for cert in parser {
225        certs.push(cert.map_err(|e| TlsError::InvalidTlsFile(path.to_path_buf(), e))?);
226    }
227
228    Ok(certs)
229}
230
231fn parse_cert_and_key(
232    cert: Option<&PathBuf>,
233    key: Option<&PathBuf>,
234) -> Result<(Vec<CertificateDer<'static>>, PrivateKeyDer<'static>)> {
235    let path = cert.ok_or(TlsError::MissingTlsCert)?;
236    let cert = parse_certificates(path)?;
237
238    let path = key.ok_or(TlsError::MissingTlsKey)?;
239    let key = PrivateKeyDer::from_pem_file(path)
240        .map_err(|e| TlsError::FailedToParseTlsKey(e.to_string()))?;
241
242    Ok((cert, key))
243}
244
245fn install_default_tls_provider() -> bool {
246    rustls::crypto::aws_lc_rs::default_provider()
247        .install_default()
248        .expect("failed to install the default TLS provider");
249    true
250}
251
252#[cfg(test)]
253mod tests {
254    use super::*;
255
256    fn create_test_file(source: &str, destination: &PathBuf) {
257        std::fs::create_dir_all(destination.parent().unwrap()).unwrap();
258        std::fs::copy(source, destination).unwrap();
259    }
260
261    #[tokio::test]
262    async fn test_tls_options_default() {
263        let opts = TlsOptions::default();
264        assert!(!opts.is_secure());
265
266        create_test_file(
267            "../../tests/certs/cert.pem",
268            &opts.config_dir.join("cert.pem"),
269        );
270        create_test_file(
271            "../../tests/certs/key.pem",
272            &opts.config_dir.join("key.pem"),
273        );
274        create_test_file("../../tests/certs/ca.pem", &opts.config_dir.join("ca.pem"));
275
276        // Should return temp paths in test mode
277        assert_eq!(opts.cert_path().unwrap(), opts.config_dir.join("cert.pem"));
278        assert_eq!(opts.key_path().unwrap(), opts.config_dir.join("key.pem"));
279        assert_eq!(opts.ca_path().unwrap(), opts.config_dir.join("ca.pem"));
280        assert!(opts.is_secure());
281
282        let config = opts.server_config().unwrap();
283        assert!(config.is_some());
284    }
285
286    #[test]
287    fn test_tls_options_with_paths() {
288        let opts = TlsOptions::new(
289            Some("../../tests/certs/cert.pem".into()),
290            Some("../../tests/certs/key.pem".into()),
291            Some("../../tests/certs/ca.pem".into()),
292        );
293
294        assert_eq!(
295            opts.cert_path().unwrap(),
296            PathBuf::from("../../tests/certs/cert.pem")
297        );
298        assert_eq!(
299            opts.key_path().unwrap(),
300            PathBuf::from("../../tests/certs/key.pem")
301        );
302        assert_eq!(
303            opts.ca_path().unwrap(),
304            PathBuf::from("../../tests/certs/ca.pem")
305        );
306        assert!(opts.is_secure());
307    }
308
309    #[test]
310    fn test_cached_paths() {
311        let opts = TlsOptions::default();
312
313        assert!(opts.cached_cert_path().is_none());
314        assert!(opts.cached_key_path().is_none());
315        assert!(opts.cached_ca_path().is_none());
316
317        create_test_file(
318            "../../tests/certs/cert.pem",
319            &opts.config_dir.join("cert.pem"),
320        );
321        create_test_file(
322            "../../tests/certs/key.pem",
323            &opts.config_dir.join("key.pem"),
324        );
325        create_test_file("../../tests/certs/ca.pem", &opts.config_dir.join("ca.pem"));
326
327        assert_eq!(
328            opts.cached_cert_path().unwrap(),
329            opts.config_dir.join("cert.pem")
330        );
331        assert_eq!(
332            opts.cached_key_path().unwrap(),
333            opts.config_dir.join("key.pem")
334        );
335        assert_eq!(
336            opts.cached_ca_path().unwrap(),
337            opts.config_dir.join("ca.pem")
338        );
339    }
340
341    #[tokio::test]
342    async fn test_server_config_with_valid_files_in_temp_dir() {
343        let opts = TlsOptions::new(
344            Some("../../tests/certs/cert.pem".into()),
345            Some("../../tests/certs/key.pem".into()),
346            None,
347        );
348
349        assert!(opts.is_secure());
350
351        let config = opts.server_config().unwrap();
352        assert!(config.is_some());
353    }
354
355    #[tokio::test]
356    async fn test_server_config_with_ca() {
357        let opts = TlsOptions::default();
358
359        create_test_file(
360            "../../tests/certs/cert.pem",
361            &opts.config_dir.join("cert.pem"),
362        );
363        create_test_file(
364            "../../tests/certs/key.pem",
365            &opts.config_dir.join("key.pem"),
366        );
367        create_test_file("../../tests/certs/ca.pem", &opts.config_dir.join("ca.pem"));
368
369        let config = opts.server_config().unwrap();
370        assert!(config.is_some());
371    }
372
373    #[tokio::test]
374    async fn test_client_config_with_ca() {
375        let opts = TlsOptions::default();
376
377        create_test_file(
378            "../../tests/certs/cert.pem",
379            &opts.config_dir.join("cert.pem"),
380        );
381        create_test_file(
382            "../../tests/certs/key.pem",
383            &opts.config_dir.join("key.pem"),
384        );
385        create_test_file("../../tests/certs/ca.pem", &opts.config_dir.join("ca.pem"));
386
387        let config = opts.client_config().unwrap();
388        assert!(config.alpn_protocols.is_empty()); // Default client config has no ALPN protocols
389    }
390
391    #[tokio::test]
392    async fn test_client_config_without_ca() {
393        let opts = TlsOptions::default();
394
395        create_test_file(
396            "../../tests/certs/cert.pem",
397            &opts.config_dir.join("cert.pem"),
398        );
399        create_test_file(
400            "../../tests/certs/key.pem",
401            &opts.config_dir.join("key.pem"),
402        );
403
404        let config = opts.client_config().unwrap();
405        assert!(config.alpn_protocols.is_empty());
406    }
407}