madsim_tokio_postgres/
connect.rs

1use crate::client::SocketConfig;
2use crate::config::{Host, TargetSessionAttrs};
3use crate::connect_raw::connect_raw;
4use crate::connect_socket::connect_socket;
5use crate::tls::{MakeTlsConnect, TlsConnect};
6use crate::{Client, Config, Connection, Error, SimpleQueryMessage, Socket};
7use futures::{future, pin_mut, Future, FutureExt, Stream};
8use std::io;
9use std::task::Poll;
10
11pub async fn connect<T>(
12    mut tls: T,
13    config: &Config,
14) -> Result<(Client, Connection<Socket, T::Stream>), Error>
15where
16    T: MakeTlsConnect<Socket>,
17{
18    if config.host.is_empty() {
19        return Err(Error::config("host missing".into()));
20    }
21
22    if config.port.len() > 1 && config.port.len() != config.host.len() {
23        return Err(Error::config("invalid number of ports".into()));
24    }
25
26    let mut error = None;
27    for (i, host) in config.host.iter().enumerate() {
28        let port = config
29            .port
30            .get(i)
31            .or_else(|| config.port.get(0))
32            .copied()
33            .unwrap_or(5432);
34
35        let hostname = match host {
36            Host::Tcp(host) => host.as_str(),
37            // postgres doesn't support TLS over unix sockets, so the choice here doesn't matter
38            #[cfg(all(unix, not(madsim)))]
39            Host::Unix(_) => "",
40        };
41
42        let tls = tls
43            .make_tls_connect(hostname)
44            .map_err(|e| Error::tls(e.into()))?;
45
46        match connect_once(host, port, tls, config).await {
47            Ok((client, connection)) => return Ok((client, connection)),
48            Err(e) => error = Some(e),
49        }
50    }
51
52    Err(error.unwrap())
53}
54
55async fn connect_once<T>(
56    host: &Host,
57    port: u16,
58    tls: T,
59    config: &Config,
60) -> Result<(Client, Connection<Socket, T::Stream>), Error>
61where
62    T: TlsConnect<Socket>,
63{
64    let socket = connect_socket(
65        host,
66        port,
67        config.connect_timeout,
68        config.keepalives,
69        config.keepalives_idle,
70    )
71    .await?;
72    let (mut client, mut connection) = connect_raw(socket, tls, config).await?;
73
74    if let TargetSessionAttrs::ReadWrite = config.target_session_attrs {
75        let rows = client.simple_query_raw("SHOW transaction_read_only");
76        pin_mut!(rows);
77
78        let rows = future::poll_fn(|cx| {
79            if connection.poll_unpin(cx)?.is_ready() {
80                return Poll::Ready(Err(Error::closed()));
81            }
82
83            rows.as_mut().poll(cx)
84        })
85        .await?;
86        pin_mut!(rows);
87
88        loop {
89            let next = future::poll_fn(|cx| {
90                if connection.poll_unpin(cx)?.is_ready() {
91                    return Poll::Ready(Some(Err(Error::closed())));
92                }
93
94                rows.as_mut().poll_next(cx)
95            });
96
97            match next.await.transpose()? {
98                Some(SimpleQueryMessage::Row(row)) => {
99                    if row.try_get(0)? == Some("on") {
100                        return Err(Error::connect(io::Error::new(
101                            io::ErrorKind::PermissionDenied,
102                            "database does not allow writes",
103                        )));
104                    } else {
105                        break;
106                    }
107                }
108                Some(_) => {}
109                None => return Err(Error::unexpected_message()),
110            }
111        }
112    }
113
114    client.set_socket_config(SocketConfig {
115        host: host.clone(),
116        port,
117        connect_timeout: config.connect_timeout,
118        keepalives: config.keepalives,
119        keepalives_idle: config.keepalives_idle,
120    });
121
122    Ok((client, connection))
123}