1use anyhow::{Context, Result};
19use oxi_agent::mcp::auth::{Credential, McpCredentialProvider};
20use oxi_agent::mcp::types::OAuthConfig;
21use parking_lot::RwLock;
22use serde::{Deserialize, Serialize};
23use std::collections::HashMap;
24use std::path::PathBuf;
25use std::sync::Arc;
26use std::time::{Duration, SystemTime, UNIX_EPOCH};
27
28const TOKEN_FILE: &str = "mcp-tokens.json";
29
30#[derive(Debug, Default, Clone, Serialize, Deserialize)]
31struct StoredToken {
32 access_token: String,
33 expires_at: Option<u64>,
36}
37
38#[derive(Debug, Default, Clone, Serialize, Deserialize)]
39struct TokenStore {
40 #[serde(default)]
41 tokens: HashMap<String, StoredToken>,
42}
43
44pub struct FileMcpCredentialProvider {
47 oauth: HashMap<String, OAuthConfig>,
49 store: RwLock<TokenStore>,
52 store_path: PathBuf,
53 http: reqwest::Client,
54}
55
56impl FileMcpCredentialProvider {
57 pub fn new(oauth: HashMap<String, OAuthConfig>, config_dir: PathBuf) -> Result<Arc<Self>> {
62 let store_path = config_dir.join(TOKEN_FILE);
63 let store = if store_path.exists() {
64 match std::fs::read_to_string(&store_path) {
65 Ok(s) => serde_json::from_str::<TokenStore>(&s).unwrap_or_default(),
66 Err(_) => TokenStore::default(),
67 }
68 } else {
69 TokenStore::default()
70 };
71 let http = reqwest::Client::builder()
72 .timeout(Duration::from_secs(15))
73 .build()
74 .context("build reqwest client for MCP credential provider")?;
75 Ok(Arc::new(Self {
76 oauth,
77 store: RwLock::new(store),
78 store_path,
79 http,
80 }))
81 }
82
83 pub async fn force_refresh(&self, server: &str) -> Result<()> {
86 let new = self.do_refresh(server).await?;
87 self.store_token(server, &new);
88 Ok(())
89 }
90
91 async fn do_refresh(&self, server: &str) -> Result<StoredToken> {
92 let cfg = self.oauth.get(server).cloned().ok_or_else(|| {
93 anyhow::anyhow!("Server '{}' has no OAuth config in mcp.json", server)
94 })?;
95 let mut form = vec![
96 ("grant_type", "client_credentials".to_string()),
97 ("client_id", cfg.client_id),
98 ("client_secret", cfg.client_secret),
99 ];
100 if let Some(scope) = cfg.scope.as_deref() {
101 form.push(("scope", scope.to_string()));
102 }
103 let resp = self
104 .http
105 .post(&cfg.token_url)
106 .header("Accept", "application/json")
107 .form(&form)
108 .send()
109 .await
110 .with_context(|| format!("OAuth token request to {} failed", cfg.token_url))?;
111 let status = resp.status();
112 let body: serde_json::Value = resp
113 .json()
114 .await
115 .context("OAuth token response was not JSON")?;
116 if !status.is_success() {
117 anyhow::bail!(
118 "OAuth token endpoint returned {}: {}",
119 status.as_u16(),
120 body
121 );
122 }
123 let access_token = body
124 .get("access_token")
125 .and_then(|v| v.as_str())
126 .ok_or_else(|| anyhow::anyhow!("OAuth response missing access_token"))?
127 .to_string();
128 let expires_at = body
129 .get("expires_in")
130 .and_then(|v| v.as_u64())
131 .and_then(|secs| now_secs().checked_add(secs));
132 Ok(StoredToken {
133 access_token,
134 expires_at,
135 })
136 }
137
138 fn store_token(&self, server: &str, token: &StoredToken) {
139 {
140 let mut s = self.store.write();
141 s.tokens.insert(server.to_string(), token.clone());
142 }
143 let snapshot = self.store.read().clone();
145 if let Ok(json) = serde_json::to_string_pretty(&snapshot) {
146 if let Some(parent) = self.store_path.parent() {
147 let _ = std::fs::create_dir_all(parent);
148 }
149 let tmp = self.store_path.with_extension("json.tmp");
150 if std::fs::write(&tmp, &json).is_ok() {
151 let _ = std::fs::rename(&tmp, &self.store_path);
152 }
153 }
154 }
155
156 fn token_is_fresh(&self, server: &str) -> Option<Credential> {
157 let s = self.store.read();
158 let t = s.tokens.get(server)?;
159 if let Some(exp) = t.expires_at {
160 if now_secs() + 30 >= exp {
163 return None;
164 }
165 }
166 Some(Credential {
167 access_token: t.access_token.clone(),
168 })
169 }
170}
171
172#[async_trait::async_trait]
173impl McpCredentialProvider for FileMcpCredentialProvider {
174 async fn access_token(&self, server: &str, _url: &str) -> Option<Credential> {
175 if let Some(c) = self.token_is_fresh(server) {
176 return Some(c);
177 }
178 match self.do_refresh(server).await {
180 Ok(token) => {
181 self.store_token(server, &token);
182 Some(Credential {
183 access_token: token.access_token,
184 })
185 }
186 Err(e) => {
187 tracing::warn!("MCP credential refresh for '{}' failed: {}", server, e);
188 None
189 }
190 }
191 }
192
193 async fn refresh(&self, server: &str, _url: &str) -> Option<Credential> {
194 match self.do_refresh(server).await {
195 Ok(token) => {
196 self.store_token(server, &token);
197 Some(Credential {
198 access_token: token.access_token,
199 })
200 }
201 Err(e) => {
202 tracing::warn!("MCP credential refresh for '{}' failed: {}", server, e);
203 None
204 }
205 }
206 }
207}
208
209fn now_secs() -> u64 {
210 SystemTime::now()
211 .duration_since(UNIX_EPOCH)
212 .map(|d| d.as_secs())
213 .unwrap_or(0)
214}