use std::sync::Arc;
use anyhow::{Result, anyhow, bail};
use axum::{
Form, Json,
body::Body,
extract::{Query, State},
http::{
HeaderValue, Request, StatusCode,
header::{AUTHORIZATION, WWW_AUTHENTICATE},
request::Parts,
},
middleware::Next,
response::{IntoResponse, Redirect, Response},
};
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use crate::{
app_state::{
AccessTokenRecord, AppState, AuthorizationCodeGrant, AuthorizedSession,
PendingAuthorizationRequest, RegisteredClient,
},
auth::oauth::{
OAuthCallbackFailure, OAuthCallbackQuery, build_gyazo_authorize_url,
exchange_code_for_token,
},
gyazo_api::fetch_authenticated_user,
};
const REQUIRED_SCOPE: &str = "gyazo";
#[derive(Debug, Serialize)]
pub(crate) struct ProtectedResourceMetadata {
resource: String,
authorization_servers: Vec<String>,
scopes_supported: Vec<String>,
}
#[derive(Debug, Serialize)]
pub(crate) struct AuthorizationServerMetadata {
issuer: String,
authorization_endpoint: String,
token_endpoint: String,
registration_endpoint: String,
response_types_supported: Vec<String>,
grant_types_supported: Vec<String>,
token_endpoint_auth_methods_supported: Vec<String>,
code_challenge_methods_supported: Vec<String>,
}
#[derive(Debug, Deserialize)]
pub(crate) struct AuthorizationRequestQuery {
response_type: Option<String>,
client_id: Option<String>,
redirect_uri: Option<String>,
state: Option<String>,
code_challenge: Option<String>,
code_challenge_method: Option<String>,
scope: Option<String>,
resource: Option<String>,
}
#[derive(Debug, Deserialize)]
pub(crate) struct TokenRequestForm {
grant_type: Option<String>,
code: Option<String>,
redirect_uri: Option<String>,
client_id: Option<String>,
code_verifier: Option<String>,
resource: Option<String>,
}
#[derive(Debug, Serialize)]
pub(crate) struct TokenResponse {
access_token: String,
token_type: String,
scope: String,
}
#[derive(Debug, Deserialize)]
pub(crate) struct DynamicClientRegistrationRequest {
redirect_uris: Option<Vec<String>>,
client_name: Option<String>,
grant_types: Option<Vec<String>>,
response_types: Option<Vec<String>>,
token_endpoint_auth_method: Option<String>,
}
#[derive(Debug, Serialize)]
pub(crate) struct DynamicClientRegistrationResponse {
client_id: String,
redirect_uris: Vec<String>,
client_name: Option<String>,
grant_types: Vec<String>,
response_types: Vec<String>,
token_endpoint_auth_method: String,
}
pub(crate) async fn require_mcp_bearer_token(
State(app_state): State<Arc<AppState>>,
mut request: Request<Body>,
next: Next,
) -> Response {
let has_authorization_header = request.headers().contains_key(AUTHORIZATION);
match authorized_session_from_request(app_state.as_ref(), &request) {
Ok(Some(session)) => {
request.extensions_mut().insert(session);
next.run(request).await
}
Ok(None) if !has_authorization_header => {
if let Some(session) = get_verified_session(app_state.as_ref()).await {
request.extensions_mut().insert(session);
next.run(request).await
} else {
unauthorized_response(app_state.as_ref(), Some("invalid_token"))
}
}
Ok(None) => unauthorized_response(app_state.as_ref(), Some("invalid_token")),
Err(_) => unauthorized_response(app_state.as_ref(), Some("server_error")),
}
}
pub(crate) async fn protected_resource_metadata_handler(
State(app_state): State<Arc<AppState>>,
) -> Response {
if get_verified_session(app_state.as_ref()).await.is_some() {
StatusCode::NOT_FOUND.into_response()
} else {
Json(build_protected_resource_metadata(app_state.as_ref())).into_response()
}
}
pub(crate) async fn authorization_server_metadata_handler(
State(app_state): State<Arc<AppState>>,
) -> Response {
if get_verified_session(app_state.as_ref()).await.is_some() {
StatusCode::NOT_FOUND.into_response()
} else {
Json(build_authorization_server_metadata(app_state.as_ref())).into_response()
}
}
pub(crate) async fn authorize_handler(
State(app_state): State<Arc<AppState>>,
Query(query): Query<AuthorizationRequestQuery>,
) -> impl IntoResponse {
match start_authorization(app_state.as_ref(), query) {
Ok(AuthorizationStart::Redirect(redirect)) => {
Redirect::temporary(&redirect).into_response()
}
Err(error) => (StatusCode::BAD_REQUEST, error.to_string()).into_response(),
}
}
pub(crate) async fn token_handler(
State(app_state): State<Arc<AppState>>,
Form(form): Form<TokenRequestForm>,
) -> impl IntoResponse {
match exchange_authorization_code(app_state.as_ref(), form).await {
Ok(response) => (StatusCode::OK, Json(response)).into_response(),
Err(error) => (StatusCode::BAD_REQUEST, error.to_string()).into_response(),
}
}
pub(crate) async fn register_client_handler(
State(app_state): State<Arc<AppState>>,
Json(request): Json<DynamicClientRegistrationRequest>,
) -> impl IntoResponse {
match register_client(app_state.as_ref(), request) {
Ok(response) => (StatusCode::CREATED, Json(response)).into_response(),
Err(error) => (StatusCode::BAD_REQUEST, error.to_string()).into_response(),
}
}
pub(crate) async fn maybe_complete_mcp_authorization(
app_state: &AppState,
query: &OAuthCallbackQuery,
) -> Result<Option<Response>, OAuthCallbackFailure> {
let Some(state) = query.state.as_deref() else {
return Ok(None);
};
let has_pending = app_state
.has_pending_authorization(state)
.map_err(|error| OAuthCallbackFailure::internal(error.to_string()))?;
if !has_pending {
return Ok(None);
}
let pending = app_state
.take_pending_authorization(state)
.map_err(|error| OAuthCallbackFailure::internal(error.to_string()))?
.ok_or_else(|| {
OAuthCallbackFailure::bad_request("保留中の MCP authorization request が見つかりません")
})?;
if let Some(error) = query.error.as_deref() {
let description = query.error_description.as_deref().unwrap_or_default();
let suffix = if description.is_empty() {
String::new()
} else {
format!(": {description}")
};
return Err(OAuthCallbackFailure::bad_request(format!(
"Gyazo OAuth がエラーを返しました ({error}{suffix})"
)));
}
let code = query.code.as_deref().ok_or_else(|| {
OAuthCallbackFailure::bad_request("callback に Gyazo authorization code が含まれていません")
})?;
let token = exchange_code_for_token(app_state, code)
.await
.map_err(|error| OAuthCallbackFailure::bad_gateway(error.to_string()))?;
let redirect_uri = pending.redirect_uri.clone();
let client_state = pending.state.clone();
let authorization_code = issue_authorization_code(app_state, pending, token.access_token)
.map_err(|error| OAuthCallbackFailure::internal(error.to_string()))?;
let redirect_uri =
build_client_redirect_url(&redirect_uri, &authorization_code, client_state.as_deref());
Ok(Some(Redirect::temporary(&redirect_uri).into_response()))
}
fn start_authorization(
app_state: &AppState,
query: AuthorizationRequestQuery,
) -> Result<AuthorizationStart> {
let pending = validate_authorization_request(app_state, query)?;
if app_state.has_backend_api_credential()? {
let backend_access_token = app_state
.resolve_backend_access_token()?
.ok_or_else(|| anyhow!("Gyazo backend access token が見つかりません"))?;
let code = issue_authorization_code(app_state, pending.clone(), backend_access_token)?;
let redirect =
build_client_redirect_url(&pending.redirect_uri, &code, pending.state.as_deref());
return Ok(AuthorizationStart::Redirect(redirect));
}
let gyazo_state = uuid::Uuid::new_v4().to_string();
app_state.insert_pending_authorization(gyazo_state.clone(), pending)?;
let redirect = build_gyazo_authorize_url(app_state, &gyazo_state)?;
Ok(AuthorizationStart::Redirect(redirect))
}
fn validate_authorization_request(
app_state: &AppState,
query: AuthorizationRequestQuery,
) -> Result<PendingAuthorizationRequest> {
let response_type = query
.response_type
.as_deref()
.ok_or_else(|| anyhow!("response_type が必要です"))?;
if response_type != "code" {
bail!("response_type には code のみ指定できます");
}
let client_id = query
.client_id
.filter(|value| !value.trim().is_empty())
.ok_or_else(|| anyhow!("client_id が必要です"))?;
let redirect_uri = query
.redirect_uri
.filter(|value| !value.trim().is_empty())
.ok_or_else(|| anyhow!("redirect_uri が必要です"))?;
let registered_client = app_state
.registered_client(&client_id)?
.ok_or_else(|| anyhow!("client_id が登録されていません"))?;
if !registered_client
.redirect_uris
.iter()
.any(|uri| uri == &redirect_uri)
{
bail!("redirect_uri が登録内容と一致しません");
}
let code_challenge = query
.code_challenge
.filter(|value| !value.trim().is_empty())
.ok_or_else(|| anyhow!("code_challenge が必要です"))?;
let code_challenge_method = query
.code_challenge_method
.unwrap_or_else(|| "plain".to_string());
if code_challenge_method != "S256" {
bail!("code_challenge_method には S256 を指定してください");
}
if let Some(resource) = query.resource.as_deref()
&& resource != app_state.runtime_config().mcp_url()
{
bail!(
"resource には {} を指定してください",
app_state.runtime_config().mcp_url()
);
}
Ok(PendingAuthorizationRequest {
client_id,
redirect_uri,
state: query.state,
code_challenge,
resource: query.resource,
requested_scope: query.scope,
})
}
fn issue_authorization_code(
app_state: &AppState,
pending: PendingAuthorizationRequest,
backend_access_token: String,
) -> Result<String> {
let grant = AuthorizationCodeGrant {
client_id: pending.client_id,
redirect_uri: pending.redirect_uri,
code_challenge: pending.code_challenge,
resource: pending.resource,
scope: normalize_scope(pending.requested_scope.as_deref()),
backend_access_token,
};
app_state.issue_authorization_code(grant)
}
async fn exchange_authorization_code(
app_state: &AppState,
form: TokenRequestForm,
) -> Result<TokenResponse> {
let grant_type = form
.grant_type
.as_deref()
.ok_or_else(|| anyhow!("grant_type が必要です"))?;
if grant_type != "authorization_code" {
bail!("grant_type には authorization_code のみ指定できます");
}
let code = form
.code
.as_deref()
.ok_or_else(|| anyhow!("code が必要です"))?;
let client_id = form
.client_id
.as_deref()
.ok_or_else(|| anyhow!("client_id が必要です"))?;
let registered_client = app_state
.registered_client(client_id)?
.ok_or_else(|| anyhow!("client_id が登録されていません"))?;
let redirect_uri = form
.redirect_uri
.as_deref()
.ok_or_else(|| anyhow!("redirect_uri が必要です"))?;
if !registered_client
.redirect_uris
.iter()
.any(|registered| registered == redirect_uri)
{
bail!("redirect_uri が登録内容と一致しません");
}
let code_verifier = form
.code_verifier
.as_deref()
.ok_or_else(|| anyhow!("code_verifier が必要です"))?;
let grant = app_state
.take_authorization_code(code)?
.ok_or_else(|| anyhow!("authorization code が見つからないか、すでに使用されています"))?;
if grant.client_id != client_id {
bail!("client_id が一致しません");
}
if grant.redirect_uri != redirect_uri {
bail!("redirect_uri が一致しません");
}
if let Some(resource) = form.resource.as_deref()
&& Some(resource) != grant.resource.as_deref()
&& resource != app_state.runtime_config().mcp_url()
{
bail!("resource が一致しません");
}
verify_pkce(code_verifier, &grant.code_challenge)?;
let gyazo_user = fetch_authenticated_user(&grant.backend_access_token).await?;
let access_token = app_state.issue_access_token(AccessTokenRecord {
backend_access_token: grant.backend_access_token,
gyazo_user,
})?;
Ok(TokenResponse {
access_token,
token_type: "Bearer".to_string(),
scope: grant.scope,
})
}
fn unauthorized_response(app_state: &AppState, error: Option<&str>) -> Response {
let metadata_url = app_state.runtime_config().protected_resource_metadata_url();
let mut response = (
StatusCode::UNAUTHORIZED,
"/mcp には Bearer token が必要です。先にこのサーバーに対して MCP login を実行してください。",
)
.into_response();
let mut header_value =
format!(r#"Bearer resource_metadata="{metadata_url}", scope="{REQUIRED_SCOPE}""#);
if let Some(error) = error {
header_value.push_str(&format!(r#", error="{error}""#));
}
if let Ok(value) = HeaderValue::from_str(&header_value) {
response.headers_mut().insert(WWW_AUTHENTICATE, value);
}
response
}
pub(crate) fn authorized_session_from_request(
app_state: &AppState,
request: &Request<Body>,
) -> Result<Option<AuthorizedSession>> {
let Some(token) = extract_bearer_token(request.headers()) else {
return Ok(None);
};
app_state.authorized_session(token)
}
pub(crate) fn authorized_session_from_parts(
app_state: &AppState,
parts: &Parts,
) -> Result<Option<AuthorizedSession>> {
let Some(token) = extract_bearer_token(&parts.headers) else {
return Ok(None);
};
app_state.authorized_session(token)
}
pub(crate) async fn get_verified_session(app_state: &AppState) -> Option<AuthorizedSession> {
{
let cache = app_state
.verified_session_cache()
.read()
.expect("verified session cache lock is poisoned");
if let Some((checked_at, ref cached)) = *cache
&& checked_at.elapsed() < app_state.verified_session_ttl()
{
return cached.clone();
}
}
let tokens = app_state.collect_backend_access_tokens().ok()?;
let tokens_was_empty = tokens.is_empty();
let mut session = None;
for token in tokens {
match fetch_authenticated_user(&token).await {
Ok(gyazo_user) => {
tracing::debug!("トークン疎通確認成功");
session = Some(AuthorizedSession {
record: AccessTokenRecord {
backend_access_token: token,
gyazo_user,
},
});
break;
}
Err(e) => {
tracing::debug!("トークン疎通確認スキップ: {e}");
}
}
}
if session.is_none() && !tokens_was_empty {
tracing::warn!("有効なトークンが見つかりませんでした");
}
let mut cache = app_state
.verified_session_cache()
.write()
.expect("verified session cache lock is poisoned");
*cache = Some((std::time::Instant::now(), session.clone()));
session
}
fn extract_bearer_token(headers: &axum::http::HeaderMap) -> Option<&str> {
let value = headers.get(AUTHORIZATION)?.to_str().ok()?;
value
.strip_prefix("Bearer ")
.map(str::trim)
.filter(|token| !token.is_empty())
}
fn build_protected_resource_metadata(app_state: &AppState) -> ProtectedResourceMetadata {
let runtime_config = app_state.runtime_config();
ProtectedResourceMetadata {
resource: runtime_config.mcp_url(),
authorization_servers: vec![runtime_config.authorization_server_issuer()],
scopes_supported: vec![REQUIRED_SCOPE.to_string()],
}
}
fn build_authorization_server_metadata(app_state: &AppState) -> AuthorizationServerMetadata {
let runtime_config = app_state.runtime_config();
AuthorizationServerMetadata {
issuer: runtime_config.authorization_server_issuer(),
authorization_endpoint: runtime_config.authorization_endpoint_url(),
token_endpoint: runtime_config.token_endpoint_url(),
registration_endpoint: runtime_config.registration_endpoint_url(),
response_types_supported: vec!["code".to_string()],
grant_types_supported: vec!["authorization_code".to_string()],
token_endpoint_auth_methods_supported: vec!["none".to_string()],
code_challenge_methods_supported: vec!["S256".to_string()],
}
}
fn register_client(
app_state: &AppState,
request: DynamicClientRegistrationRequest,
) -> Result<DynamicClientRegistrationResponse> {
let redirect_uris = request
.redirect_uris
.filter(|uris| !uris.is_empty())
.ok_or_else(|| anyhow!("redirect_uris が必要です"))?;
if let Some(method) = request.token_endpoint_auth_method.as_deref()
&& method != "none"
{
bail!("token_endpoint_auth_method には none のみ指定できます");
}
if let Some(grant_types) = request.grant_types.as_ref()
&& !grant_types
.iter()
.any(|grant| grant == "authorization_code")
{
bail!("grant_types には authorization_code が必要です");
}
if let Some(response_types) = request.response_types.as_ref()
&& !response_types.iter().any(|response| response == "code")
{
bail!("response_types には code が必要です");
}
let client_name = request.client_name.filter(|name| !name.trim().is_empty());
let client_id = app_state.register_client(RegisteredClient {
redirect_uris: redirect_uris.clone(),
})?;
Ok(DynamicClientRegistrationResponse {
client_id,
redirect_uris,
client_name,
grant_types: vec!["authorization_code".to_string()],
response_types: vec!["code".to_string()],
token_endpoint_auth_method: "none".to_string(),
})
}
fn build_client_redirect_url(base: &str, code: &str, state: Option<&str>) -> String {
let separator = if base.contains('?') { '&' } else { '?' };
let mut url = format!("{base}{separator}code={}", percent_encode(code));
if let Some(state) = state {
url.push_str("&state=");
url.push_str(&percent_encode(state));
}
url
}
fn verify_pkce(code_verifier: &str, code_challenge: &str) -> Result<()> {
let mut hasher = Sha256::new();
hasher.update(code_verifier.as_bytes());
let digest = hasher.finalize();
let actual = URL_SAFE_NO_PAD.encode(digest);
if actual != code_challenge {
bail!("code_verifier が一致しません");
}
Ok(())
}
fn normalize_scope(requested_scope: Option<&str>) -> String {
match requested_scope {
Some(scope) if !scope.trim().is_empty() => scope.to_string(),
_ => REQUIRED_SCOPE.to_string(),
}
}
fn percent_encode(value: &str) -> String {
let mut encoded = String::with_capacity(value.len());
for byte in value.bytes() {
let is_unreserved =
byte.is_ascii_alphanumeric() || matches!(byte, b'-' | b'_' | b'.' | b'~');
if is_unreserved {
encoded.push(byte as char);
} else {
encoded.push_str(&format!("%{:02X}", byte));
}
}
encoded
}
enum AuthorizationStart {
Redirect(String),
}
#[cfg(test)]
mod tests {
use axum::http::HeaderMap;
use axum::http::header::AUTHORIZATION;
use super::{build_client_redirect_url, extract_bearer_token, verify_pkce};
#[test]
fn appends_code_and_state_to_redirect_uri() {
let url = build_client_redirect_url(
"http://127.0.0.1:3000/callback",
"code-123",
Some("state-456"),
);
assert_eq!(
url,
"http://127.0.0.1:3000/callback?code=code-123&state=state-456"
);
}
#[test]
fn accepts_matching_s256_pkce() {
verify_pkce(
"dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk",
"E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM",
)
.unwrap();
}
#[test]
fn extract_bearer_token_returns_none_when_no_authorization_header() {
let headers = HeaderMap::new();
assert!(extract_bearer_token(&headers).is_none());
}
#[test]
fn extract_bearer_token_returns_token_for_valid_bearer() {
let mut headers = HeaderMap::new();
headers.insert(AUTHORIZATION, "Bearer my-token".parse().unwrap());
assert_eq!(extract_bearer_token(&headers), Some("my-token"));
}
#[test]
fn extract_bearer_token_returns_none_for_empty_bearer() {
let mut headers = HeaderMap::new();
headers.insert(AUTHORIZATION, "Bearer ".parse().unwrap());
assert!(extract_bearer_token(&headers).is_none());
}
#[test]
fn has_authorization_header_distinguishes_present_vs_absent() {
let empty = HeaderMap::new();
assert!(
!empty.contains_key(AUTHORIZATION),
"Authorization なし → fallback 許可"
);
let mut with_invalid = HeaderMap::new();
with_invalid.insert(AUTHORIZATION, "Bearer invalid-xyz".parse().unwrap());
assert!(
with_invalid.contains_key(AUTHORIZATION),
"無効な Bearer あり → fallback 禁止"
);
}
use std::sync::Arc;
use axum::Router;
use axum::body::Body;
use axum::http::{Request, StatusCode};
use axum::middleware;
use axum::routing::get;
use tower::util::ServiceExt;
use crate::app_state::AppState;
use crate::runtime_config::RuntimeConfig;
fn test_app_state() -> Arc<AppState> {
Arc::new(AppState::new_for_test(RuntimeConfig::for_test()))
}
fn test_router(app_state: Arc<AppState>) -> Router {
async fn ok_handler() -> &'static str {
"ok"
}
Router::new()
.route("/test", get(ok_handler))
.route_layer(middleware::from_fn_with_state(
app_state.clone(),
super::require_mcp_bearer_token,
))
.with_state(app_state)
}
#[tokio::test]
async fn middleware_rejects_request_without_bearer_when_no_saved_token() {
let app = test_router(test_app_state());
let response = app
.oneshot(Request::builder().uri("/test").body(Body::empty()).unwrap())
.await
.unwrap();
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn middleware_rejects_invalid_bearer_even_with_no_saved_token() {
let app = test_router(test_app_state());
let response = app
.oneshot(
Request::builder()
.uri("/test")
.header("Authorization", "Bearer invalid-token-xyz")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn middleware_does_not_fallback_for_invalid_bearer() {
let app_state = test_app_state();
{
let mut cache = app_state.verified_session_cache().write().unwrap();
*cache = Some((
std::time::Instant::now(),
Some(crate::app_state::AuthorizedSession {
record: crate::app_state::AccessTokenRecord {
backend_access_token: "fake-backend-token".to_string(),
gyazo_user: crate::gyazo_api::GyazoUserProfile {
email: String::new(),
name: String::new(),
profile_image: String::new(),
uid: String::new(),
},
},
}),
));
}
let app = test_router(app_state);
let response = app
.oneshot(
Request::builder()
.uri("/test")
.header("Authorization", "Bearer invalid-token-xyz")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn middleware_falls_back_to_verified_session_without_authorization_header() {
let app_state = test_app_state();
{
let mut cache = app_state.verified_session_cache().write().unwrap();
*cache = Some((
std::time::Instant::now(),
Some(crate::app_state::AuthorizedSession {
record: crate::app_state::AccessTokenRecord {
backend_access_token: "fake-backend-token".to_string(),
gyazo_user: crate::gyazo_api::GyazoUserProfile {
email: String::new(),
name: String::new(),
profile_image: String::new(),
uid: String::new(),
},
},
}),
));
}
let app = test_router(app_state);
let response = app
.oneshot(Request::builder().uri("/test").body(Body::empty()).unwrap())
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
}
}