mcp_guard_core/auth/
mod.rs

1// Copyright (c) 2025 Austin Green
2// SPDX-License-Identifier: AGPL-3.0
3//
4// This file is part of MCP-Guard.
5//
6// MCP-Guard is free software: you can redistribute it and/or modify
7// it under the terms of the GNU Affero General Public License as published by
8// the Free Software Foundation, either version 3 of the License, or
9// (at your option) any later version.
10//
11// MCP-Guard is distributed in the hope that it will be useful,
12// but WITHOUT ANY WARRANTY; without even the implied warranty of
13// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14// GNU Affero General Public License for more details.
15//
16// You should have received a copy of the GNU Affero General Public License
17// along with MCP-Guard. If not, see <https://www.gnu.org/licenses/>.
18//! Authentication providers for mcp-guard
19//!
20//! This module provides pluggable authentication for MCP requests:
21//! - API Key: Simple hash-based key validation
22//! - JWT: HS256 (simple) or RS256/ES256 (JWKS) token validation
23//! - OAuth 2.1: Token introspection and userinfo validation with PKCE
24//! - mTLS: Client certificate authentication via reverse proxy headers
25//!
26//! All providers implement the [`AuthProvider`] trait, allowing them to be
27//! combined via [`MultiProvider`] for fallback authentication.
28
29mod jwt;
30mod mtls;
31mod oauth;
32
33pub use jwt::JwtProvider;
34pub use mtls::{
35    ClientCertInfo, MtlsAuthProvider, HEADER_CLIENT_CERT_CN, HEADER_CLIENT_CERT_VERIFIED,
36};
37pub use oauth::OAuthAuthProvider;
38
39use async_trait::async_trait;
40use std::collections::HashMap;
41use std::sync::Arc;
42
43// ============================================================================
44// Error Types
45// ============================================================================
46
47/// Authentication error type
48#[derive(Debug, thiserror::Error)]
49pub enum AuthError {
50    #[error("Missing authentication credentials")]
51    MissingCredentials,
52
53    #[error("Invalid API key")]
54    InvalidApiKey,
55
56    #[error("Invalid JWT: {0}")]
57    InvalidJwt(String),
58
59    #[error("Token expired")]
60    TokenExpired,
61
62    #[error("OAuth error: {0}")]
63    OAuth(String),
64
65    #[error("Invalid client certificate: {0}")]
66    InvalidClientCert(String),
67
68    #[error("Internal error: {0}")]
69    Internal(String),
70}
71
72// ============================================================================
73// Types
74// ============================================================================
75
76/// Authenticated identity representing a user or service that has been verified
77#[derive(Debug, Clone)]
78pub struct Identity {
79    /// Unique identifier for the user/service
80    pub id: String,
81
82    /// Display name
83    pub name: Option<String>,
84
85    /// Allowed tools (None means all allowed)
86    pub allowed_tools: Option<Vec<String>>,
87
88    /// Custom rate limit for this identity
89    pub rate_limit: Option<u32>,
90
91    /// Additional claims/metadata from the authentication token
92    pub claims: std::collections::HashMap<String, serde_json::Value>,
93}
94
95// ============================================================================
96// Utility Functions
97// ============================================================================
98
99/// Map OAuth/JWT scopes to allowed tools based on a scope-to-tool mapping
100///
101/// # Arguments
102/// * `scopes` - List of scopes from the token
103/// * `scope_tool_mapping` - Mapping from scope names to tool names
104///
105/// # Returns
106/// * `None` - No restrictions (empty mapping or wildcard "*" scope)
107/// * `Some(vec![])` - No tools allowed (scopes not in mapping)
108/// * `Some(tools)` - Specific tools allowed
109pub fn map_scopes_to_tools(
110    scopes: &[String],
111    scope_tool_mapping: &HashMap<String, Vec<String>>,
112) -> Option<Vec<String>> {
113    if scope_tool_mapping.is_empty() {
114        return None; // No mapping = all tools allowed
115    }
116
117    let mut tools = Vec::new();
118    for scope in scopes {
119        if let Some(scope_tools) = scope_tool_mapping.get(scope) {
120            if scope_tools.contains(&"*".to_string()) {
121                return None; // Wildcard = all tools
122            }
123            tools.extend(scope_tools.iter().cloned());
124        }
125    }
126
127    if tools.is_empty() {
128        Some(vec![]) // Empty = no tools allowed (scope not in mapping)
129    } else {
130        tools.sort();
131        tools.dedup();
132        Some(tools)
133    }
134}
135
136// ============================================================================
137// Traits
138// ============================================================================
139
140/// Authentication provider trait
141#[async_trait]
142pub trait AuthProvider: Send + Sync {
143    /// Authenticate a request and return the identity
144    async fn authenticate(&self, token: &str) -> Result<Identity, AuthError>;
145
146    /// Provider name for logging and metrics
147    fn name(&self) -> &str;
148}
149
150// ============================================================================
151// Providers
152// ============================================================================
153
154/// API key authentication provider
155///
156/// Validates requests using pre-shared API keys. Keys are stored as SHA-256
157/// hashes to prevent exposure of plaintext keys in configuration.
158///
159/// SECURITY: Uses constant-time comparison to prevent timing attacks.
160pub struct ApiKeyProvider {
161    keys: Vec<crate::config::ApiKeyConfig>,
162}
163
164impl ApiKeyProvider {
165    pub fn new(configs: Vec<crate::config::ApiKeyConfig>) -> Self {
166        Self { keys: configs }
167    }
168
169    fn hash_key(key: &str) -> String {
170        use sha2::{Digest, Sha256};
171        let mut hasher = Sha256::new();
172        hasher.update(key.as_bytes());
173        base64::Engine::encode(
174            &base64::engine::general_purpose::STANDARD,
175            hasher.finalize(),
176        )
177    }
178
179    /// Constant-time comparison of two hash strings.
180    ///
181    /// SECURITY: Prevents timing attacks by ensuring comparison takes the same
182    /// amount of time regardless of where the hashes differ.
183    fn constant_time_compare(a: &str, b: &str) -> bool {
184        use subtle::ConstantTimeEq;
185
186        // First, compare lengths in constant time
187        let len_eq = a.len().ct_eq(&b.len());
188
189        // If lengths match, compare bytes in constant time
190        // If lengths differ, still compare to maintain constant time
191        let bytes_eq = if a.len() == b.len() {
192            a.as_bytes().ct_eq(b.as_bytes())
193        } else {
194            // Compare with dummy to maintain timing
195            let dummy = vec![0u8; a.len()];
196            a.as_bytes().ct_eq(&dummy)
197        };
198
199        // Both length and content must match
200        (len_eq & bytes_eq).into()
201    }
202}
203
204#[async_trait]
205impl AuthProvider for ApiKeyProvider {
206    async fn authenticate(&self, token: &str) -> Result<Identity, AuthError> {
207        let provided_hash = Self::hash_key(token);
208
209        // SECURITY: Iterate through ALL keys to prevent timing-based enumeration.
210        // The loop always runs for the same number of iterations regardless of
211        // which key matches (or if any matches at all).
212        let mut matched_config: Option<&crate::config::ApiKeyConfig> = None;
213
214        for config in &self.keys {
215            if Self::constant_time_compare(&provided_hash, &config.key_hash) {
216                matched_config = Some(config);
217                // Don't break - continue iterating to maintain constant time
218            }
219        }
220
221        matched_config
222            .map(|config| Identity {
223                id: config.id.clone(),
224                name: Some(config.id.clone()),
225                allowed_tools: if config.allowed_tools.is_empty() {
226                    None
227                } else {
228                    Some(config.allowed_tools.clone())
229                },
230                rate_limit: config.rate_limit,
231                claims: std::collections::HashMap::new(),
232            })
233            .ok_or(AuthError::InvalidApiKey)
234    }
235
236    fn name(&self) -> &str {
237        "api_key"
238    }
239}
240
241/// Combined authentication provider that tries multiple providers in sequence
242///
243/// Attempts authentication against each configured provider until one succeeds.
244/// Returns the most informative error if all providers fail (e.g., prefers
245/// "token expired" over "invalid API key").
246pub struct MultiProvider {
247    /// List of providers to try, in order of precedence
248    providers: Vec<Arc<dyn AuthProvider>>,
249}
250
251impl MultiProvider {
252    pub fn new(providers: Vec<Arc<dyn AuthProvider>>) -> Self {
253        Self { providers }
254    }
255}
256
257#[async_trait]
258impl AuthProvider for MultiProvider {
259    async fn authenticate(&self, token: &str) -> Result<Identity, AuthError> {
260        if self.providers.is_empty() {
261            return Err(AuthError::MissingCredentials);
262        }
263
264        let mut last_error: Option<AuthError> = None;
265
266        for provider in &self.providers {
267            match provider.authenticate(token).await {
268                Ok(identity) => return Ok(identity),
269                Err(e) => {
270                    // Prioritize more informative errors
271                    let should_replace = match (&last_error, &e) {
272                        (None, _) => true,
273                        // Token expired is more specific than generic errors
274                        (Some(AuthError::InvalidApiKey), AuthError::TokenExpired) => true,
275                        (Some(AuthError::InvalidApiKey), AuthError::InvalidJwt(_)) => true,
276                        (Some(AuthError::InvalidApiKey), AuthError::OAuth(_)) => true,
277                        (Some(AuthError::MissingCredentials), _) => true,
278                        // Keep the current error in other cases
279                        _ => false,
280                    };
281
282                    if should_replace {
283                        last_error = Some(e);
284                    }
285                }
286            }
287        }
288
289        Err(last_error.unwrap_or(AuthError::MissingCredentials))
290    }
291
292    fn name(&self) -> &str {
293        "multi"
294    }
295}
296
297#[cfg(test)]
298mod tests {
299    use super::*;
300
301    #[test]
302    fn test_constant_time_compare_equal() {
303        let a = "abc123XYZ";
304        let b = "abc123XYZ";
305        assert!(ApiKeyProvider::constant_time_compare(a, b));
306    }
307
308    #[test]
309    fn test_constant_time_compare_different_content() {
310        let a = "abc123XYZ";
311        let b = "abc123XYy"; // Last char different
312        assert!(!ApiKeyProvider::constant_time_compare(a, b));
313    }
314
315    #[test]
316    fn test_constant_time_compare_different_length() {
317        let a = "abc123";
318        let b = "abc123XYZ";
319        assert!(!ApiKeyProvider::constant_time_compare(a, b));
320    }
321
322    #[test]
323    fn test_constant_time_compare_empty() {
324        assert!(ApiKeyProvider::constant_time_compare("", ""));
325        assert!(!ApiKeyProvider::constant_time_compare("", "a"));
326        assert!(!ApiKeyProvider::constant_time_compare("a", ""));
327    }
328
329    #[test]
330    fn test_constant_time_compare_first_char_different() {
331        let a = "Xbc123XYZ";
332        let b = "abc123XYZ";
333        assert!(!ApiKeyProvider::constant_time_compare(a, b));
334    }
335
336    #[tokio::test]
337    async fn test_api_key_provider_valid_key() {
338        let key = "test-api-key-12345";
339        let hash = ApiKeyProvider::hash_key(key);
340
341        let config = crate::config::ApiKeyConfig {
342            id: "test-user".to_string(),
343            key_hash: hash,
344            allowed_tools: vec!["read".to_string()],
345            rate_limit: Some(100),
346        };
347
348        let provider = ApiKeyProvider::new(vec![config]);
349        let result = provider.authenticate(key).await;
350
351        assert!(result.is_ok());
352        let identity = result.unwrap();
353        assert_eq!(identity.id, "test-user");
354        assert_eq!(identity.allowed_tools, Some(vec!["read".to_string()]));
355    }
356
357    #[tokio::test]
358    async fn test_api_key_provider_invalid_key() {
359        let valid_key = "valid-key";
360        let hash = ApiKeyProvider::hash_key(valid_key);
361
362        let config = crate::config::ApiKeyConfig {
363            id: "test-user".to_string(),
364            key_hash: hash,
365            allowed_tools: vec![],
366            rate_limit: None,
367        };
368
369        let provider = ApiKeyProvider::new(vec![config]);
370        let result = provider.authenticate("wrong-key").await;
371
372        assert!(matches!(result, Err(AuthError::InvalidApiKey)));
373    }
374}