1use super::languages::{extractor_for_extension, get_extractor, LanguageExtractor};
7use super::{ExtractedSymbol, FunctionCall, Import};
8use crate::error::{AcpError, Result};
9use std::collections::HashMap;
10use std::path::Path;
11use std::sync::Mutex;
12use tree_sitter::{Parser, Tree};
13
14pub struct AstParser {
17 parsers: Mutex<HashMap<String, Parser>>,
19}
20
21impl AstParser {
22 pub fn new() -> Result<Self> {
24 Ok(Self {
25 parsers: Mutex::new(HashMap::new()),
26 })
27 }
28
29 pub fn parse_and_extract(&self, source: &str, language: &str) -> Result<Vec<ExtractedSymbol>> {
31 let extractor = get_extractor(language)
32 .ok_or_else(|| AcpError::UnsupportedLanguage(language.to_string()))?;
33
34 let tree = self.parse(source, extractor.as_ref())?;
35 extractor.extract_symbols(&tree, source)
36 }
37
38 pub fn parse_by_extension(&self, source: &str, ext: &str) -> Result<Vec<ExtractedSymbol>> {
40 let extractor = extractor_for_extension(ext)
41 .ok_or_else(|| AcpError::UnsupportedLanguage(format!(".{}", ext)))?;
42
43 let tree = self.parse(source, extractor.as_ref())?;
44 extractor.extract_symbols(&tree, source)
45 }
46
47 pub fn parse_file(&self, path: &Path, source: &str) -> Result<Vec<ExtractedSymbol>> {
49 let ext = path
50 .extension()
51 .and_then(|e| e.to_str())
52 .ok_or_else(|| AcpError::UnsupportedLanguage("no extension".to_string()))?;
53 self.parse_by_extension(source, ext)
54 }
55
56 pub fn parse_calls(&self, path: &Path, source: &str) -> Result<Vec<FunctionCall>> {
58 let ext = path
59 .extension()
60 .and_then(|e| e.to_str())
61 .ok_or_else(|| AcpError::UnsupportedLanguage("no extension".to_string()))?;
62
63 let extractor = extractor_for_extension(ext)
64 .ok_or_else(|| AcpError::UnsupportedLanguage(format!(".{}", ext)))?;
65
66 let tree = self.parse(source, extractor.as_ref())?;
67 extractor.extract_calls(&tree, source, None)
68 }
69
70 pub fn extract_imports(&self, source: &str, language: &str) -> Result<Vec<Import>> {
72 let extractor = get_extractor(language)
73 .ok_or_else(|| AcpError::UnsupportedLanguage(language.to_string()))?;
74
75 let tree = self.parse(source, extractor.as_ref())?;
76 extractor.extract_imports(&tree, source)
77 }
78
79 pub fn extract_calls_by_language(
81 &self,
82 source: &str,
83 language: &str,
84 current_function: Option<&str>,
85 ) -> Result<Vec<FunctionCall>> {
86 let extractor = get_extractor(language)
87 .ok_or_else(|| AcpError::UnsupportedLanguage(language.to_string()))?;
88
89 let tree = self.parse(source, extractor.as_ref())?;
90 extractor.extract_calls(&tree, source, current_function)
91 }
92
93 fn parse(&self, source: &str, extractor: &dyn LanguageExtractor) -> Result<Tree> {
95 let lang_name = extractor.name().to_string();
96
97 let mut parsers = self
99 .parsers
100 .lock()
101 .map_err(|_| AcpError::parse("Parser lock poisoned".to_string()))?;
102
103 let parser = parsers.entry(lang_name.clone()).or_insert_with(|| {
105 let mut p = Parser::new();
106 p.set_language(&extractor.language())
107 .expect("Failed to set language");
108 p
109 });
110
111 parser
112 .parse(source, None)
113 .ok_or_else(|| AcpError::parse(format!("Failed to parse {} source", lang_name)))
114 }
115
116 pub fn supported_languages() -> &'static [&'static str] {
118 &["typescript", "javascript", "rust", "python", "go", "java"]
119 }
120
121 pub fn supported_extensions() -> &'static [&'static str] {
123 &[
124 "ts", "tsx", "js", "jsx", "mjs", "cjs", "rs", "py", "pyi", "go", "java",
125 ]
126 }
127
128 pub fn is_language_supported(language: &str) -> bool {
130 get_extractor(language).is_some()
131 }
132
133 pub fn is_extension_supported(ext: &str) -> bool {
135 extractor_for_extension(ext).is_some()
136 }
137}
138
139impl Default for AstParser {
140 fn default() -> Self {
141 Self::new().expect("Failed to create AST parser")
142 }
143}
144
145#[cfg(test)]
146mod tests {
147 use super::*;
148
149 #[test]
150 fn test_supported_languages() {
151 let langs = AstParser::supported_languages();
152 assert!(langs.contains(&"typescript"));
153 assert!(langs.contains(&"rust"));
154 assert!(langs.contains(&"python"));
155 }
156
157 #[test]
158 fn test_is_language_supported() {
159 assert!(AstParser::is_language_supported("typescript"));
160 assert!(AstParser::is_language_supported("rust"));
161 assert!(AstParser::is_language_supported("python"));
162 assert!(!AstParser::is_language_supported("cobol"));
163 }
164
165 #[test]
166 fn test_is_extension_supported() {
167 assert!(AstParser::is_extension_supported("ts"));
168 assert!(AstParser::is_extension_supported("rs"));
169 assert!(AstParser::is_extension_supported("py"));
170 assert!(!AstParser::is_extension_supported("cob"));
171 }
172}