1use std::pin::Pin;
53use std::task::{Context, Poll};
54use std::{fs::File, io::BufReader, sync::Arc};
55
56use rustls::{ClientConfig, RootCertStore};
57use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
58use tokio::net::TcpStream;
59use tokio_rustls::{client::TlsStream, TlsConnector};
60
61use crate::config::{SslMode, TlsConfig};
62use crate::error::{PgWireError, Result};
63use crate::protocol::framing::write_ssl_request;
64
65#[derive(Debug)]
70pub enum MaybeTlsStream {
71 Plain(TcpStream),
73 Tls(Box<TlsStream<TcpStream>>),
75}
76
77impl MaybeTlsStream {
78 #[inline]
80 pub fn is_tls(&self) -> bool {
81 matches!(self, MaybeTlsStream::Tls(_))
82 }
83
84 #[inline]
86 pub fn is_plain(&self) -> bool {
87 matches!(self, MaybeTlsStream::Plain(_))
88 }
89
90 pub fn get_ref(&self) -> &TcpStream {
94 match self {
95 MaybeTlsStream::Plain(s) => s,
96 MaybeTlsStream::Tls(s) => s.get_ref().0,
97 }
98 }
99}
100
101impl AsyncRead for MaybeTlsStream {
102 fn poll_read(
103 self: Pin<&mut Self>,
104 cx: &mut Context<'_>,
105 buf: &mut ReadBuf<'_>,
106 ) -> Poll<std::io::Result<()>> {
107 match self.get_mut() {
108 MaybeTlsStream::Plain(s) => Pin::new(s).poll_read(cx, buf),
109 MaybeTlsStream::Tls(s) => Pin::new(s).poll_read(cx, buf),
110 }
111 }
112}
113
114impl AsyncWrite for MaybeTlsStream {
115 fn poll_write(
116 self: Pin<&mut Self>,
117 cx: &mut Context<'_>,
118 buf: &[u8],
119 ) -> Poll<std::io::Result<usize>> {
120 match self.get_mut() {
121 MaybeTlsStream::Plain(s) => Pin::new(s).poll_write(cx, buf),
122 MaybeTlsStream::Tls(s) => Pin::new(s).poll_write(cx, buf),
123 }
124 }
125
126 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
127 match self.get_mut() {
128 MaybeTlsStream::Plain(s) => Pin::new(s).poll_flush(cx),
129 MaybeTlsStream::Tls(s) => Pin::new(s).poll_flush(cx),
130 }
131 }
132
133 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
134 match self.get_mut() {
135 MaybeTlsStream::Plain(s) => Pin::new(s).poll_shutdown(cx),
136 MaybeTlsStream::Tls(s) => Pin::new(s).poll_shutdown(cx),
137 }
138 }
139}
140
141pub async fn maybe_upgrade_to_tls(
162 mut tcp: TcpStream,
163 tls: &TlsConfig,
164 host: &str,
165) -> Result<MaybeTlsStream> {
166 match tls.mode {
167 SslMode::Disable => return Ok(MaybeTlsStream::Plain(tcp)),
168 SslMode::Prefer | SslMode::Require | SslMode::VerifyCa | SslMode::VerifyFull => {}
169 }
170
171 let _ = rustls::crypto::ring::default_provider().install_default();
174
175 write_ssl_request(&mut tcp).await?;
177
178 let mut resp = [0u8; 1];
179 use tokio::io::AsyncReadExt;
180 tcp.read_exact(&mut resp).await?;
181
182 if resp[0] != b'S' {
183 return match tls.mode {
185 SslMode::Prefer => Ok(MaybeTlsStream::Plain(tcp)),
186 _ => Err(PgWireError::Tls(
187 "server does not support TLS (SSLRequest rejected)".into(),
188 )),
189 };
190 }
191
192 let verify_chain = matches!(tls.mode, SslMode::VerifyCa | SslMode::VerifyFull);
194 let verify_hostname = matches!(tls.mode, SslMode::VerifyFull);
195
196 let cfg = build_rustls_config(tls, verify_chain, verify_hostname, host)?;
197 let connector = TlsConnector::from(Arc::new(cfg));
198
199 let sni = tls.sni_hostname.as_deref().unwrap_or(host);
201 let server_name = rustls::pki_types::ServerName::try_from(sni.to_string())
202 .map_err(|_| PgWireError::Tls(format!("invalid SNI hostname '{sni}'")))?;
203
204 let tls_stream = connector
205 .connect(server_name, tcp)
206 .await
207 .map_err(|e| PgWireError::Tls(format!("TLS handshake failed: {e}")))?;
208
209 Ok(MaybeTlsStream::Tls(Box::new(tls_stream)))
210}
211
212fn build_rustls_config(
214 tls: &TlsConfig,
215 verify_chain: bool,
216 verify_hostname: bool,
217 host: &str,
218) -> Result<ClientConfig> {
219 let has_cert = tls.client_cert_pem_path.is_some();
221 let has_key = tls.client_key_pem_path.is_some();
222 if has_cert ^ has_key {
223 return Err(PgWireError::Tls(format!(
224 "TLS config error: mTLS requires both client_cert_pem_path and client_key_pem_path \
225 (got cert={has_cert} key={has_key})"
226 )));
227 }
228
229 if verify_hostname && host.parse::<std::net::IpAddr>().is_ok() && tls.sni_hostname.is_none() {
231 return Err(PgWireError::Tls(format!(
232 "TLS config error: VerifyFull enabled but host '{host}' is an IP address. \
233 Hint: use a DNS name matching the certificate, or set tls.sni_hostname, \
234 or use VerifyCa mode."
235 )));
236 }
237
238 let roots = build_root_store(tls)?;
240 let roots_arc = Arc::new(roots.clone());
241
242 let builder = ClientConfig::builder().with_root_certificates(roots);
244
245 let mut cfg: ClientConfig = if has_cert {
247 let cert_path = tls.client_cert_pem_path.as_ref().unwrap();
248 let key_path = tls.client_key_pem_path.as_ref().unwrap();
249
250 let cert_chain = load_cert_chain(cert_path)?;
251 let key = load_private_key(key_path)?;
252
253 builder
254 .with_client_auth_cert(cert_chain, key)
255 .map_err(|e| {
256 PgWireError::Tls(format!("TLS config error: invalid client cert/key: {e}"))
257 })?
258 } else {
259 builder.with_no_client_auth()
260 };
261
262 if !verify_chain {
264 cfg.dangerous()
266 .set_certificate_verifier(Arc::new(NoVerifier));
267 return Ok(cfg);
268 }
269
270 if verify_chain && !verify_hostname {
271 let inner = rustls::client::WebPkiServerVerifier::builder(roots_arc)
273 .build()
274 .map_err(|e| PgWireError::Tls(format!("TLS config error: build verifier: {e}")))?;
275
276 cfg.dangerous()
277 .set_certificate_verifier(Arc::new(VerifyChainOnly { inner }));
278 }
279
280 Ok(cfg)
282}
283
284fn build_root_store(tls: &TlsConfig) -> Result<RootCertStore> {
286 use rustls::pki_types::CertificateDer;
287
288 let mut roots = RootCertStore::empty();
289
290 if let Some(path) = &tls.ca_pem_path {
291 let f = File::open(path).map_err(|e| {
293 PgWireError::Tls(format!(
294 "TLS config error: failed to open CA PEM '{}': {e}",
295 path.display()
296 ))
297 })?;
298 let mut rd = BufReader::new(f);
299
300 let certs: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut rd)
301 .collect::<std::result::Result<Vec<_>, _>>()
302 .map_err(|e| {
303 PgWireError::Tls(format!(
304 "TLS config error: failed to parse CA PEM '{}': {e}",
305 path.display()
306 ))
307 })?
308 .into_iter()
309 .map(|c| c.into_owned())
310 .collect();
311
312 let (added, _ignored) = roots.add_parsable_certificates(certs);
313 if added == 0 {
314 return Err(PgWireError::Tls(format!(
315 "TLS config error: no valid CA certificates found in '{}'",
316 path.display()
317 )));
318 }
319 } else {
320 roots.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
322 }
323
324 Ok(roots)
325}
326
327fn load_cert_chain(
329 path: &std::path::Path,
330) -> Result<Vec<rustls::pki_types::CertificateDer<'static>>> {
331 use rustls::pki_types::CertificateDer;
332
333 let f = File::open(path).map_err(|e| {
334 PgWireError::Tls(format!(
335 "TLS config error: failed to open client certificate '{}': {e}",
336 path.display()
337 ))
338 })?;
339 let mut rd = BufReader::new(f);
340
341 let certs: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut rd)
342 .collect::<std::result::Result<Vec<_>, _>>()
343 .map_err(|e| {
344 PgWireError::Tls(format!(
345 "TLS config error: failed to parse client certificate '{}': {e}",
346 path.display()
347 ))
348 })?
349 .into_iter()
350 .map(|c| c.into_owned())
351 .collect();
352
353 if certs.is_empty() {
354 return Err(PgWireError::Tls(format!(
355 "TLS config error: no certificates found in '{}'",
356 path.display()
357 )));
358 }
359
360 Ok(certs)
361}
362
363fn load_private_key(path: &std::path::Path) -> Result<rustls::pki_types::PrivateKeyDer<'static>> {
367 if let Some(key) = try_load_pkcs8_key(path)? {
369 return Ok(key);
370 }
371
372 if let Some(key) = try_load_rsa_key(path)? {
374 return Ok(key);
375 }
376
377 if let Some(key) = try_load_ec_key(path)? {
379 return Ok(key);
380 }
381
382 Err(PgWireError::Tls(format!(
383 "TLS config error: no private key found in '{}'. \
384 Supported formats: PKCS#8, PKCS#1 (RSA), SEC1 (EC)",
385 path.display()
386 )))
387}
388
389fn try_load_pkcs8_key(
390 path: &std::path::Path,
391) -> Result<Option<rustls::pki_types::PrivateKeyDer<'static>>> {
392 use rustls::pki_types::PrivateKeyDer;
393
394 let f = File::open(path).map_err(|e| {
395 PgWireError::Tls(format!(
396 "TLS config error: failed to open private key '{}': {e}",
397 path.display()
398 ))
399 })?;
400 let mut rd = BufReader::new(f);
401
402 let keys: Vec<PrivateKeyDer<'static>> = rustls_pemfile::pkcs8_private_keys(&mut rd)
403 .filter_map(|r| r.ok())
404 .map(PrivateKeyDer::from)
405 .collect();
406
407 match keys.len() {
408 0 => Ok(None),
409 1 => Ok(Some(keys.into_iter().next().unwrap())),
410 n => Err(PgWireError::Tls(format!(
411 "TLS config error: found {n} PKCS#8 keys in '{}', expected 1",
412 path.display()
413 ))),
414 }
415}
416
417fn try_load_rsa_key(
418 path: &std::path::Path,
419) -> Result<Option<rustls::pki_types::PrivateKeyDer<'static>>> {
420 use rustls::pki_types::PrivateKeyDer;
421
422 let f = File::open(path).map_err(|e| {
423 PgWireError::Tls(format!(
424 "TLS config error: failed to open private key '{}': {e}",
425 path.display()
426 ))
427 })?;
428 let mut rd = BufReader::new(f);
429
430 let keys: Vec<PrivateKeyDer<'static>> = rustls_pemfile::rsa_private_keys(&mut rd)
431 .filter_map(|r| r.ok())
432 .map(PrivateKeyDer::from)
433 .collect();
434
435 match keys.len() {
436 0 => Ok(None),
437 1 => Ok(Some(keys.into_iter().next().unwrap())),
438 n => Err(PgWireError::Tls(format!(
439 "TLS config error: found {n} RSA keys in '{}', expected 1",
440 path.display()
441 ))),
442 }
443}
444
445fn try_load_ec_key(
446 path: &std::path::Path,
447) -> Result<Option<rustls::pki_types::PrivateKeyDer<'static>>> {
448 use rustls::pki_types::PrivateKeyDer;
449
450 let f = File::open(path).map_err(|e| {
451 PgWireError::Tls(format!(
452 "TLS config error: failed to open private key '{}': {e}",
453 path.display()
454 ))
455 })?;
456 let mut rd = BufReader::new(f);
457
458 let keys: Vec<PrivateKeyDer<'static>> = rustls_pemfile::ec_private_keys(&mut rd)
459 .filter_map(|r| r.ok())
460 .map(PrivateKeyDer::from)
461 .collect();
462
463 match keys.len() {
464 0 => Ok(None),
465 1 => Ok(Some(keys.into_iter().next().unwrap())),
466 n => Err(PgWireError::Tls(format!(
467 "TLS config error: found {n} EC keys in '{}', expected 1",
468 path.display()
469 ))),
470 }
471}
472
473#[derive(Debug)]
482struct NoVerifier;
483
484impl rustls::client::danger::ServerCertVerifier for NoVerifier {
485 fn verify_server_cert(
486 &self,
487 _end_entity: &rustls::pki_types::CertificateDer<'_>,
488 _intermediates: &[rustls::pki_types::CertificateDer<'_>],
489 _server_name: &rustls::pki_types::ServerName<'_>,
490 _ocsp: &[u8],
491 _now: rustls::pki_types::UnixTime,
492 ) -> std::result::Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
493 Ok(rustls::client::danger::ServerCertVerified::assertion())
494 }
495
496 fn verify_tls12_signature(
497 &self,
498 _message: &[u8],
499 _cert: &rustls::pki_types::CertificateDer<'_>,
500 _dss: &rustls::DigitallySignedStruct,
501 ) -> std::result::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
502 Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
503 }
504
505 fn verify_tls13_signature(
506 &self,
507 _message: &[u8],
508 _cert: &rustls::pki_types::CertificateDer<'_>,
509 _dss: &rustls::DigitallySignedStruct,
510 ) -> std::result::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
511 Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
512 }
513
514 fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
515 vec![
517 rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
518 rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
519 rustls::SignatureScheme::ECDSA_NISTP521_SHA512,
520 rustls::SignatureScheme::ED25519,
521 rustls::SignatureScheme::RSA_PKCS1_SHA256,
522 rustls::SignatureScheme::RSA_PKCS1_SHA384,
523 rustls::SignatureScheme::RSA_PKCS1_SHA512,
524 rustls::SignatureScheme::RSA_PSS_SHA256,
525 rustls::SignatureScheme::RSA_PSS_SHA384,
526 rustls::SignatureScheme::RSA_PSS_SHA512,
527 ]
528 }
529}
530
531#[derive(Debug)]
535struct VerifyChainOnly {
536 inner: Arc<dyn rustls::client::danger::ServerCertVerifier>,
537}
538
539impl rustls::client::danger::ServerCertVerifier for VerifyChainOnly {
540 fn verify_server_cert(
541 &self,
542 end_entity: &rustls::pki_types::CertificateDer<'_>,
543 intermediates: &[rustls::pki_types::CertificateDer<'_>],
544 server_name: &rustls::pki_types::ServerName<'_>,
545 ocsp: &[u8],
546 now: rustls::pki_types::UnixTime,
547 ) -> std::result::Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
548 match self
549 .inner
550 .verify_server_cert(end_entity, intermediates, server_name, ocsp, now)
551 {
552 Ok(ok) => Ok(ok),
553 Err(rustls::Error::InvalidCertificate(rustls::CertificateError::NotValidForName)) => {
555 Ok(rustls::client::danger::ServerCertVerified::assertion())
556 }
557 Err(e) => Err(e),
558 }
559 }
560
561 fn verify_tls12_signature(
562 &self,
563 message: &[u8],
564 cert: &rustls::pki_types::CertificateDer<'_>,
565 dss: &rustls::DigitallySignedStruct,
566 ) -> std::result::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
567 self.inner.verify_tls12_signature(message, cert, dss)
568 }
569
570 fn verify_tls13_signature(
571 &self,
572 message: &[u8],
573 cert: &rustls::pki_types::CertificateDer<'_>,
574 dss: &rustls::DigitallySignedStruct,
575 ) -> std::result::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
576 self.inner.verify_tls13_signature(message, cert, dss)
577 }
578
579 fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
580 self.inner.supported_verify_schemes()
581 }
582}
583
584#[cfg(test)]
585mod tests {
586 use super::*;
587 use std::io::Write;
588 use tempfile::NamedTempFile;
589
590 #[test]
594 fn mtls_requires_both_cert_and_key() {
595 let tls = TlsConfig {
597 client_cert_pem_path: Some("/path/to/cert.pem".into()),
598 client_key_pem_path: None,
599 ..Default::default()
600 };
601 let err = build_rustls_config(&tls, false, false, "localhost").unwrap_err();
602 assert!(err.to_string().contains("mTLS requires both"));
603
604 let tls = TlsConfig {
606 client_cert_pem_path: None,
607 client_key_pem_path: Some("/path/to/key.pem".into()),
608 ..Default::default()
609 };
610 let err = build_rustls_config(&tls, false, false, "localhost").unwrap_err();
611 assert!(err.to_string().contains("mTLS requires both"));
612 }
613
614 #[test]
618 fn verify_full_rejects_ip_without_sni_override() {
619 let tls = TlsConfig {
620 mode: SslMode::VerifyFull,
621 ..Default::default()
622 };
623
624 let err = build_rustls_config(&tls, true, true, "192.168.1.1").unwrap_err();
626 assert!(err.to_string().contains("IP address"));
627 }
628
629 #[test]
636 fn missing_ca_file_gives_clear_error() {
637 let tls = TlsConfig {
638 ca_pem_path: Some("/nonexistent/ca.pem".into()),
639 ..Default::default()
640 };
641
642 let err = build_root_store(&tls).unwrap_err().to_string();
643 assert!(err.contains("failed to open"));
644 assert!(err.contains("ca.pem"));
645 }
646
647 #[test]
648 fn empty_ca_file_gives_clear_error() {
649 let f = NamedTempFile::new().unwrap();
650 let tls = TlsConfig {
651 ca_pem_path: Some(f.path().to_path_buf()),
652 ..Default::default()
653 };
654
655 let err = build_root_store(&tls).unwrap_err().to_string();
656 assert!(err.contains("no valid CA certificates"));
657 }
658
659 #[test]
660 fn empty_key_file_gives_clear_error() {
661 let f = NamedTempFile::new().unwrap();
662
663 let err = load_private_key(f.path()).unwrap_err().to_string();
664 assert!(err.contains("no private key"));
665 }
666
667 #[test]
668 fn invalid_pem_gives_clear_error() {
669 let mut f = NamedTempFile::new().unwrap();
670 f.write_all(b"this is not a valid PEM file").unwrap();
671
672 assert!(load_private_key(f.path()).is_err());
674 assert!(load_cert_chain(f.path()).is_err());
675 }
676}