1#![doc = include_str!("../README.md")]
2#![cfg_attr(docsrs, feature(doc_cfg))]
3
4use std::future::Future;
5use std::io;
6use std::net::SocketAddr;
7use std::pin::Pin;
8use std::sync::Arc;
9
10#[cfg(feature = "use-rustls-no-provider")]
11use rustls_native_certs::load_native_certs;
12use tokio::net::{TcpSocket, TcpStream, lookup_host};
13#[cfg(feature = "use-native-tls")]
14use tokio_native_tls::native_tls::TlsConnector;
15#[cfg(feature = "use-rustls-no-provider")]
16use tokio_rustls::rustls::{ClientConfig, RootCertStore};
17
18#[cfg(feature = "proxy")]
19mod proxy;
20mod scheduler;
21#[cfg(any(feature = "use-rustls-no-provider", feature = "use-native-tls"))]
22mod tls;
23#[cfg(feature = "websocket")]
24mod websockets;
25
26#[cfg(feature = "proxy")]
27pub use proxy::{Proxy, ProxyAuth, ProxyError, ProxyType};
28pub use scheduler::{OutboundScheduler, RequestClass, RequestReadiness, ScheduledRequest};
29#[cfg(any(feature = "use-rustls-no-provider", feature = "use-native-tls"))]
30pub use tls::Error as TlsError;
31#[cfg(any(feature = "use-rustls-no-provider", feature = "use-native-tls"))]
32pub use tls::tls_connect;
33#[cfg(all(
34 feature = "websocket",
35 feature = "use-native-tls",
36 not(feature = "use-rustls-no-provider")
37))]
38pub use tls::websocket_tls_connector;
39#[cfg(all(
40 feature = "websocket",
41 feature = "use-rustls-no-provider",
42 not(feature = "use-native-tls")
43))]
44pub use tls::websocket_tls_connector;
45#[cfg(feature = "websocket")]
46pub use websockets::{UrlError, ValidationError, WsAdapter, split_url, validate_response_headers};
47
48#[cfg(not(feature = "websocket"))]
49pub trait AsyncReadWrite:
50 tokio::io::AsyncRead + tokio::io::AsyncWrite + Send + Sync + Unpin
51{
52}
53#[cfg(not(feature = "websocket"))]
54impl<T> AsyncReadWrite for T where
55 T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Send + Sync + Unpin
56{
57}
58
59#[cfg(feature = "websocket")]
60pub trait AsyncReadWrite: tokio::io::AsyncRead + tokio::io::AsyncWrite + Send + Unpin {}
61#[cfg(feature = "websocket")]
62impl<T> AsyncReadWrite for T where T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Send + Unpin {}
63
64pub type DynAsyncReadWrite = Box<dyn AsyncReadWrite>;
65
66pub type SocketConnector = Arc<
68 dyn Fn(
69 String,
70 NetworkOptions,
71 ) -> Pin<Box<dyn Future<Output = Result<DynAsyncReadWrite, io::Error>> + Send>>
72 + Send
73 + Sync,
74>;
75
76#[derive(Clone, Debug)]
78#[cfg(any(feature = "use-rustls-no-provider", feature = "use-native-tls"))]
79pub enum TlsConfiguration {
80 #[cfg(feature = "use-rustls-no-provider")]
81 Simple {
82 ca: Vec<u8>,
84 alpn: Option<Vec<Vec<u8>>>,
86 client_auth: Option<(Vec<u8>, Vec<u8>)>,
88 },
89 #[cfg(feature = "use-native-tls")]
90 SimpleNative {
91 ca: Vec<u8>,
93 client_auth: Option<(Vec<u8>, String)>,
96 },
97 #[cfg(feature = "use-rustls-no-provider")]
98 Rustls(Arc<ClientConfig>),
100 #[cfg(feature = "use-native-tls")]
101 Native,
103 #[cfg(feature = "use-native-tls")]
104 NativeConnector(TlsConnector),
106}
107
108#[cfg(any(feature = "use-rustls-no-provider", feature = "use-native-tls"))]
109impl TlsConfiguration {
110 #[cfg(feature = "use-rustls-no-provider")]
111 #[must_use]
112 pub fn default_rustls() -> Self {
119 let mut root_cert_store = RootCertStore::empty();
120 for cert in load_native_certs().expect("could not load platform certs") {
121 root_cert_store.add(cert).unwrap();
122 }
123
124 let tls_config = ClientConfig::builder()
125 .with_root_certificates(root_cert_store)
126 .with_no_client_auth();
127
128 Self::Rustls(Arc::new(tls_config))
129 }
130
131 #[cfg(feature = "use-native-tls")]
132 #[must_use]
133 pub fn simple_native(ca: Vec<u8>, client_auth: Option<(Vec<u8>, String)>) -> Self {
136 Self::SimpleNative { ca, client_auth }
137 }
138
139 #[cfg(feature = "use-native-tls")]
140 #[must_use]
141 pub fn default_native() -> Self {
142 Self::Native
143 }
144}
145
146#[cfg(all(feature = "use-rustls-no-provider", not(feature = "use-native-tls")))]
147impl Default for TlsConfiguration {
148 fn default() -> Self {
149 Self::default_rustls()
150 }
151}
152
153#[cfg(all(feature = "use-native-tls", not(feature = "use-rustls-no-provider")))]
154impl Default for TlsConfiguration {
155 fn default() -> Self {
156 Self::default_native()
157 }
158}
159
160#[cfg(feature = "use-rustls-no-provider")]
161impl From<ClientConfig> for TlsConfiguration {
162 fn from(config: ClientConfig) -> Self {
163 Self::Rustls(Arc::new(config))
164 }
165}
166
167#[cfg(feature = "use-native-tls")]
168impl From<TlsConnector> for TlsConfiguration {
169 fn from(connector: TlsConnector) -> Self {
170 TlsConfiguration::NativeConnector(connector)
171 }
172}
173
174#[derive(Clone, Debug, Default)]
176pub struct NetworkOptions {
177 tcp_send_buffer_size: Option<u32>,
178 tcp_recv_buffer_size: Option<u32>,
179 tcp_nodelay: bool,
180 conn_timeout: u64,
181 bind_addr: Option<SocketAddr>,
182 #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
183 bind_device: Option<String>,
184}
185
186impl NetworkOptions {
187 #[must_use]
188 pub const fn new() -> Self {
189 Self {
190 tcp_send_buffer_size: None,
191 tcp_recv_buffer_size: None,
192 tcp_nodelay: false,
193 conn_timeout: 5,
194 bind_addr: None,
195 #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
196 bind_device: None,
197 }
198 }
199
200 pub const fn set_tcp_nodelay(&mut self, nodelay: bool) {
201 self.tcp_nodelay = nodelay;
202 }
203
204 pub const fn set_tcp_send_buffer_size(&mut self, size: u32) {
205 self.tcp_send_buffer_size = Some(size);
206 }
207
208 pub const fn set_tcp_recv_buffer_size(&mut self, size: u32) {
209 self.tcp_recv_buffer_size = Some(size);
210 }
211
212 pub const fn set_connection_timeout(&mut self, timeout: u64) -> &mut Self {
214 self.conn_timeout = timeout;
215 self
216 }
217
218 #[must_use]
220 pub const fn connection_timeout(&self) -> u64 {
221 self.conn_timeout
222 }
223
224 pub const fn set_bind_addr(&mut self, bind_addr: SocketAddr) -> &mut Self {
235 self.bind_addr = Some(bind_addr);
236 self
237 }
238
239 #[must_use]
240 pub const fn bind_addr(&self) -> Option<SocketAddr> {
241 self.bind_addr
242 }
243
244 #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
246 #[cfg_attr(
247 docsrs,
248 doc(cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux")))
249 )]
250 pub fn set_bind_device(&mut self, bind_device: &str) -> &mut Self {
251 self.bind_device = Some(bind_device.to_string());
252 self
253 }
254}
255
256fn configure_tcp_socket(socket: &TcpSocket, network_options: &NetworkOptions) -> io::Result<()> {
257 socket.set_nodelay(network_options.tcp_nodelay)?;
258
259 if let Some(send_buff_size) = network_options.tcp_send_buffer_size {
260 socket.set_send_buffer_size(send_buff_size)?;
261 }
262 if let Some(recv_buffer_size) = network_options.tcp_recv_buffer_size {
263 socket.set_recv_buffer_size(recv_buffer_size)?;
264 }
265
266 if let Some(bind_addr) = network_options.bind_addr {
267 socket.bind(bind_addr)?;
268 }
269
270 #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
271 {
272 if let Some(bind_device) = &network_options.bind_device {
273 socket.bind_device(Some(bind_device.as_bytes()))?;
274 }
275 }
276
277 Ok(())
278}
279
280pub async fn connect_socket_addr(
289 addr: SocketAddr,
290 network_options: NetworkOptions,
291) -> io::Result<TcpStream> {
292 let socket = match addr {
293 SocketAddr::V4(_) => TcpSocket::new_v4()?,
294 SocketAddr::V6(_) => TcpSocket::new_v6()?,
295 };
296
297 configure_tcp_socket(&socket, &network_options)?;
298 socket.connect(addr).await
299}
300
301pub async fn default_socket_connect(
312 host: String,
313 network_options: NetworkOptions,
314) -> io::Result<TcpStream> {
315 let addrs = lookup_host(host).await?;
316 let mut last_err = None;
317
318 for addr in addrs {
319 match connect_socket_addr(addr, network_options.clone()).await {
320 Ok(stream) => return Ok(stream),
321 Err(err) => {
322 last_err = Some(err);
323 }
324 }
325 }
326
327 Err(last_err.unwrap_or_else(|| {
328 io::Error::new(
329 io::ErrorKind::InvalidInput,
330 "could not resolve to any address",
331 )
332 }))
333}
334
335#[cfg(test)]
336mod tests {
337 #[cfg(any(feature = "use-rustls-no-provider", feature = "use-native-tls"))]
338 use super::TlsConfiguration;
339 use super::{NetworkOptions, connect_socket_addr, default_socket_connect};
340 use std::io;
341 use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
342 use tokio::net::TcpListener;
343
344 #[cfg(all(
345 feature = "use-rustls-no-provider",
346 any(feature = "use-rustls-aws-lc", feature = "use-rustls-ring")
347 ))]
348 #[test]
349 fn default_rustls_returns_rustls_variant() {
350 assert!(matches!(
351 TlsConfiguration::default_rustls(),
352 TlsConfiguration::Rustls(_)
353 ));
354 }
355
356 #[cfg(feature = "use-native-tls")]
357 #[test]
358 fn default_native_returns_native_variant() {
359 assert!(matches!(
360 TlsConfiguration::default_native(),
361 TlsConfiguration::Native
362 ));
363 }
364
365 #[cfg(feature = "use-native-tls")]
366 #[test]
367 fn simple_native_returns_simple_native_variant() {
368 let config = TlsConfiguration::simple_native(
369 Vec::from("Test CA"),
370 Some((vec![1, 2, 3], String::from("secret"))),
371 );
372
373 assert!(matches!(
374 config,
375 TlsConfiguration::SimpleNative {
376 ca,
377 client_auth: Some((identity, password))
378 } if ca == b"Test CA" && identity == vec![1, 2, 3] && password == "secret"
379 ));
380 }
381
382 #[tokio::test]
383 async fn connect_socket_addr_succeeds_with_ipv4_bind_addr() {
384 let listener = TcpListener::bind((Ipv4Addr::LOCALHOST, 0)).await.unwrap();
385 let listener_addr = listener.local_addr().unwrap();
386
387 let accept = tokio::spawn(async move {
388 let (stream, peer_addr) = listener.accept().await.unwrap();
389 drop(stream);
390 peer_addr
391 });
392
393 let mut network_options = NetworkOptions::new();
394 network_options.set_bind_addr(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0)));
395
396 let stream = connect_socket_addr(listener_addr, network_options)
397 .await
398 .unwrap();
399 let local_addr = stream.local_addr().unwrap();
400 assert!(local_addr.ip().is_loopback());
401 drop(stream);
402
403 let peer_addr = accept.await.unwrap();
404 assert_eq!(peer_addr.ip(), local_addr.ip());
405 }
406
407 #[tokio::test]
408 async fn connect_socket_addr_returns_error_for_mismatched_bind_addr_family() {
409 let listener = TcpListener::bind((Ipv4Addr::LOCALHOST, 0)).await.unwrap();
410 let listener_addr = listener.local_addr().unwrap();
411
412 let mut network_options = NetworkOptions::new();
413 network_options.set_bind_addr(SocketAddr::V6(SocketAddrV6::new(
414 Ipv6Addr::LOCALHOST,
415 0,
416 0,
417 0,
418 )));
419
420 let err = connect_socket_addr(listener_addr, network_options)
421 .await
422 .unwrap_err();
423 assert_ne!(err.kind(), io::ErrorKind::WouldBlock);
424 }
425
426 #[tokio::test]
427 async fn connect_socket_addr_succeeds_with_ipv6_bind_addr() {
428 let listener = match TcpListener::bind((Ipv6Addr::LOCALHOST, 0)).await {
429 Ok(listener) => listener,
430 Err(_) => return,
431 };
432 let listener_addr = listener.local_addr().unwrap();
433
434 let accept = tokio::spawn(async move {
435 let (stream, peer_addr) = listener.accept().await.unwrap();
436 drop(stream);
437 peer_addr
438 });
439
440 let mut network_options = NetworkOptions::new();
441 network_options.set_bind_addr(SocketAddr::V6(SocketAddrV6::new(
442 Ipv6Addr::LOCALHOST,
443 0,
444 0,
445 0,
446 )));
447
448 let stream = connect_socket_addr(listener_addr, network_options)
449 .await
450 .unwrap();
451 let local_addr = stream.local_addr().unwrap();
452 assert_eq!(local_addr.ip(), IpAddr::V6(Ipv6Addr::LOCALHOST));
453 drop(stream);
454
455 let peer_addr = accept.await.unwrap();
456 assert_eq!(peer_addr.ip(), local_addr.ip());
457 }
458
459 #[tokio::test]
460 async fn default_socket_connect_still_connects_without_bind_addr() {
461 let listener = TcpListener::bind((Ipv4Addr::LOCALHOST, 0)).await.unwrap();
462 let addr = listener.local_addr().unwrap();
463
464 let accept = tokio::spawn(async move {
465 let (stream, _) = listener.accept().await.unwrap();
466 drop(stream);
467 });
468
469 let stream = default_socket_connect(addr.to_string(), NetworkOptions::new())
470 .await
471 .unwrap();
472 assert!(stream.local_addr().unwrap().ip().is_loopback());
473 drop(stream);
474 accept.await.unwrap();
475 }
476
477 #[test]
478 fn bind_addr_returns_configured_socket_addr() {
479 let bind_addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 1883));
480 let mut network_options = NetworkOptions::new();
481 network_options.set_bind_addr(bind_addr);
482
483 assert_eq!(network_options.bind_addr(), Some(bind_addr));
484 }
485}