use serde::{Deserialize, Serialize};
use crate::error::ConnectorError;
#[derive(Debug, Clone, Deserialize)]
pub struct ProtectedResourceMetadata {
#[serde(default)]
pub resource: Option<String>,
#[serde(default)]
pub authorization_servers: Vec<String>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct AuthServerMetadata {
pub authorization_endpoint: String,
pub token_endpoint: String,
#[serde(default)]
pub registration_endpoint: Option<String>,
#[serde(default)]
pub scopes_supported: Vec<String>,
#[serde(default)]
pub code_challenge_methods_supported: Vec<String>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct ClientRegistration {
pub client_id: String,
#[serde(default)]
pub client_secret: Option<String>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct TokenResponse {
pub access_token: String,
#[serde(default = "default_token_type")]
pub token_type: String,
#[serde(default)]
pub expires_in: Option<u64>,
#[serde(default)]
pub refresh_token: Option<String>,
#[serde(default)]
pub scope: Option<String>,
}
fn default_token_type() -> String {
"Bearer".to_string()
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StoredTokens {
pub access_token: String,
pub token_type: String,
#[serde(default)]
pub refresh_token: Option<String>,
#[serde(default)]
pub expires_at: Option<u64>,
}
impl StoredTokens {
pub fn from_response(resp: TokenResponse, now_unix: u64, prior_refresh: Option<String>) -> Self {
Self {
access_token: resp.access_token,
token_type: resp.token_type,
refresh_token: resp.refresh_token.or(prior_refresh),
expires_at: resp.expires_in.map(|e| now_unix.saturating_add(e)),
}
}
pub fn is_expired(&self, now_unix: u64, skew_secs: u64) -> bool {
match self.expires_at {
Some(exp) => now_unix.saturating_add(skew_secs) >= exp,
None => false,
}
}
}
pub fn origin_of(url: &str) -> Result<String, ConnectorError> {
let parsed =
reqwest::Url::parse(url).map_err(|e| ConnectorError::Protocol(format!("bad url: {e}")))?;
let origin = parsed.origin().ascii_serialization();
if origin == "null" {
return Err(ConnectorError::Protocol(format!(
"url has no usable origin: {url}"
)));
}
Ok(origin)
}
pub async fn discover(
http: &reqwest::Client,
mcp_url: &str,
) -> Result<(Option<String>, AuthServerMetadata), ConnectorError> {
let origin = origin_of(mcp_url)?;
let (resource, auth_server) = match fetch_json::<ProtectedResourceMetadata>(
http,
&format!("{origin}/.well-known/oauth-protected-resource"),
)
.await
{
Ok(prm) => {
let auth = prm.authorization_servers.first().cloned();
let resource = prm.resource.or_else(|| Some(mcp_url.to_string()));
(resource, auth.unwrap_or_else(|| origin.clone()))
}
Err(_) => (Some(mcp_url.to_string()), origin.clone()),
};
let asm = fetch_auth_server_metadata(http, &auth_server).await?;
Ok((resource, asm))
}
async fn fetch_auth_server_metadata(
http: &reqwest::Client,
auth_server: &str,
) -> Result<AuthServerMetadata, ConnectorError> {
let base = auth_server.trim_end_matches('/');
for path in [
"/.well-known/oauth-authorization-server",
"/.well-known/openid-configuration",
] {
if let Ok(asm) = fetch_json::<AuthServerMetadata>(http, &format!("{base}{path}")).await {
return Ok(asm);
}
}
Err(ConnectorError::Protocol(format!(
"no authorization server metadata at {auth_server}"
)))
}
pub async fn register_client(
http: &reqwest::Client,
asm: &AuthServerMetadata,
redirect_uri: &str,
client_name: &str,
) -> Result<ClientRegistration, ConnectorError> {
let endpoint = asm.registration_endpoint.as_ref().ok_or_else(|| {
ConnectorError::Protocol(
"authorization server has no registration_endpoint (dynamic client registration unsupported)".into(),
)
})?;
let body = serde_json::json!({
"client_name": client_name,
"redirect_uris": [redirect_uri],
"grant_types": ["authorization_code", "refresh_token"],
"response_types": ["code"],
"token_endpoint_auth_method": "none",
});
let resp = http
.post(endpoint)
.json(&body)
.send()
.await
.map_err(|e| ConnectorError::Http(e.to_string()))?;
let status = resp.status();
let text = resp
.text()
.await
.map_err(|e| ConnectorError::Http(e.to_string()))?;
if !status.is_success() {
return Err(ConnectorError::Http(format!(
"client registration failed: HTTP {status}: {text}"
)));
}
serde_json::from_str(&text)
.map_err(|e| ConnectorError::Protocol(format!("parse registration response: {e}")))
}
pub fn authorization_url(
asm: &AuthServerMetadata,
client_id: &str,
redirect_uri: &str,
state: &str,
challenge: &str,
resource: Option<&str>,
scopes: &[String],
) -> Result<String, ConnectorError> {
let mut url = reqwest::Url::parse(&asm.authorization_endpoint)
.map_err(|e| ConnectorError::Protocol(format!("bad authorization_endpoint: {e}")))?;
{
let mut q = url.query_pairs_mut();
q.append_pair("response_type", "code");
q.append_pair("client_id", client_id);
q.append_pair("redirect_uri", redirect_uri);
q.append_pair("state", state);
q.append_pair("code_challenge", challenge);
q.append_pair("code_challenge_method", "S256");
if !scopes.is_empty() {
q.append_pair("scope", &scopes.join(" "));
}
if let Some(resource) = resource {
q.append_pair("resource", resource);
}
}
Ok(url.to_string())
}
#[allow(clippy::too_many_arguments)]
pub async fn exchange_code(
http: &reqwest::Client,
token_endpoint: &str,
client_id: &str,
client_secret: Option<&str>,
redirect_uri: &str,
code: &str,
verifier: &str,
resource: Option<&str>,
) -> Result<TokenResponse, ConnectorError> {
let mut form = vec![
("grant_type", "authorization_code"),
("code", code),
("redirect_uri", redirect_uri),
("client_id", client_id),
("code_verifier", verifier),
];
if let Some(secret) = client_secret {
form.push(("client_secret", secret));
}
if let Some(resource) = resource {
form.push(("resource", resource));
}
post_token(http, token_endpoint, &form).await
}
pub async fn refresh(
http: &reqwest::Client,
token_endpoint: &str,
client_id: &str,
client_secret: Option<&str>,
refresh_token: &str,
resource: Option<&str>,
) -> Result<TokenResponse, ConnectorError> {
let mut form = vec![
("grant_type", "refresh_token"),
("refresh_token", refresh_token),
("client_id", client_id),
];
if let Some(secret) = client_secret {
form.push(("client_secret", secret));
}
if let Some(resource) = resource {
form.push(("resource", resource));
}
post_token(http, token_endpoint, &form).await
}
async fn post_token(
http: &reqwest::Client,
token_endpoint: &str,
form: &[(&str, &str)],
) -> Result<TokenResponse, ConnectorError> {
let resp = http
.post(token_endpoint)
.form(form)
.send()
.await
.map_err(|e| ConnectorError::Http(e.to_string()))?;
let status = resp.status();
let text = resp
.text()
.await
.map_err(|e| ConnectorError::Http(e.to_string()))?;
if !status.is_success() {
return Err(ConnectorError::Http(format!(
"token endpoint failed: HTTP {status}: {text}"
)));
}
serde_json::from_str(&text)
.map_err(|e| ConnectorError::Protocol(format!("parse token response: {e}")))
}
async fn fetch_json<T: for<'de> Deserialize<'de>>(
http: &reqwest::Client,
url: &str,
) -> Result<T, ConnectorError> {
let resp = http
.get(url)
.send()
.await
.map_err(|e| ConnectorError::Http(e.to_string()))?;
if !resp.status().is_success() {
return Err(ConnectorError::Http(format!(
"GET {url}: HTTP {}",
resp.status()
)));
}
let text = resp
.text()
.await
.map_err(|e| ConnectorError::Http(e.to_string()))?;
serde_json::from_str(&text)
.map_err(|e| ConnectorError::Protocol(format!("parse {url}: {e}")))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn origin_strips_path() {
assert_eq!(
origin_of("https://mcp.example.com/sse/v1?x=1").unwrap(),
"https://mcp.example.com"
);
assert_eq!(
origin_of("http://127.0.0.1:8931/mcp").unwrap(),
"http://127.0.0.1:8931"
);
}
#[test]
fn authorize_url_has_pkce_and_resource() {
let asm = AuthServerMetadata {
authorization_endpoint: "https://auth.example.com/authorize".into(),
token_endpoint: "https://auth.example.com/token".into(),
registration_endpoint: None,
scopes_supported: vec![],
code_challenge_methods_supported: vec!["S256".into()],
};
let url = authorization_url(
&asm,
"client-123",
"http://127.0.0.1:7777/cb",
"state-abc",
"challenge-xyz",
Some("https://mcp.example.com/"),
&["read".into(), "write".into()],
)
.unwrap();
assert!(url.starts_with("https://auth.example.com/authorize?"));
assert!(url.contains("response_type=code"));
assert!(url.contains("client_id=client-123"));
assert!(url.contains("code_challenge=challenge-xyz"));
assert!(url.contains("code_challenge_method=S256"));
assert!(url.contains("scope=read+write"));
assert!(url.contains("resource=https"));
assert!(url.contains("state=state-abc"));
}
#[test]
fn stored_tokens_expiry() {
let resp = TokenResponse {
access_token: "a".into(),
token_type: "Bearer".into(),
expires_in: Some(3600),
refresh_token: Some("r".into()),
scope: None,
};
let t = StoredTokens::from_response(resp, 1_000, None);
assert_eq!(t.expires_at, Some(4_600));
assert!(!t.is_expired(4_000, 60));
assert!(t.is_expired(4_550, 60)); assert!(t.is_expired(4_600, 0));
}
#[test]
fn refresh_keeps_prior_refresh_token_when_omitted() {
let resp = TokenResponse {
access_token: "a2".into(),
token_type: "Bearer".into(),
expires_in: Some(60),
refresh_token: None,
scope: None,
};
let t = StoredTokens::from_response(resp, 100, Some("old-refresh".into()));
assert_eq!(t.refresh_token.as_deref(), Some("old-refresh"));
}
}