pg-api 0.1.0

A high-performance PostgreSQL REST API driver with rate limiting, connection pooling, and observability
use crate::models::AppState;
use axum::{
    extract::Request,
    http::{HeaderMap, StatusCode},
    middleware::Next,
    response::{IntoResponse, Response},
};

#[derive(Debug, Clone)]
pub struct RateLimitInfo {
    pub requests: Vec<std::time::Instant>,
    pub limit: u32,
    pub window_seconds: u64,
}

impl RateLimitInfo {
    pub fn new(limit: u32, window_seconds: u64) -> Self {
        Self {
            requests: Vec::new(),
            limit,
            window_seconds,
        }
    }

    pub fn is_allowed(&mut self) -> bool {
        let now = std::time::Instant::now();
        let window_start = now - std::time::Duration::from_secs(self.window_seconds);
        
        // Remove old requests outside the window
        self.requests.retain(|&time| time > window_start);
        
        // Check if we're under the limit
        if self.requests.len() < self.limit as usize {
            self.requests.push(now);
            true
        } else {
            false
        }
    }

    pub fn remaining(&self) -> u32 {
        let now = std::time::Instant::now();
        let window_start = now - std::time::Duration::from_secs(self.window_seconds);
        
        // Count requests within the window
        let count = self.requests.iter().filter(|&&time| time > window_start).count();
        self.limit.saturating_sub(count as u32)
    }

    pub fn reset_time(&self) -> u64 {
        let now = std::time::Instant::now();
        let window_start = now - std::time::Duration::from_secs(self.window_seconds);
        
        // Find the oldest request in the window
        if let Some(&oldest) = self.requests.iter().filter(|&&time| time > window_start).min() {
            let elapsed = now.duration_since(oldest).as_secs();
            self.window_seconds.saturating_sub(elapsed)
        } else {
            self.window_seconds
        }
    }
}

pub async fn rate_limit_middleware(
    axum::extract::State(state): axum::extract::State<AppState>,
    request: Request,
    next: Next,
) -> Result<Response, StatusCode> {
    // Skip rate limiting for health check
    if request.uri().path() == "/health" {
        return Ok(next.run(request).await);
    }

    // Extract API key
    let api_key = extract_api_key(request.headers()).map_err(|_| StatusCode::UNAUTHORIZED)?;
    
    // Get account info
    let accounts = state.accounts.read().await;
    if let Some(account) = accounts.get(&api_key) {
        let account_id = account.id.clone();
        let rate_limit = account.rate_limit;
        drop(accounts); // Release the read lock
        
        // Skip rate limiting if limit is 0 (unlimited)
        if rate_limit == 0 {
            return Ok(next.run(request).await);
        }
        
        // Check rate limit
        let rate_limiter = state.rate_limiter.clone();
        let mut limiters = rate_limiter.lock().await;
        
        let limiter = limiters.entry(account_id.clone()).or_insert_with(|| {
            RateLimitInfo::new(rate_limit, 60) // 60 second window
        });
        
        // Update limit if it has changed
        if limiter.limit != rate_limit {
            limiter.limit = rate_limit;
        }
        
        if limiter.is_allowed() {
            let remaining = limiter.remaining();
            let reset = limiter.reset_time();
            drop(limiters); // Release the lock
            
            let mut response = next.run(request).await;
            
            // Add rate limit headers
            response.headers_mut().insert(
                "X-RateLimit-Limit", 
                rate_limit.to_string().parse().unwrap()
            );
            response.headers_mut().insert(
                "X-RateLimit-Remaining", 
                remaining.to_string().parse().unwrap()
            );
            response.headers_mut().insert(
                "X-RateLimit-Reset", 
                reset.to_string().parse().unwrap()
            );
            
            Ok(response)
        } else {
            let reset = limiter.reset_time();
            drop(limiters); // Release the lock
            
            // Return 429 Too Many Requests
            let error_response = serde_json::json!({
                "success": false,
                "error": {
                    "code": "RATE_LIMIT_EXCEEDED",
                    "message": "Rate limit exceeded. Try again later."
                }
            });
            
            let body = axum::Json(error_response);
            
            Ok((StatusCode::TOO_MANY_REQUESTS, 
                [
                    ("Content-Type", "application/json"),
                    ("X-RateLimit-Limit", &rate_limit.to_string()),
                    ("X-RateLimit-Remaining", "0"),
                    ("X-RateLimit-Reset", &reset.to_string()),
                ],
                body).into_response())
        }
    } else {
        // Account not found, let auth middleware handle it
        Ok(next.run(request).await)
    }
}

fn extract_api_key(headers: &HeaderMap) -> Result<String, String> {
    headers
        .get("x-api-key")
        .or_else(|| headers.get("authorization"))
        .and_then(|h| h.to_str().ok())
        .map(|h| {
            if h.starts_with("Bearer ") {
                h.strip_prefix("Bearer ").unwrap().to_string()
            } else {
                h.to_string()
            }
        })
        .ok_or_else(|| "Missing API key".to_string())
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_rate_limiter_basic() {
        let mut limiter = RateLimitInfo::new(5, 60); // 5 requests per minute
        
        // Test allowed requests
        for i in 1..=5 {
            let allowed = limiter.is_allowed();
            assert_eq!(allowed, true, "Request {} should be allowed", i);
        }
        
        // Test rate limited request
        let allowed = limiter.is_allowed();
        assert_eq!(allowed, false, "Request 6 should be rate limited");
        
        // Test remaining count
        let remaining = limiter.remaining();
        assert_eq!(remaining, 0, "Remaining requests should be 0");
    }
}