1use std::collections::HashMap;
8use std::sync::Arc;
9
10use fastmcp_core::{AccessToken, AuthContext, McpContext, McpError, McpErrorCode, McpResult};
11
12#[derive(Debug, Clone, Copy)]
14pub struct AuthRequest<'a> {
15 pub method: &'a str,
17 pub params: Option<&'a serde_json::Value>,
19 pub request_id: u64,
21}
22
23impl AuthRequest<'_> {
24 #[must_use]
26 pub fn access_token(&self) -> Option<AccessToken> {
27 extract_access_token(self.params)
28 }
29}
30
31fn extract_access_token(params: Option<&serde_json::Value>) -> Option<AccessToken> {
33 let params = params?;
34 match params {
35 serde_json::Value::String(value) => AccessToken::parse(value),
36 serde_json::Value::Object(map) => {
37 if let Some(token) = extract_from_map(map) {
38 return Some(token);
39 }
40 if let Some(meta) = map.get("_meta").and_then(serde_json::Value::as_object) {
41 if let Some(token) = extract_from_map(meta) {
42 return Some(token);
43 }
44 }
45 if let Some(headers) = map.get("headers").and_then(serde_json::Value::as_object) {
46 if let Some(token) = extract_from_map(headers) {
47 return Some(token);
48 }
49 }
50 None
51 }
52 _ => None,
53 }
54}
55
56fn extract_from_map(map: &serde_json::Map<String, serde_json::Value>) -> Option<AccessToken> {
57 for key in [
58 "authorization",
59 "Authorization",
60 "auth",
61 "token",
62 "access_token",
63 "accessToken",
64 ] {
65 if let Some(value) = map.get(key) {
66 if let Some(token) = extract_from_value(value) {
67 return Some(token);
68 }
69 }
70 }
71 None
72}
73
74fn extract_from_value(value: &serde_json::Value) -> Option<AccessToken> {
75 match value {
76 serde_json::Value::String(value) => AccessToken::parse(value),
77 serde_json::Value::Object(map) => {
78 if let (Some(scheme), Some(token)) = (
79 map.get("scheme").and_then(serde_json::Value::as_str),
80 map.get("token").and_then(serde_json::Value::as_str),
81 ) {
82 if !scheme.trim().is_empty() && !token.trim().is_empty() {
83 return Some(AccessToken {
84 scheme: scheme.trim().to_string(),
85 token: token.trim().to_string(),
86 });
87 }
88 }
89 for key in ["authorization", "token", "access_token", "accessToken"] {
90 if let Some(value) = map.get(key).and_then(serde_json::Value::as_str) {
91 if let Some(token) = AccessToken::parse(value) {
92 return Some(token);
93 }
94 }
95 }
96 None
97 }
98 _ => None,
99 }
100}
101
102pub trait AuthProvider: Send + Sync {
107 fn authenticate(&self, ctx: &McpContext, request: AuthRequest<'_>) -> McpResult<AuthContext>;
111}
112
113pub trait TokenVerifier: Send + Sync {
115 fn verify(
117 &self,
118 ctx: &McpContext,
119 request: AuthRequest<'_>,
120 token: &AccessToken,
121 ) -> McpResult<AuthContext>;
122}
123
124#[derive(Clone)]
126pub struct TokenAuthProvider {
127 verifier: Arc<dyn TokenVerifier>,
128 missing_token_error: McpError,
129}
130
131impl TokenAuthProvider {
132 #[must_use]
134 pub fn new<V: TokenVerifier + 'static>(verifier: V) -> Self {
135 Self {
136 verifier: Arc::new(verifier),
137 missing_token_error: auth_error("Missing access token"),
138 }
139 }
140
141 #[must_use]
143 pub fn with_missing_token_error(mut self, error: McpError) -> Self {
144 self.missing_token_error = error;
145 self
146 }
147}
148
149impl AuthProvider for TokenAuthProvider {
150 fn authenticate(&self, ctx: &McpContext, request: AuthRequest<'_>) -> McpResult<AuthContext> {
151 let access = request
152 .access_token()
153 .ok_or_else(|| self.missing_token_error.clone())?;
154 self.verifier.verify(ctx, request, &access)
155 }
156}
157
158#[derive(Debug, Clone)]
160pub struct StaticTokenVerifier {
161 tokens: HashMap<String, AuthContext>,
162 allowed_schemes: Option<Vec<String>>,
163}
164
165impl StaticTokenVerifier {
166 pub fn new<I, K>(tokens: I) -> Self
168 where
169 I: IntoIterator<Item = (K, AuthContext)>,
170 K: Into<String>,
171 {
172 let tokens = tokens
173 .into_iter()
174 .map(|(token, ctx)| (token.into(), ctx))
175 .collect();
176 Self {
177 tokens,
178 allowed_schemes: None,
179 }
180 }
181
182 #[must_use]
184 pub fn with_allowed_schemes<I, S>(mut self, schemes: I) -> Self
185 where
186 I: IntoIterator<Item = S>,
187 S: Into<String>,
188 {
189 self.allowed_schemes = Some(schemes.into_iter().map(Into::into).collect());
190 self
191 }
192}
193
194impl TokenVerifier for StaticTokenVerifier {
195 fn verify(
196 &self,
197 _ctx: &McpContext,
198 _request: AuthRequest<'_>,
199 token: &AccessToken,
200 ) -> McpResult<AuthContext> {
201 if let Some(allowed) = &self.allowed_schemes {
202 if !allowed
203 .iter()
204 .any(|scheme| scheme.eq_ignore_ascii_case(&token.scheme))
205 {
206 return Err(auth_error("Unsupported auth scheme"));
207 }
208 }
209
210 let Some(auth) = self.tokens.get(&token.token) else {
211 return Err(auth_error("Invalid access token"));
212 };
213
214 let mut ctx = auth.clone();
215 ctx.token.get_or_insert_with(|| token.clone());
216 Ok(ctx)
217 }
218}
219
220fn auth_error(message: impl Into<String>) -> McpError {
221 McpError::new(McpErrorCode::ResourceForbidden, message)
222}
223
224#[cfg(feature = "jwt")]
225mod jwt {
226 use super::{AuthContext, AuthRequest, TokenVerifier, auth_error};
227 use fastmcp_core::{AccessToken, McpContext, McpResult};
228 use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode};
229
230 #[derive(Debug, Clone)]
232 pub struct JwtTokenVerifier {
233 decoding_key: DecodingKey,
234 validation: Validation,
235 }
236
237 impl JwtTokenVerifier {
238 #[must_use]
240 pub fn hs256(secret: impl AsRef<[u8]>) -> Self {
241 Self {
242 decoding_key: DecodingKey::from_secret(secret.as_ref()),
243 validation: Validation::new(Algorithm::HS256),
244 }
245 }
246
247 #[must_use]
249 pub fn with_validation(mut self, validation: Validation) -> Self {
250 self.validation = validation;
251 self
252 }
253 }
254
255 impl TokenVerifier for JwtTokenVerifier {
256 fn verify(
257 &self,
258 _ctx: &McpContext,
259 _request: AuthRequest<'_>,
260 token: &AccessToken,
261 ) -> McpResult<AuthContext> {
262 if !token.scheme.eq_ignore_ascii_case("Bearer") {
263 return Err(auth_error("Unsupported auth scheme"));
264 }
265
266 let data =
267 decode::<serde_json::Value>(&token.token, &self.decoding_key, &self.validation)
268 .map_err(|err| auth_error(format!("Invalid token: {err}")))?;
269
270 let claims = data.claims;
271 let subject = claims
272 .get("sub")
273 .and_then(serde_json::Value::as_str)
274 .map(str::to_string);
275 let scopes = extract_scopes(&claims);
276
277 Ok(AuthContext {
278 subject,
279 scopes,
280 token: Some(token.clone()),
281 claims: Some(claims),
282 })
283 }
284 }
285
286 fn extract_scopes(claims: &serde_json::Value) -> Vec<String> {
287 let mut scopes = Vec::new();
288 if let Some(scope) = claims.get("scope").and_then(serde_json::Value::as_str) {
289 scopes.extend(scope.split_whitespace().map(str::to_string));
290 }
291 if let Some(list) = claims.get("scopes").and_then(serde_json::Value::as_array) {
292 scopes.extend(
293 list.iter()
294 .filter_map(|value| value.as_str().map(str::to_string)),
295 );
296 }
297 scopes
298 }
299}
300
301#[cfg(feature = "jwt")]
302pub use jwt::JwtTokenVerifier;
303
304#[derive(Debug, Default, Clone, Copy)]
306pub struct AllowAllAuthProvider;
307
308impl AuthProvider for AllowAllAuthProvider {
309 fn authenticate(&self, _ctx: &McpContext, _request: AuthRequest<'_>) -> McpResult<AuthContext> {
310 Ok(AuthContext::anonymous())
311 }
312}