1use std::sync::{Mutex, OnceLock};
2
3#[derive(Debug, Clone)]
4pub struct ClientMcpCapabilities {
5 pub client_id: String,
6 pub resources: bool,
7 pub prompts: bool,
8 pub elicitation: bool,
9 pub sampling: bool,
10 pub dynamic_tools: bool,
11 pub max_tools: Option<usize>,
12}
13
14impl Default for ClientMcpCapabilities {
15 fn default() -> Self {
16 Self {
17 client_id: "unknown".to_string(),
18 resources: false,
19 prompts: false,
20 elicitation: false,
21 sampling: false,
22 dynamic_tools: false,
23 max_tools: None,
24 }
25 }
26}
27
28impl ClientMcpCapabilities {
29 pub fn detect(client_name: &str) -> Self {
30 let hint = std::env::var("LEAN_CTX_CLIENT_HINT").ok();
31 Self::detect_with_hint(client_name, hint.as_deref())
32 }
33
34 fn detect_with_hint(client_name: &str, hint: Option<&str>) -> Self {
35 let effective = match hint {
36 Some(h) if !h.trim().is_empty() => h.trim().to_lowercase(),
37 _ => client_name.to_lowercase(),
38 };
39 let id = identify_client(&effective);
40
41 match id.as_str() {
42 "cursor" | "kiro" => Self {
43 client_id: id,
44 resources: true,
45 prompts: true,
46 elicitation: true,
47 sampling: false,
48 dynamic_tools: true,
49 max_tools: None,
50 },
51 "claude-code" => Self {
52 client_id: id,
53 resources: true,
54 prompts: true,
55 elicitation: true,
56 sampling: true,
57 dynamic_tools: true,
58 max_tools: None,
59 },
60 "windsurf" => Self {
61 client_id: id,
62 resources: false,
63 prompts: false,
64 elicitation: false,
65 sampling: false,
66 dynamic_tools: true,
67 max_tools: Some(100),
68 },
69 "zed" => Self {
70 client_id: id,
71 resources: false,
72 prompts: true,
73 elicitation: false,
74 sampling: false,
75 dynamic_tools: true,
76 max_tools: None,
77 },
78 "vscode-copilot" => Self {
79 client_id: id,
80 resources: true,
81 prompts: true,
82 elicitation: false,
83 sampling: false,
84 dynamic_tools: true,
85 max_tools: None,
86 },
87 "codex" => Self {
88 client_id: id,
89 resources: true,
90 prompts: false,
91 elicitation: false,
92 sampling: false,
93 dynamic_tools: true,
94 max_tools: None,
95 },
96 "antigravity" | "gemini-cli" => Self {
97 client_id: id,
98 resources: false,
99 prompts: false,
100 elicitation: false,
101 sampling: false,
102 dynamic_tools: false,
103 max_tools: None,
104 },
105 _ => Self {
106 client_id: id,
107 ..Default::default()
108 },
109 }
110 }
111
112 pub fn tier(&self) -> u8 {
113 let score = [
114 self.resources,
115 self.prompts,
116 self.elicitation,
117 self.sampling,
118 self.dynamic_tools,
119 ]
120 .iter()
121 .filter(|&&v| v)
122 .count();
123
124 match score {
125 4..=5 => 1,
126 2..=3 => 2,
127 1 => 3,
128 _ => 4,
129 }
130 }
131
132 pub fn format_summary(&self) -> String {
133 let features: Vec<&str> = [
134 ("resources", self.resources),
135 ("prompts", self.prompts),
136 ("elicitation", self.elicitation),
137 ("sampling", self.sampling),
138 ("dynamic_tools", self.dynamic_tools),
139 ]
140 .iter()
141 .filter(|(_, v)| *v)
142 .map(|(k, _)| *k)
143 .collect();
144
145 let tools_note = self
146 .max_tools
147 .map(|n| format!(" (max {n} tools)"))
148 .unwrap_or_default();
149
150 format!(
151 "{} (tier {}): [{}]{}",
152 self.client_id,
153 self.tier(),
154 features.join(", "),
155 tools_note,
156 )
157 }
158}
159
160fn identify_client(lower: &str) -> String {
161 if lower.contains("cursor") {
162 "cursor".to_string()
163 } else if lower.contains("claude") {
164 "claude-code".to_string()
165 } else if lower.contains("windsurf") || lower.contains("codeium") {
166 "windsurf".to_string()
167 } else if lower.contains("zed") {
168 "zed".to_string()
169 } else if lower.contains("copilot")
170 || lower.contains("github")
171 || lower.contains("visual studio code")
172 || lower.contains("vscode")
173 {
174 "vscode-copilot".to_string()
175 } else if lower.contains("kiro") {
176 "kiro".to_string()
177 } else if lower.contains("codex") || lower.contains("openai") {
178 "codex".to_string()
179 } else if lower.contains("antigravity") {
180 "antigravity".to_string()
181 } else if lower.contains("gemini") {
182 "gemini-cli".to_string()
183 } else {
184 "unknown".to_string()
185 }
186}
187
188static GLOBAL: OnceLock<Mutex<ClientMcpCapabilities>> = OnceLock::new();
189
190pub fn global() -> &'static Mutex<ClientMcpCapabilities> {
191 GLOBAL.get_or_init(|| Mutex::new(ClientMcpCapabilities::default()))
192}
193
194pub fn set_detected(caps: &ClientMcpCapabilities) {
195 if let Ok(mut g) = global().lock() {
196 *g = caps.clone();
197 }
198 persist_to_disk(caps);
199}
200
201pub fn current() -> ClientMcpCapabilities {
202 global().lock().map(|g| g.clone()).unwrap_or_default()
203}
204
205pub fn load_persisted(max_age_secs: u64) -> Option<ClientMcpCapabilities> {
208 let path = persisted_path()?;
209 let content = std::fs::read_to_string(&path).ok()?;
210 let val: serde_json::Value = serde_json::from_str(&content).ok()?;
211
212 let ts = val.get("ts").and_then(serde_json::Value::as_u64)?;
213 let now = std::time::SystemTime::now()
214 .duration_since(std::time::UNIX_EPOCH)
215 .map_or(0, |d| d.as_secs());
216 if now.saturating_sub(ts) > max_age_secs {
217 return None;
218 }
219
220 let client_id = val
221 .get("client_id")
222 .and_then(|v| v.as_str())
223 .unwrap_or("unknown")
224 .to_string();
225
226 if client_id == "unknown" {
227 return None;
228 }
229
230 Some(ClientMcpCapabilities::detect(&client_id))
231}
232
233fn persisted_path() -> Option<std::path::PathBuf> {
234 Some(
235 super::data_dir::lean_ctx_data_dir()
236 .ok()?
237 .join("client-id.json"),
238 )
239}
240
241fn persist_to_disk(caps: &ClientMcpCapabilities) {
242 let Some(path) = persisted_path() else {
243 return;
244 };
245 let ts = std::time::SystemTime::now()
246 .duration_since(std::time::UNIX_EPOCH)
247 .map_or(0, |d| d.as_secs());
248 let payload = serde_json::json!({
249 "client_id": caps.client_id,
250 "tier": caps.tier(),
251 "features": caps.format_summary(),
252 "ts": ts,
253 });
254 let tmp = path.with_extension("tmp");
255 if let Ok(json) = serde_json::to_string_pretty(&payload) {
256 if std::fs::write(&tmp, &json).is_ok() {
257 let _ = std::fs::rename(&tmp, &path);
258 }
259 }
260}
261
262#[cfg(test)]
263mod tests {
264 use super::*;
265
266 #[test]
267 fn cursor_detection() {
268 let caps = ClientMcpCapabilities::detect("Cursor");
269 assert_eq!(caps.client_id, "cursor");
270 assert!(caps.resources);
271 assert!(caps.prompts);
272 assert!(caps.elicitation);
273 assert!(caps.dynamic_tools);
274 assert_eq!(caps.tier(), 1);
275 }
276
277 #[test]
278 fn claude_code_detection() {
279 let caps = ClientMcpCapabilities::detect("claude-code");
280 assert_eq!(caps.client_id, "claude-code");
281 assert!(caps.sampling);
282 assert_eq!(caps.tier(), 1);
283 }
284
285 #[test]
286 fn windsurf_detection() {
287 let caps = ClientMcpCapabilities::detect("Windsurf");
288 assert_eq!(caps.client_id, "windsurf");
289 assert!(!caps.resources);
290 assert!(!caps.prompts);
291 assert_eq!(caps.max_tools, Some(100));
292 assert_eq!(caps.tier(), 3);
293 }
294
295 #[test]
296 fn unknown_client_tier4() {
297 let caps = ClientMcpCapabilities::detect("random-editor");
298 assert_eq!(caps.client_id, "unknown");
299 assert_eq!(caps.tier(), 4);
300 }
301
302 #[test]
303 fn copilot_detection() {
304 let caps = ClientMcpCapabilities::detect("GitHub Copilot");
305 assert_eq!(caps.client_id, "vscode-copilot");
306 assert!(caps.resources);
307 assert!(caps.prompts);
308 assert!(caps.dynamic_tools);
309 assert_eq!(caps.tier(), 2);
310 }
311
312 #[test]
313 fn vscode_plain_detection() {
314 let caps = ClientMcpCapabilities::detect("Visual Studio Code");
315 assert_eq!(caps.client_id, "vscode-copilot");
316 assert_eq!(caps.tier(), 2);
317 }
318
319 #[test]
320 fn vscode_lowercase_detection() {
321 let caps = ClientMcpCapabilities::detect("vscode");
322 assert_eq!(caps.client_id, "vscode-copilot");
323 assert_eq!(caps.tier(), 2);
324 }
325
326 #[test]
327 fn client_hint_override() {
328 let caps = ClientMcpCapabilities::detect_with_hint(
329 "random-unknown-editor",
330 Some("vscode-copilot"),
331 );
332 assert_eq!(caps.client_id, "vscode-copilot");
333 assert_eq!(caps.tier(), 2);
334 }
335
336 #[test]
337 fn client_hint_empty_falls_back() {
338 let caps = ClientMcpCapabilities::detect_with_hint("Cursor", Some(""));
339 assert_eq!(caps.client_id, "cursor");
340 assert_eq!(caps.tier(), 1);
341 }
342
343 #[test]
344 fn client_hint_none_falls_back() {
345 let caps = ClientMcpCapabilities::detect_with_hint("Cursor", None);
346 assert_eq!(caps.client_id, "cursor");
347 assert_eq!(caps.tier(), 1);
348 }
349
350 #[test]
351 fn format_summary() {
352 let caps = ClientMcpCapabilities::detect("Cursor");
353 let s = caps.format_summary();
354 assert!(s.contains("cursor"));
355 assert!(s.contains("tier 1"));
356 }
357}