1use std::sync::{Mutex, OnceLock};
2
3#[derive(Debug, Clone)]
4pub struct ClientMcpCapabilities {
5 pub client_id: String,
6 pub resources: bool,
7 pub prompts: bool,
8 pub elicitation: bool,
9 pub sampling: bool,
10 pub dynamic_tools: bool,
11 pub max_tools: Option<usize>,
12}
13
14impl Default for ClientMcpCapabilities {
15 fn default() -> Self {
16 Self {
17 client_id: "unknown".to_string(),
18 resources: false,
19 prompts: false,
20 elicitation: false,
21 sampling: false,
22 dynamic_tools: false,
23 max_tools: None,
24 }
25 }
26}
27
28impl ClientMcpCapabilities {
29 pub fn detect(client_name: &str) -> Self {
30 let lower = client_name.to_lowercase();
31 let id = identify_client(&lower);
32
33 match id.as_str() {
34 "cursor" | "kiro" => Self {
35 client_id: id,
36 resources: true,
37 prompts: true,
38 elicitation: true,
39 sampling: false,
40 dynamic_tools: true,
41 max_tools: None,
42 },
43 "claude-code" => Self {
44 client_id: id,
45 resources: true,
46 prompts: true,
47 elicitation: true,
48 sampling: true,
49 dynamic_tools: true,
50 max_tools: None,
51 },
52 "windsurf" => Self {
53 client_id: id,
54 resources: false,
55 prompts: false,
56 elicitation: false,
57 sampling: false,
58 dynamic_tools: true,
59 max_tools: Some(100),
60 },
61 "zed" => Self {
62 client_id: id,
63 resources: false,
64 prompts: true,
65 elicitation: false,
66 sampling: false,
67 dynamic_tools: true,
68 max_tools: None,
69 },
70 "vscode-copilot" => Self {
71 client_id: id,
72 resources: true,
73 prompts: true,
74 elicitation: false,
75 sampling: false,
76 dynamic_tools: true,
77 max_tools: None,
78 },
79 "codex" => Self {
80 client_id: id,
81 resources: true,
82 prompts: false,
83 elicitation: false,
84 sampling: false,
85 dynamic_tools: true,
86 max_tools: None,
87 },
88 "antigravity" | "gemini-cli" => Self {
89 client_id: id,
90 resources: false,
91 prompts: false,
92 elicitation: false,
93 sampling: false,
94 dynamic_tools: false,
95 max_tools: None,
96 },
97 _ => Self {
98 client_id: id,
99 ..Default::default()
100 },
101 }
102 }
103
104 pub fn tier(&self) -> u8 {
105 let score = [
106 self.resources,
107 self.prompts,
108 self.elicitation,
109 self.sampling,
110 self.dynamic_tools,
111 ]
112 .iter()
113 .filter(|&&v| v)
114 .count();
115
116 match score {
117 4..=5 => 1,
118 2..=3 => 2,
119 1 => 3,
120 _ => 4,
121 }
122 }
123
124 pub fn format_summary(&self) -> String {
125 let features: Vec<&str> = [
126 ("resources", self.resources),
127 ("prompts", self.prompts),
128 ("elicitation", self.elicitation),
129 ("sampling", self.sampling),
130 ("dynamic_tools", self.dynamic_tools),
131 ]
132 .iter()
133 .filter(|(_, v)| *v)
134 .map(|(k, _)| *k)
135 .collect();
136
137 let tools_note = self
138 .max_tools
139 .map(|n| format!(" (max {n} tools)"))
140 .unwrap_or_default();
141
142 format!(
143 "{} (tier {}): [{}]{}",
144 self.client_id,
145 self.tier(),
146 features.join(", "),
147 tools_note,
148 )
149 }
150}
151
152fn identify_client(lower: &str) -> String {
153 if lower.contains("cursor") {
154 "cursor".to_string()
155 } else if lower.contains("claude") {
156 "claude-code".to_string()
157 } else if lower.contains("windsurf") || lower.contains("codeium") {
158 "windsurf".to_string()
159 } else if lower.contains("zed") {
160 "zed".to_string()
161 } else if lower.contains("copilot") || lower.contains("github") {
162 "vscode-copilot".to_string()
163 } else if lower.contains("kiro") {
164 "kiro".to_string()
165 } else if lower.contains("codex") || lower.contains("openai") {
166 "codex".to_string()
167 } else if lower.contains("antigravity") {
168 "antigravity".to_string()
169 } else if lower.contains("gemini") {
170 "gemini-cli".to_string()
171 } else {
172 "unknown".to_string()
173 }
174}
175
176static GLOBAL: OnceLock<Mutex<ClientMcpCapabilities>> = OnceLock::new();
177
178pub fn global() -> &'static Mutex<ClientMcpCapabilities> {
179 GLOBAL.get_or_init(|| Mutex::new(ClientMcpCapabilities::default()))
180}
181
182pub fn set_detected(caps: &ClientMcpCapabilities) {
183 if let Ok(mut g) = global().lock() {
184 *g = caps.clone();
185 }
186 persist_to_disk(caps);
187}
188
189pub fn current() -> ClientMcpCapabilities {
190 global().lock().map(|g| g.clone()).unwrap_or_default()
191}
192
193pub fn load_persisted(max_age_secs: u64) -> Option<ClientMcpCapabilities> {
196 let path = persisted_path()?;
197 let content = std::fs::read_to_string(&path).ok()?;
198 let val: serde_json::Value = serde_json::from_str(&content).ok()?;
199
200 let ts = val.get("ts").and_then(serde_json::Value::as_u64)?;
201 let now = std::time::SystemTime::now()
202 .duration_since(std::time::UNIX_EPOCH)
203 .map_or(0, |d| d.as_secs());
204 if now.saturating_sub(ts) > max_age_secs {
205 return None;
206 }
207
208 let client_id = val
209 .get("client_id")
210 .and_then(|v| v.as_str())
211 .unwrap_or("unknown")
212 .to_string();
213
214 if client_id == "unknown" {
215 return None;
216 }
217
218 Some(ClientMcpCapabilities::detect(&client_id))
219}
220
221fn persisted_path() -> Option<std::path::PathBuf> {
222 Some(
223 super::data_dir::lean_ctx_data_dir()
224 .ok()?
225 .join("client-id.json"),
226 )
227}
228
229fn persist_to_disk(caps: &ClientMcpCapabilities) {
230 let Some(path) = persisted_path() else {
231 return;
232 };
233 let ts = std::time::SystemTime::now()
234 .duration_since(std::time::UNIX_EPOCH)
235 .map_or(0, |d| d.as_secs());
236 let payload = serde_json::json!({
237 "client_id": caps.client_id,
238 "tier": caps.tier(),
239 "features": caps.format_summary(),
240 "ts": ts,
241 });
242 let tmp = path.with_extension("tmp");
243 if let Ok(json) = serde_json::to_string_pretty(&payload) {
244 if std::fs::write(&tmp, &json).is_ok() {
245 let _ = std::fs::rename(&tmp, &path);
246 }
247 }
248}
249
250#[cfg(test)]
251mod tests {
252 use super::*;
253
254 #[test]
255 fn cursor_detection() {
256 let caps = ClientMcpCapabilities::detect("Cursor");
257 assert_eq!(caps.client_id, "cursor");
258 assert!(caps.resources);
259 assert!(caps.prompts);
260 assert!(caps.elicitation);
261 assert!(caps.dynamic_tools);
262 assert_eq!(caps.tier(), 1);
263 }
264
265 #[test]
266 fn claude_code_detection() {
267 let caps = ClientMcpCapabilities::detect("claude-code");
268 assert_eq!(caps.client_id, "claude-code");
269 assert!(caps.sampling);
270 assert_eq!(caps.tier(), 1);
271 }
272
273 #[test]
274 fn windsurf_detection() {
275 let caps = ClientMcpCapabilities::detect("Windsurf");
276 assert_eq!(caps.client_id, "windsurf");
277 assert!(!caps.resources);
278 assert!(!caps.prompts);
279 assert_eq!(caps.max_tools, Some(100));
280 assert_eq!(caps.tier(), 3);
281 }
282
283 #[test]
284 fn unknown_client_tier4() {
285 let caps = ClientMcpCapabilities::detect("random-editor");
286 assert_eq!(caps.client_id, "unknown");
287 assert_eq!(caps.tier(), 4);
288 }
289
290 #[test]
291 fn format_summary() {
292 let caps = ClientMcpCapabilities::detect("Cursor");
293 let s = caps.format_summary();
294 assert!(s.contains("cursor"));
295 assert!(s.contains("tier 1"));
296 }
297}