Skip to main content

ctx_ast/
lib.rs

1use regex::Regex;
2use serde::{Deserialize, Serialize};
3use tree_sitter::{Node, Parser, Tree};
4
5#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
6#[serde(rename_all = "snake_case")]
7pub enum SymbolKind {
8    Module,
9    Class,
10    Function,
11    Test,
12    Import,
13}
14
15#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
16pub struct Symbol {
17    pub file_path: String,
18    pub name: String,
19    pub kind: SymbolKind,
20    pub signature: String,
21}
22
23#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
24pub struct SymbolSlice {
25    pub file_path: String,
26    pub symbol_name: String,
27    pub content: String,
28    pub start_line: usize,
29    pub end_line: usize,
30}
31
32#[derive(Debug, Clone)]
33struct RawSymbol {
34    symbol: Symbol,
35    start_byte: usize,
36    end_byte: usize,
37    start_line: usize,
38    end_line: usize,
39}
40
41#[derive(Debug, Clone, Copy, PartialEq, Eq)]
42enum SourceLanguage {
43    Rust,
44    Python,
45    JavaScript,
46    TypeScript,
47    Tsx,
48}
49
50pub fn extract_symbols(code: &str, file_path: &str) -> Vec<Symbol> {
51    if let Some(raw) = extract_symbols_tree_sitter(code, file_path) {
52        return raw.into_iter().map(|entry| entry.symbol).collect();
53    }
54
55    extract_symbols_regex_fallback(code, file_path)
56}
57
58pub fn slice_symbols(code: &str, file_path: &str, symbol_names: &[&str]) -> Vec<SymbolSlice> {
59    let names = symbol_names
60        .iter()
61        .map(|name| name.trim())
62        .filter(|name| !name.is_empty())
63        .collect::<Vec<_>>();
64
65    if names.is_empty() {
66        return Vec::new();
67    }
68
69    let raws = extract_symbols_tree_sitter(code, file_path)
70        .unwrap_or_else(|| fallback_raw_symbols(code, file_path));
71
72    raws.into_iter()
73        .filter(|entry| names.iter().any(|name| *name == entry.symbol.name))
74        .map(|entry| {
75            let slice = code
76                .get(entry.start_byte..entry.end_byte)
77                .unwrap_or_default();
78            SymbolSlice {
79                file_path: entry.symbol.file_path,
80                symbol_name: entry.symbol.name,
81                content: slice.to_string(),
82                start_line: entry.start_line,
83                end_line: entry.end_line,
84            }
85        })
86        .collect()
87}
88
89fn extract_symbols_tree_sitter(code: &str, file_path: &str) -> Option<Vec<RawSymbol>> {
90    let mut parser = Parser::new();
91    let language = source_language_for_file(file_path)?;
92    let language_set = match language {
93        SourceLanguage::Rust => parser.set_language(&tree_sitter_rust::LANGUAGE.into()).ok(),
94        SourceLanguage::Python => parser
95            .set_language(&tree_sitter_python::LANGUAGE.into())
96            .ok(),
97        SourceLanguage::JavaScript => parser
98            .set_language(&tree_sitter_javascript::LANGUAGE.into())
99            .ok(),
100        SourceLanguage::TypeScript => parser
101            .set_language(&tree_sitter_typescript::LANGUAGE_TYPESCRIPT.into())
102            .ok(),
103        SourceLanguage::Tsx => parser
104            .set_language(&tree_sitter_typescript::LANGUAGE_TSX.into())
105            .ok(),
106    };
107
108    language_set?;
109    let tree = parser.parse(code, None)?;
110
111    Some(match language {
112        SourceLanguage::Rust => extract_rust_symbols(code, &tree, file_path),
113        SourceLanguage::Python => extract_python_symbols(code, &tree, file_path),
114        SourceLanguage::JavaScript | SourceLanguage::TypeScript | SourceLanguage::Tsx => {
115            extract_javascript_symbols(code, &tree, file_path)
116        }
117    })
118}
119
120fn extract_rust_symbols(code: &str, tree: &Tree, file_path: &str) -> Vec<RawSymbol> {
121    let mut symbols = Vec::new();
122    let root = tree.root_node();
123    let mut stack = vec![root];
124
125    while let Some(node) = stack.pop() {
126        match node.kind() {
127            "function_item" => {
128                if let Some(name_node) = node.child_by_field_name("name") {
129                    let name = node_text(code, name_node);
130                    let signature = first_line(node_text(code, node));
131                    let kind = if name.starts_with("test_") || has_test_attribute(code, node) {
132                        SymbolKind::Test
133                    } else {
134                        SymbolKind::Function
135                    };
136                    symbols.push(raw_symbol(file_path, &name, kind, &signature, node));
137                }
138            }
139            "struct_item" | "enum_item" => {
140                if let Some(name_node) = node.child_by_field_name("name") {
141                    let name = node_text(code, name_node);
142                    let signature = first_line(node_text(code, node));
143                    symbols.push(raw_symbol(
144                        file_path,
145                        &name,
146                        SymbolKind::Class,
147                        &signature,
148                        node,
149                    ));
150                }
151            }
152            "use_declaration" => {
153                let import = first_line(node_text(code, node));
154                symbols.push(raw_symbol(
155                    file_path,
156                    &import,
157                    SymbolKind::Import,
158                    &import,
159                    node,
160                ));
161            }
162            _ => {}
163        }
164
165        let mut cursor = node.walk();
166        for child in node.children(&mut cursor) {
167            stack.push(child);
168        }
169    }
170
171    symbols
172}
173
174fn extract_python_symbols(code: &str, tree: &Tree, file_path: &str) -> Vec<RawSymbol> {
175    let mut symbols = Vec::new();
176    let root = tree.root_node();
177    let mut stack = vec![root];
178
179    while let Some(node) = stack.pop() {
180        match node.kind() {
181            "function_definition" => {
182                if let Some(name_node) = node.child_by_field_name("name") {
183                    let name = node_text(code, name_node);
184                    let signature = first_line(node_text(code, node));
185                    let kind = if name.starts_with("test_") {
186                        SymbolKind::Test
187                    } else {
188                        SymbolKind::Function
189                    };
190                    symbols.push(raw_symbol(file_path, &name, kind, &signature, node));
191                }
192            }
193            "class_definition" => {
194                if let Some(name_node) = node.child_by_field_name("name") {
195                    let name = node_text(code, name_node);
196                    let signature = first_line(node_text(code, node));
197                    symbols.push(raw_symbol(
198                        file_path,
199                        &name,
200                        SymbolKind::Class,
201                        &signature,
202                        node,
203                    ));
204                }
205            }
206            "import_statement" | "import_from_statement" => {
207                let import = first_line(node_text(code, node));
208                symbols.push(raw_symbol(
209                    file_path,
210                    &import,
211                    SymbolKind::Import,
212                    &import,
213                    node,
214                ));
215            }
216            _ => {}
217        }
218
219        let mut cursor = node.walk();
220        for child in node.children(&mut cursor) {
221            stack.push(child);
222        }
223    }
224
225    symbols
226}
227
228fn extract_javascript_symbols(code: &str, tree: &Tree, file_path: &str) -> Vec<RawSymbol> {
229    let mut symbols = Vec::new();
230    let root = tree.root_node();
231    let mut stack = vec![root];
232
233    while let Some(node) = stack.pop() {
234        match node.kind() {
235            "function_declaration" => {
236                if let Some(name_node) = node.child_by_field_name("name") {
237                    let name = node_text(code, name_node);
238                    let signature = first_line(node_text(code, node));
239                    symbols.push(raw_symbol(
240                        file_path,
241                        &name,
242                        SymbolKind::Function,
243                        &signature,
244                        node,
245                    ));
246                }
247            }
248            "class_declaration" => {
249                if let Some(name_node) = node.child_by_field_name("name") {
250                    let name = node_text(code, name_node);
251                    let signature = first_line(node_text(code, node));
252                    symbols.push(raw_symbol(
253                        file_path,
254                        &name,
255                        SymbolKind::Class,
256                        &signature,
257                        node,
258                    ));
259                }
260            }
261            "method_definition" => {
262                if let Some(name_node) = node.child_by_field_name("name") {
263                    let name = node_text(code, name_node);
264                    let signature = first_line(node_text(code, node));
265                    let kind = if is_js_test_name(&name) {
266                        SymbolKind::Test
267                    } else {
268                        SymbolKind::Function
269                    };
270                    symbols.push(raw_symbol(file_path, &name, kind, &signature, node));
271                }
272            }
273            "import_statement" => {
274                let import = first_line(node_text(code, node));
275                symbols.push(raw_symbol(
276                    file_path,
277                    &import,
278                    SymbolKind::Import,
279                    &import,
280                    node,
281                ));
282            }
283            "lexical_declaration" | "variable_declaration" => {
284                extract_js_variable_symbols(code, file_path, node, &mut symbols);
285            }
286            "call_expression" => {
287                if let Some(test_symbol) = extract_js_test_call(code, file_path, node) {
288                    symbols.push(test_symbol);
289                }
290            }
291            _ => {}
292        }
293
294        let mut cursor = node.walk();
295        for child in node.children(&mut cursor) {
296            stack.push(child);
297        }
298    }
299
300    symbols
301}
302
303fn raw_symbol(
304    file_path: &str,
305    name: &str,
306    kind: SymbolKind,
307    signature: &str,
308    node: Node<'_>,
309) -> RawSymbol {
310    RawSymbol {
311        symbol: Symbol {
312            file_path: file_path.to_string(),
313            name: name.to_string(),
314            kind,
315            signature: signature.to_string(),
316        },
317        start_byte: node.start_byte(),
318        end_byte: node.end_byte(),
319        start_line: node.start_position().row + 1,
320        end_line: node.end_position().row + 1,
321    }
322}
323
324fn raw_symbol_with_span(
325    file_path: &str,
326    name: &str,
327    kind: SymbolKind,
328    signature: &str,
329    span_node: Node<'_>,
330) -> RawSymbol {
331    RawSymbol {
332        symbol: Symbol {
333            file_path: file_path.to_string(),
334            name: name.to_string(),
335            kind,
336            signature: signature.to_string(),
337        },
338        start_byte: span_node.start_byte(),
339        end_byte: span_node.end_byte(),
340        start_line: span_node.start_position().row + 1,
341        end_line: span_node.end_position().row + 1,
342    }
343}
344
345fn has_test_attribute(code: &str, node: Node<'_>) -> bool {
346    let start = node.start_byte();
347    if start == 0 {
348        return false;
349    }
350
351    let prefix = &code[..start];
352    prefix
353        .lines()
354        .rev()
355        .take(3)
356        .any(|line| line.trim().starts_with("#[test]"))
357}
358
359fn node_text(code: &str, node: Node<'_>) -> String {
360    code.get(node.byte_range()).unwrap_or_default().to_string()
361}
362
363fn first_line(text: String) -> String {
364    text.lines().next().unwrap_or_default().trim().to_string()
365}
366
367fn extract_symbols_regex_fallback(code: &str, file_path: &str) -> Vec<Symbol> {
368    fallback_raw_symbols(code, file_path)
369        .into_iter()
370        .map(|entry| entry.symbol)
371        .collect()
372}
373
374fn fallback_raw_symbols(code: &str, file_path: &str) -> Vec<RawSymbol> {
375    let rust_fn =
376        Regex::new(r"(?m)^\s*(?:pub\s+)?fn\s+([a-zA-Z0-9_]+)\s*\(([^)]*)\)").expect("regex");
377    let py_fn = Regex::new(r"(?m)^\s*def\s+([a-zA-Z0-9_]+)\s*\(([^)]*)\)").expect("regex");
378    let py_class = Regex::new(r"(?m)^\s*class\s+([a-zA-Z0-9_]+)").expect("regex");
379    let js_import = Regex::new(r"(?m)^\s*import\s+.+$").expect("regex");
380    let js_fn =
381        Regex::new(r"(?m)^\s*(?:export\s+)?(?:async\s+)?function\s+([a-zA-Z0-9_]+)\s*\(([^)]*)\)")
382            .expect("regex");
383    let js_class = Regex::new(r"(?m)^\s*(?:export\s+)?class\s+([a-zA-Z0-9_]+)").expect("regex");
384    let js_arrow = Regex::new(
385        r"(?m)^\s*(?:export\s+)?(?:const|let|var)\s+([a-zA-Z0-9_]+)\s*=\s*(?:async\s*)?\(([^)]*)\)\s*=>",
386    )
387    .expect("regex");
388    let js_test =
389        Regex::new(r#"(?m)^\s*(?:test|it|describe)\(\s*["']([^"']+)["']\s*,"#).expect("regex");
390    let md_heading = Regex::new(r"(?m)^(#{1,3})\s+(.+?)\s*$").expect("regex");
391    let mut out = Vec::new();
392
393    for captures in rust_fn.captures_iter(code) {
394        let Some(m) = captures.get(0) else {
395            continue;
396        };
397        let name = captures.get(1).map(|v| v.as_str()).unwrap_or_default();
398        let args = captures.get(2).map(|v| v.as_str()).unwrap_or_default();
399
400        out.push(RawSymbol {
401            symbol: Symbol {
402                file_path: file_path.to_string(),
403                name: name.to_string(),
404                kind: SymbolKind::Function,
405                signature: format!("fn {name}({args})"),
406            },
407            start_byte: m.start(),
408            end_byte: m.end(),
409            start_line: line_of_byte(code, m.start()),
410            end_line: line_of_byte(code, m.end()),
411        });
412    }
413
414    for captures in py_fn.captures_iter(code) {
415        let Some(m) = captures.get(0) else {
416            continue;
417        };
418        let name = captures.get(1).map(|v| v.as_str()).unwrap_or_default();
419        let args = captures.get(2).map(|v| v.as_str()).unwrap_or_default();
420
421        out.push(RawSymbol {
422            symbol: Symbol {
423                file_path: file_path.to_string(),
424                name: name.to_string(),
425                kind: if name.starts_with("test_") {
426                    SymbolKind::Test
427                } else {
428                    SymbolKind::Function
429                },
430                signature: format!("def {name}({args})"),
431            },
432            start_byte: m.start(),
433            end_byte: m.end(),
434            start_line: line_of_byte(code, m.start()),
435            end_line: line_of_byte(code, m.end()),
436        });
437    }
438
439    for captures in py_class.captures_iter(code) {
440        let Some(m) = captures.get(0) else {
441            continue;
442        };
443        let name = captures.get(1).map(|v| v.as_str()).unwrap_or_default();
444        out.push(RawSymbol {
445            symbol: Symbol {
446                file_path: file_path.to_string(),
447                name: name.to_string(),
448                kind: SymbolKind::Class,
449                signature: format!("class {name}"),
450            },
451            start_byte: m.start(),
452            end_byte: m.end(),
453            start_line: line_of_byte(code, m.start()),
454            end_line: line_of_byte(code, m.end()),
455        });
456    }
457
458    for captures in js_import.captures_iter(code) {
459        let Some(m) = captures.get(0) else {
460            continue;
461        };
462        let import = m.as_str().trim();
463        out.push(RawSymbol {
464            symbol: Symbol {
465                file_path: file_path.to_string(),
466                name: import.to_string(),
467                kind: SymbolKind::Import,
468                signature: import.to_string(),
469            },
470            start_byte: m.start(),
471            end_byte: m.end(),
472            start_line: line_of_byte(code, m.start()),
473            end_line: line_of_byte(code, m.end()),
474        });
475    }
476
477    for captures in js_fn.captures_iter(code) {
478        let Some(m) = captures.get(0) else {
479            continue;
480        };
481        let name = captures.get(1).map(|v| v.as_str()).unwrap_or_default();
482        let args = captures.get(2).map(|v| v.as_str()).unwrap_or_default();
483        out.push(RawSymbol {
484            symbol: Symbol {
485                file_path: file_path.to_string(),
486                name: name.to_string(),
487                kind: if is_js_test_name(name) {
488                    SymbolKind::Test
489                } else {
490                    SymbolKind::Function
491                },
492                signature: format!("function {name}({args})"),
493            },
494            start_byte: m.start(),
495            end_byte: m.end(),
496            start_line: line_of_byte(code, m.start()),
497            end_line: line_of_byte(code, m.end()),
498        });
499    }
500
501    for captures in js_class.captures_iter(code) {
502        let Some(m) = captures.get(0) else {
503            continue;
504        };
505        let name = captures.get(1).map(|v| v.as_str()).unwrap_or_default();
506        out.push(RawSymbol {
507            symbol: Symbol {
508                file_path: file_path.to_string(),
509                name: name.to_string(),
510                kind: SymbolKind::Class,
511                signature: format!("class {name}"),
512            },
513            start_byte: m.start(),
514            end_byte: m.end(),
515            start_line: line_of_byte(code, m.start()),
516            end_line: line_of_byte(code, m.end()),
517        });
518    }
519
520    for captures in js_arrow.captures_iter(code) {
521        let Some(m) = captures.get(0) else {
522            continue;
523        };
524        let name = captures.get(1).map(|v| v.as_str()).unwrap_or_default();
525        let args = captures.get(2).map(|v| v.as_str()).unwrap_or_default();
526        out.push(RawSymbol {
527            symbol: Symbol {
528                file_path: file_path.to_string(),
529                name: name.to_string(),
530                kind: if is_js_test_name(name) {
531                    SymbolKind::Test
532                } else {
533                    SymbolKind::Function
534                },
535                signature: format!("const {name} = ({args}) =>"),
536            },
537            start_byte: m.start(),
538            end_byte: m.end(),
539            start_line: line_of_byte(code, m.start()),
540            end_line: line_of_byte(code, m.end()),
541        });
542    }
543
544    for captures in js_test.captures_iter(code) {
545        let Some(m) = captures.get(0) else {
546            continue;
547        };
548        let name = captures.get(1).map(|v| v.as_str()).unwrap_or_default();
549        out.push(RawSymbol {
550            symbol: Symbol {
551                file_path: file_path.to_string(),
552                name: name.to_string(),
553                kind: SymbolKind::Test,
554                signature: m.as_str().trim().to_string(),
555            },
556            start_byte: m.start(),
557            end_byte: m.end(),
558            start_line: line_of_byte(code, m.start()),
559            end_line: line_of_byte(code, m.end()),
560        });
561    }
562
563    if file_path.ends_with(".md") {
564        for captures in md_heading.captures_iter(code) {
565            let Some(m) = captures.get(0) else {
566                continue;
567            };
568            let name = captures
569                .get(2)
570                .map(|v| v.as_str())
571                .unwrap_or_default()
572                .trim();
573            if name.is_empty() {
574                continue;
575            }
576
577            out.push(RawSymbol {
578                symbol: Symbol {
579                    file_path: file_path.to_string(),
580                    name: name.to_string(),
581                    kind: SymbolKind::Module,
582                    signature: m.as_str().trim().to_string(),
583                },
584                start_byte: m.start(),
585                end_byte: m.end(),
586                start_line: line_of_byte(code, m.start()),
587                end_line: line_of_byte(code, m.end()),
588            });
589        }
590    }
591
592    out
593}
594
595fn source_language_for_file(file_path: &str) -> Option<SourceLanguage> {
596    if file_path.ends_with(".rs") {
597        Some(SourceLanguage::Rust)
598    } else if file_path.ends_with(".py") {
599        Some(SourceLanguage::Python)
600    } else if file_path.ends_with(".tsx") {
601        Some(SourceLanguage::Tsx)
602    } else if file_path.ends_with(".ts") {
603        Some(SourceLanguage::TypeScript)
604    } else if file_path.ends_with(".js")
605        || file_path.ends_with(".jsx")
606        || file_path.ends_with(".mjs")
607        || file_path.ends_with(".cjs")
608    {
609        Some(SourceLanguage::JavaScript)
610    } else {
611        None
612    }
613}
614
615fn extract_js_variable_symbols(
616    code: &str,
617    file_path: &str,
618    declaration_node: Node<'_>,
619    out: &mut Vec<RawSymbol>,
620) {
621    let mut cursor = declaration_node.walk();
622    for child in declaration_node.children(&mut cursor) {
623        if child.kind() != "variable_declarator" {
624            continue;
625        }
626        let Some(name_node) = child.child_by_field_name("name") else {
627            continue;
628        };
629        let Some(value_node) = child.child_by_field_name("value") else {
630            continue;
631        };
632        if value_node.kind() != "arrow_function" && value_node.kind() != "function" {
633            continue;
634        }
635
636        let name = node_text(code, name_node);
637        let signature = first_line(node_text(code, declaration_node));
638        let kind = if is_js_test_name(&name) {
639            SymbolKind::Test
640        } else {
641            SymbolKind::Function
642        };
643        out.push(raw_symbol_with_span(
644            file_path,
645            &name,
646            kind,
647            &signature,
648            declaration_node,
649        ));
650    }
651}
652
653fn extract_js_test_call(code: &str, file_path: &str, node: Node<'_>) -> Option<RawSymbol> {
654    let function_node = node.child_by_field_name("function")?;
655    let callee = node_text(code, function_node);
656    if callee != "test" && callee != "it" && callee != "describe" {
657        return None;
658    }
659
660    let arguments_node = node.child_by_field_name("arguments")?;
661    let mut cursor = arguments_node.walk();
662    let first_argument = arguments_node
663        .named_children(&mut cursor)
664        .find(|child| child.kind() == "string")?;
665    let raw_name = node_text(code, first_argument);
666    let name = raw_name
667        .trim()
668        .trim_matches('"')
669        .trim_matches('\'')
670        .to_string();
671    let signature = first_line(node_text(code, node));
672    Some(raw_symbol(
673        file_path,
674        &name,
675        SymbolKind::Test,
676        &signature,
677        node,
678    ))
679}
680
681fn is_js_test_name(name: &str) -> bool {
682    name.starts_with("test") || name.ends_with("Test")
683}
684
685fn line_of_byte(code: &str, byte_idx: usize) -> usize {
686    code[..byte_idx.min(code.len())]
687        .bytes()
688        .filter(|b| *b == b'\n')
689        .count()
690        + 1
691}