use base64::engine::general_purpose::STANDARD as BASE64;
use base64::Engine;
use chrono::{DateTime, Duration, Utc};
use jsonwebtoken::{decode, encode, Algorithm, DecodingKey, EncodingKey, Header, Validation};
use rand::Rng;
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use crate::error::{McpError, McpResult};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TokenClaims {
pub sub: String,
pub iat: i64,
pub exp: i64,
pub iss: String,
pub aud: String,
pub scopes: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub token_type: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TokenResponse {
pub access_token: String,
pub token_type: String,
pub expires_in: i64,
#[serde(skip_serializing_if = "Option::is_none")]
pub refresh_token: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub scope: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ClientCredentials {
pub client_id: String,
pub client_secret_hash: String,
pub allowed_scopes: Vec<String>,
pub name: String,
pub created_at: DateTime<Utc>,
}
impl ClientCredentials {
pub fn new(client_id: String, client_secret: &str, name: String, scopes: Vec<String>) -> Self {
Self {
client_id,
client_secret_hash: hash_secret(client_secret),
allowed_scopes: scopes,
name,
created_at: Utc::now(),
}
}
pub fn verify_secret(&self, secret: &str) -> bool {
self.client_secret_hash == hash_secret(secret)
}
}
#[derive(Debug, Clone)]
struct RefreshToken {
#[allow(dead_code)]
token: String,
client_id: String,
scopes: Vec<String>,
expires_at: DateTime<Utc>,
}
pub struct OAuthManager {
encoding_key: EncodingKey,
decoding_key: DecodingKey,
clients: Arc<RwLock<HashMap<String, ClientCredentials>>>,
refresh_tokens: Arc<RwLock<HashMap<String, RefreshToken>>>,
issuer: String,
audience: String,
access_token_ttl: i64,
refresh_token_ttl: i64,
}
impl OAuthManager {
pub fn new(secret_key: &str, issuer: String, audience: String) -> Self {
Self {
encoding_key: EncodingKey::from_secret(secret_key.as_bytes()),
decoding_key: DecodingKey::from_secret(secret_key.as_bytes()),
clients: Arc::new(RwLock::new(HashMap::new())),
refresh_tokens: Arc::new(RwLock::new(HashMap::new())),
issuer,
audience,
access_token_ttl: 3600, refresh_token_ttl: 86400 * 7, }
}
pub async fn register_client(&self, credentials: ClientCredentials) -> McpResult<()> {
let mut clients = self.clients.write().await;
clients.insert(credentials.client_id.clone(), credentials);
Ok(())
}
pub async fn client_credentials_grant(
&self,
client_id: &str,
client_secret: &str,
requested_scopes: Vec<String>,
) -> McpResult<TokenResponse> {
let clients = self.clients.read().await;
let client = clients
.get(client_id)
.ok_or_else(|| McpError::AuthenticationError {
message: "Invalid client credentials".to_string(),
})?;
if !client.verify_secret(client_secret) {
return Err(McpError::AuthenticationError {
message: "Invalid client credentials".to_string(),
});
}
let granted_scopes: Vec<String> = requested_scopes
.into_iter()
.filter(|scope| client.allowed_scopes.contains(scope))
.collect();
if granted_scopes.is_empty() {
return Err(McpError::AuthenticationError {
message: "No valid scopes requested".to_string(),
});
}
drop(clients);
let access_token = self.generate_access_token(client_id, &granted_scopes)?;
let refresh_token = self
.generate_refresh_token(client_id, &granted_scopes)
.await?;
Ok(TokenResponse {
access_token,
token_type: "Bearer".to_string(),
expires_in: self.access_token_ttl,
refresh_token: Some(refresh_token),
scope: Some(granted_scopes.join(" ")),
})
}
pub async fn refresh_token_grant(&self, refresh_token: &str) -> McpResult<TokenResponse> {
let mut tokens = self.refresh_tokens.write().await;
let token_data = tokens
.get(refresh_token)
.ok_or_else(|| McpError::AuthenticationError {
message: "Invalid refresh token".to_string(),
})?
.clone();
if Utc::now() > token_data.expires_at {
tokens.remove(refresh_token);
return Err(McpError::AuthenticationError {
message: "Refresh token expired".to_string(),
});
}
drop(tokens);
let access_token = self.generate_access_token(&token_data.client_id, &token_data.scopes)?;
Ok(TokenResponse {
access_token,
token_type: "Bearer".to_string(),
expires_in: self.access_token_ttl,
refresh_token: None, scope: Some(token_data.scopes.join(" ")),
})
}
pub fn validate_token(&self, token: &str) -> McpResult<TokenClaims> {
let mut validation = Validation::new(Algorithm::HS256);
validation.set_issuer(&[&self.issuer]);
validation.set_audience(&[&self.audience]);
let token_data =
decode::<TokenClaims>(token, &self.decoding_key, &validation).map_err(|e| {
McpError::AuthenticationError {
message: format!("Invalid token: {}", e),
}
})?;
Ok(token_data.claims)
}
pub async fn revoke_refresh_token(&self, refresh_token: &str) -> McpResult<()> {
let mut tokens = self.refresh_tokens.write().await;
tokens.remove(refresh_token);
Ok(())
}
pub async fn cleanup_expired_tokens(&self) {
let mut tokens = self.refresh_tokens.write().await;
let now = Utc::now();
tokens.retain(|_, token| token.expires_at > now);
}
fn generate_access_token(&self, client_id: &str, scopes: &[String]) -> McpResult<String> {
let now = Utc::now();
let expires_at = now + Duration::seconds(self.access_token_ttl);
let claims = TokenClaims {
sub: client_id.to_string(),
iat: now.timestamp(),
exp: expires_at.timestamp(),
iss: self.issuer.clone(),
aud: self.audience.clone(),
scopes: scopes.to_vec(),
token_type: Some("access".to_string()),
};
let header = Header::new(Algorithm::HS256);
encode(&header, &claims, &self.encoding_key).map_err(|e| McpError::InternalError {
message: format!("Failed to generate token: {}", e),
})
}
async fn generate_refresh_token(
&self,
client_id: &str,
scopes: &[String],
) -> McpResult<String> {
let mut rng = rand::thread_rng();
let token_bytes: Vec<u8> = (0..32).map(|_| rng.gen()).collect();
let token = BASE64.encode(&token_bytes);
let expires_at = Utc::now() + Duration::seconds(self.refresh_token_ttl);
let refresh_token_data = RefreshToken {
token: token.clone(),
client_id: client_id.to_string(),
scopes: scopes.to_vec(),
expires_at,
};
let mut tokens = self.refresh_tokens.write().await;
tokens.insert(token.clone(), refresh_token_data);
Ok(token)
}
#[cfg(test)]
pub async fn get_client(&self, client_id: &str) -> Option<ClientCredentials> {
let clients = self.clients.read().await;
clients.get(client_id).cloned()
}
#[cfg(test)]
pub async fn refresh_token_count(&self) -> usize {
self.refresh_tokens.read().await.len()
}
}
fn hash_secret(secret: &str) -> String {
let mut hasher = Sha256::new();
hasher.update(secret.as_bytes());
BASE64.encode(hasher.finalize())
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_client_registration() {
let manager = OAuthManager::new(
"test-secret-key",
"test-issuer".to_string(),
"test-audience".to_string(),
);
let credentials = ClientCredentials::new(
"test-client".to_string(),
"test-secret",
"Test Client".to_string(),
vec!["read".to_string(), "write".to_string()],
);
manager.register_client(credentials.clone()).await.unwrap();
let stored_client = manager.get_client("test-client").await.unwrap();
assert_eq!(stored_client.client_id, "test-client");
assert!(stored_client.verify_secret("test-secret"));
assert!(!stored_client.verify_secret("wrong-secret"));
}
#[tokio::test]
async fn test_client_credentials_grant() {
let manager = OAuthManager::new(
"test-secret-key",
"test-issuer".to_string(),
"test-audience".to_string(),
);
let credentials = ClientCredentials::new(
"test-client".to_string(),
"test-secret",
"Test Client".to_string(),
vec!["read".to_string(), "write".to_string()],
);
manager.register_client(credentials).await.unwrap();
let response = manager
.client_credentials_grant("test-client", "test-secret", vec!["read".to_string()])
.await
.unwrap();
assert_eq!(response.token_type, "Bearer");
assert_eq!(response.expires_in, 3600);
assert!(response.refresh_token.is_some());
assert_eq!(response.scope, Some("read".to_string()));
let claims = manager.validate_token(&response.access_token).unwrap();
assert_eq!(claims.sub, "test-client");
assert_eq!(claims.scopes, vec!["read"]);
}
#[tokio::test]
async fn test_invalid_credentials() {
let manager = OAuthManager::new(
"test-secret-key",
"test-issuer".to_string(),
"test-audience".to_string(),
);
let credentials = ClientCredentials::new(
"test-client".to_string(),
"test-secret",
"Test Client".to_string(),
vec!["read".to_string()],
);
manager.register_client(credentials).await.unwrap();
let result = manager
.client_credentials_grant("wrong-client", "test-secret", vec!["read".to_string()])
.await;
assert!(result.is_err());
let result = manager
.client_credentials_grant("test-client", "wrong-secret", vec!["read".to_string()])
.await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_refresh_token_grant() {
let manager = OAuthManager::new(
"test-secret-key",
"test-issuer".to_string(),
"test-audience".to_string(),
);
let credentials = ClientCredentials::new(
"test-client".to_string(),
"test-secret",
"Test Client".to_string(),
vec!["read".to_string()],
);
manager.register_client(credentials).await.unwrap();
let initial_response = manager
.client_credentials_grant("test-client", "test-secret", vec!["read".to_string()])
.await
.unwrap();
let refresh_token = initial_response.refresh_token.unwrap();
let refreshed_response = manager.refresh_token_grant(&refresh_token).await.unwrap();
assert_eq!(refreshed_response.token_type, "Bearer");
assert!(refreshed_response.refresh_token.is_none());
let claims = manager
.validate_token(&refreshed_response.access_token)
.unwrap();
assert_eq!(claims.sub, "test-client");
assert_eq!(claims.scopes, vec!["read"]);
}
#[tokio::test]
async fn test_token_revocation() {
let manager = OAuthManager::new(
"test-secret-key",
"test-issuer".to_string(),
"test-audience".to_string(),
);
let credentials = ClientCredentials::new(
"test-client".to_string(),
"test-secret",
"Test Client".to_string(),
vec!["read".to_string()],
);
manager.register_client(credentials).await.unwrap();
let response = manager
.client_credentials_grant("test-client", "test-secret", vec!["read".to_string()])
.await
.unwrap();
let refresh_token = response.refresh_token.unwrap();
manager.revoke_refresh_token(&refresh_token).await.unwrap();
let result = manager.refresh_token_grant(&refresh_token).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_scope_filtering() {
let manager = OAuthManager::new(
"test-secret-key",
"test-issuer".to_string(),
"test-audience".to_string(),
);
let credentials = ClientCredentials::new(
"test-client".to_string(),
"test-secret",
"Test Client".to_string(),
vec!["read".to_string()], );
manager.register_client(credentials).await.unwrap();
let response = manager
.client_credentials_grant(
"test-client",
"test-secret",
vec!["read".to_string(), "write".to_string()],
)
.await
.unwrap();
assert_eq!(response.scope, Some("read".to_string()));
let claims = manager.validate_token(&response.access_token).unwrap();
assert_eq!(claims.scopes, vec!["read"]);
assert!(!claims.scopes.contains(&"write".to_string()));
}
#[tokio::test]
async fn test_cleanup_expired_tokens() {
let mut manager = OAuthManager::new(
"test-secret-key",
"test-issuer".to_string(),
"test-audience".to_string(),
);
manager.refresh_token_ttl = 1;
let credentials = ClientCredentials::new(
"test-client".to_string(),
"test-secret",
"Test Client".to_string(),
vec!["read".to_string()],
);
manager.register_client(credentials).await.unwrap();
let _response = manager
.client_credentials_grant("test-client", "test-secret", vec!["read".to_string()])
.await
.unwrap();
assert_eq!(manager.refresh_token_count().await, 1);
tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
manager.cleanup_expired_tokens().await;
assert_eq!(manager.refresh_token_count().await, 0);
}
}