Skip to main content

aster/parser/
symbol_extractor.rs

1//! LSP Symbol Extractor
2//!
3//! 使用 LSP 协议提取代码符号
4
5use serde::{Deserialize, Serialize};
6use std::path::Path;
7use std::sync::Arc;
8use tokio::sync::RwLock;
9
10use super::lsp_manager::LspManager;
11use super::types::*;
12
13/// 代码符号类型
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
15#[serde(rename_all = "lowercase")]
16pub enum SymbolKind {
17    Function,
18    Class,
19    Method,
20    Property,
21    Variable,
22    Constant,
23    Interface,
24    Type,
25    Enum,
26    Module,
27    Import,
28    Export,
29}
30
31impl From<LspSymbolKind> for SymbolKind {
32    fn from(kind: LspSymbolKind) -> Self {
33        match kind {
34            LspSymbolKind::Function => SymbolKind::Function,
35            LspSymbolKind::Class => SymbolKind::Class,
36            LspSymbolKind::Method => SymbolKind::Method,
37            LspSymbolKind::Property | LspSymbolKind::Field => SymbolKind::Property,
38            LspSymbolKind::Variable => SymbolKind::Variable,
39            LspSymbolKind::Constant => SymbolKind::Constant,
40            LspSymbolKind::Interface => SymbolKind::Interface,
41            LspSymbolKind::Enum => SymbolKind::Enum,
42            LspSymbolKind::Module | LspSymbolKind::Namespace | LspSymbolKind::Package => {
43                SymbolKind::Module
44            }
45            LspSymbolKind::TypeParameter | LspSymbolKind::Struct => SymbolKind::Type,
46            _ => SymbolKind::Variable,
47        }
48    }
49}
50
51/// 符号位置
52#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct SymbolLocation {
54    /// 文件路径
55    pub file: String,
56    /// 起始行 (1-indexed)
57    pub start_line: u32,
58    /// 起始列
59    pub start_column: u32,
60    /// 结束行
61    pub end_line: u32,
62    /// 结束列
63    pub end_column: u32,
64}
65
66/// 代码符号
67#[derive(Debug, Clone, Serialize, Deserialize)]
68pub struct CodeSymbol {
69    /// 符号名称
70    pub name: String,
71    /// 符号类型
72    pub kind: SymbolKind,
73    /// 位置
74    pub location: SymbolLocation,
75    /// 子符号
76    #[serde(skip_serializing_if = "Option::is_none")]
77    pub children: Option<Vec<CodeSymbol>>,
78    /// 签名
79    #[serde(skip_serializing_if = "Option::is_none")]
80    pub signature: Option<String>,
81    /// 文档
82    #[serde(skip_serializing_if = "Option::is_none")]
83    pub documentation: Option<String>,
84}
85
86/// 引用信息
87#[derive(Debug, Clone, Serialize, Deserialize)]
88pub struct Reference {
89    /// 文件路径
90    pub file: String,
91    /// 行号
92    pub line: u32,
93    /// 列号
94    pub column: u32,
95    /// 行文本
96    pub text: String,
97    /// 是否是定义
98    pub is_definition: bool,
99}
100
101/// LSP 符号提取器
102pub struct LspSymbolExtractor {
103    manager: Arc<RwLock<LspManager>>,
104    document_versions: Arc<RwLock<std::collections::HashMap<String, i32>>>,
105}
106
107impl LspSymbolExtractor {
108    /// 创建新的符号提取器
109    pub fn new(manager: LspManager) -> Self {
110        Self {
111            manager: Arc::new(RwLock::new(manager)),
112            document_versions: Arc::new(RwLock::new(std::collections::HashMap::new())),
113        }
114    }
115
116    /// 文件路径转 URI
117    fn file_to_uri(file_path: &str) -> String {
118        let normalized = file_path.replace('\\', "/");
119        if normalized.starts_with('/') {
120            format!("file://{}", normalized)
121        } else {
122            format!("file:///{}", normalized)
123        }
124    }
125
126    /// URI 转文件路径
127    fn uri_to_file(uri: &str) -> String {
128        let path = uri.trim_start_matches("file://").trim_start_matches('/');
129        if cfg!(windows) {
130            path.to_string()
131        } else {
132            format!("/{}", path)
133        }
134    }
135
136    /// 提取文件中的符号
137    pub async fn extract_symbols(&self, file_path: &str) -> Result<Vec<CodeSymbol>, String> {
138        let ext = Path::new(file_path)
139            .extension()
140            .and_then(|e| e.to_str())
141            .map(|e| format!(".{}", e))
142            .unwrap_or_default();
143
144        let manager = self.manager.read().await;
145        let language = manager
146            .get_language_by_extension(&ext)
147            .ok_or_else(|| format!("Unsupported file type: {}", ext))?;
148
149        let client = manager.get_client(&language).await?;
150
151        // 读取文件内容
152        let content = std::fs::read_to_string(file_path)
153            .map_err(|e| format!("Failed to read file: {}", e))?;
154
155        let uri = Self::file_to_uri(file_path);
156        let language_id = manager.get_language_id(&language);
157
158        // 获取或更新文档版本
159        let version = {
160            let mut versions = self.document_versions.write().await;
161            let v = versions.entry(uri.clone()).or_insert(0);
162            *v += 1;
163            *v
164        };
165
166        // 打开文档
167        client
168            .open_document(&uri, &language_id, version, &content)
169            .await;
170
171        // 获取符号
172        let _symbols = client.get_document_symbols(&uri).await?;
173
174        // 关闭文档
175        client.close_document(&uri).await;
176
177        // 转换符号 (简化实现)
178        Ok(Vec::new())
179    }
180
181    /// 查找引用
182    pub async fn find_references(
183        &self,
184        file_path: &str,
185        line: u32,
186        column: u32,
187    ) -> Result<Vec<Reference>, String> {
188        let ext = Path::new(file_path)
189            .extension()
190            .and_then(|e| e.to_str())
191            .map(|e| format!(".{}", e))
192            .unwrap_or_default();
193
194        let manager = self.manager.read().await;
195        let language = manager
196            .get_language_by_extension(&ext)
197            .ok_or_else(|| format!("Unsupported file type: {}", ext))?;
198
199        let client = manager.get_client(&language).await?;
200        let uri = Self::file_to_uri(file_path);
201
202        let position = LspPosition {
203            line: line.saturating_sub(1), // 转为 0-indexed
204            character: column,
205        };
206
207        let locations = client.find_references(&uri, position).await?;
208
209        // 转换结果
210        let references: Vec<Reference> = locations
211            .iter()
212            .map(|loc| {
213                let file = Self::uri_to_file(&loc.uri);
214                let ref_line = loc.range.start.line + 1;
215
216                // 尝试读取行文本
217                let text = std::fs::read_to_string(&file)
218                    .ok()
219                    .and_then(|content| {
220                        content
221                            .lines()
222                            .nth(ref_line as usize - 1)
223                            .map(|s| s.to_string())
224                    })
225                    .unwrap_or_default();
226
227                Reference {
228                    file,
229                    line: ref_line,
230                    column: loc.range.start.character,
231                    text,
232                    is_definition: false,
233                }
234            })
235            .collect();
236
237        Ok(references)
238    }
239
240    /// 跳转到定义
241    pub async fn get_definition(
242        &self,
243        file_path: &str,
244        line: u32,
245        column: u32,
246    ) -> Result<Option<Reference>, String> {
247        let ext = Path::new(file_path)
248            .extension()
249            .and_then(|e| e.to_str())
250            .map(|e| format!(".{}", e))
251            .unwrap_or_default();
252
253        let manager = self.manager.read().await;
254        let language = manager
255            .get_language_by_extension(&ext)
256            .ok_or_else(|| format!("Unsupported file type: {}", ext))?;
257
258        let client = manager.get_client(&language).await?;
259        let uri = Self::file_to_uri(file_path);
260
261        let position = LspPosition {
262            line: line.saturating_sub(1),
263            character: column,
264        };
265
266        let location = client.get_definition(&uri, position).await?;
267
268        Ok(location.map(|loc| {
269            let file = Self::uri_to_file(&loc.uri);
270            let def_line = loc.range.start.line + 1;
271
272            let text = std::fs::read_to_string(&file)
273                .ok()
274                .and_then(|content| {
275                    content
276                        .lines()
277                        .nth(def_line as usize - 1)
278                        .map(|s| s.to_string())
279                })
280                .unwrap_or_default();
281
282            Reference {
283                file,
284                line: def_line,
285                column: loc.range.start.character,
286                text,
287                is_definition: true,
288            }
289        }))
290    }
291
292    /// 扁平化符号树
293    pub fn flatten_symbols(symbols: &[CodeSymbol]) -> Vec<CodeSymbol> {
294        let mut result = Vec::new();
295        for sym in symbols {
296            result.push(sym.clone());
297            if let Some(ref children) = sym.children {
298                result.extend(Self::flatten_symbols(children));
299            }
300        }
301        result
302    }
303
304    /// 停止所有 LSP 客户端
305    pub async fn shutdown(&self) {
306        self.manager.read().await.stop_all().await;
307    }
308}
309
310#[cfg(test)]
311mod tests {
312    use super::*;
313
314    #[test]
315    fn test_symbol_kind_from_lsp() {
316        assert_eq!(
317            SymbolKind::from(LspSymbolKind::Function),
318            SymbolKind::Function
319        );
320        assert_eq!(SymbolKind::from(LspSymbolKind::Class), SymbolKind::Class);
321        assert_eq!(SymbolKind::from(LspSymbolKind::Method), SymbolKind::Method);
322        assert_eq!(
323            SymbolKind::from(LspSymbolKind::Interface),
324            SymbolKind::Interface
325        );
326    }
327
328    #[test]
329    fn test_file_to_uri() {
330        let uri = LspSymbolExtractor::file_to_uri("/tmp/test.rs");
331        assert!(uri.starts_with("file://"));
332        assert!(uri.contains("tmp"));
333    }
334
335    #[test]
336    fn test_uri_to_file() {
337        let file = LspSymbolExtractor::uri_to_file("file:///tmp/test.rs");
338        assert!(file.contains("tmp"));
339    }
340
341    #[test]
342    fn test_flatten_symbols() {
343        let symbols = vec![CodeSymbol {
344            name: "Parent".to_string(),
345            kind: SymbolKind::Class,
346            location: SymbolLocation {
347                file: "test.rs".to_string(),
348                start_line: 1,
349                start_column: 0,
350                end_line: 10,
351                end_column: 0,
352            },
353            children: Some(vec![CodeSymbol {
354                name: "child".to_string(),
355                kind: SymbolKind::Method,
356                location: SymbolLocation {
357                    file: "test.rs".to_string(),
358                    start_line: 2,
359                    start_column: 0,
360                    end_line: 5,
361                    end_column: 0,
362                },
363                children: None,
364                signature: None,
365                documentation: None,
366            }]),
367            signature: None,
368            documentation: None,
369        }];
370
371        let flat = LspSymbolExtractor::flatten_symbols(&symbols);
372        assert_eq!(flat.len(), 2);
373        assert_eq!(flat[0].name, "Parent");
374        assert_eq!(flat[1].name, "child");
375    }
376
377    #[test]
378    fn test_reference_struct() {
379        let reference = Reference {
380            file: "test.rs".to_string(),
381            line: 10,
382            column: 5,
383            text: "fn test()".to_string(),
384            is_definition: true,
385        };
386        assert_eq!(reference.line, 10);
387        assert!(reference.is_definition);
388    }
389}