Skip to main content

mcp_utils/client/
oauth_handler.rs

1use crate::client::manager::{ElicitationRequest, McpClientEvent, OAuthHandlerContext, UrlElicitationCompleteParams};
2use aether_auth::{OAuthCallback, OAuthError, OAuthHandler, accept_oauth_callback};
3use futures::future::BoxFuture;
4use rmcp::model::{CreateElicitationRequestParams, ElicitationAction};
5use tokio::net::TcpListener;
6use tokio::sync::{mpsc, oneshot};
7
8pub const AETHER_OAUTH_ELICITATION_ID: &str = "aether-oauth";
9
10/// `OAuthHandler` that dispatches the OAuth authorization URL to the host
11pub struct ElicitingOAuthHandler {
12    listener: TcpListener,
13    redirect_uri: String,
14    server_name: String,
15    event_sender: mpsc::Sender<McpClientEvent>,
16}
17
18impl ElicitingOAuthHandler {
19    pub fn new(ctx: OAuthHandlerContext) -> Result<Self, std::io::Error> {
20        let (port, listener) = {
21            let std_listener = std::net::TcpListener::bind("127.0.0.1:0")?;
22            let port = std_listener.local_addr()?.port();
23            std_listener.set_nonblocking(true)?;
24            (port, TcpListener::from_std(std_listener)?)
25        };
26
27        Ok(Self {
28            listener,
29            redirect_uri: format!("http://127.0.0.1:{port}/oauth2callback"),
30            server_name: ctx.server_name,
31            event_sender: ctx.tx,
32        })
33    }
34}
35
36impl OAuthHandler for ElicitingOAuthHandler {
37    fn redirect_uri(&self) -> &str {
38        &self.redirect_uri
39    }
40
41    fn authorize(&self, auth_url: &str) -> BoxFuture<'_, Result<OAuthCallback, OAuthError>> {
42        let auth_url = auth_url.to_string();
43        Box::pin(async move {
44            let (response_sender, response_rx) = oneshot::channel();
45            let request = ElicitationRequest {
46                server_name: self.server_name.clone(),
47                request: CreateElicitationRequestParams::UrlElicitationParams {
48                    meta: None,
49                    message: "Open this URL to authorize MCP server access.".to_string(),
50                    url: auth_url,
51                    elicitation_id: AETHER_OAUTH_ELICITATION_ID.to_string(),
52                },
53                response_sender,
54            };
55
56            self.event_sender
57                .send(McpClientEvent::Elicitation(request))
58                .await
59                .map_err(|_| OAuthError::Rmcp("OAuth prompt channel closed".to_string()))?;
60
61            let callback = tokio::select! {
62                callback = accept_oauth_callback(&self.listener) => callback,
63                response = response_rx => match response {
64                    Ok(result) if matches!(result.action, ElicitationAction::Decline | ElicitationAction::Cancel) => {
65                        Err(OAuthError::UserCancelled)
66                    }
67                    Ok(_) | Err(_) => accept_oauth_callback(&self.listener).await,
68                },
69            }?;
70
71            let complete = UrlElicitationCompleteParams {
72                server_name: self.server_name.clone(),
73                elicitation_id: AETHER_OAUTH_ELICITATION_ID.to_string(),
74            };
75
76            if self.event_sender.send(McpClientEvent::UrlElicitationComplete(complete)).await.is_err() {
77                tracing::warn!("Failed to send OAuth URL elicitation completion: receiver dropped");
78            }
79
80            Ok(callback)
81        })
82    }
83}