fuel-web-utils 0.0.29

Fuel library for web utils
use std::{collections::HashMap, sync::Arc};

use axum::{
    body::Body,
    extract::{FromRequestParts, Query, Request, State},
    http::{header::AUTHORIZATION, request::Parts},
    middleware::Next,
    response::Response,
};
use fuel_streams_domains::infra::Db;

use super::{
    rate_limiter::RateLimitsController,
    ApiKey,
    ApiKeyError,
    ApiKeyId,
    ApiKeyRole,
    ApiKeyValue,
    InMemoryApiKeyStorage,
    KeyStorage,
};

#[derive(Debug, Clone)]
pub struct ApiKeysManager {
    storage: Arc<InMemoryApiKeyStorage>,
    rate_limiter_controller: Arc<RateLimitsController>,
}

impl Default for ApiKeysManager {
    fn default() -> Self {
        let storage = Arc::new(InMemoryApiKeyStorage::new());
        Self {
            storage,
            rate_limiter_controller: RateLimitsController::default().arc(),
        }
    }
}

impl ApiKeysManager {
    pub fn storage(&self) -> &Arc<InMemoryApiKeyStorage> {
        &self.storage
    }

    pub fn rate_limiter(&self) -> &Arc<RateLimitsController> {
        &self.rate_limiter_controller
    }

    pub async fn load_from_db(
        &self,
        db: &Arc<Db>,
    ) -> Result<Vec<ApiKey>, ApiKeyError> {
        let pool = db.pool_ref();
        let db_keys = ApiKey::fetch_all(pool).await?;
        Ok(db_keys)
    }

    pub async fn get_api_key_from_db(
        &self,
        key: &ApiKeyValue,
        db: &Arc<Db>,
    ) -> Result<ApiKey, ApiKeyError> {
        let pool = db.pool_ref();
        let api_key = ApiKey::fetch_by_key(pool, key).await?;
        Ok(api_key)
    }

    pub async fn validate_api_key(
        &self,
        key: &ApiKeyValue,
        db: &Arc<Db>,
    ) -> Result<ApiKey, ApiKeyError> {
        match self.storage.find_by_key(key) {
            Ok(key) => {
                tracing::debug!("Cache hit for API key");
                Ok(key)
            }
            Err(ApiKeyError::NotFound) => {
                tracing::debug!("Cache miss for API key");
                self.get_api_key_from_db(key, db).await
            }
            Err(e) => Err(e),
        }
    }

    pub fn check_subscriptions(
        &self,
        id: &ApiKeyId,
        role: &ApiKeyRole,
    ) -> Result<(), ApiKeyError> {
        let (allowed, limit) =
            self.rate_limiter().check_subscriptions(id, role)?;
        if !allowed {
            return Err(ApiKeyError::RateLimitExceeded(limit.to_string()));
        }
        Ok(())
    }

    pub fn check_rate_limit(
        &self,
        id: &ApiKeyId,
        role: &ApiKeyRole,
    ) -> Result<(), ApiKeyError> {
        let (allowed, limit) =
            self.rate_limiter().check_rate_limit(id, role)?;
        if !allowed {
            return Err(ApiKeyError::RateLimitExceeded(limit.to_string()));
        }
        Ok(())
    }

    pub async fn extract_api_key(
        parts: &mut Parts,
    ) -> Result<ApiKeyValue, ApiKeyError> {
        if let Some(auth_header) = parts.headers.get(AUTHORIZATION) {
            let token = auth_header.to_str().map_err(|_| {
                ApiKeyError::InvalidHeader("Invalid header".to_string())
            })?;
            if token.starts_with("Bearer ") {
                return Ok(ApiKeyValue::new(
                    token.trim_start_matches("Bearer ").to_string(),
                ));
            }
        }

        let query =
            Query::<HashMap<String, String>>::from_request_parts(parts, &())
                .await
                .map_err(|_| ApiKeyError::Invalid)?;
        if let Some(key) = query.get("api_key") {
            return Ok(ApiKeyValue::new(key.clone()));
        }

        Err(ApiKeyError::NotFound)
    }

    pub async fn middleware(
        State(manager): State<Arc<Self>>,
        State(db): State<Arc<Db>>,
        req: Request,
        next: Next,
    ) -> Result<Response, ApiKeyError> {
        let mut parts = req.into_parts().0;
        let api_key_str = Self::extract_api_key(&mut parts).await?;
        let api_key = manager.validate_api_key(&api_key_str, &db).await?;
        manager.check_subscriptions(api_key.id(), api_key.role())?;
        manager.check_rate_limit(api_key.id(), api_key.role())?;
        api_key.validate_status()?;
        let mut req = Request::from_parts(parts, Body::default());
        req.extensions_mut().insert(api_key.clone());
        let response = next.run(req).await;
        Ok(response)
    }

    #[cfg(any(test, feature = "test-helpers"))]
    pub fn new_for_testing() -> Self {
        let storage = Arc::new(InMemoryApiKeyStorage::new());
        let rate_limiter = RateLimitsController::default().arc();
        Self {
            storage,
            rate_limiter_controller: rate_limiter,
        }
    }
}

#[cfg(test)]
mod tests {
    use axum::http::{HeaderMap, HeaderValue};
    use pretty_assertions::assert_eq;

    use super::*;

    #[tokio::test]
    async fn test_key_extraction_from_header() {
        let mut headers = HeaderMap::new();
        headers.insert(
            AUTHORIZATION,
            HeaderValue::from_str("Bearer test_api_key").unwrap(),
        );
        let req = axum::http::Request::builder()
            .uri("/test")
            .header(AUTHORIZATION, "Bearer test_api_key")
            .body(())
            .unwrap();

        let mut parts = req.into_parts().0;
        let result = ApiKeysManager::extract_api_key(&mut parts).await;
        assert!(result.is_ok());
        assert_eq!(result.unwrap().to_string(), "test_api_key");
    }

    #[tokio::test]
    async fn test_key_extraction_from_query() {
        let req = axum::http::Request::builder()
            .uri("/test?api_key=test_api_key")
            .body(())
            .unwrap();

        let mut parts = req.into_parts().0;
        let result = ApiKeysManager::extract_api_key(&mut parts).await;
        assert!(result.is_ok());
        assert_eq!(result.unwrap().to_string(), "test_api_key");
    }

    #[tokio::test]
    async fn test_key_extraction_missing() {
        let req = axum::http::Request::builder()
            .uri("/test")
            .body(())
            .unwrap();

        let mut parts = req.into_parts().0;
        let result = ApiKeysManager::extract_api_key(&mut parts).await;
        assert!(result.is_err());
        assert!(matches!(result, Err(ApiKeyError::NotFound)));
    }
}