Skip to main content

agp_config/tls/
server.rs

1// Copyright AGNTCY Contributors (https://github.com/agntcy)
2// SPDX-License-Identifier: Apache-2.0
3
4use std::path::Path;
5
6use rustls::{
7    RootCertStore, ServerConfig as RustlsServerConfig,
8    server::WebPkiClientVerifier,
9    version::{TLS12, TLS13},
10};
11use rustls_pki_types::pem::PemObject;
12use rustls_pki_types::{CertificateDer, PrivateKeyDer};
13use serde::Deserialize;
14
15use super::common::{Config, ConfigError, RustlsConfigLoader};
16use crate::component::configuration::{Configuration, ConfigurationError};
17
18#[derive(Debug, Deserialize, PartialEq, Clone)]
19pub struct TlsServerConfig {
20    /// The Config struct
21    #[serde(flatten, default)]
22    pub config: Config,
23
24    /// insecure do not setup a TLS server
25    #[serde(default = "default_insecure")]
26    pub insecure: bool,
27
28    /// Path to the TLS cert to use by the server to verify a client certificate. (optional)
29    pub client_ca_file: Option<String>,
30
31    /// PEM encoded CA cert to use by the server to verify a client certificate. (optional)
32    pub client_ca_pem: Option<String>,
33
34    /// Reload the ClientCAs file when it is modified
35    /// TODO(msardara): not implemented yet
36    #[serde(default = "default_reload_client_ca_file")]
37    pub reload_client_ca_file: bool,
38}
39
40impl Default for TlsServerConfig {
41    fn default() -> Self {
42        TlsServerConfig {
43            config: Config::default(),
44            insecure: default_insecure(),
45            client_ca_file: None,
46            client_ca_pem: None,
47            reload_client_ca_file: default_reload_client_ca_file(),
48        }
49    }
50}
51
52fn default_insecure() -> bool {
53    false
54}
55
56fn default_reload_client_ca_file() -> bool {
57    false
58}
59
60/// Display the ServerConfig
61impl std::fmt::Display for TlsServerConfig {
62    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
63        write!(f, "{:?}", self)
64    }
65}
66
67// methods for ServerConfig to create a RustlsServerConfig from the config
68impl TlsServerConfig {
69    /// Create a new TlsServerConfig
70    pub fn new() -> Self {
71        TlsServerConfig {
72            ..Default::default()
73        }
74    }
75
76    /// Set insecure (disable TLS)
77    pub fn with_insecure(self, insecure: bool) -> Self {
78        TlsServerConfig { insecure, ..self }
79    }
80
81    /// Set CA file for client auth
82    pub fn with_client_ca_file(self, client_ca_file: &str) -> Self {
83        TlsServerConfig {
84            client_ca_file: Some(client_ca_file.to_string()),
85            ..self
86        }
87    }
88
89    /// Set CA pem for client auth
90    pub fn with_client_ca_pem(self, client_ca_pem: &str) -> Self {
91        TlsServerConfig {
92            client_ca_pem: Some(client_ca_pem.to_string()),
93            ..self
94        }
95    }
96
97    /// Set reload_client_ca_file
98    pub fn with_reload_client_ca_file(self, reload_client_ca_file: bool) -> Self {
99        TlsServerConfig {
100            reload_client_ca_file,
101            ..self
102        }
103    }
104
105    /// Set CA file
106    pub fn with_ca_file(self, ca_file: &str) -> Self {
107        TlsServerConfig {
108            config: self.config.with_ca_file(ca_file),
109            ..self
110        }
111    }
112
113    /// Set CA pem
114    pub fn with_ca_pem(self, ca_pem: &str) -> Self {
115        TlsServerConfig {
116            config: self.config.with_ca_pem(ca_pem),
117            ..self
118        }
119    }
120
121    /// Set include system CA certs pool
122    pub fn with_include_system_ca_certs_pool(self, include_system_ca_certs_pool: bool) -> Self {
123        TlsServerConfig {
124            config: self
125                .config
126                .with_include_system_ca_certs_pool(include_system_ca_certs_pool),
127            ..self
128        }
129    }
130
131    /// Set cert file
132    pub fn with_cert_file(self, cert_file: &str) -> Self {
133        TlsServerConfig {
134            config: self.config.with_cert_file(cert_file),
135            ..self
136        }
137    }
138
139    /// Set cert pem
140    pub fn with_cert_pem(self, cert_pem: &str) -> Self {
141        TlsServerConfig {
142            config: self.config.with_cert_pem(cert_pem),
143            ..self
144        }
145    }
146
147    /// Set key file
148    pub fn with_key_file(self, key_file: &str) -> Self {
149        TlsServerConfig {
150            config: self.config.with_key_file(key_file),
151            ..self
152        }
153    }
154
155    /// Set key pem
156    pub fn with_key_pem(self, key_pem: &str) -> Self {
157        TlsServerConfig {
158            config: self.config.with_key_pem(key_pem),
159            ..self
160        }
161    }
162
163    /// Set TLS version
164    pub fn with_tls_version(self, tls_version: &str) -> Self {
165        TlsServerConfig {
166            config: self.config.with_tls_version(tls_version),
167            ..self
168        }
169    }
170
171    /// Set reload interval
172    pub fn with_reload_interval(self, reload_interval: Option<std::time::Duration>) -> Self {
173        TlsServerConfig {
174            config: self.config.with_reload_interval(reload_interval),
175            ..self
176        }
177    }
178
179    pub fn load_rustls_server_config(&self) -> Result<Option<RustlsServerConfig>, ConfigError> {
180        // Check if insecure is set
181        if self.insecure {
182            return Ok(None);
183        }
184
185        // Check TLS version
186        let tls_version = match self.config.tls_version.as_str() {
187            "tls1.2" => &TLS12,
188            "tls1.3" => &TLS13,
189            _ => {
190                return Err(ConfigError::InvalidTlsVersion(
191                    self.config.tls_version.clone(),
192                ));
193            }
194        };
195
196        // Get certificate & key
197        let (cert, key) = match (
198            self.config.has_cert_file() && self.config.has_key_file(),
199            self.config.has_cert_pem() && self.config.has_key_pem(),
200        ) {
201            (true, true) => {
202                // If both cert_file and cert_pem are set, return an error
203                return Err(ConfigError::CannotUseBoth("cert".to_string()));
204            }
205            (false, false) => {
206                // If no cert, return an error
207                return Err(ConfigError::MissingServerCertAndKey);
208            }
209            (true, false) => (
210                CertificateDer::from_pem_file(Path::new(self.config.cert_file.as_ref().unwrap()))
211                    .map_err(ConfigError::InvalidPem)?,
212                PrivateKeyDer::from_pem_file(Path::new(self.config.key_file.as_ref().unwrap()))
213                    .map_err(ConfigError::InvalidPem)?,
214            ),
215            (false, true) => (
216                CertificateDer::from_pem_slice(self.config.cert_pem.as_ref().unwrap().as_bytes())
217                    .map_err(ConfigError::InvalidPem)?,
218                PrivateKeyDer::from_pem_slice(self.config.key_pem.as_ref().unwrap().as_bytes())
219                    .map_err(ConfigError::InvalidPem)?,
220            ),
221        };
222
223        // create a server ConfigBuilder
224        let config_builder = RustlsServerConfig::builder_with_protocol_versions(&[tls_version]);
225
226        // Check whether to enable client auth or not
227        let client_ca = match (&self.client_ca_file, &self.client_ca_pem) {
228            (Some(_), Some(_)) => return Err(ConfigError::CannotUseBoth("client_ca".to_string())),
229            (Some(_), None) => Option::Some(
230                CertificateDer::from_pem_file(Path::new(self.client_ca_file.as_ref().unwrap()))
231                    .map_err(ConfigError::InvalidPem)?,
232            ),
233            (None, Some(_)) => Option::Some(
234                CertificateDer::from_pem_slice(self.client_ca_pem.as_ref().unwrap().as_bytes())
235                    .map_err(ConfigError::InvalidPem)?,
236            ),
237            (None, None) => Option::None,
238        };
239
240        // create root store if client_ca is set
241        let server_config = match client_ca {
242            Some(client_ca) => {
243                let mut root_store = RootCertStore::empty();
244                root_store.add(client_ca).map_err(ConfigError::RootStore)?;
245                let verifier = WebPkiClientVerifier::builder(root_store.into())
246                    .build()
247                    .map_err(ConfigError::VerifierBuilder)?;
248                config_builder
249                    .with_client_cert_verifier(verifier)
250                    .with_single_cert(vec![cert], key)
251                    .map_err(ConfigError::ConfigBuilder)?
252            }
253            None => config_builder
254                .with_no_client_auth()
255                .with_single_cert(vec![cert], key)
256                .map_err(ConfigError::ConfigBuilder)?,
257        };
258
259        // We are good to go
260        Ok(Some(server_config))
261    }
262}
263
264// trait implementation
265impl RustlsConfigLoader<RustlsServerConfig> for TlsServerConfig {
266    fn load_rustls_config(&self) -> Result<Option<RustlsServerConfig>, ConfigError> {
267        let server_config = self.load_rustls_server_config()?;
268        Ok(server_config)
269    }
270}
271
272impl Configuration for TlsServerConfig {
273    fn validate(&self) -> Result<(), ConfigurationError> {
274        // TODO(msardara): validate the configuration
275        Ok(())
276    }
277}