use std::future::Future;
use std::pin::Pin;
use crate::body::RequestBodyLocal;
use crate::error::Error;
use crate::pool::PooledConnection;
use crate::proxy::ProxyConfig;
use crate::runtime::{ConnectorLocal, RuntimeLocal, SocketConfig};
use super::HttpEngineLocal;
impl<R: RuntimeLocal, C: ConnectorLocal + Clone> HttpEngineLocal<R, C> {
pub(super) async fn connect_via_proxy_local(
&self,
proxy: &ProxyConfig,
target_authority: &http::uri::Authority,
is_https: bool,
) -> Result<PooledConnection<RequestBodyLocal>, Error> {
let proxy_authority = proxy.authority()?;
let default_port = proxy.default_port();
let proxy_addr = self
.core
.resolve_authority(proxy_authority, default_port)
.await?;
let tcp_stream = if let Some(local_addr) = self.core.local_address {
self.connector
.connect_bound(proxy_addr, local_addr)
.await
.map_err(Error::Io)?
} else {
self.connector
.connect(proxy_addr)
.await
.map_err(Error::Io)?
};
#[cfg(target_os = "linux")]
if let Some(ref iface) = self.core.interface {
tcp_stream.bind_device(iface).map_err(Error::Io)?;
}
if let Some(time) = self.core.tcp_keepalive {
tcp_stream
.set_keepalive(
time,
self.core.tcp_keepalive_interval,
self.core.tcp_keepalive_retries,
)
.map_err(Error::Io)?;
}
if self.core.tcp_fast_open {
let _ = tcp_stream.set_fast_open();
}
if proxy.scheme == crate::proxy::ProxyScheme::Socks5 {
let host = target_authority.host();
let port = target_authority
.port_u16()
.unwrap_or(if is_https { 443 } else { 80 });
let mut std_stream = self.connector.into_std_tcp(tcp_stream).map_err(Error::Io)?;
if let Some(timeout) = self.core.connect_timeout {
std_stream
.set_read_timeout(Some(timeout))
.map_err(Error::Io)?;
std_stream
.set_write_timeout(Some(timeout))
.map_err(Error::Io)?;
}
crate::socks5::socks5_handshake(&mut std_stream, host, port, proxy.auth.as_ref())
.map_err(Error::Io)?;
if self.core.connect_timeout.is_some() {
std_stream.set_read_timeout(None).map_err(Error::Io)?;
std_stream.set_write_timeout(None).map_err(Error::Io)?;
}
let tcp_stream = self.connector.from_std_tcp(std_stream).map_err(Error::Io)?;
if is_https {
self.connect_tls_local(tcp_stream, host).await
} else {
self.connect_h1_local(tcp_stream).await
}
} else if proxy.scheme == crate::proxy::ProxyScheme::Socks4 {
let host = target_authority.host();
let port = target_authority
.port_u16()
.unwrap_or(if is_https { 443 } else { 80 });
let mut std_stream = self.connector.into_std_tcp(tcp_stream).map_err(Error::Io)?;
if let Some(timeout) = self.core.connect_timeout {
std_stream
.set_read_timeout(Some(timeout))
.map_err(Error::Io)?;
std_stream
.set_write_timeout(Some(timeout))
.map_err(Error::Io)?;
}
crate::socks4::socks4a_handshake(&mut std_stream, host, port, proxy.auth.as_ref())
.map_err(Error::Io)?;
if self.core.connect_timeout.is_some() {
std_stream.set_read_timeout(None).map_err(Error::Io)?;
std_stream.set_write_timeout(None).map_err(Error::Io)?;
}
let tcp_stream = self.connector.from_std_tcp(std_stream).map_err(Error::Io)?;
if is_https {
self.connect_tls_local(tcp_stream, host).await
} else {
self.connect_h1_local(tcp_stream).await
}
} else if is_https {
self.connect_tunnel_local(tcp_stream, proxy, target_authority)
.await
} else {
self.connect_plaintext_local(tcp_stream).await
}
}
async fn connect_tunnel_local(
&self,
mut tcp_stream: C::Stream,
proxy: &ProxyConfig,
target_authority: &http::uri::Authority,
) -> Result<PooledConnection<RequestBodyLocal>, Error> {
use hyper::rt::{Read, Write};
let target = target_authority.as_str();
let mut connect_msg = format!("CONNECT {target} HTTP/1.1\r\nHost: {target}\r\n");
if let Some(auth_value) = proxy.connect_header(target) {
connect_msg.push_str(&format!("Proxy-Authorization: {auth_value}\r\n"));
}
connect_msg.push_str("\r\n");
let buf = connect_msg.into_bytes();
let mut written = 0;
while written < buf.len() {
let n = std::future::poll_fn(|cx| {
Pin::new(&mut tcp_stream).poll_write(cx, &buf[written..])
})
.await
.map_err(Error::Io)?;
if n == 0 {
return Err(Error::Io(std::io::Error::new(
std::io::ErrorKind::WriteZero,
"proxy closed connection during CONNECT handshake",
)));
}
written += n;
}
let mut resp_buf = Vec::with_capacity(256);
loop {
let mut one = [0u8; 1];
let mut read_buf = hyper::rt::ReadBuf::new(&mut one);
std::future::poll_fn(|cx| Pin::new(&mut tcp_stream).poll_read(cx, read_buf.unfilled()))
.await
.map_err(Error::Io)?;
if read_buf.filled().is_empty() {
return Err(Error::Other("proxy closed connection".into()));
}
resp_buf.push(one[0]);
if resp_buf.len() >= 4 && resp_buf[resp_buf.len() - 4..] == *b"\r\n\r\n" {
break;
}
if resp_buf.len() > 8192 {
return Err(Error::Other("CONNECT response too large".into()));
}
}
let resp_str = String::from_utf8_lossy(&resp_buf);
let status_line = resp_str
.lines()
.next()
.ok_or_else(|| Error::Other("empty CONNECT response".into()))?;
let status_code = super::connect::parse_connect_status(status_line)?;
if status_code != 200 {
return Err(Error::Other(
format!("CONNECT tunnel failed: {status_line}").into(),
));
}
self.connect_tls_local(tcp_stream, target_authority.host())
.await
}
pub(super) fn connect_plaintext_local<S>(
&self,
stream: S,
) -> Pin<Box<dyn Future<Output = Result<PooledConnection<RequestBodyLocal>, Error>> + '_>>
where
S: hyper::rt::Read + hyper::rt::Write + Unpin + 'static,
{
self.connect_plaintext_local_with_hint(stream, false)
}
pub(super) fn connect_plaintext_local_with_hint<S>(
&self,
stream: S,
force_h2c: bool,
) -> Pin<Box<dyn Future<Output = Result<PooledConnection<RequestBodyLocal>, Error>> + '_>>
where
S: hyper::rt::Read + hyper::rt::Write + Unpin + 'static,
{
if self.core.http2_prior_knowledge || force_h2c {
Box::pin(self.connect_h2_prior_knowledge_local(stream))
} else {
Box::pin(self.connect_h1_local(stream))
}
}
pub(super) async fn connect_h1_local<S>(
&self,
stream: S,
) -> Result<PooledConnection<RequestBodyLocal>, Error>
where
S: hyper::rt::Read + hyper::rt::Write + Unpin + 'static,
{
let (sender, conn) = hyper::client::conn::http1::handshake(stream).await?;
let handle = crate::upgrade::UpgradeHandleLocal::new();
let handle_clone = handle.clone();
R::spawn_local(async move {
match conn.without_shutdown().await {
Ok(parts) => {
let upgraded = crate::upgrade::UpgradedLocal::new(parts.io, parts.read_buf);
handle_clone.fulfill(upgraded);
}
Err(_) => {
handle_clone.fail();
}
}
});
let mut pooled = PooledConnection::new_h1(sender);
pooled.upgrade_handle_local = Some(handle);
Ok(pooled)
}
pub(super) async fn connect_h2_prior_knowledge_local<S>(
&self,
stream: S,
) -> Result<PooledConnection<RequestBodyLocal>, Error>
where
S: hyper::rt::Read + hyper::rt::Write + Unpin + 'static,
{
let mut builder = hyper::client::conn::http2::Builder::new(
crate::runtime::executor::completion_executor::<R>(),
);
if let Some(ref h2) = self.core.http2 {
h2.apply(&mut builder);
}
let (sender, conn) = builder.handshake(stream).await?;
R::spawn_local(async move {
let _ = conn.await;
});
Ok(PooledConnection::new_h2(sender))
}
#[cfg(all(feature = "rustls", feature = "compio"))]
pub(super) async fn connect_tls_local(
&self,
tcp_stream: C::Stream,
host: &str,
) -> Result<PooledConnection<RequestBodyLocal>, Error> {
use crate::tls::TlsConnectLocal;
use std::time::Instant;
let tls_start = Instant::now();
let tls_connector = self
.core
.tls
.as_ref()
.ok_or_else(|| Error::Tls("no TLS connector configured".into()))?;
let tls_stream =
<crate::tls::RustlsConnector as TlsConnectLocal<C::Stream>>::connect_local(
tls_connector,
host,
tcp_stream,
)
.await
.map_err(|e| Error::Tls(Box::new(e)))?;
let tls_duration = tls_start.elapsed();
let alpn = crate::tls::RustlsConnector::negotiated_protocol(tls_stream.tls_connection());
let tls_info = tls_stream.tls_info();
match alpn {
Some(crate::tls::AlpnProtocol::H2) => {
let mut builder = hyper::client::conn::http2::Builder::new(
crate::runtime::executor::completion_executor::<R>(),
);
if let Some(ref h2) = self.core.http2 {
h2.apply(&mut builder);
}
let (sender, conn) = builder.handshake(tls_stream).await?;
R::spawn_local(async move {
let _ = conn.await;
});
let mut pooled = PooledConnection::new_h2(sender);
pooled.tls_info = Some(tls_info);
pooled.tls_handshake_duration = Some(tls_duration);
Ok(pooled)
}
_ => {
let (sender, conn) = hyper::client::conn::http1::handshake(tls_stream).await?;
let handle = crate::upgrade::UpgradeHandleLocal::new();
let handle_clone = handle.clone();
R::spawn_local(async move {
match conn.without_shutdown().await {
Ok(parts) => {
let upgraded =
crate::upgrade::UpgradedLocal::new(parts.io, parts.read_buf);
handle_clone.fulfill(upgraded);
}
Err(_) => {
handle_clone.fail();
}
}
});
let mut pooled = PooledConnection::new_h1(sender);
pooled.tls_info = Some(tls_info);
pooled.tls_handshake_duration = Some(tls_duration);
pooled.upgrade_handle_local = Some(handle);
Ok(pooled)
}
}
}
#[cfg(all(feature = "rustls", not(feature = "compio")))]
pub(super) async fn connect_tls_local(
&self,
_tcp_stream: C::Stream,
_host: &str,
) -> Result<PooledConnection<RequestBodyLocal>, Error> {
Err(Error::Tls(
"TLS with !Send streams requires the compio feature".into(),
))
}
#[cfg(not(feature = "rustls"))]
pub(super) async fn connect_tls_local(
&self,
_tcp_stream: C::Stream,
_host: &str,
) -> Result<PooledConnection<RequestBodyLocal>, Error> {
Err(Error::Tls(
"HTTPS requires the `rustls` TLS backend feature".into(),
))
}
}
#[cfg(all(test, feature = "compio"))]
mod compio_tests {
use super::super::HttpEngineLocal;
use crate::runtime::compio_rt::{CompioIo, CompioRuntime, TcpConnector};
fn make_local_engine() -> HttpEngineLocal<CompioRuntime, TcpConnector> {
HttpEngineLocal::<CompioRuntime, TcpConnector>::new()
}
fn make_h2_local_engine() -> HttpEngineLocal<CompioRuntime, TcpConnector> {
HttpEngineLocal::<CompioRuntime, TcpConnector>::builder()
.http2_prior_knowledge()
.build_local()
.unwrap()
}
#[test]
fn connect_h1_local_succeeds() {
let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
let addr = listener.local_addr().unwrap();
listener.set_nonblocking(true).unwrap();
let async_listener = async_io::Async::new(listener).unwrap();
compio_runtime::Runtime::new().unwrap().block_on(async {
let client_tcp = async_io::Async::<std::net::TcpStream>::connect(addr)
.await
.unwrap();
let (server_tcp, _) = async_listener.accept().await.unwrap();
compio_runtime::spawn(async move {
use futures_io::AsyncRead;
let mut server = server_tcp;
let mut buf = [0u8; 4096];
while std::future::poll_fn(|cx| {
std::pin::Pin::new(&mut server).poll_read(cx, &mut buf)
})
.await
.unwrap_or(0)
> 0
{}
})
.detach();
let io = CompioIo::new(client_tcp);
let engine = make_local_engine();
let result = engine.connect_h1_local(io).await;
assert!(result.is_ok());
let pooled = result.unwrap();
assert!(matches!(pooled.conn, crate::pool::HttpConnection::H1(_)));
});
}
#[test]
fn connect_h2_prior_knowledge_local_succeeds() {
let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
let addr = listener.local_addr().unwrap();
listener.set_nonblocking(true).unwrap();
let async_listener = async_io::Async::new(listener).unwrap();
compio_runtime::Runtime::new().unwrap().block_on(async {
let client_tcp = async_io::Async::<std::net::TcpStream>::connect(addr)
.await
.unwrap();
let (server_tcp, _) = async_listener.accept().await.unwrap();
compio_runtime::spawn(async move {
let io = CompioIo::new(server_tcp);
let builder = hyper::server::conn::http2::Builder::new(
crate::runtime::executor::completion_executor::<CompioRuntime>(),
);
let _ = builder
.serve_connection(
io,
hyper::service::service_fn(|_req| async {
Ok::<_, std::convert::Infallible>(hyper::Response::new(
http_body_util::Empty::<bytes::Bytes>::new(),
))
}),
)
.await;
})
.detach();
let io = CompioIo::new(client_tcp);
let engine = make_local_engine();
let result = engine.connect_h2_prior_knowledge_local(io).await;
assert!(result.is_ok());
let pooled = result.unwrap();
assert!(matches!(pooled.conn, crate::pool::HttpConnection::H2(_)));
});
}
#[test]
fn connect_plaintext_local_defaults_to_h1() {
let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
let addr = listener.local_addr().unwrap();
listener.set_nonblocking(true).unwrap();
let async_listener = async_io::Async::new(listener).unwrap();
compio_runtime::Runtime::new().unwrap().block_on(async {
let client_tcp = async_io::Async::<std::net::TcpStream>::connect(addr)
.await
.unwrap();
let (server_tcp, _) = async_listener.accept().await.unwrap();
compio_runtime::spawn(async move {
use futures_io::AsyncRead;
let mut server = server_tcp;
let mut buf = [0u8; 4096];
while std::future::poll_fn(|cx| {
std::pin::Pin::new(&mut server).poll_read(cx, &mut buf)
})
.await
.unwrap_or(0)
> 0
{}
})
.detach();
let io = CompioIo::new(client_tcp);
let engine = make_local_engine();
let result = engine.connect_plaintext_local(io).await;
assert!(result.is_ok());
let pooled = result.unwrap();
assert!(matches!(pooled.conn, crate::pool::HttpConnection::H1(_)));
});
}
#[test]
fn connect_plaintext_local_with_hint_false_uses_h1() {
let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
let addr = listener.local_addr().unwrap();
listener.set_nonblocking(true).unwrap();
let async_listener = async_io::Async::new(listener).unwrap();
compio_runtime::Runtime::new().unwrap().block_on(async {
let client_tcp = async_io::Async::<std::net::TcpStream>::connect(addr)
.await
.unwrap();
let (server_tcp, _) = async_listener.accept().await.unwrap();
compio_runtime::spawn(async move {
use futures_io::AsyncRead;
let mut server = server_tcp;
let mut buf = [0u8; 4096];
while std::future::poll_fn(|cx| {
std::pin::Pin::new(&mut server).poll_read(cx, &mut buf)
})
.await
.unwrap_or(0)
> 0
{}
})
.detach();
let io = CompioIo::new(client_tcp);
let engine = make_local_engine();
let result = engine.connect_plaintext_local_with_hint(io, false).await;
assert!(result.is_ok());
let pooled = result.unwrap();
assert!(matches!(pooled.conn, crate::pool::HttpConnection::H1(_)));
});
}
#[test]
fn connect_plaintext_local_with_hint_true_uses_h2() {
let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
let addr = listener.local_addr().unwrap();
listener.set_nonblocking(true).unwrap();
let async_listener = async_io::Async::new(listener).unwrap();
compio_runtime::Runtime::new().unwrap().block_on(async {
let client_tcp = async_io::Async::<std::net::TcpStream>::connect(addr)
.await
.unwrap();
let (server_tcp, _) = async_listener.accept().await.unwrap();
compio_runtime::spawn(async move {
let io = CompioIo::new(server_tcp);
let builder = hyper::server::conn::http2::Builder::new(
crate::runtime::executor::completion_executor::<CompioRuntime>(),
);
let _ = builder
.serve_connection(
io,
hyper::service::service_fn(|_req| async {
Ok::<_, std::convert::Infallible>(hyper::Response::new(
http_body_util::Empty::<bytes::Bytes>::new(),
))
}),
)
.await;
})
.detach();
let io = CompioIo::new(client_tcp);
let engine = make_local_engine();
let result = engine.connect_plaintext_local_with_hint(io, true).await;
assert!(result.is_ok());
let pooled = result.unwrap();
assert!(matches!(pooled.conn, crate::pool::HttpConnection::H2(_)));
});
}
#[test]
fn connect_plaintext_local_with_http2_prior_knowledge_uses_h2() {
let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
let addr = listener.local_addr().unwrap();
listener.set_nonblocking(true).unwrap();
let async_listener = async_io::Async::new(listener).unwrap();
compio_runtime::Runtime::new().unwrap().block_on(async {
let client_tcp = async_io::Async::<std::net::TcpStream>::connect(addr)
.await
.unwrap();
let (server_tcp, _) = async_listener.accept().await.unwrap();
compio_runtime::spawn(async move {
let io = CompioIo::new(server_tcp);
let builder = hyper::server::conn::http2::Builder::new(
crate::runtime::executor::completion_executor::<CompioRuntime>(),
);
let _ = builder
.serve_connection(
io,
hyper::service::service_fn(|_req| async {
Ok::<_, std::convert::Infallible>(hyper::Response::new(
http_body_util::Empty::<bytes::Bytes>::new(),
))
}),
)
.await;
})
.detach();
let io = CompioIo::new(client_tcp);
let engine = make_h2_local_engine();
let result = engine.connect_plaintext_local(io).await;
assert!(result.is_ok());
let pooled = result.unwrap();
assert!(matches!(pooled.conn, crate::pool::HttpConnection::H2(_)));
});
}
#[test]
fn connect_h1_local_server_closes_immediately() {
let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
let addr = listener.local_addr().unwrap();
listener.set_nonblocking(true).unwrap();
let async_listener = async_io::Async::new(listener).unwrap();
compio_runtime::Runtime::new().unwrap().block_on(async {
let client_tcp = async_io::Async::<std::net::TcpStream>::connect(addr)
.await
.unwrap();
let (server_tcp, _) = async_listener.accept().await.unwrap();
drop(server_tcp);
let io = CompioIo::new(client_tcp);
let engine = make_local_engine();
let result = engine.connect_h1_local(io).await;
assert!(result.is_ok());
});
}
#[test]
fn connect_h2_local_server_closes_immediately_fails() {
let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
let addr = listener.local_addr().unwrap();
listener.set_nonblocking(true).unwrap();
let async_listener = async_io::Async::new(listener).unwrap();
compio_runtime::Runtime::new().unwrap().block_on(async {
let client_tcp = async_io::Async::<std::net::TcpStream>::connect(addr)
.await
.unwrap();
let (server_tcp, _) = async_listener.accept().await.unwrap();
drop(server_tcp);
let io = CompioIo::new(client_tcp);
let engine = make_local_engine();
let result = engine.connect_h2_prior_knowledge_local(io).await;
match result {
Ok(pooled) => {
assert!(matches!(pooled.conn, crate::pool::HttpConnection::H2(_)));
}
Err(_) => {
}
}
});
}
#[cfg(feature = "rustls")]
#[test]
fn connect_tls_local_on_plain_stream_fails() {
let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
let addr = listener.local_addr().unwrap();
compio_runtime::Runtime::new().unwrap().block_on(async {
let accept_handle = std::thread::spawn(move || {
let (mut conn, _) = listener.accept().unwrap();
use std::io::Read;
let mut buf = [0u8; 4096];
let _ = conn.read(&mut buf);
});
crate::tls::install_default_crypto_provider();
let engine = HttpEngineLocal::<CompioRuntime, TcpConnector>::builder()
.tls(crate::tls::RustlsConnector::with_webpki_roots())
.build_local()
.unwrap();
let connector = TcpConnector;
let stream = crate::runtime::ConnectorLocal::connect(&connector, addr)
.await
.unwrap();
let result = engine.connect_tls_local(stream, "example.com").await;
assert!(result.is_err());
drop(accept_handle);
});
}
}
#[cfg(all(test, feature = "tokio"))]
mod tokio_tests {
#[test]
fn connect_local_uses_same_parse_connect_status() {
let result = super::super::connect::parse_connect_status("HTTP/1.1 200 OK");
assert_eq!(result.unwrap(), 200);
}
}