hyperdb_api_core/client/
tls.rs1use std::path::PathBuf;
39
40#[derive(Debug, Clone)]
42pub struct TlsConfig {
43 pub verify_server: bool,
45 pub ca_cert_path: Option<PathBuf>,
47 pub client_cert_path: Option<PathBuf>,
49 pub client_key_path: Option<PathBuf>,
51 pub server_name: Option<String>,
53}
54
55impl Default for TlsConfig {
56 fn default() -> Self {
57 TlsConfig {
58 verify_server: true,
59 ca_cert_path: None,
60 client_cert_path: None,
61 client_key_path: None,
62 server_name: None,
63 }
64 }
65}
66
67impl TlsConfig {
68 #[must_use]
70 pub fn new() -> Self {
71 Self::default()
72 }
73
74 #[must_use]
75 pub fn danger_accept_invalid_certs(mut self) -> Self {
89 tracing::warn!(
90 "TLS certificate verification disabled - this should only be used for testing. \
91 Man-in-the-middle attacks are possible."
92 );
93
94 self.verify_server = false;
95 self
96 }
97
98 #[must_use]
99 pub fn ca_cert(mut self, path: impl Into<PathBuf>) -> Self {
101 self.ca_cert_path = Some(path.into());
102 self
103 }
104
105 #[must_use]
106 pub fn client_cert(
108 mut self,
109 cert_path: impl Into<PathBuf>,
110 key_path: impl Into<PathBuf>,
111 ) -> Self {
112 self.client_cert_path = Some(cert_path.into());
113 self.client_key_path = Some(key_path.into());
114 self
115 }
116
117 #[must_use]
118 pub fn server_name(mut self, name: impl Into<String>) -> Self {
120 self.server_name = Some(name.into());
121 self
122 }
123
124 #[must_use]
126 pub fn has_client_cert(&self) -> bool {
127 self.client_cert_path.is_some() && self.client_key_path.is_some()
128 }
129}
130
131#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
133pub enum TlsMode {
134 #[default]
136 Disable,
137 Prefer,
139 Require,
141 VerifyCA,
143 VerifyFull,
145}
146
147impl TlsMode {
148 #[must_use]
150 pub fn is_enabled(&self) -> bool {
151 !matches!(self, TlsMode::Disable)
152 }
153
154 #[must_use]
156 pub fn is_required(&self) -> bool {
157 matches!(
158 self,
159 TlsMode::Require | TlsMode::VerifyCA | TlsMode::VerifyFull
160 )
161 }
162
163 #[must_use]
165 pub fn verify_server(&self) -> bool {
166 matches!(self, TlsMode::VerifyCA | TlsMode::VerifyFull)
167 }
168
169 #[must_use]
171 pub fn verify_hostname(&self) -> bool {
172 matches!(self, TlsMode::VerifyFull)
173 }
174}
175
176pub mod rustls_impl {
178 use super::TlsConfig;
179 use std::sync::Arc;
180
181 use rustls::pki_types::pem::PemObject;
182 use rustls::pki_types::{CertificateDer, PrivateKeyDer};
183 use tokio::net::TcpStream;
184 use tokio_rustls::rustls::{ClientConfig, RootCertStore};
185 use tokio_rustls::TlsConnector;
186
187 use crate::client::error::{Error, ErrorKind, Result};
188
189 pub fn create_connector(config: &TlsConfig, _host: &str) -> Result<TlsConnector> {
209 let mut root_store = RootCertStore::empty();
210
211 root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
213
214 if let Some(ref ca_path) = config.ca_cert_path {
216 let certs = CertificateDer::pem_file_iter(ca_path)
217 .map_err(|e| Error::new(ErrorKind::Config, format!("failed to read CA cert: {e}")))?
218 .collect::<std::result::Result<Vec<_>, _>>()
219 .map_err(|e| Error::new(ErrorKind::Config, format!("invalid CA cert: {e}")))?;
220 for cert in certs {
221 root_store.add(cert).map_err(|e| {
222 Error::new(ErrorKind::Config, format!("failed to add CA cert: {e}"))
223 })?;
224 }
225 }
226
227 let provider = Arc::new(rustls::crypto::ring::default_provider());
228 let builder = ClientConfig::builder_with_provider(provider)
229 .with_safe_default_protocol_versions()
230 .map_err(|e| Error::new(ErrorKind::Config, format!("TLS protocol config error: {e}")))?
231 .with_root_certificates(root_store);
232
233 let client_config = if config.has_client_cert() {
234 let cert_path = config.client_cert_path.as_ref().unwrap();
236 let key_path = config.client_key_path.as_ref().unwrap();
237
238 let certs = CertificateDer::pem_file_iter(cert_path)
239 .map_err(|e| {
240 Error::new(
241 ErrorKind::Config,
242 format!("failed to read client cert: {e}"),
243 )
244 })?
245 .collect::<std::result::Result<Vec<_>, _>>()
246 .map_err(|e| Error::new(ErrorKind::Config, format!("invalid client cert: {e}")))?;
247
248 let key = PrivateKeyDer::from_pem_file(key_path)
252 .map_err(|e| Error::new(ErrorKind::Config, format!("invalid client key: {e}")))?;
253
254 builder
255 .with_client_auth_cert(certs, key)
256 .map_err(|e| Error::new(ErrorKind::Config, format!("invalid client auth: {e}")))?
257 } else {
258 builder.with_no_client_auth()
259 };
260
261 Ok(TlsConnector::from(Arc::new(client_config)))
262 }
263
264 pub type TlsStream = tokio_rustls::client::TlsStream<TcpStream>;
266
267 pub async fn wrap_stream(
277 stream: TcpStream,
278 connector: &TlsConnector,
279 server_name: &str,
280 ) -> Result<TlsStream> {
281 let domain = rustls::pki_types::ServerName::try_from(server_name.to_string())
282 .map_err(|_| Error::new(ErrorKind::Config, "invalid server name"))?;
283
284 connector
285 .connect(domain, stream)
286 .await
287 .map_err(|e| Error::new(ErrorKind::Connection, format!("TLS handshake failed: {e}")))
288 }
289}
290
291#[cfg(test)]
292mod tests {
293 use super::*;
294
295 #[test]
296 fn test_tls_config_default() {
297 let config = TlsConfig::default();
298 assert!(config.verify_server);
299 assert!(config.ca_cert_path.is_none());
300 assert!(!config.has_client_cert());
301 }
302
303 #[test]
304 fn test_tls_config_builder() {
305 let config = TlsConfig::new()
306 .ca_cert("/path/to/ca.pem")
307 .client_cert("/path/to/cert.pem", "/path/to/key.pem")
308 .server_name("example.com");
309
310 assert!(config.has_client_cert());
311 assert_eq!(config.server_name, Some("example.com".to_string()));
312 }
313
314 #[test]
315 fn test_tls_mode() {
316 assert!(!TlsMode::Disable.is_enabled());
317 assert!(TlsMode::Require.is_required());
318 assert!(TlsMode::VerifyFull.verify_hostname());
319 }
320}