btdt-server 0.2.0

Server component for "been there, done that" - a tool for flexible CI caching
Documentation
use crate::app::Options;
use crate::config::BtdtServerConfig;
use chrono::Local;
use data_encoding::BASE64;
use poem::listener::{BoxListener, Listener, NativeTlsConfig};
use poem::{
    Endpoint, EndpointExt, IntoResponse, Middleware, Request, Response, Server,
    listener::TcpListener,
};
use std::borrow::Cow;
use std::convert::Infallible;
use std::error::Error;
use std::fs::File;
use std::io::Read;

mod app;
mod config;

struct AccessLogMiddleware {}

struct AccessLogMiddlewareImpl<E: Endpoint> {
    ep: E,
}

impl<E: Endpoint> Middleware<E> for AccessLogMiddleware {
    type Output = AccessLogMiddlewareImpl<E>;

    fn transform(&self, ep: E) -> Self::Output {
        AccessLogMiddlewareImpl { ep }
    }
}

impl<E: Endpoint> Endpoint for AccessLogMiddlewareImpl<E> {
    type Output = Response;

    async fn call(&self, req: Request) -> poem::Result<Self::Output> {
        let version = req.version();
        let method = req.method();
        let original_uri = req.original_uri();
        let remote_addr = req
            .remote_addr()
            .as_socket_addr()
            .map(|addr| Cow::Owned(addr.ip().to_string()))
            .unwrap_or(Cow::Borrowed("-"));
        let referer = req
            .headers()
            .get("Referer")
            .and_then(|v| v.to_str().map(|r| Cow::Owned(r.replace('"', "\\\""))).ok())
            .unwrap_or(Cow::Borrowed("-"));
        let user_agent = req
            .headers()
            .get("User-Agent")
            .and_then(|v| {
                v.to_str()
                    .map(|ua| Cow::Owned(ua.replace('"', "\\\"")))
                    .ok()
            })
            .unwrap_or(Cow::Borrowed("-"));
        let basic_auth_user = req
            .headers()
            .get("Authorization")
            .and_then(|v| v.to_str().ok())
            .and_then(|auth| auth.split_once(' '))
            .and_then(|(scheme, credential)| {
                if scheme.eq_ignore_ascii_case("basic") {
                    Some(credential)
                } else {
                    None
                }
            })
            .and_then(|auth| BASE64.decode(auth.as_bytes()).ok())
            .and_then(|decoded_auth| {
                decoded_auth
                    .split(|&c| c == b':')
                    .next()
                    .map(|u| Cow::Owned(String::from_utf8_lossy(u).into_owned()))
            })
            .unwrap_or(Cow::Borrowed("-"));
        let time = Local::now().format("%d/%b/%Y:%H:%M:%S %z");

        let log_start = format!(
            "{remote_addr} - {basic_auth_user} [{time}] \"{method} {original_uri} {version:?}\""
        );
        let log_end = format!("\"{referer}\" \"{user_agent}\"");

        let result = self.ep.call(req).await.map(|res| res.into_response());
        let status = result
            .as_ref()
            .map(|res| Cow::Owned(res.status().as_u16().to_string()))
            .or_else(|err| {
                Result::<_, Infallible>::Ok(Cow::Owned(err.status().as_u16().to_string()))
            })
            .unwrap_or(Cow::Borrowed("-"));

        println!("{log_start} {status} - {log_end}");
        result
    }
}

struct ErrorLogMiddleware {}

struct ErrorLogMiddlewareImpl<E: Endpoint> {
    ep: E,
}

impl<E: Endpoint> Middleware<E> for ErrorLogMiddleware {
    type Output = ErrorLogMiddlewareImpl<E>;

    fn transform(&self, ep: E) -> Self::Output {
        ErrorLogMiddlewareImpl { ep }
    }
}

impl<E: Endpoint> Endpoint for ErrorLogMiddlewareImpl<E> {
    type Output = E::Output;

    async fn call(&self, req: Request) -> poem::Result<Self::Output> {
        let method = req.method().to_string();
        let original_uri = req.original_uri().clone();
        match self.ep.call(req).await {
            Ok(response) => Ok(response),
            Err(mut err) => {
                let source = err.source().unwrap_or(&err);
                eprintln!("Error in request for {method} {original_uri}: {source:?}");
                if err.status().is_server_error() {
                    err.set_error_message("Internal Server Error");
                }
                Err(err)
            }
        }
    }
}

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
    println!("{} v{}", env!("CARGO_PKG_NAME"), env!("CARGO_PKG_VERSION"));

    let settings = BtdtServerConfig::load()?;
    let mut listener: BoxListener = settings
        .bind_addrs
        .iter()
        .cloned()
        .map(|addr| TcpListener::bind(addr).boxed())
        .reduce(|a, b| a.combine(b).boxed())
        .ok_or("No bind addresses provided")?;

    let enable_tls = !settings.tls_keystore.is_empty();
    if enable_tls {
        let mut cert_buf = Vec::new();
        File::open(settings.tls_keystore)?.read_to_end(&mut cert_buf)?;
        listener = listener
            .native_tls(
                NativeTlsConfig::new()
                    .pkcs12(cert_buf)
                    .password(settings.tls_keystore_password),
            )
            .boxed()
    }

    let protocol = if enable_tls { "https" } else { "http" };
    for addr in &settings.bind_addrs {
        println!("Listening on {protocol}://{addr}");
    }

    Server::new(listener)
        .run(
            app::create_route(
                Options::builder()
                    .enable_api_docs(settings.enable_api_docs)
                    .build(),
                &settings.caches,
            )
            .with(AccessLogMiddleware {})
            .with(ErrorLogMiddleware {}),
        )
        .await?;
    Ok(())
}