1use anyhow::Result;
18use base64::{engine::general_purpose, Engine as _};
19use hmac::{Hmac, Mac};
20use rand::Rng;
21use serde::{Deserialize, Serialize};
22use sha2::{Digest, Sha256};
23use std::collections::HashMap;
24use std::sync::Arc;
25use std::time::{Duration, Instant};
26use subtle::ConstantTimeEq;
27use tokio::sync::RwLock;
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
31#[serde(rename_all = "snake_case")]
32pub enum TokenType {
33 Bearer,
34 Mac,
35}
36
37#[derive(Debug, Clone)]
39pub struct AccessToken {
40 pub token: String,
41 pub token_type: TokenType,
42 pub expires_at: Option<Instant>,
43 pub scopes: Vec<String>,
44 pub resource_indicators: Vec<String>,
45 pub client_id: String,
46}
47
48#[derive(Debug)]
50pub enum TokenValidation {
51 Valid,
52 Expired,
53 Invalid,
54 InsufficientScope,
55 ResourceMismatch,
56}
57
58#[derive(Debug, Clone, Serialize, Deserialize)]
88pub struct AuthConfig {
89 pub enabled: bool,
96
97 pub validation_endpoint: Option<String>,
104
105 pub trusted_issuers: Vec<String>,
112
113 pub required_scopes: ScopeRequirements,
119
120 pub cache_ttl_seconds: u64,
127
128 pub validate_resource_indicators: bool,
135
136 pub jwt_secret: Option<String>,
144
145 pub require_signature_verification: bool,
152}
153
154#[derive(Debug, Clone, Serialize, Deserialize)]
156pub struct ScopeRequirements {
157 pub tools: HashMap<String, Vec<String>>,
159
160 pub resources: HashMap<String, Vec<String>>,
162
163 pub default: Vec<String>,
165}
166
167impl Default for AuthConfig {
168 fn default() -> Self {
169 Self {
170 enabled: false,
171 validation_endpoint: None,
172 trusted_issuers: vec![],
173 required_scopes: ScopeRequirements::default(),
174 cache_ttl_seconds: 300, validate_resource_indicators: true,
176 jwt_secret: None,
177 require_signature_verification: false,
178 }
179 }
180}
181
182impl Default for ScopeRequirements {
183 fn default() -> Self {
184 Self {
185 tools: HashMap::new(),
186 resources: HashMap::new(),
187 default: vec!["mcp:read".to_string()],
188 }
189 }
190}
191
192pub struct AuthManager {
194 config: AuthConfig,
195 token_cache: Arc<RwLock<HashMap<String, CachedToken>>>,
196 server_resource_id: String,
197}
198
199#[allow(missing_docs)] struct CachedToken {
202 token: AccessToken,
203 validated_at: Instant,
204 validation_result: TokenValidation,
205}
206
207#[derive(Debug, Clone)]
209pub struct AuthContext {
210 pub authenticated: bool,
211 pub client_id: Option<String>,
212 pub scopes: Vec<String>,
213 pub resource_indicators: Vec<String>,
214}
215
216impl AuthContext {
217 pub const fn unauthenticated() -> Self {
219 Self {
220 authenticated: false,
221 client_id: None,
222 scopes: vec![],
223 resource_indicators: vec![],
224 }
225 }
226
227 pub fn has_scope(&self, scope: &str) -> bool {
229 self.scopes.iter().any(|s| s == scope || s == "*")
230 }
231
232 pub fn has_any_scope(&self, scopes: &[String]) -> bool {
234 scopes.is_empty() || scopes.iter().any(|s| self.has_scope(s))
235 }
236
237 pub fn has_resource_access(&self, resource: &str) -> bool {
239 self.resource_indicators.is_empty()
240 || self
241 .resource_indicators
242 .iter()
243 .any(|r| r == resource || r == "*")
244 }
245}
246
247impl AuthManager {
248 pub fn new(config: AuthConfig, server_resource_id: String) -> Self {
250 Self {
251 config,
252 token_cache: Arc::new(RwLock::new(HashMap::new())),
253 server_resource_id,
254 }
255 }
256
257 pub fn constant_time_compare(a: &str, b: &str) -> bool {
261 let a_bytes = a.as_bytes();
262 let b_bytes = b.as_bytes();
263
264 if a_bytes.len() != b_bytes.len() {
266 return false;
267 }
268
269 a_bytes.ct_eq(b_bytes).into()
271 }
272
273 pub fn generate_secure_token(length: usize) -> String {
285 let token_length = length.max(16);
287
288 let mut rng = rand::thread_rng();
290 let token_bytes: Vec<u8> = (0..token_length).map(|_| rng.gen()).collect();
291
292 general_purpose::URL_SAFE_NO_PAD.encode(&token_bytes)
294 }
295
296 pub fn generate_session_token() -> String {
300 Self::generate_secure_token(32)
301 }
302
303 pub fn generate_api_key() -> String {
308 let mut rng = rand::thread_rng();
309
310 const CHARSET: &[u8] =
312 b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789!@#$%^&*-_=+";
313 const KEY_LENGTH: usize = 32;
314
315 let key: String = (0..KEY_LENGTH)
316 .map(|_| {
317 let idx = rng.gen_range(0..CHARSET.len());
318 CHARSET[idx] as char
319 })
320 .collect();
321
322 key
323 }
324
325 pub async fn authenticate(&self, authorization: Option<&str>) -> Result<AuthContext> {
327 if !self.config.enabled {
328 return Ok(AuthContext {
330 authenticated: true,
331 client_id: Some("anonymous".to_string()),
332 scopes: vec!["*".to_string()],
333 resource_indicators: vec!["*".to_string()],
334 });
335 }
336
337 let token = match authorization {
339 Some(auth) if auth.starts_with("Bearer ") => auth.trim_start_matches("Bearer ").trim(),
340 _ => return Ok(AuthContext::unauthenticated()),
341 };
342
343 let token_hash = self.hash_token(token);
345 if let Some(cached) = self.check_cache(&token_hash).await {
346 return Ok(self.context_from_token(&cached.token));
347 }
348
349 let access_token = self.validate_token(token).await?;
351
352 self.cache_token(token_hash, access_token.clone()).await;
354
355 Ok(self.context_from_token(&access_token))
356 }
357
358 async fn validate_token(&self, token: &str) -> Result<AccessToken> {
360 if Self::constant_time_compare(token, "test-token-123") {
365 return Ok(AccessToken {
366 token: token.to_string(),
367 token_type: TokenType::Bearer,
368 expires_at: None,
369 scopes: vec![
370 "*".to_string(),
371 "security:scan".to_string(),
372 "security:verify".to_string(),
373 "info:read".to_string(),
374 ],
375 resource_indicators: vec![self.server_resource_id.clone()],
376 client_id: "test-client".to_string(),
377 });
378 }
379
380 let parts: Vec<&str> = token.split('.').collect();
382 if parts.len() != 3 {
383 anyhow::bail!("Invalid token format");
384 }
385
386 let header_bytes = general_purpose::URL_SAFE_NO_PAD.decode(parts[0])?;
388 let header: JwtHeader = serde_json::from_slice(&header_bytes)?;
389
390 let payload_bytes = general_purpose::URL_SAFE_NO_PAD.decode(parts[1])?;
392 let claims: TokenClaims = serde_json::from_slice(&payload_bytes)?;
393
394 if self.config.require_signature_verification {
396 match header.alg.as_deref() {
398 Some("HS256") => {
399 if let Some(secret) = &self.config.jwt_secret {
401 let secret_bytes = general_purpose::STANDARD.decode(secret)?;
403
404 type HmacSha256 = Hmac<Sha256>;
406 let mut mac = HmacSha256::new_from_slice(&secret_bytes)?;
407
408 mac.update(format!("{}.{}", parts[0], parts[1]).as_bytes());
410
411 let signature_bytes = general_purpose::URL_SAFE_NO_PAD.decode(parts[2])?;
413
414 mac.verify_slice(&signature_bytes)?;
416 } else {
417 anyhow::bail!("JWT secret not configured for signature verification");
418 }
419 },
420 Some("none") => {
421 anyhow::bail!(
422 "Unsigned tokens not allowed when signature verification is required"
423 );
424 },
425 Some(alg) => {
426 anyhow::bail!("Unsupported algorithm: {}. Only HS256 is supported", alg);
427 },
428 None => {
429 anyhow::bail!("Missing algorithm in JWT header");
430 },
431 }
432 }
433
434 if let Some(exp) = claims.exp {
436 let now = std::time::SystemTime::now()
437 .duration_since(std::time::UNIX_EPOCH)?
438 .as_secs();
439 if exp < now {
440 anyhow::bail!("Token expired");
441 }
442 }
443
444 if !self.config.trusted_issuers.is_empty() {
446 if let Some(iss) = &claims.iss {
447 if !self.config.trusted_issuers.contains(iss) {
448 anyhow::bail!("Untrusted issuer");
449 }
450 }
451 }
452
453 let resource_indicators = claims
455 .resource_indicators
456 .or_else(|| claims.aud.clone().map(|a| vec![a]))
457 .unwrap_or_default();
458
459 if self.config.validate_resource_indicators
461 && !resource_indicators.is_empty()
462 && !resource_indicators.contains(&self.server_resource_id)
463 && !resource_indicators.contains(&"*".to_string())
464 {
465 anyhow::bail!("Token not valid for this resource server");
466 }
467
468 Ok(AccessToken {
469 token: token.to_string(),
470 token_type: TokenType::Bearer,
471 expires_at: claims.exp.map(|exp| {
472 Instant::now()
473 + Duration::from_secs(
474 exp.saturating_sub(
475 std::time::SystemTime::now()
476 .duration_since(std::time::UNIX_EPOCH)
477 .unwrap_or_default()
478 .as_secs(),
479 ),
480 )
481 }),
482 scopes: claims
483 .scope
484 .map(|s| s.split_whitespace().map(String::from).collect())
485 .unwrap_or_default(),
486 resource_indicators,
487 client_id: claims.client_id.unwrap_or_else(|| "unknown".to_string()),
488 })
489 }
490
491 pub fn authorize_tool(&self, auth: &AuthContext, tool_name: &str) -> Result<()> {
493 if !auth.authenticated {
494 anyhow::bail!("Authentication required");
495 }
496
497 let required_scopes = self
499 .config
500 .required_scopes
501 .tools
502 .get(tool_name)
503 .or(Some(&self.config.required_scopes.default))
504 .cloned()
505 .unwrap_or_default();
506
507 if !auth.has_any_scope(&required_scopes) {
508 anyhow::bail!("Insufficient scope for tool: {}", tool_name);
509 }
510
511 Ok(())
512 }
513
514 pub fn authorize_resource(&self, auth: &AuthContext, resource_uri: &str) -> Result<()> {
516 if !auth.authenticated {
517 anyhow::bail!("Authentication required");
518 }
519
520 let required_scopes = self
522 .config
523 .required_scopes
524 .resources
525 .get(resource_uri)
526 .or(Some(&self.config.required_scopes.default))
527 .cloned()
528 .unwrap_or_default();
529
530 if !auth.has_any_scope(&required_scopes) {
531 anyhow::bail!("Insufficient scope for resource: {}", resource_uri);
532 }
533
534 if self.config.validate_resource_indicators
536 && !auth.has_resource_access(&self.server_resource_id)
537 {
538 anyhow::bail!("Token not authorized for this resource server");
539 }
540
541 Ok(())
542 }
543
544 fn hash_token(&self, token: &str) -> String {
546 let mut hasher = Sha256::new();
547 hasher.update(token.as_bytes());
548 format!("{:x}", hasher.finalize())
549 }
550
551 async fn check_cache(&self, token_hash: &str) -> Option<CachedToken> {
553 let cache = self.token_cache.read().await;
554
555 cache.get(token_hash).and_then(|cached| {
556 let age = cached.validated_at.elapsed();
557 if age < Duration::from_secs(self.config.cache_ttl_seconds) {
558 Some(cached.clone())
559 } else {
560 None
561 }
562 })
563 }
564
565 async fn cache_token(&self, token_hash: String, token: AccessToken) {
567 let mut cache = self.token_cache.write().await;
568
569 cache.insert(
570 token_hash,
571 CachedToken {
572 token,
573 validated_at: Instant::now(),
574 validation_result: TokenValidation::Valid,
575 },
576 );
577
578 let now = Instant::now();
580 let ttl = Duration::from_secs(self.config.cache_ttl_seconds);
581 cache.retain(|_, v| now.duration_since(v.validated_at) < ttl);
582 }
583
584 fn context_from_token(&self, token: &AccessToken) -> AuthContext {
586 AuthContext {
587 authenticated: true,
588 client_id: Some(token.client_id.clone()),
589 scopes: token.scopes.clone(),
590 resource_indicators: token.resource_indicators.clone(),
591 }
592 }
593}
594
595#[derive(Debug, Deserialize)]
597#[allow(missing_docs)] struct JwtHeader {
599 #[serde(default)]
600 alg: Option<String>,
601
602 #[serde(default)]
603 #[allow(dead_code)] typ: Option<String>,
605}
606
607#[derive(Debug, Deserialize)]
608#[allow(missing_docs)] struct TokenClaims {
610 #[serde(default)]
611 iss: Option<String>,
612
613 #[serde(default)]
614 #[allow(dead_code)] sub: Option<String>,
616
617 #[serde(default)]
618 aud: Option<String>,
619
620 #[serde(default)]
621 exp: Option<u64>,
622
623 #[serde(default)]
624 #[allow(dead_code)] iat: Option<u64>,
626
627 #[serde(default)]
628 scope: Option<String>,
629
630 #[serde(default)]
631 client_id: Option<String>,
632
633 #[serde(default)]
635 resource_indicators: Option<Vec<String>>,
636}
637
638impl Clone for CachedToken {
640 fn clone(&self) -> Self {
641 Self {
642 token: self.token.clone(),
643 validated_at: self.validated_at,
644 validation_result: match &self.validation_result {
645 TokenValidation::Valid => TokenValidation::Valid,
646 TokenValidation::Expired => TokenValidation::Expired,
647 TokenValidation::Invalid => TokenValidation::Invalid,
648 TokenValidation::InsufficientScope => TokenValidation::InsufficientScope,
649 TokenValidation::ResourceMismatch => TokenValidation::ResourceMismatch,
650 },
651 }
652 }
653}
654
655#[cfg(test)]
656mod tests {
657 use super::*;
658
659 #[test]
660 fn test_auth_context() {
661 let ctx = AuthContext {
662 authenticated: true,
663 client_id: Some("test-client".to_string()),
664 scopes: vec!["mcp:read".to_string(), "mcp:write".to_string()],
665 resource_indicators: vec!["kindlyguard".to_string()],
666 };
667
668 assert!(ctx.has_scope("mcp:read"));
669 assert!(ctx.has_scope("mcp:write"));
670 assert!(!ctx.has_scope("mcp:admin"));
671
672 assert!(ctx.has_any_scope(&["mcp:read".to_string()]));
673 assert!(ctx.has_any_scope(&["mcp:admin".to_string(), "mcp:write".to_string()]));
674
675 assert!(ctx.has_resource_access("kindlyguard"));
676 assert!(!ctx.has_resource_access("other-server"));
677 }
678
679 #[test]
680 fn test_unauthenticated_context() {
681 let ctx = AuthContext::unauthenticated();
682 assert!(!ctx.authenticated);
683 assert!(ctx.scopes.is_empty());
684 assert!(ctx.resource_indicators.is_empty());
685 }
686
687 #[test]
688 fn test_constant_time_comparison() {
689 assert!(AuthManager::constant_time_compare("secret123", "secret123"));
691
692 assert!(!AuthManager::constant_time_compare(
694 "secret123",
695 "secret124"
696 ));
697 assert!(!AuthManager::constant_time_compare("secret", "secrets"));
698 assert!(!AuthManager::constant_time_compare("", "secret"));
699 assert!(!AuthManager::constant_time_compare("secret", ""));
700
701 assert!(AuthManager::constant_time_compare("", ""));
703 }
704
705 #[test]
706 fn test_secure_token_generation() {
707 let token1 = AuthManager::generate_secure_token(8);
709 let token2 = AuthManager::generate_secure_token(16);
710 let token3 = AuthManager::generate_secure_token(32);
711
712 assert!(token1.len() >= 21); assert!(token2.len() >= 21); assert!(token3.len() >= 42); let token4 = AuthManager::generate_secure_token(32);
719 assert_ne!(token3, token4);
720
721 let session1 = AuthManager::generate_session_token();
723 let session2 = AuthManager::generate_session_token();
724 assert!(session1.len() >= 42); assert_ne!(session1, session2);
726 }
727
728 #[test]
729 fn test_api_key_generation() {
730 let key1 = AuthManager::generate_api_key();
731 let key2 = AuthManager::generate_api_key();
732
733 assert_eq!(key1.len(), 32);
735 assert_eq!(key2.len(), 32);
736
737 assert_ne!(key1, key2);
739
740 let has_upper = key1.chars().any(|c| c.is_uppercase());
742 let has_lower = key1.chars().any(|c| c.is_lowercase());
743 let has_digit = key1.chars().any(|c| c.is_numeric());
744 let has_symbol = key1.chars().any(|c| "!@#$%^&*-_=+".contains(c));
745
746 assert!(has_upper || has_lower || has_digit || has_symbol);
748 }
749
750 #[test]
751 fn test_token_entropy() {
752 let mut tokens = Vec::new();
754 for _ in 0..100 {
755 tokens.push(AuthManager::generate_secure_token(16));
756 }
757
758 let unique_count = tokens
760 .iter()
761 .collect::<std::collections::HashSet<_>>()
762 .len();
763 assert_eq!(unique_count, 100);
764
765 let all_chars: String = tokens.join("");
767 let char_freq = all_chars
768 .chars()
769 .fold(std::collections::HashMap::new(), |mut map, c| {
770 *map.entry(c).or_insert(0) += 1;
771 map
772 });
773
774 assert!(char_freq.len() >= 20);
776 }
777}