Skip to main content

lean_ctx/core/
client_capabilities.rs

1use std::sync::{Mutex, OnceLock};
2
3#[derive(Debug, Clone)]
4pub struct ClientMcpCapabilities {
5    pub client_id: String,
6    pub resources: bool,
7    pub prompts: bool,
8    pub elicitation: bool,
9    pub sampling: bool,
10    pub dynamic_tools: bool,
11    pub max_tools: Option<usize>,
12}
13
14impl Default for ClientMcpCapabilities {
15    fn default() -> Self {
16        Self {
17            client_id: "unknown".to_string(),
18            resources: false,
19            prompts: false,
20            elicitation: false,
21            sampling: false,
22            dynamic_tools: false,
23            max_tools: None,
24        }
25    }
26}
27
28impl ClientMcpCapabilities {
29    pub fn detect(client_name: &str) -> Self {
30        let hint = std::env::var("LEAN_CTX_CLIENT_HINT").ok();
31        Self::detect_with_hint(client_name, hint.as_deref())
32    }
33
34    fn detect_with_hint(client_name: &str, hint: Option<&str>) -> Self {
35        let effective = match hint {
36            Some(h) if !h.trim().is_empty() => h.trim().to_lowercase(),
37            _ => client_name.to_lowercase(),
38        };
39        let id = identify_client(&effective);
40
41        match id.as_str() {
42            "cursor" | "kiro" => Self {
43                client_id: id,
44                resources: true,
45                prompts: true,
46                elicitation: true,
47                sampling: false,
48                dynamic_tools: true,
49                max_tools: None,
50            },
51            "claude-code" => Self {
52                client_id: id,
53                resources: true,
54                prompts: true,
55                elicitation: true,
56                sampling: true,
57                dynamic_tools: true,
58                max_tools: None,
59            },
60            "windsurf" => Self {
61                client_id: id,
62                resources: false,
63                prompts: false,
64                elicitation: false,
65                sampling: false,
66                dynamic_tools: true,
67                max_tools: Some(100),
68            },
69            "zed" => Self {
70                client_id: id,
71                resources: false,
72                prompts: true,
73                elicitation: false,
74                sampling: false,
75                dynamic_tools: true,
76                max_tools: None,
77            },
78            "vscode-copilot" => Self {
79                client_id: id,
80                resources: true,
81                prompts: true,
82                elicitation: false,
83                sampling: false,
84                dynamic_tools: true,
85                max_tools: None,
86            },
87            "codex" => Self {
88                client_id: id,
89                resources: true,
90                prompts: false,
91                elicitation: false,
92                sampling: false,
93                dynamic_tools: true,
94                max_tools: None,
95            },
96            "antigravity" | "gemini-cli" => Self {
97                client_id: id,
98                resources: false,
99                prompts: false,
100                elicitation: false,
101                sampling: false,
102                dynamic_tools: false,
103                max_tools: None,
104            },
105            _ => Self {
106                client_id: id,
107                ..Default::default()
108            },
109        }
110    }
111
112    pub fn tier(&self) -> u8 {
113        let score = [
114            self.resources,
115            self.prompts,
116            self.elicitation,
117            self.sampling,
118            self.dynamic_tools,
119        ]
120        .iter()
121        .filter(|&&v| v)
122        .count();
123
124        match score {
125            4..=5 => 1,
126            2..=3 => 2,
127            1 => 3,
128            _ => 4,
129        }
130    }
131
132    pub fn format_summary(&self) -> String {
133        let features: Vec<&str> = [
134            ("resources", self.resources),
135            ("prompts", self.prompts),
136            ("elicitation", self.elicitation),
137            ("sampling", self.sampling),
138            ("dynamic_tools", self.dynamic_tools),
139        ]
140        .iter()
141        .filter(|(_, v)| *v)
142        .map(|(k, _)| *k)
143        .collect();
144
145        let tools_note = self
146            .max_tools
147            .map(|n| format!(" (max {n} tools)"))
148            .unwrap_or_default();
149
150        format!(
151            "{} (tier {}): [{}]{}",
152            self.client_id,
153            self.tier(),
154            features.join(", "),
155            tools_note,
156        )
157    }
158}
159
160fn identify_client(lower: &str) -> String {
161    if lower.contains("cursor") {
162        "cursor".to_string()
163    } else if lower.contains("claude") {
164        "claude-code".to_string()
165    } else if lower.contains("windsurf") || lower.contains("codeium") {
166        "windsurf".to_string()
167    } else if lower.contains("zed") {
168        "zed".to_string()
169    } else if lower.contains("copilot")
170        || lower.contains("github")
171        || lower.contains("visual studio code")
172        || lower.contains("vscode")
173    {
174        "vscode-copilot".to_string()
175    } else if lower.contains("kiro") {
176        "kiro".to_string()
177    } else if lower.contains("codex") || lower.contains("openai") {
178        "codex".to_string()
179    } else if lower.contains("antigravity") {
180        "antigravity".to_string()
181    } else if lower.contains("gemini") {
182        "gemini-cli".to_string()
183    } else {
184        "unknown".to_string()
185    }
186}
187
188static GLOBAL: OnceLock<Mutex<ClientMcpCapabilities>> = OnceLock::new();
189
190pub fn global() -> &'static Mutex<ClientMcpCapabilities> {
191    GLOBAL.get_or_init(|| Mutex::new(ClientMcpCapabilities::default()))
192}
193
194pub fn set_detected(caps: &ClientMcpCapabilities) {
195    if let Ok(mut g) = global().lock() {
196        *g = caps.clone();
197    }
198    persist_to_disk(caps);
199}
200
201pub fn current() -> ClientMcpCapabilities {
202    global().lock().map(|g| g.clone()).unwrap_or_default()
203}
204
205/// Load persisted client info from disk (for cross-process use, e.g. dashboard).
206/// Returns `None` if file missing or older than `max_age_secs`.
207pub fn load_persisted(max_age_secs: u64) -> Option<ClientMcpCapabilities> {
208    let path = persisted_path()?;
209    let content = std::fs::read_to_string(&path).ok()?;
210    let val: serde_json::Value = serde_json::from_str(&content).ok()?;
211
212    let ts = val.get("ts").and_then(serde_json::Value::as_u64)?;
213    let now = std::time::SystemTime::now()
214        .duration_since(std::time::UNIX_EPOCH)
215        .map_or(0, |d| d.as_secs());
216    if now.saturating_sub(ts) > max_age_secs {
217        return None;
218    }
219
220    let client_id = val
221        .get("client_id")
222        .and_then(|v| v.as_str())
223        .unwrap_or("unknown")
224        .to_string();
225
226    if client_id == "unknown" {
227        return None;
228    }
229
230    Some(ClientMcpCapabilities::detect(&client_id))
231}
232
233fn persisted_path() -> Option<std::path::PathBuf> {
234    Some(
235        super::data_dir::lean_ctx_data_dir()
236            .ok()?
237            .join("client-id.json"),
238    )
239}
240
241fn persist_to_disk(caps: &ClientMcpCapabilities) {
242    let Some(path) = persisted_path() else {
243        return;
244    };
245    let ts = std::time::SystemTime::now()
246        .duration_since(std::time::UNIX_EPOCH)
247        .map_or(0, |d| d.as_secs());
248    let payload = serde_json::json!({
249        "client_id": caps.client_id,
250        "tier": caps.tier(),
251        "features": caps.format_summary(),
252        "ts": ts,
253    });
254    let tmp = path.with_extension("tmp");
255    if let Ok(json) = serde_json::to_string_pretty(&payload) {
256        if std::fs::write(&tmp, &json).is_ok() {
257            let _ = std::fs::rename(&tmp, &path);
258        }
259    }
260}
261
262#[cfg(test)]
263mod tests {
264    use super::*;
265
266    #[test]
267    fn cursor_detection() {
268        let caps = ClientMcpCapabilities::detect("Cursor");
269        assert_eq!(caps.client_id, "cursor");
270        assert!(caps.resources);
271        assert!(caps.prompts);
272        assert!(caps.elicitation);
273        assert!(caps.dynamic_tools);
274        assert_eq!(caps.tier(), 1);
275    }
276
277    #[test]
278    fn claude_code_detection() {
279        let caps = ClientMcpCapabilities::detect("claude-code");
280        assert_eq!(caps.client_id, "claude-code");
281        assert!(caps.sampling);
282        assert_eq!(caps.tier(), 1);
283    }
284
285    #[test]
286    fn windsurf_detection() {
287        let caps = ClientMcpCapabilities::detect("Windsurf");
288        assert_eq!(caps.client_id, "windsurf");
289        assert!(!caps.resources);
290        assert!(!caps.prompts);
291        assert_eq!(caps.max_tools, Some(100));
292        assert_eq!(caps.tier(), 3);
293    }
294
295    #[test]
296    fn unknown_client_tier4() {
297        let caps = ClientMcpCapabilities::detect("random-editor");
298        assert_eq!(caps.client_id, "unknown");
299        assert_eq!(caps.tier(), 4);
300    }
301
302    #[test]
303    fn copilot_detection() {
304        let caps = ClientMcpCapabilities::detect("GitHub Copilot");
305        assert_eq!(caps.client_id, "vscode-copilot");
306        assert!(caps.resources);
307        assert!(caps.prompts);
308        assert!(caps.dynamic_tools);
309        assert_eq!(caps.tier(), 2);
310    }
311
312    #[test]
313    fn vscode_plain_detection() {
314        let caps = ClientMcpCapabilities::detect("Visual Studio Code");
315        assert_eq!(caps.client_id, "vscode-copilot");
316        assert_eq!(caps.tier(), 2);
317    }
318
319    #[test]
320    fn vscode_lowercase_detection() {
321        let caps = ClientMcpCapabilities::detect("vscode");
322        assert_eq!(caps.client_id, "vscode-copilot");
323        assert_eq!(caps.tier(), 2);
324    }
325
326    #[test]
327    fn client_hint_override() {
328        let caps = ClientMcpCapabilities::detect_with_hint(
329            "random-unknown-editor",
330            Some("vscode-copilot"),
331        );
332        assert_eq!(caps.client_id, "vscode-copilot");
333        assert_eq!(caps.tier(), 2);
334    }
335
336    #[test]
337    fn client_hint_empty_falls_back() {
338        let caps = ClientMcpCapabilities::detect_with_hint("Cursor", Some(""));
339        assert_eq!(caps.client_id, "cursor");
340        assert_eq!(caps.tier(), 1);
341    }
342
343    #[test]
344    fn client_hint_none_falls_back() {
345        let caps = ClientMcpCapabilities::detect_with_hint("Cursor", None);
346        assert_eq!(caps.client_id, "cursor");
347        assert_eq!(caps.tier(), 1);
348    }
349
350    #[test]
351    fn format_summary() {
352        let caps = ClientMcpCapabilities::detect("Cursor");
353        let s = caps.format_summary();
354        assert!(s.contains("cursor"));
355        assert!(s.contains("tier 1"));
356    }
357}