use hyper::{Request, Response, body::Incoming};
use hyper_util::rt::TokioIo;
use snafu::ResultExt;
use tokio::net::TcpStream;
use tracing::Instrument as _;
use crate::{
Error, ForwardConnectSnafu, ForwardHandshakeSnafu, ForwardInvalidHostSnafu,
ForwardMissingHostSnafu, ForwardSendRequestSnafu,
};
pub async fn forward_http(req: Request<Incoming>) -> Result<Response<Incoming>, Error> {
let host_port = if let Some(authority) = req.uri().authority() {
authority.to_string()
} else if let Some(host_header) = req.headers().get(http::header::HOST) {
host_header
.to_str()
.context(ForwardInvalidHostSnafu)?
.to_string()
} else {
return ForwardMissingHostSnafu.fail();
};
let addr = if host_port.contains(':') {
host_port
} else {
format!("{}:80", host_port)
};
let stream = TcpStream::connect(&addr)
.await
.context(ForwardConnectSnafu { addr: addr.clone() })?;
crate::configure_tcp_keepalive(&stream);
let io = TokioIo::new(stream);
let (mut sender, conn) = hyper::client::conn::http1::handshake(io)
.await
.context(ForwardHandshakeSnafu { addr: addr.clone() })?;
tokio::spawn(conn.in_current_span());
let resp = sender
.send_request(req)
.await
.context(ForwardSendRequestSnafu)?;
Ok(resp)
}