1use std::collections::HashMap;
6use std::sync::Arc;
7use std::time::{Duration, SystemTime};
8
9use tokio::sync::RwLock;
10
11use super::config::AuthConfig;
12use super::types::{AuthContext, AuthCredentials, AuthProvider};
13use turbomcp_protocol::{Error as McpError, Result as McpResult};
14
15#[derive(Debug)]
17pub struct AuthManager {
18 config: AuthConfig,
20 providers: Arc<RwLock<HashMap<String, Arc<dyn AuthProvider>>>>,
22 sessions: Arc<RwLock<HashMap<String, AuthContext>>>,
24 _cleanup_handle: Option<tokio::task::JoinHandle<()>>,
26}
27
28impl AuthManager {
29 #[must_use]
31 pub fn new(config: AuthConfig) -> Self {
32 let manager = Self {
33 config,
34 providers: Arc::new(RwLock::new(HashMap::new())),
35 sessions: Arc::new(RwLock::new(HashMap::new())),
36 _cleanup_handle: None,
37 };
38
39 let sessions_clone = manager.sessions.clone();
41 let cleanup_handle = tokio::spawn(async move {
42 let mut interval = tokio::time::interval(Duration::from_secs(300)); loop {
44 interval.tick().await;
45 let now = SystemTime::now();
46 let mut sessions = sessions_clone.write().await;
47 sessions
48 .retain(|_, context| context.expires_at.is_none_or(|expires| expires > now));
49 }
50 });
51
52 Self {
53 _cleanup_handle: Some(cleanup_handle),
54 ..manager
55 }
56 }
57
58 pub async fn add_provider(&self, provider: Arc<dyn AuthProvider>) {
60 let name = provider.name().to_string();
61 self.providers.write().await.insert(name, provider);
62 }
63
64 pub async fn remove_provider(&self, name: &str) -> bool {
66 self.providers.write().await.remove(name).is_some()
67 }
68
69 pub async fn list_providers(&self) -> Vec<String> {
71 self.providers.read().await.keys().cloned().collect()
72 }
73
74 pub async fn authenticate(
76 &self,
77 provider_name: &str,
78 credentials: AuthCredentials,
79 ) -> McpResult<AuthContext> {
80 if !self.config.enabled {
81 return Err(McpError::internal("Authentication is disabled".to_string()));
82 }
83
84 let providers = self.providers.read().await;
85 let provider = providers
86 .get(provider_name)
87 .ok_or_else(|| McpError::internal(format!("Provider '{provider_name}' not found")))?;
88
89 let mut auth_context = provider.authenticate(credentials).await?;
90
91 if auth_context.roles.is_empty() {
93 auth_context.roles = self.config.authorization.default_roles.clone();
94 }
95
96 let session_id = auth_context.session_id.clone();
98 self.sessions
99 .write()
100 .await
101 .insert(session_id, auth_context.clone());
102
103 Ok(auth_context)
104 }
105
106 pub async fn validate_token(
108 &self,
109 token: &str,
110 provider_name: Option<&str>,
111 ) -> McpResult<AuthContext> {
112 if !self.config.enabled {
113 return Err(McpError::internal("Authentication is disabled".to_string()));
114 }
115
116 let providers = self.providers.read().await;
117
118 if let Some(provider_name) = provider_name {
119 let provider = providers.get(provider_name).ok_or_else(|| {
120 McpError::internal(format!("Provider '{provider_name}' not found"))
121 })?;
122 provider.validate_token(token).await
123 } else {
124 for provider in providers.values() {
126 if let Ok(context) = provider.validate_token(token).await {
127 return Ok(context);
128 }
129 }
130 Err(McpError::internal("Token validation failed".to_string()))
131 }
132 }
133
134 pub async fn get_session(&self, session_id: &str) -> Option<AuthContext> {
136 self.sessions.read().await.get(session_id).cloned()
137 }
138
139 pub async fn revoke_session(&self, session_id: &str) -> McpResult<()> {
141 let context = self
142 .sessions
143 .write()
144 .await
145 .remove(session_id)
146 .ok_or_else(|| McpError::internal("Session not found".to_string()))?;
147
148 let providers = self.providers.read().await;
150 if let Some(provider) = providers.get(&context.provider)
151 && let Some(token) = &context.token
152 {
153 let _ = provider.revoke_token(&token.access_token).await;
154 }
155
156 Ok(())
157 }
158
159 #[must_use]
161 pub fn check_permission(&self, context: &AuthContext, permission: &str) -> bool {
162 context.permissions.contains(&permission.to_string())
163 || context.roles.iter().any(|role| {
164 self.config
165 .authorization
166 .inheritance_rules
167 .get(role)
168 .is_some_and(|perms| perms.contains(&permission.to_string()))
169 })
170 }
171
172 #[must_use]
174 pub fn check_role(&self, context: &AuthContext, role: &str) -> bool {
175 context.roles.contains(&role.to_string())
176 }
177}
178
179static GLOBAL_AUTH_MANAGER: once_cell::sync::Lazy<tokio::sync::RwLock<Option<Arc<AuthManager>>>> =
184 once_cell::sync::Lazy::new(|| tokio::sync::RwLock::new(None));
185
186pub async fn set_global_auth_manager(manager: Arc<AuthManager>) {
188 *GLOBAL_AUTH_MANAGER.write().await = Some(manager);
189}
190
191pub async fn global_auth_manager() -> Option<Arc<AuthManager>> {
193 GLOBAL_AUTH_MANAGER.read().await.clone()
194}
195
196pub async fn check_auth(token: &str) -> McpResult<AuthContext> {
198 if let Some(manager) = global_auth_manager().await {
199 manager.validate_token(token, None).await
200 } else {
201 Err(McpError::internal(
202 "Authentication manager not initialized".to_string(),
203 ))
204 }
205}
206
207#[cfg(test)]
208mod tests {
209 use super::*;
210 use crate::{
211 config::{
212 AuthorizationConfig, OAuth2Config, OAuth2FlowType, SecurityLevel, SessionConfig,
213 SessionStorageType,
214 },
215 providers::ApiKeyProvider,
216 types::UserInfo,
217 };
218 use std::collections::HashMap;
219
220 #[test]
221 fn test_oauth2_config() {
222 let config = OAuth2Config {
223 client_id: "test_client".to_string(),
224 client_secret: "test_secret".to_string(),
225 auth_url: "https://auth.example.com/oauth/authorize".to_string(),
226 token_url: "https://auth.example.com/oauth/token".to_string(),
227 redirect_uri: "http://localhost:8080/callback".to_string(),
228 scopes: vec!["read".to_string(), "write".to_string()],
229 flow_type: OAuth2FlowType::AuthorizationCode,
230 additional_params: HashMap::new(),
231 security_level: SecurityLevel::Standard,
232 mcp_resource_uri: None,
233 auto_resource_indicators: false,
234 #[cfg(feature = "dpop")]
235 dpop_config: None,
236 };
237
238 assert_eq!(config.client_id, "test_client");
239 assert_eq!(config.flow_type, OAuth2FlowType::AuthorizationCode);
240 }
241
242 #[test]
243 fn test_oauth2_pkce_integration() {
244 let (challenge1, _verifier1) = oauth2::PkceCodeChallenge::new_random_sha256();
246 let (challenge2, _verifier2) = oauth2::PkceCodeChallenge::new_random_sha256();
247
248 assert_ne!(challenge1.as_str(), challenge2.as_str());
250 assert!(!challenge1.as_str().is_empty());
251 assert!(!challenge2.as_str().is_empty());
252 }
253
254 #[tokio::test]
255 async fn test_api_key_provider() {
256 let provider = ApiKeyProvider::new("test_api".to_string());
257
258 let user_info = UserInfo {
259 id: "user123".to_string(),
260 username: "testuser".to_string(),
261 email: Some("test@example.com".to_string()),
262 display_name: Some("Test User".to_string()),
263 avatar_url: None,
264 metadata: HashMap::new(),
265 };
266
267 provider
268 .add_api_key("test_key_123".to_string(), user_info.clone())
269 .await;
270
271 let credentials = AuthCredentials::ApiKey {
272 key: "test_key_123".to_string(),
273 };
274
275 let auth_result = provider.authenticate(credentials).await;
276 assert!(auth_result.is_ok());
277
278 let context = auth_result.unwrap();
279 assert_eq!(context.user.username, "testuser");
280 assert_eq!(context.provider, "test_api");
281 }
282
283 #[tokio::test]
284 async fn test_auth_manager() {
285 let config = AuthConfig {
286 enabled: true,
287 providers: vec![],
288 session: SessionConfig {
289 timeout_seconds: 3600,
290 secure_cookies: true,
291 cookie_domain: None,
292 storage: SessionStorageType::Memory,
293 max_sessions_per_user: Some(5),
294 },
295 authorization: AuthorizationConfig {
296 rbac_enabled: true,
297 default_roles: vec!["user".to_string()],
298 inheritance_rules: HashMap::new(),
299 resource_permissions: HashMap::new(),
300 },
301 };
302
303 let manager = AuthManager::new(config);
304 let api_provider = Arc::new(ApiKeyProvider::new("api".to_string()));
305 manager.add_provider(api_provider.clone()).await;
306
307 let providers = manager.list_providers().await;
308 assert!(providers.contains(&"api".to_string()));
309 }
310}