Skip to main content

a3s_code_core/mcp/
oauth.rs

1//! MCP OAuth Token Exchange
2//!
3//! Implements the OAuth 2.0 Client Credentials flow for machine-to-machine
4//! authentication with MCP servers that require bearer tokens.
5
6use anyhow::{anyhow, Context, Result};
7
8/// Token response from an OAuth token endpoint.
9#[derive(Debug, serde::Deserialize)]
10pub struct TokenResponse {
11    pub access_token: String,
12    #[allow(dead_code)]
13    pub token_type: String,
14    #[allow(dead_code)]
15    pub expires_in: Option<u64>,
16    #[allow(dead_code)]
17    pub refresh_token: Option<String>,
18}
19
20/// Exchange client credentials for an access token (OAuth 2.0 Client Credentials flow).
21///
22/// Sends a `POST` to `token_url` with:
23/// ```text
24/// grant_type=client_credentials&client_id=...&client_secret=...&scope=...
25/// ```
26///
27/// Returns the raw `access_token` string on success.
28pub async fn exchange_client_credentials(
29    token_url: &str,
30    client_id: &str,
31    client_secret: &str,
32    scopes: &[String],
33) -> Result<String> {
34    let client = reqwest::Client::builder()
35        .build()
36        .context("Failed to build HTTP client for OAuth token exchange")?;
37
38    // Build application/x-www-form-urlencoded body
39    let scope_str = scopes.join(" ");
40    let params = [
41        ("grant_type", "client_credentials"),
42        ("client_id", client_id),
43        ("client_secret", client_secret),
44        ("scope", &scope_str),
45    ];
46
47    let response = client
48        .post(token_url)
49        .form(&params)
50        .send()
51        .await
52        .with_context(|| format!("OAuth token request to {} failed", token_url))?;
53
54    if !response.status().is_success() {
55        let status = response.status();
56        let body = response.text().await.unwrap_or_default();
57        return Err(anyhow!(
58            "OAuth token exchange failed at {} (HTTP {}): {}",
59            token_url,
60            status,
61            body
62        ));
63    }
64
65    let token_resp: TokenResponse = response
66        .json()
67        .await
68        .context("Failed to parse OAuth token response")?;
69
70    Ok(token_resp.access_token)
71}
72
73// ============================================================================
74// Tests
75// ============================================================================
76
77#[cfg(test)]
78mod tests {
79    use super::*;
80
81    #[test]
82    fn test_token_response_deserialize() {
83        let json = r#"{
84            "access_token": "eyJhbGciOiJSUzI1NiJ9...",
85            "token_type": "Bearer",
86            "expires_in": 3600
87        }"#;
88        let resp: TokenResponse = serde_json::from_str(json).unwrap();
89        assert_eq!(resp.access_token, "eyJhbGciOiJSUzI1NiJ9...");
90        assert_eq!(resp.token_type, "Bearer");
91        assert_eq!(resp.expires_in, Some(3600));
92        assert!(resp.refresh_token.is_none());
93    }
94
95    #[test]
96    fn test_token_response_with_refresh_token() {
97        let json = r#"{
98            "access_token": "access123",
99            "token_type": "Bearer",
100            "expires_in": 7200,
101            "refresh_token": "refresh456"
102        }"#;
103        let resp: TokenResponse = serde_json::from_str(json).unwrap();
104        assert_eq!(resp.access_token, "access123");
105        assert_eq!(resp.refresh_token, Some("refresh456".to_string()));
106    }
107
108    #[tokio::test]
109    async fn test_exchange_client_credentials_invalid_url() {
110        // Should fail gracefully — not panic
111        let result = exchange_client_credentials(
112            "http://127.0.0.1:1/token",
113            "client_id",
114            "client_secret",
115            &[],
116        )
117        .await;
118        assert!(result.is_err());
119    }
120}