openlatch-client 0.0.0

The open-source security layer for AI agents — client forwarder
Documentation
/// Daemon HTTP server — composes all leaf modules into the running service.
///
/// This module provides:
/// - [`AppState`]: shared state cloned into every axum handler via `Arc`
/// - [`start_server`]: entry point that builds the router, binds TCP, and runs to completion
/// - Signal handling: graceful shutdown on SIGTERM (Unix) or Ctrl+C (all platforms)
/// - Route layout: authenticated POST routes + unauthenticated GET routes
pub mod auth;
pub mod dedup;
pub mod handlers;

use std::sync::atomic::AtomicU64;
use std::sync::{Arc, Mutex};

use axum::{
    extract::{DefaultBodyLimit, Request},
    http::{header::CONTENT_TYPE, StatusCode},
    middleware::{self, Next},
    response::Response,
    routing::get,
    routing::post,
    Router,
};
use tokio::net::TcpListener;

use crate::config::Config;
use crate::logging::EventLogger;
use crate::privacy::PrivacyFilter;
use crate::update;

/// Shared state injected into every axum handler via `Arc<AppState>`.
///
/// Fields are either inherently thread-safe (`AtomicU64`, `DashMap`, `mpsc::Sender`)
/// or wrapped in appropriate synchronization primitives.
pub struct AppState {
    /// Resolved daemon configuration (port, log dir, retention, etc.)
    pub config: Arc<Config>,
    /// Bearer token for authenticating POST requests.
    /// SECURITY: Never log this value.
    pub token: String,
    /// In-memory dedup store with 100ms TTL.
    pub dedup: dedup::DedupStore,
    /// Async event logger (sends to background writer task via mpsc).
    pub event_logger: EventLogger,
    /// Pre-compiled privacy filter for credential masking.
    pub privacy_filter: PrivacyFilter,
    /// Total events processed (not counting deduped duplicates).
    pub event_counter: AtomicU64,
    /// Oneshot sender for triggering graceful shutdown via POST /shutdown.
    /// Wrapped in Mutex so the handler can take ownership without `&mut self`.
    pub shutdown_tx: tokio::sync::Mutex<Option<tokio::sync::oneshot::Sender<()>>>,
    /// Wall-clock time when the daemon started (for uptime reporting).
    pub started_at: std::time::Instant,
    /// Latest available version string, populated by the async update check on startup.
    /// `None` means either the check has not completed yet, or the current version is latest.
    pub available_update: Mutex<Option<String>>,
}

impl AppState {
    /// Store a newly discovered available version.
    pub fn set_available_update(&self, version: String) {
        if let Ok(mut guard) = self.available_update.lock() {
            *guard = Some(version);
        }
    }

    /// Return the latest available version, if one has been discovered.
    pub fn get_available_update(&self) -> Option<String> {
        self.available_update.lock().ok().and_then(|g| g.clone())
    }
}

/// Start the daemon HTTP server and run until a shutdown signal is received.
///
/// Binds to `127.0.0.1:{config.port}`. The server shuts down gracefully on:
/// - SIGTERM (Unix) or Ctrl+C (all platforms)
/// - HTTP POST /shutdown (authenticated)
///
/// After shutdown, prints a summary to stderr and waits for the event logger to drain.
///
/// # Errors
///
/// Returns an error if the TCP listener cannot be bound (e.g., port in use).
pub async fn start_server(config: Config, token: String) -> anyhow::Result<(u64, u64)> {
    // SECURITY: Bind to 127.0.0.1 by default — only 0.0.0.0 inside containers
    // where Docker network isolation provides the boundary instead.
    let bind_host = if std::env::var("OPENLATCH_BIND_ALL").is_ok() {
        "0.0.0.0"
    } else {
        "127.0.0.1"
    };
    let bind_addr = format!("{}:{}", bind_host, config.port);
    let listener = TcpListener::bind(&bind_addr).await?;

    tracing::info!(
        port = config.port,
        addr = %bind_addr,
        "daemon listening"
    );

    // Write daemon.port file so the hook binary can discover the port
    if let Err(e) = crate::config::write_port_file(config.port) {
        tracing::warn!(error = %e, "failed to write daemon.port file");
    }

    serve_with_listener(listener, config, token).await
}

/// Start the daemon with a pre-bound TCP listener.
///
/// Accepts an already-bound listener — useful for integration tests where port 0
/// is bound by the OS for a random free port, avoiding test conflicts.
///
/// # Errors
///
/// Returns an error if the axum server fails during operation.
pub async fn start_server_with_listener(
    listener: TcpListener,
    config: Config,
    token: String,
) -> anyhow::Result<(u64, u64)> {
    serve_with_listener(listener, config, token).await
}

/// Internal implementation: serve HTTP on the given listener.
async fn serve_with_listener(
    listener: TcpListener,
    config: Config,
    token: String,
) -> anyhow::Result<(u64, u64)> {
    let log_dir = config.log_dir.clone();
    let (event_logger, logger_handle) = EventLogger::new(log_dir.clone());

    let privacy_filter = PrivacyFilter::new(&config.extra_patterns);

    let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<()>();

    let state = Arc::new(AppState {
        config: Arc::new(config.clone()),
        token,
        dedup: dedup::DedupStore::new(),
        event_logger,
        privacy_filter,
        event_counter: AtomicU64::new(0),
        shutdown_tx: tokio::sync::Mutex::new(Some(shutdown_tx)),
        started_at: std::time::Instant::now(),
        available_update: Mutex::new(None),
    });

    // UPDT-01: Spawn async update check at startup (T-02-14: 2s timeout, non-blocking)
    if config.update.check {
        let current = env!("CARGO_PKG_VERSION").to_string();
        let state_for_update = state.clone();
        tokio::spawn(async move {
            if let Some(latest) = update::check_for_update(&current).await {
                tracing::warn!(code = crate::error::ERR_VERSION_OUTDATED, latest_version = %latest, "Update available: run `npx openlatch@latest`");
                state_for_update.set_available_update(latest);
            }
        });
    }

    // Spawn periodic dedup eviction to prevent unbounded memory growth
    let state_for_evict = state.clone();
    tokio::spawn(async move {
        let mut interval = tokio::time::interval(std::time::Duration::from_secs(30));
        loop {
            interval.tick().await;
            state_for_evict.dedup.evict_expired();
        }
    });

    // Hook routes — require Bearer token + JSON content type
    let hook_routes = Router::new()
        .route("/hooks/pre-tool-use", post(handlers::pre_tool_use))
        .route(
            "/hooks/user-prompt-submit",
            post(handlers::user_prompt_submit),
        )
        .route("/hooks/stop", post(handlers::stop))
        .route_layer(middleware::from_fn(require_json_content_type))
        .route_layer(middleware::from_fn_with_state(
            state.clone(),
            auth::bearer_auth,
        ));

    // Shutdown route — requires Bearer token but no JSON body
    let shutdown_route = Router::new()
        .route("/shutdown", post(handlers::shutdown_handler))
        .route_layer(middleware::from_fn_with_state(
            state.clone(),
            auth::bearer_auth,
        ));

    // Public routes — no authentication required
    let public_routes = Router::new()
        .route("/health", get(handlers::health))
        .route("/metrics", get(handlers::metrics));

    let app = Router::new()
        .merge(hook_routes)
        .merge(shutdown_route)
        .merge(public_routes)
        // SECURITY: 1MB body limit — reject oversized payloads with 413 before parsing
        .layer(DefaultBodyLimit::max(1_048_576))
        .with_state(state.clone());

    axum::serve(listener, app)
        .with_graceful_shutdown(async move {
            tokio::select! {
                _ = signal_handler() => {
                    tracing::info!("received OS shutdown signal");
                }
                _ = shutdown_rx => {
                    tracing::info!("received shutdown via /shutdown endpoint");
                }
            }
        })
        .await?;

    // Capture final stats before releasing state
    let uptime_secs = state.started_at.elapsed().as_secs();
    let events = state
        .event_counter
        .load(std::sync::atomic::Ordering::Relaxed);

    crate::logging::daemon_log::log_shutdown(uptime_secs, events);

    // Release Arc so EventLogger's sender is dropped, signaling the writer task to exit.
    // If handlers still hold Arc clones at shutdown, warn — log drain may be incomplete.
    match Arc::try_unwrap(state) {
        Ok(_state) => { /* sole owner — sender dropped cleanly */ }
        Err(arc) => {
            tracing::warn!(
                strong_refs = Arc::strong_count(&arc),
                "AppState still has references at shutdown — log drain may be incomplete"
            );
            drop(arc);
        }
    }
    logger_handle.shutdown().await;

    Ok((uptime_secs, events))
}

/// Format a duration in seconds as a human-readable uptime string.
///
/// Examples: `"45s"`, `"3m12s"`, `"2h14m"`
pub fn format_uptime(secs: u64) -> String {
    let hours = secs / 3600;
    let minutes = (secs % 3600) / 60;
    let seconds = secs % 60;
    if hours > 0 {
        format!("{}h{}m", hours, minutes)
    } else if minutes > 0 {
        format!("{}m{}s", minutes, seconds)
    } else {
        format!("{}s", seconds)
    }
}

/// SECURITY: Reject non-JSON content types with 415 Unsupported Media Type.
///
/// Applied to POST routes before the JSON body extractor, so wrong content types
/// are caught before axum's Json<T> returns 422.
async fn require_json_content_type(request: Request, next: Next) -> Result<Response, StatusCode> {
    let ct = request
        .headers()
        .get(CONTENT_TYPE)
        .and_then(|v| v.to_str().ok())
        .unwrap_or("");
    if !ct.starts_with("application/json") {
        return Err(StatusCode::UNSUPPORTED_MEDIA_TYPE);
    }
    Ok(next.run(request).await)
}

/// Wait for an OS shutdown signal (SIGTERM on Unix, Ctrl+C on all platforms).
async fn signal_handler() {
    #[cfg(unix)]
    {
        use tokio::signal::unix::{signal, SignalKind};
        let mut sigterm =
            signal(SignalKind::terminate()).expect("failed to register SIGTERM handler");
        tokio::select! {
            _ = tokio::signal::ctrl_c() => {}
            _ = sigterm.recv() => {}
        }
    }
    #[cfg(not(unix))]
    {
        tokio::signal::ctrl_c()
            .await
            .expect("failed to register ctrl_c handler");
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_format_uptime_seconds_only() {
        assert_eq!(format_uptime(0), "0s");
        assert_eq!(format_uptime(45), "45s");
        assert_eq!(format_uptime(59), "59s");
    }

    #[test]
    fn test_format_uptime_minutes_and_seconds() {
        assert_eq!(format_uptime(60), "1m0s");
        assert_eq!(format_uptime(192), "3m12s");
        assert_eq!(format_uptime(3599), "59m59s");
    }

    #[test]
    fn test_format_uptime_hours_and_minutes() {
        assert_eq!(format_uptime(3600), "1h0m");
        assert_eq!(format_uptime(8094), "2h14m");
        assert_eq!(format_uptime(7200), "2h0m");
    }
}