pub mod callback;
pub mod discovery;
pub mod storage;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use anyhow::{Context, Result};
use http::{HeaderName, HeaderValue};
use reqwest::Client as HttpClient;
use rmcp::transport::auth::{AuthClient, OAuthState};
use crate::cli::Cli;
use crate::session::CredentialKey;
use self::callback::CallbackServer;
use self::discovery::{AuthRequirement, OAuthDiscovery};
use self::storage::SecureCredentialStore;
pub enum AuthOutcome {
Anonymous { http_client: HttpClient },
Authorized { client: AuthClient<HttpClient> },
}
pub async fn acquire_auth_client(
cli: &Cli,
cred_key: &CredentialKey,
headers: &HashMap<HeaderName, HeaderValue>,
) -> Result<AuthOutcome> {
let http_client = build_http_client()?;
if cli.no_auth {
tracing::info!(
"--no-auth specified; skipping OAuth discovery and using an anonymous HTTP client"
);
return Ok(AuthOutcome::Anonymous { http_client });
}
let store = Arc::new(SecureCredentialStore::new(cred_key)?);
if cli.reset_auth {
tracing::info!("--reset-auth specified; clearing any cached credentials");
store
.clear_sync()
.context("failed to clear cached credentials")?;
}
let requirement = discovery::discover(
&http_client,
&cli.server_url,
headers,
cli.resource.as_deref(),
)
.await
.context("failed to discover OAuth requirements")?;
let discovery = match requirement {
AuthRequirement::None => {
tracing::info!("remote server accepts unauthenticated requests");
return Ok(AuthOutcome::Anonymous { http_client });
}
AuthRequirement::Required(d) => d,
};
tracing::info!(
authorization_server = %discovery.authorization_server,
scopes = ?discovery.scopes,
"remote requires OAuth"
);
let mut state = OAuthState::new(discovery.authorization_server.as_str(), None)
.await
.context("failed to initialize OAuth state machine")?;
install_credential_store(&mut state, store.clone()).await?;
if let Some(cached) = store
.as_ref()
.load_via_trait()
.await
.context("failed to load cached credentials")?
&& let Some(token) = cached.token_response.clone()
{
let stale = cached_access_token_is_stale(&cached);
tracing::info!(
stale_cached_token = stale,
token_received_at = ?cached.token_received_at,
"found cached OAuth credentials; using them"
);
state
.set_credentials(&cached.client_id, token)
.await
.context("failed to apply cached credentials")?;
if let Err(e) = store.as_ref().save_via_trait(cached.clone()).await {
tracing::warn!(
error = %e,
"could not restore genuine token_received_at after set_credentials; \
future launches may incorrectly treat an expired token as fresh"
);
}
if stale {
tracing::info!(
"cached access token is expired or within refresh buffer; refreshing now"
);
match state.refresh_token().await {
Ok(()) => tracing::info!("refresh succeeded; cached credentials are current"),
Err(e) => {
tracing::warn!(
error = %e,
"could not refresh expired cached credentials; clearing cache"
);
let _ = store.clear_sync();
anyhow::bail!(
"cached OAuth credentials are expired and the refresh \
token could not be exchanged ({e}). The credential \
cache has been cleared; re-run hyper-mcp-remote to \
start a fresh OAuth flow"
);
}
}
}
return Ok(AuthOutcome::Authorized {
client: into_auth_client(state, http_client)?,
});
}
run_interactive_flow(cli, &mut state, &discovery).await?;
Ok(AuthOutcome::Authorized {
client: into_auth_client(state, http_client)?,
})
}
#[async_trait::async_trait]
trait CredentialStoreExt {
async fn load_via_trait(
&self,
) -> Result<Option<rmcp::transport::auth::StoredCredentials>, rmcp::transport::auth::AuthError>;
async fn save_via_trait(
&self,
creds: rmcp::transport::auth::StoredCredentials,
) -> Result<(), rmcp::transport::auth::AuthError>;
}
#[async_trait::async_trait]
impl CredentialStoreExt for SecureCredentialStore {
async fn load_via_trait(
&self,
) -> Result<Option<rmcp::transport::auth::StoredCredentials>, rmcp::transport::auth::AuthError>
{
<Self as rmcp::transport::auth::CredentialStore>::load(self).await
}
async fn save_via_trait(
&self,
creds: rmcp::transport::auth::StoredCredentials,
) -> Result<(), rmcp::transport::auth::AuthError> {
<Self as rmcp::transport::auth::CredentialStore>::save(self, creds).await
}
}
async fn install_credential_store(
state: &mut OAuthState,
store: Arc<SecureCredentialStore>,
) -> Result<()> {
match state {
OAuthState::Unauthorized(manager) => {
manager.set_credential_store(ArcStore(store));
Ok(())
}
_ => anyhow::bail!(
"internal error: OAuthState must be Unauthorized when installing credential store"
),
}
}
struct ArcStore(Arc<SecureCredentialStore>);
#[async_trait::async_trait]
impl rmcp::transport::auth::CredentialStore for ArcStore {
async fn load(
&self,
) -> Result<Option<rmcp::transport::auth::StoredCredentials>, rmcp::transport::auth::AuthError>
{
self.0.load().await
}
async fn save(
&self,
credentials: rmcp::transport::auth::StoredCredentials,
) -> Result<(), rmcp::transport::auth::AuthError> {
self.0.save(credentials).await
}
async fn clear(&self) -> Result<(), rmcp::transport::auth::AuthError> {
self.0.clear().await
}
}
async fn run_interactive_flow(
cli: &Cli,
state: &mut OAuthState,
discovery: &OAuthDiscovery,
) -> Result<()> {
let callback = CallbackServer::bind(&cli.callback_host, cli.callback_port.unwrap_or(0))
.await
.context("failed to start local OAuth callback server")?;
let scopes = effective_scopes(cli, discovery);
let scope_refs: Vec<&str> = scopes.iter().map(String::as_str).collect();
tracing::info!(scopes = ?scopes, "starting OAuth authorization");
state
.start_authorization(
&scope_refs,
&callback.redirect_uri,
Some(cli.client_name.as_str()),
)
.await
.context("OAuth authorization start failed (dynamic registration?)")?;
let auth_url = state
.get_authorization_url()
.await
.context("failed to build authorization URL")?;
eprintln!("\nOpen this URL in your browser to authorize hyper-mcp-remote:\n{auth_url}\n");
match webbrowser::open(&auth_url) {
Ok(_) => tracing::info!("opened authorization URL in default browser"),
Err(e) => {
tracing::warn!(error = %e, "couldn't open browser automatically; please open the URL above manually")
}
}
let timeout = Duration::from_secs(cli.auth_timeout_secs);
let code = callback
.wait(timeout)
.await
.context("OAuth callback wait failed")?;
state
.handle_callback(&code.code, &code.state)
.await
.context("OAuth code exchange failed")?;
tracing::info!("OAuth authorization complete");
Ok(())
}
fn effective_scopes(cli: &Cli, discovery: &OAuthDiscovery) -> Vec<String> {
if let Some(s) = &cli.scope {
s.split(|c: char| c == ',' || c.is_whitespace())
.filter(|s| !s.is_empty())
.map(str::to_string)
.collect()
} else {
discovery.scopes.clone()
}
}
fn into_auth_client(state: OAuthState, http_client: HttpClient) -> Result<AuthClient<HttpClient>> {
let manager = state
.into_authorization_manager()
.context("OAuthState was not authorized after flow")?;
Ok(AuthClient::new(http_client, manager))
}
const REFRESH_BUFFER_SECS: u64 = 30;
fn cached_access_token_is_stale(cached: &rmcp::transport::auth::StoredCredentials) -> bool {
use oauth2::TokenResponse;
let Some(received_at) = cached.token_received_at else {
return false;
};
let Some(token) = cached.token_response.as_ref() else {
return false;
};
let Some(expires_in) = token.expires_in() else {
return false;
};
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0);
let elapsed = now.saturating_sub(received_at);
expires_in.as_secs().saturating_sub(elapsed) < REFRESH_BUFFER_SECS
}
fn build_http_client() -> Result<HttpClient> {
HttpClient::builder()
.user_agent(concat!(
"hyper-mcp-remote/",
env!("CARGO_PKG_VERSION"),
" (+https://github.com/hyper-mcp-rs/hyper-mcp-remote)"
))
.build()
.context("failed to build HTTP client")
}
#[cfg(test)]
mod tests {
use super::*;
use clap::Parser;
use discovery::OAuthDiscovery;
fn parse_cli(args: &[&str]) -> Cli {
let mut full = vec!["hyper-mcp-remote"];
full.extend_from_slice(args);
Cli::parse_from(full)
}
fn discovery_with(scopes: &[&str]) -> OAuthDiscovery {
OAuthDiscovery {
authorization_server: "https://auth.example.com".to_string(),
scopes: scopes.iter().map(|s| s.to_string()).collect(),
resource: "https://example.com/mcp".to_string(),
}
}
#[test]
fn effective_scopes_uses_cli_override_when_set() {
let cli = parse_cli(&["--scope", "read,write", "https://example.com/mcp"]);
let d = discovery_with(&["discovered"]);
assert_eq!(
effective_scopes(&cli, &d),
vec!["read".to_string(), "write".to_string()],
"CLI --scope must take precedence over discovery"
);
}
#[test]
fn effective_scopes_falls_back_to_discovery() {
let cli = parse_cli(&["https://example.com/mcp"]);
let d = discovery_with(&["a", "b", "c"]);
assert_eq!(
effective_scopes(&cli, &d),
vec!["a".to_string(), "b".to_string(), "c".to_string()],
);
}
#[test]
fn effective_scopes_handles_mixed_whitespace_and_commas() {
let cli = parse_cli(&[
"--scope",
" read , write\tadmin",
"https://example.com/mcp",
]);
let scopes = effective_scopes(&cli, &discovery_with(&[]));
assert_eq!(
scopes,
vec!["read".to_string(), "write".to_string(), "admin".to_string()],
);
}
#[test]
fn effective_scopes_empty_cli_returns_empty() {
let cli = parse_cli(&["--scope", " , , ", "https://example.com/mcp"]);
let scopes = effective_scopes(&cli, &discovery_with(&["unused"]));
assert!(
scopes.is_empty(),
"CLI override of all-whitespace must produce empty list, got {scopes:?}"
);
}
fn sample_stored(
token_received_at: Option<u64>,
expires_in_secs: Option<u64>,
) -> rmcp::transport::auth::StoredCredentials {
let mut token = serde_json::json!({
"access_token": "cached-access",
"token_type": "bearer",
"refresh_token": "cached-refresh",
});
if let Some(secs) = expires_in_secs {
token["expires_in"] = serde_json::Value::from(secs);
}
let stored = serde_json::json!({
"client_id": "client-abc",
"token_response": token,
"granted_scopes": [],
"token_received_at": token_received_at,
});
serde_json::from_value(stored).expect("sample StoredCredentials must deserialize")
}
fn now_epoch_secs() -> u64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.expect("system time before UNIX epoch")
.as_secs()
}
#[test]
fn stale_when_received_long_ago_with_short_expiry() {
let stored = sample_stored(Some(now_epoch_secs() - 7200), Some(3600));
assert!(
cached_access_token_is_stale(&stored),
"a 1h token received 2h ago must be flagged stale"
);
}
#[test]
fn stale_when_within_refresh_buffer() {
let stored = sample_stored(
Some(now_epoch_secs() - (3600 - (REFRESH_BUFFER_SECS - 1))),
Some(3600),
);
assert!(
cached_access_token_is_stale(&stored),
"token within REFRESH_BUFFER_SECS of expiry must be flagged stale"
);
}
#[test]
fn fresh_when_well_inside_validity_window() {
let stored = sample_stored(Some(now_epoch_secs() - 60), Some(3600));
assert!(
!cached_access_token_is_stale(&stored),
"a 1h token received 1min ago must be fresh"
);
}
#[test]
fn fresh_when_no_received_at_to_compare_against() {
let stored = sample_stored(None, Some(3600));
assert!(!cached_access_token_is_stale(&stored));
}
#[test]
fn fresh_when_no_expires_in_to_compare_against() {
let stored = sample_stored(Some(now_epoch_secs() - 86_400), None);
assert!(!cached_access_token_is_stale(&stored));
}
#[test]
fn build_http_client_succeeds() {
let _c = build_http_client().expect("http client must build");
}
use axum::Router;
use axum::routing::get;
use std::collections::HashMap;
use std::sync::Arc;
async fn spawn_anonymous_mock() -> (String, tokio::task::JoinHandle<()>) {
let app = Router::new().route("/mcp", get(|| async { "ok" }));
let listener = tokio::net::TcpListener::bind("127.0.0.1:0")
.await
.expect("bind");
let addr = listener.local_addr().expect("local_addr");
let handle = tokio::spawn(async move {
let _ = axum::serve(listener, app).await;
});
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
(format!("http://{addr}/mcp"), handle)
}
#[tokio::test]
async fn acquire_auth_client_returns_anonymous_when_server_does_not_require_oauth() {
let (url, _h) = spawn_anonymous_mock().await;
let cli = parse_cli(&["--allow-http", &url]);
let headers = HashMap::new();
let cred_key = CredentialKey::new(&url, None);
let outcome = acquire_auth_client(&cli, &cred_key, &headers)
.await
.expect("acquire");
assert!(
matches!(outcome, AuthOutcome::Anonymous { .. }),
"server returned 200 — we should not have started an OAuth flow"
);
}
#[tokio::test]
async fn acquire_auth_client_short_circuits_on_no_auth_without_touching_network() {
let url = "http://127.0.0.1:1/mcp"; let cli = parse_cli(&["--no-auth", "--allow-http", url]);
let headers = HashMap::new();
let cred_key = CredentialKey::new(url, None);
let outcome = acquire_auth_client(&cli, &cred_key, &headers)
.await
.expect("--no-auth must not perform any network I/O");
assert!(
matches!(outcome, AuthOutcome::Anonymous { .. }),
"--no-auth must yield AuthOutcome::Anonymous"
);
}
#[tokio::test]
async fn arc_store_load_returns_none_on_empty() {
let dir = tempfile::tempdir().expect("tempdir");
let cred_key = CredentialKey::new("https://example.com/arc-empty", None);
let store =
SecureCredentialStore::with_data_dir(&cred_key, dir.path()).expect("with_data_dir");
let arc_store = ArcStore(Arc::new(store));
let loaded = rmcp::transport::auth::CredentialStore::load(&arc_store)
.await
.expect("load");
let _ = loaded;
}
}