use hyper::header::{CONNECTION, HeaderValue, UPGRADE};
use hyper::upgrade::OnUpgrade;
use hyper::{Method, Request};
use std::sync::{Arc, Mutex};
use tokio::io::{AsyncRead, AsyncWrite};
pub mod ws;
pub use ws::MaskMode;
pub trait UpgradedStream: AsyncRead + AsyncWrite + Send + Unpin {}
impl<T> UpgradedStream for T where
T: AsyncRead + AsyncWrite + Send + Unpin + ?Sized
{
}
pub type BoxedUpgradedStream = Box<dyn UpgradedStream>;
pub async fn pump(
mut a: BoxedUpgradedStream,
mut b: BoxedUpgradedStream,
) -> std::io::Result<(u64, u64)> {
tokio::io::copy_bidirectional(&mut a, &mut b).await
}
pub async fn pump_websocket(
inbound: BoxedUpgradedStream,
upstream: BoxedUpgradedStream,
mode: MaskMode,
) -> std::io::Result<()> {
use tokio::io::AsyncWriteExt;
let (mut in_r, mut in_w) = tokio::io::split(inbound);
let (mut up_r, mut up_w) = tokio::io::split(upstream);
let client_to_server = async {
let res =
ws::translate_masking(&mut in_r, &mut up_w, mode).await;
let _ = up_w.shutdown().await;
res
};
let server_to_client = async {
let res =
tokio::io::copy(&mut up_r, &mut in_w).await.map(|_| ());
let _ = in_w.shutdown().await;
res
};
let (c2s, s2c) = tokio::join!(client_to_server, server_to_client);
c2s.and(s2c)
}
#[derive(Clone)]
pub struct UpgradeRequest {
pub protocol: HeaderValue,
pub inbound: InboundProtocol,
pub on_upgrade: Arc<Mutex<Option<OnUpgrade>>>,
pub ws_key: Option<HeaderValue>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum InboundProtocol {
H1,
H2,
H3,
}
pub fn detect_h1_upgrade<B>(
req: &mut Request<B>,
) -> Option<UpgradeRequest>
where
OnUpgrade: From<hyper::upgrade::OnUpgrade>,
{
if !connection_has_upgrade(req.headers().get(CONNECTION)?) {
return None;
}
let protocol = req.headers().get(UPGRADE)?.clone();
if protocol.as_bytes().is_empty() {
return None;
}
let ws_key = req
.headers()
.get(hyper::header::SEC_WEBSOCKET_KEY)
.cloned();
let on_upgrade = hyper::upgrade::on(req);
Some(UpgradeRequest {
protocol,
inbound: InboundProtocol::H1,
on_upgrade: Arc::new(Mutex::new(Some(on_upgrade))),
ws_key,
})
}
pub fn detect_h2_upgrade<B>(
req: &mut Request<B>,
) -> Option<UpgradeRequest> {
if req.method() != Method::CONNECT {
return None;
}
let proto = req.extensions().get::<hyper::ext::Protocol>()?;
let hv = HeaderValue::from_bytes(proto.as_str().as_bytes()).ok()?;
let on_upgrade = hyper::upgrade::on(req);
Some(UpgradeRequest {
protocol: hv,
inbound: InboundProtocol::H2,
on_upgrade: Arc::new(Mutex::new(Some(on_upgrade))),
ws_key: None,
})
}
fn connection_has_upgrade(v: &HeaderValue) -> bool {
let Ok(s) = v.to_str() else { return false };
s.split(',').any(|tok| tok.trim().eq_ignore_ascii_case("upgrade"))
}
pub async fn open_h1_upstream_tunnel<B>(
upstream: &hyper::Uri,
req: hyper::Request<B>,
) -> anyhow::Result<(hyper::http::response::Parts, BoxedUpgradedStream)>
where
B: hyper::body::Body + Unpin + Send + 'static,
B::Data: Send,
B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
{
use anyhow::{Context, anyhow, bail};
use hyper::client::conn::http1;
use hyper_util::rt::TokioIo;
let scheme = upstream.scheme_str().unwrap_or("http");
let io: Box<dyn UpgradedStream> = match scheme {
"http" => {
let host = upstream.host().ok_or_else(|| {
anyhow!("upgrade upstream missing host: {upstream}")
})?;
let port = upstream.port_u16().unwrap_or(80);
Box::new(
tokio::net::TcpStream::connect((host, port))
.await
.with_context(|| {
format!(
"connect upgrade upstream {host}:{port}"
)
})?,
)
}
#[cfg(unix)]
"unix" => {
let path = upstream.path();
Box::new(
tokio::net::UnixStream::connect(path)
.await
.with_context(|| {
format!("connect upgrade upstream unix:{path}")
})?,
)
}
s => bail!(
"h1 upgrade tunnel: scheme `{s}://` is not yet supported \
on the upgrade path (TLS / h2 follow up in #29)"
),
};
let (mut sender, conn) =
http1::handshake(TokioIo::new(io)).await.context(
"h1 client handshake for upgrade",
)?;
let conn_task = tokio::spawn(conn.with_upgrades());
let mut resp = sender
.send_request(req)
.await
.context("sending upgrade request to upstream")?;
let status = resp.status();
if status != hyper::StatusCode::SWITCHING_PROTOCOLS {
drop(conn_task);
let (parts, _) = resp.into_parts();
let (_, never) = tokio::io::duplex(0);
return Ok((parts, Box::new(never)));
}
let upgraded = hyper::upgrade::on(&mut resp)
.await
.map_err(|e| anyhow!("upstream upgrade handoff: {e}"))?;
drop(sender); drop(conn_task);
let (parts, _) = resp.into_parts();
Ok((parts, h1_upgraded(upgraded)))
}
pub async fn open_h2c_upstream_tunnel<B>(
upstream: &hyper::Uri,
mut req: hyper::Request<B>,
protocol: &hyper::header::HeaderValue,
) -> anyhow::Result<(hyper::http::response::Parts, BoxedUpgradedStream)>
where
B: hyper::body::Body + Unpin + Send + 'static,
B::Data: Send,
B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
{
use anyhow::{Context, anyhow, bail};
use hyper::client::conn::http2;
use hyper_util::rt::{TokioExecutor, TokioIo};
let scheme = upstream.scheme_str().unwrap_or("http");
if scheme != "http" {
bail!(
"h2c upgrade tunnel: scheme `{scheme}://` is not valid; \
use http:// for prior-knowledge h2 (TLS+h2 ALPN \
follow-up)"
);
}
let host = upstream
.host()
.ok_or_else(|| anyhow!("upstream missing host: {upstream}"))?;
let port = upstream.port_u16().unwrap_or(80);
let tcp = tokio::net::TcpStream::connect((host, port))
.await
.with_context(|| {
format!("connect h2c upgrade upstream {host}:{port}")
})?;
let (mut sender, conn) = http2::Builder::new(TokioExecutor::new())
.handshake(TokioIo::new(tcp))
.await
.context("h2c client handshake for upgrade")?;
let _conn_task = tokio::spawn(conn);
*req.method_mut() = hyper::Method::CONNECT;
let proto_str = protocol.to_str().map_err(|_| {
anyhow!("protocol header is not ASCII: {protocol:?}")
})?;
req.extensions_mut().insert(hyper::ext::Protocol::from_static(
Box::leak(proto_str.to_owned().into_boxed_str()),
));
if let Some(authority) = upstream.authority() {
let mut new_uri_parts = req.uri().clone().into_parts();
new_uri_parts.authority = Some(authority.clone());
new_uri_parts.scheme =
Some(hyper::http::uri::Scheme::HTTP);
if new_uri_parts.path_and_query.is_none() {
new_uri_parts.path_and_query = Some(
hyper::http::uri::PathAndQuery::from_static("/"),
);
}
if let Ok(new_uri) =
hyper::Uri::from_parts(new_uri_parts)
{
*req.uri_mut() = new_uri;
}
}
let mut resp = sender
.send_request(req)
.await
.context("sending h2c upgrade request to upstream")?;
let status = resp.status();
if status != hyper::StatusCode::OK {
let (parts, _) = resp.into_parts();
let (_, never) = tokio::io::duplex(0);
return Ok((parts, Box::new(never)));
}
let upgraded = hyper::upgrade::on(&mut resp)
.await
.map_err(|e| anyhow!("h2c upstream upgrade handoff: {e}"))?;
let (parts, _) = resp.into_parts();
Ok((parts, h1_upgraded(upgraded)))
}
use bytes::Bytes;
pub fn h3_upgraded_from_split<S, R>(
mut send: S,
mut recv: R,
) -> BoxedUpgradedStream
where
S: H3SendData + Send + 'static,
R: H3RecvData + Send + 'static,
{
let (adapter_io, mut driver_io) = tokio::io::duplex(64 * 1024);
tokio::spawn(async move {
use tokio::io::{AsyncReadExt, AsyncWriteExt};
let mut write_buf = vec![0u8; 16 * 1024];
loop {
tokio::select! {
res = recv.recv_data_owned() => match res {
Ok(Some(chunk)) => {
if driver_io.write_all(&chunk).await.is_err() {
break;
}
}
Ok(None) => {
let _ = driver_io.shutdown().await;
let mut drain_buf = vec![0u8; 16 * 1024];
while let Ok(n) =
driver_io.read(&mut drain_buf).await
{
if n == 0 { break; }
if send.send_data_bytes(
Bytes::copy_from_slice(&drain_buf[..n]),
).await.is_err() {
break;
}
}
let _ = send.finish_stream().await;
break;
}
Err(e) => {
tracing::debug!("h3 upgrade: recv error: {e:?}");
break;
}
},
n = driver_io.read(&mut write_buf) => match n {
Ok(0) => {
let _ = send.finish_stream().await;
break;
}
Ok(n) => {
if send.send_data_bytes(
Bytes::copy_from_slice(&write_buf[..n]),
).await.is_err() {
break;
}
}
Err(_) => break,
},
}
}
});
Box::new(adapter_io)
}
#[async_trait::async_trait]
pub trait H3SendData {
async fn send_data_bytes(
&mut self,
bytes: Bytes,
) -> Result<(), anyhow::Error>;
async fn finish_stream(&mut self) -> Result<(), anyhow::Error>;
}
#[async_trait::async_trait]
pub trait H3RecvData {
async fn recv_data_owned(
&mut self,
) -> Result<Option<Bytes>, anyhow::Error>;
}
pub fn h1_upgraded(
upgraded: hyper::upgrade::Upgraded,
) -> BoxedUpgradedStream {
Box::new(hyper_util::rt::TokioIo::new(upgraded))
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn pump_round_trips_bytes_both_ways_renamed() {
let (a_inner, a_outer) = tokio::io::duplex(64);
let (b_inner, b_outer) = tokio::io::duplex(64);
let bidi = tokio::spawn(pump(
Box::new(a_inner),
Box::new(b_inner),
));
let writer = tokio::spawn({
let mut a = a_outer;
let mut b = b_outer;
async move {
a.write_all(b"hello-from-a").await.unwrap();
a.shutdown().await.unwrap();
let mut got = Vec::new();
b.read_to_end(&mut got).await.unwrap();
b.write_all(b"reply-from-b").await.unwrap();
b.shutdown().await.unwrap();
got
}
});
let from_a_seen_by_b = writer.await.unwrap();
assert_eq!(from_a_seen_by_b, b"hello-from-a");
let _ = bidi.await.unwrap();
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn websocket_h1_round_trip_through_proxy() {
use futures_util::{SinkExt, StreamExt};
use tokio::net::TcpListener;
use tokio_tungstenite::{
accept_async, connect_async,
tungstenite::protocol::Message,
};
let backend = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = backend.local_addr().unwrap();
tokio::spawn(async move {
while let Ok((sock, _)) = backend.accept().await {
tokio::spawn(async move {
let mut ws =
accept_async(sock).await.unwrap();
while let Some(Ok(msg)) = ws.next().await {
if msg.is_text() || msg.is_binary() {
ws.send(msg).await.unwrap();
} else if msg.is_close() {
break;
}
}
});
}
});
let template = format!(
r#"
listener "tcp://{{addr}}" {{ }}
vhost "example.com" {{
location "/" {{
proxy {{ upstream "http://{backend_addr}" }}
}}
}}
"#,
);
let srv = crate::test::TestServer::start(&template).await;
let hypershunt_addr = srv.addr;
let url = format!("ws://{hypershunt_addr}/echo");
let (mut ws, response) =
connect_async(&url).await.expect("ws connect");
assert_eq!(response.status(), 101);
ws.send(Message::text("ping")).await.unwrap();
let reply = ws.next().await.expect("got reply").unwrap();
assert_eq!(reply.into_text().unwrap().as_str(), "ping");
ws.close(None).await.unwrap();
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn h1_inbound_to_h2c_outbound_round_trip() {
use bytes::Bytes;
use http_body_util::{BodyExt as _, Empty};
use hyper::body::Incoming;
use hyper::server::conn::http2;
use hyper::service::service_fn;
use hyper::{Method, Request, Response};
use hyper_util::rt::{TokioExecutor, TokioIo};
use std::convert::Infallible;
use tokio::net::TcpListener;
let backend = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = backend.local_addr().unwrap();
tokio::spawn(async move {
while let Ok((sock, _)) = backend.accept().await {
let svc = service_fn(|mut req: Request<Incoming>| async move {
if req.method() != Method::CONNECT
|| req
.extensions()
.get::<hyper::ext::Protocol>()
.is_none()
{
return Ok::<_, Infallible>(
Response::builder()
.status(400)
.body(
Empty::<Bytes>::new()
.map_err(|_| {
std::io::Error::other("never")
})
.boxed_unsync(),
)
.unwrap(),
);
}
let upgrade = hyper::upgrade::on(&mut req);
tokio::spawn(async move {
let Ok(upgraded) = upgrade.await else {
return;
};
use tokio::io::{
AsyncReadExt, AsyncWriteExt,
};
let mut io =
hyper_util::rt::TokioIo::new(upgraded);
let mut buf = vec![0u8; 4096];
loop {
match io.read(&mut buf).await {
Ok(0) | Err(_) => break,
Ok(n) => {
if io
.write_all(&buf[..n])
.await
.is_err()
{
break;
}
}
}
}
});
Ok(Response::builder()
.status(200)
.body(
Empty::<Bytes>::new()
.map_err(|_| {
std::io::Error::other("never")
})
.boxed_unsync(),
)
.unwrap())
});
tokio::spawn(async move {
let mut builder =
http2::Builder::new(TokioExecutor::new());
builder.enable_connect_protocol();
let _ = builder
.serve_connection(TokioIo::new(sock), svc)
.await;
});
}
});
let template = format!(
r#"
listener "tcp://{{addr}}" {{ }}
vhost "example.com" {{
location "/" {{
proxy scheme="h2c" {{
upstream "http://{backend_addr}"
}}
}}
}}
"#,
);
let srv = crate::test::TestServer::start(&template).await;
let hypershunt_addr = srv.addr;
use futures_util::{SinkExt as _, StreamExt as _};
use tokio_tungstenite::{
connect_async, tungstenite::protocol::Message,
};
let url = format!("ws://{hypershunt_addr}/echo");
let (mut ws, response) =
connect_async(&url).await.expect("ws connect");
assert_eq!(response.status(), 101);
ws.send(Message::text("cross-proto-ping")).await.unwrap();
let reply = ws.next().await.expect("got reply").unwrap();
assert_eq!(
reply.into_text().unwrap().as_str(),
"cross-proto-ping"
);
ws.close(None).await.unwrap();
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn h1_to_h2c_backend_frames_arrive_unmasked() {
use bytes::Bytes;
use http_body_util::{BodyExt as _, Empty};
use hyper::body::Incoming;
use hyper::server::conn::http2;
use hyper::service::service_fn;
use hyper::{Method, Request, Response};
use hyper_util::rt::{TokioExecutor, TokioIo};
use std::convert::Infallible;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpListener;
let backend = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = backend.local_addr().unwrap();
tokio::spawn(async move {
while let Ok((sock, _)) = backend.accept().await {
let svc = service_fn(|mut req: Request<Incoming>| async move {
if req.method() != Method::CONNECT
|| req
.extensions()
.get::<hyper::ext::Protocol>()
.is_none()
{
return Ok::<_, Infallible>(
Response::builder()
.status(400)
.body(
Empty::<Bytes>::new()
.map_err(|_| {
std::io::Error::other("never")
})
.boxed_unsync(),
)
.unwrap(),
);
}
let upgrade = hyper::upgrade::on(&mut req);
tokio::spawn(async move {
let Ok(upgraded) = upgrade.await else {
return;
};
let mut io =
hyper_util::rt::TokioIo::new(upgraded);
let mut acc = Vec::new();
let mut buf = vec![0u8; 4096];
let header = loop {
if let Some(h) =
super::ws::parse_header(&acc).unwrap()
{
let end =
h.header_len + h.payload_len as usize;
if acc.len() >= end {
break h;
}
}
match io.read(&mut buf).await {
Ok(0) | Err(_) => return,
Ok(n) => acc.extend_from_slice(&buf[..n]),
}
};
let payload = &acc[header.header_len
..header.header_len
+ header.payload_len as usize];
let reply_text = if header.masked {
format!(
"MASKED:{}",
String::from_utf8_lossy(payload)
)
} else {
format!(
"verified:{}",
String::from_utf8_lossy(payload)
)
};
let mut out = Vec::new();
super::ws::emit_header(
&mut out,
0x81,
reply_text.len() as u64,
None,
);
out.extend_from_slice(reply_text.as_bytes());
let _ = io.write_all(&out).await;
let _ = io.flush().await;
});
Ok(Response::builder()
.status(200)
.body(
Empty::<Bytes>::new()
.map_err(|_| {
std::io::Error::other("never")
})
.boxed_unsync(),
)
.unwrap())
});
tokio::spawn(async move {
let mut builder =
http2::Builder::new(TokioExecutor::new());
builder.enable_connect_protocol();
let _ = builder
.serve_connection(TokioIo::new(sock), svc)
.await;
});
}
});
let template = format!(
r#"
listener "tcp://{{addr}}" {{ }}
vhost "example.com" {{
location "/" {{
proxy scheme="h2c" {{
upstream "http://{backend_addr}"
}}
}}
}}
"#,
);
let srv = crate::test::TestServer::start(&template).await;
let hypershunt_addr = srv.addr;
use futures_util::{SinkExt as _, StreamExt as _};
use tokio_tungstenite::{
connect_async, tungstenite::protocol::Message,
};
let url = format!("ws://{hypershunt_addr}/echo");
let (mut ws, response) =
connect_async(&url).await.expect("ws connect");
assert_eq!(response.status(), 101);
ws.send(Message::text("payload-42")).await.unwrap();
let reply = ws.next().await.expect("got reply").unwrap();
assert_eq!(
reply.into_text().unwrap().as_str(),
"verified:payload-42",
"backend must have seen an unmasked frame"
);
ws.close(None).await.unwrap();
}
}