1use std::pin::Pin;
53use std::sync::Arc;
54use std::task::{Context, Poll};
55
56use rustls::{ClientConfig, RootCertStore};
57use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
58use tokio::net::TcpStream;
59use tokio_rustls::{client::TlsStream, TlsConnector};
60
61use crate::config::{SslMode, TlsConfig};
62use crate::error::{PgWireError, Result};
63use crate::protocol::framing::write_ssl_request;
64
65#[derive(Debug)]
70pub enum MaybeTlsStream {
71 Plain(TcpStream),
73 Tls(Box<TlsStream<TcpStream>>),
75}
76
77impl MaybeTlsStream {
78 #[inline]
80 pub fn is_tls(&self) -> bool {
81 matches!(self, MaybeTlsStream::Tls(_))
82 }
83
84 #[inline]
86 pub fn is_plain(&self) -> bool {
87 matches!(self, MaybeTlsStream::Plain(_))
88 }
89
90 pub fn get_ref(&self) -> &TcpStream {
94 match self {
95 MaybeTlsStream::Plain(s) => s,
96 MaybeTlsStream::Tls(s) => s.get_ref().0,
97 }
98 }
99}
100
101impl AsyncRead for MaybeTlsStream {
102 fn poll_read(
103 self: Pin<&mut Self>,
104 cx: &mut Context<'_>,
105 buf: &mut ReadBuf<'_>,
106 ) -> Poll<std::io::Result<()>> {
107 match self.get_mut() {
108 MaybeTlsStream::Plain(s) => Pin::new(s).poll_read(cx, buf),
109 MaybeTlsStream::Tls(s) => Pin::new(s).poll_read(cx, buf),
110 }
111 }
112}
113
114impl AsyncWrite for MaybeTlsStream {
115 fn poll_write(
116 self: Pin<&mut Self>,
117 cx: &mut Context<'_>,
118 buf: &[u8],
119 ) -> Poll<std::io::Result<usize>> {
120 match self.get_mut() {
121 MaybeTlsStream::Plain(s) => Pin::new(s).poll_write(cx, buf),
122 MaybeTlsStream::Tls(s) => Pin::new(s).poll_write(cx, buf),
123 }
124 }
125
126 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
127 match self.get_mut() {
128 MaybeTlsStream::Plain(s) => Pin::new(s).poll_flush(cx),
129 MaybeTlsStream::Tls(s) => Pin::new(s).poll_flush(cx),
130 }
131 }
132
133 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
134 match self.get_mut() {
135 MaybeTlsStream::Plain(s) => Pin::new(s).poll_shutdown(cx),
136 MaybeTlsStream::Tls(s) => Pin::new(s).poll_shutdown(cx),
137 }
138 }
139}
140
141pub async fn maybe_upgrade_to_tls(
162 mut tcp: TcpStream,
163 tls: &TlsConfig,
164 host: &str,
165) -> Result<MaybeTlsStream> {
166 match tls.mode {
167 SslMode::Disable => return Ok(MaybeTlsStream::Plain(tcp)),
168 SslMode::Prefer | SslMode::Require | SslMode::VerifyCa | SslMode::VerifyFull => {}
169 }
170
171 let _ = rustls::crypto::ring::default_provider().install_default();
174
175 write_ssl_request(&mut tcp).await?;
177
178 let mut resp = [0u8; 1];
179 use tokio::io::AsyncReadExt;
180 tcp.read_exact(&mut resp).await?;
181
182 if resp[0] != b'S' {
183 return match tls.mode {
185 SslMode::Prefer => Ok(MaybeTlsStream::Plain(tcp)),
186 _ => Err(PgWireError::Tls(
187 "server does not support TLS (SSLRequest rejected)".into(),
188 )),
189 };
190 }
191
192 let verify_chain = matches!(tls.mode, SslMode::VerifyCa | SslMode::VerifyFull);
194 let verify_hostname = matches!(tls.mode, SslMode::VerifyFull);
195
196 let cfg = build_rustls_config(tls, verify_chain, verify_hostname, host)?;
197 let connector = TlsConnector::from(Arc::new(cfg));
198
199 let sni = tls.sni_hostname.as_deref().unwrap_or(host);
201 let server_name = rustls::pki_types::ServerName::try_from(sni.to_string())
202 .map_err(|_| PgWireError::Tls(format!("invalid SNI hostname '{sni}'")))?;
203
204 let tls_stream = connector
205 .connect(server_name, tcp)
206 .await
207 .map_err(|e| PgWireError::Tls(format!("TLS handshake failed: {e}")))?;
208
209 Ok(MaybeTlsStream::Tls(Box::new(tls_stream)))
210}
211
212fn build_rustls_config(
214 tls: &TlsConfig,
215 verify_chain: bool,
216 verify_hostname: bool,
217 host: &str,
218) -> Result<ClientConfig> {
219 let has_cert = tls.client_cert_pem_path.is_some();
221 let has_key = tls.client_key_pem_path.is_some();
222 if has_cert ^ has_key {
223 return Err(PgWireError::Tls(format!(
224 "TLS config error: mTLS requires both client_cert_pem_path and client_key_pem_path \
225 (got cert={has_cert} key={has_key})"
226 )));
227 }
228
229 if verify_hostname && host.parse::<std::net::IpAddr>().is_ok() && tls.sni_hostname.is_none() {
231 return Err(PgWireError::Tls(format!(
232 "TLS config error: VerifyFull enabled but host '{host}' is an IP address. \
233 Hint: use a DNS name matching the certificate, or set tls.sni_hostname, \
234 or use VerifyCa mode."
235 )));
236 }
237
238 let roots = build_root_store(tls)?;
240 let roots_arc = Arc::new(roots.clone());
241
242 let builder = ClientConfig::builder().with_root_certificates(roots);
244
245 let mut cfg: ClientConfig = if has_cert {
247 let cert_path = tls.client_cert_pem_path.as_ref().unwrap();
248 let key_path = tls.client_key_pem_path.as_ref().unwrap();
249
250 let cert_chain = load_cert_chain(cert_path)?;
251 let key = load_private_key(key_path)?;
252
253 builder
254 .with_client_auth_cert(cert_chain, key)
255 .map_err(|e| {
256 PgWireError::Tls(format!("TLS config error: invalid client cert/key: {e}"))
257 })?
258 } else {
259 builder.with_no_client_auth()
260 };
261
262 if !verify_chain {
264 cfg.dangerous()
266 .set_certificate_verifier(Arc::new(NoVerifier));
267 return Ok(cfg);
268 }
269
270 if verify_chain && !verify_hostname {
271 let inner = rustls::client::WebPkiServerVerifier::builder(roots_arc)
273 .build()
274 .map_err(|e| PgWireError::Tls(format!("TLS config error: build verifier: {e}")))?;
275
276 cfg.dangerous()
277 .set_certificate_verifier(Arc::new(VerifyChainOnly { inner }));
278 }
279
280 Ok(cfg)
282}
283
284fn build_root_store(tls: &TlsConfig) -> Result<RootCertStore> {
286 use rustls::pki_types::CertificateDer;
287
288 let mut roots = RootCertStore::empty();
289
290 if let Some(path) = &tls.ca_pem_path {
291 use rustls::pki_types::pem::PemObject;
293
294 let certs: Vec<CertificateDer<'static>> = CertificateDer::pem_file_iter(path)
295 .map_err(|e| {
296 PgWireError::Tls(format!(
297 "TLS config error: failed to open CA PEM '{}': {e}",
298 path.display()
299 ))
300 })?
301 .collect::<std::result::Result<Vec<_>, _>>()
302 .map_err(|e| {
303 PgWireError::Tls(format!(
304 "TLS config error: failed to parse CA PEM '{}': {e}",
305 path.display()
306 ))
307 })?;
308
309 let (added, _ignored) = roots.add_parsable_certificates(certs);
310 if added == 0 {
311 return Err(PgWireError::Tls(format!(
312 "TLS config error: no valid CA certificates found in '{}'",
313 path.display()
314 )));
315 }
316 } else {
317 roots.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
319 }
320
321 Ok(roots)
322}
323
324fn load_cert_chain(
326 path: &std::path::Path,
327) -> Result<Vec<rustls::pki_types::CertificateDer<'static>>> {
328 use rustls::pki_types::pem::PemObject;
329 use rustls::pki_types::CertificateDer;
330
331 let certs: Vec<CertificateDer<'static>> = CertificateDer::pem_file_iter(path)
332 .map_err(|e| {
333 PgWireError::Tls(format!(
334 "TLS config error: failed to open client certificate '{}': {e}",
335 path.display()
336 ))
337 })?
338 .collect::<std::result::Result<Vec<_>, _>>()
339 .map_err(|e| {
340 PgWireError::Tls(format!(
341 "TLS config error: failed to parse client certificate '{}': {e}",
342 path.display()
343 ))
344 })?;
345
346 if certs.is_empty() {
347 return Err(PgWireError::Tls(format!(
348 "TLS config error: no certificates found in '{}'",
349 path.display()
350 )));
351 }
352
353 Ok(certs)
354}
355
356fn load_private_key(path: &std::path::Path) -> Result<rustls::pki_types::PrivateKeyDer<'static>> {
360 use rustls::pki_types::pem::PemObject;
361 use rustls::pki_types::PrivateKeyDer;
362
363 PrivateKeyDer::from_pem_file(path).map_err(|e| {
364 PgWireError::Tls(format!(
365 "TLS config error: failed to load private key from '{}': {e}. \
366 Supported formats: PKCS#8, PKCS#1 (RSA), SEC1 (EC)",
367 path.display()
368 ))
369 })
370}
371
372#[derive(Debug)]
381struct NoVerifier;
382
383impl rustls::client::danger::ServerCertVerifier for NoVerifier {
384 fn verify_server_cert(
385 &self,
386 _end_entity: &rustls::pki_types::CertificateDer<'_>,
387 _intermediates: &[rustls::pki_types::CertificateDer<'_>],
388 _server_name: &rustls::pki_types::ServerName<'_>,
389 _ocsp: &[u8],
390 _now: rustls::pki_types::UnixTime,
391 ) -> std::result::Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
392 Ok(rustls::client::danger::ServerCertVerified::assertion())
393 }
394
395 fn verify_tls12_signature(
396 &self,
397 _message: &[u8],
398 _cert: &rustls::pki_types::CertificateDer<'_>,
399 _dss: &rustls::DigitallySignedStruct,
400 ) -> std::result::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
401 Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
402 }
403
404 fn verify_tls13_signature(
405 &self,
406 _message: &[u8],
407 _cert: &rustls::pki_types::CertificateDer<'_>,
408 _dss: &rustls::DigitallySignedStruct,
409 ) -> std::result::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
410 Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
411 }
412
413 fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
414 vec![
416 rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
417 rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
418 rustls::SignatureScheme::ECDSA_NISTP521_SHA512,
419 rustls::SignatureScheme::ED25519,
420 rustls::SignatureScheme::RSA_PKCS1_SHA256,
421 rustls::SignatureScheme::RSA_PKCS1_SHA384,
422 rustls::SignatureScheme::RSA_PKCS1_SHA512,
423 rustls::SignatureScheme::RSA_PSS_SHA256,
424 rustls::SignatureScheme::RSA_PSS_SHA384,
425 rustls::SignatureScheme::RSA_PSS_SHA512,
426 ]
427 }
428}
429
430#[derive(Debug)]
434struct VerifyChainOnly {
435 inner: Arc<dyn rustls::client::danger::ServerCertVerifier>,
436}
437
438impl rustls::client::danger::ServerCertVerifier for VerifyChainOnly {
439 fn verify_server_cert(
440 &self,
441 end_entity: &rustls::pki_types::CertificateDer<'_>,
442 intermediates: &[rustls::pki_types::CertificateDer<'_>],
443 server_name: &rustls::pki_types::ServerName<'_>,
444 ocsp: &[u8],
445 now: rustls::pki_types::UnixTime,
446 ) -> std::result::Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
447 match self
448 .inner
449 .verify_server_cert(end_entity, intermediates, server_name, ocsp, now)
450 {
451 Ok(ok) => Ok(ok),
452 Err(rustls::Error::InvalidCertificate(rustls::CertificateError::NotValidForName)) => {
454 Ok(rustls::client::danger::ServerCertVerified::assertion())
455 }
456 Err(e) => Err(e),
457 }
458 }
459
460 fn verify_tls12_signature(
461 &self,
462 message: &[u8],
463 cert: &rustls::pki_types::CertificateDer<'_>,
464 dss: &rustls::DigitallySignedStruct,
465 ) -> std::result::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
466 self.inner.verify_tls12_signature(message, cert, dss)
467 }
468
469 fn verify_tls13_signature(
470 &self,
471 message: &[u8],
472 cert: &rustls::pki_types::CertificateDer<'_>,
473 dss: &rustls::DigitallySignedStruct,
474 ) -> std::result::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
475 self.inner.verify_tls13_signature(message, cert, dss)
476 }
477
478 fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
479 self.inner.supported_verify_schemes()
480 }
481}
482
483#[cfg(test)]
484mod tests {
485 use super::*;
486 use std::io::Write;
487 use tempfile::NamedTempFile;
488
489 #[test]
493 fn mtls_requires_both_cert_and_key() {
494 let tls = TlsConfig {
496 client_cert_pem_path: Some("/path/to/cert.pem".into()),
497 client_key_pem_path: None,
498 ..Default::default()
499 };
500 let err = build_rustls_config(&tls, false, false, "localhost").unwrap_err();
501 assert!(err.to_string().contains("mTLS requires both"));
502
503 let tls = TlsConfig {
505 client_cert_pem_path: None,
506 client_key_pem_path: Some("/path/to/key.pem".into()),
507 ..Default::default()
508 };
509 let err = build_rustls_config(&tls, false, false, "localhost").unwrap_err();
510 assert!(err.to_string().contains("mTLS requires both"));
511 }
512
513 #[test]
517 fn verify_full_rejects_ip_without_sni_override() {
518 let tls = TlsConfig {
519 mode: SslMode::VerifyFull,
520 ..Default::default()
521 };
522
523 let err = build_rustls_config(&tls, true, true, "192.168.1.1").unwrap_err();
525 assert!(err.to_string().contains("IP address"));
526 }
527
528 #[test]
535 fn missing_ca_file_gives_clear_error() {
536 let tls = TlsConfig {
537 ca_pem_path: Some("/nonexistent/ca.pem".into()),
538 ..Default::default()
539 };
540
541 let err = build_root_store(&tls).unwrap_err().to_string();
542 assert!(err.contains("failed to open"));
543 assert!(err.contains("ca.pem"));
544 }
545
546 #[test]
547 fn empty_ca_file_gives_clear_error() {
548 let f = NamedTempFile::new().unwrap();
549 let tls = TlsConfig {
550 ca_pem_path: Some(f.path().to_path_buf()),
551 ..Default::default()
552 };
553
554 let err = build_root_store(&tls).unwrap_err().to_string();
555 assert!(err.contains("no valid CA certificates"));
556 }
557
558 #[test]
559 fn empty_key_file_gives_clear_error() {
560 let f = NamedTempFile::new().unwrap();
561
562 let err = load_private_key(f.path()).unwrap_err().to_string();
563 assert!(err.contains("failed to load private key"));
564 }
565
566 #[test]
567 fn invalid_pem_gives_clear_error() {
568 let mut f = NamedTempFile::new().unwrap();
569 f.write_all(b"this is not a valid PEM file").unwrap();
570
571 assert!(load_private_key(f.path()).is_err());
573 assert!(load_cert_chain(f.path()).is_err());
574 }
575}