pg-api 0.1.0

A high-performance PostgreSQL REST API driver with rate limiting, connection pooling, and observability
use crate::models::AppState;
use axum::{
    extract::{Request, State},
    http::{HeaderMap, StatusCode},
    middleware::Next,
    response::{IntoResponse, Response},
};

pub async fn connection_limit_middleware(
    State(state): State<AppState>,
    request: Request,
    next: Next,
) -> Result<Response, StatusCode> {
    // Skip connection limiting for non-query endpoints
    let path = request.uri().path();
    if !path.starts_with("/v1/query") && !path.starts_with("/v1/batch") && !path.starts_with("/v1/transaction") {
        return Ok(next.run(request).await);
    }

    // Extract API key
    let api_key = match extract_api_key(request.headers()) {
        Ok(key) => key,
        Err(_) => {
            // Let auth middleware handle missing API key
            return Ok(next.run(request).await);
        }
    };
    
    // Get account info
    let accounts = state.accounts.read().await;
    if let Some(account) = accounts.get(&api_key) {
        let account_id = account.id.clone();
        let max_connections = account.max_connections;
        drop(accounts); // Release the read lock
        
        // Skip connection limiting if limit is 0 (unlimited)
        if max_connections == 0 {
            return Ok(next.run(request).await);
        }
        
        // Check current connection count
        let current_count = state.active_connections
            .get(&account_id)
            .map(|entry| *entry.value())
            .unwrap_or(0);
        
        if current_count >= max_connections {
            // Return error response
            let error_response = serde_json::json!({
                "success": false,
                "error": {
                    "code": "CONNECTION_LIMIT_EXCEEDED",
                    "message": format!("Maximum connections ({}) exceeded. Current: {}", max_connections, current_count)
                }
            });
            
            return Ok((
                StatusCode::SERVICE_UNAVAILABLE,
                [("Content-Type", "application/json")],
                axum::Json(error_response)
            ).into_response());
        }
        
        // Increment connection count
        state.active_connections
            .entry(account_id.clone())
            .and_modify(|count| *count += 1)
            .or_insert(1);
        
        // Process request
        let response = next.run(request).await;
        
        // Decrement connection count after request completes
        state.active_connections
            .entry(account_id)
            .and_modify(|count| *count = count.saturating_sub(1));
        
        Ok(response)
    } else {
        // Account not found, let auth middleware handle it
        Ok(next.run(request).await)
    }
}

fn extract_api_key(headers: &HeaderMap) -> Result<String, String> {
    headers
        .get("x-api-key")
        .or_else(|| headers.get("authorization"))
        .and_then(|h| h.to_str().ok())
        .map(|h| {
            if h.starts_with("Bearer ") {
                h.strip_prefix("Bearer ").unwrap().to_string()
            } else {
                h.to_string()
            }
        })
        .ok_or_else(|| "Missing API key".to_string())
}

#[cfg(test)]
mod tests {
    use super::*;
    use axum::http::HeaderValue;

    #[test]
    fn test_extract_api_key() {
        let mut headers = HeaderMap::new();
        
        // Test x-api-key header
        headers.insert("x-api-key", HeaderValue::from_static("test-key"));
        assert_eq!(extract_api_key(&headers).unwrap(), "test-key");
        
        // Test authorization header with Bearer
        headers.clear();
        headers.insert("authorization", HeaderValue::from_static("Bearer test-key-2"));
        assert_eq!(extract_api_key(&headers).unwrap(), "test-key-2");
        
        // Test missing key
        headers.clear();
        assert!(extract_api_key(&headers).is_err());
    }
}