rok-core 0.6.1

Core primitives for the rok ecosystem — errors, crypto, i18n, config, DI, and more
Documentation
//! Request lifecycle hooks — before/after interceptors for request handlers.
//!
//! Provides traits and built-in interceptors for cross-cutting concerns like
//! logging, rate limiting, and request validation.

use std::{
    collections::HashMap,
    net::IpAddr,
    sync::{
        atomic::{AtomicU64, Ordering},
        Arc, Mutex,
    },
    time::Instant,
};

use axum::http::{Method, Uri};

use crate::api::ApiResponse;

/// Contextual information about the current HTTP request.
#[derive(Debug, Clone)]
pub struct RequestContext {
    /// HTTP method (GET, POST, etc.)
    pub method: Method,
    /// Request URI path + query
    pub uri: Uri,
    /// Client IP address (if available)
    pub ip: Option<IpAddr>,
    /// Authenticated user ID (if available)
    pub user_id: Option<i64>,
    /// User-Agent header value
    pub user_agent: Option<String>,
    /// Timestamp when the request started
    pub started_at: Instant,
}

impl RequestContext {
    pub fn new(method: Method, uri: Uri) -> Self {
        Self {
            method,
            uri,
            ip: None,
            user_id: None,
            user_agent: None,
            started_at: Instant::now(),
        }
    }

    pub fn with_ip(mut self, ip: IpAddr) -> Self {
        self.ip = Some(ip);
        self
    }

    pub fn with_user_id(mut self, user_id: i64) -> Self {
        self.user_id = Some(user_id);
        self
    }

    pub fn with_user_agent(mut self, agent: String) -> Self {
        self.user_agent = Some(agent);
        self
    }
}

/// Trait for interceptors that run **before** a request handler.
pub trait BeforeHandler: Send + Sync {
    /// Called before the request handler.
    ///
    /// Return `Ok(())` to allow the request to proceed, or `Err(ApiResponse)`
    /// to short-circuit with an error response.
    fn before(&self, cx: &RequestContext) -> Result<(), ApiResponse>;
}

/// Trait for interceptors that run **after** a request handler.
pub trait AfterHandler: Send + Sync {
    /// Called after the request handler produces a response.
    ///
    /// Can inspect and/or modify the response (e.g. add headers, wrap in envelope).
    fn after(&self, cx: &RequestContext, response: ApiResponse) -> ApiResponse;
}

// ── Built-in interceptors ─────────────────────────────────────────────────────

/// Request/response logging interceptor.
///
/// Logs each request at `info` level with method, path, status, and duration.
pub struct LogRequest;

impl LogRequest {
    pub fn new() -> Self {
        Self
    }
}

impl BeforeHandler for LogRequest {
    fn before(&self, cx: &RequestContext) -> Result<(), ApiResponse> {
        #[cfg(feature = "tracing")]
        tracing::info!(
            method = %cx.method,
            path = %cx.uri,
            ip = ?cx.ip,
            user_agent = ?cx.user_agent,
            "incoming request"
        );
        let _ = cx;
        Ok(())
    }
}

impl AfterHandler for LogRequest {
    fn after(&self, cx: &RequestContext, response: ApiResponse) -> ApiResponse {
        #[cfg(feature = "tracing")]
        {
            let status = response.status_code();
            let elapsed = cx.started_at.elapsed();
            tracing::info!(
                method = %cx.method,
                path = %cx.uri,
                status = status.as_u16(),
                duration_ms = elapsed.as_secs_f64() * 1000.0,
                "request completed"
            );
        }
        let _ = cx;
        response
    }
}

/// Simple in-memory rate limiter keyed by IP address.
///
/// Uses a token-bucket-like approach per second window.
pub struct ThrottleInterceptor {
    max_requests: u64,
    window_secs: u64,
    buckets: Arc<Mutex<HashMap<IpAddr, RateBucket>>>,
}

struct RateBucket {
    count: AtomicU64,
    window_start: Instant,
}

impl ThrottleInterceptor {
    /// Create a new throttle interceptor.
    ///
    /// `max_requests` — max requests per `window_secs` per IP.
    pub fn new(max_requests: u64, window_secs: u64) -> Self {
        Self {
            max_requests,
            window_secs,
            buckets: Arc::new(Mutex::new(HashMap::new())),
        }
    }
}

impl BeforeHandler for ThrottleInterceptor {
    fn before(&self, cx: &RequestContext) -> Result<(), ApiResponse> {
        let ip = match cx.ip {
            Some(ip) => ip,
            None => return Ok(()),
        };

        let buckets = self.buckets.clone();
        let max = self.max_requests;
        let window = self.window_secs;

        // Use tokio::spawn_blocking or just block — for simplicity in this scope
        // we use a synchronous check. In a real deployment, use a dedicated
        // rate-limiter crate.
        let now = Instant::now();
        let mut map = buckets.lock().unwrap();
        let bucket = map.entry(ip).or_insert_with(|| RateBucket {
            count: AtomicU64::new(0),
            window_start: now,
        });

        if now.duration_since(bucket.window_start).as_secs() >= window {
            bucket.count.store(1, Ordering::SeqCst);
            bucket.window_start = now;
            return Ok(());
        }

        let current = bucket.count.fetch_add(1, Ordering::SeqCst);
        if current >= max {
            return Err(ApiResponse::error(
                "E_RATE_LIMIT_EXCEEDED",
                "Too many requests. Please try again later.",
                429,
            ));
        }

        Ok(())
    }
}

/// Request body validation interceptor (placeholder / trait-based).
///
/// In practice, this would be integrated with `rok_validate` to automatically
/// validate request bodies before they reach the handler.
pub struct ValidateBody<T> {
    _marker: std::marker::PhantomData<T>,
}

impl<T> ValidateBody<T> {
    pub fn new() -> Self {
        Self {
            _marker: std::marker::PhantomData,
        }
    }
}

impl<T: Send + Sync> BeforeHandler for ValidateBody<T> {
    fn before(&self, _cx: &RequestContext) -> Result<(), ApiResponse> {
        // Integration with rok_validate would go here.
        // Currently a placeholder — actual validation is done via the
        // `Valid<T>` extractor in the handler itself.
        Ok(())
    }
}

// ── Composite hook ────────────────────────────────────────────────────────────

/// A composite interceptor that runs multiple `BeforeHandler` and `AfterHandler`
/// instances in sequence.
pub struct HookChain {
    before_hooks: Vec<Arc<dyn BeforeHandler>>,
    after_hooks: Vec<Arc<dyn AfterHandler>>,
}

impl HookChain {
    pub fn new() -> Self {
        Self {
            before_hooks: Vec::new(),
            after_hooks: Vec::new(),
        }
    }

    pub fn push_before(mut self, hook: impl BeforeHandler + 'static) -> Self {
        self.before_hooks.push(Arc::new(hook));
        self
    }

    pub fn push_after(mut self, hook: impl AfterHandler + 'static) -> Self {
        self.after_hooks.push(Arc::new(hook));
        self
    }

    /// Run all before hooks in sequence.
    /// Returns `Err` with the first error response.
    pub fn run_before(&self, cx: &RequestContext) -> Result<(), ApiResponse> {
        for hook in &self.before_hooks {
            hook.before(cx)?;
        }
        Ok(())
    }

    /// Run all after hooks in sequence (in reverse order).
    pub fn run_after(&self, cx: &RequestContext, response: ApiResponse) -> ApiResponse {
        let mut response = response;
        for hook in self.after_hooks.iter().rev() {
            response = hook.after(cx, response);
        }
        response
    }
}

impl Default for HookChain {
    fn default() -> Self {
        Self::new()
    }
}