1use std::io;
2use std::pin::Pin;
3use std::sync::Arc;
4use std::task::{Context, Poll};
5
6use futures::{SinkExt, StreamExt};
7#[cfg(any(feature = "_ring", feature = "_aws-lc-rs"))]
8use rustls_pki_types::CertificateDer;
9use tokio::io::{AsyncRead, AsyncWrite};
10use tokio::net::TcpStream;
11#[cfg(unix)]
12use tokio::net::UnixStream;
13use tokio::time::{Duration, sleep};
14#[cfg(any(feature = "_ring", feature = "_aws-lc-rs"))]
15use tokio_rustls::server::TlsStream;
16use tokio_util::codec::{Decoder, Encoder, Framed, FramedParts};
17
18use crate::api::auth::StartupHandler;
19use crate::api::cancel::CancelHandler;
20use crate::api::copy::CopyHandler;
21use crate::api::query::{ExtendedQueryHandler, SimpleQueryHandler, send_ready_for_query};
22use crate::api::{
23 ClientInfo, ClientPortalStore, DefaultClient, ErrorHandler, PgWireConnectionState,
24 PgWireServerHandlers,
25};
26use crate::error::{ErrorInfo, PgWireError, PgWireResult};
27use crate::messages::response::{GssEncResponse, ReadyForQuery, SslResponse, TransactionStatus};
28use crate::messages::startup::SecretKey;
29use crate::messages::{
30 DecodeContext, PgWireBackendMessage, PgWireFrontendMessage, ProtocolVersion,
31 SslNegotiationMetaMessage,
32};
33
34const STARTUP_TIMEOUT_MILLIS: u64 = 60_000;
36
37#[non_exhaustive]
38#[derive(Debug, new)]
39pub struct PgWireMessageServerCodec<S> {
40 pub client_info: DefaultClient<S>,
41 #[new(default)]
42 decode_context: DecodeContext,
43}
44
45impl<S> Decoder for PgWireMessageServerCodec<S> {
46 type Item = PgWireFrontendMessage;
47 type Error = PgWireError;
48
49 fn decode(&mut self, src: &mut bytes::BytesMut) -> Result<Option<Self::Item>, Self::Error> {
50 self.decode_context.protocol_version = self.client_info.protocol_version;
51
52 match self.client_info.state() {
53 PgWireConnectionState::AwaitingSslRequest => {
54 self.decode_context.awaiting_frontend_ssl = true;
55 self.decode_context.awaiting_frontend_startup = true;
56 }
57
58 PgWireConnectionState::AwaitingStartup => {
59 self.decode_context.awaiting_frontend_ssl = false;
60 self.decode_context.awaiting_frontend_startup = true;
61 }
62
63 _ => {
64 self.decode_context.awaiting_frontend_startup = false;
65 self.decode_context.awaiting_frontend_ssl = false;
66 }
67 }
68
69 PgWireFrontendMessage::decode(src, &self.decode_context)
70 }
71}
72
73impl<S> Encoder<PgWireBackendMessage> for PgWireMessageServerCodec<S> {
74 type Error = io::Error;
75
76 fn encode(
77 &mut self,
78 item: PgWireBackendMessage,
79 dst: &mut bytes::BytesMut,
80 ) -> Result<(), Self::Error> {
81 item.encode(dst).map_err(Into::into)
82 }
83}
84
85impl<T: 'static, S> ClientInfo for Framed<T, PgWireMessageServerCodec<S>> {
86 fn socket_addr(&self) -> std::net::SocketAddr {
87 self.codec().client_info.socket_addr
88 }
89
90 fn is_secure(&self) -> bool {
91 self.codec().client_info.is_secure
92 }
93
94 fn pid_and_secret_key(&self) -> (i32, SecretKey) {
95 self.codec().client_info.pid_and_secret_key()
96 }
97
98 fn set_pid_and_secret_key(&mut self, pid: i32, secret_key: SecretKey) {
99 self.codec_mut()
100 .client_info
101 .set_pid_and_secret_key(pid, secret_key);
102 }
103
104 fn protocol_version(&self) -> ProtocolVersion {
105 self.codec().client_info.protocol_version()
106 }
107
108 fn set_protocol_version(&mut self, version: ProtocolVersion) {
109 self.codec_mut().client_info.set_protocol_version(version);
110 }
111
112 fn state(&self) -> PgWireConnectionState {
113 self.codec().client_info.state
114 }
115
116 fn set_state(&mut self, new_state: PgWireConnectionState) {
117 self.codec_mut().client_info.set_state(new_state);
118 }
119
120 fn metadata(&self) -> &std::collections::HashMap<String, String> {
121 self.codec().client_info.metadata()
122 }
123
124 fn metadata_mut(&mut self) -> &mut std::collections::HashMap<String, String> {
125 self.codec_mut().client_info.metadata_mut()
126 }
127
128 fn transaction_status(&self) -> TransactionStatus {
129 self.codec().client_info.transaction_status()
130 }
131
132 fn set_transaction_status(&mut self, new_status: TransactionStatus) {
133 self.codec_mut()
134 .client_info
135 .set_transaction_status(new_status);
136 }
137
138 #[cfg(any(feature = "_ring", feature = "_aws-lc-rs"))]
139 fn sni_server_name(&self) -> Option<&str> {
140 self.codec().client_info.sni_server_name()
141 }
142
143 #[cfg(any(feature = "_ring", feature = "_aws-lc-rs"))]
144 fn client_certificates<'a>(&self) -> Option<&[CertificateDer<'a>]> {
145 if !self.is_secure() {
146 None
147 } else {
148 let socket =
149 <dyn std::any::Any>::downcast_ref::<TlsStream<TcpStream>>(self.get_ref()).unwrap();
150 let (_, tls_session) = socket.get_ref();
151 tls_session.peer_certificates()
152 }
153 }
154}
155
156impl<T, S> ClientPortalStore for Framed<T, PgWireMessageServerCodec<S>> {
157 type PortalStore = <DefaultClient<S> as ClientPortalStore>::PortalStore;
158
159 fn portal_store(&self) -> &Self::PortalStore {
160 self.codec().client_info.portal_store()
161 }
162}
163
164pub async fn process_message<S, A, Q, EQ, C, CR>(
165 message: PgWireFrontendMessage,
166 socket: &mut Framed<S, PgWireMessageServerCodec<EQ::Statement>>,
167 authenticator: Arc<A>,
168 query_handler: Arc<Q>,
169 extended_query_handler: Arc<EQ>,
170 copy_handler: Arc<C>,
171 cancel_handler: Arc<CR>,
172) -> PgWireResult<()>
173where
174 S: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static,
175 A: StartupHandler,
176 Q: SimpleQueryHandler,
177 EQ: ExtendedQueryHandler,
178 C: CopyHandler,
179 CR: CancelHandler,
180{
181 if let PgWireFrontendMessage::CancelRequest(cancel) = message {
183 cancel_handler.on_cancel_request(cancel).await;
184 socket.close().await?;
185 return Ok(());
186 }
187
188 match socket.state() {
189 PgWireConnectionState::AwaitingStartup
190 | PgWireConnectionState::AuthenticationInProgress => {
191 authenticator.on_startup(socket, message).await?;
192 }
193 PgWireConnectionState::AwaitingSync => {
199 if let PgWireFrontendMessage::Sync(sync) = message {
200 extended_query_handler.on_sync(socket, sync).await?;
201 socket.set_state(PgWireConnectionState::ReadyForQuery);
203 }
204 }
205 PgWireConnectionState::CopyInProgress(is_extended_query) => {
206 match message {
208 PgWireFrontendMessage::CopyData(copy_data) => {
209 copy_handler.on_copy_data(socket, copy_data).await?;
210 }
211 PgWireFrontendMessage::CopyDone(copy_done) => {
212 let result = copy_handler.on_copy_done(socket, copy_done).await;
213 if !is_extended_query {
214 socket.set_state(PgWireConnectionState::ReadyForQuery);
219 }
220 match result {
221 Ok(_) => {
222 if !is_extended_query {
223 send_ready_for_query(socket, TransactionStatus::Idle).await?
227 } else {
228 }
233 }
234 err => return err,
235 }
236 }
237 PgWireFrontendMessage::CopyFail(copy_fail) => {
238 let error = copy_handler.on_copy_fail(socket, copy_fail).await;
239 if !is_extended_query {
240 socket.set_state(PgWireConnectionState::ReadyForQuery);
245 }
246 return Err(error);
247 }
248 _ => {}
249 }
250 }
251 _ => {
252 match message {
254 PgWireFrontendMessage::Query(query) => {
255 query_handler.on_query(socket, query).await?;
256 }
257 PgWireFrontendMessage::Parse(parse) => {
258 extended_query_handler.on_parse(socket, parse).await?;
259 }
260 PgWireFrontendMessage::Bind(bind) => {
261 extended_query_handler.on_bind(socket, bind).await?;
262 }
263 PgWireFrontendMessage::Execute(execute) => {
264 extended_query_handler.on_execute(socket, execute).await?;
265 }
266 PgWireFrontendMessage::Describe(describe) => {
267 extended_query_handler.on_describe(socket, describe).await?;
268 }
269 PgWireFrontendMessage::Flush(flush) => {
270 extended_query_handler.on_flush(socket, flush).await?;
271 }
272 PgWireFrontendMessage::Sync(sync) => {
273 extended_query_handler.on_sync(socket, sync).await?;
274 }
275 PgWireFrontendMessage::Close(close) => {
276 extended_query_handler.on_close(socket, close).await?;
277 }
278 _ => {}
279 }
280 }
281 }
282 Ok(())
283}
284
285pub async fn process_error<S, ST>(
286 socket: &mut Framed<S, PgWireMessageServerCodec<ST>>,
287 error: PgWireError,
288 wait_for_sync: bool,
289) -> Result<(), io::Error>
290where
291 S: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static,
292{
293 let error_info: ErrorInfo = error.into();
294 let is_fatal = error_info.is_fatal();
295 socket
296 .send(PgWireBackendMessage::ErrorResponse(error_info.into()))
297 .await?;
298
299 let transaction_status = socket.transaction_status().to_error_state();
300 socket.set_transaction_status(transaction_status);
301
302 if wait_for_sync {
303 socket.set_state(PgWireConnectionState::AwaitingSync);
304 } else {
305 socket.set_state(PgWireConnectionState::ReadyForQuery);
306 socket
307 .feed(PgWireBackendMessage::ReadyForQuery(ReadyForQuery::new(
308 transaction_status,
309 )))
310 .await?;
311 }
312 socket.flush().await?;
313
314 if is_fatal {
315 return socket.close().await;
316 }
317
318 Ok(())
319}
320
321#[derive(Debug, PartialEq, Eq)]
322enum SslNegotiationType {
323 Postgres,
324 Direct,
325 None,
326}
327
328async fn check_ssl_direct_negotiation(tcp_socket: &TcpStream) -> Result<bool, io::Error> {
329 let mut buf = [0u8; 1];
330 let n = tcp_socket.peek(&mut buf).await?;
331
332 Ok(n > 0 && buf[0] == 0x16)
333}
334
335async fn peek_for_sslrequest<ST>(
336 socket: &mut Framed<TcpStream, PgWireMessageServerCodec<ST>>,
337 ssl_supported: bool,
338) -> Result<SslNegotiationType, io::Error> {
339 if check_ssl_direct_negotiation(socket.get_ref()).await? {
340 Ok(SslNegotiationType::Direct)
341 } else {
342 let mut ssl_done = false;
343 let mut gss_done = false;
344
345 loop {
346 match socket.next().await {
347 Some(Ok(PgWireFrontendMessage::SslNegotiation(
349 SslNegotiationMetaMessage::PostgresSsl(_),
350 ))) => {
351 if ssl_supported {
353 socket
354 .send(PgWireBackendMessage::SslResponse(SslResponse::Accept))
355 .await?;
356 return Ok(SslNegotiationType::Postgres);
357 } else {
358 socket
359 .send(PgWireBackendMessage::SslResponse(SslResponse::Refuse))
360 .await?;
361 ssl_done = true;
362
363 if gss_done {
364 return Ok(SslNegotiationType::None);
365 } else {
366 continue;
368 }
369 }
370 }
371
372 Some(Ok(PgWireFrontendMessage::SslNegotiation(
374 SslNegotiationMetaMessage::PostgresGss(_),
375 ))) => {
376 socket
377 .send(PgWireBackendMessage::GssEncResponse(GssEncResponse::Refuse))
378 .await?;
379 gss_done = true;
380
381 if ssl_done {
382 return Ok(SslNegotiationType::None);
383 } else {
384 continue;
386 }
387 }
388
389 _ => {
391 return Ok(SslNegotiationType::None);
392 }
393 }
394 }
395 }
396}
397
398#[cfg(any(feature = "_ring", feature = "_aws-lc-rs"))]
399fn check_alpn_for_direct_ssl<IO>(tls_socket: &TlsStream<IO>) -> Result<(), io::Error> {
400 let (_, the_conn) = tls_socket.get_ref();
401 let mut accept = false;
402
403 if let Some(alpn) = the_conn.alpn_protocol()
404 && alpn == super::POSTGRESQL_ALPN_NAME
405 {
406 accept = true;
407 }
408
409 if !accept {
410 Err(io::Error::new(
411 io::ErrorKind::InvalidData,
412 "received direct SSL connection request without ALPN protocol negotiation extension",
413 ))
414 } else {
415 Ok(())
416 }
417}
418
419#[non_exhaustive]
420pub enum MaybeTls {
421 Plain(TcpStream),
422 #[cfg(unix)]
423 Unix(UnixStream),
424 #[cfg(any(feature = "_ring", feature = "_aws-lc-rs"))]
425 Tls(Box<TlsStream<TcpStream>>),
426}
427
428macro_rules! maybe_tls {
429 ($self:ident, $poll_x:ident($($args:expr),*)) => {
430 match $self.get_mut() {
431 MaybeTls::Plain(io) => Pin::new(io).$poll_x($($args),*),
432 #[cfg(unix)]
433 MaybeTls::Unix(io) => Pin::new(io).$poll_x($($args),*),
434 #[cfg(any(feature = "_ring", feature = "_aws-lc-rs"))]
435 MaybeTls::Tls(io) => Pin::new(io).$poll_x($($args),*),
436 }
437 };
438}
439
440impl AsyncRead for MaybeTls {
441 fn poll_read(
442 self: Pin<&mut Self>,
443 cx: &mut Context<'_>,
444 buf: &mut tokio::io::ReadBuf<'_>,
445 ) -> Poll<io::Result<()>> {
446 maybe_tls!(self, poll_read(cx, buf))
447 }
448}
449
450impl AsyncWrite for MaybeTls {
451 fn poll_write(
452 self: Pin<&mut Self>,
453 cx: &mut Context<'_>,
454 buf: &[u8],
455 ) -> Poll<io::Result<usize>> {
456 maybe_tls!(self, poll_write(cx, buf))
457 }
458
459 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
460 maybe_tls!(self, poll_flush(cx))
461 }
462
463 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
464 maybe_tls!(self, poll_shutdown(cx))
465 }
466}
467
468pub async fn negotiate_tls<S>(
473 tcp_socket: TcpStream,
474 tls_acceptor: Option<crate::tokio::TlsAcceptor>,
475) -> io::Result<Option<Framed<MaybeTls, PgWireMessageServerCodec<S>>>> {
476 let addr = tcp_socket.peer_addr()?;
477 tcp_socket.set_nodelay(true)?;
478
479 let client_info = DefaultClient::new(addr, false);
480 let mut tcp_socket = Framed::new(tcp_socket, PgWireMessageServerCodec::new(client_info));
481
482 let ssl = peek_for_sslrequest(&mut tcp_socket, tls_acceptor.is_some()).await?;
485
486 let old_parts = tcp_socket.into_parts();
487
488 if ssl == SslNegotiationType::None {
489 let mut parts = FramedParts::new(MaybeTls::Plain(old_parts.io), old_parts.codec);
490 parts.read_buf = old_parts.read_buf;
491 parts.write_buf = old_parts.write_buf;
492 let mut socket = Framed::from_parts(parts);
493
494 socket.set_state(PgWireConnectionState::AwaitingStartup);
495
496 return Ok(Some(socket));
497 }
498 #[cfg(any(feature = "_ring", feature = "_aws-lc-rs"))]
499 if let Some(tls_acceptor) = tls_acceptor {
500 let mut client_info = DefaultClient::new(addr, true);
502
503 let ssl_socket = Box::new(tls_acceptor.accept(old_parts.io).await?);
504
505 if ssl == SslNegotiationType::Direct {
507 check_alpn_for_direct_ssl(&ssl_socket)?;
508 }
509
510 let sni = {
512 let (_, conn) = ssl_socket.get_ref();
513 conn.server_name().map(|s| s.to_string())
514 };
515 if let Some(s) = sni {
516 client_info.sni_server_name = Some(s);
517 }
518
519 let mut parts = FramedParts::new(
520 MaybeTls::Tls(ssl_socket),
521 PgWireMessageServerCodec::new(client_info),
522 );
523 parts.read_buf = old_parts.read_buf;
524 parts.write_buf = old_parts.write_buf;
525 let mut socket = Framed::from_parts(parts);
526
527 socket.set_state(PgWireConnectionState::AwaitingStartup);
528
529 return Ok(Some(socket));
530 }
531 Ok(None)
532}
533
534macro_rules! process_socket_messages {
538 ($socket:expr, $startup_timeout:expr, $handlers:expr) => {{
539 let startup_handler = $handlers.startup_handler();
540 let simple_query_handler = $handlers.simple_query_handler();
541 let extended_query_handler = $handlers.extended_query_handler();
542 let copy_handler = $handlers.copy_handler();
543 let cancel_handler = $handlers.cancel_handler();
544 let error_handler = $handlers.error_handler();
545
546 let socket = &mut $socket;
547 loop {
548 let msg = if matches!(
549 socket.state(),
550 PgWireConnectionState::AwaitingStartup
551 | PgWireConnectionState::AuthenticationInProgress
552 ) {
553 tokio::select! {
554 _ = &mut $startup_timeout => None,
555 msg = socket.next() => msg,
556 }
557 } else {
558 socket.next().await
559 };
560
561 if let Some(Ok(msg)) = msg {
562 let is_extended_query = match socket.state() {
563 PgWireConnectionState::CopyInProgress(is_extended_query) => is_extended_query,
564 _ => msg.is_extended_query(),
565 };
566 if let Err(mut e) = process_message(
567 msg,
568 socket,
569 startup_handler.clone(),
570 simple_query_handler.clone(),
571 extended_query_handler.clone(),
572 copy_handler.clone(),
573 cancel_handler.clone(),
574 )
575 .await
576 {
577 error_handler.on_error(socket, &mut e);
578 process_error(socket, e, is_extended_query).await?;
579 }
580 } else {
581 break;
582 }
583 }
584 }};
585}
586
587#[cfg(unix)]
589pub async fn process_socket_unix<H>(unix_socket: UnixStream, handlers: H) -> Result<(), io::Error>
590where
591 H: PgWireServerHandlers,
592{
593 let startup_timeout = sleep(Duration::from_millis(STARTUP_TIMEOUT_MILLIS));
594 tokio::pin!(startup_timeout);
595
596 let addr = "127.0.0.1:0".parse().unwrap();
599
600 let client_info = DefaultClient::new(addr, false);
601 let mut socket = Framed::new(
602 MaybeTls::Unix(unix_socket),
603 PgWireMessageServerCodec::new(client_info),
604 );
605
606 socket.set_state(PgWireConnectionState::AwaitingStartup);
607
608 process_socket_messages!(socket, startup_timeout, handlers);
609 Ok(())
610}
611
612pub async fn process_socket<H>(
613 tcp_socket: TcpStream,
614 tls_acceptor: Option<crate::tokio::TlsAcceptor>,
615 handlers: H,
616) -> Result<(), io::Error>
617where
618 H: PgWireServerHandlers,
619{
620 let startup_timeout = sleep(Duration::from_millis(STARTUP_TIMEOUT_MILLIS));
623 tokio::pin!(startup_timeout);
624
625 let socket = tokio::select! {
628 _ = &mut startup_timeout => {
629 return Ok(())
630 },
631 socket = negotiate_tls(tcp_socket, tls_acceptor) => {
632 socket?
633 }
634 };
635 let Some(mut socket) = socket else {
636 return Ok(());
639 };
640
641 process_socket_messages!(socket, startup_timeout, handlers);
642 Ok(())
643}
644
645#[cfg(all(test, any(feature = "_ring", feature = "_aws-lc-rs")))]
646mod tests {
647 use super::*;
648 use std::fs::File;
649 use std::io::{BufReader, Error as IOError};
650 use std::sync::Arc;
651 use tokio::sync::oneshot;
652 use tokio_rustls::TlsAcceptor;
653 use tokio_rustls::TlsConnector;
654 use tokio_rustls::rustls;
655 use tokio_rustls::rustls::crypto::CryptoProvider;
656
657 fn load_test_server_config() -> Result<rustls::ServerConfig, IOError> {
658 use rustls_pemfile::{certs, pkcs8_private_keys};
659 use rustls_pki_types::{CertificateDer, PrivateKeyDer};
660
661 let certs = certs(&mut BufReader::new(File::open("examples/ssl/server.crt")?))
662 .collect::<Result<Vec<CertificateDer>, _>>()?;
663 let key = pkcs8_private_keys(&mut BufReader::new(File::open("examples/ssl/server.key")?))
664 .map(|key| key.map(PrivateKeyDer::from))
665 .collect::<Result<Vec<PrivateKeyDer>, _>>()?
666 .remove(0);
667
668 let mut cfg = rustls::ServerConfig::builder()
669 .with_no_client_auth()
670 .with_single_cert(certs, key)
671 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidInput, e))?;
672 cfg.alpn_protocols = vec![crate::tokio::POSTGRESQL_ALPN_NAME.to_vec()];
674 Ok(cfg)
675 }
676
677 fn make_test_client_connector() -> Result<TlsConnector, IOError> {
678 #[derive(Debug)]
681 struct NoCertVerifier;
682 impl rustls::client::danger::ServerCertVerifier for NoCertVerifier {
683 fn verify_server_cert(
684 &self,
685 _end_entity: &rustls::pki_types::CertificateDer<'_>,
686 _intermediates: &[rustls::pki_types::CertificateDer<'_>],
687 _server_name: &rustls::pki_types::ServerName<'_>,
688 _ocsp_response: &[u8],
689 _now: rustls::pki_types::UnixTime,
690 ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
691 Ok(rustls::client::danger::ServerCertVerified::assertion())
692 }
693
694 fn verify_tls12_signature(
695 &self,
696 _message: &[u8],
697 _cert: &rustls::pki_types::CertificateDer<'_>,
698 _dss: &rustls::DigitallySignedStruct,
699 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error>
700 {
701 Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
702 }
703
704 fn verify_tls13_signature(
705 &self,
706 _message: &[u8],
707 _cert: &rustls::pki_types::CertificateDer<'_>,
708 _dss: &rustls::DigitallySignedStruct,
709 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error>
710 {
711 Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
712 }
713
714 fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
715 vec![
716 rustls::SignatureScheme::RSA_PKCS1_SHA256,
717 rustls::SignatureScheme::RSA_PKCS1_SHA384,
718 rustls::SignatureScheme::RSA_PKCS1_SHA512,
719 rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
720 rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
721 rustls::SignatureScheme::RSA_PSS_SHA256,
722 rustls::SignatureScheme::RSA_PSS_SHA384,
723 rustls::SignatureScheme::RSA_PSS_SHA512,
724 ]
725 }
726 }
727
728 let mut cfg = rustls::ClientConfig::builder()
729 .dangerous()
730 .with_custom_certificate_verifier(Arc::new(NoCertVerifier))
731 .with_no_client_auth();
732 cfg.alpn_protocols = vec![crate::tokio::POSTGRESQL_ALPN_NAME.to_vec()];
734 Ok(TlsConnector::from(Arc::new(cfg)))
735 }
736
737 #[tokio::test]
738 #[ignore]
739 async fn server_name_metadata_is_set_from_tls_sni() {
740 use std::net::SocketAddr;
741 use tokio::io::duplex;
742
743 let server_cfg = load_test_server_config().expect("server config");
745 let acceptor = TlsAcceptor::from(Arc::new(server_cfg));
746 let connector = make_test_client_connector().expect("client connector");
747
748 let (server_io, client_io) = duplex(64 * 1024);
750
751 let (tx, rx) = oneshot::channel::<Option<String>>();
752
753 tokio::spawn(async move {
755 let tls = acceptor.accept(server_io).await.unwrap();
756
757 let sni = {
759 let (_, conn) = tls.get_ref();
760 conn.server_name().map(|s| s.to_string())
761 };
762 let peer: SocketAddr = "127.0.0.1:0".parse().unwrap();
763 let mut ci: DefaultClient<()> = DefaultClient::new(peer, true);
764 if let Some(s) = sni {
765 ci.sni_server_name = Some(s);
766 }
767 let framed = Framed::new(tls, PgWireMessageServerCodec::new(ci));
768 let server_name = framed.sni_server_name().map(str::to_string);
769 let _ = tx.send(server_name);
770 });
771
772 let server_name = rustls_pki_types::ServerName::try_from("localhost").unwrap();
774 let _ = connector.connect(server_name, client_io).await.unwrap();
775
776 let observed = rx.await.expect("server_name from server");
778 assert_eq!(observed.as_deref(), Some("localhost"));
779 }
780
781 #[tokio::test]
782 async fn server_name_metadata_is_set_from_tls_sni_in_memory() {
783 use std::net::SocketAddr;
784
785 #[cfg(feature = "_aws-lc-rs")]
786 CryptoProvider::install_default(tokio_rustls::rustls::crypto::aws_lc_rs::default_provider()).unwrap();
787 #[cfg(feature = "_ring")]
788 CryptoProvider::install_default(tokio_rustls::rustls::crypto::ring::default_provider())
789 .unwrap();
790
791 let server_cfg = Arc::new(load_test_server_config().expect("server config"));
793
794 #[derive(Debug)]
796 struct NoCertVerifier;
797 impl rustls::client::danger::ServerCertVerifier for NoCertVerifier {
798 fn verify_server_cert(
799 &self,
800 _end_entity: &rustls::pki_types::CertificateDer<'_>,
801 _intermediates: &[rustls::pki_types::CertificateDer<'_>],
802 _server_name: &rustls::pki_types::ServerName<'_>,
803 _ocsp_response: &[u8],
804 _now: rustls::pki_types::UnixTime,
805 ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
806 Ok(rustls::client::danger::ServerCertVerified::assertion())
807 }
808 fn verify_tls12_signature(
809 &self,
810 _message: &[u8],
811 _cert: &rustls::pki_types::CertificateDer<'_>,
812 _dss: &rustls::DigitallySignedStruct,
813 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error>
814 {
815 Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
816 }
817 fn verify_tls13_signature(
818 &self,
819 _message: &[u8],
820 _cert: &rustls::pki_types::CertificateDer<'_>,
821 _dss: &rustls::DigitallySignedStruct,
822 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error>
823 {
824 Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
825 }
826 fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
827 vec![
828 rustls::SignatureScheme::RSA_PKCS1_SHA256,
829 rustls::SignatureScheme::RSA_PKCS1_SHA384,
830 rustls::SignatureScheme::RSA_PKCS1_SHA512,
831 rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
832 rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
833 rustls::SignatureScheme::RSA_PSS_SHA256,
834 rustls::SignatureScheme::RSA_PSS_SHA384,
835 rustls::SignatureScheme::RSA_PSS_SHA512,
836 ]
837 }
838 }
839
840 let mut client_cfg = rustls::ClientConfig::builder()
841 .dangerous()
842 .with_custom_certificate_verifier(Arc::new(NoCertVerifier))
843 .with_no_client_auth();
844 client_cfg.alpn_protocols = vec![crate::tokio::POSTGRESQL_ALPN_NAME.to_vec()];
845 let client_cfg = Arc::new(client_cfg);
846
847 let mut server_conn = rustls::ServerConnection::new(server_cfg).unwrap();
849 let mut client_conn = rustls::ClientConnection::new(
850 client_cfg,
851 rustls_pki_types::ServerName::try_from("localhost").unwrap(),
852 )
853 .unwrap();
854
855 let mut c2s = Vec::new();
857 let mut s2c = Vec::new();
858
859 for _ in 0..1000 {
861 let _ = client_conn.write_tls(&mut c2s);
863 if !c2s.is_empty() {
864 let mut cur = std::io::Cursor::new(&c2s);
865 let _ = server_conn.read_tls(&mut cur);
866 c2s.clear();
867 server_conn.process_new_packets().unwrap();
868 }
869
870 let _ = server_conn.write_tls(&mut s2c);
872 if !s2c.is_empty() {
873 let mut cur = std::io::Cursor::new(&s2c);
874 let _ = client_conn.read_tls(&mut cur);
875 s2c.clear();
876 client_conn.process_new_packets().unwrap();
877 }
878
879 if !client_conn.is_handshaking() && !server_conn.is_handshaking() {
880 break;
881 }
882 }
883
884 let sni = server_conn.server_name().map(|s| s.to_string());
886 let peer: SocketAddr = "127.0.0.1:0".parse().unwrap();
887 let mut ci: DefaultClient<()> = DefaultClient::new(peer, true);
888 if let Some(s) = sni {
889 ci.sni_server_name = Some(s);
890 }
891 assert_eq!(ci.sni_server_name(), Some("localhost"));
892 }
893}