1#![doc = include_str!("../README.md")]
2#![cfg_attr(docsrs, feature(doc_cfg))]
3
4use std::future::Future;
5use std::io;
6use std::net::SocketAddr;
7use std::pin::Pin;
8use std::sync::Arc;
9
10#[cfg(feature = "use-rustls-no-provider")]
11use rustls_native_certs::load_native_certs;
12use tokio::net::{TcpSocket, TcpStream, lookup_host};
13#[cfg(feature = "use-native-tls")]
14use tokio_native_tls::native_tls::TlsConnector;
15#[cfg(feature = "use-rustls-no-provider")]
16use tokio_rustls::rustls::{ClientConfig, RootCertStore};
17
18#[cfg(feature = "proxy")]
19mod proxy;
20#[cfg(any(feature = "use-rustls-no-provider", feature = "use-native-tls"))]
21mod tls;
22#[cfg(feature = "websocket")]
23mod websockets;
24
25#[cfg(feature = "proxy")]
26pub use proxy::{Proxy, ProxyAuth, ProxyError, ProxyType};
27#[cfg(any(feature = "use-rustls-no-provider", feature = "use-native-tls"))]
28pub use tls::Error as TlsError;
29#[cfg(any(feature = "use-rustls-no-provider", feature = "use-native-tls"))]
30pub use tls::tls_connect;
31#[cfg(all(
32 feature = "websocket",
33 feature = "use-native-tls",
34 not(feature = "use-rustls-no-provider")
35))]
36pub use tls::websocket_tls_connector;
37#[cfg(all(
38 feature = "websocket",
39 feature = "use-rustls-no-provider",
40 not(feature = "use-native-tls")
41))]
42pub use tls::websocket_tls_connector;
43#[cfg(feature = "websocket")]
44pub use websockets::{UrlError, ValidationError, WsAdapter, split_url, validate_response_headers};
45
46#[cfg(not(feature = "websocket"))]
47pub trait AsyncReadWrite:
48 tokio::io::AsyncRead + tokio::io::AsyncWrite + Send + Sync + Unpin
49{
50}
51#[cfg(not(feature = "websocket"))]
52impl<T> AsyncReadWrite for T where
53 T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Send + Sync + Unpin
54{
55}
56
57#[cfg(feature = "websocket")]
58pub trait AsyncReadWrite: tokio::io::AsyncRead + tokio::io::AsyncWrite + Send + Unpin {}
59#[cfg(feature = "websocket")]
60impl<T> AsyncReadWrite for T where T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Send + Unpin {}
61
62pub type DynAsyncReadWrite = Box<dyn AsyncReadWrite>;
63
64pub type SocketConnector = Arc<
66 dyn Fn(
67 String,
68 NetworkOptions,
69 ) -> Pin<Box<dyn Future<Output = Result<DynAsyncReadWrite, io::Error>> + Send>>
70 + Send
71 + Sync,
72>;
73
74#[derive(Clone, Debug)]
76#[cfg(any(feature = "use-rustls-no-provider", feature = "use-native-tls"))]
77pub enum TlsConfiguration {
78 #[cfg(feature = "use-rustls-no-provider")]
79 Simple {
80 ca: Vec<u8>,
82 alpn: Option<Vec<Vec<u8>>>,
84 client_auth: Option<(Vec<u8>, Vec<u8>)>,
86 },
87 #[cfg(feature = "use-native-tls")]
88 SimpleNative {
89 ca: Vec<u8>,
91 client_auth: Option<(Vec<u8>, String)>,
94 },
95 #[cfg(feature = "use-rustls-no-provider")]
96 Rustls(Arc<ClientConfig>),
98 #[cfg(feature = "use-native-tls")]
99 Native,
101 #[cfg(feature = "use-native-tls")]
102 NativeConnector(TlsConnector),
104}
105
106#[cfg(any(feature = "use-rustls-no-provider", feature = "use-native-tls"))]
107impl TlsConfiguration {
108 #[cfg(feature = "use-rustls-no-provider")]
109 #[must_use]
110 pub fn default_rustls() -> Self {
117 let mut root_cert_store = RootCertStore::empty();
118 for cert in load_native_certs().expect("could not load platform certs") {
119 root_cert_store.add(cert).unwrap();
120 }
121
122 let tls_config = ClientConfig::builder()
123 .with_root_certificates(root_cert_store)
124 .with_no_client_auth();
125
126 Self::Rustls(Arc::new(tls_config))
127 }
128
129 #[cfg(feature = "use-native-tls")]
130 #[must_use]
131 pub fn simple_native(ca: Vec<u8>, client_auth: Option<(Vec<u8>, String)>) -> Self {
134 Self::SimpleNative { ca, client_auth }
135 }
136
137 #[cfg(feature = "use-native-tls")]
138 #[must_use]
139 pub fn default_native() -> Self {
140 Self::Native
141 }
142}
143
144#[cfg(all(feature = "use-rustls-no-provider", not(feature = "use-native-tls")))]
145impl Default for TlsConfiguration {
146 fn default() -> Self {
147 Self::default_rustls()
148 }
149}
150
151#[cfg(all(feature = "use-native-tls", not(feature = "use-rustls-no-provider")))]
152impl Default for TlsConfiguration {
153 fn default() -> Self {
154 Self::default_native()
155 }
156}
157
158#[cfg(feature = "use-rustls-no-provider")]
159impl From<ClientConfig> for TlsConfiguration {
160 fn from(config: ClientConfig) -> Self {
161 Self::Rustls(Arc::new(config))
162 }
163}
164
165#[cfg(feature = "use-native-tls")]
166impl From<TlsConnector> for TlsConfiguration {
167 fn from(connector: TlsConnector) -> Self {
168 TlsConfiguration::NativeConnector(connector)
169 }
170}
171
172#[derive(Clone, Debug, Default)]
174pub struct NetworkOptions {
175 tcp_send_buffer_size: Option<u32>,
176 tcp_recv_buffer_size: Option<u32>,
177 tcp_nodelay: bool,
178 conn_timeout: u64,
179 bind_addr: Option<SocketAddr>,
180 #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
181 bind_device: Option<String>,
182}
183
184impl NetworkOptions {
185 #[must_use]
186 pub const fn new() -> Self {
187 Self {
188 tcp_send_buffer_size: None,
189 tcp_recv_buffer_size: None,
190 tcp_nodelay: false,
191 conn_timeout: 5,
192 bind_addr: None,
193 #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
194 bind_device: None,
195 }
196 }
197
198 pub const fn set_tcp_nodelay(&mut self, nodelay: bool) {
199 self.tcp_nodelay = nodelay;
200 }
201
202 pub const fn set_tcp_send_buffer_size(&mut self, size: u32) {
203 self.tcp_send_buffer_size = Some(size);
204 }
205
206 pub const fn set_tcp_recv_buffer_size(&mut self, size: u32) {
207 self.tcp_recv_buffer_size = Some(size);
208 }
209
210 pub const fn set_connection_timeout(&mut self, timeout: u64) -> &mut Self {
212 self.conn_timeout = timeout;
213 self
214 }
215
216 #[must_use]
218 pub const fn connection_timeout(&self) -> u64 {
219 self.conn_timeout
220 }
221
222 pub const fn set_bind_addr(&mut self, bind_addr: SocketAddr) -> &mut Self {
233 self.bind_addr = Some(bind_addr);
234 self
235 }
236
237 #[must_use]
238 pub const fn bind_addr(&self) -> Option<SocketAddr> {
239 self.bind_addr
240 }
241
242 #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
244 #[cfg_attr(
245 docsrs,
246 doc(cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux")))
247 )]
248 pub fn set_bind_device(&mut self, bind_device: &str) -> &mut Self {
249 self.bind_device = Some(bind_device.to_string());
250 self
251 }
252}
253
254fn configure_tcp_socket(socket: &TcpSocket, network_options: &NetworkOptions) -> io::Result<()> {
255 socket.set_nodelay(network_options.tcp_nodelay)?;
256
257 if let Some(send_buff_size) = network_options.tcp_send_buffer_size {
258 socket.set_send_buffer_size(send_buff_size)?;
259 }
260 if let Some(recv_buffer_size) = network_options.tcp_recv_buffer_size {
261 socket.set_recv_buffer_size(recv_buffer_size)?;
262 }
263
264 if let Some(bind_addr) = network_options.bind_addr {
265 socket.bind(bind_addr)?;
266 }
267
268 #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
269 {
270 if let Some(bind_device) = &network_options.bind_device {
271 socket.bind_device(Some(bind_device.as_bytes()))?;
272 }
273 }
274
275 Ok(())
276}
277
278pub async fn connect_socket_addr(
287 addr: SocketAddr,
288 network_options: NetworkOptions,
289) -> io::Result<TcpStream> {
290 let socket = match addr {
291 SocketAddr::V4(_) => TcpSocket::new_v4()?,
292 SocketAddr::V6(_) => TcpSocket::new_v6()?,
293 };
294
295 configure_tcp_socket(&socket, &network_options)?;
296 socket.connect(addr).await
297}
298
299pub async fn default_socket_connect(
310 host: String,
311 network_options: NetworkOptions,
312) -> io::Result<TcpStream> {
313 let addrs = lookup_host(host).await?;
314 let mut last_err = None;
315
316 for addr in addrs {
317 match connect_socket_addr(addr, network_options.clone()).await {
318 Ok(stream) => return Ok(stream),
319 Err(err) => {
320 last_err = Some(err);
321 }
322 }
323 }
324
325 Err(last_err.unwrap_or_else(|| {
326 io::Error::new(
327 io::ErrorKind::InvalidInput,
328 "could not resolve to any address",
329 )
330 }))
331}
332
333#[cfg(test)]
334mod tests {
335 #[cfg(any(feature = "use-rustls-no-provider", feature = "use-native-tls"))]
336 use super::TlsConfiguration;
337 use super::{NetworkOptions, connect_socket_addr, default_socket_connect};
338 use std::io;
339 use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
340 use tokio::net::TcpListener;
341
342 #[cfg(all(
343 feature = "use-rustls-no-provider",
344 any(feature = "use-rustls-aws-lc", feature = "use-rustls-ring")
345 ))]
346 #[test]
347 fn default_rustls_returns_rustls_variant() {
348 assert!(matches!(
349 TlsConfiguration::default_rustls(),
350 TlsConfiguration::Rustls(_)
351 ));
352 }
353
354 #[cfg(feature = "use-native-tls")]
355 #[test]
356 fn default_native_returns_native_variant() {
357 assert!(matches!(
358 TlsConfiguration::default_native(),
359 TlsConfiguration::Native
360 ));
361 }
362
363 #[cfg(feature = "use-native-tls")]
364 #[test]
365 fn simple_native_returns_simple_native_variant() {
366 let config = TlsConfiguration::simple_native(
367 Vec::from("Test CA"),
368 Some((vec![1, 2, 3], String::from("secret"))),
369 );
370
371 assert!(matches!(
372 config,
373 TlsConfiguration::SimpleNative {
374 ca,
375 client_auth: Some((identity, password))
376 } if ca == b"Test CA" && identity == vec![1, 2, 3] && password == "secret"
377 ));
378 }
379
380 #[tokio::test]
381 async fn connect_socket_addr_succeeds_with_ipv4_bind_addr() {
382 let listener = TcpListener::bind((Ipv4Addr::LOCALHOST, 0)).await.unwrap();
383 let listener_addr = listener.local_addr().unwrap();
384
385 let accept = tokio::spawn(async move {
386 let (stream, peer_addr) = listener.accept().await.unwrap();
387 drop(stream);
388 peer_addr
389 });
390
391 let mut network_options = NetworkOptions::new();
392 network_options.set_bind_addr(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0)));
393
394 let stream = connect_socket_addr(listener_addr, network_options)
395 .await
396 .unwrap();
397 let local_addr = stream.local_addr().unwrap();
398 assert!(local_addr.ip().is_loopback());
399 drop(stream);
400
401 let peer_addr = accept.await.unwrap();
402 assert_eq!(peer_addr.ip(), local_addr.ip());
403 }
404
405 #[tokio::test]
406 async fn connect_socket_addr_returns_error_for_mismatched_bind_addr_family() {
407 let listener = TcpListener::bind((Ipv4Addr::LOCALHOST, 0)).await.unwrap();
408 let listener_addr = listener.local_addr().unwrap();
409
410 let mut network_options = NetworkOptions::new();
411 network_options.set_bind_addr(SocketAddr::V6(SocketAddrV6::new(
412 Ipv6Addr::LOCALHOST,
413 0,
414 0,
415 0,
416 )));
417
418 let err = connect_socket_addr(listener_addr, network_options)
419 .await
420 .unwrap_err();
421 assert_ne!(err.kind(), io::ErrorKind::WouldBlock);
422 }
423
424 #[tokio::test]
425 async fn connect_socket_addr_succeeds_with_ipv6_bind_addr() {
426 let listener = match TcpListener::bind((Ipv6Addr::LOCALHOST, 0)).await {
427 Ok(listener) => listener,
428 Err(_) => return,
429 };
430 let listener_addr = listener.local_addr().unwrap();
431
432 let accept = tokio::spawn(async move {
433 let (stream, peer_addr) = listener.accept().await.unwrap();
434 drop(stream);
435 peer_addr
436 });
437
438 let mut network_options = NetworkOptions::new();
439 network_options.set_bind_addr(SocketAddr::V6(SocketAddrV6::new(
440 Ipv6Addr::LOCALHOST,
441 0,
442 0,
443 0,
444 )));
445
446 let stream = connect_socket_addr(listener_addr, network_options)
447 .await
448 .unwrap();
449 let local_addr = stream.local_addr().unwrap();
450 assert_eq!(local_addr.ip(), IpAddr::V6(Ipv6Addr::LOCALHOST));
451 drop(stream);
452
453 let peer_addr = accept.await.unwrap();
454 assert_eq!(peer_addr.ip(), local_addr.ip());
455 }
456
457 #[tokio::test]
458 async fn default_socket_connect_still_connects_without_bind_addr() {
459 let listener = TcpListener::bind((Ipv4Addr::LOCALHOST, 0)).await.unwrap();
460 let addr = listener.local_addr().unwrap();
461
462 let accept = tokio::spawn(async move {
463 let (stream, _) = listener.accept().await.unwrap();
464 drop(stream);
465 });
466
467 let stream = default_socket_connect(addr.to_string(), NetworkOptions::new())
468 .await
469 .unwrap();
470 assert!(stream.local_addr().unwrap().ip().is_loopback());
471 drop(stream);
472 accept.await.unwrap();
473 }
474
475 #[test]
476 fn bind_addr_returns_configured_socket_addr() {
477 let bind_addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 1883));
478 let mut network_options = NetworkOptions::new();
479 network_options.set_bind_addr(bind_addr);
480
481 assert_eq!(network_options.bind_addr(), Some(bind_addr));
482 }
483}