madsim_tokio_postgres/
connect.rs1use crate::client::SocketConfig;
2use crate::config::{Host, TargetSessionAttrs};
3use crate::connect_raw::connect_raw;
4use crate::connect_socket::connect_socket;
5use crate::tls::{MakeTlsConnect, TlsConnect};
6use crate::{Client, Config, Connection, Error, SimpleQueryMessage, Socket};
7use futures::{future, pin_mut, Future, FutureExt, Stream};
8use std::io;
9use std::task::Poll;
10
11pub async fn connect<T>(
12 mut tls: T,
13 config: &Config,
14) -> Result<(Client, Connection<Socket, T::Stream>), Error>
15where
16 T: MakeTlsConnect<Socket>,
17{
18 if config.host.is_empty() {
19 return Err(Error::config("host missing".into()));
20 }
21
22 if config.port.len() > 1 && config.port.len() != config.host.len() {
23 return Err(Error::config("invalid number of ports".into()));
24 }
25
26 let mut error = None;
27 for (i, host) in config.host.iter().enumerate() {
28 let port = config
29 .port
30 .get(i)
31 .or_else(|| config.port.get(0))
32 .copied()
33 .unwrap_or(5432);
34
35 let hostname = match host {
36 Host::Tcp(host) => host.as_str(),
37 #[cfg(all(unix, not(madsim)))]
39 Host::Unix(_) => "",
40 };
41
42 let tls = tls
43 .make_tls_connect(hostname)
44 .map_err(|e| Error::tls(e.into()))?;
45
46 match connect_once(host, port, tls, config).await {
47 Ok((client, connection)) => return Ok((client, connection)),
48 Err(e) => error = Some(e),
49 }
50 }
51
52 Err(error.unwrap())
53}
54
55async fn connect_once<T>(
56 host: &Host,
57 port: u16,
58 tls: T,
59 config: &Config,
60) -> Result<(Client, Connection<Socket, T::Stream>), Error>
61where
62 T: TlsConnect<Socket>,
63{
64 let socket = connect_socket(
65 host,
66 port,
67 config.connect_timeout,
68 config.keepalives,
69 config.keepalives_idle,
70 )
71 .await?;
72 let (mut client, mut connection) = connect_raw(socket, tls, config).await?;
73
74 if let TargetSessionAttrs::ReadWrite = config.target_session_attrs {
75 let rows = client.simple_query_raw("SHOW transaction_read_only");
76 pin_mut!(rows);
77
78 let rows = future::poll_fn(|cx| {
79 if connection.poll_unpin(cx)?.is_ready() {
80 return Poll::Ready(Err(Error::closed()));
81 }
82
83 rows.as_mut().poll(cx)
84 })
85 .await?;
86 pin_mut!(rows);
87
88 loop {
89 let next = future::poll_fn(|cx| {
90 if connection.poll_unpin(cx)?.is_ready() {
91 return Poll::Ready(Some(Err(Error::closed())));
92 }
93
94 rows.as_mut().poll_next(cx)
95 });
96
97 match next.await.transpose()? {
98 Some(SimpleQueryMessage::Row(row)) => {
99 if row.try_get(0)? == Some("on") {
100 return Err(Error::connect(io::Error::new(
101 io::ErrorKind::PermissionDenied,
102 "database does not allow writes",
103 )));
104 } else {
105 break;
106 }
107 }
108 Some(_) => {}
109 None => return Err(Error::unexpected_message()),
110 }
111 }
112 }
113
114 client.set_socket_config(SocketConfig {
115 host: host.clone(),
116 port,
117 connect_timeout: config.connect_timeout,
118 keepalives: config.keepalives,
119 keepalives_idle: config.keepalives_idle,
120 });
121
122 Ok((client, connection))
123}