Skip to main content

cargo_lambda_remote/
tls.rs

1use clap::Args;
2use miette::{Diagnostic, Result};
3use rustls::{ClientConfig, RootCertStore, ServerConfig};
4use rustls_pki_types::{CertificateDer, PrivateKeyDer, pem::PemObject};
5use serde::{Deserialize, Serialize, ser::SerializeStruct};
6use std::{
7    path::{Path, PathBuf},
8    sync::OnceLock,
9};
10use thiserror::Error;
11
12static CELL: OnceLock<bool> = OnceLock::new();
13
14#[derive(Debug, Diagnostic, Error)]
15pub enum TlsError {
16    #[error("missing TLS certificate")]
17    #[diagnostic()]
18    MissingTlsCert,
19
20    #[error("missing TLS key")]
21    #[diagnostic()]
22    MissingTlsKey,
23
24    #[error("invalid TLS file: {0}, {1}")]
25    #[diagnostic()]
26    InvalidTlsFile(PathBuf, rustls_pki_types::pem::Error),
27
28    #[error("failed to parse TLS key: {0}")]
29    #[diagnostic()]
30    FailedToParseTlsKey(String),
31
32    #[error("failed to parse config: {0}")]
33    #[diagnostic()]
34    FailedToParseConfig(#[from] rustls::Error),
35}
36
37#[derive(Args, Clone, Debug, Deserialize, Serialize)]
38pub struct TlsOptions {
39    /// Path to a TLS certificate file
40    #[arg(long, conflicts_with = "remote")]
41    #[serde(default)]
42    pub tls_cert: Option<PathBuf>,
43    /// Path to a TLS key file
44    #[arg(long, conflicts_with = "remote")]
45    #[serde(default)]
46    pub tls_key: Option<PathBuf>,
47    /// Path to a TLS CA file
48    #[arg(long, conflicts_with = "remote")]
49    #[serde(default)]
50    pub tls_ca: Option<PathBuf>,
51
52    #[cfg(test)]
53    pub config_dir: PathBuf,
54}
55
56impl TlsOptions {
57    #[cfg(not(test))]
58    pub fn new(
59        tls_cert: Option<PathBuf>,
60        tls_key: Option<PathBuf>,
61        tls_ca: Option<PathBuf>,
62    ) -> Self {
63        Self {
64            tls_cert,
65            tls_key,
66            tls_ca,
67        }
68    }
69
70    #[cfg(test)]
71    pub fn new(
72        tls_cert: Option<PathBuf>,
73        tls_key: Option<PathBuf>,
74        tls_ca: Option<PathBuf>,
75    ) -> Self {
76        Self {
77            tls_cert,
78            tls_key,
79            tls_ca,
80            config_dir: tempfile::TempDir::new().unwrap().path().to_path_buf(),
81        }
82    }
83
84    pub fn is_secure(&self) -> bool {
85        self.cert_path().is_some() && self.key_path().is_some()
86    }
87
88    pub fn server_config(&self) -> Result<Option<ServerConfig>> {
89        if !self.is_secure() {
90            return Ok(None);
91        }
92
93        CELL.get_or_init(install_default_tls_provider);
94
95        let (mut cert_chain, key) =
96            parse_cert_and_key(self.cert_path().as_ref(), self.key_path().as_ref())?;
97
98        if let Some(path) = self.ca_path() {
99            let certs = parse_certificates(path)?;
100            if !certs.is_empty() {
101                cert_chain.extend(certs);
102            }
103        }
104
105        let mut config = ServerConfig::builder()
106            .with_no_client_auth()
107            .with_single_cert(cert_chain, key)
108            .map_err(TlsError::FailedToParseConfig)?;
109
110        config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
111
112        Ok(Some(config))
113    }
114
115    pub fn client_config(&self) -> Result<ClientConfig> {
116        CELL.get_or_init(install_default_tls_provider);
117
118        let builder = if let Some(path) = self.ca_path() {
119            let mut root_store = RootCertStore::empty();
120            root_store.add_parsable_certificates(parse_certificates(path)?);
121            ClientConfig::builder().with_root_certificates(root_store)
122        } else {
123            use rustls_platform_verifier::BuilderVerifierExt;
124            ClientConfig::builder().with_platform_verifier()
125        };
126
127        let (cert, key) = parse_cert_and_key(self.cert_path().as_ref(), self.key_path().as_ref())?;
128
129        let config = builder
130            .with_client_auth_cert(cert, key)
131            .map_err(TlsError::FailedToParseConfig)?;
132
133        Ok(config)
134    }
135
136    fn cert_path(&self) -> Option<PathBuf> {
137        self.tls_cert.clone().or_else(|| self.cached_cert_path())
138    }
139
140    fn key_path(&self) -> Option<PathBuf> {
141        self.tls_key.clone().or_else(|| self.cached_key_path())
142    }
143
144    fn ca_path(&self) -> Option<PathBuf> {
145        self.tls_ca.clone().or_else(|| self.cached_ca_path())
146    }
147
148    fn cached_cert_path(&self) -> Option<PathBuf> {
149        let cache = self.config_dir().map(|p| p.join("cert.pem"));
150        if cache.as_ref().is_some_and(|p| p.exists() && p.is_file()) {
151            return cache;
152        }
153
154        None
155    }
156
157    fn cached_key_path(&self) -> Option<PathBuf> {
158        let cache = self.config_dir().map(|p| p.join("key.pem"));
159        if cache.as_ref().is_some_and(|p| p.exists() && p.is_file()) {
160            return cache;
161        }
162
163        None
164    }
165
166    fn cached_ca_path(&self) -> Option<PathBuf> {
167        let cache = self.config_dir().map(|p| p.join("ca.pem"));
168        if cache.as_ref().is_some_and(|p| p.exists() && p.is_file()) {
169            return cache;
170        }
171
172        None
173    }
174
175    #[cfg(not(test))]
176    fn config_dir(&self) -> Option<PathBuf> {
177        // First check XDG_CONFIG_HOME environment variable explicitly
178        if let Ok(xdg_config) = std::env::var("XDG_CONFIG_HOME") {
179            let path = PathBuf::from(xdg_config).join("cargo-lambda");
180            if path.exists() {
181                return Some(path);
182            }
183        }
184
185        // Fall back to dirs::config_dir()
186        dirs::config_dir().map(|p| p.join("cargo-lambda"))
187    }
188
189    #[cfg(test)]
190    fn config_dir(&self) -> Option<PathBuf> {
191        Some(self.config_dir.clone())
192    }
193
194    pub fn count_fields(&self) -> usize {
195        self.tls_cert.is_some() as usize
196            + self.tls_key.is_some() as usize
197            + self.tls_ca.is_some() as usize
198    }
199
200    pub fn serialize_fields<S>(
201        &self,
202        state: &mut <S as serde::Serializer>::SerializeStruct,
203    ) -> Result<(), S::Error>
204    where
205        S: serde::Serializer,
206    {
207        if let Some(tls_cert) = &self.tls_cert {
208            state.serialize_field("tls_cert", tls_cert)?;
209        }
210        if let Some(tls_key) = &self.tls_key {
211            state.serialize_field("tls_key", tls_key)?;
212        }
213        if let Some(tls_ca) = &self.tls_ca {
214            state.serialize_field("tls_ca", tls_ca)?;
215        }
216        Ok(())
217    }
218}
219
220impl Default for TlsOptions {
221    fn default() -> Self {
222        Self::new(None, None, None)
223    }
224}
225
226fn parse_certificates<P: AsRef<Path>>(path: P) -> Result<Vec<CertificateDer<'static>>> {
227    let path = path.as_ref();
228    let parser = CertificateDer::pem_file_iter(path)
229        .map_err(|e| TlsError::InvalidTlsFile(path.to_path_buf(), e))?
230        .collect::<Vec<_>>();
231
232    let mut certs = Vec::with_capacity(parser.len());
233    for cert in parser {
234        certs.push(cert.map_err(|e| TlsError::InvalidTlsFile(path.to_path_buf(), e))?);
235    }
236
237    Ok(certs)
238}
239
240fn parse_cert_and_key(
241    cert: Option<&PathBuf>,
242    key: Option<&PathBuf>,
243) -> Result<(Vec<CertificateDer<'static>>, PrivateKeyDer<'static>)> {
244    let path = cert.ok_or(TlsError::MissingTlsCert)?;
245    let cert = parse_certificates(path)?;
246
247    let path = key.ok_or(TlsError::MissingTlsKey)?;
248    let key = PrivateKeyDer::from_pem_file(path)
249        .map_err(|e| TlsError::FailedToParseTlsKey(e.to_string()))?;
250
251    Ok((cert, key))
252}
253
254fn install_default_tls_provider() -> bool {
255    rustls::crypto::aws_lc_rs::default_provider()
256        .install_default()
257        .expect("failed to install the default TLS provider");
258    true
259}
260
261#[cfg(test)]
262mod tests {
263    use super::*;
264
265    fn create_test_file(source: &str, destination: &PathBuf) {
266        std::fs::create_dir_all(destination.parent().unwrap()).unwrap();
267        std::fs::copy(source, destination).unwrap();
268    }
269
270    #[tokio::test]
271    async fn test_tls_options_default() {
272        let opts = TlsOptions::default();
273        assert!(!opts.is_secure());
274
275        create_test_file(
276            "../../tests/certs/cert.pem",
277            &opts.config_dir.join("cert.pem"),
278        );
279        create_test_file(
280            "../../tests/certs/key.pem",
281            &opts.config_dir.join("key.pem"),
282        );
283        create_test_file("../../tests/certs/ca.pem", &opts.config_dir.join("ca.pem"));
284
285        // Should return temp paths in test mode
286        assert_eq!(opts.cert_path().unwrap(), opts.config_dir.join("cert.pem"));
287        assert_eq!(opts.key_path().unwrap(), opts.config_dir.join("key.pem"));
288        assert_eq!(opts.ca_path().unwrap(), opts.config_dir.join("ca.pem"));
289        assert!(opts.is_secure());
290
291        let config = opts.server_config().unwrap();
292        assert!(config.is_some());
293    }
294
295    #[test]
296    fn test_tls_options_with_paths() {
297        let opts = TlsOptions::new(
298            Some("../../tests/certs/cert.pem".into()),
299            Some("../../tests/certs/key.pem".into()),
300            Some("../../tests/certs/ca.pem".into()),
301        );
302
303        assert_eq!(
304            opts.cert_path().unwrap(),
305            PathBuf::from("../../tests/certs/cert.pem")
306        );
307        assert_eq!(
308            opts.key_path().unwrap(),
309            PathBuf::from("../../tests/certs/key.pem")
310        );
311        assert_eq!(
312            opts.ca_path().unwrap(),
313            PathBuf::from("../../tests/certs/ca.pem")
314        );
315        assert!(opts.is_secure());
316    }
317
318    #[test]
319    fn test_cached_paths() {
320        let opts = TlsOptions::default();
321
322        assert!(opts.cached_cert_path().is_none());
323        assert!(opts.cached_key_path().is_none());
324        assert!(opts.cached_ca_path().is_none());
325
326        create_test_file(
327            "../../tests/certs/cert.pem",
328            &opts.config_dir.join("cert.pem"),
329        );
330        create_test_file(
331            "../../tests/certs/key.pem",
332            &opts.config_dir.join("key.pem"),
333        );
334        create_test_file("../../tests/certs/ca.pem", &opts.config_dir.join("ca.pem"));
335
336        assert_eq!(
337            opts.cached_cert_path().unwrap(),
338            opts.config_dir.join("cert.pem")
339        );
340        assert_eq!(
341            opts.cached_key_path().unwrap(),
342            opts.config_dir.join("key.pem")
343        );
344        assert_eq!(
345            opts.cached_ca_path().unwrap(),
346            opts.config_dir.join("ca.pem")
347        );
348    }
349
350    #[tokio::test]
351    async fn test_server_config_with_valid_files_in_temp_dir() {
352        let opts = TlsOptions::new(
353            Some("../../tests/certs/cert.pem".into()),
354            Some("../../tests/certs/key.pem".into()),
355            None,
356        );
357
358        assert!(opts.is_secure());
359
360        let config = opts.server_config().unwrap();
361        assert!(config.is_some());
362    }
363
364    #[tokio::test]
365    async fn test_server_config_with_ca() {
366        let opts = TlsOptions::default();
367
368        create_test_file(
369            "../../tests/certs/cert.pem",
370            &opts.config_dir.join("cert.pem"),
371        );
372        create_test_file(
373            "../../tests/certs/key.pem",
374            &opts.config_dir.join("key.pem"),
375        );
376        create_test_file("../../tests/certs/ca.pem", &opts.config_dir.join("ca.pem"));
377
378        let config = opts.server_config().unwrap();
379        assert!(config.is_some());
380    }
381
382    #[tokio::test]
383    async fn test_client_config_with_ca() {
384        let opts = TlsOptions::default();
385
386        create_test_file(
387            "../../tests/certs/cert.pem",
388            &opts.config_dir.join("cert.pem"),
389        );
390        create_test_file(
391            "../../tests/certs/key.pem",
392            &opts.config_dir.join("key.pem"),
393        );
394        create_test_file("../../tests/certs/ca.pem", &opts.config_dir.join("ca.pem"));
395
396        let config = opts.client_config().unwrap();
397        assert!(config.alpn_protocols.is_empty()); // Default client config has no ALPN protocols
398    }
399
400    #[tokio::test]
401    async fn test_client_config_without_ca() {
402        let opts = TlsOptions::default();
403
404        create_test_file(
405            "../../tests/certs/cert.pem",
406            &opts.config_dir.join("cert.pem"),
407        );
408        create_test_file(
409            "../../tests/certs/key.pem",
410            &opts.config_dir.join("key.pem"),
411        );
412
413        let config = opts.client_config().unwrap();
414        assert!(config.alpn_protocols.is_empty());
415    }
416}