1mod persist;
2
3use axum::extract::{Query, State};
4use axum::response::Html;
5use axum::routing::get;
6use axum::Router;
7use minijinja::render;
8use rmcp::transport::auth::{CredentialStore, OAuthState, StoredCredentials};
9use rmcp::transport::AuthorizationManager;
10use serde::Deserialize;
11use std::net::SocketAddr;
12use std::sync::Arc;
13use tokio::sync::{oneshot, Mutex};
14use tracing::warn;
15
16use crate::oauth::persist::AsterCredentialStore;
17
18const CALLBACK_TEMPLATE: &str = include_str!("oauth_callback.html");
19
20#[derive(Clone)]
21struct AppState {
22 code_receiver: Arc<Mutex<Option<oneshot::Sender<CallbackParams>>>>,
23}
24
25#[derive(Debug, Deserialize)]
26struct CallbackParams {
27 code: String,
28 state: String,
29}
30
31pub async fn oauth_flow(
32 mcp_server_url: &String,
33 name: &String,
34) -> Result<AuthorizationManager, anyhow::Error> {
35 let credential_store = AsterCredentialStore::new(name.clone());
36 let mut auth_manager = AuthorizationManager::new(mcp_server_url).await?;
37 auth_manager.set_credential_store(credential_store.clone());
38
39 if auth_manager.initialize_from_store().await? {
40 if auth_manager.refresh_token().await.is_ok() {
41 return Ok(auth_manager);
42 }
43
44 if let Err(e) = credential_store.clear().await {
45 warn!("error clearing bad credentials: {}", e);
46 }
47 }
48
49 let (code_sender, code_receiver) = oneshot::channel::<CallbackParams>();
51 let app_state = AppState {
52 code_receiver: Arc::new(Mutex::new(Some(code_sender))),
53 };
54
55 let rendered = render!(CALLBACK_TEMPLATE, name => name);
56 let handler = move |Query(params): Query<CallbackParams>, State(state): State<AppState>| {
57 let rendered = rendered.clone();
58 async move {
59 if let Some(sender) = state.code_receiver.lock().await.take() {
60 let _ = sender.send(params);
61 }
62 Html(rendered)
63 }
64 };
65 let app = Router::new()
66 .route("/oauth_callback", get(handler))
67 .with_state(app_state);
68
69 let addr = SocketAddr::from(([127, 0, 0, 1], 0));
70 let listener = tokio::net::TcpListener::bind(addr).await?;
71 let used_addr = listener.local_addr()?;
72 tokio::spawn(async move {
73 let result = axum::serve(listener, app).await;
74 if let Err(e) = result {
75 eprintln!("Callback server error: {}", e);
76 }
77 });
78
79 let mut oauth_state = OAuthState::new(mcp_server_url, None).await?;
80
81 let redirect_uri = format!("http://localhost:{}/oauth_callback", used_addr.port());
82 oauth_state
83 .start_authorization(&[], redirect_uri.as_str(), Some("aster"))
84 .await?;
85
86 let authorization_url = oauth_state.get_authorization_url().await?;
87 if webbrowser::open(authorization_url.as_str()).is_err() {
88 eprintln!("Open the following URL to authorize {}:", name);
89 eprintln!(" {}", authorization_url);
90 }
91
92 let CallbackParams {
93 code: auth_code,
94 state: csrf_token,
95 } = code_receiver.await?;
96 oauth_state.handle_callback(&auth_code, &csrf_token).await?;
97
98 let (client_id, token_response) = oauth_state.get_credentials().await?;
99
100 let mut auth_manager = oauth_state
101 .into_authorization_manager()
102 .ok_or_else(|| anyhow::anyhow!("Failed to get authorization manager"))?;
103
104 credential_store
105 .save(StoredCredentials {
106 client_id,
107 token_response,
108 })
109 .await?;
110
111 auth_manager.set_credential_store(credential_store);
112
113 Ok(auth_manager)
114}