Skip to main content

openauth_plugins/mcp/client/
mod.rs

1//! Small framework-neutral helpers for MCP resource servers.
2
3use http::{header, HeaderValue, Request, Response, StatusCode};
4use serde::{Deserialize, Serialize};
5use serde_json::{json, Value};
6use std::future::Future;
7use std::sync::{Arc, Mutex};
8use std::time::{Duration, Instant};
9
10#[derive(Debug, Clone)]
11pub struct McpAuthClientOptions {
12    pub auth_url: String,
13    pub resource: Option<String>,
14    pub allowed_origin: Option<String>,
15    pub discovery_cache_ttl: Duration,
16    pub http_client: reqwest::Client,
17}
18
19impl Default for McpAuthClientOptions {
20    fn default() -> Self {
21        Self {
22            auth_url: String::new(),
23            resource: None,
24            allowed_origin: None,
25            discovery_cache_ttl: Duration::from_secs(60),
26            http_client: reqwest::Client::new(),
27        }
28    }
29}
30
31#[derive(Debug, Clone)]
32pub struct McpAuthClient {
33    auth_url: String,
34    resource: Option<String>,
35    allowed_origin: Option<String>,
36    discovery_cache_ttl: Duration,
37    http_client: reqwest::Client,
38    discovery_cache: Arc<Mutex<Option<CachedMetadata>>>,
39}
40
41#[derive(Debug, Clone)]
42struct CachedMetadata {
43    value: Value,
44    cached_at: Instant,
45}
46
47#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
48pub struct McpSession {
49    pub record: Value,
50}
51
52#[derive(Debug, thiserror::Error)]
53pub enum McpClientError {
54    #[error("http response build failed: {0}")]
55    Http(#[from] http::Error),
56    #[error("token verification failed: {0}")]
57    Verify(#[from] reqwest::Error),
58}
59
60#[derive(Debug, Serialize)]
61struct JsonRpcUnauthorized<'a> {
62    jsonrpc: &'static str,
63    error: JsonRpcError<'a>,
64    id: Option<&'a str>,
65}
66
67#[derive(Debug, Serialize)]
68struct JsonRpcError<'a> {
69    code: i64,
70    message: &'a str,
71    #[serde(rename = "www-authenticate")]
72    www_authenticate: &'a str,
73}
74
75impl McpAuthClient {
76    pub fn new(options: McpAuthClientOptions) -> Self {
77        Self {
78            auth_url: options.auth_url.trim_end_matches('/').to_owned(),
79            resource: options.resource,
80            allowed_origin: options.allowed_origin,
81            discovery_cache_ttl: options.discovery_cache_ttl,
82            http_client: options.http_client,
83            discovery_cache: Arc::new(Mutex::new(None)),
84        }
85    }
86
87    pub async fn verify_token(&self, token: &str) -> Result<Option<McpSession>, reqwest::Error> {
88        let response = self
89            .http_client
90            .get(format!("{}/mcp/get-session", self.auth_url))
91            .bearer_auth(token)
92            .send()
93            .await?;
94        if !response.status().is_success() {
95            return Ok(None);
96        }
97        let value = response.json::<Value>().await?;
98        if value.is_null() || value.get("userId").is_none() {
99            return Ok(None);
100        }
101        Ok(Some(McpSession { record: value }))
102    }
103
104    pub async fn discovery_metadata(&self) -> Result<Value, reqwest::Error> {
105        if let Some(cached) = self
106            .discovery_cache
107            .lock()
108            .ok()
109            .and_then(|cache| cache.clone())
110        {
111            if cached.cached_at.elapsed() < self.discovery_cache_ttl {
112                return Ok(cached.value);
113            }
114        }
115        let value = self
116            .http_client
117            .get(format!(
118                "{}/.well-known/oauth-authorization-server",
119                self.auth_url
120            ))
121            .send()
122            .await?
123            .error_for_status()?
124            .json::<Value>()
125            .await?;
126        if let Ok(mut cache) = self.discovery_cache.lock() {
127            *cache = Some(CachedMetadata {
128                value: value.clone(),
129                cached_at: Instant::now(),
130            });
131        }
132        Ok(value)
133    }
134
135    pub fn protected_resource_metadata(&self, server_url: &str) -> Value {
136        json!({
137            "resource": self.resource.clone().unwrap_or_else(|| origin_from_url(server_url)),
138            "authorization_servers": [self.auth_url.clone()],
139            "bearer_methods_supported": ["header"],
140            "scopes_supported": ["openid", "profile", "email", "offline_access"],
141        })
142    }
143
144    pub fn www_authenticate(&self) -> String {
145        let base = self.resource.as_deref().unwrap_or(&self.auth_url);
146        format!("Bearer resource_metadata=\"{base}/.well-known/oauth-protected-resource\"")
147    }
148
149    pub fn unauthorized_response(&self) -> Result<Response<Vec<u8>>, http::Error> {
150        let authenticate = self.www_authenticate();
151        let body = serde_json::to_vec(&JsonRpcUnauthorized {
152            jsonrpc: "2.0",
153            error: JsonRpcError {
154                code: -32000,
155                message: "Unauthorized: Authentication required",
156                www_authenticate: &authenticate,
157            },
158            id: None,
159        })
160        .unwrap_or_default();
161        Response::builder()
162            .status(StatusCode::UNAUTHORIZED)
163            .header(header::CONTENT_TYPE, "application/json")
164            .header(header::WWW_AUTHENTICATE, authenticate)
165            .header("Access-Control-Expose-Headers", "WWW-Authenticate")
166            .body(body)
167    }
168
169    pub fn cors_preflight_response(&self) -> Result<Response<Vec<u8>>, http::Error> {
170        Response::builder()
171            .status(StatusCode::NO_CONTENT)
172            .header(header::ACCESS_CONTROL_ALLOW_ORIGIN, self.allowed_origin())
173            .header(header::ACCESS_CONTROL_ALLOW_METHODS, "GET, POST, OPTIONS")
174            .header(
175                header::ACCESS_CONTROL_ALLOW_HEADERS,
176                "Content-Type, Authorization",
177            )
178            .header(header::ACCESS_CONTROL_MAX_AGE, "86400")
179            .body(Vec::new())
180    }
181
182    pub fn bearer_token<B>(&self, request: &Request<B>) -> Option<String> {
183        request
184            .headers()
185            .get(header::AUTHORIZATION)
186            .and_then(|value| value.to_str().ok())
187            .and_then(|value| value.strip_prefix("Bearer "))
188            .map(str::to_owned)
189    }
190
191    pub async fn authorize_request<B>(
192        &self,
193        request: &Request<B>,
194    ) -> Result<Option<McpSession>, reqwest::Error> {
195        let Some(token) = self.bearer_token(request) else {
196            return Ok(None);
197        };
198        self.verify_token(&token).await
199    }
200
201    pub async fn handle_request<F, Fut>(
202        &self,
203        request: Request<Vec<u8>>,
204        handler: F,
205    ) -> Result<Response<Vec<u8>>, McpClientError>
206    where
207        F: FnOnce(Request<Vec<u8>>, McpSession) -> Fut,
208        Fut: Future<Output = Result<Response<Vec<u8>>, http::Error>>,
209    {
210        if request.method() == http::Method::OPTIONS {
211            return Ok(self.cors_preflight_response()?);
212        }
213        let Some(token) = self.bearer_token(&request) else {
214            return Ok(self.unauthorized_response()?);
215        };
216        let Some(session) = self.verify_token(&token).await? else {
217            return Ok(self.unauthorized_response()?);
218        };
219        Ok(handler(request, session).await?)
220    }
221
222    fn allowed_origin(&self) -> HeaderValue {
223        if let Some(origin) = &self.allowed_origin {
224            return HeaderValue::from_str(origin).unwrap_or_else(|_| HeaderValue::from_static("*"));
225        }
226        url::Url::parse(&self.auth_url)
227            .ok()
228            .and_then(|url| {
229                let scheme = url.scheme();
230                let host = url.host_str()?;
231                let port = url
232                    .port()
233                    .map(|port| format!(":{port}"))
234                    .unwrap_or_default();
235                HeaderValue::from_str(&format!("{scheme}://{host}{port}")).ok()
236            })
237            .unwrap_or_else(|| HeaderValue::from_static("*"))
238    }
239}
240
241fn origin_from_url(url: &str) -> String {
242    url::Url::parse(url)
243        .ok()
244        .and_then(|url| {
245            let scheme = url.scheme();
246            let host = url.host_str()?;
247            let port = url
248                .port()
249                .map(|port| format!(":{port}"))
250                .unwrap_or_default();
251            Some(format!("{scheme}://{host}{port}"))
252        })
253        .unwrap_or_else(|| url.trim_end_matches('/').to_owned())
254}