arbiter-proxy 0.0.22

Async HTTP reverse proxy with middleware chain architecture
Documentation
//! Server bootstrap and graceful shutdown.

use std::net::SocketAddr;
use std::sync::Arc;

use arbiter_metrics::ArbiterMetrics;
use hyper::server::conn::http1;
use hyper::service::service_fn;
use hyper_util::rt::TokioIo;
use tokio::net::TcpListener;
use tokio::signal;

use crate::config::ProxyConfig;
use crate::middleware::MiddlewareChain;
use crate::proxy::{ProxyState, build_audit, handle_request};

/// Run the proxy server. Blocks until a shutdown signal is received.
pub async fn run(config: ProxyConfig) -> anyhow::Result<()> {
    let addr: SocketAddr = format!(
        "{}:{}",
        config.server.listen_addr, config.server.listen_port
    )
    .parse()?;

    let middleware = MiddlewareChain::from_config(&config.middleware);

    let (audit_sink, redaction_config) = build_audit(&config.audit);
    let metrics = Arc::new(
        ArbiterMetrics::new().map_err(|e| anyhow::anyhow!("failed to create metrics: {e}"))?,
    );

    let state = Arc::new(ProxyState::new(
        config.upstream.url.clone(),
        middleware,
        audit_sink,
        redaction_config,
        metrics,
    ));

    let listener = TcpListener::bind(addr).await?;
    tracing::info!(%addr, upstream = %config.upstream.url, "proxy listening");

    let shutdown = shutdown_signal();
    tokio::pin!(shutdown);

    loop {
        tokio::select! {
            result = listener.accept() => {
                let (stream, remote_addr) = result?;
                let state = Arc::clone(&state);
                tracing::debug!(%remote_addr, "accepted connection");

                tokio::spawn(async move {
                    let io = TokioIo::new(stream);
                    let svc = service_fn(move |req| {
                        let state = Arc::clone(&state);
                        handle_request(state, req)
                    });
                    if let Err(e) = http1::Builder::new()
                        .serve_connection(io, svc)
                        .await
                    {
                        tracing::error!(error = %e, %remote_addr, "connection error");
                    }
                });
            }
            _ = &mut shutdown => {
                tracing::info!("shutdown signal received, stopping");
                break;
            }
        }
    }

    Ok(())
}

/// Wait for SIGTERM or SIGINT (ctrl-c).
async fn shutdown_signal() {
    let ctrl_c = async {
        signal::ctrl_c()
            .await
            .expect("failed to install ctrl-c handler");
    };

    #[cfg(unix)]
    let terminate = async {
        signal::unix::signal(signal::unix::SignalKind::terminate())
            .expect("failed to install SIGTERM handler")
            .recv()
            .await;
    };

    #[cfg(not(unix))]
    let terminate = std::future::pending::<()>();

    tokio::select! {
        _ = ctrl_c => {}
        _ = terminate => {}
    }
}