use axum::{
extract::{Query, State},
http::{HeaderMap, StatusCode},
response::{IntoResponse, Redirect, Response},
Json,
};
use serde::{Deserialize, Serialize};
use crate::handlers::{
check_auth_token, chrono_now, AccessTokenRecord, AppState, OAuthClientRecord,
RefreshTokenRecord,
};
use crate::rate_limit::{ReadRateLimit, RegistrationRateLimit};
use chrono;
pub async fn handle_protected_resource_metadata(
State(state): State<AppState>,
) -> impl IntoResponse {
let base = base_url(&state);
let body = serde_json::json!({
"resource": format!("{base}/mcp"),
"authorization_servers": [base],
"scopes_supported": ["mcp:tools"],
"bearer_methods_supported": ["header"]
});
(
StatusCode::OK,
[(axum::http::header::CONTENT_TYPE, "application/json")],
axum::Json(body),
)
}
pub async fn handle_authorization_server_metadata(
State(state): State<AppState>,
) -> impl IntoResponse {
let base = base_url(&state);
let body = serde_json::json!({
"issuer": base,
"authorization_endpoint": format!("{base}/authorize"),
"token_endpoint": format!("{base}/oauth/token"),
"registration_endpoint": format!("{base}/oauth/register"),
"response_types_supported": ["code"],
"code_challenge_methods_supported": ["S256"],
"scopes_supported": ["mcp:tools", "offline_access"],
"token_endpoint_auth_methods_supported": ["none", "client_secret_post"],
"grant_types_supported": ["authorization_code", "client_credentials", "refresh_token"]
});
(
StatusCode::OK,
[(axum::http::header::CONTENT_TYPE, "application/json")],
axum::Json(body),
)
}
#[derive(Debug, Deserialize)]
pub struct RegisterRequest {
pub client_name: Option<String>,
pub redirect_uris: Option<Vec<String>>,
pub grant_types: Option<Vec<String>>,
pub response_types: Option<Vec<String>>,
pub token_endpoint_auth_method: Option<String>,
pub scope: Option<String>,
}
#[derive(Debug, Serialize)]
struct RegisterResponse {
client_id: String,
client_name: String,
redirect_uris: Vec<String>,
grant_types: Vec<String>,
response_types: Vec<String>,
token_endpoint_auth_method: String,
}
pub async fn handle_register(
State(state): State<AppState>,
_rl: RegistrationRateLimit,
Json(req): Json<RegisterRequest>,
) -> Response {
let client_name = req.client_name.unwrap_or_else(|| "unnamed-client".to_string());
let redirect_uris = req.redirect_uris.unwrap_or_default();
if redirect_uris.is_empty() {
return (
StatusCode::BAD_REQUEST,
Json(OAuthError {
error: "invalid_client_metadata",
error_description: "redirect_uris is required and must not be empty",
}),
)
.into_response();
}
let grant_types = req
.grant_types
.unwrap_or_else(|| vec!["authorization_code".to_string()]);
let response_types = req
.response_types
.unwrap_or_else(|| vec!["code".to_string()]);
let token_endpoint_auth_method = req
.token_endpoint_auth_method
.unwrap_or_else(|| "none".to_string());
let client_id = new_uuid();
let now = chrono_now();
let cutoff_24h = {
let cutoff = chrono::Utc::now() - chrono::Duration::hours(24);
cutoff.format("%Y-%m-%dT%H:%M:%SZ").to_string()
};
let record = OAuthClientRecord {
client_id: client_id.clone(),
client_name: client_name.clone(),
redirect_uris: redirect_uris.clone(),
grant_types: grant_types.clone(),
response_types: response_types.clone(),
token_endpoint_auth_method: token_endpoint_auth_method.clone(),
created_at: now,
};
{
let mut clients = state
.oauth_clients
.lock()
.unwrap_or_else(|e| e.into_inner());
clients.retain(|_, v| v.created_at.as_str() >= cutoff_24h.as_str());
if clients.len() >= 1_000 {
tracing::warn!("OAuth client registration limit reached ({} entries)", clients.len());
return (
StatusCode::SERVICE_UNAVAILABLE,
Json(serde_json::json!({
"error": "server_error",
"error_description": "Registration limit reached"
})),
)
.into_response();
}
clients.insert(client_id.clone(), record);
}
tracing::info!(client_id = %client_id, client_name = %client_name, "OAuth dynamic client registered");
let resp = RegisterResponse {
client_id,
client_name,
redirect_uris,
grant_types,
response_types,
token_endpoint_auth_method,
};
(StatusCode::CREATED, Json(resp)).into_response()
}
#[derive(Debug, Deserialize)]
pub struct AuthorizeParams {
pub response_type: Option<String>,
pub client_id: Option<String>,
pub redirect_uri: Option<String>,
pub state: Option<String>,
pub code_challenge: Option<String>,
pub code_challenge_method: Option<String>,
pub resource: Option<String>,
pub scope: Option<String>,
}
pub async fn handle_authorize(
State(state): State<AppState>,
_rl: ReadRateLimit,
Query(params): Query<AuthorizeParams>,
) -> Response {
match params.response_type.as_deref() {
Some("code") => {}
_ => {
return (
StatusCode::BAD_REQUEST,
Json(OAuthError {
error: "unsupported_response_type",
error_description: "Only response_type=code is supported",
}),
)
.into_response();
}
}
let client_id = match params.client_id {
Some(ref id) if !id.is_empty() => id.clone(),
_ => {
return (
StatusCode::BAD_REQUEST,
Json(OAuthError {
error: "invalid_request",
error_description: "client_id is required",
}),
)
.into_response();
}
};
let redirect_uri = match params.redirect_uri {
Some(ref uri) if !uri.is_empty() => uri.clone(),
_ => {
return (
StatusCode::BAD_REQUEST,
Json(OAuthError {
error: "invalid_request",
error_description: "redirect_uri is required",
}),
)
.into_response();
}
};
let code_challenge = match params.code_challenge {
Some(ref c) if !c.is_empty() => c.clone(),
_ => {
return (
StatusCode::BAD_REQUEST,
Json(OAuthError {
error: "invalid_request",
error_description: "code_challenge is required (S256 PKCE mandatory)",
}),
)
.into_response();
}
};
match params.code_challenge_method.as_deref() {
Some("S256") | None => {}
Some(_) => {
return (
StatusCode::BAD_REQUEST,
Json(OAuthError {
error: "invalid_request",
error_description: "Only code_challenge_method=S256 is supported",
}),
)
.into_response();
}
}
{
let clients = state
.oauth_clients
.lock()
.unwrap_or_else(|e| e.into_inner());
if let Some(client) = clients.get(&client_id) {
if !client.redirect_uris.contains(&redirect_uri) {
return (
StatusCode::BAD_REQUEST,
Json(OAuthError {
error: "invalid_request",
error_description: "redirect_uri does not match registered redirect URIs",
}),
)
.into_response();
}
} else {
if !is_safe_redirect_uri(&redirect_uri) {
return (
StatusCode::BAD_REQUEST,
Json(OAuthError {
error: "invalid_request",
error_description: "redirect_uri must use HTTPS (localhost is permitted over HTTP)",
}),
)
.into_response();
}
}
}
let code = {
use rand::RngCore;
let mut bytes = [0u8; 32];
rand::thread_rng().fill_bytes(&mut bytes);
hex_encode(&bytes)
};
let now = chrono_now();
let expires_at = {
let future = chrono::Utc::now() + chrono::Duration::minutes(5);
future.format("%Y-%m-%dT%H:%M:%SZ").to_string()
};
{
let mut codes = state
.auth_codes
.lock()
.unwrap_or_else(|e| e.into_inner());
codes.retain(|_, v| v.expires_at.as_str() >= now.as_str());
codes.insert(
code.clone(),
crate::handlers::AuthCodeRecord {
client_id: client_id.clone(),
redirect_uri: redirect_uri.clone(),
expires_at,
code_challenge,
resource: params.resource.clone(),
scope: params.scope.clone(),
},
);
}
tracing::info!(client_id = %client_id, scope = ?params.scope, "OAuth authorization_code issued");
let redirect_url = match params.state.as_deref() {
Some(s) if !s.is_empty() => format!(
"{}{}code={}&state={}",
redirect_uri,
if redirect_uri.contains('?') { "&" } else { "?" },
code,
percent_encode(s),
),
_ => format!(
"{}{}code={}",
redirect_uri,
if redirect_uri.contains('?') { "&" } else { "?" },
code,
),
};
Redirect::to(&redirect_url).into_response()
}
#[derive(Debug, Deserialize)]
struct TokenRequest {
grant_type: String,
client_id: Option<String>,
client_secret: Option<String>,
code: Option<String>,
redirect_uri: Option<String>,
code_verifier: Option<String>,
resource: Option<String>,
refresh_token: Option<String>,
scope: Option<String>,
}
#[derive(Debug, Serialize)]
struct TokenResponse {
access_token: String,
token_type: &'static str,
expires_in: u32,
#[serde(skip_serializing_if = "Option::is_none")]
refresh_token: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
scope: Option<String>,
}
pub async fn handle_oauth_token(
State(state): State<AppState>,
_rl: ReadRateLimit,
headers: HeaderMap,
body: axum::body::Bytes,
) -> Response {
let content_type = headers
.get("content-type")
.and_then(|v| v.to_str().ok())
.unwrap_or("");
let req = if content_type.starts_with("application/json") {
match serde_json::from_slice::<TokenRequest>(&body) {
Ok(r) => r,
Err(_) => return invalid_request("Missing or malformed request parameters"),
}
} else {
let body_str = match std::str::from_utf8(&body) {
Ok(s) => s,
Err(_) => return invalid_request("Request body is not valid UTF-8"),
};
match parse_form(body_str) {
Some(r) => r,
None => return invalid_request("grant_type is required"),
}
};
match req.grant_type.as_str() {
"client_credentials" => handle_client_credentials(state, req).await,
"authorization_code" => handle_authorization_code(state, req).await,
"refresh_token" => handle_refresh_token(state, req).await,
_ => unsupported_grant_type(),
}
}
async fn handle_client_credentials(state: AppState, req: TokenRequest) -> Response {
let client_id = req.client_id.as_deref().unwrap_or("<unset>");
let secret = match req.client_secret.as_deref() {
Some(s) if !s.is_empty() => s.to_string(),
_ => return invalid_request("client_secret is required for client_credentials grant"),
};
match check_auth_token(&state, &secret).await {
Ok(_) => {
tracing::info!(client_id = %client_id, "OAuth client_credentials grant issued");
(
StatusCode::OK,
Json(TokenResponse {
access_token: secret,
token_type: "bearer",
expires_in: 3600,
refresh_token: None,
scope: req.scope,
}),
)
.into_response()
}
Err(_) => {
tracing::warn!(client_id = %client_id, "OAuth client_credentials denied");
invalid_client()
}
}
}
async fn handle_authorization_code(state: AppState, req: TokenRequest) -> Response {
let code = match req.code.as_deref() {
Some(c) if !c.is_empty() => c.to_string(),
_ => return invalid_request("code is required"),
};
let client_id = match req.client_id.as_deref() {
Some(id) if !id.is_empty() => id.to_string(),
_ => return invalid_request("client_id is required"),
};
let redirect_uri = match req.redirect_uri.as_deref() {
Some(uri) if !uri.is_empty() => uri.to_string(),
_ => return invalid_request("redirect_uri is required"),
};
let code_verifier = match req.code_verifier.as_deref() {
Some(v) if !v.is_empty() => v.to_string(),
_ => return invalid_request("code_verifier is required (PKCE mandatory)"),
};
let record = {
let mut codes = state
.auth_codes
.lock()
.unwrap_or_else(|e| e.into_inner());
match codes.remove(&code) {
Some(r) => r,
None => {
tracing::warn!(client_id = %client_id, "OAuth authorization_code not found");
return invalid_grant("Authorization code not found or already used");
}
}
};
let now = chrono_now();
if record.expires_at.as_str() < now.as_str() {
tracing::warn!(client_id = %client_id, "OAuth authorization_code expired");
return invalid_grant("Authorization code has expired");
}
if record.client_id != client_id {
tracing::warn!(client_id = %client_id, "OAuth authorization_code client_id mismatch");
return invalid_grant("client_id does not match authorization request");
}
if record.redirect_uri != redirect_uri {
tracing::warn!(client_id = %client_id, "OAuth authorization_code redirect_uri mismatch");
return invalid_grant("redirect_uri does not match authorization request");
}
if let Some(ref stored_resource) = record.resource {
match req.resource.as_deref() {
Some(req_resource) if req_resource == stored_resource => {}
Some(_) => {
tracing::warn!(client_id = %client_id, "OAuth resource parameter mismatch");
return invalid_grant("resource parameter does not match authorization request");
}
None => {
tracing::warn!(client_id = %client_id, "OAuth token request missing required resource parameter");
return invalid_grant("resource parameter is required for this authorization code");
}
}
}
if !verify_pkce_s256(&code_verifier, &record.code_challenge) {
tracing::warn!(client_id = %client_id, "OAuth PKCE verification failed");
return invalid_grant("code_verifier does not match code_challenge");
}
let is_public_client = {
let clients = state
.oauth_clients
.lock()
.unwrap_or_else(|e| e.into_inner());
clients
.get(&client_id)
.map(|c| c.token_endpoint_auth_method == "none")
.unwrap_or(false)
};
let access_token = if is_public_client {
let at = new_uuid();
let at_expires = {
let future = chrono::Utc::now() + chrono::Duration::hours(1);
future.format("%Y-%m-%dT%H:%M:%SZ").to_string()
};
{
let mut tokens = state
.access_tokens
.lock()
.unwrap_or_else(|e| e.into_inner());
tokens.retain(|_, v| v.expires_at.as_str() >= now.as_str());
tokens.insert(
at.clone(),
AccessTokenRecord {
client_id: client_id.clone(),
scope: record.scope.clone(),
expires_at: at_expires,
},
);
}
at
} else {
let client_secret = match req.client_secret.as_deref() {
Some(s) if !s.is_empty() => s.to_string(),
_ => return invalid_client(),
};
match check_auth_token(&state, &client_secret).await {
Ok(_) => client_secret,
Err(_) => {
tracing::warn!(client_id = %client_id, "OAuth authorization_code denied: invalid secret");
return invalid_client();
}
}
};
let wants_offline = record
.scope
.as_deref()
.map(|s| s.split_whitespace().any(|tok| tok == "offline_access"))
.unwrap_or(false);
let refresh_tok = if wants_offline {
let rt = new_uuid();
let rt_expires = {
let future = chrono::Utc::now() + chrono::Duration::days(30);
future.format("%Y-%m-%dT%H:%M:%SZ").to_string()
};
let rt_record = RefreshTokenRecord {
client_id: client_id.clone(),
access_token: access_token.clone(),
scope: record.scope.clone(),
expires_at: rt_expires,
};
let mut tokens = state
.refresh_tokens
.lock()
.unwrap_or_else(|e| e.into_inner());
tokens.retain(|_, v| v.expires_at.as_str() >= now.as_str());
tokens.insert(rt.clone(), rt_record);
Some(rt)
} else {
None
};
tracing::info!(client_id = %client_id, has_refresh = refresh_tok.is_some(), "OAuth authorization_code grant issued");
(
StatusCode::OK,
Json(TokenResponse {
access_token,
token_type: "bearer",
expires_in: 3600,
refresh_token: refresh_tok,
scope: record.scope,
}),
)
.into_response()
}
async fn handle_refresh_token(state: AppState, req: TokenRequest) -> Response {
let rt_value = match req.refresh_token.as_deref() {
Some(t) if !t.is_empty() => t.to_string(),
_ => return invalid_request("refresh_token is required"),
};
let client_id = match req.client_id.as_deref() {
Some(id) if !id.is_empty() => id.to_string(),
_ => return invalid_request("client_id is required"),
};
let record = {
let mut tokens = state
.refresh_tokens
.lock()
.unwrap_or_else(|e| e.into_inner());
match tokens.remove(&rt_value) {
Some(r) => r,
None => {
tracing::warn!(client_id = %client_id, "OAuth refresh_token not found");
return invalid_grant("Refresh token not found or already used");
}
}
};
let now = chrono_now();
if record.expires_at.as_str() < now.as_str() {
tracing::warn!(client_id = %client_id, "OAuth refresh_token expired");
return invalid_grant("Refresh token has expired");
}
if record.client_id != client_id {
tracing::warn!(client_id = %client_id, "OAuth refresh_token client_id mismatch");
return invalid_grant("client_id does not match refresh token");
}
let is_public_client = {
let clients = state
.oauth_clients
.lock()
.unwrap_or_else(|e| e.into_inner());
clients
.get(&client_id)
.map(|c| c.token_endpoint_auth_method == "none")
.unwrap_or(false)
};
let access_token = if is_public_client {
let at = new_uuid();
let at_expires = {
let future = chrono::Utc::now() + chrono::Duration::hours(1);
future.format("%Y-%m-%dT%H:%M:%SZ").to_string()
};
{
let mut tokens = state
.access_tokens
.lock()
.unwrap_or_else(|e| e.into_inner());
tokens.retain(|_, v| v.expires_at.as_str() >= now.as_str());
tokens.insert(
at.clone(),
AccessTokenRecord {
client_id: client_id.clone(),
scope: record.scope.clone(),
expires_at: at_expires,
},
);
}
at
} else {
let client_secret = match req.client_secret.as_deref() {
Some(s) if !s.is_empty() => s.to_string(),
_ => return invalid_client(),
};
match check_auth_token(&state, &client_secret).await {
Ok(_) => client_secret,
Err(_) => return invalid_client(),
}
};
let new_rt = new_uuid();
let new_rt_expires = {
let future = chrono::Utc::now() + chrono::Duration::days(30);
future.format("%Y-%m-%dT%H:%M:%SZ").to_string()
};
{
let mut tokens = state
.refresh_tokens
.lock()
.unwrap_or_else(|e| e.into_inner());
tokens.retain(|_, v| v.expires_at.as_str() >= now.as_str());
tokens.insert(
new_rt.clone(),
RefreshTokenRecord {
client_id: client_id.clone(),
access_token: access_token.clone(),
scope: record.scope.clone(),
expires_at: new_rt_expires,
},
);
}
tracing::info!(client_id = %client_id, "OAuth refresh_token grant issued (rotated)");
(
StatusCode::OK,
Json(TokenResponse {
access_token,
token_type: "bearer",
expires_in: 3600,
refresh_token: Some(new_rt),
scope: record.scope,
}),
)
.into_response()
}
fn verify_pkce_s256(code_verifier: &str, stored_challenge: &str) -> bool {
use base64::Engine;
use sha2::{Digest, Sha256};
let hash = Sha256::digest(code_verifier.as_bytes());
let computed = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(hash);
constant_time_str_eq(&computed, stored_challenge)
}
fn constant_time_str_eq(a: &str, b: &str) -> bool {
use subtle::ConstantTimeEq;
if a.len() != b.len() {
return false;
}
a.as_bytes().ct_eq(b.as_bytes()).into()
}
fn base_url(state: &AppState) -> String {
state.config.base_url.trim_end_matches('/').to_string()
}
fn new_uuid() -> String {
use rand::RngCore;
let mut bytes = [0u8; 16];
rand::thread_rng().fill_bytes(&mut bytes);
bytes[6] = (bytes[6] & 0x0f) | 0x40;
bytes[8] = (bytes[8] & 0x3f) | 0x80;
format!(
"{:02x}{:02x}{:02x}{:02x}-{:02x}{:02x}-{:02x}{:02x}-{:02x}{:02x}-{:02x}{:02x}{:02x}{:02x}{:02x}{:02x}",
bytes[0], bytes[1], bytes[2], bytes[3],
bytes[4], bytes[5],
bytes[6], bytes[7],
bytes[8], bytes[9],
bytes[10], bytes[11], bytes[12], bytes[13], bytes[14], bytes[15],
)
}
fn hex_encode(bytes: &[u8]) -> String {
bytes.iter().map(|b| format!("{b:02x}")).collect()
}
fn percent_encode(s: &str) -> String {
let mut out = String::with_capacity(s.len());
for byte in s.as_bytes() {
match byte {
b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'.' | b'_' | b'~' => {
out.push(char::from(*byte));
}
b => {
out.push('%');
out.push(
char::from_digit((*b >> 4) as u32, 16)
.unwrap()
.to_ascii_uppercase(),
);
out.push(
char::from_digit((*b & 0xf) as u32, 16)
.unwrap()
.to_ascii_uppercase(),
);
}
}
}
out
}
fn is_safe_redirect_uri(uri: &str) -> bool {
if uri.starts_with("https://") {
return true;
}
if uri.starts_with("http://localhost") || uri.starts_with("http://127.0.0.1") {
return true;
}
false
}
fn url_decode(s: &str) -> String {
let mut out = String::with_capacity(s.len());
let bytes = s.as_bytes();
let mut i = 0;
while i < bytes.len() {
if bytes[i] == b'+' {
out.push(' ');
i += 1;
} else if bytes[i] == b'%' && i + 2 < bytes.len() {
if let (Some(h), Some(l)) = (from_hex(bytes[i + 1]), from_hex(bytes[i + 2])) {
out.push(char::from(h << 4 | l));
i += 3;
continue;
}
out.push('%');
i += 1;
} else {
out.push(char::from(bytes[i]));
i += 1;
}
}
out
}
fn from_hex(b: u8) -> Option<u8> {
match b {
b'0'..=b'9' => Some(b - b'0'),
b'a'..=b'f' => Some(b - b'a' + 10),
b'A'..=b'F' => Some(b - b'A' + 10),
_ => None,
}
}
fn parse_form(body: &str) -> Option<TokenRequest> {
let mut grant_type = None;
let mut client_id = None;
let mut client_secret = None;
let mut code = None;
let mut redirect_uri = None;
let mut code_verifier = None;
let mut resource = None;
let mut refresh_token = None;
let mut scope = None;
for pair in body.split('&') {
if let Some((k, v)) = pair.split_once('=') {
let k = url_decode(k);
let v = url_decode(v);
match k.as_str() {
"grant_type" => grant_type = Some(v),
"client_id" => client_id = Some(v),
"client_secret" => client_secret = Some(v),
"code" => code = Some(v),
"redirect_uri" => redirect_uri = Some(v),
"code_verifier" => code_verifier = Some(v),
"resource" => resource = Some(v),
"refresh_token" => refresh_token = Some(v),
"scope" => scope = Some(v),
_ => {}
}
}
}
Some(TokenRequest {
grant_type: grant_type?,
client_id,
client_secret,
code,
redirect_uri,
code_verifier,
resource,
refresh_token,
scope,
})
}
#[derive(Debug, Serialize)]
struct OAuthError {
error: &'static str,
error_description: &'static str,
}
fn invalid_client() -> Response {
(
StatusCode::UNAUTHORIZED,
Json(OAuthError {
error: "invalid_client",
error_description: "Invalid client credentials",
}),
)
.into_response()
}
fn invalid_request(description: &'static str) -> Response {
(
StatusCode::BAD_REQUEST,
Json(OAuthError {
error: "invalid_request",
error_description: description,
}),
)
.into_response()
}
fn unsupported_grant_type() -> Response {
(
StatusCode::BAD_REQUEST,
Json(OAuthError {
error: "unsupported_grant_type",
error_description: "Supported: authorization_code, client_credentials, refresh_token",
}),
)
.into_response()
}
fn invalid_grant(description: &'static str) -> Response {
(
StatusCode::BAD_REQUEST,
Json(OAuthError {
error: "invalid_grant",
error_description: description,
}),
)
.into_response()
}