#[cfg(feature = "oauth")]
use super::types::AuthorizationResult;
use do_memory_mcp::protocol::OAuthConfig;
#[cfg(feature = "oauth")]
use tracing::debug;
pub fn load_oauth_config() -> OAuthConfig {
let enabled = std::env::var("MCP_OAUTH_ENABLED")
.unwrap_or_else(|_| "false".to_string())
.to_lowercase();
OAuthConfig {
enabled: enabled == "true" || enabled == "1" || enabled == "yes",
audience: std::env::var("MCP_OAUTH_AUDIENCE").ok(),
issuer: std::env::var("MCP_OAUTH_ISSUER").ok(),
scopes: std::env::var("MCP_OAUTH_SCOPES")
.unwrap_or_else(|_| "mcp:read,mcp:write".to_string())
.split(',')
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty())
.collect(),
jwks_uri: std::env::var("MCP_OAUTH_JWKS_URI").ok(),
}
}
#[cfg(feature = "oauth")]
pub fn validate_bearer_token(token: &str, config: &OAuthConfig) -> AuthorizationResult {
let parts: Vec<&str> = token.split('.').collect();
if parts.len() != 3 {
return AuthorizationResult::InvalidToken("Invalid token format".to_string());
}
let payload = match base64url_decode(parts[1]) {
Ok(p) => p,
Err(e) => {
return AuthorizationResult::InvalidToken(format!("Invalid token payload: {}", e));
}
};
let payload_str = match String::from_utf8(payload) {
Ok(s) => s,
Err(e) => {
return AuthorizationResult::InvalidToken(format!("Invalid token encoding: {}", e));
}
};
let claims: serde_json::Value = match serde_json::from_str(&payload_str) {
Ok(c) => c,
Err(e) => return AuthorizationResult::InvalidToken(format!("Invalid token JSON: {}", e)),
};
if let Some(expected_iss) = &config.issuer {
let token_iss = claims.get("iss").and_then(|v| v.as_str()).unwrap_or("");
if !token_iss.is_empty() && token_iss != expected_iss {
return AuthorizationResult::InvalidToken(format!(
"Invalid token issuer: expected {}, got {}",
expected_iss, token_iss
));
}
}
if let Some(expected_aud) = &config.audience {
let token_aud = claims.get("aud").and_then(|v| v.as_str()).unwrap_or("");
if !token_aud.is_empty() && token_aud != expected_aud {
return AuthorizationResult::InvalidToken(format!(
"Invalid token audience: expected {}, got {}",
expected_aud, token_aud
));
}
}
if let Some(exp) = claims.get("exp").and_then(|v| v.as_u64()) {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
if exp < now {
return AuthorizationResult::InvalidToken("Token expired".to_string());
}
}
let sub = claims.get("sub").and_then(|v| v.as_str()).unwrap_or("");
if sub.is_empty() {
return AuthorizationResult::InvalidToken("Token missing subject claim".to_string());
}
debug!("Token validated for subject: {}", sub);
AuthorizationResult::Authorized
}
#[cfg(feature = "oauth")]
pub fn base64url_decode(input: &str) -> Result<Vec<u8>, base64::DecodeError> {
let filtered: String = input.chars().filter(|c| !c.is_whitespace()).collect();
let padded = match filtered.len() % 4 {
2 => filtered + "==",
3 => filtered + "=",
_ => filtered,
};
base64::Engine::decode(&base64::engine::general_purpose::URL_SAFE_NO_PAD, &padded)
}
#[cfg(feature = "oauth")]
pub fn check_scopes(token_scope: Option<&str>, required_scopes: &[String]) -> AuthorizationResult {
let token_scopes: Vec<String> = match token_scope {
Some(s) => s
.split(' ')
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty())
.collect(),
None => vec![],
};
if required_scopes.is_empty() {
return AuthorizationResult::Authorized;
}
if token_scopes.is_empty() {
return AuthorizationResult::InsufficientScope(required_scopes.to_vec());
}
let missing: Vec<String> = required_scopes
.iter()
.filter(|r| !token_scopes.contains(r))
.cloned()
.collect();
if missing.is_empty() {
AuthorizationResult::Authorized
} else {
AuthorizationResult::InsufficientScope(missing)
}
}
#[cfg(feature = "oauth")]
pub fn extract_bearer_token(_headers: &str) -> Option<String> {
None
}
#[cfg(feature = "oauth")]
pub fn create_www_authenticate_header(
error: &str,
error_description: Option<&str>,
realm: Option<&str>,
) -> String {
let mut parts = vec![format!("error=\"{}\"", error)];
if let Some(desc) = error_description {
parts.push(format!("error_description=\"{}\"", desc));
}
if let Some(r) = realm {
parts.push(format!("realm=\"{}\"", r));
}
format!("Bearer {}", parts.join(", "))
}