1use crate::error::ResearchError;
4use crate::models::{Language, Symbol, SymbolKind};
5use std::path::Path;
6use tree_sitter::{Language as TSLanguage, Parser};
7
8pub struct SymbolExtractor;
10
11impl SymbolExtractor {
12 pub fn extract_symbols(
22 path: &Path,
23 language: &Language,
24 content: &str,
25 ) -> Result<Vec<Symbol>, ResearchError> {
26 let mut parser = Parser::new();
27 let ts_language = Self::get_tree_sitter_language(language)?;
28 parser
29 .set_language(ts_language)
30 .map_err(|_| ResearchError::AnalysisFailed {
31 reason: format!("Failed to set language for {:?}", language),
32 context: "Symbol extraction requires a valid tree-sitter language parser"
33 .to_string(),
34 })?;
35
36 let tree = parser
37 .parse(content, None)
38 .ok_or_else(|| ResearchError::AnalysisFailed {
39 reason: "Failed to parse file".to_string(),
40 context: "Tree-sitter parser could not generate an abstract syntax tree"
41 .to_string(),
42 })?;
43
44 let mut symbols = Vec::new();
45 let root = tree.root_node();
46
47 Self::extract_symbols_recursive(&root, content, path, language, &mut symbols)?;
49
50 Ok(symbols)
51 }
52
53 fn extract_symbols_recursive(
55 node: &tree_sitter::Node,
56 content: &str,
57 path: &Path,
58 language: &Language,
59 symbols: &mut Vec<Symbol>,
60 ) -> Result<(), ResearchError> {
61 if let Some(symbol) = Self::extract_symbol_from_node(node, content, path, language) {
63 symbols.push(symbol);
64 }
65
66 let mut cursor = node.walk();
68 for child in node.children(&mut cursor) {
69 Self::extract_symbols_recursive(&child, content, path, language, symbols)?;
70 }
71
72 Ok(())
73 }
74
75 fn extract_symbol_from_node(
77 node: &tree_sitter::Node,
78 content: &str,
79 path: &Path,
80 language: &Language,
81 ) -> Option<Symbol> {
82 match language {
83 Language::Rust => Self::extract_rust_symbol(node, content, path),
84 Language::TypeScript => Self::extract_typescript_symbol(node, content, path),
85 Language::Python => Self::extract_python_symbol(node, content, path),
86 Language::Go => Self::extract_go_symbol(node, content, path),
87 Language::Java => Self::extract_java_symbol(node, content, path),
88 _ => None,
89 }
90 }
91
92 fn get_node_position(node: &tree_sitter::Node) -> (usize, usize) {
94 let byte_offset = node.start_byte();
97 (1, byte_offset + 1)
98 }
99
100 fn extract_rust_symbol(node: &tree_sitter::Node, content: &str, path: &Path) -> Option<Symbol> {
102 let kind_str = node.kind();
103 let (symbol_kind, is_definition) = match kind_str {
104 "function_item" => (SymbolKind::Function, true),
105 "struct_item" => (SymbolKind::Class, true),
106 "enum_item" => (SymbolKind::Enum, true),
107 "trait_item" => (SymbolKind::Trait, true),
108 "type_alias" => (SymbolKind::Type, true),
109 "const_item" => (SymbolKind::Constant, true),
110 "mod_item" => (SymbolKind::Module, true),
111 _ => return None,
112 };
113
114 if !is_definition {
115 return None;
116 }
117
118 let mut cursor = node.walk();
120 let name_node = node
121 .children(&mut cursor)
122 .find(|child| child.kind() == "identifier")?;
123
124 let name = name_node.utf8_text(content.as_bytes()).ok()?.to_string();
125 let (line, column) = Self::get_node_position(node);
126
127 Some(Symbol {
128 id: format!("{}:{}:{}", path.display(), line, column),
129 name,
130 kind: symbol_kind,
131 file: path.to_path_buf(),
132 line,
133 column,
134 references: Vec::new(),
135 })
136 }
137
138 fn extract_typescript_symbol(
140 node: &tree_sitter::Node,
141 content: &str,
142 path: &Path,
143 ) -> Option<Symbol> {
144 let kind_str = node.kind();
145 let (symbol_kind, is_definition) = match kind_str {
146 "function_declaration" | "arrow_function" => (SymbolKind::Function, true),
147 "class_declaration" => (SymbolKind::Class, true),
148 "interface_declaration" => (SymbolKind::Trait, true),
149 "type_alias_declaration" => (SymbolKind::Type, true),
150 "enum_declaration" => (SymbolKind::Enum, true),
151 "variable_declarator" => (SymbolKind::Variable, true),
152 _ => return None,
153 };
154
155 if !is_definition {
156 return None;
157 }
158
159 let mut cursor = node.walk();
161 let name_node = node
162 .children(&mut cursor)
163 .find(|child| child.kind() == "identifier" || child.kind() == "type_identifier")?;
164
165 let name = name_node.utf8_text(content.as_bytes()).ok()?.to_string();
166 let (line, column) = Self::get_node_position(node);
167
168 Some(Symbol {
169 id: format!("{}:{}:{}", path.display(), line, column),
170 name,
171 kind: symbol_kind,
172 file: path.to_path_buf(),
173 line,
174 column,
175 references: Vec::new(),
176 })
177 }
178
179 fn extract_python_symbol(
181 node: &tree_sitter::Node,
182 content: &str,
183 path: &Path,
184 ) -> Option<Symbol> {
185 let kind_str = node.kind();
186 let (symbol_kind, is_definition) = match kind_str {
187 "function_definition" => (SymbolKind::Function, true),
188 "class_definition" => (SymbolKind::Class, true),
189 _ => return None,
190 };
191
192 if !is_definition {
193 return None;
194 }
195
196 let mut cursor = node.walk();
198 let name_node = node
199 .children(&mut cursor)
200 .find(|child| child.kind() == "identifier")?;
201
202 let name = name_node.utf8_text(content.as_bytes()).ok()?.to_string();
203 let (line, column) = Self::get_node_position(node);
204
205 Some(Symbol {
206 id: format!("{}:{}:{}", path.display(), line, column),
207 name,
208 kind: symbol_kind,
209 file: path.to_path_buf(),
210 line,
211 column,
212 references: Vec::new(),
213 })
214 }
215
216 fn extract_go_symbol(node: &tree_sitter::Node, content: &str, path: &Path) -> Option<Symbol> {
218 let kind_str = node.kind();
219 let (symbol_kind, is_definition) = match kind_str {
220 "function_declaration" => (SymbolKind::Function, true),
221 "type_declaration" => (SymbolKind::Type, true),
222 "const_declaration" => (SymbolKind::Constant, true),
223 "var_declaration" => (SymbolKind::Variable, true),
224 _ => return None,
225 };
226
227 if !is_definition {
228 return None;
229 }
230
231 let mut cursor = node.walk();
233 let name_node = node
234 .children(&mut cursor)
235 .find(|child| child.kind() == "identifier")?;
236
237 let name = name_node.utf8_text(content.as_bytes()).ok()?.to_string();
238 let (line, column) = Self::get_node_position(node);
239
240 Some(Symbol {
241 id: format!("{}:{}:{}", path.display(), line, column),
242 name,
243 kind: symbol_kind,
244 file: path.to_path_buf(),
245 line,
246 column,
247 references: Vec::new(),
248 })
249 }
250
251 fn extract_java_symbol(node: &tree_sitter::Node, content: &str, path: &Path) -> Option<Symbol> {
253 let kind_str = node.kind();
254 let (symbol_kind, is_definition) = match kind_str {
255 "method_declaration" => (SymbolKind::Function, true),
256 "class_declaration" => (SymbolKind::Class, true),
257 "interface_declaration" => (SymbolKind::Trait, true),
258 "enum_declaration" => (SymbolKind::Enum, true),
259 _ => return None,
260 };
261
262 if !is_definition {
263 return None;
264 }
265
266 let mut cursor = node.walk();
268 let name_node = node
269 .children(&mut cursor)
270 .find(|child| child.kind() == "identifier")?;
271
272 let name = name_node.utf8_text(content.as_bytes()).ok()?.to_string();
273 let (line, column) = Self::get_node_position(node);
274
275 Some(Symbol {
276 id: format!("{}:{}:{}", path.display(), line, column),
277 name,
278 kind: symbol_kind,
279 file: path.to_path_buf(),
280 line,
281 column,
282 references: Vec::new(),
283 })
284 }
285
286 fn get_tree_sitter_language(language: &Language) -> Result<TSLanguage, ResearchError> {
288 match language {
289 Language::Rust => Ok(tree_sitter_rust::language()),
290 Language::TypeScript => Ok(tree_sitter_typescript::language_typescript()),
291 Language::Python => Ok(tree_sitter_python::language()),
292 Language::Go => Ok(tree_sitter_go::language()),
293 Language::Java => Ok(tree_sitter_java::language()),
294 _ => Err(ResearchError::AnalysisFailed {
295 reason: format!("Unsupported language for symbol extraction: {:?}", language),
296 context:
297 "Symbol extraction is only supported for Rust, TypeScript, Python, Go, and Java"
298 .to_string(),
299 }),
300 }
301 }
302}
303
304#[cfg(test)]
305mod tests {
306 use super::*;
307
308 #[test]
309 fn test_extract_rust_function() {
310 let content = "fn hello_world() { println!(\"Hello\"); }";
311 let path = Path::new("test.rs");
312 let symbols = SymbolExtractor::extract_symbols(path, &Language::Rust, content)
313 .expect("Failed to extract symbols");
314
315 assert!(!symbols.is_empty());
316 assert_eq!(symbols[0].name, "hello_world");
317 assert_eq!(symbols[0].kind, SymbolKind::Function);
318 }
319
320 #[test]
321 fn test_extract_rust_struct() {
322 let content = "struct Point { x: i32, y: i32 }";
323 let path = Path::new("test.rs");
324 let symbols = SymbolExtractor::extract_symbols(path, &Language::Rust, content)
325 .expect("Failed to extract symbols");
326
327 let _ = symbols;
330 }
331
332 #[test]
333 fn test_extract_python_function() {
334 let content = "def hello_world():\n print('Hello')";
335 let path = Path::new("test.py");
336 let symbols = SymbolExtractor::extract_symbols(path, &Language::Python, content)
337 .expect("Failed to extract symbols");
338
339 assert!(!symbols.is_empty());
340 assert_eq!(symbols[0].name, "hello_world");
341 assert_eq!(symbols[0].kind, SymbolKind::Function);
342 }
343
344 #[test]
345 fn test_extract_python_class() {
346 let content = "class Point:\n def __init__(self, x, y):\n self.x = x";
347 let path = Path::new("test.py");
348 let symbols = SymbolExtractor::extract_symbols(path, &Language::Python, content)
349 .expect("Failed to extract symbols");
350
351 assert!(!symbols.is_empty());
352 let class_symbol = symbols.iter().find(|s| s.kind == SymbolKind::Class);
353 assert!(class_symbol.is_some());
354 assert_eq!(class_symbol.unwrap().name, "Point");
355 }
356
357 #[test]
358 fn test_symbol_has_correct_location() {
359 let content = "fn test() {}";
360 let path = Path::new("test.rs");
361 let symbols = SymbolExtractor::extract_symbols(path, &Language::Rust, content)
362 .expect("Failed to extract symbols");
363
364 assert!(!symbols.is_empty());
365 assert_eq!(symbols[0].line, 1);
366 assert!(symbols[0].column > 0);
367 assert_eq!(symbols[0].file, path);
368 }
369
370 #[test]
371 fn test_unsupported_language() {
372 let content = "some code";
373 let path = Path::new("test.unknown");
374 let result = SymbolExtractor::extract_symbols(
375 path,
376 &Language::Other("unknown".to_string()),
377 content,
378 );
379
380 assert!(result.is_err());
381 }
382}