madsim_tokio_postgres/
connect_socket.rs

1use crate::config::Host;
2use crate::{Error, Socket};
3#[cfg(not(madsim))]
4use socket2::{SockRef, TcpKeepalive};
5use std::future::Future;
6use std::io;
7use std::time::Duration;
8#[cfg(all(unix, not(madsim)))]
9use tokio::net::UnixStream;
10use tokio::net::{self, TcpStream};
11use tokio::time;
12
13#[cfg_attr(madsim, allow(unused_variables))]
14pub(crate) async fn connect_socket(
15    host: &Host,
16    port: u16,
17    connect_timeout: Option<Duration>,
18    keepalives: bool,
19    keepalives_idle: Duration,
20) -> Result<Socket, Error> {
21    match host {
22        Host::Tcp(host) => {
23            let addrs = net::lookup_host((&**host, port))
24                .await
25                .map_err(Error::connect)?;
26
27            let mut last_err = None;
28
29            for addr in addrs {
30                let stream =
31                    match connect_with_timeout(TcpStream::connect(addr), connect_timeout).await {
32                        Ok(stream) => stream,
33                        Err(e) => {
34                            last_err = Some(e);
35                            continue;
36                        }
37                    };
38
39                stream.set_nodelay(true).map_err(Error::connect)?;
40                #[cfg(not(madsim))] // TODO: simulate keep alive
41                if keepalives {
42                    SockRef::from(&stream)
43                        .set_tcp_keepalive(&TcpKeepalive::new().with_time(keepalives_idle))
44                        .map_err(Error::connect)?;
45                }
46
47                return Ok(Socket::new_tcp(stream));
48            }
49
50            Err(last_err.unwrap_or_else(|| {
51                Error::connect(io::Error::new(
52                    io::ErrorKind::InvalidInput,
53                    "could not resolve any addresses",
54                ))
55            }))
56        }
57        #[cfg(all(unix, not(madsim)))]
58        Host::Unix(path) => {
59            let path = path.join(format!(".s.PGSQL.{}", port));
60            let socket = connect_with_timeout(UnixStream::connect(path), connect_timeout).await?;
61            Ok(Socket::new_unix(socket))
62        }
63    }
64}
65
66async fn connect_with_timeout<F, T>(connect: F, timeout: Option<Duration>) -> Result<T, Error>
67where
68    F: Future<Output = io::Result<T>>,
69{
70    match timeout {
71        Some(timeout) => match time::timeout(timeout, connect).await {
72            Ok(Ok(socket)) => Ok(socket),
73            Ok(Err(e)) => Err(Error::connect(e)),
74            Err(_) => Err(Error::connect(io::Error::new(
75                io::ErrorKind::TimedOut,
76                "connection timed out",
77            ))),
78        },
79        None => match connect.await {
80            Ok(socket) => Ok(socket),
81            Err(e) => Err(Error::connect(e)),
82        },
83    }
84}