Skip to main content

oxi/
mcp_credentials.rs

1//! File-backed MCP credential provider (v2.2).
2//!
3//! Implements [`McpCredentialProvider`] using OAuth2 client_credentials.
4//! For each server with an [`OAuthConfig`] in `mcp.json`, the provider
5//! POSTs to `token_url` to exchange `client_id`/`client_secret` for
6//! an access token, caches it in `~/.config/oxi/mcp-tokens.json`, and
7//! returns it on [`McpCredentialProvider::access_token`]. Tokens are
8//! refreshed when missing or expired (based on `expires_in`).
9//!
10//! Browser-based authorization-code flow (with a local callback server)
11//! is **not** implemented in v2.2 — clients without a public `client_id`
12//! / `client_secret` (e.g. MCP servers that require user-interactive
13//! consent) cannot be auto-authenticated. Use the server's own hosted
14//! login and configure `oauth` to a machine-client for now.
15//!
16//! Trigger a manual refresh via `/mcp reauth <server>` in the TUI.
17
18use anyhow::{Context, Result};
19use oxi_agent::mcp::auth::{Credential, McpCredentialProvider};
20use oxi_agent::mcp::types::OAuthConfig;
21use parking_lot::RwLock;
22use serde::{Deserialize, Serialize};
23use std::collections::HashMap;
24use std::path::PathBuf;
25use std::sync::Arc;
26use std::time::{Duration, SystemTime, UNIX_EPOCH};
27
28const TOKEN_FILE: &str = "mcp-tokens.json";
29
30#[derive(Debug, Default, Clone, Serialize, Deserialize)]
31struct StoredToken {
32    access_token: String,
33    /// Unix seconds at which the token is known to be expired.
34    /// `None` means unknown / never expires (rely on refresh errors).
35    expires_at: Option<u64>,
36}
37
38#[derive(Debug, Default, Clone, Serialize, Deserialize)]
39struct TokenStore {
40    #[serde(default)]
41    tokens: HashMap<String, StoredToken>,
42}
43
44/// File-backed credential provider. Constructed once in the bootstrap
45/// and shared with [`McpManager`] via `Arc<dyn McpCredentialProvider>`.
46pub struct FileMcpCredentialProvider {
47    /// Per-server OAuth client config (from `mcp.json`).
48    oauth: HashMap<String, OAuthConfig>,
49    /// Cached tokens, loaded from disk on construction and saved
50    /// after every successful refresh (atomic temp + rename).
51    store: RwLock<TokenStore>,
52    store_path: PathBuf,
53    http: reqwest::Client,
54}
55
56impl FileMcpCredentialProvider {
57    /// Create a new file-backed credential provider.
58    ///
59    /// Loads any cached tokens from `config_dir` (or starts with an empty
60    /// store) and configures a 15-second-timeout HTTP client for refreshes.
61    pub fn new(oauth: HashMap<String, OAuthConfig>, config_dir: PathBuf) -> Result<Arc<Self>> {
62        let store_path = config_dir.join(TOKEN_FILE);
63        let store = if store_path.exists() {
64            match std::fs::read_to_string(&store_path) {
65                Ok(s) => serde_json::from_str::<TokenStore>(&s).unwrap_or_default(),
66                Err(_) => TokenStore::default(),
67            }
68        } else {
69            TokenStore::default()
70        };
71        let http = reqwest::Client::builder()
72            .timeout(Duration::from_secs(15))
73            .build()
74            .context("build reqwest client for MCP credential provider")?;
75        Ok(Arc::new(Self {
76            oauth,
77            store: RwLock::new(store),
78            store_path,
79            http,
80        }))
81    }
82
83    /// Force a refresh for `server` and persist the new token. Used by
84    /// `/mcp reauth <server>`.
85    pub async fn force_refresh(&self, server: &str) -> Result<()> {
86        let new = self.do_refresh(server).await?;
87        self.store_token(server, &new);
88        Ok(())
89    }
90
91    async fn do_refresh(&self, server: &str) -> Result<StoredToken> {
92        let cfg = self.oauth.get(server).cloned().ok_or_else(|| {
93            anyhow::anyhow!("Server '{}' has no OAuth config in mcp.json", server)
94        })?;
95        let mut form = vec![
96            ("grant_type", "client_credentials".to_string()),
97            ("client_id", cfg.client_id),
98            ("client_secret", cfg.client_secret),
99        ];
100        if let Some(scope) = cfg.scope.as_deref() {
101            form.push(("scope", scope.to_string()));
102        }
103        let resp = self
104            .http
105            .post(&cfg.token_url)
106            .header("Accept", "application/json")
107            .form(&form)
108            .send()
109            .await
110            .with_context(|| format!("OAuth token request to {} failed", cfg.token_url))?;
111        let status = resp.status();
112        let body: serde_json::Value = resp
113            .json()
114            .await
115            .context("OAuth token response was not JSON")?;
116        if !status.is_success() {
117            anyhow::bail!(
118                "OAuth token endpoint returned {}: {}",
119                status.as_u16(),
120                body
121            );
122        }
123        let access_token = body
124            .get("access_token")
125            .and_then(|v| v.as_str())
126            .ok_or_else(|| anyhow::anyhow!("OAuth response missing access_token"))?
127            .to_string();
128        let expires_at = body
129            .get("expires_in")
130            .and_then(|v| v.as_u64())
131            .and_then(|secs| now_secs().checked_add(secs));
132        Ok(StoredToken {
133            access_token,
134            expires_at,
135        })
136    }
137
138    fn store_token(&self, server: &str, token: &StoredToken) {
139        {
140            let mut s = self.store.write();
141            s.tokens.insert(server.to_string(), token.clone());
142        }
143        // Best-effort atomic write.
144        let snapshot = self.store.read().clone();
145        if let Ok(json) = serde_json::to_string_pretty(&snapshot) {
146            if let Some(parent) = self.store_path.parent() {
147                let _ = std::fs::create_dir_all(parent);
148            }
149            let tmp = self.store_path.with_extension("json.tmp");
150            if std::fs::write(&tmp, &json).is_ok() {
151                let _ = std::fs::rename(&tmp, &self.store_path);
152            }
153        }
154    }
155
156    fn token_is_fresh(&self, server: &str) -> Option<Credential> {
157        let s = self.store.read();
158        let t = s.tokens.get(server)?;
159        if let Some(exp) = t.expires_at {
160            // Treat tokens within 30s of expiry as stale so a request
161            // doesn't use a token that will expire mid-flight.
162            if now_secs() + 30 >= exp {
163                return None;
164            }
165        }
166        Some(Credential {
167            access_token: t.access_token.clone(),
168        })
169    }
170}
171
172#[async_trait::async_trait]
173impl McpCredentialProvider for FileMcpCredentialProvider {
174    async fn access_token(&self, server: &str, _url: &str) -> Option<Credential> {
175        if let Some(c) = self.token_is_fresh(server) {
176            return Some(c);
177        }
178        // Try to refresh; if refresh fails, return None (no auth).
179        match self.do_refresh(server).await {
180            Ok(token) => {
181                self.store_token(server, &token);
182                Some(Credential {
183                    access_token: token.access_token,
184                })
185            }
186            Err(e) => {
187                tracing::warn!("MCP credential refresh for '{}' failed: {}", server, e);
188                None
189            }
190        }
191    }
192
193    async fn refresh(&self, server: &str, _url: &str) -> Option<Credential> {
194        match self.do_refresh(server).await {
195            Ok(token) => {
196                self.store_token(server, &token);
197                Some(Credential {
198                    access_token: token.access_token,
199                })
200            }
201            Err(e) => {
202                tracing::warn!("MCP credential refresh for '{}' failed: {}", server, e);
203                None
204            }
205        }
206    }
207}
208
209fn now_secs() -> u64 {
210    SystemTime::now()
211        .duration_since(UNIX_EPOCH)
212        .map(|d| d.as_secs())
213        .unwrap_or(0)
214}