tokio_postgres/
connect.rs1use crate::client::{Addr, SocketConfig};
2use crate::config::{Host, LoadBalanceHosts, TargetSessionAttrs};
3use crate::connect_raw::connect_raw;
4use crate::connect_socket::connect_socket;
5use crate::tls::MakeTlsConnect;
6use crate::{Client, Config, Connection, Error, SimpleQueryMessage, Socket};
7use futures_util::{FutureExt, Stream};
8use rand::seq::SliceRandom;
9use std::future::{self, Future};
10use std::pin::pin;
11use std::task::Poll;
12use std::{cmp, io};
13use tokio::net;
14
15pub async fn connect<T>(
16 mut tls: T,
17 config: &Config,
18) -> Result<(Client, Connection<Socket, T::Stream>), Error>
19where
20 T: MakeTlsConnect<Socket>,
21{
22 if config.host.is_empty() && config.hostaddr.is_empty() {
23 return Err(Error::config("both host and hostaddr are missing".into()));
24 }
25
26 if !config.host.is_empty()
27 && !config.hostaddr.is_empty()
28 && config.host.len() != config.hostaddr.len()
29 {
30 let msg = format!(
31 "number of hosts ({}) is different from number of hostaddrs ({})",
32 config.host.len(),
33 config.hostaddr.len(),
34 );
35 return Err(Error::config(msg.into()));
36 }
37
38 let num_hosts = cmp::max(config.host.len(), config.hostaddr.len());
42
43 if config.port.len() > 1 && config.port.len() != num_hosts {
44 return Err(Error::config("invalid number of ports".into()));
45 }
46
47 let mut indices = (0..num_hosts).collect::<Vec<_>>();
48 if config.load_balance_hosts == LoadBalanceHosts::Random {
49 indices.shuffle(&mut rand::rng());
50 }
51
52 let mut error = None;
53 for i in indices {
54 let host = config.host.get(i);
55 let hostaddr = config.hostaddr.get(i);
56 let port = config
57 .port
58 .get(i)
59 .or_else(|| config.port.first())
60 .copied()
61 .unwrap_or(5432);
62
63 let hostname = match host {
65 Some(Host::Tcp(host)) => Some(host.clone()),
66 #[cfg(unix)]
68 Some(Host::Unix(_)) => None,
69 None => None,
70 };
71
72 let addr = match hostaddr {
75 Some(ipaddr) => Host::Tcp(ipaddr.to_string()),
76 None => host.cloned().unwrap(),
77 };
78
79 match connect_host(addr, hostname, port, &mut tls, config).await {
80 Ok((client, connection)) => return Ok((client, connection)),
81 Err(e) => error = Some(e),
82 }
83 }
84
85 Err(error.unwrap())
86}
87
88async fn connect_host<T>(
89 host: Host,
90 hostname: Option<String>,
91 port: u16,
92 tls: &mut T,
93 config: &Config,
94) -> Result<(Client, Connection<Socket, T::Stream>), Error>
95where
96 T: MakeTlsConnect<Socket>,
97{
98 match host {
99 Host::Tcp(host) => {
100 let mut addrs = net::lookup_host((&*host, port))
101 .await
102 .map_err(Error::connect)?
103 .collect::<Vec<_>>();
104
105 if config.load_balance_hosts == LoadBalanceHosts::Random {
106 addrs.shuffle(&mut rand::rng());
107 }
108
109 let mut last_err = None;
110 for addr in addrs {
111 match connect_once(Addr::Tcp(addr.ip()), hostname.as_deref(), port, tls, config)
112 .await
113 {
114 Ok(stream) => return Ok(stream),
115 Err(e) => {
116 last_err = Some(e);
117 continue;
118 }
119 };
120 }
121
122 Err(last_err.unwrap_or_else(|| {
123 Error::connect(io::Error::new(
124 io::ErrorKind::InvalidInput,
125 "could not resolve any addresses",
126 ))
127 }))
128 }
129 #[cfg(unix)]
130 Host::Unix(path) => {
131 connect_once(Addr::Unix(path), hostname.as_deref(), port, tls, config).await
132 }
133 }
134}
135
136async fn connect_once<T>(
137 addr: Addr,
138 hostname: Option<&str>,
139 port: u16,
140 tls: &mut T,
141 config: &Config,
142) -> Result<(Client, Connection<Socket, T::Stream>), Error>
143where
144 T: MakeTlsConnect<Socket>,
145{
146 let socket = connect_socket(
147 &addr,
148 port,
149 config.connect_timeout,
150 config.tcp_user_timeout,
151 if config.keepalives {
152 Some(&config.keepalive_config)
153 } else {
154 None
155 },
156 )
157 .await?;
158
159 let tls = tls
160 .make_tls_connect(hostname.unwrap_or(""))
161 .map_err(|e| Error::tls(e.into()))?;
162 let has_hostname = hostname.is_some();
163 let (mut client, mut connection) = connect_raw(socket, tls, has_hostname, config).await?;
164
165 if config.target_session_attrs != TargetSessionAttrs::Any {
166 let mut rows = pin!(client.simple_query_raw("SHOW transaction_read_only"));
167
168 let mut rows = pin!(
169 future::poll_fn(|cx| {
170 if connection.poll_unpin(cx)?.is_ready() {
171 return Poll::Ready(Err(Error::closed()));
172 }
173
174 rows.as_mut().poll(cx)
175 })
176 .await?
177 );
178
179 loop {
180 let next = future::poll_fn(|cx| {
181 if connection.poll_unpin(cx)?.is_ready() {
182 return Poll::Ready(Some(Err(Error::closed())));
183 }
184
185 rows.as_mut().poll_next(cx)
186 });
187
188 match next.await.transpose()? {
189 Some(SimpleQueryMessage::Row(row)) => {
190 let read_only_result = row.try_get(0)?;
191 if read_only_result == Some("on")
192 && config.target_session_attrs == TargetSessionAttrs::ReadWrite
193 {
194 return Err(Error::connect(io::Error::new(
195 io::ErrorKind::PermissionDenied,
196 "database does not allow writes",
197 )));
198 } else if read_only_result == Some("off")
199 && config.target_session_attrs == TargetSessionAttrs::ReadOnly
200 {
201 return Err(Error::connect(io::Error::new(
202 io::ErrorKind::PermissionDenied,
203 "database is not read only",
204 )));
205 } else {
206 break;
207 }
208 }
209 Some(_) => {}
210 None => return Err(Error::unexpected_message()),
211 }
212 }
213 }
214
215 client.set_socket_config(SocketConfig {
216 addr,
217 hostname: hostname.map(|s| s.to_string()),
218 port,
219 connect_timeout: config.connect_timeout,
220 tcp_user_timeout: config.tcp_user_timeout,
221 keepalive: if config.keepalives {
222 Some(config.keepalive_config.clone())
223 } else {
224 None
225 },
226 });
227
228 Ok((client, connection))
229}