Skip to main content

agp_config/tls/
common.rs

1// Copyright AGNTCY Contributors (https://github.com/agntcy)
2// SPDX-License-Identifier: Apache-2.0
3
4use rustls::RootCertStore;
5use rustls::server::VerifierBuilderError;
6use rustls_native_certs;
7use rustls_pki_types::CertificateDer;
8use rustls_pki_types::pem::PemObject;
9use serde::Deserialize;
10use std::path::Path;
11use std::time::Duration;
12use thiserror::Error;
13
14#[derive(Debug, Deserialize, PartialEq, Clone)]
15pub struct Config {
16    /// Path to the CA cert. For a client this verifies the server certificate.
17    /// For a server this verifies client certificates. If empty uses system root CA.
18    /// (optional)
19    pub ca_file: Option<String>,
20
21    /// In memory PEM encoded cert. (optional)
22    pub ca_pem: Option<String>,
23
24    /// If true, load system CA certificates pool in addition to the certificates
25    /// configured in this struct.
26    #[serde(default = "default_include_system_ca_certs_pool")]
27    pub include_system_ca_certs_pool: bool,
28
29    /// Path to the TLS cert to use for TLS required connections. (optional)
30    pub cert_file: Option<String>,
31
32    /// In memory PEM encoded TLS cert to use for TLS required connections. (optional)
33    pub cert_pem: Option<String>,
34
35    /// Path to the TLS key to use for TLS required connections. (optional)
36    pub key_file: Option<String>,
37
38    /// In memory PEM encoded TLS key to use for TLS required connections. (optional)
39    pub key_pem: Option<String>,
40
41    /// The TLS version to use. If not set, the default is "tls1.3".
42    /// The value must be either "tls1.2" or "tls1.3".
43    /// (optional)
44    #[serde(default = "default_tls_version")]
45    pub tls_version: String,
46
47    /// ReloadInterval specifies the duration after which the certificate will be reloaded
48    /// If not set, it will never be reloaded
49    // TODO(msardara): not implemented yet
50    pub reload_interval: Option<Duration>,
51}
52
53/// Errors for Config
54#[derive(Error, Debug)]
55pub enum ConfigError {
56    #[error("invalid tls version: {0}")]
57    InvalidTlsVersion(String),
58    #[error("invalid pem format: {0}")]
59    InvalidPem(rustls_pki_types::pem::Error),
60    #[error("cannot use both file and pem for {0}")]
61    CannotUseBoth(String),
62    #[error("root store error: {0}")]
63    RootStore(rustls::Error),
64    #[error("config builder error")]
65    ConfigBuilder(rustls::Error),
66    #[error("missing server cert and key. cert_{{file, pem}} and key_{{file, pem}} must be set")]
67    MissingServerCertAndKey,
68    #[error("verifier builder error")]
69    VerifierBuilder(VerifierBuilderError),
70    #[error("unknown error")]
71    Unknown,
72}
73
74// Defaults for Config
75impl Default for Config {
76    fn default() -> Config {
77        Config {
78            ca_file: None,
79            ca_pem: None,
80            include_system_ca_certs_pool: default_include_system_ca_certs_pool(),
81            cert_file: None,
82            cert_pem: None,
83            key_file: None,
84            key_pem: None,
85            tls_version: "tls1.3".to_string(),
86            reload_interval: None,
87        }
88    }
89}
90
91// Default system CA certs pool
92fn default_include_system_ca_certs_pool() -> bool {
93    false
94}
95
96// Default for tls version
97fn default_tls_version() -> String {
98    "tls1.3".to_string()
99}
100
101impl Config {
102    pub(crate) fn with_ca_file(self, ca_file: &str) -> Config {
103        Config {
104            ca_file: Some(ca_file.to_string()),
105            ..self
106        }
107    }
108
109    pub(crate) fn with_ca_pem(self, ca_pem: &str) -> Config {
110        Config {
111            ca_pem: Some(ca_pem.to_string()),
112            ..self
113        }
114    }
115
116    pub(crate) fn with_include_system_ca_certs_pool(
117        self,
118        include_system_ca_certs_pool: bool,
119    ) -> Config {
120        Config {
121            include_system_ca_certs_pool,
122            ..self
123        }
124    }
125
126    pub(crate) fn with_cert_file(self, cert_file: &str) -> Config {
127        Config {
128            cert_file: Some(cert_file.to_string()),
129            ..self
130        }
131    }
132
133    pub(crate) fn with_cert_pem(self, cert_pem: &str) -> Config {
134        Config {
135            cert_pem: Some(cert_pem.to_string()),
136            ..self
137        }
138    }
139
140    pub(crate) fn with_key_file(self, key_file: &str) -> Config {
141        Config {
142            key_file: Some(key_file.to_string()),
143            ..self
144        }
145    }
146
147    pub(crate) fn with_key_pem(self, key_pem: &str) -> Config {
148        Config {
149            key_pem: Some(key_pem.to_string()),
150            ..self
151        }
152    }
153
154    pub(crate) fn with_tls_version(self, tls_version: &str) -> Config {
155        Config {
156            tls_version: tls_version.to_string(),
157            ..self
158        }
159    }
160
161    pub(crate) fn with_reload_interval(self, reload_interval: Option<Duration>) -> Config {
162        Config {
163            reload_interval,
164            ..self
165        }
166    }
167
168    pub(crate) fn load_ca_cert_pool(&self) -> Result<RootCertStore, ConfigError> {
169        let mut root_store = RootCertStore::empty();
170
171        let cert = match (self.has_ca_file(), self.has_ca_pem()) {
172            (true, true) => return Err(ConfigError::CannotUseBoth("ca".to_string())),
173            (true, false) => {
174                CertificateDer::from_pem_file(Path::new(self.ca_file.as_ref().unwrap()))
175                    .map_err(ConfigError::InvalidPem)?
176            }
177            (false, true) => {
178                CertificateDer::from_pem_slice(self.ca_pem.as_ref().unwrap().as_bytes())
179                    .map_err(ConfigError::InvalidPem)?
180            }
181            (false, false) => return Ok(root_store),
182        };
183
184        root_store.add(cert).map_err(ConfigError::RootStore)?;
185
186        if self.include_system_ca_certs_pool {
187            for cert in
188                rustls_native_certs::load_native_certs().expect("could not load platform certs")
189            {
190                root_store.add(cert).map_err(ConfigError::RootStore)?;
191            }
192        }
193
194        Ok(root_store)
195    }
196
197    /// Returns true if the config has a CA cert
198    pub fn has_ca(&self) -> bool {
199        self.has_ca_file() || self.has_ca_pem()
200    }
201
202    /// Returns true if the config has a CA file
203    pub fn has_ca_file(&self) -> bool {
204        self.ca_file.is_some()
205    }
206
207    /// Returns true if the config has a CA PEM
208    pub fn has_ca_pem(&self) -> bool {
209        self.ca_pem.is_some()
210    }
211
212    /// Returns true if the config has a cert file
213    pub fn has_cert_file(&self) -> bool {
214        self.cert_file.is_some()
215    }
216
217    /// Returns true if the config has a cert PEM
218    pub fn has_cert_pem(&self) -> bool {
219        self.cert_pem.is_some()
220    }
221
222    /// Returns true if the config has a key file
223    pub fn has_key_file(&self) -> bool {
224        self.key_file.is_some()
225    }
226
227    /// Returns true if the config has a key PEM
228    pub fn has_key_pem(&self) -> bool {
229        self.key_pem.is_some()
230    }
231}
232
233// trait load_rustls_config
234pub trait RustlsConfigLoader<T> {
235    fn load_rustls_config(&self) -> Result<Option<T>, ConfigError>;
236}
237
238// Tests
239#[cfg(test)]
240mod tests {
241    use super::*;
242
243    #[test]
244    fn test_default() {
245        let config = Config::default();
246        assert_eq!(config.ca_file, None);
247        assert_eq!(config.ca_pem, None);
248        assert!(!config.include_system_ca_certs_pool);
249        assert_eq!(config.cert_file, None);
250        assert_eq!(config.cert_pem, None);
251        assert_eq!(config.key_file, None);
252        assert_eq!(config.key_pem, None);
253        assert_eq!(config.tls_version, "tls1.3".to_string());
254        assert_eq!(config.reload_interval, None);
255    }
256}