use std::net::ToSocketAddrs;
use std::{borrow::Cow, collections::HashMap, sync::Arc};
use async_trait::async_trait;
use http::{HeaderValue, Uri};
use pingora::http::{RequestHeader, ResponseHeader};
use pingora::proxy::{ProxyHttp, Session};
use pingora::{upstreams::peer::HttpPeer, ErrorType::HTTPStatus};
use crate::config::RouteUpstream;
use crate::stores::{self, routes::RouteStoreContainer};
use super::{
middleware::{
execute_request_plugins, execute_response_plugins, execute_upstream_request_plugins,
execute_upstream_response_plugins,
},
DEFAULT_PEER_OPTIONS,
};
pub struct Router {
}
fn process_route(ctx: &RouterContext) -> Arc<RouteStoreContainer> {
ctx.route_container.clone().unwrap()
}
pub struct RouterContext {
pub host: String,
pub route_container: Option<Arc<RouteStoreContainer>>,
pub upstream: Option<RouteUpstream>,
pub extensions: HashMap<Cow<'static, str>, String>,
pub timings: RouterTimings,
}
pub struct RouterTimings {
request_filter_start: std::time::Instant,
}
#[async_trait]
impl ProxyHttp for Router {
type CTX = RouterContext;
fn new_ctx(&self) -> Self::CTX {
RouterContext {
host: String::new(),
route_container: None,
upstream: None,
extensions: HashMap::new(),
timings: RouterTimings {
request_filter_start: std::time::Instant::now(),
},
}
}
async fn request_filter(
&self,
session: &mut Session,
ctx: &mut Self::CTX,
) -> pingora::Result<bool> {
let req_host = get_host(session);
let host_without_port = req_host.split(':').collect::<Vec<_>>()[0];
host_without_port.clone_into(&mut ctx.host);
let Some(route_container) = stores::get_route_by_key(host_without_port) else {
session.respond_error(404).await;
return Ok(true);
};
let arced = Arc::new(route_container);
let uri = get_uri(session);
match &arced.path_matcher.pattern {
Some(pattern) if pattern.find(uri.path()).is_none() => {
session.respond_error(404).await;
return Ok(true);
}
_ => {}
}
ctx.route_container = Some(Arc::clone(&arced));
if let Ok(true) = execute_request_plugins(session, ctx, &arced.plugins).await {
return Ok(true);
}
Ok(false)
}
async fn upstream_peer(
&self,
_session: &mut Session,
ctx: &mut Self::CTX,
) -> pingora::Result<Box<HttpPeer>> {
let route_container = process_route(ctx);
let Some(healthy_upstream) = route_container.load_balancer.select(b"", 128) else {
return Err(pingora::Error::new(HTTPStatus(503)));
};
let (healthy_ip, healthy_port) = if let Some(scr) = healthy_upstream.addr.as_inet() {
(scr.ip().to_string(), scr.port())
} else {
return Err(pingora::Error::new(HTTPStatus(503)));
};
let Some(upstream) = route_container.upstreams.iter().find(|u| {
format!("{}:{}", u.ip, u.port)
.to_socket_addrs()
.unwrap()
.any(|s| s.ip().to_string() == healthy_ip && s.port() == healthy_port)
}) else {
return Err(pingora::Error::new(HTTPStatus(503)));
};
ctx.upstream = Some(upstream.clone());
let mut peer = HttpPeer::new(
healthy_upstream,
healthy_port == 443,
upstream.sni.clone().unwrap_or(String::new()),
);
peer.options = DEFAULT_PEER_OPTIONS;
Ok(Box::new(peer))
}
async fn response_filter(
&self,
session: &mut Session,
upstream_response: &mut ResponseHeader,
ctx: &mut Self::CTX,
) -> pingora::Result<()> {
let route_container = process_route(ctx);
for (name, value) in &route_container.host_header_add {
upstream_response.insert_header(name, value)?;
}
for name in &route_container.host_header_remove {
upstream_response.remove_header(name);
}
execute_response_plugins(&route_container, session, ctx).await?;
Ok(())
}
async fn upstream_request_filter(
&self,
session: &mut Session,
upstream_request: &mut RequestHeader,
ctx: &mut Self::CTX,
) -> pingora::Result<()> {
let route_container = process_route(ctx);
let Some(upstream) = ctx.upstream.as_ref() else {
return Err(pingora::Error::new(HTTPStatus(503)));
};
if let Some(headers) = upstream.headers.as_ref() {
if let Some(add) = headers.add.as_ref() {
for header_add in add {
upstream_request
.insert_header(header_add.name.to_string(), header_add.value.to_string())
.ok();
}
}
}
execute_upstream_request_plugins(&route_container, session, upstream_request, ctx)
.await
.ok();
Ok(())
}
fn upstream_response_filter(
&self,
session: &mut Session,
upstream_response: &mut ResponseHeader,
ctx: &mut Self::CTX,
) {
let route_container = process_route(ctx);
execute_upstream_response_plugins(&route_container, session, upstream_response, ctx);
}
async fn logging(
&self,
session: &mut Session,
_: Option<&pingora::Error>,
ctx: &mut Self::CTX,
) {
let http_version = if session.is_http2() {
"http/2"
} else {
"http/1.1"
};
let method = session.req_header().method.to_string();
let query = session.req_header().uri.query().unwrap_or_default();
let path = session.req_header().uri.path();
let empty_header = HeaderValue::from_static("");
let host = session.req_header().uri.host();
let referer = session
.req_header()
.headers
.get("referer")
.unwrap_or(&empty_header);
let user_agent = session
.req_header()
.headers
.get("user-agent")
.unwrap_or(&empty_header);
let client_ip = session
.client_addr()
.map(ToString::to_string)
.unwrap_or_default();
let status_code = session
.response_written()
.map(|v| v.status.as_u16())
.unwrap_or_default();
let duration_ms = ctx.timings.request_filter_start.elapsed().as_millis();
tracing::info!(
method,
path,
query,
host,
duration_ms,
user_agent = user_agent.to_str().unwrap_or(""),
referer = referer.to_str().unwrap_or(""),
client_ip,
status_code,
http_version,
request_id = ctx.extensions.get("request_id_header"),
access_log = true
);
}
}
fn get_uri(session: &mut Session) -> Uri {
session.req_header().uri.clone()
}
fn get_host(session: &mut Session) -> &str {
if let Some(host) = session.get_header(http::header::HOST) {
return host.to_str().unwrap_or("");
}
if let Some(host) = session.req_header().uri.host() {
return host;
}
""
}