Skip to main content

fastmcp_server/
auth.rs

1//! Authentication provider hooks for MCP servers.
2//!
3//! Auth providers are transport-agnostic and operate on the JSON-RPC
4//! request payload. They may populate [`AuthContext`] to be stored in
5//! session state for downstream handlers.
6
7use std::collections::HashMap;
8use std::sync::Arc;
9
10use fastmcp_core::{AccessToken, AuthContext, McpContext, McpError, McpErrorCode, McpResult};
11
12/// Authentication request view used by providers.
13#[derive(Debug, Clone, Copy)]
14pub struct AuthRequest<'a> {
15    /// JSON-RPC method name.
16    pub method: &'a str,
17    /// Raw params payload (if present).
18    pub params: Option<&'a serde_json::Value>,
19    /// Internal request ID (u64) used for tracing.
20    pub request_id: u64,
21}
22
23impl AuthRequest<'_> {
24    /// Attempts to extract an access token from the raw request params.
25    #[must_use]
26    pub fn access_token(&self) -> Option<AccessToken> {
27        extract_access_token(self.params)
28    }
29}
30
31/// Extracts an access token from request params using common field names.
32fn extract_access_token(params: Option<&serde_json::Value>) -> Option<AccessToken> {
33    let params = params?;
34    match params {
35        serde_json::Value::String(value) => AccessToken::parse(value),
36        serde_json::Value::Object(map) => {
37            if let Some(token) = extract_from_map(map) {
38                return Some(token);
39            }
40            if let Some(meta) = map.get("_meta").and_then(serde_json::Value::as_object) {
41                if let Some(token) = extract_from_map(meta) {
42                    return Some(token);
43                }
44            }
45            if let Some(headers) = map.get("headers").and_then(serde_json::Value::as_object) {
46                if let Some(token) = extract_from_map(headers) {
47                    return Some(token);
48                }
49            }
50            None
51        }
52        _ => None,
53    }
54}
55
56fn extract_from_map(map: &serde_json::Map<String, serde_json::Value>) -> Option<AccessToken> {
57    for key in [
58        "authorization",
59        "Authorization",
60        "auth",
61        "token",
62        "access_token",
63        "accessToken",
64    ] {
65        if let Some(value) = map.get(key) {
66            if let Some(token) = extract_from_value(value) {
67                return Some(token);
68            }
69        }
70    }
71    None
72}
73
74fn extract_from_value(value: &serde_json::Value) -> Option<AccessToken> {
75    match value {
76        serde_json::Value::String(value) => AccessToken::parse(value),
77        serde_json::Value::Object(map) => {
78            if let (Some(scheme), Some(token)) = (
79                map.get("scheme").and_then(serde_json::Value::as_str),
80                map.get("token").and_then(serde_json::Value::as_str),
81            ) {
82                if !scheme.trim().is_empty() && !token.trim().is_empty() {
83                    return Some(AccessToken {
84                        scheme: scheme.trim().to_string(),
85                        token: token.trim().to_string(),
86                    });
87                }
88            }
89            for key in ["authorization", "token", "access_token", "accessToken"] {
90                if let Some(value) = map.get(key).and_then(serde_json::Value::as_str) {
91                    if let Some(token) = AccessToken::parse(value) {
92                        return Some(token);
93                    }
94                }
95            }
96            None
97        }
98        _ => None,
99    }
100}
101
102/// Authentication provider interface.
103///
104/// Implementations decide whether a request is allowed and may return
105/// an [`AuthContext`] describing the authenticated subject.
106pub trait AuthProvider: Send + Sync {
107    /// Authenticate an incoming request.
108    ///
109    /// Return `Ok(AuthContext)` to allow, or an `Err(McpError)` to deny.
110    fn authenticate(&self, ctx: &McpContext, request: AuthRequest<'_>) -> McpResult<AuthContext>;
111}
112
113/// Token verifier interface used by token-based auth providers.
114pub trait TokenVerifier: Send + Sync {
115    /// Verify an access token and return an auth context if valid.
116    fn verify(
117        &self,
118        ctx: &McpContext,
119        request: AuthRequest<'_>,
120        token: &AccessToken,
121    ) -> McpResult<AuthContext>;
122}
123
124/// Token-based authentication provider.
125#[derive(Clone)]
126pub struct TokenAuthProvider {
127    verifier: Arc<dyn TokenVerifier>,
128    missing_token_error: McpError,
129}
130
131impl TokenAuthProvider {
132    /// Creates a new token auth provider with the given verifier.
133    #[must_use]
134    pub fn new<V: TokenVerifier + 'static>(verifier: V) -> Self {
135        Self {
136            verifier: Arc::new(verifier),
137            missing_token_error: auth_error("Missing access token"),
138        }
139    }
140
141    /// Overrides the error returned when a token is missing.
142    #[must_use]
143    pub fn with_missing_token_error(mut self, error: McpError) -> Self {
144        self.missing_token_error = error;
145        self
146    }
147}
148
149impl AuthProvider for TokenAuthProvider {
150    fn authenticate(&self, ctx: &McpContext, request: AuthRequest<'_>) -> McpResult<AuthContext> {
151        let access = request
152            .access_token()
153            .ok_or_else(|| self.missing_token_error.clone())?;
154        self.verifier.verify(ctx, request, &access)
155    }
156}
157
158/// Static token verifier backed by an in-memory token map.
159#[derive(Debug, Clone)]
160pub struct StaticTokenVerifier {
161    tokens: HashMap<String, AuthContext>,
162    allowed_schemes: Option<Vec<String>>,
163}
164
165impl StaticTokenVerifier {
166    /// Creates a new static verifier from a token → context map.
167    pub fn new<I, K>(tokens: I) -> Self
168    where
169        I: IntoIterator<Item = (K, AuthContext)>,
170        K: Into<String>,
171    {
172        let tokens = tokens
173            .into_iter()
174            .map(|(token, ctx)| (token.into(), ctx))
175            .collect();
176        Self {
177            tokens,
178            allowed_schemes: None,
179        }
180    }
181
182    /// Restricts accepted token schemes (case-insensitive).
183    #[must_use]
184    pub fn with_allowed_schemes<I, S>(mut self, schemes: I) -> Self
185    where
186        I: IntoIterator<Item = S>,
187        S: Into<String>,
188    {
189        self.allowed_schemes = Some(schemes.into_iter().map(Into::into).collect());
190        self
191    }
192}
193
194impl TokenVerifier for StaticTokenVerifier {
195    fn verify(
196        &self,
197        _ctx: &McpContext,
198        _request: AuthRequest<'_>,
199        token: &AccessToken,
200    ) -> McpResult<AuthContext> {
201        if let Some(allowed) = &self.allowed_schemes {
202            if !allowed
203                .iter()
204                .any(|scheme| scheme.eq_ignore_ascii_case(&token.scheme))
205            {
206                return Err(auth_error("Unsupported auth scheme"));
207            }
208        }
209
210        let Some(auth) = self.tokens.get(&token.token) else {
211            return Err(auth_error("Invalid access token"));
212        };
213
214        let mut ctx = auth.clone();
215        ctx.token.get_or_insert_with(|| token.clone());
216        Ok(ctx)
217    }
218}
219
220fn auth_error(message: impl Into<String>) -> McpError {
221    McpError::new(McpErrorCode::ResourceForbidden, message)
222}
223
224#[cfg(feature = "jwt")]
225mod jwt {
226    use super::{AuthContext, AuthRequest, TokenVerifier, auth_error};
227    use fastmcp_core::{AccessToken, McpContext, McpResult};
228    use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode};
229
230    /// JWT verifier for HMAC-SHA tokens.
231    #[derive(Debug, Clone)]
232    pub struct JwtTokenVerifier {
233        decoding_key: DecodingKey,
234        validation: Validation,
235    }
236
237    impl JwtTokenVerifier {
238        /// Creates an HS256 verifier from a shared secret.
239        #[must_use]
240        pub fn hs256(secret: impl AsRef<[u8]>) -> Self {
241            Self {
242                decoding_key: DecodingKey::from_secret(secret.as_ref()),
243                validation: Validation::new(Algorithm::HS256),
244            }
245        }
246
247        /// Overrides the JWT validation settings.
248        #[must_use]
249        pub fn with_validation(mut self, validation: Validation) -> Self {
250            self.validation = validation;
251            self
252        }
253    }
254
255    impl TokenVerifier for JwtTokenVerifier {
256        fn verify(
257            &self,
258            _ctx: &McpContext,
259            _request: AuthRequest<'_>,
260            token: &AccessToken,
261        ) -> McpResult<AuthContext> {
262            if !token.scheme.eq_ignore_ascii_case("Bearer") {
263                return Err(auth_error("Unsupported auth scheme"));
264            }
265
266            let data =
267                decode::<serde_json::Value>(&token.token, &self.decoding_key, &self.validation)
268                    .map_err(|err| auth_error(format!("Invalid token: {err}")))?;
269
270            let claims = data.claims;
271            let subject = claims
272                .get("sub")
273                .and_then(serde_json::Value::as_str)
274                .map(str::to_string);
275            let scopes = extract_scopes(&claims);
276
277            Ok(AuthContext {
278                subject,
279                scopes,
280                token: Some(token.clone()),
281                claims: Some(claims),
282            })
283        }
284    }
285
286    fn extract_scopes(claims: &serde_json::Value) -> Vec<String> {
287        let mut scopes = Vec::new();
288        if let Some(scope) = claims.get("scope").and_then(serde_json::Value::as_str) {
289            scopes.extend(scope.split_whitespace().map(str::to_string));
290        }
291        if let Some(list) = claims.get("scopes").and_then(serde_json::Value::as_array) {
292            scopes.extend(
293                list.iter()
294                    .filter_map(|value| value.as_str().map(str::to_string)),
295            );
296        }
297        scopes
298    }
299}
300
301#[cfg(feature = "jwt")]
302pub use jwt::JwtTokenVerifier;
303
304/// Default allow-all provider (returns anonymous auth context).
305#[derive(Debug, Default, Clone, Copy)]
306pub struct AllowAllAuthProvider;
307
308impl AuthProvider for AllowAllAuthProvider {
309    fn authenticate(&self, _ctx: &McpContext, _request: AuthRequest<'_>) -> McpResult<AuthContext> {
310        Ok(AuthContext::anonymous())
311    }
312}