1use anyhow::{Context, Result};
19use oxi_agent::mcp::auth::{Credential, McpCredentialProvider};
20use oxi_agent::mcp::types::OAuthConfig;
21use parking_lot::RwLock;
22use serde::{Deserialize, Serialize};
23use std::collections::HashMap;
24use std::path::PathBuf;
25use std::sync::Arc;
26use std::time::{Duration, SystemTime, UNIX_EPOCH};
27
28const TOKEN_FILE: &str = "mcp-tokens.json";
29
30#[derive(Debug, Default, Clone, Serialize, Deserialize)]
31struct StoredToken {
32 access_token: String,
33 expires_at: Option<u64>,
36}
37
38#[derive(Debug, Default, Clone, Serialize, Deserialize)]
39struct TokenStore {
40 #[serde(default)]
41 tokens: HashMap<String, StoredToken>,
42}
43
44pub struct FileMcpCredentialProvider {
47 oauth: HashMap<String, OAuthConfig>,
49 store: RwLock<TokenStore>,
52 store_path: PathBuf,
53 http: reqwest::Client,
54}
55
56impl FileMcpCredentialProvider {
57 pub fn new(oauth: HashMap<String, OAuthConfig>, config_dir: PathBuf) -> Result<Arc<Self>> {
58 let store_path = config_dir.join(TOKEN_FILE);
59 let store = if store_path.exists() {
60 match std::fs::read_to_string(&store_path) {
61 Ok(s) => serde_json::from_str::<TokenStore>(&s).unwrap_or_default(),
62 Err(_) => TokenStore::default(),
63 }
64 } else {
65 TokenStore::default()
66 };
67 let http = reqwest::Client::builder()
68 .timeout(Duration::from_secs(15))
69 .build()
70 .context("build reqwest client for MCP credential provider")?;
71 Ok(Arc::new(Self {
72 oauth,
73 store: RwLock::new(store),
74 store_path,
75 http,
76 }))
77 }
78
79 pub async fn force_refresh(&self, server: &str) -> Result<()> {
82 let new = self.do_refresh(server).await?;
83 self.store_token(server, &new);
84 Ok(())
85 }
86
87 async fn do_refresh(&self, server: &str) -> Result<StoredToken> {
88 let cfg = self.oauth.get(server).cloned().ok_or_else(|| {
89 anyhow::anyhow!("Server '{}' has no OAuth config in mcp.json", server)
90 })?;
91 let mut form = vec![
92 ("grant_type", "client_credentials".to_string()),
93 ("client_id", cfg.client_id),
94 ("client_secret", cfg.client_secret),
95 ];
96 if let Some(scope) = cfg.scope.as_deref() {
97 form.push(("scope", scope.to_string()));
98 }
99 let resp = self
100 .http
101 .post(&cfg.token_url)
102 .header("Accept", "application/json")
103 .form(&form)
104 .send()
105 .await
106 .with_context(|| format!("OAuth token request to {} failed", cfg.token_url))?;
107 let status = resp.status();
108 let body: serde_json::Value = resp
109 .json()
110 .await
111 .context("OAuth token response was not JSON")?;
112 if !status.is_success() {
113 anyhow::bail!(
114 "OAuth token endpoint returned {}: {}",
115 status.as_u16(),
116 body
117 );
118 }
119 let access_token = body
120 .get("access_token")
121 .and_then(|v| v.as_str())
122 .ok_or_else(|| anyhow::anyhow!("OAuth response missing access_token"))?
123 .to_string();
124 let expires_at = body
125 .get("expires_in")
126 .and_then(|v| v.as_u64())
127 .and_then(|secs| now_secs().checked_add(secs));
128 Ok(StoredToken {
129 access_token,
130 expires_at,
131 })
132 }
133
134 fn store_token(&self, server: &str, token: &StoredToken) {
135 {
136 let mut s = self.store.write();
137 s.tokens.insert(server.to_string(), token.clone());
138 }
139 let snapshot = self.store.read().clone();
141 if let Ok(json) = serde_json::to_string_pretty(&snapshot) {
142 if let Some(parent) = self.store_path.parent() {
143 let _ = std::fs::create_dir_all(parent);
144 }
145 let tmp = self.store_path.with_extension("json.tmp");
146 if std::fs::write(&tmp, &json).is_ok() {
147 let _ = std::fs::rename(&tmp, &self.store_path);
148 }
149 }
150 }
151
152 fn token_is_fresh(&self, server: &str) -> Option<Credential> {
153 let s = self.store.read();
154 let t = s.tokens.get(server)?;
155 if let Some(exp) = t.expires_at {
156 if now_secs() + 30 >= exp {
159 return None;
160 }
161 }
162 Some(Credential {
163 access_token: t.access_token.clone(),
164 })
165 }
166}
167
168#[async_trait::async_trait]
169impl McpCredentialProvider for FileMcpCredentialProvider {
170 async fn access_token(&self, server: &str, _url: &str) -> Option<Credential> {
171 if let Some(c) = self.token_is_fresh(server) {
172 return Some(c);
173 }
174 match self.do_refresh(server).await {
176 Ok(token) => {
177 self.store_token(server, &token);
178 Some(Credential {
179 access_token: token.access_token,
180 })
181 }
182 Err(e) => {
183 tracing::warn!("MCP credential refresh for '{}' failed: {}", server, e);
184 None
185 }
186 }
187 }
188
189 async fn refresh(&self, server: &str, _url: &str) -> Option<Credential> {
190 match self.do_refresh(server).await {
191 Ok(token) => {
192 self.store_token(server, &token);
193 Some(Credential {
194 access_token: token.access_token,
195 })
196 }
197 Err(e) => {
198 tracing::warn!("MCP credential refresh for '{}' failed: {}", server, e);
199 None
200 }
201 }
202 }
203}
204
205fn now_secs() -> u64 {
206 SystemTime::now()
207 .duration_since(UNIX_EPOCH)
208 .map(|d| d.as_secs())
209 .unwrap_or(0)
210}