1use anyhow::{Context, Result};
7use chrono::Utc;
8use tree_sitter::{Language, Node, Parser};
9
10use crate::indexer::FileInfo;
11use crate::relations::types::{Definition, SymbolId, SymbolKind, Visibility};
12
13pub struct SymbolExtractor {
15 }
17
18impl SymbolExtractor {
19 pub fn new() -> Self {
21 Self {}
22 }
23
24 pub fn extract_definitions(&self, file_info: &FileInfo) -> Result<Vec<Definition>> {
26 let extension = file_info.extension.as_deref().unwrap_or("");
27
28 let (language, language_name) = match get_language_for_extension(extension) {
30 Some(lang) => lang,
31 None => return Ok(Vec::new()), };
33
34 let mut parser = Parser::new();
35 parser
36 .set_language(&language)
37 .context("Failed to set parser language")?;
38
39 let tree = parser
40 .parse(&file_info.content, None)
41 .context("Failed to parse source code")?;
42
43 let root_node = tree.root_node();
44 let mut definitions = Vec::new();
45
46 self.extract_from_node(
48 root_node,
49 &file_info.content,
50 &language_name,
51 file_info,
52 None,
53 &mut definitions,
54 );
55
56 Ok(definitions)
57 }
58
59 fn extract_from_node(
61 &self,
62 node: Node,
63 source: &str,
64 language: &str,
65 file_info: &FileInfo,
66 parent_id: Option<String>,
67 result: &mut Vec<Definition>,
68 ) {
69 let kind = node.kind();
70
71 if is_definition_node(kind, language) {
73 if let Some(def) = self.node_to_definition(node, source, language, file_info, &parent_id)
74 {
75 let new_parent_id = Some(def.to_storage_id());
76 result.push(def);
77
78 let mut cursor = node.walk();
80 for child in node.children(&mut cursor) {
81 self.extract_from_node(
82 child,
83 source,
84 language,
85 file_info,
86 new_parent_id.clone(),
87 result,
88 );
89 }
90 return;
91 }
92 }
93
94 let mut cursor = node.walk();
96 for child in node.children(&mut cursor) {
97 self.extract_from_node(child, source, language, file_info, parent_id.clone(), result);
98 }
99 }
100
101 fn node_to_definition(
103 &self,
104 node: Node,
105 source: &str,
106 language: &str,
107 file_info: &FileInfo,
108 parent_id: &Option<String>,
109 ) -> Option<Definition> {
110 let kind = node.kind();
111 let symbol_kind = SymbolKind::from_ast_kind(kind);
112
113 let name = extract_symbol_name(node, source, language)?;
115
116 let start_pos = node.start_position();
118 let end_pos = node.end_position();
119
120 let signature = extract_signature(node, source, language);
122
123 let doc_comment = extract_doc_comment(node, source, language);
125
126 let node_text = &source[node.start_byte()..node.end_byte().min(source.len())];
128 let visibility = Visibility::from_keywords(node_text);
129
130 Some(Definition {
131 symbol_id: SymbolId::new(
132 &file_info.relative_path,
133 name,
134 symbol_kind,
135 start_pos.row + 1, start_pos.column,
137 ),
138 root_path: Some(file_info.root_path.clone()),
139 project: file_info.project.clone(),
140 end_line: end_pos.row + 1,
141 end_col: end_pos.column,
142 signature,
143 doc_comment,
144 visibility,
145 parent_id: parent_id.clone(),
146 indexed_at: Utc::now().timestamp(),
147 })
148 }
149}
150
151impl Default for SymbolExtractor {
152 fn default() -> Self {
153 Self::new()
154 }
155}
156
157fn get_language_for_extension(extension: &str) -> Option<(Language, String)> {
159 match extension.to_lowercase().as_str() {
160 "rs" => Some((tree_sitter_rust::LANGUAGE.into(), "Rust".to_string())),
161 "py" => Some((tree_sitter_python::LANGUAGE.into(), "Python".to_string())),
162 "js" | "mjs" | "cjs" | "jsx" => Some((
163 tree_sitter_javascript::LANGUAGE.into(),
164 "JavaScript".to_string(),
165 )),
166 "ts" | "tsx" => Some((
167 tree_sitter_typescript::LANGUAGE_TYPESCRIPT.into(),
168 "TypeScript".to_string(),
169 )),
170 "go" => Some((tree_sitter_go::LANGUAGE.into(), "Go".to_string())),
171 "java" => Some((tree_sitter_java::LANGUAGE.into(), "Java".to_string())),
172 "swift" => Some((tree_sitter_swift::LANGUAGE.into(), "Swift".to_string())),
173 "c" | "h" => Some((tree_sitter_c::LANGUAGE.into(), "C".to_string())),
174 "cpp" | "cc" | "cxx" | "hpp" | "hxx" | "hh" => {
175 Some((tree_sitter_cpp::LANGUAGE.into(), "C++".to_string()))
176 }
177 "cs" => Some((tree_sitter_c_sharp::LANGUAGE.into(), "C#".to_string())),
178 "rb" => Some((tree_sitter_ruby::LANGUAGE.into(), "Ruby".to_string())),
179 "php" => Some((tree_sitter_php::LANGUAGE_PHP.into(), "PHP".to_string())),
180 _ => None,
181 }
182}
183
184fn is_definition_node(kind: &str, language: &str) -> bool {
186 match language {
187 "Rust" => matches!(
188 kind,
189 "function_item"
190 | "impl_item"
191 | "trait_item"
192 | "struct_item"
193 | "enum_item"
194 | "mod_item"
195 | "const_item"
196 | "static_item"
197 | "type_item"
198 ),
199 "Python" => matches!(
200 kind,
201 "function_definition" | "class_definition" | "decorated_definition"
202 ),
203 "JavaScript" | "TypeScript" => matches!(
204 kind,
205 "function_declaration"
206 | "function_expression"
207 | "arrow_function"
208 | "method_definition"
209 | "class_declaration"
210 | "interface_declaration"
211 | "type_alias_declaration"
212 ),
213 "Go" => matches!(
214 kind,
215 "function_declaration" | "method_declaration" | "type_declaration"
216 ),
217 "Java" => matches!(
218 kind,
219 "method_declaration"
220 | "class_declaration"
221 | "interface_declaration"
222 | "constructor_declaration"
223 | "enum_declaration"
224 ),
225 "Swift" => matches!(
226 kind,
227 "function_declaration"
228 | "class_declaration"
229 | "struct_declaration"
230 | "enum_declaration"
231 | "protocol_declaration"
232 ),
233 "C" => matches!(
234 kind,
235 "function_definition" | "struct_specifier" | "enum_specifier"
236 ),
237 "C++" => matches!(
238 kind,
239 "function_definition"
240 | "class_specifier"
241 | "struct_specifier"
242 | "enum_specifier"
243 | "namespace_definition"
244 ),
245 "C#" => matches!(
246 kind,
247 "method_declaration"
248 | "class_declaration"
249 | "struct_declaration"
250 | "interface_declaration"
251 | "enum_declaration"
252 | "constructor_declaration"
253 ),
254 "Ruby" => matches!(kind, "method" | "singleton_method" | "class" | "module"),
255 "PHP" => matches!(
256 kind,
257 "function_definition"
258 | "method_declaration"
259 | "class_declaration"
260 | "interface_declaration"
261 | "trait_declaration"
262 ),
263 _ => false,
264 }
265}
266
267fn extract_symbol_name(node: Node, source: &str, language: &str) -> Option<String> {
269 let name_node = find_name_node(node, language)?;
271
272 let start = name_node.start_byte();
273 let end = name_node.end_byte();
274
275 if end > source.len() {
276 return None;
277 }
278
279 let name = source[start..end].to_string();
280
281 if name.trim().is_empty() {
283 return None;
284 }
285
286 Some(name)
287}
288
289fn find_name_node<'a>(node: Node<'a>, language: &str) -> Option<Node<'a>> {
291 let kind = node.kind();
292
293 match language {
295 "Rust" => {
296 if let Some(name_node) = node.child_by_field_name("name") {
298 return Some(name_node);
299 }
300 if kind == "impl_item" {
302 if let Some(type_node) = node.child_by_field_name("type") {
303 return Some(type_node);
304 }
305 }
306 }
307 "Python" => {
308 if let Some(name_node) = node.child_by_field_name("name") {
310 return Some(name_node);
311 }
312 if kind == "decorated_definition" {
314 let mut cursor = node.walk();
315 for child in node.children(&mut cursor) {
316 if child.kind() == "function_definition" || child.kind() == "class_definition" {
317 return find_name_node(child, language);
318 }
319 }
320 }
321 }
322 "JavaScript" | "TypeScript" => {
323 if let Some(name_node) = node.child_by_field_name("name") {
325 return Some(name_node);
326 }
327 if kind == "arrow_function" {
329 if let Some(parent) = node.parent() {
331 if parent.kind() == "variable_declarator" {
332 if let Some(name_node) = parent.child_by_field_name("name") {
333 return Some(name_node);
334 }
335 }
336 }
337 }
338 }
339 "Go" => {
340 if let Some(name_node) = node.child_by_field_name("name") {
341 return Some(name_node);
342 }
343 }
344 "Java" => {
345 if let Some(name_node) = node.child_by_field_name("name") {
346 return Some(name_node);
347 }
348 }
349 "Swift" => {
350 if let Some(name_node) = node.child_by_field_name("name") {
351 return Some(name_node);
352 }
353 }
354 "C" | "C++" => {
355 if let Some(declarator) = node.child_by_field_name("declarator") {
357 return find_innermost_identifier(declarator);
359 }
360 if kind == "struct_specifier" || kind == "class_specifier" || kind == "enum_specifier" {
362 if let Some(name_node) = node.child_by_field_name("name") {
363 return Some(name_node);
364 }
365 }
366 }
367 "C#" => {
368 if let Some(name_node) = node.child_by_field_name("name") {
369 return Some(name_node);
370 }
371 }
372 "Ruby" => {
373 if let Some(name_node) = node.child_by_field_name("name") {
374 return Some(name_node);
375 }
376 }
377 "PHP" => {
378 if let Some(name_node) = node.child_by_field_name("name") {
379 return Some(name_node);
380 }
381 }
382 _ => {}
383 }
384
385 let mut cursor = node.walk();
387 for child in node.children(&mut cursor) {
388 if child.kind() == "identifier"
389 || child.kind() == "type_identifier"
390 || child.kind() == "name"
391 {
392 return Some(child);
393 }
394 }
395
396 None
397}
398
399fn find_innermost_identifier<'a>(node: Node<'a>) -> Option<Node<'a>> {
401 if node.kind() == "identifier" || node.kind() == "field_identifier" {
403 return Some(node);
404 }
405
406 if let Some(name_node) = node.child_by_field_name("declarator") {
408 return find_innermost_identifier(name_node);
409 }
410
411 let mut cursor = node.walk();
413 for child in node.children(&mut cursor) {
414 if let Some(id) = find_innermost_identifier(child) {
415 return Some(id);
416 }
417 }
418
419 None
420}
421
422fn extract_signature(node: Node, source: &str, _language: &str) -> String {
424 let start = node.start_byte();
425 let end = node.end_byte().min(source.len());
426 let text = &source[start..end];
427
428 let first_line = text.lines().next().unwrap_or("");
430 if first_line.len() > 200 {
431 format!("{}...", &first_line[..200])
432 } else {
433 first_line.to_string()
434 }
435}
436
437fn extract_doc_comment(node: Node, source: &str, language: &str) -> Option<String> {
439 let mut prev = node.prev_sibling();
441
442 while let Some(sibling) = prev {
443 let kind = sibling.kind();
444
445 let is_doc_comment = match language {
447 "Rust" => kind == "line_comment" || kind == "block_comment",
448 "Python" => kind == "comment" || kind == "expression_statement", "JavaScript" | "TypeScript" => kind == "comment",
450 "Java" => kind == "line_comment" || kind == "block_comment",
451 "Go" => kind == "comment",
452 "C" | "C++" => kind == "comment",
453 "C#" => kind == "comment",
454 "Ruby" => kind == "comment",
455 "PHP" => kind == "comment",
456 _ => kind.contains("comment"),
457 };
458
459 if is_doc_comment {
460 let start = sibling.start_byte();
461 let end = sibling.end_byte().min(source.len());
462 let comment = source[start..end].trim().to_string();
463
464 let cleaned = clean_comment(&comment, language);
466 if !cleaned.is_empty() {
467 return Some(cleaned);
468 }
469 }
470
471 if !kind.contains("comment") && kind != "decorator" && kind != "attribute" {
473 break;
474 }
475
476 prev = sibling.prev_sibling();
477 }
478
479 None
480}
481
482fn clean_comment(comment: &str, _language: &str) -> String {
484 let lines: Vec<&str> = comment.lines().collect();
485
486 let cleaned: Vec<String> = lines
487 .iter()
488 .map(|line| {
489 let mut s = line.trim();
490 for prefix in ["///", "//!", "//", "/*", "*/", "*", "#", "\"\"\"", "'''"] {
492 s = s.trim_start_matches(prefix);
493 }
494 s.trim().to_string()
495 })
496 .filter(|s| !s.is_empty())
497 .collect();
498
499 cleaned.join("\n")
500}
501
502#[cfg(test)]
503mod tests {
504 use super::*;
505 use std::path::PathBuf;
506
507 fn make_file_info(content: &str, extension: &str) -> FileInfo {
508 FileInfo {
509 path: PathBuf::from(format!("test.{}", extension)),
510 relative_path: format!("test.{}", extension),
511 root_path: "/test".to_string(),
512 project: None,
513 extension: Some(extension.to_string()),
514 language: None,
515 content: content.to_string(),
516 hash: "test_hash".to_string(),
517 }
518 }
519
520 #[test]
521 fn test_rust_extraction() {
522 let source = r#"
523/// A greeting function
524pub fn greet(name: &str) -> String {
525 format!("Hello, {}!", name)
526}
527
528struct Person {
529 name: String,
530}
531
532impl Person {
533 fn new(name: String) -> Self {
534 Self { name }
535 }
536}
537"#;
538 let file_info = make_file_info(source, "rs");
539 let extractor = SymbolExtractor::new();
540 let definitions = extractor.extract_definitions(&file_info).unwrap();
541
542 assert!(!definitions.is_empty());
543
544 let greet = definitions.iter().find(|d| d.name() == "greet");
546 assert!(greet.is_some(), "Should find greet function");
547
548 let greet = greet.unwrap();
549 assert_eq!(greet.kind(), SymbolKind::Function);
550 assert_eq!(greet.visibility, Visibility::Public);
551 assert!(greet.doc_comment.is_some());
552 }
553
554 #[test]
555 fn test_python_extraction() {
556 let source = r#"
557def hello(name):
558 """Say hello."""
559 print(f"Hello, {name}!")
560
561class MyClass:
562 def __init__(self, value):
563 self.value = value
564"#;
565 let file_info = make_file_info(source, "py");
566 let extractor = SymbolExtractor::new();
567 let definitions = extractor.extract_definitions(&file_info).unwrap();
568
569 assert!(!definitions.is_empty());
570
571 let hello = definitions.iter().find(|d| d.name() == "hello");
573 assert!(hello.is_some(), "Should find hello function");
574
575 let my_class = definitions.iter().find(|d| d.name() == "MyClass");
577 assert!(my_class.is_some(), "Should find MyClass");
578 }
579
580 #[test]
581 fn test_javascript_extraction() {
582 let source = r#"
583function add(a, b) {
584 return a + b;
585}
586
587class Calculator {
588 constructor() {
589 this.result = 0;
590 }
591
592 add(x) {
593 this.result += x;
594 }
595}
596"#;
597 let file_info = make_file_info(source, "js");
598 let extractor = SymbolExtractor::new();
599 let definitions = extractor.extract_definitions(&file_info).unwrap();
600
601 assert!(!definitions.is_empty());
602
603 let add = definitions.iter().find(|d| d.name() == "add");
605 assert!(add.is_some(), "Should find add function");
606 }
607
608 #[test]
609 fn test_unsupported_extension() {
610 let source = "some content";
611 let file_info = make_file_info(source, "xyz");
612 let extractor = SymbolExtractor::new();
613 let definitions = extractor.extract_definitions(&file_info).unwrap();
614
615 assert!(definitions.is_empty());
616 }
617
618 #[test]
619 fn test_definition_storage_id() {
620 let source = "fn foo() {}";
621 let file_info = make_file_info(source, "rs");
622 let extractor = SymbolExtractor::new();
623 let definitions = extractor.extract_definitions(&file_info).unwrap();
624
625 assert!(!definitions.is_empty());
626 let def = &definitions[0];
627 let storage_id = def.to_storage_id();
628 assert!(storage_id.contains("foo"));
629 }
630}