use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use tokio::sync::RwLock;
use super::super::api_key_validation::validate_api_key;
use super::super::config::AuthProviderType;
use super::super::context::AuthContext;
use super::super::types::{AuthCredentials, AuthProvider, TokenInfo, UserInfo};
use turbomcp_protocol::{Error as McpError, Result as McpResult};
#[derive(Debug)]
pub struct ApiKeyProvider {
name: String,
api_keys: Arc<RwLock<HashMap<String, UserInfo>>>,
}
impl ApiKeyProvider {
#[must_use]
pub fn new(name: String) -> Self {
Self {
name,
api_keys: Arc::new(RwLock::new(HashMap::new())),
}
}
pub async fn add_api_key(&self, key: String, user_info: UserInfo) {
self.api_keys.write().await.insert(key, user_info);
}
pub async fn remove_api_key(&self, key: &str) -> bool {
self.api_keys.write().await.remove(key).is_some()
}
pub async fn list_api_keys(&self) -> Vec<String> {
self.api_keys.read().await.keys().cloned().collect()
}
}
impl AuthProvider for ApiKeyProvider {
fn name(&self) -> &str {
&self.name
}
fn provider_type(&self) -> AuthProviderType {
AuthProviderType::ApiKey
}
fn authenticate(
&self,
credentials: AuthCredentials,
) -> Pin<Box<dyn Future<Output = McpResult<AuthContext>> + Send + '_>> {
Box::pin(async move {
match credentials {
AuthCredentials::ApiKey { key } => {
let api_keys = self.api_keys.read().await;
let mut matched_user_info: Option<UserInfo> = None;
for (stored_key, user_info) in api_keys.iter() {
if validate_api_key(&key, stored_key) {
matched_user_info = Some(user_info.clone());
break;
}
}
if let Some(user_info) = matched_user_info {
let token = TokenInfo {
access_token: key,
token_type: "ApiKey".to_string(),
refresh_token: None,
expires_in: None,
scope: None,
};
AuthContext::builder()
.subject(user_info.id.clone())
.user(user_info.clone())
.roles(vec!["api_user".to_string()])
.permissions(vec!["api_access".to_string()])
.request_id(uuid::Uuid::new_v4().to_string())
.token(token)
.provider(self.name.clone())
.build()
.map_err(|e| McpError::internal(e.to_string()))
} else {
Err(McpError::internal("Invalid API key".to_string()))
}
}
_ => Err(McpError::internal(
"Invalid credentials for API key provider".to_string(),
)),
}
})
}
fn validate_token(
&self,
token: &str,
) -> Pin<Box<dyn Future<Output = McpResult<AuthContext>> + Send + '_>> {
let token = token.to_string();
Box::pin(async move {
self.authenticate(AuthCredentials::ApiKey { key: token })
.await
})
}
fn refresh_token(
&self,
_refresh_token: &str,
) -> Pin<Box<dyn Future<Output = McpResult<TokenInfo>> + Send + '_>> {
Box::pin(async {
Err(McpError::internal(
"API keys do not support token refresh".to_string(),
))
})
}
fn revoke_token(
&self,
token: &str,
) -> Pin<Box<dyn Future<Output = McpResult<()>> + Send + '_>> {
let token = token.to_string();
Box::pin(async move {
let removed = self.remove_api_key(&token).await;
if removed {
Ok(())
} else {
Err(McpError::internal("API key not found".to_string()))
}
})
}
fn get_user_info(
&self,
token: &str,
) -> Pin<Box<dyn Future<Output = McpResult<UserInfo>> + Send + '_>> {
let token = token.to_string();
Box::pin(async move {
let api_keys = self.api_keys.read().await;
for (stored_key, user_info) in api_keys.iter() {
if validate_api_key(&token, stored_key) {
return Ok(user_info.clone());
}
}
Err(McpError::internal("Invalid API key".to_string()))
})
}
}