1use regex::Regex;
2use serde::{Deserialize, Serialize};
3use tree_sitter::{Node, Parser, Tree};
4
5#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
6#[serde(rename_all = "snake_case")]
7pub enum SymbolKind {
8 Module,
9 Class,
10 Function,
11 Test,
12 Import,
13}
14
15#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
16pub struct Symbol {
17 pub file_path: String,
18 pub name: String,
19 pub kind: SymbolKind,
20 pub signature: String,
21}
22
23#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
24pub struct SymbolSlice {
25 pub file_path: String,
26 pub symbol_name: String,
27 pub content: String,
28 pub start_line: usize,
29 pub end_line: usize,
30}
31
32#[derive(Debug, Clone)]
33struct RawSymbol {
34 symbol: Symbol,
35 start_byte: usize,
36 end_byte: usize,
37 start_line: usize,
38 end_line: usize,
39}
40
41#[derive(Debug, Clone, Copy, PartialEq, Eq)]
42enum SourceLanguage {
43 Rust,
44 Python,
45 JavaScript,
46 TypeScript,
47 Tsx,
48}
49
50pub fn extract_symbols(code: &str, file_path: &str) -> Vec<Symbol> {
51 if let Some(raw) = extract_symbols_tree_sitter(code, file_path) {
52 return raw.into_iter().map(|entry| entry.symbol).collect();
53 }
54
55 extract_symbols_regex_fallback(code, file_path)
56}
57
58pub fn slice_symbols(code: &str, file_path: &str, symbol_names: &[&str]) -> Vec<SymbolSlice> {
59 let names = symbol_names
60 .iter()
61 .map(|name| name.trim())
62 .filter(|name| !name.is_empty())
63 .collect::<Vec<_>>();
64
65 if names.is_empty() {
66 return Vec::new();
67 }
68
69 let raws = extract_symbols_tree_sitter(code, file_path)
70 .unwrap_or_else(|| fallback_raw_symbols(code, file_path));
71
72 raws.into_iter()
73 .filter(|entry| names.iter().any(|name| *name == entry.symbol.name))
74 .map(|entry| {
75 let slice = code
76 .get(entry.start_byte..entry.end_byte)
77 .unwrap_or_default();
78 SymbolSlice {
79 file_path: entry.symbol.file_path,
80 symbol_name: entry.symbol.name,
81 content: slice.to_string(),
82 start_line: entry.start_line,
83 end_line: entry.end_line,
84 }
85 })
86 .collect()
87}
88
89fn extract_symbols_tree_sitter(code: &str, file_path: &str) -> Option<Vec<RawSymbol>> {
90 let mut parser = Parser::new();
91 let language = source_language_for_file(file_path)?;
92 let language_set = match language {
93 SourceLanguage::Rust => parser.set_language(&tree_sitter_rust::LANGUAGE.into()).ok(),
94 SourceLanguage::Python => parser
95 .set_language(&tree_sitter_python::LANGUAGE.into())
96 .ok(),
97 SourceLanguage::JavaScript => parser
98 .set_language(&tree_sitter_javascript::LANGUAGE.into())
99 .ok(),
100 SourceLanguage::TypeScript => parser
101 .set_language(&tree_sitter_typescript::LANGUAGE_TYPESCRIPT.into())
102 .ok(),
103 SourceLanguage::Tsx => parser
104 .set_language(&tree_sitter_typescript::LANGUAGE_TSX.into())
105 .ok(),
106 };
107
108 language_set?;
109 let tree = parser.parse(code, None)?;
110
111 Some(match language {
112 SourceLanguage::Rust => extract_rust_symbols(code, &tree, file_path),
113 SourceLanguage::Python => extract_python_symbols(code, &tree, file_path),
114 SourceLanguage::JavaScript | SourceLanguage::TypeScript | SourceLanguage::Tsx => {
115 extract_javascript_symbols(code, &tree, file_path)
116 }
117 })
118}
119
120fn extract_rust_symbols(code: &str, tree: &Tree, file_path: &str) -> Vec<RawSymbol> {
121 let mut symbols = Vec::new();
122 let root = tree.root_node();
123 let mut stack = vec![root];
124
125 while let Some(node) = stack.pop() {
126 match node.kind() {
127 "function_item" => {
128 if let Some(name_node) = node.child_by_field_name("name") {
129 let name = node_text(code, name_node);
130 let signature = first_line(node_text(code, node));
131 let kind = if name.starts_with("test_") || has_test_attribute(code, node) {
132 SymbolKind::Test
133 } else {
134 SymbolKind::Function
135 };
136 symbols.push(raw_symbol(file_path, &name, kind, &signature, node));
137 }
138 }
139 "struct_item" | "enum_item" => {
140 if let Some(name_node) = node.child_by_field_name("name") {
141 let name = node_text(code, name_node);
142 let signature = first_line(node_text(code, node));
143 symbols.push(raw_symbol(
144 file_path,
145 &name,
146 SymbolKind::Class,
147 &signature,
148 node,
149 ));
150 }
151 }
152 "use_declaration" => {
153 let import = first_line(node_text(code, node));
154 symbols.push(raw_symbol(
155 file_path,
156 &import,
157 SymbolKind::Import,
158 &import,
159 node,
160 ));
161 }
162 _ => {}
163 }
164
165 let mut cursor = node.walk();
166 for child in node.children(&mut cursor) {
167 stack.push(child);
168 }
169 }
170
171 symbols
172}
173
174fn extract_python_symbols(code: &str, tree: &Tree, file_path: &str) -> Vec<RawSymbol> {
175 let mut symbols = Vec::new();
176 let root = tree.root_node();
177 let mut stack = vec![root];
178
179 while let Some(node) = stack.pop() {
180 match node.kind() {
181 "function_definition" => {
182 if let Some(name_node) = node.child_by_field_name("name") {
183 let name = node_text(code, name_node);
184 let signature = first_line(node_text(code, node));
185 let kind = if name.starts_with("test_") {
186 SymbolKind::Test
187 } else {
188 SymbolKind::Function
189 };
190 symbols.push(raw_symbol(file_path, &name, kind, &signature, node));
191 }
192 }
193 "class_definition" => {
194 if let Some(name_node) = node.child_by_field_name("name") {
195 let name = node_text(code, name_node);
196 let signature = first_line(node_text(code, node));
197 symbols.push(raw_symbol(
198 file_path,
199 &name,
200 SymbolKind::Class,
201 &signature,
202 node,
203 ));
204 }
205 }
206 "import_statement" | "import_from_statement" => {
207 let import = first_line(node_text(code, node));
208 symbols.push(raw_symbol(
209 file_path,
210 &import,
211 SymbolKind::Import,
212 &import,
213 node,
214 ));
215 }
216 _ => {}
217 }
218
219 let mut cursor = node.walk();
220 for child in node.children(&mut cursor) {
221 stack.push(child);
222 }
223 }
224
225 symbols
226}
227
228fn extract_javascript_symbols(code: &str, tree: &Tree, file_path: &str) -> Vec<RawSymbol> {
229 let mut symbols = Vec::new();
230 let root = tree.root_node();
231 let mut stack = vec![root];
232
233 while let Some(node) = stack.pop() {
234 match node.kind() {
235 "function_declaration" => {
236 if let Some(name_node) = node.child_by_field_name("name") {
237 let name = node_text(code, name_node);
238 let signature = first_line(node_text(code, node));
239 symbols.push(raw_symbol(
240 file_path,
241 &name,
242 SymbolKind::Function,
243 &signature,
244 node,
245 ));
246 }
247 }
248 "class_declaration" => {
249 if let Some(name_node) = node.child_by_field_name("name") {
250 let name = node_text(code, name_node);
251 let signature = first_line(node_text(code, node));
252 symbols.push(raw_symbol(
253 file_path,
254 &name,
255 SymbolKind::Class,
256 &signature,
257 node,
258 ));
259 }
260 }
261 "method_definition" => {
262 if let Some(name_node) = node.child_by_field_name("name") {
263 let name = node_text(code, name_node);
264 let signature = first_line(node_text(code, node));
265 let kind = if is_js_test_name(&name) {
266 SymbolKind::Test
267 } else {
268 SymbolKind::Function
269 };
270 symbols.push(raw_symbol(file_path, &name, kind, &signature, node));
271 }
272 }
273 "import_statement" => {
274 let import = first_line(node_text(code, node));
275 symbols.push(raw_symbol(
276 file_path,
277 &import,
278 SymbolKind::Import,
279 &import,
280 node,
281 ));
282 }
283 "lexical_declaration" | "variable_declaration" => {
284 extract_js_variable_symbols(code, file_path, node, &mut symbols);
285 }
286 "call_expression" => {
287 if let Some(test_symbol) = extract_js_test_call(code, file_path, node) {
288 symbols.push(test_symbol);
289 }
290 }
291 _ => {}
292 }
293
294 let mut cursor = node.walk();
295 for child in node.children(&mut cursor) {
296 stack.push(child);
297 }
298 }
299
300 symbols
301}
302
303fn raw_symbol(
304 file_path: &str,
305 name: &str,
306 kind: SymbolKind,
307 signature: &str,
308 node: Node<'_>,
309) -> RawSymbol {
310 RawSymbol {
311 symbol: Symbol {
312 file_path: file_path.to_string(),
313 name: name.to_string(),
314 kind,
315 signature: signature.to_string(),
316 },
317 start_byte: node.start_byte(),
318 end_byte: node.end_byte(),
319 start_line: node.start_position().row + 1,
320 end_line: node.end_position().row + 1,
321 }
322}
323
324fn raw_symbol_with_span(
325 file_path: &str,
326 name: &str,
327 kind: SymbolKind,
328 signature: &str,
329 span_node: Node<'_>,
330) -> RawSymbol {
331 RawSymbol {
332 symbol: Symbol {
333 file_path: file_path.to_string(),
334 name: name.to_string(),
335 kind,
336 signature: signature.to_string(),
337 },
338 start_byte: span_node.start_byte(),
339 end_byte: span_node.end_byte(),
340 start_line: span_node.start_position().row + 1,
341 end_line: span_node.end_position().row + 1,
342 }
343}
344
345fn has_test_attribute(code: &str, node: Node<'_>) -> bool {
346 let start = node.start_byte();
347 if start == 0 {
348 return false;
349 }
350
351 let prefix = &code[..start];
352 prefix
353 .lines()
354 .rev()
355 .take(3)
356 .any(|line| line.trim().starts_with("#[test]"))
357}
358
359fn node_text(code: &str, node: Node<'_>) -> String {
360 code.get(node.byte_range()).unwrap_or_default().to_string()
361}
362
363fn first_line(text: String) -> String {
364 text.lines().next().unwrap_or_default().trim().to_string()
365}
366
367fn extract_symbols_regex_fallback(code: &str, file_path: &str) -> Vec<Symbol> {
368 fallback_raw_symbols(code, file_path)
369 .into_iter()
370 .map(|entry| entry.symbol)
371 .collect()
372}
373
374fn fallback_raw_symbols(code: &str, file_path: &str) -> Vec<RawSymbol> {
375 let rust_fn =
376 Regex::new(r"(?m)^\s*(?:pub\s+)?fn\s+([a-zA-Z0-9_]+)\s*\(([^)]*)\)").expect("regex");
377 let py_fn = Regex::new(r"(?m)^\s*def\s+([a-zA-Z0-9_]+)\s*\(([^)]*)\)").expect("regex");
378 let py_class = Regex::new(r"(?m)^\s*class\s+([a-zA-Z0-9_]+)").expect("regex");
379 let js_import = Regex::new(r"(?m)^\s*import\s+.+$").expect("regex");
380 let js_fn =
381 Regex::new(r"(?m)^\s*(?:export\s+)?(?:async\s+)?function\s+([a-zA-Z0-9_]+)\s*\(([^)]*)\)")
382 .expect("regex");
383 let js_class = Regex::new(r"(?m)^\s*(?:export\s+)?class\s+([a-zA-Z0-9_]+)").expect("regex");
384 let js_arrow = Regex::new(
385 r"(?m)^\s*(?:export\s+)?(?:const|let|var)\s+([a-zA-Z0-9_]+)\s*=\s*(?:async\s*)?\(([^)]*)\)\s*=>",
386 )
387 .expect("regex");
388 let js_test =
389 Regex::new(r#"(?m)^\s*(?:test|it|describe)\(\s*["']([^"']+)["']\s*,"#).expect("regex");
390 let md_heading = Regex::new(r"(?m)^(#{1,3})\s+(.+?)\s*$").expect("regex");
391 let mut out = Vec::new();
392
393 for captures in rust_fn.captures_iter(code) {
394 let Some(m) = captures.get(0) else {
395 continue;
396 };
397 let name = captures.get(1).map(|v| v.as_str()).unwrap_or_default();
398 let args = captures.get(2).map(|v| v.as_str()).unwrap_or_default();
399
400 out.push(RawSymbol {
401 symbol: Symbol {
402 file_path: file_path.to_string(),
403 name: name.to_string(),
404 kind: SymbolKind::Function,
405 signature: format!("fn {name}({args})"),
406 },
407 start_byte: m.start(),
408 end_byte: m.end(),
409 start_line: line_of_byte(code, m.start()),
410 end_line: line_of_byte(code, m.end()),
411 });
412 }
413
414 for captures in py_fn.captures_iter(code) {
415 let Some(m) = captures.get(0) else {
416 continue;
417 };
418 let name = captures.get(1).map(|v| v.as_str()).unwrap_or_default();
419 let args = captures.get(2).map(|v| v.as_str()).unwrap_or_default();
420
421 out.push(RawSymbol {
422 symbol: Symbol {
423 file_path: file_path.to_string(),
424 name: name.to_string(),
425 kind: if name.starts_with("test_") {
426 SymbolKind::Test
427 } else {
428 SymbolKind::Function
429 },
430 signature: format!("def {name}({args})"),
431 },
432 start_byte: m.start(),
433 end_byte: m.end(),
434 start_line: line_of_byte(code, m.start()),
435 end_line: line_of_byte(code, m.end()),
436 });
437 }
438
439 for captures in py_class.captures_iter(code) {
440 let Some(m) = captures.get(0) else {
441 continue;
442 };
443 let name = captures.get(1).map(|v| v.as_str()).unwrap_or_default();
444 out.push(RawSymbol {
445 symbol: Symbol {
446 file_path: file_path.to_string(),
447 name: name.to_string(),
448 kind: SymbolKind::Class,
449 signature: format!("class {name}"),
450 },
451 start_byte: m.start(),
452 end_byte: m.end(),
453 start_line: line_of_byte(code, m.start()),
454 end_line: line_of_byte(code, m.end()),
455 });
456 }
457
458 for captures in js_import.captures_iter(code) {
459 let Some(m) = captures.get(0) else {
460 continue;
461 };
462 let import = m.as_str().trim();
463 out.push(RawSymbol {
464 symbol: Symbol {
465 file_path: file_path.to_string(),
466 name: import.to_string(),
467 kind: SymbolKind::Import,
468 signature: import.to_string(),
469 },
470 start_byte: m.start(),
471 end_byte: m.end(),
472 start_line: line_of_byte(code, m.start()),
473 end_line: line_of_byte(code, m.end()),
474 });
475 }
476
477 for captures in js_fn.captures_iter(code) {
478 let Some(m) = captures.get(0) else {
479 continue;
480 };
481 let name = captures.get(1).map(|v| v.as_str()).unwrap_or_default();
482 let args = captures.get(2).map(|v| v.as_str()).unwrap_or_default();
483 out.push(RawSymbol {
484 symbol: Symbol {
485 file_path: file_path.to_string(),
486 name: name.to_string(),
487 kind: if is_js_test_name(name) {
488 SymbolKind::Test
489 } else {
490 SymbolKind::Function
491 },
492 signature: format!("function {name}({args})"),
493 },
494 start_byte: m.start(),
495 end_byte: m.end(),
496 start_line: line_of_byte(code, m.start()),
497 end_line: line_of_byte(code, m.end()),
498 });
499 }
500
501 for captures in js_class.captures_iter(code) {
502 let Some(m) = captures.get(0) else {
503 continue;
504 };
505 let name = captures.get(1).map(|v| v.as_str()).unwrap_or_default();
506 out.push(RawSymbol {
507 symbol: Symbol {
508 file_path: file_path.to_string(),
509 name: name.to_string(),
510 kind: SymbolKind::Class,
511 signature: format!("class {name}"),
512 },
513 start_byte: m.start(),
514 end_byte: m.end(),
515 start_line: line_of_byte(code, m.start()),
516 end_line: line_of_byte(code, m.end()),
517 });
518 }
519
520 for captures in js_arrow.captures_iter(code) {
521 let Some(m) = captures.get(0) else {
522 continue;
523 };
524 let name = captures.get(1).map(|v| v.as_str()).unwrap_or_default();
525 let args = captures.get(2).map(|v| v.as_str()).unwrap_or_default();
526 out.push(RawSymbol {
527 symbol: Symbol {
528 file_path: file_path.to_string(),
529 name: name.to_string(),
530 kind: if is_js_test_name(name) {
531 SymbolKind::Test
532 } else {
533 SymbolKind::Function
534 },
535 signature: format!("const {name} = ({args}) =>"),
536 },
537 start_byte: m.start(),
538 end_byte: m.end(),
539 start_line: line_of_byte(code, m.start()),
540 end_line: line_of_byte(code, m.end()),
541 });
542 }
543
544 for captures in js_test.captures_iter(code) {
545 let Some(m) = captures.get(0) else {
546 continue;
547 };
548 let name = captures.get(1).map(|v| v.as_str()).unwrap_or_default();
549 out.push(RawSymbol {
550 symbol: Symbol {
551 file_path: file_path.to_string(),
552 name: name.to_string(),
553 kind: SymbolKind::Test,
554 signature: m.as_str().trim().to_string(),
555 },
556 start_byte: m.start(),
557 end_byte: m.end(),
558 start_line: line_of_byte(code, m.start()),
559 end_line: line_of_byte(code, m.end()),
560 });
561 }
562
563 if file_path.ends_with(".md") {
564 for captures in md_heading.captures_iter(code) {
565 let Some(m) = captures.get(0) else {
566 continue;
567 };
568 let name = captures
569 .get(2)
570 .map(|v| v.as_str())
571 .unwrap_or_default()
572 .trim();
573 if name.is_empty() {
574 continue;
575 }
576
577 out.push(RawSymbol {
578 symbol: Symbol {
579 file_path: file_path.to_string(),
580 name: name.to_string(),
581 kind: SymbolKind::Module,
582 signature: m.as_str().trim().to_string(),
583 },
584 start_byte: m.start(),
585 end_byte: m.end(),
586 start_line: line_of_byte(code, m.start()),
587 end_line: line_of_byte(code, m.end()),
588 });
589 }
590 }
591
592 out
593}
594
595fn source_language_for_file(file_path: &str) -> Option<SourceLanguage> {
596 if file_path.ends_with(".rs") {
597 Some(SourceLanguage::Rust)
598 } else if file_path.ends_with(".py") {
599 Some(SourceLanguage::Python)
600 } else if file_path.ends_with(".tsx") {
601 Some(SourceLanguage::Tsx)
602 } else if file_path.ends_with(".ts") {
603 Some(SourceLanguage::TypeScript)
604 } else if file_path.ends_with(".js")
605 || file_path.ends_with(".jsx")
606 || file_path.ends_with(".mjs")
607 || file_path.ends_with(".cjs")
608 {
609 Some(SourceLanguage::JavaScript)
610 } else {
611 None
612 }
613}
614
615fn extract_js_variable_symbols(
616 code: &str,
617 file_path: &str,
618 declaration_node: Node<'_>,
619 out: &mut Vec<RawSymbol>,
620) {
621 let mut cursor = declaration_node.walk();
622 for child in declaration_node.children(&mut cursor) {
623 if child.kind() != "variable_declarator" {
624 continue;
625 }
626 let Some(name_node) = child.child_by_field_name("name") else {
627 continue;
628 };
629 let Some(value_node) = child.child_by_field_name("value") else {
630 continue;
631 };
632 if value_node.kind() != "arrow_function" && value_node.kind() != "function" {
633 continue;
634 }
635
636 let name = node_text(code, name_node);
637 let signature = first_line(node_text(code, declaration_node));
638 let kind = if is_js_test_name(&name) {
639 SymbolKind::Test
640 } else {
641 SymbolKind::Function
642 };
643 out.push(raw_symbol_with_span(
644 file_path,
645 &name,
646 kind,
647 &signature,
648 declaration_node,
649 ));
650 }
651}
652
653fn extract_js_test_call(code: &str, file_path: &str, node: Node<'_>) -> Option<RawSymbol> {
654 let function_node = node.child_by_field_name("function")?;
655 let callee = node_text(code, function_node);
656 if callee != "test" && callee != "it" && callee != "describe" {
657 return None;
658 }
659
660 let arguments_node = node.child_by_field_name("arguments")?;
661 let mut cursor = arguments_node.walk();
662 let first_argument = arguments_node
663 .named_children(&mut cursor)
664 .find(|child| child.kind() == "string")?;
665 let raw_name = node_text(code, first_argument);
666 let name = raw_name
667 .trim()
668 .trim_matches('"')
669 .trim_matches('\'')
670 .to_string();
671 let signature = first_line(node_text(code, node));
672 Some(raw_symbol(
673 file_path,
674 &name,
675 SymbolKind::Test,
676 &signature,
677 node,
678 ))
679}
680
681fn is_js_test_name(name: &str) -> bool {
682 name.starts_with("test") || name.ends_with("Test")
683}
684
685fn line_of_byte(code: &str, byte_idx: usize) -> usize {
686 code[..byte_idx.min(code.len())]
687 .bytes()
688 .filter(|b| *b == b'\n')
689 .count()
690 + 1
691}