Skip to main content

xitca_postgres/
driver.rs

1//! client driver module.
2
3pub(crate) mod codec;
4pub(crate) mod generic;
5
6mod connect;
7
8pub(crate) use generic::DriverTx;
9
10#[cfg(feature = "tls")]
11mod tls;
12
13#[cfg(feature = "quic")]
14pub(crate) mod quic;
15
16#[cfg(feature = "io-uring")]
17pub(crate) mod io_uring;
18
19#[cfg(feature = "compio")]
20pub(crate) mod compio;
21
22use core::{
23    future::{Future, IntoFuture},
24    net::SocketAddr,
25    pin::Pin,
26};
27
28use std::io;
29
30use xitca_io::{
31    bytes::{Buf, BytesMut},
32    io::{AsyncIo, AsyncIoDyn, Interest},
33    net::TcpStream,
34};
35
36use super::{
37    client::Client,
38    config::{Config, SslMode, SslNegotiation},
39    error::{ConfigError, Error, unexpected_eof_err},
40    iter::AsyncLendingIterator,
41    protocol::message::{backend, frontend},
42    session::{ConnectInfo, Session},
43};
44
45use self::generic::GenericDriver;
46
47#[cfg(feature = "tls")]
48use xitca_tls::rustls::{ClientConnection, TlsStream};
49
50#[cfg(unix)]
51use xitca_io::net::UnixStream;
52
53pub(super) async fn connect(cfg: &mut Config) -> Result<(Client, Driver), Error> {
54    if cfg.get_hosts().is_empty() {
55        return Err(ConfigError::EmptyHost.into());
56    }
57
58    if cfg.get_ports().is_empty() {
59        return Err(ConfigError::EmptyPort.into());
60    }
61
62    let mut err = None;
63    let hosts = cfg.get_hosts().to_vec();
64    for host in hosts {
65        match self::connect::connect_host(host, cfg).await {
66            Ok((tx, session, drv)) => return Ok((Client::new(tx, session), drv)),
67            Err(e) => err = Some(e),
68        }
69    }
70
71    Err(err.unwrap())
72}
73
74pub(super) async fn connect_io<Io>(io: Io, cfg: &mut Config) -> Result<(Client, Driver), Error>
75where
76    Io: AsyncIo + Send + 'static,
77{
78    let (tx, session, drv) = prepare_driver(ConnectInfo::default(), Box::new(io) as _, cfg).await?;
79    Ok((Client::new(tx, session), Driver::Dynamic(drv)))
80}
81
82pub(super) async fn connect_info(info: ConnectInfo) -> Result<(DriverTx, Driver), Error> {
83    self::connect::connect_info(info).await
84}
85
86async fn prepare_driver<Io>(
87    info: ConnectInfo,
88    io: Io,
89    cfg: &mut Config,
90) -> Result<(DriverTx, Session, GenericDriver<Io>), Error>
91where
92    Io: AsyncIo + Send + 'static,
93{
94    let (mut drv, tx) = GenericDriver::new(io);
95    let session = Session::prepare_session(info, &mut drv, cfg).await?;
96    Ok((tx, session, drv))
97}
98
99async fn should_connect_tls<Io>(io: &mut Io, ssl_mode: SslMode, ssl_negotiation: SslNegotiation) -> Result<bool, Error>
100where
101    Io: AsyncIo,
102{
103    async fn query_tls_availability<Io>(io: &mut Io) -> std::io::Result<bool>
104    where
105        Io: AsyncIo,
106    {
107        let mut buf = BytesMut::new();
108        frontend::ssl_request(&mut buf);
109
110        while !buf.is_empty() {
111            match io.write(&buf) {
112                Ok(0) => return Err(unexpected_eof_err()),
113                Ok(n) => buf.advance(n),
114                Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
115                    io.ready(Interest::WRITABLE).await?;
116                }
117                Err(e) => return Err(e),
118            }
119        }
120
121        let mut buf = [0];
122        loop {
123            match io.read(&mut buf) {
124                Ok(0) => return Err(unexpected_eof_err()),
125                Ok(_) => return Ok(buf[0] == b'S'),
126                Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
127                    io.ready(Interest::READABLE).await?;
128                }
129                Err(e) => return Err(e),
130            }
131        }
132    }
133
134    match ssl_mode {
135        SslMode::Disable => Ok(false),
136        _ if matches!(ssl_negotiation, SslNegotiation::Direct) => Ok(true),
137        mode => match (query_tls_availability(io).await?, mode) {
138            (false, SslMode::Require) => Err(Error::todo()),
139            (bool, _) => Ok(bool),
140        },
141    }
142}
143
144async fn dns_resolve<'p>(host: &'p str, ports: &'p [u16]) -> Result<impl Iterator<Item = SocketAddr> + 'p, Error> {
145    let addrs = tokio::net::lookup_host((host, 0)).await?.flat_map(|mut addr| {
146        ports.iter().map(move |port| {
147            addr.set_port(*port);
148            addr
149        })
150    });
151    Ok(addrs)
152}
153
154/// async driver of [`Client`]
155///
156/// it handles IO and emit server sent message that do not belong to any query with [`AsyncLendingIterator`]
157/// trait impl.
158///
159/// # Examples
160/// ```
161/// use std::future::IntoFuture;
162/// use xitca_postgres::{iter::AsyncLendingIterator, Driver};
163///
164/// // drive the client and listen to server notify at the same time.
165/// fn drive_with_server_notify(mut drv: Driver) {
166///     tokio::spawn(async move {
167///         while let Ok(Some(msg)) = drv.try_next().await {
168///             // *Note:
169///             // handle message must be non-blocking to prevent starvation of driver.
170///         }
171///     });
172/// }
173///
174/// // drive client without handling notify.
175/// fn drive_only(drv: Driver) {
176///     tokio::spawn(drv.into_future());
177/// }
178/// ```
179///
180/// # Lifetime
181/// Driver and [`Client`] have a dependent lifetime where either side can trigger the other part to shutdown.
182/// From Driver side it's in the form of dropping ownership.
183/// ## Examples
184/// ```
185/// # use xitca_postgres::{error::Error, Config, Execute, Postgres};
186/// # async fn shut_down(cfg: Config) -> Result<(), Error> {
187/// // connect to a database
188/// let (cli, drv) = Postgres::new(cfg).connect().await?;
189///
190/// // drop driver
191/// drop(drv);
192///
193/// // client will always return error when it's driver is gone.
194/// let e = "SELECT 1".query(&cli).await.unwrap_err();
195/// // a shortcut method can be used to determine if the error is caused by a shutdown driver.
196/// assert!(e.is_driver_down());
197///
198/// # Ok(())
199/// # }
200/// ```
201///
202// TODO: use Box<dyn AsyncIterator> when life time GAT is object safe.
203pub enum Driver {
204    Tcp(GenericDriver<TcpStream>),
205    Dynamic(GenericDriver<Box<dyn AsyncIoDyn + Send>>),
206    #[cfg(feature = "tls")]
207    Tls(GenericDriver<TlsStream<ClientConnection, TcpStream>>),
208    #[cfg(unix)]
209    Unix(GenericDriver<UnixStream>),
210    #[cfg(all(unix, feature = "tls"))]
211    UnixTls(GenericDriver<TlsStream<ClientConnection, UnixStream>>),
212    #[cfg(feature = "quic")]
213    Quic(GenericDriver<crate::driver::quic::QuicStream>),
214}
215
216impl Driver {
217    #[inline]
218    pub(crate) async fn send(&mut self, buf: BytesMut) -> Result<(), Error> {
219        match self {
220            Self::Tcp(drv) => drv.send(buf).await,
221            Self::Dynamic(drv) => drv.send(buf).await,
222            #[cfg(feature = "tls")]
223            Self::Tls(drv) => drv.send(buf).await,
224            #[cfg(unix)]
225            Self::Unix(drv) => drv.send(buf).await,
226            #[cfg(all(unix, feature = "tls"))]
227            Self::UnixTls(drv) => drv.send(buf).await,
228            #[cfg(feature = "quic")]
229            Self::Quic(drv) => drv.send(buf).await,
230        }
231    }
232
233    // try to unwrap driver that using unencrypted tcp connection
234    pub fn try_into_tcp(self) -> Option<GenericDriver<TcpStream>> {
235        match self {
236            Self::Tcp(drv) => Some(drv),
237            _ => None,
238        }
239    }
240
241    #[cfg(feature = "io-uring")]
242    pub fn try_into_uring(self) -> Option<io_uring::UringDriver> {
243        self.try_into_tcp().map(io_uring::UringDriver::from_tcp)
244    }
245}
246
247impl AsyncLendingIterator for Driver {
248    type Ok<'i>
249        = backend::Message
250    where
251        Self: 'i;
252    type Err = Error;
253
254    #[inline]
255    async fn try_next(&mut self) -> Result<Option<Self::Ok<'_>>, Self::Err> {
256        match self {
257            Self::Tcp(drv) => drv.try_next().await,
258            Self::Dynamic(drv) => drv.try_next().await,
259            #[cfg(feature = "tls")]
260            Self::Tls(drv) => drv.try_next().await,
261            #[cfg(unix)]
262            Self::Unix(drv) => drv.try_next().await,
263            #[cfg(all(unix, feature = "tls"))]
264            Self::UnixTls(drv) => drv.try_next().await,
265            #[cfg(feature = "quic")]
266            Self::Quic(drv) => drv.try_next().await,
267        }
268    }
269}
270
271impl IntoFuture for Driver {
272    type Output = Result<(), Error>;
273    type IntoFuture = Pin<Box<dyn Future<Output = Self::Output> + Send>>;
274
275    fn into_future(mut self) -> Self::IntoFuture {
276        Box::pin(async move {
277            while self.try_next().await?.is_some() {}
278            Ok(())
279        })
280    }
281}