tokio_postgres/
connect.rs

1use crate::client::{Addr, SocketConfig};
2use crate::config::{Host, LoadBalanceHosts, TargetSessionAttrs};
3use crate::connect_raw::connect_raw;
4use crate::connect_socket::connect_socket;
5use crate::tls::MakeTlsConnect;
6use crate::{Client, Config, Connection, Error, SimpleQueryMessage, Socket};
7use futures_util::{FutureExt, Stream};
8use rand::seq::SliceRandom;
9use std::future::{self, Future};
10use std::pin::pin;
11use std::task::Poll;
12use std::{cmp, io};
13use tokio::net;
14
15pub async fn connect<T>(
16    mut tls: T,
17    config: &Config,
18) -> Result<(Client, Connection<Socket, T::Stream>), Error>
19where
20    T: MakeTlsConnect<Socket>,
21{
22    if config.host.is_empty() && config.hostaddr.is_empty() {
23        return Err(Error::config("both host and hostaddr are missing".into()));
24    }
25
26    if !config.host.is_empty()
27        && !config.hostaddr.is_empty()
28        && config.host.len() != config.hostaddr.len()
29    {
30        let msg = format!(
31            "number of hosts ({}) is different from number of hostaddrs ({})",
32            config.host.len(),
33            config.hostaddr.len(),
34        );
35        return Err(Error::config(msg.into()));
36    }
37
38    // At this point, either one of the following two scenarios could happen:
39    // (1) either config.host or config.hostaddr must be empty;
40    // (2) if both config.host and config.hostaddr are NOT empty; their lengths must be equal.
41    let num_hosts = cmp::max(config.host.len(), config.hostaddr.len());
42
43    if config.port.len() > 1 && config.port.len() != num_hosts {
44        return Err(Error::config("invalid number of ports".into()));
45    }
46
47    let mut indices = (0..num_hosts).collect::<Vec<_>>();
48    if config.load_balance_hosts == LoadBalanceHosts::Random {
49        indices.shuffle(&mut rand::rng());
50    }
51
52    let mut error = None;
53    for i in indices {
54        let host = config.host.get(i);
55        let hostaddr = config.hostaddr.get(i);
56        let port = config
57            .port
58            .get(i)
59            .or_else(|| config.port.first())
60            .copied()
61            .unwrap_or(5432);
62
63        // The value of host is used as the hostname for TLS validation,
64        let hostname = match host {
65            Some(Host::Tcp(host)) => Some(host.clone()),
66            // postgres doesn't support TLS over unix sockets, so the choice here doesn't matter
67            #[cfg(unix)]
68            Some(Host::Unix(_)) => None,
69            None => None,
70        };
71
72        // Try to use the value of hostaddr to establish the TCP connection,
73        // fallback to host if hostaddr is not present.
74        let addr = match hostaddr {
75            Some(ipaddr) => Host::Tcp(ipaddr.to_string()),
76            None => host.cloned().unwrap(),
77        };
78
79        match connect_host(addr, hostname, port, &mut tls, config).await {
80            Ok((client, connection)) => return Ok((client, connection)),
81            Err(e) => error = Some(e),
82        }
83    }
84
85    Err(error.unwrap())
86}
87
88async fn connect_host<T>(
89    host: Host,
90    hostname: Option<String>,
91    port: u16,
92    tls: &mut T,
93    config: &Config,
94) -> Result<(Client, Connection<Socket, T::Stream>), Error>
95where
96    T: MakeTlsConnect<Socket>,
97{
98    match host {
99        Host::Tcp(host) => {
100            let mut addrs = net::lookup_host((&*host, port))
101                .await
102                .map_err(Error::connect)?
103                .collect::<Vec<_>>();
104
105            if config.load_balance_hosts == LoadBalanceHosts::Random {
106                addrs.shuffle(&mut rand::rng());
107            }
108
109            let mut last_err = None;
110            for addr in addrs {
111                match connect_once(Addr::Tcp(addr.ip()), hostname.as_deref(), port, tls, config)
112                    .await
113                {
114                    Ok(stream) => return Ok(stream),
115                    Err(e) => {
116                        last_err = Some(e);
117                        continue;
118                    }
119                };
120            }
121
122            Err(last_err.unwrap_or_else(|| {
123                Error::connect(io::Error::new(
124                    io::ErrorKind::InvalidInput,
125                    "could not resolve any addresses",
126                ))
127            }))
128        }
129        #[cfg(unix)]
130        Host::Unix(path) => {
131            connect_once(Addr::Unix(path), hostname.as_deref(), port, tls, config).await
132        }
133    }
134}
135
136async fn connect_once<T>(
137    addr: Addr,
138    hostname: Option<&str>,
139    port: u16,
140    tls: &mut T,
141    config: &Config,
142) -> Result<(Client, Connection<Socket, T::Stream>), Error>
143where
144    T: MakeTlsConnect<Socket>,
145{
146    let socket = connect_socket(
147        &addr,
148        port,
149        config.connect_timeout,
150        config.tcp_user_timeout,
151        if config.keepalives {
152            Some(&config.keepalive_config)
153        } else {
154            None
155        },
156    )
157    .await?;
158
159    let tls = tls
160        .make_tls_connect(hostname.unwrap_or(""))
161        .map_err(|e| Error::tls(e.into()))?;
162    let has_hostname = hostname.is_some();
163    let (mut client, mut connection) = connect_raw(socket, tls, has_hostname, config).await?;
164
165    if config.target_session_attrs != TargetSessionAttrs::Any {
166        let mut rows = pin!(client.simple_query_raw("SHOW transaction_read_only"));
167
168        let mut rows = pin!(
169            future::poll_fn(|cx| {
170                if connection.poll_unpin(cx)?.is_ready() {
171                    return Poll::Ready(Err(Error::closed()));
172                }
173
174                rows.as_mut().poll(cx)
175            })
176            .await?
177        );
178
179        loop {
180            let next = future::poll_fn(|cx| {
181                if connection.poll_unpin(cx)?.is_ready() {
182                    return Poll::Ready(Some(Err(Error::closed())));
183                }
184
185                rows.as_mut().poll_next(cx)
186            });
187
188            match next.await.transpose()? {
189                Some(SimpleQueryMessage::Row(row)) => {
190                    let read_only_result = row.try_get(0)?;
191                    if read_only_result == Some("on")
192                        && config.target_session_attrs == TargetSessionAttrs::ReadWrite
193                    {
194                        return Err(Error::connect(io::Error::new(
195                            io::ErrorKind::PermissionDenied,
196                            "database does not allow writes",
197                        )));
198                    } else if read_only_result == Some("off")
199                        && config.target_session_attrs == TargetSessionAttrs::ReadOnly
200                    {
201                        return Err(Error::connect(io::Error::new(
202                            io::ErrorKind::PermissionDenied,
203                            "database is not read only",
204                        )));
205                    } else {
206                        break;
207                    }
208                }
209                Some(_) => {}
210                None => return Err(Error::unexpected_message()),
211            }
212        }
213    }
214
215    client.set_socket_config(SocketConfig {
216        addr,
217        hostname: hostname.map(|s| s.to_string()),
218        port,
219        connect_timeout: config.connect_timeout,
220        tcp_user_timeout: config.tcp_user_timeout,
221        keepalive: if config.keepalives {
222            Some(config.keepalive_config.clone())
223        } else {
224            None
225        },
226    });
227
228    Ok((client, connection))
229}