use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Duration;
use futures_lite::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use crate::connection::tcp::{connect_happy_eyeballs, connect_tcp_addr};
#[cfg(feature = "native-tls")]
use crate::connection::tls::build_native_tls_connector_for_protocols;
#[cfg(feature = "rustls")]
use crate::connection::tls::connect_async_tls_with_config;
#[cfg(feature = "native-tls")]
use crate::connection::tls::connect_native_tls_with_connector;
use crate::dns::{DnsCache, DnsConfig};
use crate::error::{Error, ErrorKind, Result};
use crate::proxy::{Proxy, ProxyAuth, build_http_connect_request};
use crate::request::TimeoutConfig;
use crate::tls::TlsConfig;
use crate::url::Url;
pub(super) trait IoStream: AsyncRead + AsyncWrite + Unpin + Send {}
impl<T> IoStream for T where T: AsyncRead + AsyncWrite + Unpin + Send {}
pub(super) type BoxedStream = Box<dyn IoStream>;
pub(super) fn with_read_prefix(stream: BoxedStream, prefix: Vec<u8>) -> BoxedStream {
if prefix.is_empty() {
return stream;
}
Box::new(PrefixedStream {
inner: stream,
prefix,
offset: 0,
})
}
struct PrefixedStream {
inner: BoxedStream,
prefix: Vec<u8>,
offset: usize,
}
impl AsyncRead for PrefixedStream {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<std::io::Result<usize>> {
if self.offset < self.prefix.len() {
let remaining = &self.prefix[self.offset..];
let read = remaining.len().min(buf.len());
buf[..read].copy_from_slice(&remaining[..read]);
self.offset += read;
return Poll::Ready(Ok(read));
}
Pin::new(&mut self.inner).poll_read(cx, buf)
}
}
impl AsyncWrite for PrefixedStream {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
Pin::new(&mut self.inner).poll_write(cx, buf)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
Pin::new(&mut self.inner).poll_flush(cx)
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
Pin::new(&mut self.inner).poll_close(cx)
}
}
pub(super) async fn connect_h2_stream(
url: &Url,
timeout_config: TimeoutConfig,
tls_config: &TlsConfig,
prior_knowledge_h2c: bool,
dns_cache: Arc<DnsCache>,
dns_config: DnsConfig,
local_addr: Option<SocketAddr>,
proxy: Option<&Proxy>,
) -> Result<BoxedStream> {
let stream = if let Some(proxy) = proxy {
match proxy {
Proxy::Http { addr, auth } => {
connect_via_http_proxy_tunnel(
*addr,
auth.as_ref(),
url.host(),
url.effective_port(),
timeout_config,
local_addr,
)
.await?
}
Proxy::Socks5 { addr, auth } => {
connect_via_socks5_proxy(
*addr,
auth.as_ref(),
url.host(),
url.effective_port(),
timeout_config,
local_addr,
)
.await?
}
}
} else {
let addrs = dns_cache.resolve_socket_addrs(url.host(), url.effective_port(), dns_config)?;
let (primary, fallback): (Vec<_>, Vec<_>) = addrs.into_iter().partition(|a| a.is_ipv6());
let (primary, fallback) = if primary.is_empty() {
(fallback, vec![])
} else {
(primary, fallback)
};
connect_happy_eyeballs(primary, fallback, local_addr, timeout_config.connect).await?
};
match url.scheme() {
"https" => connect_tls_h2(stream, url, tls_config, timeout_config).await,
"http" if prior_knowledge_h2c => Ok(Box::new(stream)),
"http" => Err(Error::new(
ErrorKind::Transport,
"h2c prior knowledge is required for cleartext http2",
)),
_ => Err(Error::new(
ErrorKind::Transport,
"http2 transport requires https or explicit h2c prior knowledge over http",
)),
}
}
async fn connect_via_http_proxy_tunnel(
proxy_addr: SocketAddr,
auth: Option<&ProxyAuth>,
target_host: &str,
target_port: u16,
timeout_config: TimeoutConfig,
local_addr: Option<SocketAddr>,
) -> Result<async_net::TcpStream> {
let mut stream = connect_tcp_addr(proxy_addr, local_addr, timeout_config.connect)
.await
.map_err(|err| {
if err.kind() == &ErrorKind::Timeout {
err
} else {
Error::with_source(ErrorKind::Transport, "failed to connect to HTTP proxy", err)
}
})?;
let connect_request = build_http_connect_request(target_host, target_port, auth);
with_timeout_io(
timeout_config.write,
stream.write_all(connect_request.as_bytes()),
"proxy connect write timed out",
)
.await?;
let mut response = Vec::new();
let mut buf = [0u8; 1];
loop {
let n = with_timeout_io(
timeout_config.read,
stream.read(&mut buf),
"proxy connect read timed out",
)
.await?;
if n == 0 {
break;
}
response.push(buf[0]);
if response.ends_with(b"\r\n\r\n") {
break;
}
}
let response_str = String::from_utf8_lossy(&response);
let success = response_str.starts_with("HTTP/1.1 200")
|| response_str.starts_with("HTTP/1.0 200")
|| response_str.starts_with("HTTP/2 200")
|| response_str.starts_with("HTTP/2.0 200");
if !success {
return Err(Error::new(
ErrorKind::Transport,
format!("HTTP proxy CONNECT failed: {}", response_str.trim()),
));
}
Ok(stream)
}
async fn connect_via_socks5_proxy(
proxy_addr: SocketAddr,
auth: Option<&ProxyAuth>,
target_host: &str,
target_port: u16,
timeout_config: TimeoutConfig,
local_addr: Option<SocketAddr>,
) -> Result<async_net::TcpStream> {
let mut stream = connect_tcp_addr(proxy_addr, local_addr, timeout_config.connect)
.await
.map_err(|err| {
if err.kind() == &ErrorKind::Timeout {
err
} else {
Error::with_source(
ErrorKind::Transport,
"failed to connect to SOCKS5 proxy",
err,
)
}
})?;
let methods: &[u8] = if auth.is_some() {
&[0x00, 0x02]
} else {
&[0x00]
};
let greeting = [&[0x05u8, methods.len() as u8][..], methods].concat();
with_timeout_io(
timeout_config.write,
stream.write_all(&greeting),
"socks5 write timed out",
)
.await?;
let mut method_resp = [0u8; 2];
with_timeout_io(
timeout_config.read,
stream.read_exact(&mut method_resp),
"socks5 read timed out",
)
.await?;
if method_resp[0] != 0x05 {
return Err(Error::new(
ErrorKind::Transport,
"SOCKS5 proxy returned unexpected version",
));
}
match method_resp[1] {
0x00 => {} 0x02 => {
let auth = auth.ok_or_else(|| {
Error::new(ErrorKind::Transport, "SOCKS5 proxy requires authentication")
})?;
let mut auth_req = vec![0x01u8];
auth_req.push(auth.username.len() as u8);
auth_req.extend_from_slice(auth.username.as_bytes());
auth_req.push(auth.password.len() as u8);
auth_req.extend_from_slice(auth.password.as_bytes());
with_timeout_io(
timeout_config.write,
stream.write_all(&auth_req),
"socks5 write timed out",
)
.await?;
let mut auth_resp = [0u8; 2];
with_timeout_io(
timeout_config.read,
stream.read_exact(&mut auth_resp),
"socks5 read timed out",
)
.await?;
if auth_resp[1] != 0x00 {
return Err(Error::new(
ErrorKind::Transport,
"SOCKS5 proxy authentication failed",
));
}
}
0xFF => {
return Err(Error::new(
ErrorKind::Transport,
"SOCKS5 proxy rejected all auth methods",
));
}
m => {
return Err(Error::new(
ErrorKind::Transport,
format!("SOCKS5 proxy selected unknown auth method: {m}"),
));
}
}
let host_bytes = target_host.as_bytes();
let mut connect_req = vec![
0x05,
0x01,
0x00, 0x03, host_bytes.len() as u8,
];
connect_req.extend_from_slice(host_bytes);
connect_req.extend_from_slice(&target_port.to_be_bytes());
with_timeout_io(
timeout_config.write,
stream.write_all(&connect_req),
"socks5 write timed out",
)
.await?;
let mut resp_hdr = [0u8; 4];
with_timeout_io(
timeout_config.read,
stream.read_exact(&mut resp_hdr),
"socks5 read timed out",
)
.await?;
if resp_hdr[1] != 0x00 {
return Err(Error::new(
ErrorKind::Transport,
format!(
"SOCKS5 proxy connection failed with code: {:#x}",
resp_hdr[1]
),
));
}
match resp_hdr[3] {
0x01 => {
let mut _a = [0u8; 6];
with_timeout_io(
timeout_config.read,
stream.read_exact(&mut _a),
"socks5 read timed out",
)
.await?;
}
0x03 => {
let mut len_buf = [0u8; 1];
with_timeout_io(
timeout_config.read,
stream.read_exact(&mut len_buf),
"socks5 read timed out",
)
.await?;
let mut _a = vec![0u8; len_buf[0] as usize + 2];
with_timeout_io(
timeout_config.read,
stream.read_exact(&mut _a),
"socks5 read timed out",
)
.await?;
}
0x04 => {
let mut _a = [0u8; 18];
with_timeout_io(
timeout_config.read,
stream.read_exact(&mut _a),
"socks5 read timed out",
)
.await?;
}
t => {
return Err(Error::new(
ErrorKind::Transport,
format!("SOCKS5 proxy returned unknown address type: {t}"),
));
}
}
Ok(stream)
}
#[cfg(any(feature = "rustls", feature = "native-tls", feature = "btls-backend"))]
async fn connect_tls_h2(
stream: async_net::TcpStream,
url: &Url,
tls_config: &TlsConfig,
timeout_config: TimeoutConfig,
) -> Result<BoxedStream> {
fn default_backend() -> crate::tls::TlsBackend {
#[cfg(feature = "rustls")]
{
return crate::tls::TlsBackend::Rustls;
}
#[cfg(all(not(feature = "rustls"), feature = "native-tls"))]
{
return crate::tls::TlsBackend::Native;
}
#[cfg(all(
not(feature = "rustls"),
not(feature = "native-tls"),
feature = "btls-backend"
))]
{
return crate::tls::TlsBackend::Boring;
}
}
let backend = tls_config.backend.unwrap_or_else(default_backend);
match backend {
#[cfg(feature = "rustls")]
crate::tls::TlsBackend::Rustls => {
connect_tls_h2_async_tls(stream, url, tls_config, timeout_config).await
}
#[cfg(feature = "native-tls")]
crate::tls::TlsBackend::Native => {
connect_tls_h2_native(stream, url, tls_config, timeout_config).await
}
#[cfg(feature = "btls-backend")]
crate::tls::TlsBackend::Boring => {
connect_tls_h2_boring(stream, url, tls_config, timeout_config).await
}
}
}
#[cfg(all(
not(feature = "rustls"),
not(feature = "native-tls"),
not(feature = "btls-backend")
))]
async fn connect_tls_h2(
_stream: async_net::TcpStream,
_url: &Url,
_tls_config: &TlsConfig,
_timeout_config: TimeoutConfig,
) -> Result<BoxedStream> {
Err(Error::new(
ErrorKind::Transport,
"https support requires a tls feature",
))
}
#[cfg(feature = "rustls")]
async fn connect_tls_h2_async_tls(
stream: async_net::TcpStream,
url: &Url,
tls_config: &TlsConfig,
timeout_config: TimeoutConfig,
) -> Result<BoxedStream> {
let config = tls_config
.build_h2_client_config()
.map_err(|message| Error::new(ErrorKind::Transport, message))?;
let tls_stream =
connect_async_tls_with_config(stream, url.host(), config, timeout_config.connect).await?;
Ok(Box::new(tls_stream))
}
#[cfg(feature = "native-tls")]
async fn connect_tls_h2_native(
stream: async_net::TcpStream,
url: &Url,
tls_config: &TlsConfig,
timeout_config: TimeoutConfig,
) -> Result<BoxedStream> {
let protocols = tls_config
.validate_h2_alpn()
.map_err(|message| Error::new(ErrorKind::Transport, message))?;
let connector = build_native_tls_connector_for_protocols(tls_config, &protocols)?;
let tls_stream =
connect_native_tls_with_connector(stream, url.host(), connector, timeout_config.connect)
.await?;
Ok(Box::new(tls_stream))
}
#[cfg(feature = "btls-backend")]
async fn connect_tls_h2_boring(
stream: async_net::TcpStream,
url: &Url,
tls_config: &TlsConfig,
timeout_config: TimeoutConfig,
) -> Result<BoxedStream> {
let protocols = tls_config
.validate_h2_alpn()
.map_err(|message| Error::new(ErrorKind::Transport, message))?;
let connector =
crate::connection::tls::build_boring_tls_connector_for_protocols(tls_config, &protocols)?;
let tls_stream = crate::connection::tls::connect_boring_tls_with_connector(
stream,
url.host(),
connector,
tls_config,
timeout_config.connect,
)
.await?;
Ok(Box::new(tls_stream))
}
pub(super) async fn with_timeout<F, T>(
timeout: Option<Duration>,
future: F,
kind: ErrorKind,
message: &'static str,
) -> Result<T>
where
F: std::future::Future<Output = T>,
{
match timeout {
Some(duration) => {
futures_lite::future::or(async move { Ok(future.await) }, async move {
async_io::Timer::after(duration).await;
Err(Error::new(kind, message))
})
.await
}
None => Ok(future.await),
}
}
pub(super) async fn with_timeout_io<F, T>(
timeout: Option<Duration>,
future: F,
message: &'static str,
) -> Result<T>
where
F: std::future::Future<Output = std::io::Result<T>>,
{
match with_timeout(timeout, future, ErrorKind::Timeout, message).await {
Ok(result) => result.map_err(|err| Error::with_source(ErrorKind::Transport, message, err)),
Err(err) => Err(err),
}
}