turbomcp-auth 3.0.12

OAuth 2.1 and authentication for TurboMCP with MCP protocol compliance
//! API Key Authentication Provider
//!
//! Simple API key-based authentication for service-to-service communication.
//!
//! ## Security
//!
//! This provider uses constant-time comparison to prevent timing attacks on API keys.
//! See [`crate::api_key_validation`] for implementation details.

use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;

use tokio::sync::RwLock;

use super::super::api_key_validation::validate_api_key;
use super::super::config::AuthProviderType;
use super::super::context::AuthContext;
use super::super::types::{AuthCredentials, AuthProvider, TokenInfo, UserInfo};
use turbomcp_protocol::{Error as McpError, Result as McpResult};

/// API Key authentication provider
#[derive(Debug)]
pub struct ApiKeyProvider {
    /// Provider name
    name: String,
    /// Valid API keys with associated user info
    api_keys: Arc<RwLock<HashMap<String, UserInfo>>>,
}

impl ApiKeyProvider {
    /// Create a new API key provider
    #[must_use]
    pub fn new(name: String) -> Self {
        Self {
            name,
            api_keys: Arc::new(RwLock::new(HashMap::new())),
        }
    }

    /// Add an API key
    pub async fn add_api_key(&self, key: String, user_info: UserInfo) {
        self.api_keys.write().await.insert(key, user_info);
    }

    /// Remove an API key
    pub async fn remove_api_key(&self, key: &str) -> bool {
        self.api_keys.write().await.remove(key).is_some()
    }

    /// List all API keys (returns keys only, not full info for security)
    pub async fn list_api_keys(&self) -> Vec<String> {
        self.api_keys.read().await.keys().cloned().collect()
    }
}

impl AuthProvider for ApiKeyProvider {
    fn name(&self) -> &str {
        &self.name
    }

    fn provider_type(&self) -> AuthProviderType {
        AuthProviderType::ApiKey
    }

    fn authenticate(
        &self,
        credentials: AuthCredentials,
    ) -> Pin<Box<dyn Future<Output = McpResult<AuthContext>> + Send + '_>> {
        Box::pin(async move {
            match credentials {
                AuthCredentials::ApiKey { key } => {
                    let api_keys = self.api_keys.read().await;

                    // Use constant-time comparison to prevent timing attacks
                    // Instead of HashMap::get (which uses string equality), we iterate
                    // and use secure comparison for each key
                    let mut matched_user_info: Option<UserInfo> = None;

                    for (stored_key, user_info) in api_keys.iter() {
                        if validate_api_key(&key, stored_key) {
                            matched_user_info = Some(user_info.clone());
                            break;
                        }
                    }

                    if let Some(user_info) = matched_user_info {
                        let token = TokenInfo {
                            access_token: key,
                            token_type: "ApiKey".to_string(),
                            refresh_token: None,
                            expires_in: None,
                            scope: None,
                        };

                        AuthContext::builder()
                            .subject(user_info.id.clone())
                            .user(user_info.clone())
                            .roles(vec!["api_user".to_string()])
                            .permissions(vec!["api_access".to_string()])
                            .request_id(uuid::Uuid::new_v4().to_string())
                            .token(token)
                            .provider(self.name.clone())
                            .build()
                            .map_err(|e| McpError::internal(e.to_string()))
                    } else {
                        Err(McpError::internal("Invalid API key".to_string()))
                    }
                }
                _ => Err(McpError::internal(
                    "Invalid credentials for API key provider".to_string(),
                )),
            }
        })
    }

    fn validate_token(
        &self,
        token: &str,
    ) -> Pin<Box<dyn Future<Output = McpResult<AuthContext>> + Send + '_>> {
        let token = token.to_string();
        Box::pin(async move {
            self.authenticate(AuthCredentials::ApiKey { key: token })
                .await
        })
    }

    fn refresh_token(
        &self,
        _refresh_token: &str,
    ) -> Pin<Box<dyn Future<Output = McpResult<TokenInfo>> + Send + '_>> {
        Box::pin(async {
            Err(McpError::internal(
                "API keys do not support token refresh".to_string(),
            ))
        })
    }

    fn revoke_token(
        &self,
        token: &str,
    ) -> Pin<Box<dyn Future<Output = McpResult<()>> + Send + '_>> {
        let token = token.to_string();
        Box::pin(async move {
            let removed = self.remove_api_key(&token).await;
            if removed {
                Ok(())
            } else {
                Err(McpError::internal("API key not found".to_string()))
            }
        })
    }

    fn get_user_info(
        &self,
        token: &str,
    ) -> Pin<Box<dyn Future<Output = McpResult<UserInfo>> + Send + '_>> {
        let token = token.to_string();
        Box::pin(async move {
            let api_keys = self.api_keys.read().await;

            // Use constant-time comparison to prevent timing attacks
            for (stored_key, user_info) in api_keys.iter() {
                if validate_api_key(&token, stored_key) {
                    return Ok(user_info.clone());
                }
            }

            Err(McpError::internal("Invalid API key".to_string()))
        })
    }
}