Skip to main content

aster/oauth/
mod.rs

1mod persist;
2
3use axum::extract::{Query, State};
4use axum::response::Html;
5use axum::routing::get;
6use axum::Router;
7use minijinja::render;
8use rmcp::transport::auth::{CredentialStore, OAuthState, StoredCredentials};
9use rmcp::transport::AuthorizationManager;
10use serde::Deserialize;
11use std::net::SocketAddr;
12use std::sync::Arc;
13use tokio::sync::{oneshot, Mutex};
14use tracing::warn;
15
16use crate::oauth::persist::AsterCredentialStore;
17
18const CALLBACK_TEMPLATE: &str = include_str!("oauth_callback.html");
19
20#[derive(Clone)]
21struct AppState {
22    code_receiver: Arc<Mutex<Option<oneshot::Sender<CallbackParams>>>>,
23}
24
25#[derive(Debug, Deserialize)]
26struct CallbackParams {
27    code: String,
28    state: String,
29}
30
31pub async fn oauth_flow(
32    mcp_server_url: &String,
33    name: &String,
34) -> Result<AuthorizationManager, anyhow::Error> {
35    let credential_store = AsterCredentialStore::new(name.clone());
36    let mut auth_manager = AuthorizationManager::new(mcp_server_url).await?;
37    auth_manager.set_credential_store(credential_store.clone());
38
39    if auth_manager.initialize_from_store().await? {
40        if auth_manager.refresh_token().await.is_ok() {
41            return Ok(auth_manager);
42        }
43
44        if let Err(e) = credential_store.clear().await {
45            warn!("error clearing bad credentials: {}", e);
46        }
47    }
48
49    // No existing credentials or they were invalid - need to do the full oauth flow
50    let (code_sender, code_receiver) = oneshot::channel::<CallbackParams>();
51    let app_state = AppState {
52        code_receiver: Arc::new(Mutex::new(Some(code_sender))),
53    };
54
55    let rendered = render!(CALLBACK_TEMPLATE, name => name);
56    let handler = move |Query(params): Query<CallbackParams>, State(state): State<AppState>| {
57        let rendered = rendered.clone();
58        async move {
59            if let Some(sender) = state.code_receiver.lock().await.take() {
60                let _ = sender.send(params);
61            }
62            Html(rendered)
63        }
64    };
65    let app = Router::new()
66        .route("/oauth_callback", get(handler))
67        .with_state(app_state);
68
69    let addr = SocketAddr::from(([127, 0, 0, 1], 0));
70    let listener = tokio::net::TcpListener::bind(addr).await?;
71    let used_addr = listener.local_addr()?;
72    tokio::spawn(async move {
73        let result = axum::serve(listener, app).await;
74        if let Err(e) = result {
75            eprintln!("Callback server error: {}", e);
76        }
77    });
78
79    let mut oauth_state = OAuthState::new(mcp_server_url, None).await?;
80
81    let redirect_uri = format!("http://localhost:{}/oauth_callback", used_addr.port());
82    oauth_state
83        .start_authorization(&[], redirect_uri.as_str(), Some("aster"))
84        .await?;
85
86    let authorization_url = oauth_state.get_authorization_url().await?;
87    if webbrowser::open(authorization_url.as_str()).is_err() {
88        eprintln!("Open the following URL to authorize {}:", name);
89        eprintln!("  {}", authorization_url);
90    }
91
92    let CallbackParams {
93        code: auth_code,
94        state: csrf_token,
95    } = code_receiver.await?;
96    oauth_state.handle_callback(&auth_code, &csrf_token).await?;
97
98    let (client_id, token_response) = oauth_state.get_credentials().await?;
99
100    let mut auth_manager = oauth_state
101        .into_authorization_manager()
102        .ok_or_else(|| anyhow::anyhow!("Failed to get authorization manager"))?;
103
104    credential_store
105        .save(StoredCredentials {
106            client_id,
107            token_response,
108        })
109        .await?;
110
111    auth_manager.set_credential_store(credential_store);
112
113    Ok(auth_manager)
114}