hyperapi 0.2.2

An easy to use API Gateway
Documentation
use hyper::{Body, Request, Response, Uri, header::HeaderValue};
use hyper::client::HttpConnector;
use hyper::client::Client;
use hyper_rustls::HttpsConnector;
use rustls::ClientConfig;
use tower::Service;
use std::pin::Pin;
use std::task::{Poll, Context};
use std::future::Future;
use std::time::Duration;
use tracing::{event, Level};
use crate::{config::Upstream, middleware::GatewayError};


lazy_static::lazy_static! {

    static ref HTTP_REQ_INPROGRESS: prometheus::IntGaugeVec = prometheus::register_int_gauge_vec!(
        "gateway_requests_in_progress",
        "Request in progress count",
        &["service", "upstream", "version"]
    ).unwrap();

}


#[derive(Debug, Clone)]
pub struct ProxyHandler {
    service_id: String,
    upstream_id: String,
    upstream: String,
    version: String,
    timeout: Duration,
    client: Client<HttpsConnector<HttpConnector>, Body>,
}

impl ProxyHandler {

    pub fn new(service_id: &str, upstream: &Upstream, timeout: u32) -> Self {
        let mut connector = HttpConnector::new();
        let timeout = Duration::from_secs(timeout as u64);
        connector.set_connect_timeout(Some(timeout));
        connector.set_keepalive(Some(Duration::from_secs(30)));

        let mut tls_config = ClientConfig::new();
        tls_config.root_store = match rustls_native_certs::load_native_certs() {
            Ok(store) => store,
            Err((Some(store), err)) => {
                log::warn!("Could not load all certificates: {:?}", err);
                store
            }
            Err((None, err)) => Err(err).expect("cannot access native cert store"),
        };
        if tls_config.root_store.is_empty() {
            panic!("no CA certificates found");
        }

        let tls = HttpsConnector::from((connector, tls_config));
        let client = Client::builder()
            .pool_idle_timeout(timeout)
            .build::<_, Body>(tls);

        ProxyHandler { 
            service_id: String::from(service_id), 
            client, 
            timeout,
            upstream: upstream.target.clone(), 
            upstream_id: upstream.id.clone(),
            version: upstream.version.clone(),
        }
    }

    fn alter_request(req: Request<Body>, endpoint: &str) -> Request<Body> {
        let (mut parts, body) = req.into_parts();
        parts.version = hyper::http::Version::HTTP_11;
        let path_and_query = parts.uri.path_and_query().map(|pq| pq.as_str()).unwrap_or("/");
        let path = path_and_query.strip_prefix("/").unwrap_or("/");
        let path_left = if let Some(offset) = path.find("/") {
            let (_service_id, path_left) = path.split_at(offset);
            path_left
        } else {
            ""
        };
        let mut new_uri = String::from(endpoint.trim_end_matches('/'));
        new_uri.push_str(path_left);

        parts.uri = new_uri.parse::<Uri>().unwrap();
        Request::from_parts(parts, body)
    }
}

impl Service<Request<Body>> for ProxyHandler
{
    type Response = Response<Body>;
    type Error = Box<dyn std::error::Error + Send + Sync>;
    type Future = Pin<Box<dyn Future<Output=Result<Self::Response, Self::Error>> + Send + 'static>>;

    fn poll_ready(&mut self, _c: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        Poll::Ready(Ok(()))
    }

    fn call(&mut self, req: Request<Body>) -> Self::Future {
        let req = ProxyHandler::alter_request(req, &self.upstream);
        event!(Level::DEBUG, "{:?}", req.uri());
        let upstream_id = self.upstream_id.to_string();
        let version = self.version.to_string();
        let service_id = self.service_id.clone();
        HTTP_REQ_INPROGRESS.with_label_values(&[
            &service_id, 
            &upstream_id,
            &version,
        ]).inc();

        let sleep = tokio::time::sleep(self.timeout.clone());
        let fut = self.client.request(req);
        Box::pin(async move {
            let result: Result<Response<Body>, GatewayError> = tokio::select! {
                resp = fut => {
                    Ok(resp?)
                },
                _ = sleep => {
                    Err(GatewayError::TimeoutError)
                },
            };

            HTTP_REQ_INPROGRESS.with_label_values(&[
                &service_id,
                &upstream_id,
                &version,
            ]).dec();

            let mut resp = result?;
            let header = resp.headers_mut();
            let us_id = HeaderValue::from_str(&upstream_id).unwrap();
            let us_version = HeaderValue::from_str(&version).unwrap();
            header.append("X-UPSTREAM-ID", us_id);
            header.append("X-UPSTREAM-VERSION", us_version);
            Ok(resp)
        })
    }
}