use super::OAuthError;
use super::credential_store::OAuthCredentialStore;
use super::handler::OAuthHandler;
use rmcp::transport::auth::{AuthClient, AuthorizationManager, OAuthState};
const OAUTH_CLIENT_NAME: &str = "Aether MCP Client";
fn create_credential_store(server_id: &str) -> OAuthCredentialStore {
OAuthCredentialStore::new(server_id)
}
pub async fn create_auth_manager_from_store(
server_id: &str,
base_url: &str,
) -> Result<Option<AuthorizationManager>, OAuthError> {
let credential_store = create_credential_store(server_id);
let mut auth_manager = AuthorizationManager::new(base_url).await.map_err(|e| OAuthError::Rmcp(e.to_string()))?;
auth_manager.set_credential_store(credential_store);
if auth_manager.initialize_from_store().await.map_err(|e| OAuthError::Rmcp(e.to_string()))? {
Ok(Some(auth_manager))
} else {
Ok(None)
}
}
pub async fn perform_oauth_flow(
server_id: &str,
base_url: &str,
handler: &dyn OAuthHandler,
) -> Result<AuthClient<reqwest::Client>, OAuthError> {
let mut oauth_state =
OAuthState::new(base_url, None).await.map_err(|e| OAuthError::Rmcp(format!("OAuth init failed: {e}")))?;
let credential_store = create_credential_store(server_id);
if let OAuthState::Unauthorized(ref mut manager) = oauth_state {
manager.set_credential_store(credential_store);
}
oauth_state
.start_authorization(&[], handler.redirect_uri(), Some(OAUTH_CLIENT_NAME))
.await
.map_err(|e| OAuthError::Rmcp(format!("start_authorization failed: {e}")))?;
let auth_url = oauth_state
.get_authorization_url()
.await
.map_err(|e| OAuthError::Rmcp(format!("get_authorization_url failed: {e}")))?;
let auth_url = dedupe_query_params(&auth_url);
let callback = handler.authorize(&auth_url).await?;
oauth_state
.handle_callback(&callback.code, &callback.state)
.await
.map_err(|e| OAuthError::Rmcp(format!("handle_callback failed: {e}")))?;
let auth_manager = oauth_state
.into_authorization_manager()
.ok_or_else(|| OAuthError::Rmcp("OAuth flow did not produce an AuthorizationManager".into()))?;
Ok(AuthClient::new(reqwest::Client::default(), auth_manager))
}
fn dedupe_query_params(url_str: &str) -> String {
let Ok(mut url) = url::Url::parse(url_str) else {
return url_str.to_string();
};
let pairs: std::collections::HashMap<String, String> =
url.query_pairs().map(|(k, v)| (k.into_owned(), v.into_owned())).collect();
url.query_pairs_mut().clear().extend_pairs(&pairs);
url.to_string()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn dedupes_duplicate_resource_param() {
let input = "https://example.com/oauth/authorize?resource=https%3A%2F%2Fa%2Fb&response_type=code&client_id=x&resource=https%3A%2F%2Fa%2Fb";
let out = dedupe_query_params(input);
let url = url::Url::parse(&out).unwrap();
let resources: Vec<_> = url.query_pairs().filter(|(k, _)| k == "resource").collect();
assert_eq!(resources.len(), 1);
assert_eq!(resources[0].1, "https://a/b");
assert!(url.query_pairs().any(|(k, _)| k == "response_type"));
assert!(url.query_pairs().any(|(k, _)| k == "client_id"));
}
#[test]
fn preserves_unique_params() {
let input = "https://example.com/?resource=x&other=y";
let out = dedupe_query_params(input);
let url = url::Url::parse(&out).unwrap();
let pairs: Vec<_> = url.query_pairs().collect();
assert_eq!(pairs.len(), 2);
}
}