1use std::collections::HashMap;
7use std::sync::Arc;
8use tokio::sync::RwLock;
9
10#[derive(Clone, Default)]
12pub enum McpAuth {
13 #[default]
15 None,
16 Bearer(String),
18 ApiKey { header: String, key: String },
20 OAuth2(Arc<OAuth2Config>),
22}
23
24impl std::fmt::Debug for McpAuth {
25 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26 match self {
27 McpAuth::None => write!(f, "McpAuth::None"),
28 McpAuth::Bearer(_) => write!(f, "McpAuth::Bearer([REDACTED])"),
29 McpAuth::ApiKey { header, .. } => write!(f, "McpAuth::ApiKey {{ header: {} }}", header),
30 McpAuth::OAuth2(_) => write!(f, "McpAuth::OAuth2([CONFIG])"),
31 }
32 }
33}
34
35impl McpAuth {
36 pub fn bearer(token: impl Into<String>) -> Self {
38 McpAuth::Bearer(token.into())
39 }
40
41 pub fn api_key(header: impl Into<String>, key: impl Into<String>) -> Self {
43 McpAuth::ApiKey { header: header.into(), key: key.into() }
44 }
45
46 pub fn oauth2(config: OAuth2Config) -> Self {
48 McpAuth::OAuth2(Arc::new(config))
49 }
50
51 pub async fn get_headers(&self) -> Result<HashMap<String, String>, AuthError> {
53 let mut headers = HashMap::new();
54
55 match self {
56 McpAuth::None => {}
57 McpAuth::Bearer(token) => {
58 headers.insert("Authorization".to_string(), format!("Bearer {}", token));
59 }
60 McpAuth::ApiKey { header, key } => {
61 headers.insert(header.clone(), key.clone());
62 }
63 McpAuth::OAuth2(config) => {
64 let token = config.get_or_refresh_token().await?;
65 headers.insert("Authorization".to_string(), format!("Bearer {}", token));
66 }
67 }
68
69 Ok(headers)
70 }
71
72 pub fn is_configured(&self) -> bool {
74 !matches!(self, McpAuth::None)
75 }
76}
77
78pub struct OAuth2Config {
80 pub client_id: String,
82 pub client_secret: Option<String>,
84 pub token_url: String,
86 pub scopes: Vec<String>,
88 token_cache: RwLock<Option<CachedToken>>,
90}
91
92impl OAuth2Config {
93 pub fn new(client_id: impl Into<String>, token_url: impl Into<String>) -> Self {
95 Self {
96 client_id: client_id.into(),
97 client_secret: None,
98 token_url: token_url.into(),
99 scopes: Vec::new(),
100 token_cache: RwLock::new(None),
101 }
102 }
103
104 pub fn with_secret(mut self, secret: impl Into<String>) -> Self {
106 self.client_secret = Some(secret.into());
107 self
108 }
109
110 pub fn with_scopes(mut self, scopes: Vec<String>) -> Self {
112 self.scopes = scopes;
113 self
114 }
115
116 pub async fn get_or_refresh_token(&self) -> Result<String, AuthError> {
118 {
120 let cache = self.token_cache.read().await;
121 if let Some(ref cached) = *cache {
122 if !cached.is_expired() {
123 return Ok(cached.access_token.clone());
124 }
125 }
126 }
127
128 let token = self.fetch_token().await?;
130
131 {
133 let mut cache = self.token_cache.write().await;
134 *cache = Some(token.clone());
135 }
136
137 Ok(token.access_token)
138 }
139
140 async fn fetch_token(&self) -> Result<CachedToken, AuthError> {
142 let mut params = vec![
144 ("grant_type", "client_credentials".to_string()),
145 ("client_id", self.client_id.clone()),
146 ];
147
148 if let Some(ref secret) = self.client_secret {
149 params.push(("client_secret", secret.clone()));
150 }
151
152 if !self.scopes.is_empty() {
153 params.push(("scope", self.scopes.join(" ")));
154 }
155
156 #[cfg(feature = "http-transport")]
158 {
159 let client = reqwest::Client::new();
160 let response = client
161 .post(&self.token_url)
162 .form(¶ms)
163 .send()
164 .await
165 .map_err(|e| AuthError::TokenFetch(e.to_string()))?;
166
167 if !response.status().is_success() {
168 let status = response.status();
169 let body = response.text().await.unwrap_or_default();
170 return Err(AuthError::TokenFetch(format!(
171 "Token request failed: {} - {}",
172 status, body
173 )));
174 }
175
176 let token_response: TokenResponse =
177 response.json().await.map_err(|e| AuthError::TokenParse(e.to_string()))?;
178
179 Ok(CachedToken::from_response(token_response))
180 }
181
182 #[cfg(not(feature = "http-transport"))]
183 {
184 Err(AuthError::NotSupported("OAuth2 requires the 'http-transport' feature".to_string()))
185 }
186 }
187
188 pub async fn clear_cache(&self) {
190 let mut cache = self.token_cache.write().await;
191 *cache = None;
192 }
193}
194
195#[derive(Clone)]
197#[allow(dead_code)] struct CachedToken {
199 access_token: String,
200 expires_at: Option<std::time::Instant>,
201 refresh_token: Option<String>,
202}
203
204#[allow(dead_code)] impl CachedToken {
206 fn from_response(response: TokenResponse) -> Self {
207 let expires_at = response.expires_in.map(|secs| {
208 std::time::Instant::now() + std::time::Duration::from_secs(secs.saturating_sub(60))
210 });
211
212 Self {
213 access_token: response.access_token,
214 expires_at,
215 refresh_token: response.refresh_token,
216 }
217 }
218
219 fn is_expired(&self) -> bool {
220 match self.expires_at {
221 Some(expires_at) => std::time::Instant::now() >= expires_at,
222 None => false, }
224 }
225}
226
227#[derive(serde::Deserialize)]
229#[allow(dead_code)] struct TokenResponse {
231 access_token: String,
232 #[serde(default)]
233 expires_in: Option<u64>,
234 #[serde(default)]
235 refresh_token: Option<String>,
236 #[serde(default)]
237 token_type: Option<String>,
238}
239
240#[derive(Debug, Clone)]
242pub enum AuthError {
243 TokenFetch(String),
245 TokenParse(String),
247 TokenExpired(String),
249 NotSupported(String),
251}
252
253impl std::fmt::Display for AuthError {
254 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
255 match self {
256 AuthError::TokenFetch(msg) => write!(f, "Token fetch failed: {}", msg),
257 AuthError::TokenParse(msg) => write!(f, "Token parse failed: {}", msg),
258 AuthError::TokenExpired(msg) => write!(f, "Token expired: {}", msg),
259 AuthError::NotSupported(msg) => write!(f, "Not supported: {}", msg),
260 }
261 }
262}
263
264impl std::error::Error for AuthError {}
265
266#[cfg(test)]
267mod tests {
268 use super::*;
269
270 #[test]
271 fn test_mcp_auth_none() {
272 let auth = McpAuth::None;
273 assert!(!auth.is_configured());
274 }
275
276 #[test]
277 fn test_mcp_auth_bearer() {
278 let auth = McpAuth::bearer("test-token");
279 assert!(auth.is_configured());
280 }
281
282 #[test]
283 fn test_mcp_auth_api_key() {
284 let auth = McpAuth::api_key("X-API-Key", "secret-key");
285 assert!(auth.is_configured());
286 }
287
288 #[tokio::test]
289 async fn test_bearer_headers() {
290 let auth = McpAuth::bearer("my-token");
291 let headers = auth.get_headers().await.unwrap();
292 assert_eq!(headers.get("Authorization"), Some(&"Bearer my-token".to_string()));
293 }
294
295 #[tokio::test]
296 async fn test_api_key_headers() {
297 let auth = McpAuth::api_key("X-API-Key", "secret");
298 let headers = auth.get_headers().await.unwrap();
299 assert_eq!(headers.get("X-API-Key"), Some(&"secret".to_string()));
300 }
301
302 #[test]
303 fn test_oauth2_config() {
304 let config = OAuth2Config::new("client-id", "https://auth.example.com/token")
305 .with_secret("client-secret")
306 .with_scopes(vec!["read".to_string(), "write".to_string()]);
307
308 assert_eq!(config.client_id, "client-id");
309 assert_eq!(config.client_secret, Some("client-secret".to_string()));
310 assert_eq!(config.scopes, vec!["read", "write"]);
311 }
312}