use crate::error::Result;
use async_trait::async_trait;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct AuthContext {
pub subject: String,
pub scopes: Vec<String>,
pub claims: HashMap<String, serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub token: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub client_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub expires_at: Option<u64>,
#[serde(default)]
pub authenticated: bool,
}
impl AuthContext {
pub fn new(subject: impl Into<String>) -> Self {
Self {
subject: subject.into(),
authenticated: true,
..Default::default()
}
}
pub fn anonymous() -> Self {
Self {
subject: "anonymous".to_string(),
authenticated: false,
..Default::default()
}
}
#[inline]
pub fn user_id(&self) -> &str {
&self.subject
}
pub fn claim<T: DeserializeOwned>(&self, key: &str) -> Option<T> {
self.claims
.get(key)
.and_then(|v| serde_json::from_value(v.clone()).ok())
}
pub fn email(&self) -> Option<&str> {
self.claims
.get("email")
.or_else(|| self.claims.get("preferred_username"))
.or_else(|| self.claims.get("upn"))
.and_then(|v| v.as_str())
}
pub fn name(&self) -> Option<&str> {
self.claims.get("name").and_then(|v| v.as_str())
}
pub fn tenant_id(&self) -> Option<&str> {
self.claims
.get("tenant_id")
.or_else(|| self.claims.get("tid")) .or_else(|| self.claims.get("custom:tenant_id")) .or_else(|| self.claims.get("custom:tenant")) .or_else(|| self.claims.get("org_id")) .and_then(|v| v.as_str())
}
pub fn groups(&self) -> Vec<String> {
self.claims
.get("groups")
.or_else(|| self.claims.get("cognito:groups"))
.or_else(|| self.claims.get("roles"))
.and_then(|v| serde_json::from_value(v.clone()).ok())
.unwrap_or_default()
}
pub fn has_scope(&self, scope: &str) -> bool {
self.scopes.iter().any(|s| s == scope)
}
pub fn has_all_scopes(&self, scopes: &[&str]) -> bool {
scopes.iter().all(|scope| self.has_scope(scope))
}
pub fn has_any_scope(&self, scopes: &[&str]) -> bool {
scopes.iter().any(|scope| self.has_scope(scope))
}
pub fn require_scope(&self, scope: &str) -> std::result::Result<(), &'static str> {
if self.has_scope(scope) {
Ok(())
} else {
Err("Insufficient scope")
}
}
pub fn require_auth(&self) -> std::result::Result<(), &'static str> {
if self.authenticated {
Ok(())
} else {
Err("Authentication required")
}
}
pub fn is_expired(&self) -> bool {
if let Some(expires_at) = self.expires_at {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs();
expires_at < now
} else {
false
}
}
pub fn in_group(&self, group: &str) -> bool {
self.groups().iter().any(|g| g == group)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ClaimMappings {
#[serde(default = "default_user_id_claim")]
pub user_id: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub tenant_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub email: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub groups: Option<String>,
#[serde(flatten)]
pub custom: HashMap<String, String>,
}
fn default_user_id_claim() -> String {
"sub".to_string()
}
impl Default for ClaimMappings {
fn default() -> Self {
Self {
user_id: default_user_id_claim(),
tenant_id: None,
email: Some("email".to_string()),
name: Some("name".to_string()),
groups: None,
custom: HashMap::new(),
}
}
}
impl ClaimMappings {
pub fn cognito() -> Self {
Self {
user_id: "sub".to_string(),
tenant_id: Some("custom:tenant_id".to_string()),
email: Some("email".to_string()),
name: Some("name".to_string()),
groups: Some("cognito:groups".to_string()),
custom: HashMap::new(),
}
}
pub fn entra() -> Self {
Self {
user_id: "oid".to_string(),
tenant_id: Some("tid".to_string()),
email: Some("preferred_username".to_string()),
name: Some("name".to_string()),
groups: Some("groups".to_string()),
custom: HashMap::new(),
}
}
pub fn google() -> Self {
Self {
user_id: "sub".to_string(),
tenant_id: None, email: Some("email".to_string()),
name: Some("name".to_string()),
groups: None,
custom: HashMap::new(),
}
}
pub fn okta() -> Self {
Self {
user_id: "uid".to_string(),
tenant_id: Some("org_id".to_string()),
email: Some("email".to_string()),
name: Some("name".to_string()),
groups: Some("groups".to_string()),
custom: HashMap::new(),
}
}
pub fn auth0() -> Self {
Self {
user_id: "sub".to_string(),
tenant_id: Some("org_id".to_string()),
email: Some("email".to_string()),
name: Some("name".to_string()),
groups: Some("roles".to_string()),
custom: HashMap::new(),
}
}
pub fn normalize_claims(
&self,
claims: &serde_json::Value,
) -> HashMap<String, serde_json::Value> {
let mut normalized = HashMap::new();
if let Some(obj) = claims.as_object() {
for (key, value) in obj {
normalized.insert(key.clone(), value.clone());
}
if let Some(value) = obj.get(&self.user_id) {
normalized.insert("sub".to_string(), value.clone());
}
if let Some(ref tenant_claim) = self.tenant_id {
if let Some(value) = obj.get(tenant_claim) {
normalized.insert("tenant_id".to_string(), value.clone());
}
}
if let Some(ref email_claim) = self.email {
if let Some(value) = obj.get(email_claim) {
normalized.insert("email".to_string(), value.clone());
}
}
if let Some(ref name_claim) = self.name {
if let Some(value) = obj.get(name_claim) {
normalized.insert("name".to_string(), value.clone());
}
}
if let Some(ref groups_claim) = self.groups {
if let Some(value) = obj.get(groups_claim) {
normalized.insert("groups".to_string(), value.clone());
}
}
for (standard_name, provider_name) in &self.custom {
if let Some(value) = obj.get(provider_name) {
normalized.insert(standard_name.clone(), value.clone());
}
}
}
normalized
}
}
#[async_trait]
pub trait AuthProvider: Send + Sync {
async fn validate_request(
&self,
authorization_header: Option<&str>,
) -> Result<Option<AuthContext>>;
fn auth_scheme(&self) -> &'static str {
"Bearer"
}
fn is_required(&self) -> bool {
true
}
}
#[async_trait]
pub trait TokenValidator: Send + Sync {
async fn validate(&self, token: &str) -> Result<AuthContext>;
async fn validate_with_context(
&self,
token: &str,
required_scopes: Option<&[&str]>,
) -> Result<AuthContext> {
let auth_context = self.validate(token).await?;
if let Some(scopes) = required_scopes {
if !auth_context.has_all_scopes(scopes) {
return Err(crate::error::Error::protocol(
crate::error::ErrorCode::INVALID_REQUEST,
"Insufficient scopes",
));
}
}
Ok(auth_context)
}
}
#[async_trait]
pub trait SessionManager: Send + Sync {
async fn create_session(&self, auth: AuthContext) -> Result<String>;
async fn get_session(&self, session_id: &str) -> Result<Option<AuthContext>>;
async fn update_session(&self, session_id: &str, auth: AuthContext) -> Result<()>;
async fn invalidate_session(&self, session_id: &str) -> Result<()>;
async fn cleanup_expired(&self) -> Result<usize> {
Ok(0) }
}
#[async_trait]
pub trait ToolAuthorizer: Send + Sync {
async fn can_access_tool(&self, auth: &AuthContext, tool_name: &str) -> Result<bool>;
async fn required_scopes_for_tool(&self, tool_name: &str) -> Result<Vec<String>>;
}
#[derive(Debug, Clone)]
pub struct ScopeBasedAuthorizer {
tool_scopes: HashMap<String, Vec<String>>,
default_scopes: Vec<String>,
}
impl ScopeBasedAuthorizer {
pub fn new() -> Self {
Self {
tool_scopes: HashMap::new(),
default_scopes: vec!["mcp:tools:use".to_string()],
}
}
pub fn require_scopes<S, I>(mut self, tool_name: impl Into<String>, scopes: I) -> Self
where
I: IntoIterator<Item = S>,
S: AsRef<str>,
{
let scopes_vec = scopes.into_iter().map(|s| s.as_ref().to_string()).collect();
self.tool_scopes.insert(tool_name.into(), scopes_vec);
self
}
pub fn default_scopes(mut self, scopes: Vec<String>) -> Self {
self.default_scopes = scopes;
self
}
}
#[async_trait]
impl ToolAuthorizer for ScopeBasedAuthorizer {
async fn can_access_tool(&self, auth: &AuthContext, tool_name: &str) -> Result<bool> {
let required_scopes = self
.tool_scopes
.get(tool_name)
.unwrap_or(&self.default_scopes);
let scope_refs: Vec<&str> = required_scopes.iter().map(|s| s.as_str()).collect();
Ok(auth.has_all_scopes(&scope_refs))
}
async fn required_scopes_for_tool(&self, tool_name: &str) -> Result<Vec<String>> {
Ok(self
.tool_scopes
.get(tool_name)
.unwrap_or(&self.default_scopes)
.clone())
}
}
impl Default for ScopeBasedAuthorizer {
fn default() -> Self {
Self::new()
}
}