mod token_store;
pub use token_store::{ChainId, MintedPair, RefreshError, TokenStore};
use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::{SystemTime, UNIX_EPOCH};
use axum::{
extract::{Query, State},
http::{header, HeaderMap, StatusCode},
response::IntoResponse,
routing::{get, post},
Form, Json, Router,
};
use base64::Engine;
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use tokio::sync::RwLock;
const AUTH_CODE_TTL_SECS: u64 = 300;
const ALLOWED_REDIRECT_URI_PREFIXES: &[&str] =
&["https://claude.ai/api/mcp/", "https://claude.com/api/mcp/"];
pub const DEFAULT_ACCESS_TOKEN_TTL_SECS: u64 = 7 * 24 * 3600; pub const DEFAULT_REFRESH_TOKEN_TTL_SECS: u64 = 90 * 24 * 3600;
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct OAuthConfig {
pub client_id: String,
pub client_secret: String,
pub issuer: String,
#[serde(default)]
pub access_token_ttl_secs: Option<u64>,
#[serde(default)]
pub refresh_token_ttl_secs: Option<u64>,
}
impl OAuthConfig {
pub fn effective_access_ttl(&self) -> std::time::Duration {
std::time::Duration::from_secs(
self.access_token_ttl_secs
.unwrap_or(DEFAULT_ACCESS_TOKEN_TTL_SECS),
)
}
pub fn effective_refresh_ttl(&self) -> std::time::Duration {
std::time::Duration::from_secs(
self.refresh_token_ttl_secs
.unwrap_or(DEFAULT_REFRESH_TOKEN_TTL_SECS),
)
}
}
pub fn config_path() -> Option<PathBuf> {
directories::ProjectDirs::from("dev", "things-mcp", "things-mcp")
.map(|d| d.config_dir().join("oauth.toml"))
}
impl OAuthConfig {
pub fn load_or_generate(issuer_hint: Option<String>) -> anyhow::Result<Option<Self>> {
let Some(path) = config_path() else {
tracing::warn!("could not resolve ProjectDirs for OAuth config; OAuth disabled");
return Ok(None);
};
if path.exists() {
let bytes = std::fs::read(&path)
.map_err(|e| anyhow::anyhow!("read {}: {e}", path.display()))?;
let config: OAuthConfig = toml::from_str(std::str::from_utf8(&bytes)?)
.map_err(|e| anyhow::anyhow!("parse {}: {e}", path.display()))?;
tracing::info!(path = %path.display(), "loaded OAuth config");
return Ok(Some(config));
}
let Some(issuer) = issuer_hint else {
tracing::warn!(
path = %path.display(),
"OAuth config not found and no THINGS_MCP_OAUTH_ISSUER set; OAuth disabled"
);
return Ok(None);
};
let config = OAuthConfig {
client_id: format!("things-mcp-{}", short_id()),
client_secret: format!("{:032x}", rand::random::<u128>()),
issuer,
access_token_ttl_secs: None,
refresh_token_ttl_secs: None,
};
Self::write_secure(&path, &config)?;
tracing::warn!(
path = %path.display(),
client_id = %config.client_id,
"generated OAuth credentials — paste these into the Claude.ai connector's Advanced fields"
);
eprintln!(
"\n=== things-mcp OAuth credentials generated at {} ===\n client_id = {}\n client_secret = {}\n issuer = {}\n → paste client_id + client_secret into Claude.ai connector → Advanced → OAuth fields\n",
path.display(),
config.client_id,
config.client_secret,
config.issuer,
);
Ok(Some(config))
}
fn write_secure(path: &std::path::Path, config: &OAuthConfig) -> anyhow::Result<()> {
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent)
.map_err(|e| anyhow::anyhow!("mkdir {}: {e}", parent.display()))?;
}
let serialized = toml::to_string_pretty(config)?;
std::fs::write(path, serialized)
.map_err(|e| anyhow::anyhow!("write {}: {e}", path.display()))?;
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
let perms = std::fs::Permissions::from_mode(0o600);
std::fs::set_permissions(path, perms)
.map_err(|e| anyhow::anyhow!("chmod 0600 {}: {e}", path.display()))?;
}
Ok(())
}
}
fn short_id() -> String {
format!("{:08x}", rand::random::<u32>())
}
#[derive(Clone)]
pub struct OAuthState {
inner: Arc<Inner>,
}
struct Inner {
config: OAuthConfig,
codes: RwLock<HashMap<String, AuthCode>>,
tokens: TokenStore,
}
#[derive(Clone)]
struct AuthCode {
code_challenge: String,
redirect_uri: String,
expires_at: u64,
}
impl OAuthState {
pub fn with_tokens_path(config: OAuthConfig, tokens_path: PathBuf) -> anyhow::Result<Self> {
let access_ttl = config.effective_access_ttl();
let refresh_ttl = config.effective_refresh_ttl();
let tokens = TokenStore::load(tokens_path, &config.client_id, access_ttl, refresh_ttl)?;
Ok(Self {
inner: Arc::new(Inner {
config,
codes: RwLock::new(HashMap::new()),
tokens,
}),
})
}
pub fn from_default_path(config: OAuthConfig) -> anyhow::Result<Self> {
let dir = directories::ProjectDirs::from("dev", "things-mcp", "things-mcp")
.ok_or_else(|| anyhow::anyhow!("could not resolve ProjectDirs for tokens.json"))?
.config_dir()
.to_path_buf();
Self::with_tokens_path(config, dir.join("tokens.json"))
}
pub fn issuer(&self) -> &str {
&self.inner.config.issuer
}
pub fn resource_metadata_url(&self) -> String {
format!(
"{}/.well-known/oauth-protected-resource",
self.inner.config.issuer
)
}
pub async fn validate_token(&self, token: &str) -> bool {
self.inner.tokens.validate_access(token).await
}
#[cfg(test)]
pub(crate) fn token_store(&self) -> &TokenStore {
&self.inner.tokens
}
}
fn unix_now() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
}
#[derive(Serialize)]
struct AuthorizationServerMetadata {
issuer: String,
authorization_endpoint: String,
token_endpoint: String,
grant_types_supported: &'static [&'static str],
token_endpoint_auth_methods_supported: &'static [&'static str],
response_types_supported: &'static [&'static str],
code_challenge_methods_supported: &'static [&'static str],
scopes_supported: &'static [&'static str],
}
#[derive(Serialize)]
struct ProtectedResourceMetadata {
resource: String,
authorization_servers: Vec<String>,
bearer_methods_supported: &'static [&'static str],
scopes_supported: &'static [&'static str],
}
async fn authorization_server_metadata(
State(state): State<OAuthState>,
) -> Json<AuthorizationServerMetadata> {
let issuer = state.issuer().to_string();
Json(AuthorizationServerMetadata {
authorization_endpoint: format!("{issuer}/authorize"),
token_endpoint: format!("{issuer}/oauth/token"),
issuer,
grant_types_supported: &["authorization_code", "refresh_token", "client_credentials"],
token_endpoint_auth_methods_supported: &["client_secret_post", "client_secret_basic"],
response_types_supported: &["code", "token"],
code_challenge_methods_supported: &["S256"],
scopes_supported: &["mcp"],
})
}
async fn protected_resource_metadata(
State(state): State<OAuthState>,
) -> Json<ProtectedResourceMetadata> {
let issuer = state.issuer().to_string();
Json(ProtectedResourceMetadata {
authorization_servers: vec![issuer.clone()],
resource: issuer,
bearer_methods_supported: &["header"],
scopes_supported: &["mcp"],
})
}
#[derive(Deserialize)]
struct TokenRequest {
grant_type: String,
client_id: Option<String>,
client_secret: Option<String>,
#[serde(default)]
code: Option<String>,
#[serde(default)]
code_verifier: Option<String>,
#[serde(default)]
redirect_uri: Option<String>,
#[allow(dead_code)]
#[serde(default)]
resource: Option<String>,
#[allow(dead_code)]
#[serde(default)]
scope: Option<String>,
#[serde(default)]
refresh_token: Option<String>,
}
#[derive(Serialize)]
struct TokenResponse {
access_token: String,
token_type: &'static str,
expires_in: u64,
#[serde(skip_serializing_if = "Option::is_none")]
refresh_token: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
refresh_expires_in: Option<u64>,
scope: &'static str,
}
#[derive(Serialize)]
struct OAuthError {
error: &'static str,
#[serde(skip_serializing_if = "Option::is_none")]
error_description: Option<&'static str>,
}
async fn token_handler(
State(state): State<OAuthState>,
headers: HeaderMap,
Form(body): Form<TokenRequest>,
) -> axum::response::Response {
match body.grant_type.as_str() {
"authorization_code" => handle_authorization_code(state, headers, body).await,
"client_credentials" => handle_client_credentials(state, headers, body).await,
"refresh_token" => handle_refresh_token(state, headers, body).await,
_ => (
StatusCode::BAD_REQUEST,
Json(OAuthError {
error: "unsupported_grant_type",
error_description: Some(
"supported grant types: authorization_code, refresh_token, client_credentials",
),
}),
)
.into_response(),
}
}
async fn handle_client_credentials(
state: OAuthState,
headers: HeaderMap,
body: TokenRequest,
) -> axum::response::Response {
let Some((client_id, client_secret)) = resolve_client_credentials(&headers, &body) else {
return invalid_client();
};
let expected = &state.inner.config;
if !constant_time_eq(client_id.as_bytes(), expected.client_id.as_bytes())
|| !constant_time_eq(client_secret.as_bytes(), expected.client_secret.as_bytes())
{
return invalid_client();
}
let pair = match state.inner.tokens.mint_pair(None).await {
Ok(p) => p,
Err(e) => {
tracing::error!(error = %e, "mint_pair failed for client_credentials");
return server_error();
}
};
let token = pair.access_token;
let ttl = pair.access_ttl.as_secs();
tracing::info!(
client_id = %client_id,
grant = "client_credentials",
expires_in = ttl,
"OAuth token minted"
);
token_ok_access_only(token, ttl)
}
async fn handle_authorization_code(
state: OAuthState,
headers: HeaderMap,
body: TokenRequest,
) -> axum::response::Response {
let Some(code) = body.code.as_deref() else {
return invalid_grant("missing code");
};
let Some(verifier) = body.code_verifier.as_deref() else {
return invalid_grant("missing code_verifier");
};
let Some(redirect_uri) = body.redirect_uri.as_deref() else {
return invalid_grant("missing redirect_uri");
};
if let Some((client_id, client_secret)) = resolve_client_credentials(&headers, &body) {
let expected = &state.inner.config;
if !constant_time_eq(client_id.as_bytes(), expected.client_id.as_bytes())
|| !constant_time_eq(client_secret.as_bytes(), expected.client_secret.as_bytes())
{
return invalid_client();
}
} else if let Some(client_id) = body.client_id.as_deref() {
if !constant_time_eq(
client_id.as_bytes(),
state.inner.config.client_id.as_bytes(),
) {
return invalid_client();
}
}
let info = state.inner.codes.write().await.remove(code);
let Some(info) = info else {
return invalid_grant("unknown or already-used code");
};
if info.expires_at < unix_now() {
return invalid_grant("code expired");
}
if info.redirect_uri != redirect_uri {
return invalid_grant("redirect_uri mismatch");
}
let computed = pkce_s256(verifier);
if !constant_time_eq(computed.as_bytes(), info.code_challenge.as_bytes()) {
return invalid_grant("PKCE verification failed");
}
let pair = match state.inner.tokens.mint_pair(None).await {
Ok(p) => p,
Err(e) => {
tracing::error!(error = %e, "mint_pair failed for authorization_code");
return server_error();
}
};
tracing::info!(
grant = "authorization_code",
chain_id = %pair.chain_id,
expires_in = pair.access_ttl.as_secs(),
"OAuth token pair minted"
);
token_ok_pair(pair)
}
async fn handle_refresh_token(
state: OAuthState,
headers: HeaderMap,
body: TokenRequest,
) -> axum::response::Response {
if let Some((client_id, client_secret)) = resolve_client_credentials(&headers, &body) {
let expected = &state.inner.config;
if !constant_time_eq(client_id.as_bytes(), expected.client_id.as_bytes())
|| !constant_time_eq(client_secret.as_bytes(), expected.client_secret.as_bytes())
{
return invalid_client();
}
} else if let Some(client_id) = body.client_id.as_deref() {
if !constant_time_eq(
client_id.as_bytes(),
state.inner.config.client_id.as_bytes(),
) {
return invalid_client();
}
}
let Some(presented) = body.refresh_token.as_deref() else {
return invalid_grant("missing refresh_token");
};
let chain_id = match state.inner.tokens.consume_refresh(presented).await {
Ok(chain) => chain,
Err(RefreshError::Replayed(chain)) => {
tracing::warn!(chain_id = %chain, "refresh-token replay detected; revoking chain");
state.inner.tokens.revoke_chain(chain).await;
return invalid_grant("refresh token replay");
}
Err(RefreshError::Expired) => return invalid_grant("refresh token expired"),
Err(RefreshError::Unknown) => return invalid_grant("unknown refresh token"),
};
let pair = match state.inner.tokens.mint_pair(Some(chain_id.clone())).await {
Ok(p) => p,
Err(e) => {
tracing::error!(error = %e, "mint_pair failed during refresh_token grant");
return server_error();
}
};
tracing::info!(
grant = "refresh_token",
chain_id = %chain_id,
expires_in = pair.access_ttl.as_secs(),
"OAuth token pair minted (refreshed)"
);
token_ok_pair(pair)
}
fn token_ok_access_only(token: String, ttl: u64) -> axum::response::Response {
(
StatusCode::OK,
Json(TokenResponse {
access_token: token,
token_type: "Bearer",
expires_in: ttl,
refresh_token: None,
refresh_expires_in: None,
scope: "mcp",
}),
)
.into_response()
}
fn token_ok_pair(pair: MintedPair) -> axum::response::Response {
(
StatusCode::OK,
Json(TokenResponse {
access_token: pair.access_token,
token_type: "Bearer",
expires_in: pair.access_ttl.as_secs(),
refresh_token: Some(pair.refresh_token),
refresh_expires_in: Some(pair.refresh_ttl.as_secs()),
scope: "mcp",
}),
)
.into_response()
}
fn invalid_grant(detail: &'static str) -> axum::response::Response {
tracing::info!(detail, "OAuth grant rejected");
(
StatusCode::BAD_REQUEST,
Json(OAuthError {
error: "invalid_grant",
error_description: Some(detail),
}),
)
.into_response()
}
fn pkce_s256(verifier: &str) -> String {
let digest = Sha256::digest(verifier.as_bytes());
base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(digest)
}
#[derive(Deserialize)]
struct AuthorizeQuery {
response_type: String,
client_id: String,
redirect_uri: String,
code_challenge: String,
code_challenge_method: String,
state: String,
#[allow(dead_code)]
#[serde(default)]
scope: Option<String>,
#[allow(dead_code)]
#[serde(default)]
resource: Option<String>,
}
async fn authorize_handler(
State(state): State<OAuthState>,
Query(q): Query<AuthorizeQuery>,
) -> axum::response::Response {
let redirect_ok = ALLOWED_REDIRECT_URI_PREFIXES
.iter()
.any(|p| q.redirect_uri.starts_with(p));
if !redirect_ok {
tracing::warn!(redirect_uri = %q.redirect_uri, "authorize: redirect_uri not allowed");
return (StatusCode::BAD_REQUEST, "invalid_redirect_uri").into_response();
}
if q.response_type != "code" {
return redirect_with_error(&q.redirect_uri, &q.state, "unsupported_response_type");
}
if q.code_challenge_method != "S256" {
return redirect_with_error(&q.redirect_uri, &q.state, "invalid_request");
}
if !constant_time_eq(
q.client_id.as_bytes(),
state.inner.config.client_id.as_bytes(),
) {
return redirect_with_error(&q.redirect_uri, &q.state, "unauthorized_client");
}
if q.code_challenge.is_empty() {
return redirect_with_error(&q.redirect_uri, &q.state, "invalid_request");
}
let code = format!("{:032x}", rand::random::<u128>());
let info = AuthCode {
code_challenge: q.code_challenge,
redirect_uri: q.redirect_uri.clone(),
expires_at: unix_now() + AUTH_CODE_TTL_SECS,
};
state.inner.codes.write().await.insert(code.clone(), info);
tracing::info!(redirect_uri = %q.redirect_uri, "authorization code issued");
let location = format!(
"{}?code={}&state={}",
q.redirect_uri,
urlencoding_minimal(&code),
urlencoding_minimal(&q.state),
);
(StatusCode::FOUND, [(header::LOCATION, location.as_str())]).into_response()
}
fn redirect_with_error(redirect_uri: &str, state: &str, error: &str) -> axum::response::Response {
let location = format!(
"{redirect_uri}?error={}&state={}",
urlencoding_minimal(error),
urlencoding_minimal(state)
);
(StatusCode::FOUND, [(header::LOCATION, location.as_str())]).into_response()
}
fn urlencoding_minimal(s: &str) -> String {
let mut out = String::with_capacity(s.len());
for c in s.chars() {
match c {
'A'..='Z' | 'a'..='z' | '0'..='9' | '-' | '_' | '.' | '~' => out.push(c),
_ => out.push_str(&format!("%{:02X}", c as u32)),
}
}
out
}
fn resolve_client_credentials(
headers: &HeaderMap,
body: &TokenRequest,
) -> Option<(String, String)> {
if let (Some(id), Some(secret)) = (body.client_id.as_ref(), body.client_secret.as_ref()) {
return Some((id.clone(), secret.clone()));
}
let auth = headers.get(header::AUTHORIZATION)?.to_str().ok()?;
let encoded = auth.strip_prefix("Basic ")?;
let bytes = base64::engine::general_purpose::STANDARD
.decode(encoded.trim())
.ok()?;
let decoded = String::from_utf8(bytes).ok()?;
let (id, secret) = decoded.split_once(':')?;
Some((id.to_string(), secret.to_string()))
}
fn invalid_client() -> axum::response::Response {
(
StatusCode::UNAUTHORIZED,
[(header::WWW_AUTHENTICATE, "Basic realm=\"oauth/token\"")],
Json(OAuthError {
error: "invalid_client",
error_description: None,
}),
)
.into_response()
}
fn server_error() -> axum::response::Response {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(OAuthError {
error: "server_error",
error_description: None,
}),
)
.into_response()
}
fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
if a.len() != b.len() {
return false;
}
let mut diff: u8 = 0;
for (x, y) in a.iter().zip(b.iter()) {
diff |= x ^ y;
}
diff == 0
}
pub fn router(state: OAuthState) -> Router {
Router::new()
.route(
"/.well-known/oauth-authorization-server",
get(authorization_server_metadata),
)
.route(
"/.well-known/openid-configuration",
get(authorization_server_metadata),
)
.route(
"/.well-known/oauth-protected-resource",
get(protected_resource_metadata),
)
.route("/authorize", get(authorize_handler))
.route("/oauth/token", post(token_handler))
.with_state(state)
}
#[cfg(test)]
mod tests {
use super::*;
use axum::body::{to_bytes, Body};
use axum::http::Request;
use tower::ServiceExt;
#[test]
fn config_loads_with_default_ttls_when_unset() {
let toml_str = r#"
client_id = "x"
client_secret = "y"
issuer = "https://example.test"
"#;
let cfg: OAuthConfig = toml::from_str(toml_str).unwrap();
assert_eq!(cfg.access_token_ttl_secs, None);
assert_eq!(cfg.refresh_token_ttl_secs, None);
assert_eq!(cfg.effective_access_ttl().as_secs(), 7 * 24 * 3600);
assert_eq!(cfg.effective_refresh_ttl().as_secs(), 90 * 24 * 3600);
}
#[test]
fn config_loads_with_explicit_ttls() {
let toml_str = r#"
client_id = "x"
client_secret = "y"
issuer = "https://example.test"
access_token_ttl_secs = 3600
refresh_token_ttl_secs = 86400
"#;
let cfg: OAuthConfig = toml::from_str(toml_str).unwrap();
assert_eq!(cfg.effective_access_ttl().as_secs(), 3600);
assert_eq!(cfg.effective_refresh_ttl().as_secs(), 86400);
}
#[test]
fn config_roundtrips_through_disk_with_secure_perms() {
let dir = tempdir();
let path = dir.join("oauth.toml");
let original = OAuthConfig {
client_id: "id-x".into(),
client_secret: "secret-y".into(),
issuer: "https://example.test".into(),
access_token_ttl_secs: None,
refresh_token_ttl_secs: None,
};
OAuthConfig::write_secure(&path, &original).unwrap();
let bytes = std::fs::read(&path).unwrap();
let parsed: OAuthConfig = toml::from_str(std::str::from_utf8(&bytes).unwrap()).unwrap();
assert_eq!(parsed.client_id, "id-x");
assert_eq!(parsed.client_secret, "secret-y");
assert_eq!(parsed.issuer, "https://example.test");
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
let mode = std::fs::metadata(&path).unwrap().permissions().mode() & 0o777;
assert_eq!(mode, 0o600, "file should be readable only by owner");
}
}
fn tempdir() -> PathBuf {
let p = std::env::temp_dir().join(format!("things-mcp-test-{}", rand::random::<u64>()));
std::fs::create_dir_all(&p).unwrap();
p
}
fn test_state() -> OAuthState {
let dir = tempdir();
OAuthState::with_tokens_path(
OAuthConfig {
client_id: "test-id".into(),
client_secret: "test-secret".into(),
issuer: "https://example.test".into(),
access_token_ttl_secs: None,
refresh_token_ttl_secs: None,
},
dir.join("tokens.json"),
)
.unwrap()
}
async fn body_string(resp: axum::response::Response) -> String {
let bytes = to_bytes(resp.into_body(), 64 * 1024).await.unwrap();
String::from_utf8(bytes.to_vec()).unwrap()
}
#[tokio::test]
async fn token_endpoint_issues_for_valid_credentials_via_body() {
let app = router(test_state());
let resp = app
.oneshot(
Request::builder()
.method("POST")
.uri("/oauth/token")
.header("content-type", "application/x-www-form-urlencoded")
.body(Body::from(
"grant_type=client_credentials&client_id=test-id&client_secret=test-secret",
))
.unwrap(),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let body = body_string(resp).await;
assert!(body.contains("\"access_token\""), "body was: {body}");
assert!(body.contains("\"token_type\":\"Bearer\""));
assert!(body.contains("\"expires_in\":604800"));
}
#[tokio::test]
async fn token_endpoint_accepts_basic_auth() {
let app = router(test_state());
let basic = base64::engine::general_purpose::STANDARD.encode("test-id:test-secret");
let resp = app
.oneshot(
Request::builder()
.method("POST")
.uri("/oauth/token")
.header("content-type", "application/x-www-form-urlencoded")
.header("authorization", format!("Basic {basic}"))
.body(Body::from("grant_type=client_credentials"))
.unwrap(),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
}
#[tokio::test]
async fn token_endpoint_rejects_bad_secret() {
let app = router(test_state());
let resp = app
.oneshot(
Request::builder()
.method("POST")
.uri("/oauth/token")
.header("content-type", "application/x-www-form-urlencoded")
.body(Body::from(
"grant_type=client_credentials&client_id=test-id&client_secret=WRONG",
))
.unwrap(),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
let body = body_string(resp).await;
assert!(body.contains("\"error\":\"invalid_client\""));
}
#[tokio::test]
async fn token_endpoint_rejects_unsupported_grant() {
let app = router(test_state());
let resp = app
.oneshot(
Request::builder()
.method("POST")
.uri("/oauth/token")
.header("content-type", "application/x-www-form-urlencoded")
.body(Body::from(
"grant_type=password&client_id=test-id&client_secret=test-secret",
))
.unwrap(),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
let body = body_string(resp).await;
assert!(body.contains("unsupported_grant_type"));
}
#[tokio::test]
async fn minted_token_validates_then_expires() {
let state = test_state();
let pair = state.inner.tokens.mint_pair(None).await.unwrap();
assert!(state.validate_token(&pair.access_token).await);
assert!(!state.validate_token("not-issued").await);
}
#[tokio::test]
async fn discovery_documents_advertise_correct_endpoints() {
let app = router(test_state());
let resp = app
.clone()
.oneshot(
Request::builder()
.uri("/.well-known/oauth-authorization-server")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let body = body_string(resp).await;
assert!(body.contains("\"issuer\":\"https://example.test\""));
assert!(body.contains("\"authorization_endpoint\":\"https://example.test/authorize\""));
assert!(body.contains("\"token_endpoint\":\"https://example.test/oauth/token\""));
assert!(body.contains("\"authorization_code\""));
assert!(body.contains("\"client_credentials\""));
assert!(
body.contains("\"refresh_token\""),
"discovery must advertise refresh_token grant; body was: {body}"
);
assert!(body.contains("\"code_challenge_methods_supported\":[\"S256\"]"));
let resp = app
.oneshot(
Request::builder()
.uri("/.well-known/oauth-protected-resource")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let body = body_string(resp).await;
assert!(body.contains("\"resource\":\"https://example.test\""));
assert!(body.contains("\"authorization_servers\":[\"https://example.test\"]"));
}
#[test]
fn pkce_s256_matches_rfc7636_example() {
let verifier = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk";
let expected = "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM";
assert_eq!(pkce_s256(verifier), expected);
}
fn challenge_for(verifier: &str) -> String {
pkce_s256(verifier)
}
#[tokio::test]
async fn authorize_endpoint_redirects_with_code() {
let app = router(test_state());
let verifier = "test-verifier-string-of-reasonable-length-1234";
let challenge = challenge_for(verifier);
let uri = format!(
"/authorize?response_type=code&client_id=test-id&redirect_uri=https%3A%2F%2Fclaude.ai%2Fapi%2Fmcp%2Fauth_callback&code_challenge={challenge}&code_challenge_method=S256&state=xyz&scope=mcp",
);
let resp = app
.oneshot(Request::builder().uri(uri).body(Body::empty()).unwrap())
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::FOUND);
let location = resp
.headers()
.get(header::LOCATION)
.unwrap()
.to_str()
.unwrap();
assert!(location.starts_with("https://claude.ai/api/mcp/auth_callback?code="));
assert!(location.contains("&state=xyz"));
}
#[tokio::test]
async fn authorize_rejects_disallowed_redirect_uri() {
let app = router(test_state());
let uri = "/authorize?response_type=code&client_id=test-id&redirect_uri=https%3A%2F%2Fattacker.example%2Fcb&code_challenge=abc&code_challenge_method=S256&state=z";
let resp = app
.oneshot(Request::builder().uri(uri).body(Body::empty()).unwrap())
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
}
#[tokio::test]
async fn authorize_rejects_unknown_client_id() {
let app = router(test_state());
let uri = "/authorize?response_type=code&client_id=WRONG&redirect_uri=https%3A%2F%2Fclaude.ai%2Fapi%2Fmcp%2Fauth_callback&code_challenge=abc&code_challenge_method=S256&state=z";
let resp = app
.oneshot(Request::builder().uri(uri).body(Body::empty()).unwrap())
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::FOUND);
let location = resp
.headers()
.get(header::LOCATION)
.unwrap()
.to_str()
.unwrap();
assert!(location.contains("error=unauthorized_client"));
}
#[tokio::test]
async fn auth_code_grant_full_flow_succeeds() {
let state = test_state();
let verifier = "the-verifier-anthropic-would-have-generated";
let challenge = challenge_for(verifier);
let redirect_uri = "https://claude.ai/api/mcp/auth_callback";
let auth_uri = format!(
"/authorize?response_type=code&client_id=test-id&redirect_uri=https%3A%2F%2Fclaude.ai%2Fapi%2Fmcp%2Fauth_callback&code_challenge={challenge}&code_challenge_method=S256&state=opaque-state",
);
let resp = router(state.clone())
.oneshot(
Request::builder()
.uri(auth_uri)
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::FOUND);
let location = resp
.headers()
.get(header::LOCATION)
.unwrap()
.to_str()
.unwrap()
.to_string();
let code = location
.split_once("code=")
.and_then(|(_, rest)| rest.split('&').next())
.unwrap()
.to_string();
let body = format!(
"grant_type=authorization_code&code={code}&redirect_uri={}&code_verifier={verifier}&client_id=test-id",
urlencoding_minimal(redirect_uri)
);
let resp = router(state)
.oneshot(
Request::builder()
.method("POST")
.uri("/oauth/token")
.header("content-type", "application/x-www-form-urlencoded")
.body(Body::from(body))
.unwrap(),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let body = body_string(resp).await;
assert!(body.contains("\"access_token\""));
assert!(body.contains("\"token_type\":\"Bearer\""));
}
#[tokio::test]
async fn auth_code_rejects_bad_verifier() {
let state = test_state();
let verifier = "correct-verifier";
let challenge = challenge_for(verifier);
let auth_uri = format!(
"/authorize?response_type=code&client_id=test-id&redirect_uri=https%3A%2F%2Fclaude.ai%2Fapi%2Fmcp%2Fauth_callback&code_challenge={challenge}&code_challenge_method=S256&state=s",
);
let resp = router(state.clone())
.oneshot(
Request::builder()
.uri(auth_uri)
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
let location = resp
.headers()
.get(header::LOCATION)
.unwrap()
.to_str()
.unwrap()
.to_string();
let code = location
.split_once("code=")
.and_then(|(_, rest)| rest.split('&').next())
.unwrap()
.to_string();
let body = format!(
"grant_type=authorization_code&code={code}&redirect_uri=https%3A%2F%2Fclaude.ai%2Fapi%2Fmcp%2Fauth_callback&code_verifier=WRONG&client_id=test-id"
);
let resp = router(state)
.oneshot(
Request::builder()
.method("POST")
.uri("/oauth/token")
.header("content-type", "application/x-www-form-urlencoded")
.body(Body::from(body))
.unwrap(),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
let body = body_string(resp).await;
assert!(body.contains("invalid_grant"));
}
#[tokio::test]
async fn auth_code_response_includes_refresh_token() {
let state = test_state();
let verifier = "the-verifier-of-reasonable-length";
let challenge = challenge_for(verifier);
let auth_uri = format!(
"/authorize?response_type=code&client_id=test-id&redirect_uri=https%3A%2F%2Fclaude.ai%2Fapi%2Fmcp%2Fauth_callback&code_challenge={challenge}&code_challenge_method=S256&state=s",
);
let resp = router(state.clone())
.oneshot(
Request::builder()
.uri(auth_uri)
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
let location = resp
.headers()
.get(header::LOCATION)
.unwrap()
.to_str()
.unwrap()
.to_string();
let code = location
.split_once("code=")
.and_then(|(_, r)| r.split('&').next())
.unwrap()
.to_string();
let body = format!(
"grant_type=authorization_code&code={code}&redirect_uri=https%3A%2F%2Fclaude.ai%2Fapi%2Fmcp%2Fauth_callback&code_verifier={verifier}&client_id=test-id"
);
let resp = router(state)
.oneshot(
Request::builder()
.method("POST")
.uri("/oauth/token")
.header("content-type", "application/x-www-form-urlencoded")
.body(Body::from(body))
.unwrap(),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let body = body_string(resp).await;
assert!(body.contains("\"access_token\""), "body was: {body}");
assert!(body.contains("\"refresh_token\""), "body was: {body}");
assert!(
body.contains("\"refresh_expires_in\":7776000"),
"body was: {body}"
);
}
async fn auth_code_full_flow(state: OAuthState) -> (String, String) {
let verifier = "verifier-string-of-decent-length";
let challenge = challenge_for(verifier);
let auth_uri = format!(
"/authorize?response_type=code&client_id=test-id&redirect_uri=https%3A%2F%2Fclaude.ai%2Fapi%2Fmcp%2Fauth_callback&code_challenge={challenge}&code_challenge_method=S256&state=s",
);
let resp = router(state.clone())
.oneshot(
Request::builder()
.uri(auth_uri)
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
let location = resp
.headers()
.get(header::LOCATION)
.unwrap()
.to_str()
.unwrap()
.to_string();
let code = location
.split_once("code=")
.and_then(|(_, r)| r.split('&').next())
.unwrap()
.to_string();
let body = format!(
"grant_type=authorization_code&code={code}&redirect_uri=https%3A%2F%2Fclaude.ai%2Fapi%2Fmcp%2Fauth_callback&code_verifier={verifier}&client_id=test-id"
);
let resp = router(state)
.oneshot(
Request::builder()
.method("POST")
.uri("/oauth/token")
.header("content-type", "application/x-www-form-urlencoded")
.body(Body::from(body))
.unwrap(),
)
.await
.unwrap();
let body_str = body_string(resp).await;
let parsed: serde_json::Value = serde_json::from_str(&body_str).unwrap();
let access = parsed["access_token"].as_str().unwrap().to_string();
let refresh = parsed["refresh_token"].as_str().unwrap().to_string();
(access, refresh)
}
async fn post_token(state: OAuthState, body: &str) -> axum::response::Response {
router(state)
.oneshot(
Request::builder()
.method("POST")
.uri("/oauth/token")
.header("content-type", "application/x-www-form-urlencoded")
.body(Body::from(body.to_string()))
.unwrap(),
)
.await
.unwrap()
}
#[tokio::test]
async fn refresh_token_grant_returns_new_access_and_refresh() {
let state = test_state();
let (orig_access, refresh) = auth_code_full_flow(state.clone()).await;
let resp = post_token(
state.clone(),
&format!("grant_type=refresh_token&refresh_token={refresh}&client_id=test-id"),
)
.await;
assert_eq!(resp.status(), StatusCode::OK);
let body = body_string(resp).await;
let parsed: serde_json::Value = serde_json::from_str(&body).unwrap();
let new_access = parsed["access_token"].as_str().unwrap();
let new_refresh = parsed["refresh_token"].as_str().unwrap();
assert_ne!(new_access, orig_access);
assert_ne!(new_refresh, refresh);
assert!(state.validate_token(new_access).await);
}
#[tokio::test]
async fn refresh_token_grant_invalidates_old_refresh_token() {
let state = test_state();
let (_, refresh) = auth_code_full_flow(state.clone()).await;
let r1 = post_token(
state.clone(),
&format!("grant_type=refresh_token&refresh_token={refresh}&client_id=test-id"),
)
.await;
assert_eq!(r1.status(), StatusCode::OK);
let r2 = post_token(
state,
&format!("grant_type=refresh_token&refresh_token={refresh}&client_id=test-id"),
)
.await;
assert_eq!(r2.status(), StatusCode::BAD_REQUEST);
let body = body_string(r2).await;
assert!(body.contains("invalid_grant"));
}
#[tokio::test]
async fn refresh_token_replay_revokes_chain() {
let state = test_state();
let (orig_access, refresh) = auth_code_full_flow(state.clone()).await;
let r1 = post_token(
state.clone(),
&format!("grant_type=refresh_token&refresh_token={refresh}&client_id=test-id"),
)
.await;
let body = body_string(r1).await;
let parsed: serde_json::Value = serde_json::from_str(&body).unwrap();
let new_access = parsed["access_token"].as_str().unwrap().to_string();
let _ = post_token(
state.clone(),
&format!("grant_type=refresh_token&refresh_token={refresh}&client_id=test-id"),
)
.await;
assert!(
!state.validate_token(&new_access).await,
"new access should be revoked after replay"
);
assert!(
!state.validate_token(&orig_access).await,
"original access should be revoked after replay"
);
}
#[tokio::test]
async fn refresh_token_grant_with_unknown_token_returns_invalid_grant() {
let state = test_state();
let resp = post_token(
state,
"grant_type=refresh_token&refresh_token=never-issued&client_id=test-id",
)
.await;
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
let body = body_string(resp).await;
assert!(body.contains("invalid_grant"));
}
#[tokio::test]
async fn auth_code_is_single_use() {
let state = test_state();
let verifier = "vvv";
let challenge = challenge_for(verifier);
let auth_uri = format!(
"/authorize?response_type=code&client_id=test-id&redirect_uri=https%3A%2F%2Fclaude.ai%2Fapi%2Fmcp%2Fauth_callback&code_challenge={challenge}&code_challenge_method=S256&state=s",
);
let resp = router(state.clone())
.oneshot(
Request::builder()
.uri(auth_uri)
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
let location = resp
.headers()
.get(header::LOCATION)
.unwrap()
.to_str()
.unwrap()
.to_string();
let code = location
.split_once("code=")
.and_then(|(_, r)| r.split('&').next())
.unwrap()
.to_string();
let body = format!(
"grant_type=authorization_code&code={code}&redirect_uri=https%3A%2F%2Fclaude.ai%2Fapi%2Fmcp%2Fauth_callback&code_verifier={verifier}&client_id=test-id"
);
let make_req = || {
Request::builder()
.method("POST")
.uri("/oauth/token")
.header("content-type", "application/x-www-form-urlencoded")
.body(Body::from(body.clone()))
.unwrap()
};
let first = router(state.clone()).oneshot(make_req()).await.unwrap();
assert_eq!(first.status(), StatusCode::OK);
let second = router(state).oneshot(make_req()).await.unwrap();
assert_eq!(second.status(), StatusCode::BAD_REQUEST);
}
}