Skip to main content

oxi/
mcp_credentials.rs

1//! File-backed MCP credential provider (v2.2).
2//!
3//! Implements [`McpCredentialProvider`] using OAuth2 client_credentials.
4//! For each server with an [`OAuthConfig`] in `mcp.json`, the provider
5//! POSTs to `token_url` to exchange `client_id`/`client_secret` for
6//! an access token, caches it in `~/.config/oxi/mcp-tokens.json`, and
7//! returns it on [`McpCredentialProvider::access_token`]. Tokens are
8//! refreshed when missing or expired (based on `expires_in`).
9//!
10//! Browser-based authorization-code flow (with a local callback server)
11//! is **not** implemented in v2.2 — clients without a public `client_id`
12//! / `client_secret` (e.g. MCP servers that require user-interactive
13//! consent) cannot be auto-authenticated. Use the server's own hosted
14//! login and configure `oauth` to a machine-client for now.
15//!
16//! Trigger a manual refresh via `/mcp reauth <server>` in the TUI.
17
18use anyhow::{Context, Result};
19use oxi_agent::mcp::auth::{Credential, McpCredentialProvider};
20use oxi_agent::mcp::types::OAuthConfig;
21use parking_lot::RwLock;
22use serde::{Deserialize, Serialize};
23use std::collections::HashMap;
24use std::path::PathBuf;
25use std::sync::Arc;
26use std::time::{Duration, SystemTime, UNIX_EPOCH};
27
28const TOKEN_FILE: &str = "mcp-tokens.json";
29
30#[derive(Debug, Default, Clone, Serialize, Deserialize)]
31struct StoredToken {
32    access_token: String,
33    /// Unix seconds at which the token is known to be expired.
34    /// `None` means unknown / never expires (rely on refresh errors).
35    expires_at: Option<u64>,
36}
37
38#[derive(Debug, Default, Clone, Serialize, Deserialize)]
39struct TokenStore {
40    #[serde(default)]
41    tokens: HashMap<String, StoredToken>,
42}
43
44/// File-backed credential provider. Constructed once in the bootstrap
45/// and shared with [`McpManager`] via `Arc<dyn McpCredentialProvider>`.
46pub struct FileMcpCredentialProvider {
47    /// Per-server OAuth client config (from `mcp.json`).
48    oauth: HashMap<String, OAuthConfig>,
49    /// Cached tokens, loaded from disk on construction and saved
50    /// after every successful refresh (atomic temp + rename).
51    store: RwLock<TokenStore>,
52    store_path: PathBuf,
53    http: reqwest::Client,
54}
55
56impl FileMcpCredentialProvider {
57    pub fn new(oauth: HashMap<String, OAuthConfig>, config_dir: PathBuf) -> Result<Arc<Self>> {
58        let store_path = config_dir.join(TOKEN_FILE);
59        let store = if store_path.exists() {
60            match std::fs::read_to_string(&store_path) {
61                Ok(s) => serde_json::from_str::<TokenStore>(&s).unwrap_or_default(),
62                Err(_) => TokenStore::default(),
63            }
64        } else {
65            TokenStore::default()
66        };
67        let http = reqwest::Client::builder()
68            .timeout(Duration::from_secs(15))
69            .build()
70            .context("build reqwest client for MCP credential provider")?;
71        Ok(Arc::new(Self {
72            oauth,
73            store: RwLock::new(store),
74            store_path,
75            http,
76        }))
77    }
78
79    /// Force a refresh for `server` and persist the new token. Used by
80    /// `/mcp reauth <server>`.
81    pub async fn force_refresh(&self, server: &str) -> Result<()> {
82        let new = self.do_refresh(server).await?;
83        self.store_token(server, &new);
84        Ok(())
85    }
86
87    async fn do_refresh(&self, server: &str) -> Result<StoredToken> {
88        let cfg = self.oauth.get(server).cloned().ok_or_else(|| {
89            anyhow::anyhow!("Server '{}' has no OAuth config in mcp.json", server)
90        })?;
91        let mut form = vec![
92            ("grant_type", "client_credentials".to_string()),
93            ("client_id", cfg.client_id),
94            ("client_secret", cfg.client_secret),
95        ];
96        if let Some(scope) = cfg.scope.as_deref() {
97            form.push(("scope", scope.to_string()));
98        }
99        let resp = self
100            .http
101            .post(&cfg.token_url)
102            .header("Accept", "application/json")
103            .form(&form)
104            .send()
105            .await
106            .with_context(|| format!("OAuth token request to {} failed", cfg.token_url))?;
107        let status = resp.status();
108        let body: serde_json::Value = resp
109            .json()
110            .await
111            .context("OAuth token response was not JSON")?;
112        if !status.is_success() {
113            anyhow::bail!(
114                "OAuth token endpoint returned {}: {}",
115                status.as_u16(),
116                body
117            );
118        }
119        let access_token = body
120            .get("access_token")
121            .and_then(|v| v.as_str())
122            .ok_or_else(|| anyhow::anyhow!("OAuth response missing access_token"))?
123            .to_string();
124        let expires_at = body
125            .get("expires_in")
126            .and_then(|v| v.as_u64())
127            .and_then(|secs| now_secs().checked_add(secs));
128        Ok(StoredToken {
129            access_token,
130            expires_at,
131        })
132    }
133
134    fn store_token(&self, server: &str, token: &StoredToken) {
135        {
136            let mut s = self.store.write();
137            s.tokens.insert(server.to_string(), token.clone());
138        }
139        // Best-effort atomic write.
140        let snapshot = self.store.read().clone();
141        if let Ok(json) = serde_json::to_string_pretty(&snapshot) {
142            if let Some(parent) = self.store_path.parent() {
143                let _ = std::fs::create_dir_all(parent);
144            }
145            let tmp = self.store_path.with_extension("json.tmp");
146            if std::fs::write(&tmp, &json).is_ok() {
147                let _ = std::fs::rename(&tmp, &self.store_path);
148            }
149        }
150    }
151
152    fn token_is_fresh(&self, server: &str) -> Option<Credential> {
153        let s = self.store.read();
154        let t = s.tokens.get(server)?;
155        if let Some(exp) = t.expires_at {
156            // Treat tokens within 30s of expiry as stale so a request
157            // doesn't use a token that will expire mid-flight.
158            if now_secs() + 30 >= exp {
159                return None;
160            }
161        }
162        Some(Credential {
163            access_token: t.access_token.clone(),
164        })
165    }
166}
167
168#[async_trait::async_trait]
169impl McpCredentialProvider for FileMcpCredentialProvider {
170    async fn access_token(&self, server: &str, _url: &str) -> Option<Credential> {
171        if let Some(c) = self.token_is_fresh(server) {
172            return Some(c);
173        }
174        // Try to refresh; if refresh fails, return None (no auth).
175        match self.do_refresh(server).await {
176            Ok(token) => {
177                self.store_token(server, &token);
178                Some(Credential {
179                    access_token: token.access_token,
180                })
181            }
182            Err(e) => {
183                tracing::warn!("MCP credential refresh for '{}' failed: {}", server, e);
184                None
185            }
186        }
187    }
188
189    async fn refresh(&self, server: &str, _url: &str) -> Option<Credential> {
190        match self.do_refresh(server).await {
191            Ok(token) => {
192                self.store_token(server, &token);
193                Some(Credential {
194                    access_token: token.access_token,
195                })
196            }
197            Err(e) => {
198                tracing::warn!("MCP credential refresh for '{}' failed: {}", server, e);
199                None
200            }
201        }
202    }
203}
204
205fn now_secs() -> u64 {
206    SystemTime::now()
207        .duration_since(UNIX_EPOCH)
208        .map(|d| d.as_secs())
209        .unwrap_or(0)
210}