Skip to main content

aster/parser/
lsp_manager.rs

1//! LSP Server Manager
2//!
3//! 管理 LSP 服务器的安装、启动和生命周期
4
5use std::collections::HashMap;
6use std::path::PathBuf;
7use std::process::Command;
8use std::sync::Arc;
9use tokio::sync::{broadcast, RwLock};
10
11use super::lsp_client::{LspClient, LspClientConfig, LspServerState};
12
13/// LSP 服务器信息
14#[derive(Debug, Clone)]
15pub struct LspServerInfo {
16    /// 语言
17    pub language: String,
18    /// 服务器名称
19    pub name: String,
20    /// 命令
21    pub command: String,
22    /// 参数
23    pub args: Vec<String>,
24    /// 安装命令
25    pub install_command: String,
26    /// 检查命令
27    pub check_command: String,
28    /// 文件扩展名
29    pub extensions: Vec<String>,
30    /// 语言 ID
31    pub language_id: String,
32}
33
34/// LSP 服务器配置表
35pub static LSP_SERVERS: once_cell::sync::Lazy<HashMap<&'static str, LspServerInfo>> =
36    once_cell::sync::Lazy::new(|| {
37        let mut m = HashMap::new();
38
39        m.insert(
40            "typescript",
41            LspServerInfo {
42                language: "typescript".to_string(),
43                name: "TypeScript Language Server".to_string(),
44                command: "typescript-language-server".to_string(),
45                args: vec!["--stdio".to_string()],
46                install_command: "npm install -g typescript-language-server typescript".to_string(),
47                check_command: "typescript-language-server --version".to_string(),
48                extensions: vec![".ts".to_string(), ".tsx".to_string()],
49                language_id: "typescript".to_string(),
50            },
51        );
52
53        m.insert(
54            "javascript",
55            LspServerInfo {
56                language: "javascript".to_string(),
57                name: "TypeScript Language Server (JavaScript)".to_string(),
58                command: "typescript-language-server".to_string(),
59                args: vec!["--stdio".to_string()],
60                install_command: "npm install -g typescript-language-server typescript".to_string(),
61                check_command: "typescript-language-server --version".to_string(),
62                extensions: vec![".js".to_string(), ".jsx".to_string()],
63                language_id: "javascript".to_string(),
64            },
65        );
66
67        m.insert(
68            "python",
69            LspServerInfo {
70                language: "python".to_string(),
71                name: "Pyright".to_string(),
72                command: "pyright-langserver".to_string(),
73                args: vec!["--stdio".to_string()],
74                install_command: "npm install -g pyright".to_string(),
75                check_command: "pyright-langserver --version".to_string(),
76                extensions: vec![".py".to_string(), ".pyi".to_string()],
77                language_id: "python".to_string(),
78            },
79        );
80
81        m.insert(
82            "rust",
83            LspServerInfo {
84                language: "rust".to_string(),
85                name: "rust-analyzer".to_string(),
86                command: "rust-analyzer".to_string(),
87                args: vec![],
88                install_command: "rustup component add rust-analyzer".to_string(),
89                check_command: "rust-analyzer --version".to_string(),
90                extensions: vec![".rs".to_string()],
91                language_id: "rust".to_string(),
92            },
93        );
94
95        m.insert(
96            "go",
97            LspServerInfo {
98                language: "go".to_string(),
99                name: "gopls".to_string(),
100                command: "gopls".to_string(),
101                args: vec!["serve".to_string()],
102                install_command: "go install golang.org/x/tools/gopls@latest".to_string(),
103                check_command: "gopls version".to_string(),
104                extensions: vec![".go".to_string()],
105                language_id: "go".to_string(),
106            },
107        );
108
109        m
110    });
111
112/// 安装状态
113#[derive(Debug, Clone, Copy, PartialEq, Eq)]
114pub enum InstallStatus {
115    Checking,
116    Installing,
117    Installed,
118    Failed,
119    Skipped,
120}
121
122/// 进度事件
123#[derive(Debug, Clone)]
124pub struct ProgressEvent {
125    pub language: String,
126    pub status: InstallStatus,
127    pub message: String,
128    pub progress: Option<u8>,
129}
130
131/// LSP 管理器事件
132#[derive(Debug, Clone)]
133pub enum LspManagerEvent {
134    Progress(ProgressEvent),
135    ClientStateChange {
136        language: String,
137        state: LspServerState,
138    },
139    ClientError {
140        language: String,
141        error: String,
142    },
143}
144
145/// LSP 服务器管理器
146pub struct LspManager {
147    clients: Arc<RwLock<HashMap<String, Arc<LspClient>>>>,
148    installed_servers: Arc<RwLock<std::collections::HashSet<String>>>,
149    workspace_root: PathBuf,
150    event_sender: broadcast::Sender<LspManagerEvent>,
151}
152
153impl LspManager {
154    /// 创建新的 LSP 管理器
155    pub fn new(workspace_root: Option<PathBuf>) -> Self {
156        let (event_sender, _) = broadcast::channel(64);
157        Self {
158            clients: Arc::new(RwLock::new(HashMap::new())),
159            installed_servers: Arc::new(RwLock::new(std::collections::HashSet::new())),
160            workspace_root: workspace_root
161                .unwrap_or_else(|| std::env::current_dir().unwrap_or_default()),
162            event_sender,
163        }
164    }
165
166    /// 订阅事件
167    pub fn subscribe(&self) -> broadcast::Receiver<LspManagerEvent> {
168        self.event_sender.subscribe()
169    }
170
171    /// 检查 LSP 服务器是否已安装
172    pub fn is_server_installed(&self, language: &str) -> bool {
173        let server = match LSP_SERVERS.get(language) {
174            Some(s) => s,
175            None => return false,
176        };
177
178        let output = Command::new("sh")
179            .arg("-c")
180            .arg(&server.check_command)
181            .output();
182
183        matches!(output, Ok(o) if o.status.success())
184    }
185
186    /// 确保 LSP 服务器已安装
187    pub async fn ensure_server(&self, language: &str) -> Result<(), String> {
188        let server = LSP_SERVERS
189            .get(language)
190            .ok_or_else(|| format!("Unsupported language: {}", language))?;
191
192        if self.installed_servers.read().await.contains(language) {
193            return Ok(());
194        }
195
196        let _ = self
197            .event_sender
198            .send(LspManagerEvent::Progress(ProgressEvent {
199                language: language.to_string(),
200                status: InstallStatus::Checking,
201                message: format!("Checking {}...", server.name),
202                progress: None,
203            }));
204
205        if self.is_server_installed(language) {
206            self.installed_servers
207                .write()
208                .await
209                .insert(language.to_string());
210            let _ = self
211                .event_sender
212                .send(LspManagerEvent::Progress(ProgressEvent {
213                    language: language.to_string(),
214                    status: InstallStatus::Installed,
215                    message: format!("{} is ready", server.name),
216                    progress: Some(100),
217                }));
218            return Ok(());
219        }
220
221        Err(format!(
222            "{} is not installed. Install with: {}",
223            server.name, server.install_command
224        ))
225    }
226
227    /// 获取或创建 LSP 客户端
228    pub async fn get_client(&self, language: &str) -> Result<Arc<LspClient>, String> {
229        // 检查是否已有客户端
230        if let Some(client) = self.clients.read().await.get(language) {
231            if client.get_state().await == LspServerState::Running {
232                return Ok(client.clone());
233            }
234        }
235
236        // 确保服务器已安装
237        self.ensure_server(language).await?;
238
239        let server = LSP_SERVERS
240            .get(language)
241            .ok_or_else(|| format!("Unsupported language: {}", language))?;
242
243        // 构建 root URI
244        let root_uri = format!("file://{}", self.workspace_root.display());
245
246        let config = LspClientConfig {
247            command: server.command.clone(),
248            args: server.args.clone(),
249            root_uri: Some(root_uri),
250            initialization_options: None,
251        };
252
253        let client = Arc::new(LspClient::new(language, config));
254
255        // 启动客户端
256        client.start().await?;
257
258        self.clients
259            .write()
260            .await
261            .insert(language.to_string(), client.clone());
262
263        Ok(client)
264    }
265
266    /// 根据文件扩展名获取语言
267    pub fn get_language_by_extension(&self, ext: &str) -> Option<String> {
268        for (lang, server) in LSP_SERVERS.iter() {
269            if server.extensions.contains(&ext.to_string()) {
270                return Some(lang.to_string());
271            }
272        }
273        None
274    }
275
276    /// 获取语言 ID
277    pub fn get_language_id(&self, language: &str) -> String {
278        LSP_SERVERS
279            .get(language)
280            .map(|s| s.language_id.clone())
281            .unwrap_or_else(|| language.to_string())
282    }
283
284    /// 停止所有客户端
285    pub async fn stop_all(&self) {
286        let clients = self.clients.read().await;
287        for client in clients.values() {
288            client.stop().await;
289        }
290    }
291
292    /// 获取所有支持的语言
293    pub fn get_supported_languages(&self) -> Vec<String> {
294        LSP_SERVERS.keys().map(|s| s.to_string()).collect()
295    }
296
297    /// 获取服务器信息
298    pub fn get_server_info(&self, language: &str) -> Option<&LspServerInfo> {
299        LSP_SERVERS.get(language)
300    }
301}
302
303impl Default for LspManager {
304    fn default() -> Self {
305        Self::new(None)
306    }
307}
308
309#[cfg(test)]
310mod tests {
311    use super::*;
312
313    #[test]
314    fn test_lsp_servers_config() {
315        assert!(LSP_SERVERS.contains_key("typescript"));
316        assert!(LSP_SERVERS.contains_key("rust"));
317        assert!(LSP_SERVERS.contains_key("python"));
318    }
319
320    #[test]
321    fn test_get_language_by_extension() {
322        let manager = LspManager::default();
323        assert_eq!(
324            manager.get_language_by_extension(".ts"),
325            Some("typescript".to_string())
326        );
327        assert_eq!(
328            manager.get_language_by_extension(".rs"),
329            Some("rust".to_string())
330        );
331        assert_eq!(
332            manager.get_language_by_extension(".py"),
333            Some("python".to_string())
334        );
335        assert_eq!(manager.get_language_by_extension(".unknown"), None);
336    }
337
338    #[test]
339    fn test_get_language_id() {
340        let manager = LspManager::default();
341        assert_eq!(manager.get_language_id("typescript"), "typescript");
342        assert_eq!(manager.get_language_id("rust"), "rust");
343    }
344
345    #[test]
346    fn test_get_supported_languages() {
347        let manager = LspManager::default();
348        let languages = manager.get_supported_languages();
349        assert!(languages.contains(&"typescript".to_string()));
350        assert!(languages.contains(&"rust".to_string()));
351    }
352
353    #[test]
354    fn test_get_server_info() {
355        let manager = LspManager::default();
356        let info = manager.get_server_info("rust").unwrap();
357        assert_eq!(info.command, "rust-analyzer");
358    }
359}