forge-runtime 0.9.0

Runtime executors and gateway for the Forge framework
use std::time::Duration;

use chrono::{DateTime, Utc};
use dashmap::DashMap;
use sqlx::PgPool;

use forge_core::rate_limit::{RateLimitConfig, RateLimitKey, RateLimitResult};
use forge_core::{AuthContext, ForgeError, RequestMetadata, Result};

/// Rate limiter using PostgreSQL for state storage.
///
/// Implements token bucket algorithm with atomic updates.
pub struct RateLimiter {
    pool: PgPool,
}

impl RateLimiter {
    /// Create a new rate limiter.
    pub fn new(pool: PgPool) -> Self {
        Self { pool }
    }

    /// Check rate limit for a bucket key.
    pub async fn check(
        &self,
        bucket_key: &str,
        config: &RateLimitConfig,
    ) -> Result<RateLimitResult> {
        let max_tokens = config.requests as f64;
        let refill_rate = config.refill_rate();

        // Atomic upsert with token bucket logic
        let result = sqlx::query!(
            r#"
            INSERT INTO forge_rate_limits (bucket_key, tokens, last_refill, max_tokens, refill_rate)
            VALUES ($1, $2 - 1, NOW(), $2, $3)
            ON CONFLICT (bucket_key) DO UPDATE SET
                tokens = LEAST(
                    forge_rate_limits.max_tokens::double precision,
                    forge_rate_limits.tokens +
                        (EXTRACT(EPOCH FROM (NOW() - forge_rate_limits.last_refill)) * forge_rate_limits.refill_rate)
                ) - 1,
                last_refill = NOW()
            RETURNING tokens, max_tokens, last_refill, (tokens >= 0) as "allowed!"
            "#,
            bucket_key,
            max_tokens as i32,
            refill_rate
        )
        .fetch_one(&self.pool)
        .await
        .map_err(|e| ForgeError::Database(e.to_string()))?;

        let tokens = result.tokens;
        let last_refill = result.last_refill;
        let allowed = result.allowed;

        let remaining = tokens.max(0.0) as u32;
        let reset_at =
            last_refill + chrono::Duration::seconds(((max_tokens - tokens) / refill_rate) as i64);

        if allowed {
            Ok(RateLimitResult::allowed(remaining, reset_at))
        } else {
            let retry_after = Duration::from_secs_f64((1.0 - tokens) / refill_rate);
            Ok(RateLimitResult::denied(remaining, reset_at, retry_after))
        }
    }

    /// Build a bucket key for the given parameters.
    pub fn build_key(
        &self,
        key_type: RateLimitKey,
        action_name: &str,
        auth: &AuthContext,
        request: &RequestMetadata,
    ) -> String {
        match key_type {
            RateLimitKey::User => {
                let user_id = auth
                    .user_id()
                    .map(|u| u.to_string())
                    .unwrap_or_else(|| "anonymous".to_string());
                format!("user:{}:{}", user_id, action_name)
            }
            RateLimitKey::Ip => {
                let ip = request.client_ip.as_deref().unwrap_or("unknown");
                format!("ip:{}:{}", ip, action_name)
            }
            RateLimitKey::Tenant => {
                let tenant_id = auth
                    .claim("tenant_id")
                    .and_then(|v| v.as_str())
                    .unwrap_or("none");
                format!("tenant:{}:{}", tenant_id, action_name)
            }
            RateLimitKey::UserAction => {
                let user_id = auth
                    .user_id()
                    .map(|u| u.to_string())
                    .unwrap_or_else(|| "anonymous".to_string());
                format!("user_action:{}:{}", user_id, action_name)
            }
            RateLimitKey::Global => {
                format!("global:{}", action_name)
            }
        }
    }

    /// Check rate limit and return an error if exceeded.
    pub async fn enforce(
        &self,
        bucket_key: &str,
        config: &RateLimitConfig,
    ) -> Result<RateLimitResult> {
        let result = self.check(bucket_key, config).await?;
        if !result.allowed {
            crate::signals::emit_diagnostic(
                "rate_limit.exceeded",
                serde_json::json!({
                    "bucket": bucket_key,
                    "limit": config.requests,
                    "remaining": result.remaining,
                    "retry_after_ms": result
                        .retry_after
                        .unwrap_or(Duration::from_secs(1))
                        .as_millis() as u64,
                }),
                None,
                None,
                None,
                None,
                false,
            );
            return Err(ForgeError::RateLimitExceeded {
                retry_after: result.retry_after.unwrap_or(Duration::from_secs(1)),
                limit: config.requests,
                remaining: result.remaining,
            });
        }
        Ok(result)
    }

    /// Reset a rate limit bucket.
    pub async fn reset(&self, bucket_key: &str) -> Result<()> {
        sqlx::query!(
            "DELETE FROM forge_rate_limits WHERE bucket_key = $1",
            bucket_key
        )
        .execute(&self.pool)
        .await
        .map_err(|e| ForgeError::Database(e.to_string()))?;
        Ok(())
    }

    /// Clean up old rate limit entries.
    pub async fn cleanup(&self, older_than: DateTime<Utc>) -> Result<u64> {
        let result = sqlx::query!(
            r#"
            DELETE FROM forge_rate_limits
            WHERE created_at < $1
            "#,
            older_than,
        )
        .execute(&self.pool)
        .await
        .map_err(|e| ForgeError::Database(e.to_string()))?;

        Ok(result.rows_affected())
    }
}

struct LocalBucket {
    tokens: f64,
    max_tokens: f64,
    refill_rate: f64,
    last_refill: std::time::Instant,
}

impl LocalBucket {
    fn new(max_tokens: f64, refill_rate: f64) -> Self {
        Self {
            tokens: max_tokens,
            max_tokens,
            refill_rate,
            last_refill: std::time::Instant::now(),
        }
    }

    fn try_consume(&mut self) -> bool {
        let now = std::time::Instant::now();
        let elapsed = now.duration_since(self.last_refill).as_secs_f64();
        self.tokens = (self.tokens + elapsed * self.refill_rate).min(self.max_tokens);
        self.last_refill = now;

        if self.tokens >= 1.0 {
            self.tokens -= 1.0;
            true
        } else {
            false
        }
    }

    fn remaining(&self) -> u32 {
        self.tokens.max(0.0) as u32
    }

    fn time_until_token(&self) -> Duration {
        if self.tokens >= 1.0 {
            Duration::ZERO
        } else {
            Duration::from_secs_f64((1.0 - self.tokens) / self.refill_rate)
        }
    }
}

/// Maximum number of local rate limit buckets to prevent unbounded memory growth.
/// When exceeded, a cleanup is triggered to evict idle entries.
const MAX_LOCAL_BUCKETS: usize = 100_000;

/// Hybrid rate limiter with in-memory fast path and periodic DB sync.
///
/// Per-user/per-IP checks use a local DashMap for sub-microsecond decisions.
/// Global keys always hit the database for strict cross-node consistency.
pub struct HybridRateLimiter {
    local: DashMap<String, LocalBucket>,
    db_limiter: RateLimiter,
}

impl HybridRateLimiter {
    pub fn new(pool: PgPool) -> Self {
        Self {
            local: DashMap::new(),
            db_limiter: RateLimiter::new(pool),
        }
    }

    /// Check rate limit. Uses local fast path for per-user/per-IP keys,
    /// database for global keys.
    pub async fn check(
        &self,
        bucket_key: &str,
        config: &RateLimitConfig,
    ) -> Result<RateLimitResult> {
        if config.key == RateLimitKey::Global {
            return self.db_limiter.check(bucket_key, config).await;
        }

        let max_tokens = config.requests as f64;
        let refill_rate = config.refill_rate();

        // Evict idle buckets when the map gets too large to prevent memory exhaustion
        if self.local.len() > MAX_LOCAL_BUCKETS {
            self.cleanup_local(Duration::from_secs(300)); // evict entries idle > 5 min
        }

        let mut bucket = self
            .local
            .entry(bucket_key.to_string())
            .or_insert_with(|| LocalBucket::new(max_tokens, refill_rate));

        let allowed = bucket.try_consume();
        let remaining = bucket.remaining();
        let reset_at = Utc::now()
            + chrono::Duration::seconds(((max_tokens - bucket.tokens) / refill_rate) as i64);

        if allowed {
            Ok(RateLimitResult::allowed(remaining, reset_at))
        } else {
            let retry_after = bucket.time_until_token();
            Ok(RateLimitResult::denied(remaining, reset_at, retry_after))
        }
    }

    pub fn build_key(
        &self,
        key_type: RateLimitKey,
        action_name: &str,
        auth: &AuthContext,
        request: &RequestMetadata,
    ) -> String {
        self.db_limiter
            .build_key(key_type, action_name, auth, request)
    }

    pub async fn enforce(
        &self,
        bucket_key: &str,
        config: &RateLimitConfig,
    ) -> Result<RateLimitResult> {
        let result = self.check(bucket_key, config).await?;
        if !result.allowed {
            return Err(ForgeError::RateLimitExceeded {
                retry_after: result.retry_after.unwrap_or(Duration::from_secs(1)),
                limit: config.requests,
                remaining: result.remaining,
            });
        }
        Ok(result)
    }

    /// Clean up expired local buckets (call periodically).
    pub fn cleanup_local(&self, max_idle: Duration) {
        let cutoff = std::time::Instant::now() - max_idle;
        self.local.retain(|_, bucket| bucket.last_refill > cutoff);
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[tokio::test]
    async fn test_rate_limiter_creation() {
        let pool = sqlx::postgres::PgPoolOptions::new()
            .max_connections(1)
            .connect_lazy("postgres://localhost/test")
            .expect("Failed to create mock pool");

        let _limiter = RateLimiter::new(pool);
    }

    #[tokio::test]
    async fn test_build_key() {
        let pool = sqlx::postgres::PgPoolOptions::new()
            .max_connections(1)
            .connect_lazy("postgres://localhost/test")
            .expect("Failed to create mock pool");

        let limiter = RateLimiter::new(pool);
        let auth = AuthContext::unauthenticated();
        let request = RequestMetadata::default();

        let key = limiter.build_key(RateLimitKey::Global, "test_action", &auth, &request);
        assert_eq!(key, "global:test_action");

        let key = limiter.build_key(RateLimitKey::User, "test_action", &auth, &request);
        assert!(key.starts_with("user:"));
    }
}