1#![doc = include_str!("../README.md")]
2#![cfg_attr(docsrs, feature(doc_cfg))]
3
4use std::future::Future;
5use std::io;
6use std::net::SocketAddr;
7use std::pin::Pin;
8use std::sync::Arc;
9
10#[cfg(feature = "use-rustls-no-provider")]
11use rustls_native_certs::load_native_certs;
12use tokio::net::{TcpSocket, TcpStream, lookup_host};
13#[cfg(feature = "use-native-tls")]
14use tokio_native_tls::native_tls::TlsConnector;
15#[cfg(feature = "use-rustls-no-provider")]
16use tokio_rustls::rustls::{ClientConfig, RootCertStore};
17
18#[cfg(feature = "proxy")]
19mod proxy;
20#[cfg(any(feature = "use-rustls-no-provider", feature = "use-native-tls"))]
21mod tls;
22#[cfg(feature = "websocket")]
23mod websockets;
24
25#[cfg(feature = "proxy")]
26pub use proxy::{Proxy, ProxyAuth, ProxyError, ProxyType};
27#[cfg(any(feature = "use-rustls-no-provider", feature = "use-native-tls"))]
28pub use tls::Error as TlsError;
29#[cfg(any(feature = "use-rustls-no-provider", feature = "use-native-tls"))]
30pub use tls::tls_connect;
31#[cfg(all(
32 feature = "websocket",
33 feature = "use-native-tls",
34 not(feature = "use-rustls-no-provider")
35))]
36pub use tls::websocket_tls_connector;
37#[cfg(all(
38 feature = "websocket",
39 feature = "use-rustls-no-provider",
40 not(feature = "use-native-tls")
41))]
42pub use tls::websocket_tls_connector;
43#[cfg(feature = "websocket")]
44pub use websockets::{UrlError, ValidationError, WsAdapter, split_url, validate_response_headers};
45
46#[cfg(not(feature = "websocket"))]
47pub trait AsyncReadWrite:
48 tokio::io::AsyncRead + tokio::io::AsyncWrite + Send + Sync + Unpin
49{
50}
51#[cfg(not(feature = "websocket"))]
52impl<T> AsyncReadWrite for T where
53 T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Send + Sync + Unpin
54{
55}
56
57#[cfg(feature = "websocket")]
58pub trait AsyncReadWrite: tokio::io::AsyncRead + tokio::io::AsyncWrite + Send + Unpin {}
59#[cfg(feature = "websocket")]
60impl<T> AsyncReadWrite for T where T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Send + Unpin {}
61
62pub type DynAsyncReadWrite = Box<dyn AsyncReadWrite>;
63
64pub type SocketConnector = Arc<
66 dyn Fn(
67 String,
68 NetworkOptions,
69 ) -> Pin<Box<dyn Future<Output = Result<DynAsyncReadWrite, io::Error>> + Send>>
70 + Send
71 + Sync,
72>;
73
74#[derive(Clone, Debug)]
76#[cfg(any(feature = "use-rustls-no-provider", feature = "use-native-tls"))]
77pub enum TlsConfiguration {
78 #[cfg(feature = "use-rustls-no-provider")]
79 Simple {
80 ca: Vec<u8>,
82 alpn: Option<Vec<Vec<u8>>>,
84 client_auth: Option<(Vec<u8>, Vec<u8>)>,
86 },
87 #[cfg(feature = "use-native-tls")]
88 SimpleNative {
89 ca: Vec<u8>,
91 client_auth: Option<(Vec<u8>, String)>,
94 },
95 #[cfg(feature = "use-rustls-no-provider")]
96 Rustls(Arc<ClientConfig>),
98 #[cfg(feature = "use-native-tls")]
99 Native,
101 #[cfg(feature = "use-native-tls")]
102 NativeConnector(TlsConnector),
104}
105
106#[cfg(any(feature = "use-rustls-no-provider", feature = "use-native-tls"))]
107impl TlsConfiguration {
108 #[cfg(feature = "use-rustls-no-provider")]
109 #[must_use]
110 pub fn default_rustls() -> Self {
117 let mut root_cert_store = RootCertStore::empty();
118 for cert in load_native_certs().expect("could not load platform certs") {
119 root_cert_store.add(cert).unwrap();
120 }
121
122 let tls_config = ClientConfig::builder()
123 .with_root_certificates(root_cert_store)
124 .with_no_client_auth();
125
126 Self::Rustls(Arc::new(tls_config))
127 }
128
129 #[cfg(feature = "use-native-tls")]
130 #[must_use]
131 pub fn default_native() -> Self {
132 Self::Native
133 }
134}
135
136#[cfg(all(feature = "use-rustls-no-provider", not(feature = "use-native-tls")))]
137impl Default for TlsConfiguration {
138 fn default() -> Self {
139 Self::default_rustls()
140 }
141}
142
143#[cfg(all(feature = "use-native-tls", not(feature = "use-rustls-no-provider")))]
144impl Default for TlsConfiguration {
145 fn default() -> Self {
146 Self::default_native()
147 }
148}
149
150#[cfg(feature = "use-rustls-no-provider")]
151impl From<ClientConfig> for TlsConfiguration {
152 fn from(config: ClientConfig) -> Self {
153 Self::Rustls(Arc::new(config))
154 }
155}
156
157#[cfg(feature = "use-native-tls")]
158impl From<TlsConnector> for TlsConfiguration {
159 fn from(connector: TlsConnector) -> Self {
160 TlsConfiguration::NativeConnector(connector)
161 }
162}
163
164#[derive(Clone, Debug, Default)]
166pub struct NetworkOptions {
167 tcp_send_buffer_size: Option<u32>,
168 tcp_recv_buffer_size: Option<u32>,
169 tcp_nodelay: bool,
170 conn_timeout: u64,
171 #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
172 bind_device: Option<String>,
173}
174
175impl NetworkOptions {
176 #[must_use]
177 pub const fn new() -> Self {
178 Self {
179 tcp_send_buffer_size: None,
180 tcp_recv_buffer_size: None,
181 tcp_nodelay: false,
182 conn_timeout: 5,
183 #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
184 bind_device: None,
185 }
186 }
187
188 pub const fn set_tcp_nodelay(&mut self, nodelay: bool) {
189 self.tcp_nodelay = nodelay;
190 }
191
192 pub const fn set_tcp_send_buffer_size(&mut self, size: u32) {
193 self.tcp_send_buffer_size = Some(size);
194 }
195
196 pub const fn set_tcp_recv_buffer_size(&mut self, size: u32) {
197 self.tcp_recv_buffer_size = Some(size);
198 }
199
200 pub const fn set_connection_timeout(&mut self, timeout: u64) -> &mut Self {
202 self.conn_timeout = timeout;
203 self
204 }
205
206 #[must_use]
208 pub const fn connection_timeout(&self) -> u64 {
209 self.conn_timeout
210 }
211
212 #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
214 #[cfg_attr(
215 docsrs,
216 doc(cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux")))
217 )]
218 pub fn set_bind_device(&mut self, bind_device: &str) -> &mut Self {
219 self.bind_device = Some(bind_device.to_string());
220 self
221 }
222}
223
224pub async fn default_socket_connect(
235 host: String,
236 network_options: NetworkOptions,
237) -> io::Result<TcpStream> {
238 let addrs = lookup_host(host).await?;
239 let mut last_err = None;
240
241 for addr in addrs {
242 let socket = match addr {
243 SocketAddr::V4(_) => TcpSocket::new_v4()?,
244 SocketAddr::V6(_) => TcpSocket::new_v6()?,
245 };
246
247 socket.set_nodelay(network_options.tcp_nodelay)?;
248
249 if let Some(send_buff_size) = network_options.tcp_send_buffer_size {
250 socket.set_send_buffer_size(send_buff_size)?;
251 }
252 if let Some(recv_buffer_size) = network_options.tcp_recv_buffer_size {
253 socket.set_recv_buffer_size(recv_buffer_size)?;
254 }
255
256 #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
257 {
258 if let Some(bind_device) = &network_options.bind_device {
259 socket.bind_device(Some(bind_device.as_bytes()))?;
260 }
261 }
262
263 match socket.connect(addr).await {
264 Ok(s) => return Ok(s),
265 Err(e) => {
266 last_err = Some(e);
267 }
268 }
269 }
270
271 Err(last_err.unwrap_or_else(|| {
272 io::Error::new(
273 io::ErrorKind::InvalidInput,
274 "could not resolve to any address",
275 )
276 }))
277}
278
279#[cfg(test)]
280mod tests {
281 use super::TlsConfiguration;
282
283 #[cfg(all(
284 feature = "use-rustls-no-provider",
285 any(feature = "use-rustls-aws-lc", feature = "use-rustls-ring")
286 ))]
287 #[test]
288 fn default_rustls_returns_rustls_variant() {
289 assert!(matches!(
290 TlsConfiguration::default_rustls(),
291 TlsConfiguration::Rustls(_)
292 ));
293 }
294
295 #[cfg(feature = "use-native-tls")]
296 #[test]
297 fn default_native_returns_native_variant() {
298 assert!(matches!(
299 TlsConfiguration::default_native(),
300 TlsConfiguration::Native
301 ));
302 }
303}