use crate::prelude::*;
use hyper::service::service_fn;
use hyper_util::rt::TokioIo;
use std::convert::Infallible;
use std::str::FromStr;
use std::sync::Arc;
use tokio::net::TcpListener;
use tokio::sync::oneshot::Receiver;
use tokio::{pin, select};
impl Proxy {
pub(crate) async fn serve(
routes: ProxyRoutes,
handler: ProxyHandler,
listener: TcpListener,
shutdown_rx: Receiver<()>,
) {
let shutdown_future = async {
let _ = shutdown_rx.await;
};
pin!(shutdown_future);
loop {
select! {
accept_result = listener.accept() => {
match accept_result {
Ok((stream, _)) => {
let io = TokioIo::new(stream);
let routes_clone = Arc::clone(&routes);
let handler_clone = Arc::clone(&handler);
tokio::spawn(async move {
if let Err(err) = hyper::server::conn::http1::Builder::new()
.serve_connection(io, service_fn(move |req| {
let routes_clone = Arc::clone(&routes_clone);
let handler_clone = handler_clone.clone();
Self::handler(routes_clone, handler_clone,req)
}))
.await
{
eprintln!("Error serving connection: {:?}", err);
}
});
}
Err(err) => {
eprintln!("Failed to accept connection: {:?}", err);
}
}
}
_ = &mut shutdown_future => {
println!("Shutdown signal received, stopping server");
break;
}
}
}
}
pub(crate) async fn handler(
routes: ProxyRoutes,
handler: ProxyHandler,
mut req: Request<hyper::body::Incoming>,
) -> Result<Response<Full<Bytes>>, Infallible> {
if let Some((origin, path, callback)) = Self::match_route(routes.clone(), req.uri().path())
{
let body = req
.body_mut()
.collect()
.await
.map(|body| body.to_bytes())
.ok();
let uri_str = format!("{}{}", origin, path);
let uri = Uri::from_str(&uri_str).unwrap();
let payload = ProxyForwardRequestPayload {
method: req.method(),
headers: req.headers(),
uri,
body,
};
if let Some(res) = handler(ProxyHandlerPayload {
request: &payload,
origin: &origin,
path: &path,
})
.unwrap()
{
return Ok(res);
}
match Proxy::forward_request(payload, origin).await {
Ok((forward_req, mut response)) => {
let status = response.status();
let headers = response.headers().clone();
let body_bytes = match response.body_mut().collect().await {
Ok(collected) => collected.to_bytes(),
Err(e) => {
eprintln!("Failed to read response body: {:?}", e);
return Ok(Proxy::error_response(format!(
"Failed to read response body: {:?}",
e
)));
}
};
let mut res_builder = Response::builder().status(status);
for (name, value) in headers.iter() {
res_builder = res_builder.header(name, value);
}
if let Some(callback) = callback {
let _ = callback(ProxyRouteCallbackPayload {
request: &forward_req,
status,
headers: res_builder.headers_ref(),
body: &body_bytes,
});
}
let res = res_builder.body(Full::new(body_bytes)).unwrap_or_else(|_| {
Proxy::error_response("Failed to build response".into())
});
return Ok(res);
}
Err(e) => {
return Ok(Proxy::error_response(format!("Proxy error: {:?}", e)));
}
}
}
Ok(Response::new(Full::new(Bytes::from("Hello, World!"))))
}
fn match_route(
routes: ProxyRoutes,
uri: &str,
) -> Option<(String, String, Option<ProxyRouteCallback>)> {
for route in routes.read().unwrap().iter() {
if let Some((path, callback)) = route.matching_path(uri) {
return Some((route.origin(), path, callback));
}
}
None
}
fn error_response(msg: String) -> Response<Full<Bytes>> {
Response::builder()
.status(502)
.body(Full::new(Bytes::from(msg)))
.expect("Error response builder always valid")
}
}