dynamic-mcp 1.5.0

MCP proxy server that reduces LLM context overhead with on-demand tool loading from multiple upstream servers.
use anyhow::{anyhow, bail, Context, Result};
use chrono::Utc;
use oauth2::{
    basic::BasicClient, AuthUrl, AuthorizationCode, ClientId, CsrfToken, PkceCodeChallenge,
    RedirectUrl, RefreshToken, Scope, TokenResponse, TokenUrl,
};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::net::{Ipv4Addr, SocketAddr, TcpListener};
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::net::TcpStream;
use url::Url;

use super::store::{AuthStore, OAuthTokens};

const CALLBACK_PATH: &str = "/oauth/callback";
const DISCOVERY_PATH: &str = "/.well-known/oauth-authorization-server";

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OAuthServerMetadata {
    pub authorization_endpoint: String,
    pub token_endpoint: String,
    #[serde(default)]
    pub scopes_supported: Vec<String>,
}

pub struct OAuthClient {
    store: AuthStore,
}

impl OAuthClient {
    pub fn new() -> Result<Self> {
        Ok(Self {
            store: AuthStore::new()?,
        })
    }

    pub async fn discover_oauth_endpoints(server_url: &str) -> Result<OAuthServerMetadata> {
        let base_url = Url::parse(server_url).context("Invalid server URL")?;
        let discovery_url = base_url
            .join(DISCOVERY_PATH)
            .context("Failed to construct discovery URL")?;

        tracing::debug!("Discovering OAuth endpoints at: {}", discovery_url);

        let client = reqwest::Client::new();
        let response = client
            .get(discovery_url.as_str())
            .send()
            .await
            .context("Failed to fetch OAuth discovery endpoint")?;

        if !response.status().is_success() {
            bail!("OAuth discovery failed with status: {}", response.status());
        }

        let metadata: OAuthServerMetadata = response
            .json()
            .await
            .context("Failed to parse OAuth discovery response")?;

        tracing::debug!("Discovered OAuth endpoints: {:?}", metadata);
        Ok(metadata)
    }

    pub async fn authenticate(
        &self,
        server_name: &str,
        server_url: &str,
        client_id: &str,
        scopes: Option<Vec<String>>,
    ) -> Result<OAuthTokens> {
        let existing_token = self.store.load_token(server_name).await?;

        if let Some(token) = existing_token {
            if !token.is_expired() {
                tracing::debug!("Using existing valid token for {}", server_name);
                return Ok(token);
            }

            if token.needs_refresh() {
                if let Some(refresh_token) = &token.refresh_token {
                    tracing::info!("Refreshing expired token for {}", server_name);
                    match self
                        .refresh_token(server_name, server_url, client_id, refresh_token)
                        .await
                    {
                        Ok(new_token) => return Ok(new_token),
                        Err(e) => {
                            tracing::warn!("Token refresh failed: {}, re-authenticating", e);
                        }
                    }
                }
            }
        }

        tracing::info!("Performing OAuth authentication for {}", server_name);
        let metadata = Self::discover_oauth_endpoints(server_url).await?;

        let tokens = self
            .perform_oauth_flow(server_name, server_url, &metadata, client_id, scopes)
            .await?;

        self.store.save_token(server_name, &tokens).await?;
        Ok(tokens)
    }

    async fn perform_oauth_flow(
        &self,
        server_name: &str,
        server_url: &str,
        metadata: &OAuthServerMetadata,
        client_id: &str,
        scopes: Option<Vec<String>>,
    ) -> Result<OAuthTokens> {
        let (listener, redirect_url) = Self::create_callback_server()?;

        let oauth_client = BasicClient::new(
            ClientId::new(client_id.to_string()),
            None,
            AuthUrl::new(metadata.authorization_endpoint.clone())
                .context("Invalid authorization endpoint")?,
            Some(TokenUrl::new(metadata.token_endpoint.clone()).context("Invalid token endpoint")?),
        )
        .set_redirect_uri(redirect_url);

        let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();

        let mut auth_request = oauth_client
            .authorize_url(CsrfToken::new_random)
            .set_pkce_challenge(pkce_challenge)
            .add_extra_param("resource", server_url);

        if let Some(scope_list) = scopes {
            for scope in scope_list {
                auth_request = auth_request.add_scope(Scope::new(scope));
            }
        }

        let (auth_url, csrf_token) = auth_request.url();

        tracing::info!("Opening browser for OAuth authorization: {}", auth_url);
        open::that(auth_url.as_str()).context("Failed to open browser")?;

        let (code, state) = Self::wait_for_callback(listener).await?;

        if &state != csrf_token.secret() {
            bail!("CSRF token mismatch");
        }

        tracing::debug!("Exchanging authorization code for token");
        let token_result = oauth_client
            .exchange_code(AuthorizationCode::new(code))
            .set_pkce_verifier(pkce_verifier)
            .request_async(oauth2::reqwest::async_http_client)
            .await
            .context("Token exchange failed")?;

        let expires_at = token_result
            .expires_in()
            .map(|duration| Utc::now() + chrono::Duration::seconds(duration.as_secs() as i64));

        let tokens = OAuthTokens {
            access_token: token_result.access_token().secret().clone(),
            refresh_token: token_result.refresh_token().map(|t| t.secret().clone()),
            expires_at,
        };

        tracing::info!("Successfully authenticated for {}", server_name);
        Ok(tokens)
    }

    async fn refresh_token(
        &self,
        server_name: &str,
        server_url: &str,
        client_id: &str,
        refresh_token: &str,
    ) -> Result<OAuthTokens> {
        let metadata = Self::discover_oauth_endpoints(server_url).await?;

        let oauth_client = BasicClient::new(
            ClientId::new(client_id.to_string()),
            None,
            AuthUrl::new(metadata.authorization_endpoint.clone())
                .context("Invalid authorization endpoint")?,
            Some(TokenUrl::new(metadata.token_endpoint.clone()).context("Invalid token endpoint")?),
        );

        let token_result = oauth_client
            .exchange_refresh_token(&RefreshToken::new(refresh_token.to_string()))
            .request_async(oauth2::reqwest::async_http_client)
            .await
            .context("Token refresh failed")?;

        let expires_at = token_result
            .expires_in()
            .map(|duration| Utc::now() + chrono::Duration::seconds(duration.as_secs() as i64));

        // OAuth2 token rotation: Use new refresh token if provided, keep old one otherwise
        let new_refresh_token = token_result
            .refresh_token()
            .map(|t| t.secret().clone())
            .or_else(|| Some(refresh_token.to_string()));

        let tokens = OAuthTokens {
            access_token: token_result.access_token().secret().clone(),
            refresh_token: new_refresh_token.clone(),
            expires_at,
        };

        self.store.save_token(server_name, &tokens).await?;

        if token_result.refresh_token().is_some() {
            tracing::info!(
                "Successfully refreshed token for {} (received new refresh token)",
                server_name
            );
        } else {
            tracing::info!(
                "Successfully refreshed token for {} (reusing refresh token)",
                server_name
            );
        }

        Ok(tokens)
    }

    fn create_callback_server() -> Result<(TcpListener, RedirectUrl)> {
        let listener = TcpListener::bind(SocketAddr::from((Ipv4Addr::LOCALHOST, 0)))
            .context("Failed to bind callback server")?;

        let port = listener.local_addr()?.port();
        let redirect_url = RedirectUrl::new(format!("http://localhost:{}{}", port, CALLBACK_PATH))
            .context("Failed to create redirect URL")?;

        tracing::debug!("Callback server listening on port {}", port);
        Ok((listener, redirect_url))
    }

    async fn wait_for_callback(listener: TcpListener) -> Result<(String, String)> {
        listener
            .set_nonblocking(true)
            .context("Failed to set non-blocking")?;

        let listener = tokio::net::TcpListener::from_std(listener)
            .context("Failed to convert to tokio listener")?;

        let (stream, _) = listener
            .accept()
            .await
            .context("Failed to accept connection")?;

        Self::handle_callback(stream).await
    }

    async fn handle_callback(mut stream: TcpStream) -> Result<(String, String)> {
        let mut reader = BufReader::new(&mut stream);
        let mut request_line = String::new();
        reader
            .read_line(&mut request_line)
            .await
            .context("Failed to read request line")?;

        let parts: Vec<&str> = request_line.split_whitespace().collect();
        if parts.len() < 2 {
            bail!("Invalid HTTP request");
        }

        let path_and_query = parts[1];
        let url = Url::parse(&format!("http://localhost{}", path_and_query))
            .context("Failed to parse callback URL")?;

        let params: HashMap<_, _> = url.query_pairs().collect();

        let code = params
            .get("code")
            .ok_or_else(|| anyhow!("No code parameter in callback"))?
            .to_string();

        let state = params
            .get("state")
            .ok_or_else(|| anyhow!("No state parameter in callback"))?
            .to_string();

        let response = "HTTP/1.1 200 OK\r\n\
                       Content-Type: text/html\r\n\
                       \r\n\
                       <html><body>\
                       <h1>Authentication Successful</h1>\
                       <p>You can close this window and return to the application.</p>\
                       </body></html>";

        stream
            .write_all(response.as_bytes())
            .await
            .context("Failed to write response")?;

        stream.flush().await.context("Failed to flush stream")?;

        Ok((code, state))
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_create_callback_server() {
        let result = OAuthClient::create_callback_server();
        assert!(result.is_ok());

        let (_listener, redirect_url) = result.unwrap();
        assert!(redirect_url.as_str().starts_with("http://localhost:"));
        assert!(redirect_url.as_str().ends_with("/oauth/callback"));
    }

    #[tokio::test]
    async fn test_oauth_client_creation() {
        let client = OAuthClient::new();
        assert!(client.is_ok());
    }
}