Skip to main content

pgwire/tokio/
server.rs

1use std::io;
2use std::pin::Pin;
3use std::sync::Arc;
4use std::task::{Context, Poll};
5
6use futures::{SinkExt, StreamExt};
7#[cfg(any(feature = "_ring", feature = "_aws-lc-rs"))]
8use rustls_pki_types::CertificateDer;
9use tokio::io::{AsyncRead, AsyncWrite};
10use tokio::net::TcpStream;
11#[cfg(unix)]
12use tokio::net::UnixStream;
13use tokio::time::{Duration, sleep};
14#[cfg(any(feature = "_ring", feature = "_aws-lc-rs"))]
15use tokio_rustls::server::TlsStream;
16use tokio_util::codec::{Decoder, Encoder, Framed, FramedParts};
17
18use crate::api::auth::StartupHandler;
19use crate::api::cancel::CancelHandler;
20use crate::api::copy::CopyHandler;
21use crate::api::query::{ExtendedQueryHandler, SimpleQueryHandler, send_ready_for_query};
22use crate::api::{
23    ClientInfo, ClientPortalStore, DefaultClient, ErrorHandler, PgWireConnectionState,
24    PgWireServerHandlers,
25};
26use crate::error::{ErrorInfo, PgWireError, PgWireResult};
27use crate::messages::response::{GssEncResponse, ReadyForQuery, SslResponse, TransactionStatus};
28use crate::messages::startup::SecretKey;
29use crate::messages::{
30    DecodeContext, PgWireBackendMessage, PgWireFrontendMessage, ProtocolVersion,
31    SslNegotiationMetaMessage,
32};
33
34/// startup timeout
35const STARTUP_TIMEOUT_MILLIS: u64 = 60_000;
36
37#[non_exhaustive]
38#[derive(Debug, new)]
39pub struct PgWireMessageServerCodec<S> {
40    pub client_info: DefaultClient<S>,
41    #[new(default)]
42    decode_context: DecodeContext,
43}
44
45impl<S> Decoder for PgWireMessageServerCodec<S> {
46    type Item = PgWireFrontendMessage;
47    type Error = PgWireError;
48
49    fn decode(&mut self, src: &mut bytes::BytesMut) -> Result<Option<Self::Item>, Self::Error> {
50        self.decode_context.protocol_version = self.client_info.protocol_version;
51
52        match self.client_info.state() {
53            PgWireConnectionState::AwaitingSslRequest => {
54                self.decode_context.awaiting_frontend_ssl = true;
55                self.decode_context.awaiting_frontend_startup = true;
56            }
57
58            PgWireConnectionState::AwaitingStartup => {
59                self.decode_context.awaiting_frontend_ssl = false;
60                self.decode_context.awaiting_frontend_startup = true;
61            }
62
63            _ => {
64                self.decode_context.awaiting_frontend_startup = false;
65                self.decode_context.awaiting_frontend_ssl = false;
66            }
67        }
68
69        PgWireFrontendMessage::decode(src, &self.decode_context)
70    }
71}
72
73impl<S> Encoder<PgWireBackendMessage> for PgWireMessageServerCodec<S> {
74    type Error = io::Error;
75
76    fn encode(
77        &mut self,
78        item: PgWireBackendMessage,
79        dst: &mut bytes::BytesMut,
80    ) -> Result<(), Self::Error> {
81        item.encode(dst).map_err(Into::into)
82    }
83}
84
85impl<T: 'static, S> ClientInfo for Framed<T, PgWireMessageServerCodec<S>> {
86    fn socket_addr(&self) -> std::net::SocketAddr {
87        self.codec().client_info.socket_addr
88    }
89
90    fn is_secure(&self) -> bool {
91        self.codec().client_info.is_secure
92    }
93
94    fn pid_and_secret_key(&self) -> (i32, SecretKey) {
95        self.codec().client_info.pid_and_secret_key()
96    }
97
98    fn set_pid_and_secret_key(&mut self, pid: i32, secret_key: SecretKey) {
99        self.codec_mut()
100            .client_info
101            .set_pid_and_secret_key(pid, secret_key);
102    }
103
104    fn protocol_version(&self) -> ProtocolVersion {
105        self.codec().client_info.protocol_version()
106    }
107
108    fn set_protocol_version(&mut self, version: ProtocolVersion) {
109        self.codec_mut().client_info.set_protocol_version(version);
110    }
111
112    fn state(&self) -> PgWireConnectionState {
113        self.codec().client_info.state
114    }
115
116    fn set_state(&mut self, new_state: PgWireConnectionState) {
117        self.codec_mut().client_info.set_state(new_state);
118    }
119
120    fn metadata(&self) -> &std::collections::HashMap<String, String> {
121        self.codec().client_info.metadata()
122    }
123
124    fn metadata_mut(&mut self) -> &mut std::collections::HashMap<String, String> {
125        self.codec_mut().client_info.metadata_mut()
126    }
127
128    fn transaction_status(&self) -> TransactionStatus {
129        self.codec().client_info.transaction_status()
130    }
131
132    fn set_transaction_status(&mut self, new_status: TransactionStatus) {
133        self.codec_mut()
134            .client_info
135            .set_transaction_status(new_status);
136    }
137
138    #[cfg(any(feature = "_ring", feature = "_aws-lc-rs"))]
139    fn sni_server_name(&self) -> Option<&str> {
140        self.codec().client_info.sni_server_name()
141    }
142
143    #[cfg(any(feature = "_ring", feature = "_aws-lc-rs"))]
144    fn client_certificates<'a>(&self) -> Option<&[CertificateDer<'a>]> {
145        if !self.is_secure() {
146            None
147        } else {
148            let socket =
149                <dyn std::any::Any>::downcast_ref::<TlsStream<TcpStream>>(self.get_ref()).unwrap();
150            let (_, tls_session) = socket.get_ref();
151            tls_session.peer_certificates()
152        }
153    }
154}
155
156impl<T, S> ClientPortalStore for Framed<T, PgWireMessageServerCodec<S>> {
157    type PortalStore = <DefaultClient<S> as ClientPortalStore>::PortalStore;
158
159    fn portal_store(&self) -> &Self::PortalStore {
160        self.codec().client_info.portal_store()
161    }
162}
163
164pub async fn process_message<S, A, Q, EQ, C, CR>(
165    message: PgWireFrontendMessage,
166    socket: &mut Framed<S, PgWireMessageServerCodec<EQ::Statement>>,
167    authenticator: Arc<A>,
168    query_handler: Arc<Q>,
169    extended_query_handler: Arc<EQ>,
170    copy_handler: Arc<C>,
171    cancel_handler: Arc<CR>,
172) -> PgWireResult<()>
173where
174    S: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static,
175    A: StartupHandler,
176    Q: SimpleQueryHandler,
177    EQ: ExtendedQueryHandler,
178    C: CopyHandler,
179    CR: CancelHandler,
180{
181    // CancelRequest is from a dedicated connection, process it and close it.
182    if let PgWireFrontendMessage::CancelRequest(cancel) = message {
183        cancel_handler.on_cancel_request(cancel).await;
184        socket.close().await?;
185        return Ok(());
186    }
187
188    match socket.state() {
189        PgWireConnectionState::AwaitingStartup
190        | PgWireConnectionState::AuthenticationInProgress => {
191            authenticator.on_startup(socket, message).await?;
192        }
193        // From Postgres docs:
194        // When an error is detected while processing any extended-query
195        // message, the backend issues ErrorResponse, then reads and discards
196        // messages until a Sync is reached, then issues ReadyForQuery and
197        // returns to normal message processing.
198        PgWireConnectionState::AwaitingSync => {
199            if let PgWireFrontendMessage::Sync(sync) = message {
200                extended_query_handler.on_sync(socket, sync).await?;
201                // TODO: confirm if we need to track transaction state there
202                socket.set_state(PgWireConnectionState::ReadyForQuery);
203            }
204        }
205        PgWireConnectionState::CopyInProgress(is_extended_query) => {
206            // query or query in progress
207            match message {
208                PgWireFrontendMessage::CopyData(copy_data) => {
209                    copy_handler.on_copy_data(socket, copy_data).await?;
210                }
211                PgWireFrontendMessage::CopyDone(copy_done) => {
212                    let result = copy_handler.on_copy_done(socket, copy_done).await;
213                    if !is_extended_query {
214                        // If the copy was initiated from a simple protocol
215                        // query, we should leave the CopyInProgress state
216                        // before returning the error in order to resume normal
217                        // operation after handling it in process_error.
218                        socket.set_state(PgWireConnectionState::ReadyForQuery);
219                    }
220                    match result {
221                        Ok(_) => {
222                            if !is_extended_query {
223                                // If the copy was initiated from a simple protocol
224                                // query, notify the client that we are not ready
225                                // for the next query.
226                                send_ready_for_query(socket, TransactionStatus::Idle).await?
227                            } else {
228                                // In the extended protocol (at least as
229                                // implemented by rust-postgres) we get a Sync
230                                // after the CopyDone, so we should let the
231                                // on_sync handler send the ReadyForQuery.
232                            }
233                        }
234                        err => return err,
235                    }
236                }
237                PgWireFrontendMessage::CopyFail(copy_fail) => {
238                    let error = copy_handler.on_copy_fail(socket, copy_fail).await;
239                    if !is_extended_query {
240                        // If the copy was initiated from a simple protocol query,
241                        // we should leave the CopyInProgress state
242                        // before returning the error in order to resume normal
243                        // operation after handling it in process_error.
244                        socket.set_state(PgWireConnectionState::ReadyForQuery);
245                    }
246                    return Err(error);
247                }
248                _ => {}
249            }
250        }
251        _ => {
252            // query or query in progress
253            match message {
254                PgWireFrontendMessage::Query(query) => {
255                    query_handler.on_query(socket, query).await?;
256                }
257                PgWireFrontendMessage::Parse(parse) => {
258                    extended_query_handler.on_parse(socket, parse).await?;
259                }
260                PgWireFrontendMessage::Bind(bind) => {
261                    extended_query_handler.on_bind(socket, bind).await?;
262                }
263                PgWireFrontendMessage::Execute(execute) => {
264                    extended_query_handler.on_execute(socket, execute).await?;
265                }
266                PgWireFrontendMessage::Describe(describe) => {
267                    extended_query_handler.on_describe(socket, describe).await?;
268                }
269                PgWireFrontendMessage::Flush(flush) => {
270                    extended_query_handler.on_flush(socket, flush).await?;
271                }
272                PgWireFrontendMessage::Sync(sync) => {
273                    extended_query_handler.on_sync(socket, sync).await?;
274                }
275                PgWireFrontendMessage::Close(close) => {
276                    extended_query_handler.on_close(socket, close).await?;
277                }
278                _ => {}
279            }
280        }
281    }
282    Ok(())
283}
284
285pub async fn process_error<S, ST>(
286    socket: &mut Framed<S, PgWireMessageServerCodec<ST>>,
287    error: PgWireError,
288    wait_for_sync: bool,
289) -> Result<(), io::Error>
290where
291    S: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static,
292{
293    let error_info: ErrorInfo = error.into();
294    let is_fatal = error_info.is_fatal();
295    socket
296        .send(PgWireBackendMessage::ErrorResponse(error_info.into()))
297        .await?;
298
299    let transaction_status = socket.transaction_status().to_error_state();
300    socket.set_transaction_status(transaction_status);
301
302    if wait_for_sync {
303        socket.set_state(PgWireConnectionState::AwaitingSync);
304    } else {
305        socket.set_state(PgWireConnectionState::ReadyForQuery);
306        socket
307            .feed(PgWireBackendMessage::ReadyForQuery(ReadyForQuery::new(
308                transaction_status,
309            )))
310            .await?;
311    }
312    socket.flush().await?;
313
314    if is_fatal {
315        return socket.close().await;
316    }
317
318    Ok(())
319}
320
321#[derive(Debug, PartialEq, Eq)]
322enum SslNegotiationType {
323    Postgres,
324    Direct,
325    None,
326}
327
328async fn check_ssl_direct_negotiation(tcp_socket: &TcpStream) -> Result<bool, io::Error> {
329    let mut buf = [0u8; 1];
330    let n = tcp_socket.peek(&mut buf).await?;
331
332    Ok(n > 0 && buf[0] == 0x16)
333}
334
335async fn peek_for_sslrequest<ST>(
336    socket: &mut Framed<TcpStream, PgWireMessageServerCodec<ST>>,
337    ssl_supported: bool,
338) -> Result<SslNegotiationType, io::Error> {
339    if check_ssl_direct_negotiation(socket.get_ref()).await? {
340        Ok(SslNegotiationType::Direct)
341    } else {
342        let mut ssl_done = false;
343        let mut gss_done = false;
344
345        loop {
346            match socket.next().await {
347                // postgres ssl
348                Some(Ok(PgWireFrontendMessage::SslNegotiation(
349                    SslNegotiationMetaMessage::PostgresSsl(_),
350                ))) => {
351                    // ssl request
352                    if ssl_supported {
353                        socket
354                            .send(PgWireBackendMessage::SslResponse(SslResponse::Accept))
355                            .await?;
356                        return Ok(SslNegotiationType::Postgres);
357                    } else {
358                        socket
359                            .send(PgWireBackendMessage::SslResponse(SslResponse::Refuse))
360                            .await?;
361                        ssl_done = true;
362
363                        if gss_done {
364                            return Ok(SslNegotiationType::None);
365                        } else {
366                            // Continue to check for more requests (e.g., GssEncRequest after SSL refuse)
367                            continue;
368                        }
369                    }
370                }
371
372                // postgres gss
373                Some(Ok(PgWireFrontendMessage::SslNegotiation(
374                    SslNegotiationMetaMessage::PostgresGss(_),
375                ))) => {
376                    socket
377                        .send(PgWireBackendMessage::GssEncResponse(GssEncResponse::Refuse))
378                        .await?;
379                    gss_done = true;
380
381                    if ssl_done {
382                        return Ok(SslNegotiationType::None);
383                    } else {
384                        // Continue to check for more requests (e.g., SSL request after GSSAPI refuse)
385                        continue;
386                    }
387                }
388
389                // not a handshake request or connection is broken
390                _ => {
391                    return Ok(SslNegotiationType::None);
392                }
393            }
394        }
395    }
396}
397
398#[cfg(any(feature = "_ring", feature = "_aws-lc-rs"))]
399fn check_alpn_for_direct_ssl<IO>(tls_socket: &TlsStream<IO>) -> Result<(), io::Error> {
400    let (_, the_conn) = tls_socket.get_ref();
401    let mut accept = false;
402
403    if let Some(alpn) = the_conn.alpn_protocol()
404        && alpn == super::POSTGRESQL_ALPN_NAME
405    {
406        accept = true;
407    }
408
409    if !accept {
410        Err(io::Error::new(
411            io::ErrorKind::InvalidData,
412            "received direct SSL connection request without ALPN protocol negotiation extension",
413        ))
414    } else {
415        Ok(())
416    }
417}
418
419#[non_exhaustive]
420pub enum MaybeTls {
421    Plain(TcpStream),
422    #[cfg(unix)]
423    Unix(UnixStream),
424    #[cfg(any(feature = "_ring", feature = "_aws-lc-rs"))]
425    Tls(Box<TlsStream<TcpStream>>),
426}
427
428macro_rules! maybe_tls {
429    ($self:ident, $poll_x:ident($($args:expr),*)) => {
430        match $self.get_mut() {
431            MaybeTls::Plain(io) => Pin::new(io).$poll_x($($args),*),
432            #[cfg(unix)]
433            MaybeTls::Unix(io) => Pin::new(io).$poll_x($($args),*),
434            #[cfg(any(feature = "_ring", feature = "_aws-lc-rs"))]
435            MaybeTls::Tls(io) => Pin::new(io).$poll_x($($args),*),
436        }
437    };
438}
439
440impl AsyncRead for MaybeTls {
441    fn poll_read(
442        self: Pin<&mut Self>,
443        cx: &mut Context<'_>,
444        buf: &mut tokio::io::ReadBuf<'_>,
445    ) -> Poll<io::Result<()>> {
446        maybe_tls!(self, poll_read(cx, buf))
447    }
448}
449
450impl AsyncWrite for MaybeTls {
451    fn poll_write(
452        self: Pin<&mut Self>,
453        cx: &mut Context<'_>,
454        buf: &[u8],
455    ) -> Poll<io::Result<usize>> {
456        maybe_tls!(self, poll_write(cx, buf))
457    }
458
459    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
460        maybe_tls!(self, poll_flush(cx))
461    }
462
463    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
464        maybe_tls!(self, poll_shutdown(cx))
465    }
466}
467
468/// Negotiate TLS with the given client stream.
469///
470/// Returns `Ok(None)` if the client sent a direct TLS negotiation but
471/// `tls_acceptor` was `None`.
472pub async fn negotiate_tls<S>(
473    tcp_socket: TcpStream,
474    tls_acceptor: Option<crate::tokio::TlsAcceptor>,
475) -> io::Result<Option<Framed<MaybeTls, PgWireMessageServerCodec<S>>>> {
476    let addr = tcp_socket.peer_addr()?;
477    tcp_socket.set_nodelay(true)?;
478
479    let client_info = DefaultClient::new(addr, false);
480    let mut tcp_socket = Framed::new(tcp_socket, PgWireMessageServerCodec::new(client_info));
481
482    // this function will process postgres ssl negotiation and consume the first
483    // SslRequest packet if detected.
484    let ssl = peek_for_sslrequest(&mut tcp_socket, tls_acceptor.is_some()).await?;
485
486    let old_parts = tcp_socket.into_parts();
487
488    if ssl == SslNegotiationType::None {
489        let mut parts = FramedParts::new(MaybeTls::Plain(old_parts.io), old_parts.codec);
490        parts.read_buf = old_parts.read_buf;
491        parts.write_buf = old_parts.write_buf;
492        let mut socket = Framed::from_parts(parts);
493
494        socket.set_state(PgWireConnectionState::AwaitingStartup);
495
496        return Ok(Some(socket));
497    }
498    #[cfg(any(feature = "_ring", feature = "_aws-lc-rs"))]
499    if let Some(tls_acceptor) = tls_acceptor {
500        // mention the use of ssl
501        let mut client_info = DefaultClient::new(addr, true);
502
503        let ssl_socket = Box::new(tls_acceptor.accept(old_parts.io).await?);
504
505        // check alpn for direct ssl connection
506        if ssl == SslNegotiationType::Direct {
507            check_alpn_for_direct_ssl(&ssl_socket)?;
508        }
509
510        // capture SNI (server name) from the underlying TLS connection
511        let sni = {
512            let (_, conn) = ssl_socket.get_ref();
513            conn.server_name().map(|s| s.to_string())
514        };
515        if let Some(s) = sni {
516            client_info.sni_server_name = Some(s);
517        }
518
519        let mut parts = FramedParts::new(
520            MaybeTls::Tls(ssl_socket),
521            PgWireMessageServerCodec::new(client_info),
522        );
523        parts.read_buf = old_parts.read_buf;
524        parts.write_buf = old_parts.write_buf;
525        let mut socket = Framed::from_parts(parts);
526
527        socket.set_state(PgWireConnectionState::AwaitingStartup);
528
529        return Ok(Some(socket));
530    }
531    Ok(None)
532}
533
534/// Process messages on an already-negotiated socket.
535///
536/// This is the common message processing loop shared by both TCP and Unix sockets.
537macro_rules! process_socket_messages {
538    ($socket:expr, $startup_timeout:expr, $handlers:expr) => {{
539        let startup_handler = $handlers.startup_handler();
540        let simple_query_handler = $handlers.simple_query_handler();
541        let extended_query_handler = $handlers.extended_query_handler();
542        let copy_handler = $handlers.copy_handler();
543        let cancel_handler = $handlers.cancel_handler();
544        let error_handler = $handlers.error_handler();
545
546        let socket = &mut $socket;
547        loop {
548            let msg = if matches!(
549                socket.state(),
550                PgWireConnectionState::AwaitingStartup
551                    | PgWireConnectionState::AuthenticationInProgress
552            ) {
553                tokio::select! {
554                    _ = &mut $startup_timeout => None,
555                    msg = socket.next() => msg,
556                }
557            } else {
558                socket.next().await
559            };
560
561            if let Some(Ok(msg)) = msg {
562                let is_extended_query = match socket.state() {
563                    PgWireConnectionState::CopyInProgress(is_extended_query) => is_extended_query,
564                    _ => msg.is_extended_query(),
565                };
566                if let Err(mut e) = process_message(
567                    msg,
568                    socket,
569                    startup_handler.clone(),
570                    simple_query_handler.clone(),
571                    extended_query_handler.clone(),
572                    copy_handler.clone(),
573                    cancel_handler.clone(),
574                )
575                .await
576                {
577                    error_handler.on_error(socket, &mut e);
578                    process_error(socket, e, is_extended_query).await?;
579                }
580            } else {
581                break;
582            }
583        }
584    }};
585}
586
587/// Process Unix domain socket connection.
588#[cfg(unix)]
589pub async fn process_socket_unix<H>(unix_socket: UnixStream, handlers: H) -> Result<(), io::Error>
590where
591    H: PgWireServerHandlers,
592{
593    let startup_timeout = sleep(Duration::from_millis(STARTUP_TIMEOUT_MILLIS));
594    tokio::pin!(startup_timeout);
595
596    // Use a dummy socket address for Unix domain socket connections
597    // This is consistent with how PostgreSQL handles Unix socket connections
598    let addr = "127.0.0.1:0".parse().unwrap();
599
600    let client_info = DefaultClient::new(addr, false);
601    let mut socket = Framed::new(
602        MaybeTls::Unix(unix_socket),
603        PgWireMessageServerCodec::new(client_info),
604    );
605
606    socket.set_state(PgWireConnectionState::AwaitingStartup);
607
608    process_socket_messages!(socket, startup_timeout, handlers);
609    Ok(())
610}
611
612pub async fn process_socket<H>(
613    tcp_socket: TcpStream,
614    tls_acceptor: Option<crate::tokio::TlsAcceptor>,
615    handlers: H,
616) -> Result<(), io::Error>
617where
618    H: PgWireServerHandlers,
619{
620    // start a timer for startup process, if the client couldn't finish startup
621    // within the timeout, it has to be dropped.
622    let startup_timeout = sleep(Duration::from_millis(STARTUP_TIMEOUT_MILLIS));
623    tokio::pin!(startup_timeout);
624
625    // this function will process postgres ssl negotiation and consume the first
626    // SslRequest packet if detected.
627    let socket = tokio::select! {
628        _ = &mut startup_timeout => {
629            return Ok(())
630        },
631        socket = negotiate_tls(tcp_socket, tls_acceptor) => {
632            socket?
633        }
634    };
635    let Some(mut socket) = socket else {
636        // no tls_acceptor configured. But the client sends direct tls
637        // negotiation. this is typically an invalid connection
638        return Ok(());
639    };
640
641    process_socket_messages!(socket, startup_timeout, handlers);
642    Ok(())
643}
644
645#[cfg(all(test, any(feature = "_ring", feature = "_aws-lc-rs")))]
646mod tests {
647    use super::*;
648    use std::fs::File;
649    use std::io::{BufReader, Error as IOError};
650    use std::sync::Arc;
651    use tokio::sync::oneshot;
652    use tokio_rustls::TlsAcceptor;
653    use tokio_rustls::TlsConnector;
654    use tokio_rustls::rustls;
655    use tokio_rustls::rustls::crypto::CryptoProvider;
656
657    fn load_test_server_config() -> Result<rustls::ServerConfig, IOError> {
658        use rustls_pemfile::{certs, pkcs8_private_keys};
659        use rustls_pki_types::{CertificateDer, PrivateKeyDer};
660
661        let certs = certs(&mut BufReader::new(File::open("examples/ssl/server.crt")?))
662            .collect::<Result<Vec<CertificateDer>, _>>()?;
663        let key = pkcs8_private_keys(&mut BufReader::new(File::open("examples/ssl/server.key")?))
664            .map(|key| key.map(PrivateKeyDer::from))
665            .collect::<Result<Vec<PrivateKeyDer>, _>>()?
666            .remove(0);
667
668        let mut cfg = rustls::ServerConfig::builder()
669            .with_no_client_auth()
670            .with_single_cert(certs, key)
671            .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidInput, e))?;
672        // ALPN is optional for this test; SNI extraction doesn't depend on it.
673        cfg.alpn_protocols = vec![crate::tokio::POSTGRESQL_ALPN_NAME.to_vec()];
674        Ok(cfg)
675    }
676
677    fn make_test_client_connector() -> Result<TlsConnector, IOError> {
678        // For this unit test we are only validating SNI plumbing, not cert validation.
679        // Use a custom verifier that accepts any certificate.
680        #[derive(Debug)]
681        struct NoCertVerifier;
682        impl rustls::client::danger::ServerCertVerifier for NoCertVerifier {
683            fn verify_server_cert(
684                &self,
685                _end_entity: &rustls::pki_types::CertificateDer<'_>,
686                _intermediates: &[rustls::pki_types::CertificateDer<'_>],
687                _server_name: &rustls::pki_types::ServerName<'_>,
688                _ocsp_response: &[u8],
689                _now: rustls::pki_types::UnixTime,
690            ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
691                Ok(rustls::client::danger::ServerCertVerified::assertion())
692            }
693
694            fn verify_tls12_signature(
695                &self,
696                _message: &[u8],
697                _cert: &rustls::pki_types::CertificateDer<'_>,
698                _dss: &rustls::DigitallySignedStruct,
699            ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error>
700            {
701                Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
702            }
703
704            fn verify_tls13_signature(
705                &self,
706                _message: &[u8],
707                _cert: &rustls::pki_types::CertificateDer<'_>,
708                _dss: &rustls::DigitallySignedStruct,
709            ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error>
710            {
711                Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
712            }
713
714            fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
715                vec![
716                    rustls::SignatureScheme::RSA_PKCS1_SHA256,
717                    rustls::SignatureScheme::RSA_PKCS1_SHA384,
718                    rustls::SignatureScheme::RSA_PKCS1_SHA512,
719                    rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
720                    rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
721                    rustls::SignatureScheme::RSA_PSS_SHA256,
722                    rustls::SignatureScheme::RSA_PSS_SHA384,
723                    rustls::SignatureScheme::RSA_PSS_SHA512,
724                ]
725            }
726        }
727
728        let mut cfg = rustls::ClientConfig::builder()
729            .dangerous()
730            .with_custom_certificate_verifier(Arc::new(NoCertVerifier))
731            .with_no_client_auth();
732        // Align ALPN to server to reduce negotiation variance
733        cfg.alpn_protocols = vec![crate::tokio::POSTGRESQL_ALPN_NAME.to_vec()];
734        Ok(TlsConnector::from(Arc::new(cfg)))
735    }
736
737    #[tokio::test]
738    #[ignore]
739    async fn server_name_metadata_is_set_from_tls_sni() {
740        use std::net::SocketAddr;
741        use tokio::io::duplex;
742
743        // set up TLS server and client configs
744        let server_cfg = load_test_server_config().expect("server config");
745        let acceptor = TlsAcceptor::from(Arc::new(server_cfg));
746        let connector = make_test_client_connector().expect("client connector");
747
748        // in-memory full-duplex stream pair (use a larger buffer for TLS handshake)
749        let (server_io, client_io) = duplex(64 * 1024);
750
751        let (tx, rx) = oneshot::channel::<Option<String>>();
752
753        // spawn server task to accept TLS over in-memory IO
754        tokio::spawn(async move {
755            let tls = acceptor.accept(server_io).await.unwrap();
756
757            // mimic production path: capture SNI and store on client_info
758            let sni = {
759                let (_, conn) = tls.get_ref();
760                conn.server_name().map(|s| s.to_string())
761            };
762            let peer: SocketAddr = "127.0.0.1:0".parse().unwrap();
763            let mut ci: DefaultClient<()> = DefaultClient::new(peer, true);
764            if let Some(s) = sni {
765                ci.sni_server_name = Some(s);
766            }
767            let framed = Framed::new(tls, PgWireMessageServerCodec::new(ci));
768            let server_name = framed.sni_server_name().map(str::to_string);
769            let _ = tx.send(server_name);
770        });
771
772        // client side: connect with SNI=localhost over in-memory IO
773        let server_name = rustls_pki_types::ServerName::try_from("localhost").unwrap();
774        let _ = connector.connect(server_name, client_io).await.unwrap();
775
776        // verify server observed SNI and stored as `server_name`
777        let observed = rx.await.expect("server_name from server");
778        assert_eq!(observed.as_deref(), Some("localhost"));
779    }
780
781    #[tokio::test]
782    async fn server_name_metadata_is_set_from_tls_sni_in_memory() {
783        use std::net::SocketAddr;
784
785        #[cfg(feature = "_aws-lc-rs")]
786        CryptoProvider::install_default(tokio_rustls::rustls::crypto::aws_lc_rs::default_provider()).unwrap();
787        #[cfg(feature = "_ring")]
788        CryptoProvider::install_default(tokio_rustls::rustls::crypto::ring::default_provider())
789            .unwrap();
790
791        // server and client rustls configs
792        let server_cfg = Arc::new(load_test_server_config().expect("server config"));
793
794        // no-op verifier to focus on SNI plumbing
795        #[derive(Debug)]
796        struct NoCertVerifier;
797        impl rustls::client::danger::ServerCertVerifier for NoCertVerifier {
798            fn verify_server_cert(
799                &self,
800                _end_entity: &rustls::pki_types::CertificateDer<'_>,
801                _intermediates: &[rustls::pki_types::CertificateDer<'_>],
802                _server_name: &rustls::pki_types::ServerName<'_>,
803                _ocsp_response: &[u8],
804                _now: rustls::pki_types::UnixTime,
805            ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
806                Ok(rustls::client::danger::ServerCertVerified::assertion())
807            }
808            fn verify_tls12_signature(
809                &self,
810                _message: &[u8],
811                _cert: &rustls::pki_types::CertificateDer<'_>,
812                _dss: &rustls::DigitallySignedStruct,
813            ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error>
814            {
815                Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
816            }
817            fn verify_tls13_signature(
818                &self,
819                _message: &[u8],
820                _cert: &rustls::pki_types::CertificateDer<'_>,
821                _dss: &rustls::DigitallySignedStruct,
822            ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error>
823            {
824                Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
825            }
826            fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
827                vec![
828                    rustls::SignatureScheme::RSA_PKCS1_SHA256,
829                    rustls::SignatureScheme::RSA_PKCS1_SHA384,
830                    rustls::SignatureScheme::RSA_PKCS1_SHA512,
831                    rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
832                    rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
833                    rustls::SignatureScheme::RSA_PSS_SHA256,
834                    rustls::SignatureScheme::RSA_PSS_SHA384,
835                    rustls::SignatureScheme::RSA_PSS_SHA512,
836                ]
837            }
838        }
839
840        let mut client_cfg = rustls::ClientConfig::builder()
841            .dangerous()
842            .with_custom_certificate_verifier(Arc::new(NoCertVerifier))
843            .with_no_client_auth();
844        client_cfg.alpn_protocols = vec![crate::tokio::POSTGRESQL_ALPN_NAME.to_vec()];
845        let client_cfg = Arc::new(client_cfg);
846
847        // build rustls connections directly and drive handshake in-memory
848        let mut server_conn = rustls::ServerConnection::new(server_cfg).unwrap();
849        let mut client_conn = rustls::ClientConnection::new(
850            client_cfg,
851            rustls_pki_types::ServerName::try_from("localhost").unwrap(),
852        )
853        .unwrap();
854
855        // in-memory pipes for TLS records
856        let mut c2s = Vec::new();
857        let mut s2c = Vec::new();
858
859        // drive handshake until both sides complete
860        for _ in 0..1000 {
861            // client -> server
862            let _ = client_conn.write_tls(&mut c2s);
863            if !c2s.is_empty() {
864                let mut cur = std::io::Cursor::new(&c2s);
865                let _ = server_conn.read_tls(&mut cur);
866                c2s.clear();
867                server_conn.process_new_packets().unwrap();
868            }
869
870            // server -> client
871            let _ = server_conn.write_tls(&mut s2c);
872            if !s2c.is_empty() {
873                let mut cur = std::io::Cursor::new(&s2c);
874                let _ = client_conn.read_tls(&mut cur);
875                s2c.clear();
876                client_conn.process_new_packets().unwrap();
877            }
878
879            if !client_conn.is_handshaking() && !server_conn.is_handshaking() {
880                break;
881            }
882        }
883
884        // capture SNI from server side and store on client info
885        let sni = server_conn.server_name().map(|s| s.to_string());
886        let peer: SocketAddr = "127.0.0.1:0".parse().unwrap();
887        let mut ci: DefaultClient<()> = DefaultClient::new(peer, true);
888        if let Some(s) = sni {
889            ci.sni_server_name = Some(s);
890        }
891        assert_eq!(ci.sni_server_name(), Some("localhost"));
892    }
893}