Skip to main content

adk_tool/mcp/
auth.rs

1// MCP Authentication Support
2//
3// Provides authentication mechanisms for remote MCP servers.
4// Integrates with adk-auth for SSO/OAuth support.
5
6use std::collections::HashMap;
7use std::sync::Arc;
8use tokio::sync::RwLock;
9
10/// Authentication configuration for MCP connections
11#[derive(Clone, Default)]
12pub enum McpAuth {
13    /// No authentication required
14    #[default]
15    None,
16    /// Static bearer token
17    Bearer(String),
18    /// API key in header
19    ApiKey { header: String, key: String },
20    /// OAuth2 with automatic token refresh
21    OAuth2(Arc<OAuth2Config>),
22}
23
24impl std::fmt::Debug for McpAuth {
25    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26        match self {
27            McpAuth::None => write!(f, "McpAuth::None"),
28            McpAuth::Bearer(_) => write!(f, "McpAuth::Bearer([REDACTED])"),
29            McpAuth::ApiKey { header, .. } => write!(f, "McpAuth::ApiKey {{ header: {} }}", header),
30            McpAuth::OAuth2(_) => write!(f, "McpAuth::OAuth2([CONFIG])"),
31        }
32    }
33}
34
35impl McpAuth {
36    /// Create bearer token auth
37    pub fn bearer(token: impl Into<String>) -> Self {
38        McpAuth::Bearer(token.into())
39    }
40
41    /// Create API key auth
42    pub fn api_key(header: impl Into<String>, key: impl Into<String>) -> Self {
43        McpAuth::ApiKey { header: header.into(), key: key.into() }
44    }
45
46    /// Create OAuth2 auth
47    pub fn oauth2(config: OAuth2Config) -> Self {
48        McpAuth::OAuth2(Arc::new(config))
49    }
50
51    /// Get authorization headers for a request
52    pub async fn get_headers(&self) -> Result<HashMap<String, String>, AuthError> {
53        let mut headers = HashMap::new();
54
55        match self {
56            McpAuth::None => {}
57            McpAuth::Bearer(token) => {
58                headers.insert("Authorization".to_string(), format!("Bearer {}", token));
59            }
60            McpAuth::ApiKey { header, key } => {
61                headers.insert(header.clone(), key.clone());
62            }
63            McpAuth::OAuth2(config) => {
64                let token = config.get_or_refresh_token().await?;
65                headers.insert("Authorization".to_string(), format!("Bearer {}", token));
66            }
67        }
68
69        Ok(headers)
70    }
71
72    /// Check if authentication is configured
73    pub fn is_configured(&self) -> bool {
74        !matches!(self, McpAuth::None)
75    }
76}
77
78/// OAuth2 configuration for MCP authentication
79pub struct OAuth2Config {
80    /// OAuth2 client ID
81    pub client_id: String,
82    /// OAuth2 client secret (optional for public clients)
83    pub client_secret: Option<String>,
84    /// Token endpoint URL
85    pub token_url: String,
86    /// Requested scopes
87    pub scopes: Vec<String>,
88    /// Cached token with expiry
89    token_cache: RwLock<Option<CachedToken>>,
90}
91
92impl OAuth2Config {
93    /// Create a new OAuth2 config
94    pub fn new(client_id: impl Into<String>, token_url: impl Into<String>) -> Self {
95        Self {
96            client_id: client_id.into(),
97            client_secret: None,
98            token_url: token_url.into(),
99            scopes: Vec::new(),
100            token_cache: RwLock::new(None),
101        }
102    }
103
104    /// Set client secret
105    pub fn with_secret(mut self, secret: impl Into<String>) -> Self {
106        self.client_secret = Some(secret.into());
107        self
108    }
109
110    /// Add scopes
111    pub fn with_scopes(mut self, scopes: Vec<String>) -> Self {
112        self.scopes = scopes;
113        self
114    }
115
116    /// Get or refresh the access token
117    pub async fn get_or_refresh_token(&self) -> Result<String, AuthError> {
118        // Check cache first
119        {
120            let cache = self.token_cache.read().await;
121            if let Some(ref cached) = *cache {
122                if !cached.is_expired() {
123                    return Ok(cached.access_token.clone());
124                }
125            }
126        }
127
128        // Need to refresh
129        let token = self.fetch_token().await?;
130
131        // Update cache
132        {
133            let mut cache = self.token_cache.write().await;
134            *cache = Some(token.clone());
135        }
136
137        Ok(token.access_token)
138    }
139
140    /// Fetch a new token from the token endpoint
141    async fn fetch_token(&self) -> Result<CachedToken, AuthError> {
142        // Build request body
143        let mut params = vec![
144            ("grant_type", "client_credentials".to_string()),
145            ("client_id", self.client_id.clone()),
146        ];
147
148        if let Some(ref secret) = self.client_secret {
149            params.push(("client_secret", secret.clone()));
150        }
151
152        if !self.scopes.is_empty() {
153            params.push(("scope", self.scopes.join(" ")));
154        }
155
156        // Make request (using reqwest if available)
157        #[cfg(feature = "http-transport")]
158        {
159            let client = reqwest::Client::new();
160            let response = client
161                .post(&self.token_url)
162                .form(&params)
163                .send()
164                .await
165                .map_err(|e| AuthError::TokenFetch(e.to_string()))?;
166
167            if !response.status().is_success() {
168                let status = response.status();
169                let body = response.text().await.unwrap_or_default();
170                return Err(AuthError::TokenFetch(format!(
171                    "Token request failed: {} - {}",
172                    status, body
173                )));
174            }
175
176            let token_response: TokenResponse =
177                response.json().await.map_err(|e| AuthError::TokenParse(e.to_string()))?;
178
179            Ok(CachedToken::from_response(token_response))
180        }
181
182        #[cfg(not(feature = "http-transport"))]
183        {
184            Err(AuthError::NotSupported("OAuth2 requires the 'http-transport' feature".to_string()))
185        }
186    }
187
188    /// Clear the token cache (force refresh on next request)
189    pub async fn clear_cache(&self) {
190        let mut cache = self.token_cache.write().await;
191        *cache = None;
192    }
193}
194
195/// Cached OAuth2 token
196#[derive(Clone)]
197#[allow(dead_code)] // Used when http-transport feature is enabled
198struct CachedToken {
199    access_token: String,
200    expires_at: Option<std::time::Instant>,
201    refresh_token: Option<String>,
202}
203
204#[allow(dead_code)] // Used when http-transport feature is enabled
205impl CachedToken {
206    fn from_response(response: TokenResponse) -> Self {
207        let expires_at = response.expires_in.map(|secs| {
208            // Refresh 60 seconds before actual expiry
209            std::time::Instant::now() + std::time::Duration::from_secs(secs.saturating_sub(60))
210        });
211
212        Self {
213            access_token: response.access_token,
214            expires_at,
215            refresh_token: response.refresh_token,
216        }
217    }
218
219    fn is_expired(&self) -> bool {
220        match self.expires_at {
221            Some(expires_at) => std::time::Instant::now() >= expires_at,
222            None => false, // No expiry = never expires
223        }
224    }
225}
226
227/// OAuth2 token response
228#[derive(serde::Deserialize)]
229#[allow(dead_code)] // Used when http-transport feature is enabled
230struct TokenResponse {
231    access_token: String,
232    #[serde(default)]
233    expires_in: Option<u64>,
234    #[serde(default)]
235    refresh_token: Option<String>,
236    #[serde(default)]
237    token_type: Option<String>,
238}
239
240/// Authentication errors
241#[derive(Debug, Clone)]
242pub enum AuthError {
243    /// Failed to fetch token
244    TokenFetch(String),
245    /// Failed to parse token response
246    TokenParse(String),
247    /// Token expired and refresh failed
248    TokenExpired(String),
249    /// Feature not supported
250    NotSupported(String),
251}
252
253impl std::fmt::Display for AuthError {
254    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
255        match self {
256            AuthError::TokenFetch(msg) => write!(f, "Token fetch failed: {}", msg),
257            AuthError::TokenParse(msg) => write!(f, "Token parse failed: {}", msg),
258            AuthError::TokenExpired(msg) => write!(f, "Token expired: {}", msg),
259            AuthError::NotSupported(msg) => write!(f, "Not supported: {}", msg),
260        }
261    }
262}
263
264impl std::error::Error for AuthError {}
265
266#[cfg(test)]
267mod tests {
268    use super::*;
269
270    #[test]
271    fn test_mcp_auth_none() {
272        let auth = McpAuth::None;
273        assert!(!auth.is_configured());
274    }
275
276    #[test]
277    fn test_mcp_auth_bearer() {
278        let auth = McpAuth::bearer("test-token");
279        assert!(auth.is_configured());
280    }
281
282    #[test]
283    fn test_mcp_auth_api_key() {
284        let auth = McpAuth::api_key("X-API-Key", "secret-key");
285        assert!(auth.is_configured());
286    }
287
288    #[tokio::test]
289    async fn test_bearer_headers() {
290        let auth = McpAuth::bearer("my-token");
291        let headers = auth.get_headers().await.unwrap();
292        assert_eq!(headers.get("Authorization"), Some(&"Bearer my-token".to_string()));
293    }
294
295    #[tokio::test]
296    async fn test_api_key_headers() {
297        let auth = McpAuth::api_key("X-API-Key", "secret");
298        let headers = auth.get_headers().await.unwrap();
299        assert_eq!(headers.get("X-API-Key"), Some(&"secret".to_string()));
300    }
301
302    #[test]
303    fn test_oauth2_config() {
304        let config = OAuth2Config::new("client-id", "https://auth.example.com/token")
305            .with_secret("client-secret")
306            .with_scopes(vec!["read".to_string(), "write".to_string()]);
307
308        assert_eq!(config.client_id, "client-id");
309        assert_eq!(config.client_secret, Some("client-secret".to_string()));
310        assert_eq!(config.scopes, vec!["read", "write"]);
311    }
312}