use std::{collections::HashMap, sync::Arc};
use actix_web::http::header::{HeaderMap, HeaderValue, AUTHORIZATION};
use fuel_streams_store::db::Db;
use super::{
rate_limiter::RateLimitsController,
ApiKeyError,
ApiKeyId,
ApiKeyRole,
ApiKeyStorageError,
ApiKeyValue,
};
use crate::api_key::{ApiKey, InMemoryApiKeyStorage, KeyStorage};
#[derive(Debug, thiserror::Error)]
pub enum ApiKeyManagerError {
#[error("Database error: {0}")]
DatabaseError(#[from] sqlx::Error),
#[error("Invalid API key")]
InvalidApiKey,
}
const BEARER: &str = "Bearer";
#[derive(Debug, Clone)]
pub struct ApiKeysManager {
storage: Arc<InMemoryApiKeyStorage>,
rate_limiter_controller: Arc<RateLimitsController>,
}
impl Default for ApiKeysManager {
fn default() -> Self {
Self::new()
}
}
impl ApiKeysManager {
pub fn new() -> Self {
let storage = Arc::new(InMemoryApiKeyStorage::new());
Self {
storage,
rate_limiter_controller: RateLimitsController::default().arc(),
}
}
pub fn storage(&self) -> &Arc<InMemoryApiKeyStorage> {
&self.storage
}
pub fn rate_limiter(&self) -> &Arc<RateLimitsController> {
&self.rate_limiter_controller
}
pub async fn load_from_db(
&self,
db: &Arc<Db>,
) -> Result<Vec<ApiKey>, ApiKeyError> {
let pool = db.pool_ref();
let db_keys = ApiKey::fetch_all(pool).await?;
let api_keys = db_keys.into_iter().collect();
Ok(api_keys)
}
pub async fn get_api_key_from_db(
self,
key: &ApiKeyValue,
db: &Arc<Db>,
) -> Result<ApiKey, ApiKeyError> {
let pool = db.pool_ref();
let api_key = ApiKey::fetch_by_key(pool, key).await?;
Ok(api_key)
}
pub async fn validate_api_key(
&self,
key: &ApiKeyValue,
db: &Arc<Db>,
) -> Result<ApiKey, ApiKeyError> {
match self.storage.find_by_key(key) {
Ok(key) => {
tracing::debug!("Cache hit for API key");
Ok(key)
}
Err(ApiKeyError::Storage(ApiKeyStorageError::KeyNotFound)) => {
tracing::debug!("Cache miss for API key");
self.clone().get_api_key_from_db(key, db).await
}
Err(e) => Err(e),
}
}
pub fn check_subscriptions(
&self,
id: &ApiKeyId,
role: &ApiKeyRole,
) -> Result<(), ApiKeyError> {
let (allowed, limit) =
self.rate_limiter().check_subscriptions(id, role)?;
if !allowed {
return Err(ApiKeyError::RateLimitExceeded(limit.to_string()));
}
Ok(())
}
pub fn check_rate_limit(
&self,
id: &ApiKeyId,
role: &ApiKeyRole,
) -> Result<(), ApiKeyError> {
let (allowed, limit) =
self.rate_limiter().check_rate_limit(id, role)?;
if !allowed {
return Err(ApiKeyError::RateLimitExceeded(limit.to_string()));
}
Ok(())
}
pub fn key_from_headers(
&self,
(headers, query_map): (HeaderMap, HashMap<String, String>),
) -> Result<ApiKeyValue, ApiKeyError> {
let mut headers = headers;
for (key, value) in query_map.iter() {
if key.eq_ignore_ascii_case("api_key") {
let token = format!("Bearer {}", value);
headers.insert(
AUTHORIZATION,
HeaderValue::from_str(&token)
.map_err(ApiKeyError::InvalidHeader)?,
);
}
}
match Self::from_query_string(&headers) {
Ok(key) => Ok(key),
Err(_) => Err(ApiKeyError::NotFound),
}
}
fn from_query_string(
headers: &HeaderMap,
) -> Result<ApiKeyValue, ApiKeyError> {
let token = headers.get(AUTHORIZATION).ok_or(ApiKeyError::NotFound)?;
let token = match token.to_str() {
Ok(token) => token,
Err(_) => return Err(ApiKeyError::Invalid),
};
if !token.starts_with(BEARER) {
return Err(ApiKeyError::Invalid);
}
urlencoding::decode(token.trim_start_matches(BEARER))
.map_err(|_| ApiKeyError::Invalid)
.map(|decoded| decoded.trim().to_string().into())
}
#[cfg(any(test, feature = "test-helpers"))]
pub fn new_for_testing() -> Self {
let storage = Arc::new(InMemoryApiKeyStorage::new());
let rate_limiter = RateLimitsController::default().arc();
Self {
storage,
rate_limiter_controller: rate_limiter,
}
}
}
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use actix_web::http::header::{HeaderMap, HeaderValue, AUTHORIZATION};
use pretty_assertions::assert_eq;
use super::*;
#[test]
fn test_key_from_headers_with_authorization_header() {
let manager = ApiKeysManager::new_for_testing();
let mut headers = HeaderMap::new();
headers.insert(
AUTHORIZATION,
HeaderValue::from_str("Bearer test_api_key").unwrap(),
);
let query_map = HashMap::new();
let result = manager.key_from_headers((headers, query_map));
assert!(
result.is_ok(),
"Should extract API key from Authorization header"
);
assert_eq!(result.unwrap().to_string(), "test_api_key");
}
#[test]
fn test_key_from_headers_with_query_param() {
let manager = ApiKeysManager::new_for_testing();
let headers = HeaderMap::new();
let mut query_map = HashMap::new();
query_map.insert("api_key".to_string(), "test_api_key".to_string());
let result = manager.key_from_headers((headers, query_map));
assert!(
result.is_ok(),
"Should extract API key from query parameters"
);
assert_eq!(result.unwrap().to_string(), "test_api_key");
}
#[test]
fn test_key_from_headers_missing_key() {
let manager = ApiKeysManager::new_for_testing();
let headers = HeaderMap::new();
let query_map = HashMap::new();
let result = manager.key_from_headers((headers, query_map));
assert!(result.is_err(), "Should fail when no API key is provided");
assert!(matches!(result, Err(ApiKeyError::NotFound)));
}
#[test]
fn test_key_from_headers_invalid_format() {
let manager = ApiKeysManager::new_for_testing();
let mut headers = HeaderMap::new();
headers.insert(
AUTHORIZATION,
HeaderValue::from_str("Basic test_api_key").unwrap(),
);
let query_map = HashMap::new();
let result = manager.key_from_headers((headers, query_map));
assert!(
result.is_err(),
"Should fail with invalid Authorization format"
);
}
}