xitca_postgres/driver/
connect.rs

1use core::net::SocketAddr;
2
3use xitca_io::net::TcpStream;
4
5use crate::{
6    config::{Config, Host},
7    error::Error,
8    session::{Addr, ConnectInfo, Session},
9};
10
11use super::{
12    dns_resolve,
13    generic::{DriverTx, GenericDriver},
14    prepare_driver, should_connect_tls, Driver,
15};
16
17#[cold]
18#[inline(never)]
19pub(super) async fn connect_host(host: Host, cfg: &mut Config) -> Result<(DriverTx, Session, Driver), Error> {
20    async fn connect_tcp(host: &str, ports: &[u16]) -> Result<(TcpStream, SocketAddr), Error> {
21        let addrs = dns_resolve(host, ports).await?;
22
23        let mut err = None;
24
25        for addr in addrs {
26            match TcpStream::connect(addr).await {
27                Ok(stream) => {
28                    let _ = stream.set_nodelay(true);
29                    return Ok((stream, addr));
30                }
31                Err(e) => err = Some(e),
32            }
33        }
34
35        Err(err.unwrap().into())
36    }
37
38    let ssl_mode = cfg.get_ssl_mode();
39    let ssl_negotiation = cfg.get_ssl_negotiation();
40
41    match host {
42        Host::Tcp(host) => {
43            let (mut io, addr) = connect_tcp(&host, cfg.get_ports()).await?;
44            if should_connect_tls(&mut io, ssl_mode, ssl_negotiation).await? {
45                #[cfg(feature = "tls")]
46                {
47                    let io = super::tls::connect_tls(io, &host, cfg).await?;
48                    let info = ConnectInfo::new(Addr::Tcp(host, addr), ssl_mode, ssl_negotiation);
49                    prepare_driver(info, io, cfg)
50                        .await
51                        .map(|(tx, session, drv)| (tx, session, Driver::Tls(drv)))
52                }
53                #[cfg(not(feature = "tls"))]
54                {
55                    Err(crate::error::FeatureError::Tls.into())
56                }
57            } else {
58                let info = ConnectInfo::new(Addr::Tcp(host, addr), ssl_mode, ssl_negotiation);
59                prepare_driver(info, io, cfg)
60                    .await
61                    .map(|(tx, session, drv)| (tx, session, Driver::Tcp(drv)))
62            }
63        }
64        #[cfg(not(unix))]
65        Host::Unix(_) => Err(crate::error::SystemError::Unix.into()),
66        #[cfg(unix)]
67        Host::Unix(host) => {
68            let mut io = xitca_io::net::UnixStream::connect(&host).await?;
69            let host_str: Box<str> = host.to_string_lossy().into();
70            if should_connect_tls(&mut io, ssl_mode, ssl_negotiation).await? {
71                #[cfg(feature = "tls")]
72                {
73                    let io = super::tls::connect_tls(io, host_str.as_ref(), cfg).await?;
74                    let info = ConnectInfo::new(Addr::Unix(host_str, host), ssl_mode, ssl_negotiation);
75                    prepare_driver(info, io, cfg)
76                        .await
77                        .map(|(tx, session, drv)| (tx, session, Driver::UnixTls(drv)))
78                }
79                #[cfg(not(feature = "tls"))]
80                {
81                    Err(crate::error::FeatureError::Tls.into())
82                }
83            } else {
84                let info = ConnectInfo::new(Addr::Unix(host_str, host), ssl_mode, ssl_negotiation);
85                prepare_driver(info, io, cfg)
86                    .await
87                    .map(|(tx, session, drv)| (tx, session, Driver::Unix(drv)))
88            }
89        }
90        #[cfg(not(feature = "quic"))]
91        Host::Quic(_) => Err(crate::error::FeatureError::Quic.into()),
92        #[cfg(feature = "quic")]
93        Host::Quic(host) => {
94            let (io, addr) = super::quic::connect_quic(&host, cfg.get_ports()).await?;
95            let info = ConnectInfo::new(Addr::Quic(host, addr), ssl_mode, ssl_negotiation);
96            prepare_driver(info, io, cfg)
97                .await
98                .map(|(tx, session, drv)| (tx, session, Driver::Quic(drv)))
99        }
100    }
101}
102
103#[cold]
104#[inline(never)]
105pub(super) async fn connect_info(info: ConnectInfo) -> Result<(DriverTx, Driver), Error> {
106    let ConnectInfo {
107        addr,
108        ssl_mode,
109        ssl_negotiation,
110    } = info;
111    match addr {
112        Addr::Tcp(_host, addr) => {
113            let mut io = TcpStream::connect(addr).await?;
114            let _ = io.set_nodelay(true);
115
116            if should_connect_tls(&mut io, ssl_mode, ssl_negotiation).await? {
117                #[cfg(feature = "tls")]
118                {
119                    let io = super::tls::connect_tls(io, &_host, &mut Config::default()).await?;
120                    let (io, tx) = GenericDriver::new(io);
121                    Ok((tx, Driver::Tls(io)))
122                }
123                #[cfg(not(feature = "tls"))]
124                {
125                    Err(crate::error::FeatureError::Tls.into())
126                }
127            } else {
128                let (io, tx) = GenericDriver::new(io);
129                Ok((tx, Driver::Tcp(io)))
130            }
131        }
132        #[cfg(unix)]
133        Addr::Unix(_host, path) => {
134            let mut io = xitca_io::net::UnixStream::connect(path).await?;
135            if should_connect_tls(&mut io, ssl_mode, ssl_negotiation).await? {
136                #[cfg(feature = "tls")]
137                {
138                    let io = super::tls::connect_tls(io, &_host, &mut Config::default()).await?;
139                    let (io, tx) = GenericDriver::new(io);
140                    Ok((tx, Driver::UnixTls(io)))
141                }
142                #[cfg(not(feature = "tls"))]
143                {
144                    Err(crate::error::FeatureError::Tls.into())
145                }
146            } else {
147                let (io, tx) = GenericDriver::new(io);
148                Ok((tx, Driver::Unix(io)))
149            }
150        }
151        #[cfg(feature = "quic")]
152        Addr::Quic(host, addr) => {
153            let io = super::quic::connect_quic_addr(&host, addr).await?;
154            let (io, tx) = GenericDriver::new(io);
155            Ok((tx, Driver::Quic(io)))
156        }
157        Addr::None => Err(Error::todo()),
158    }
159}