httproxide 0.2.0

Rusted HTTP router reverse-proxy
Documentation
use std::collections::HashMap;
use std::str::FromStr;

use axum::extract::Host;
use axum::handler::HandlerWithoutStateExt;
use axum::routing::{any_service, MethodRouter};
use axum::Router;
use http::{Method, Request, StatusCode};
use lazy_static::lazy_static;
use regex::Regex;
use serde::{Deserialize, Serialize};
use tower::ServiceExt;

use crate::config::TargetConfig;
use crate::error_responders::get_error_responder;
use crate::target::{target_config, IntoTarget, ReqBody};

#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
pub struct HostRouterRouteConfig {
    #[serde(flatten)]
    target: TargetConfig,
}
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
pub struct HostRouterConfig {
    hosts: HashMap<String, HostRouterRouteConfig>,
}

#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
pub struct LocationRouterRouteConfig {
    #[serde(default)]
    nest: bool,
    #[serde(flatten)]
    target: TargetConfig,
}
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
pub struct LocationRouterConfig {
    locations: HashMap<String, LocationRouterRouteConfig>,
}

#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
pub struct MethodRouterRouteConfig {
    #[serde(flatten)]
    target: TargetConfig,
}
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
pub struct MethodRouterConfig {
    methods: HashMap<String, MethodRouterRouteConfig>,
}

pub fn host_router_from_config(config: HostRouterConfig) -> anyhow::Result<impl IntoTarget> {
    let mut router_for_host = HashMap::new();

    if config.hosts.is_empty() {
        anyhow::bail!("empty hosts");
    }

    for (host_name, host) in config.hosts {
        let target = target_config(host.target)?;
        router_for_host.insert(host_name, target);
    }

    let handler = |Host(hostname): Host, request: Request<ReqBody>| async move {
        lazy_static! {
            static ref RE: Regex = Regex::new(r":\d+$").unwrap();
        }
        let hostname = RE.replace(&hostname, "").to_string();
        let hostrouter = match router_for_host.get(&hostname) {
            Some(v) => v.clone(),
            None => get_error_responder(StatusCode::NOT_FOUND),
        };
        hostrouter.oneshot(request).await
    };
    Ok(handler.into_service())
}

pub fn location_router_from_config(
    mut config: LocationRouterConfig,
) -> anyhow::Result<impl IntoTarget> {
    let mut router = Router::new();
    if let Some(fallback) = config.locations.remove("fallback") {
        let target = target_config(fallback.target)?;
        router = router.fallback_service(target);
    } else {
        router = router.fallback_service(get_error_responder(StatusCode::NOT_FOUND));
    }
    for (location_name, location) in config.locations {
        let target = target_config(location.target)?;
        router = if location.nest {
            router.nest(&location_name, Router::new().route("/*path", any_service(target)))
        } else {
            router.route(&location_name, any_service(target))
        };
    }
    Ok(router)
}

pub fn method_router_from_config(
    mut config: MethodRouterConfig,
) -> anyhow::Result<impl IntoTarget> {
    let mut router = {
        if let Some(fallback) = config.methods.remove("fallback") {
            let target = target_config(fallback.target)?;
            any_service(target)
        } else {
            MethodRouter::new().fallback_service(get_error_responder(StatusCode::METHOD_NOT_ALLOWED))
        }
    };
    for (method_name, route) in config.methods {
        let target = target_config(route.target)?;
        router = router.on_service(Method::from_str(&method_name)?.try_into()?, target);
    }
    Ok(router)
}