use anyhow::Context;
use http_body_util::{BodyExt, Full};
use hyper::body::{Bytes, Incoming};
use hyper::{Request, Response, StatusCode};
use hyper_util::rt::TokioIo;
use tracing::{error, info};
use crate::net;
pub async fn handle_connect(
req: Request<Incoming>,
fast_open: bool,
) -> anyhow::Result<Response<Full<Bytes>>> {
let authority = req
.uri()
.authority()
.map(|a| a.to_string())
.unwrap_or_else(|| {
req.uri().to_string()
});
let addr = if authority.contains(':') {
authority.clone()
} else {
format!("{authority}:443")
};
info!("CONNECT tunnel to {addr}");
tokio::spawn(async move {
match hyper::upgrade::on(req).await {
Ok(upgraded) => {
let mut client = TokioIo::new(upgraded);
match net::connect(&addr, fast_open).await {
Ok(mut target) => {
if let Err(e) =
tokio::io::copy_bidirectional(&mut client, &mut target).await
{
error!("tunnel {addr} io error: {e}");
}
}
Err(e) => {
error!("failed to connect to {addr}: {e}");
}
}
}
Err(e) => {
error!("upgrade failed for {addr}: {e}");
}
}
});
Ok(Response::builder()
.status(StatusCode::OK)
.body(Full::new(Bytes::new()))
.unwrap())
}
pub async fn handle_forward(
mut req: Request<Incoming>,
fast_open: bool,
) -> anyhow::Result<Response<Full<Bytes>>> {
let uri = req.uri().clone();
let host = uri
.authority()
.context("missing authority in forward request")?
.to_string();
let port = uri.port_u16().unwrap_or(match uri.scheme_str() {
Some("https") => 443,
_ => 80,
});
let addr = if host.contains(':') {
host.clone()
} else {
format!("{host}:{port}")
};
info!("forward {} {} -> {addr}", req.method(), uri);
let path_and_query = uri
.path_and_query()
.map(|pq| pq.to_string())
.unwrap_or_else(|| "/".to_string());
*req.uri_mut() = path_and_query.parse()?;
let headers = req.headers_mut();
headers.remove("proxy-authorization");
headers.remove("proxy-connection");
let stream = net::connect(&addr, fast_open)
.await
.with_context(|| format!("connect to {addr}"))?;
let io = TokioIo::new(stream);
let (mut sender, conn) = hyper::client::conn::http1::handshake(io)
.await
.context("upstream handshake")?;
tokio::spawn(async move {
if let Err(e) = conn.await {
error!("upstream connection error: {e}");
}
});
let resp = sender
.send_request(req)
.await
.context("upstream send_request")?;
let (parts, body) = resp.into_parts();
let body_bytes = body
.collect()
.await
.context("read upstream body")?
.to_bytes();
Ok(Response::from_parts(parts, Full::new(body_bytes)))
}