openauth_plugins/mcp/client/
mod.rs1use http::{header, HeaderValue, Request, Response, StatusCode};
4use serde::{Deserialize, Serialize};
5use serde_json::{json, Value};
6use std::future::Future;
7use std::sync::{Arc, Mutex};
8use std::time::{Duration, Instant};
9
10#[derive(Debug, Clone)]
11pub struct McpAuthClientOptions {
12 pub auth_url: String,
13 pub resource: Option<String>,
14 pub allowed_origin: Option<String>,
15 pub discovery_cache_ttl: Duration,
16 pub http_client: reqwest::Client,
17}
18
19impl Default for McpAuthClientOptions {
20 fn default() -> Self {
21 Self {
22 auth_url: String::new(),
23 resource: None,
24 allowed_origin: None,
25 discovery_cache_ttl: Duration::from_secs(60),
26 http_client: reqwest::Client::new(),
27 }
28 }
29}
30
31#[derive(Debug, Clone)]
32pub struct McpAuthClient {
33 auth_url: String,
34 resource: Option<String>,
35 allowed_origin: Option<String>,
36 discovery_cache_ttl: Duration,
37 http_client: reqwest::Client,
38 discovery_cache: Arc<Mutex<Option<CachedMetadata>>>,
39}
40
41#[derive(Debug, Clone)]
42struct CachedMetadata {
43 value: Value,
44 cached_at: Instant,
45}
46
47#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
48pub struct McpSession {
49 pub record: Value,
50}
51
52#[derive(Debug, thiserror::Error)]
53pub enum McpClientError {
54 #[error("http response build failed: {0}")]
55 Http(#[from] http::Error),
56 #[error("token verification failed: {0}")]
57 Verify(#[from] reqwest::Error),
58}
59
60#[derive(Debug, Serialize)]
61struct JsonRpcUnauthorized<'a> {
62 jsonrpc: &'static str,
63 error: JsonRpcError<'a>,
64 id: Option<&'a str>,
65}
66
67#[derive(Debug, Serialize)]
68struct JsonRpcError<'a> {
69 code: i64,
70 message: &'a str,
71 #[serde(rename = "www-authenticate")]
72 www_authenticate: &'a str,
73}
74
75impl McpAuthClient {
76 pub fn new(options: McpAuthClientOptions) -> Self {
77 Self {
78 auth_url: options.auth_url.trim_end_matches('/').to_owned(),
79 resource: options.resource,
80 allowed_origin: options.allowed_origin,
81 discovery_cache_ttl: options.discovery_cache_ttl,
82 http_client: options.http_client,
83 discovery_cache: Arc::new(Mutex::new(None)),
84 }
85 }
86
87 pub async fn verify_token(&self, token: &str) -> Result<Option<McpSession>, reqwest::Error> {
88 let response = self
89 .http_client
90 .get(format!("{}/mcp/get-session", self.auth_url))
91 .bearer_auth(token)
92 .send()
93 .await?;
94 if !response.status().is_success() {
95 return Ok(None);
96 }
97 let value = response.json::<Value>().await?;
98 if value.is_null() || value.get("userId").is_none() {
99 return Ok(None);
100 }
101 Ok(Some(McpSession { record: value }))
102 }
103
104 pub async fn discovery_metadata(&self) -> Result<Value, reqwest::Error> {
105 if let Some(cached) = self
106 .discovery_cache
107 .lock()
108 .ok()
109 .and_then(|cache| cache.clone())
110 {
111 if cached.cached_at.elapsed() < self.discovery_cache_ttl {
112 return Ok(cached.value);
113 }
114 }
115 let value = self
116 .http_client
117 .get(format!(
118 "{}/.well-known/oauth-authorization-server",
119 self.auth_url
120 ))
121 .send()
122 .await?
123 .error_for_status()?
124 .json::<Value>()
125 .await?;
126 if let Ok(mut cache) = self.discovery_cache.lock() {
127 *cache = Some(CachedMetadata {
128 value: value.clone(),
129 cached_at: Instant::now(),
130 });
131 }
132 Ok(value)
133 }
134
135 pub fn protected_resource_metadata(&self, server_url: &str) -> Value {
136 json!({
137 "resource": self.resource.clone().unwrap_or_else(|| origin_from_url(server_url)),
138 "authorization_servers": [self.auth_url.clone()],
139 "bearer_methods_supported": ["header"],
140 "scopes_supported": ["openid", "profile", "email", "offline_access"],
141 })
142 }
143
144 pub fn www_authenticate(&self) -> String {
145 let base = self.resource.as_deref().unwrap_or(&self.auth_url);
146 format!("Bearer resource_metadata=\"{base}/.well-known/oauth-protected-resource\"")
147 }
148
149 pub fn unauthorized_response(&self) -> Result<Response<Vec<u8>>, http::Error> {
150 let authenticate = self.www_authenticate();
151 let body = serde_json::to_vec(&JsonRpcUnauthorized {
152 jsonrpc: "2.0",
153 error: JsonRpcError {
154 code: -32000,
155 message: "Unauthorized: Authentication required",
156 www_authenticate: &authenticate,
157 },
158 id: None,
159 })
160 .unwrap_or_default();
161 Response::builder()
162 .status(StatusCode::UNAUTHORIZED)
163 .header(header::CONTENT_TYPE, "application/json")
164 .header(header::WWW_AUTHENTICATE, authenticate)
165 .header("Access-Control-Expose-Headers", "WWW-Authenticate")
166 .body(body)
167 }
168
169 pub fn cors_preflight_response(&self) -> Result<Response<Vec<u8>>, http::Error> {
170 Response::builder()
171 .status(StatusCode::NO_CONTENT)
172 .header(header::ACCESS_CONTROL_ALLOW_ORIGIN, self.allowed_origin())
173 .header(header::ACCESS_CONTROL_ALLOW_METHODS, "GET, POST, OPTIONS")
174 .header(
175 header::ACCESS_CONTROL_ALLOW_HEADERS,
176 "Content-Type, Authorization",
177 )
178 .header(header::ACCESS_CONTROL_MAX_AGE, "86400")
179 .body(Vec::new())
180 }
181
182 pub fn bearer_token<B>(&self, request: &Request<B>) -> Option<String> {
183 request
184 .headers()
185 .get(header::AUTHORIZATION)
186 .and_then(|value| value.to_str().ok())
187 .and_then(|value| value.strip_prefix("Bearer "))
188 .map(str::to_owned)
189 }
190
191 pub async fn authorize_request<B>(
192 &self,
193 request: &Request<B>,
194 ) -> Result<Option<McpSession>, reqwest::Error> {
195 let Some(token) = self.bearer_token(request) else {
196 return Ok(None);
197 };
198 self.verify_token(&token).await
199 }
200
201 pub async fn handle_request<F, Fut>(
202 &self,
203 request: Request<Vec<u8>>,
204 handler: F,
205 ) -> Result<Response<Vec<u8>>, McpClientError>
206 where
207 F: FnOnce(Request<Vec<u8>>, McpSession) -> Fut,
208 Fut: Future<Output = Result<Response<Vec<u8>>, http::Error>>,
209 {
210 if request.method() == http::Method::OPTIONS {
211 return Ok(self.cors_preflight_response()?);
212 }
213 let Some(token) = self.bearer_token(&request) else {
214 return Ok(self.unauthorized_response()?);
215 };
216 let Some(session) = self.verify_token(&token).await? else {
217 return Ok(self.unauthorized_response()?);
218 };
219 Ok(handler(request, session).await?)
220 }
221
222 fn allowed_origin(&self) -> HeaderValue {
223 if let Some(origin) = &self.allowed_origin {
224 return HeaderValue::from_str(origin).unwrap_or_else(|_| HeaderValue::from_static("*"));
225 }
226 url::Url::parse(&self.auth_url)
227 .ok()
228 .and_then(|url| {
229 let scheme = url.scheme();
230 let host = url.host_str()?;
231 let port = url
232 .port()
233 .map(|port| format!(":{port}"))
234 .unwrap_or_default();
235 HeaderValue::from_str(&format!("{scheme}://{host}{port}")).ok()
236 })
237 .unwrap_or_else(|| HeaderValue::from_static("*"))
238 }
239}
240
241fn origin_from_url(url: &str) -> String {
242 url::Url::parse(url)
243 .ok()
244 .and_then(|url| {
245 let scheme = url.scheme();
246 let host = url.host_str()?;
247 let port = url
248 .port()
249 .map(|port| format!(":{port}"))
250 .unwrap_or_default();
251 Some(format!("{scheme}://{host}{port}"))
252 })
253 .unwrap_or_else(|| url.trim_end_matches('/').to_owned())
254}