mcp_guard_core/auth/
mod.rs1mod jwt;
30mod mtls;
31mod oauth;
32
33pub use jwt::JwtProvider;
34pub use mtls::{
35 ClientCertInfo, MtlsAuthProvider, HEADER_CLIENT_CERT_CN, HEADER_CLIENT_CERT_VERIFIED,
36};
37pub use oauth::OAuthAuthProvider;
38
39use async_trait::async_trait;
40use std::collections::HashMap;
41use std::sync::Arc;
42
43#[derive(Debug, thiserror::Error)]
49pub enum AuthError {
50 #[error("Missing authentication credentials")]
51 MissingCredentials,
52
53 #[error("Invalid API key")]
54 InvalidApiKey,
55
56 #[error("Invalid JWT: {0}")]
57 InvalidJwt(String),
58
59 #[error("Token expired")]
60 TokenExpired,
61
62 #[error("OAuth error: {0}")]
63 OAuth(String),
64
65 #[error("Invalid client certificate: {0}")]
66 InvalidClientCert(String),
67
68 #[error("Internal error: {0}")]
69 Internal(String),
70}
71
72#[derive(Debug, Clone)]
78pub struct Identity {
79 pub id: String,
81
82 pub name: Option<String>,
84
85 pub allowed_tools: Option<Vec<String>>,
87
88 pub rate_limit: Option<u32>,
90
91 pub claims: std::collections::HashMap<String, serde_json::Value>,
93}
94
95pub fn map_scopes_to_tools(
110 scopes: &[String],
111 scope_tool_mapping: &HashMap<String, Vec<String>>,
112) -> Option<Vec<String>> {
113 if scope_tool_mapping.is_empty() {
114 return None; }
116
117 let mut tools = Vec::new();
118 for scope in scopes {
119 if let Some(scope_tools) = scope_tool_mapping.get(scope) {
120 if scope_tools.contains(&"*".to_string()) {
121 return None; }
123 tools.extend(scope_tools.iter().cloned());
124 }
125 }
126
127 if tools.is_empty() {
128 Some(vec![]) } else {
130 tools.sort();
131 tools.dedup();
132 Some(tools)
133 }
134}
135
136#[async_trait]
142pub trait AuthProvider: Send + Sync {
143 async fn authenticate(&self, token: &str) -> Result<Identity, AuthError>;
145
146 fn name(&self) -> &str;
148}
149
150pub struct ApiKeyProvider {
161 keys: Vec<crate::config::ApiKeyConfig>,
162}
163
164impl ApiKeyProvider {
165 pub fn new(configs: Vec<crate::config::ApiKeyConfig>) -> Self {
166 Self { keys: configs }
167 }
168
169 fn hash_key(key: &str) -> String {
170 use sha2::{Digest, Sha256};
171 let mut hasher = Sha256::new();
172 hasher.update(key.as_bytes());
173 base64::Engine::encode(
174 &base64::engine::general_purpose::STANDARD,
175 hasher.finalize(),
176 )
177 }
178
179 fn constant_time_compare(a: &str, b: &str) -> bool {
184 use subtle::ConstantTimeEq;
185
186 let len_eq = a.len().ct_eq(&b.len());
188
189 let bytes_eq = if a.len() == b.len() {
192 a.as_bytes().ct_eq(b.as_bytes())
193 } else {
194 let dummy = vec![0u8; a.len()];
196 a.as_bytes().ct_eq(&dummy)
197 };
198
199 (len_eq & bytes_eq).into()
201 }
202}
203
204#[async_trait]
205impl AuthProvider for ApiKeyProvider {
206 async fn authenticate(&self, token: &str) -> Result<Identity, AuthError> {
207 let provided_hash = Self::hash_key(token);
208
209 let mut matched_config: Option<&crate::config::ApiKeyConfig> = None;
213
214 for config in &self.keys {
215 if Self::constant_time_compare(&provided_hash, &config.key_hash) {
216 matched_config = Some(config);
217 }
219 }
220
221 matched_config
222 .map(|config| Identity {
223 id: config.id.clone(),
224 name: Some(config.id.clone()),
225 allowed_tools: if config.allowed_tools.is_empty() {
226 None
227 } else {
228 Some(config.allowed_tools.clone())
229 },
230 rate_limit: config.rate_limit,
231 claims: std::collections::HashMap::new(),
232 })
233 .ok_or(AuthError::InvalidApiKey)
234 }
235
236 fn name(&self) -> &str {
237 "api_key"
238 }
239}
240
241pub struct MultiProvider {
247 providers: Vec<Arc<dyn AuthProvider>>,
249}
250
251impl MultiProvider {
252 pub fn new(providers: Vec<Arc<dyn AuthProvider>>) -> Self {
253 Self { providers }
254 }
255}
256
257#[async_trait]
258impl AuthProvider for MultiProvider {
259 async fn authenticate(&self, token: &str) -> Result<Identity, AuthError> {
260 if self.providers.is_empty() {
261 return Err(AuthError::MissingCredentials);
262 }
263
264 let mut last_error: Option<AuthError> = None;
265
266 for provider in &self.providers {
267 match provider.authenticate(token).await {
268 Ok(identity) => return Ok(identity),
269 Err(e) => {
270 let should_replace = match (&last_error, &e) {
272 (None, _) => true,
273 (Some(AuthError::InvalidApiKey), AuthError::TokenExpired) => true,
275 (Some(AuthError::InvalidApiKey), AuthError::InvalidJwt(_)) => true,
276 (Some(AuthError::InvalidApiKey), AuthError::OAuth(_)) => true,
277 (Some(AuthError::MissingCredentials), _) => true,
278 _ => false,
280 };
281
282 if should_replace {
283 last_error = Some(e);
284 }
285 }
286 }
287 }
288
289 Err(last_error.unwrap_or(AuthError::MissingCredentials))
290 }
291
292 fn name(&self) -> &str {
293 "multi"
294 }
295}
296
297#[cfg(test)]
298mod tests {
299 use super::*;
300
301 #[test]
302 fn test_constant_time_compare_equal() {
303 let a = "abc123XYZ";
304 let b = "abc123XYZ";
305 assert!(ApiKeyProvider::constant_time_compare(a, b));
306 }
307
308 #[test]
309 fn test_constant_time_compare_different_content() {
310 let a = "abc123XYZ";
311 let b = "abc123XYy"; assert!(!ApiKeyProvider::constant_time_compare(a, b));
313 }
314
315 #[test]
316 fn test_constant_time_compare_different_length() {
317 let a = "abc123";
318 let b = "abc123XYZ";
319 assert!(!ApiKeyProvider::constant_time_compare(a, b));
320 }
321
322 #[test]
323 fn test_constant_time_compare_empty() {
324 assert!(ApiKeyProvider::constant_time_compare("", ""));
325 assert!(!ApiKeyProvider::constant_time_compare("", "a"));
326 assert!(!ApiKeyProvider::constant_time_compare("a", ""));
327 }
328
329 #[test]
330 fn test_constant_time_compare_first_char_different() {
331 let a = "Xbc123XYZ";
332 let b = "abc123XYZ";
333 assert!(!ApiKeyProvider::constant_time_compare(a, b));
334 }
335
336 #[tokio::test]
337 async fn test_api_key_provider_valid_key() {
338 let key = "test-api-key-12345";
339 let hash = ApiKeyProvider::hash_key(key);
340
341 let config = crate::config::ApiKeyConfig {
342 id: "test-user".to_string(),
343 key_hash: hash,
344 allowed_tools: vec!["read".to_string()],
345 rate_limit: Some(100),
346 };
347
348 let provider = ApiKeyProvider::new(vec![config]);
349 let result = provider.authenticate(key).await;
350
351 assert!(result.is_ok());
352 let identity = result.unwrap();
353 assert_eq!(identity.id, "test-user");
354 assert_eq!(identity.allowed_tools, Some(vec!["read".to_string()]));
355 }
356
357 #[tokio::test]
358 async fn test_api_key_provider_invalid_key() {
359 let valid_key = "valid-key";
360 let hash = ApiKeyProvider::hash_key(valid_key);
361
362 let config = crate::config::ApiKeyConfig {
363 id: "test-user".to_string(),
364 key_hash: hash,
365 allowed_tools: vec![],
366 rate_limit: None,
367 };
368
369 let provider = ApiKeyProvider::new(vec![config]);
370 let result = provider.authenticate("wrong-key").await;
371
372 assert!(matches!(result, Err(AuthError::InvalidApiKey)));
373 }
374}