use std::{collections::HashMap, sync::Arc};
use async_trait::async_trait;
use pingora::{upstreams::peer::HttpPeer, ErrorType::HTTPStatus};
use pingora_load_balancing::{selection::RoundRobin, LoadBalancer};
use pingora_proxy::{ProxyHttp, Session};
use tracing::info;
type ArcedLB = Arc<LoadBalancer<RoundRobin>>;
pub struct Router {
routes: HashMap<String, Arc<LoadBalancer<RoundRobin>>>,
}
impl Router {
pub fn new() -> Self {
Router {
routes: HashMap::new(),
}
}
pub fn add_route(&mut self, route: String, upstream: ArcedLB) {
self.routes.insert(route, upstream);
}
}
pub struct RouterContext {
pub host: String,
pub current_lb: Option<ArcedLB>,
}
#[async_trait]
impl ProxyHttp for Router {
type CTX = RouterContext;
fn new_ctx(&self) -> Self::CTX {
RouterContext {
host: String::new(),
current_lb: None,
}
}
async fn request_filter(
&self,
session: &mut Session,
ctx: &mut Self::CTX,
) -> pingora::Result<bool> {
let req_host = get_host(session);
let host_without_port = req_host.split(':').collect::<Vec<&str>>()[0].to_string();
let upstream_lb = self.routes.get(&host_without_port);
if upstream_lb.is_none() {
return Err(pingora::Error::new(HTTPStatus(404)));
}
ctx.host = host_without_port;
ctx.current_lb = Some(upstream_lb.unwrap().clone());
Ok(false)
}
async fn upstream_peer(
&self,
_session: &mut Session,
ctx: &mut Self::CTX,
) -> pingora::Result<Box<HttpPeer>> {
let upstream = ctx.current_lb.as_ref();
if upstream.is_none() {
return Err(pingora::Error::new(HTTPStatus(404)));
}
let healthy_upstream = upstream.unwrap().select(b"", 256);
if healthy_upstream.is_none() {
return Err(pingora::Error::new(HTTPStatus(503)));
}
info!(host = ctx.host, "Upstream selected");
let peer = HttpPeer::new(healthy_upstream.unwrap(), false, ctx.host.clone());
Ok(Box::new(peer))
}
}
fn get_host(session: &mut Session) -> String {
if let Some(host) = session.get_header(http::header::HOST) {
if let Ok(host_str) = host.to_str() {
return host_str.to_string();
}
}
if let Some(host) = session.req_header().uri.host() {
return host.to_string();
}
"".to_string()
}