axocoatl-server 0.1.0

Axum HTTP/WebSocket API server for Axocoatl
//! HTTP middleware for the Axocoatl API server.
//! Rate limiting, request logging, and CORS.

use std::time::Instant;

use axum::{
    extract::Request,
    http::{header, Method},
    middleware::Next,
    response::Response,
};
use dashmap::DashMap;

/// Rate limiter configuration.
#[derive(Debug, Clone)]
pub struct RateLimitConfig {
    /// Maximum requests per window.
    pub max_requests: u32,
    /// Window duration in seconds.
    pub window_secs: u64,
    /// If true, rate limiting is enabled.
    pub enabled: bool,
}

impl Default for RateLimitConfig {
    fn default() -> Self {
        Self {
            max_requests: 100,
            window_secs: 60,
            enabled: false,
        }
    }
}

/// In-memory rate limiter state.
pub struct RateLimiter {
    config: RateLimitConfig,
    /// IP → (count, window_start)
    state: DashMap<String, (u32, Instant)>,
}

impl RateLimiter {
    pub fn new(config: RateLimitConfig) -> Self {
        Self {
            config,
            state: DashMap::new(),
        }
    }

    /// Check if a request from the given IP should be allowed.
    pub fn check(&self, ip: &str) -> bool {
        if !self.config.enabled {
            return true;
        }

        let now = Instant::now();
        let window = std::time::Duration::from_secs(self.config.window_secs);

        let mut entry = self.state.entry(ip.to_string()).or_insert((0, now));
        let (count, window_start) = entry.value_mut();

        if now.duration_since(*window_start) > window {
            // New window
            *count = 1;
            *window_start = now;
            true
        } else if *count < self.config.max_requests {
            *count += 1;
            true
        } else {
            false
        }
    }
}

/// Request logging middleware.
pub async fn request_logging(request: Request, next: Next) -> Response {
    let method = request.method().clone();
    let uri = request.uri().clone();
    let start = Instant::now();

    let response = next.run(request).await;

    let elapsed = start.elapsed();
    tracing::info!(
        method = %method,
        uri = %uri,
        status = %response.status(),
        latency_ms = elapsed.as_millis(),
        "HTTP request"
    );

    response
}

/// CORS headers for the Axocoatl API.
pub fn cors_headers() -> tower_http::cors::CorsLayer {
    tower_http::cors::CorsLayer::new()
        .allow_origin(tower_http::cors::Any)
        .allow_methods([Method::GET, Method::POST, Method::PUT, Method::DELETE])
        .allow_headers([header::CONTENT_TYPE, header::AUTHORIZATION])
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn rate_limiter_allows_under_limit() {
        let limiter = RateLimiter::new(RateLimitConfig {
            max_requests: 5,
            window_secs: 60,
            enabled: true,
        });

        for _ in 0..5 {
            assert!(limiter.check("127.0.0.1"));
        }
    }

    #[test]
    fn rate_limiter_blocks_over_limit() {
        let limiter = RateLimiter::new(RateLimitConfig {
            max_requests: 2,
            window_secs: 60,
            enabled: true,
        });

        assert!(limiter.check("127.0.0.1"));
        assert!(limiter.check("127.0.0.1"));
        assert!(!limiter.check("127.0.0.1"));
    }

    #[test]
    fn rate_limiter_disabled_allows_all() {
        let limiter = RateLimiter::new(RateLimitConfig::default());

        for _ in 0..1000 {
            assert!(limiter.check("127.0.0.1"));
        }
    }

    #[test]
    fn rate_limiter_separate_ips() {
        let limiter = RateLimiter::new(RateLimitConfig {
            max_requests: 1,
            window_secs: 60,
            enabled: true,
        });

        assert!(limiter.check("1.1.1.1"));
        assert!(!limiter.check("1.1.1.1")); // blocked
        assert!(limiter.check("2.2.2.2")); // different IP, allowed
    }
}