hyperdb_api_core/client/
tls.rs1use std::path::PathBuf;
39
40#[derive(Debug, Clone)]
42pub struct TlsConfig {
43 pub verify_server: bool,
45 pub ca_cert_path: Option<PathBuf>,
47 pub client_cert_path: Option<PathBuf>,
49 pub client_key_path: Option<PathBuf>,
51 pub server_name: Option<String>,
53}
54
55impl Default for TlsConfig {
56 fn default() -> Self {
57 TlsConfig {
58 verify_server: true,
59 ca_cert_path: None,
60 client_cert_path: None,
61 client_key_path: None,
62 server_name: None,
63 }
64 }
65}
66
67impl TlsConfig {
68 #[must_use]
70 pub fn new() -> Self {
71 Self::default()
72 }
73
74 #[must_use]
75 pub fn danger_accept_invalid_certs(mut self) -> Self {
89 tracing::warn!(
90 "TLS certificate verification disabled - this should only be used for testing. \
91 Man-in-the-middle attacks are possible."
92 );
93
94 self.verify_server = false;
95 self
96 }
97
98 #[must_use]
99 pub fn ca_cert(mut self, path: impl Into<PathBuf>) -> Self {
101 self.ca_cert_path = Some(path.into());
102 self
103 }
104
105 #[must_use]
106 pub fn client_cert(
108 mut self,
109 cert_path: impl Into<PathBuf>,
110 key_path: impl Into<PathBuf>,
111 ) -> Self {
112 self.client_cert_path = Some(cert_path.into());
113 self.client_key_path = Some(key_path.into());
114 self
115 }
116
117 #[must_use]
118 pub fn server_name(mut self, name: impl Into<String>) -> Self {
120 self.server_name = Some(name.into());
121 self
122 }
123
124 #[must_use]
126 pub fn has_client_cert(&self) -> bool {
127 self.client_cert_path.is_some() && self.client_key_path.is_some()
128 }
129}
130
131#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
133pub enum TlsMode {
134 #[default]
136 Disable,
137 Prefer,
139 Require,
141 VerifyCA,
143 VerifyFull,
145}
146
147impl TlsMode {
148 #[must_use]
150 pub fn is_enabled(&self) -> bool {
151 !matches!(self, TlsMode::Disable)
152 }
153
154 #[must_use]
156 pub fn is_required(&self) -> bool {
157 matches!(
158 self,
159 TlsMode::Require | TlsMode::VerifyCA | TlsMode::VerifyFull
160 )
161 }
162
163 #[must_use]
165 pub fn verify_server(&self) -> bool {
166 matches!(self, TlsMode::VerifyCA | TlsMode::VerifyFull)
167 }
168
169 #[must_use]
171 pub fn verify_hostname(&self) -> bool {
172 matches!(self, TlsMode::VerifyFull)
173 }
174}
175
176pub mod rustls_impl {
178 use super::TlsConfig;
179 use std::io::BufReader;
180 use std::sync::Arc;
181
182 use tokio::net::TcpStream;
183 use tokio_rustls::rustls::{ClientConfig, RootCertStore};
184 use tokio_rustls::TlsConnector;
185
186 use crate::client::error::{Error, ErrorKind, Result};
187
188 pub fn create_connector(config: &TlsConfig, _host: &str) -> Result<TlsConnector> {
208 let mut root_store = RootCertStore::empty();
209
210 root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
212
213 if let Some(ref ca_path) = config.ca_cert_path {
215 let ca_file = std::fs::File::open(ca_path).map_err(|e| {
216 Error::new(ErrorKind::Config, format!("failed to open CA cert: {e}"))
217 })?;
218 let mut ca_reader = BufReader::new(ca_file);
219 let certs = rustls_pemfile::certs(&mut ca_reader)
220 .map(|r| {
221 r.map_err(|e| Error::new(ErrorKind::Config, format!("invalid CA cert: {e}")))
222 })
223 .collect::<Result<Vec<_>>>()?;
224 for cert in certs {
225 root_store.add(cert).map_err(|e| {
226 Error::new(ErrorKind::Config, format!("failed to add CA cert: {e}"))
227 })?;
228 }
229 }
230
231 let provider = Arc::new(rustls::crypto::ring::default_provider());
232 let builder = ClientConfig::builder_with_provider(provider)
233 .with_safe_default_protocol_versions()
234 .map_err(|e| Error::new(ErrorKind::Config, format!("TLS protocol config error: {e}")))?
235 .with_root_certificates(root_store);
236
237 let client_config = if config.has_client_cert() {
238 let cert_path = config.client_cert_path.as_ref().unwrap();
240 let key_path = config.client_key_path.as_ref().unwrap();
241
242 let cert_file = std::fs::File::open(cert_path).map_err(|e| {
243 Error::new(
244 ErrorKind::Config,
245 format!("failed to open client cert: {e}"),
246 )
247 })?;
248 let mut cert_reader = BufReader::new(cert_file);
249 let certs = rustls_pemfile::certs(&mut cert_reader)
250 .map(|r| {
251 r.map_err(|e| {
252 Error::new(ErrorKind::Config, format!("invalid client cert: {e}"))
253 })
254 })
255 .collect::<Result<Vec<_>>>()?;
256
257 let key_file = std::fs::File::open(key_path).map_err(|e| {
258 Error::new(ErrorKind::Config, format!("failed to open client key: {e}"))
259 })?;
260 let mut key_reader = BufReader::new(key_file);
261 let key = rustls_pemfile::private_key(&mut key_reader)
262 .map_err(|e| Error::new(ErrorKind::Config, format!("invalid client key: {e}")))?
263 .ok_or_else(|| Error::new(ErrorKind::Config, "no private key found"))?;
264
265 builder
266 .with_client_auth_cert(certs, key)
267 .map_err(|e| Error::new(ErrorKind::Config, format!("invalid client auth: {e}")))?
268 } else {
269 builder.with_no_client_auth()
270 };
271
272 Ok(TlsConnector::from(Arc::new(client_config)))
273 }
274
275 pub type TlsStream = tokio_rustls::client::TlsStream<TcpStream>;
277
278 pub async fn wrap_stream(
288 stream: TcpStream,
289 connector: &TlsConnector,
290 server_name: &str,
291 ) -> Result<TlsStream> {
292 let domain = rustls::pki_types::ServerName::try_from(server_name.to_string())
293 .map_err(|_| Error::new(ErrorKind::Config, "invalid server name"))?;
294
295 connector
296 .connect(domain, stream)
297 .await
298 .map_err(|e| Error::new(ErrorKind::Connection, format!("TLS handshake failed: {e}")))
299 }
300}
301
302#[cfg(test)]
303mod tests {
304 use super::*;
305
306 #[test]
307 fn test_tls_config_default() {
308 let config = TlsConfig::default();
309 assert!(config.verify_server);
310 assert!(config.ca_cert_path.is_none());
311 assert!(!config.has_client_cert());
312 }
313
314 #[test]
315 fn test_tls_config_builder() {
316 let config = TlsConfig::new()
317 .ca_cert("/path/to/ca.pem")
318 .client_cert("/path/to/cert.pem", "/path/to/key.pem")
319 .server_name("example.com");
320
321 assert!(config.has_client_cert());
322 assert_eq!(config.server_name, Some("example.com".to_string()));
323 }
324
325 #[test]
326 fn test_tls_mode() {
327 assert!(!TlsMode::Disable.is_enabled());
328 assert!(TlsMode::Require.is_required());
329 assert!(TlsMode::VerifyFull.verify_hostname());
330 }
331}