mod upgrade;
pub mod util;
use std::net::SocketAddr;
use std::sync::{Arc, Weak};
use std::time::Instant;
use bytes::Bytes;
use http::header::{CONNECTION, UPGRADE};
use http_body_util::combinators::BoxBody;
use hyper::{body::Incoming, Method, StatusCode};
use hyper::{Request, Response};
use util::proxy_request;
use crate::cache::{Cache, CacheEntry, CloneableRes};
use crate::config::rule::Rule;
use crate::config::Config;
use crate::{cfg_logging, UpstreamAndConnPool, Upstreams};
#[cfg_attr(
feature = "logging",
tracing::instrument(level = "trace", skip(req, config, cache))
)]
pub(crate) async fn handle_req(
req: Request<hyper::body::Incoming>,
peer_addr: SocketAddr,
config: Arc<Config>,
cache: Arc<Cache>,
upstreams: Arc<Upstreams>,
) -> Result<Response<BoxBody<Bytes, crate::Error>>, crate::Error> {
for rule in &config.rules {
if rule.matches(&req) {
let upstream = upstreams.get(rule.upstream_key).expect("`upstream` in a rule should match a key in the `upstreams` property at the root of the config.");
let auth_res = util::authenticate(&upstreams, upstream, peer_addr, &req).await?;
if let Some(res) = auth_res {
return Ok(res);
};
return handle_match(
req,
peer_addr,
rule,
upstream,
cache,
&upstreams,
config.max_connections,
)
.await;
}
}
Ok(Response::builder()
.status(StatusCode::NOT_FOUND)
.body(util::empty())
.unwrap())
}
#[cfg_attr(
feature = "logging",
tracing::instrument(level = "trace", skip(req, cache, peer_addr))
)]
async fn handle_match(
mut req: Request<Incoming>,
peer_addr: SocketAddr,
rule: &Rule,
upstream: &UpstreamAndConnPool,
cache: Arc<Cache>,
upstreams: &Upstreams,
max_connections: usize,
) -> Result<Response<BoxBody<Bytes, crate::Error>>, crate::Error> {
if Method::CONNECT == req.method() {
return Ok(Response::builder()
.status(StatusCode::METHOD_NOT_ALLOWED)
.body(util::empty())
.unwrap());
}
*req.uri_mut() = rule.remove_match(req.uri().path()).parse().unwrap();
let connection_header = req.headers().get(CONNECTION);
let upgrade_header = req.headers().get(UPGRADE);
let upgrading = connection_header.is_some_and(|v| v.as_bytes() == b"upgrade")
&& upgrade_header.is_some_and(|v| !v.is_empty());
if upgrading {
return upgrade::handle_upgrade(req, upstream, peer_addr).await;
}
let refresh_cache = if let Some(cache_settings) = rule.cache.as_ref() {
if cache_settings.methods.contains(req.method()) {
let entry = cache.get_entry(rule, req.uri()).await;
if let Some(entry) = entry {
let entry = entry.lock().await;
let CacheEntry {
cached_at: _,
value: _,
inflight,
} = &*entry;
if let Some(cached_res) = entry.extract_fresh_data(cache_settings.max_age) {
cfg_logging! {trace!("Cache hit for {}", req.uri());}
return Ok(cached_res);
}
let inflight = inflight.as_ref().cloned();
drop(entry);
if let Some(inflight) = inflight.as_ref().and_then(Weak::upgrade) {
cfg_logging! {trace!("No cache found for {}, waiting on inflight request...", req.uri());}
if let Ok(Some(res)) = inflight.subscribe().recv().await {
return Ok((*res).clone().0.map(|b| util::full(b)));
} else {
None
}
} else {
cfg_logging! {debug!("Stale cache for {}, updating...", req.uri());}
Some(
cache
.insert_empty_entry(rule, req.uri(), max_connections)
.await,
)
}
} else {
cfg_logging! {debug!("No cache found for {}, creating...", req.uri());}
Some(
cache
.insert_empty_entry(rule, req.uri(), max_connections)
.await,
)
}
} else {
None
}
} else {
None
};
let req_uri = req.uri().clone();
let resp = util::proxy_request(req, upstream, peer_addr, false).await;
cfg_logging! {
trace!("Got res from upstream {}", peer_addr);
}
if let Some(refresh_cache) = refresh_cache {
let status = resp.status();
let resp = if let Some(entry) = cache.get_entry(rule, &req_uri).await {
let mut entry = entry.lock().await;
let resp = if status.is_success() {
let (send_res, cloned_res) = util::clone_response(resp).await?;
let cloneable = CloneableRes(cloned_res);
refresh_cache.send(Some(Arc::new(cloneable.clone()))).ok();
entry.cached_at = Some(Instant::now());
entry.value = Some(cloneable.0);
send_res
} else {
refresh_cache.send(None).ok();
resp
};
entry.inflight = None;
resp
} else {
let resp = if status.is_success() {
let (send_res, cloned_res) = util::clone_response(resp).await?;
let cloneable = CloneableRes(cloned_res);
refresh_cache.send(Some(Arc::new(cloneable.clone()))).ok();
cache
.insert_populated_entry(rule, req_uri, cloneable.0)
.await;
send_res
} else {
refresh_cache.send(None).ok();
resp
};
resp
};
Ok(resp)
} else {
cfg_logging! {
trace!("Returning res form upstream {}", peer_addr);
}
Ok(resp)
}
}