cs/parse/
sitter.rs

1use anyhow::{Context, Result};
2use std::collections::HashMap;
3use std::path::Path;
4use tree_sitter::{Language, Parser, Query, QueryCursor};
5
6/// Supported languages for Tree-sitter parsing
7#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
8pub enum SupportedLanguage {
9    Rust,
10    Python,
11    JavaScript,
12    TypeScript,
13    Ruby,
14    CSharp,
15    // Erb, // Temporarily disabled due to tree-sitter version conflict
16}
17
18impl SupportedLanguage {
19    pub fn from_path(path: &Path) -> Option<Self> {
20        match path.extension()?.to_str()? {
21            "rs" => Some(Self::Rust),
22            "py" => Some(Self::Python),
23            "js" | "jsx" => Some(Self::JavaScript),
24            "ts" | "tsx" => Some(Self::TypeScript),
25            "rb" => Some(Self::Ruby),
26            "cs" => Some(Self::CSharp),
27            // "erb" => Some(Self::Erb), // Temporarily disabled
28            _ => None,
29        }
30    }
31
32    pub fn language(&self) -> Language {
33        match self {
34            Self::Rust => tree_sitter_rust::language(),
35            Self::Python => tree_sitter_python::language(),
36            Self::JavaScript => tree_sitter_javascript::language(),
37            Self::TypeScript => tree_sitter_typescript::language_typescript(),
38            Self::Ruby => tree_sitter_ruby::language(),
39            Self::CSharp => tree_sitter_c_sharp::language(),
40            // Self::Erb => tree_sitter_embedded_template::language(),
41        }
42    }
43}
44
45/// Sitter handles Tree-sitter parsing for multiple languages
46pub struct Sitter {
47    parsers: HashMap<SupportedLanguage, Parser>,
48    queries: HashMap<SupportedLanguage, Query>,
49}
50
51impl Default for Sitter {
52    fn default() -> Self {
53        Self::new()
54    }
55}
56
57impl Sitter {
58    pub fn new() -> Self {
59        Self {
60            parsers: HashMap::new(),
61            queries: HashMap::new(),
62        }
63    }
64
65    /// Check if the file at the given path is supported by Tree-sitter
66    pub fn is_supported(&self, path: &Path) -> bool {
67        SupportedLanguage::from_path(path).is_some()
68    }
69
70    /// Get or create a parser for the given language
71    fn get_parser(&mut self, lang: SupportedLanguage) -> Result<&mut Parser> {
72        if let std::collections::hash_map::Entry::Vacant(e) = self.parsers.entry(lang) {
73            let mut parser = Parser::new();
74            parser
75                .set_language(lang.language())
76                .context("Failed to set parser language")?;
77            e.insert(parser);
78        }
79        Ok(self.parsers.get_mut(&lang).unwrap())
80    }
81
82    /// Get or create a query for the given language
83    fn get_query(&mut self, lang: SupportedLanguage) -> Result<&Query> {
84        if let std::collections::hash_map::Entry::Vacant(e) = self.queries.entry(lang) {
85            let query_str = match lang {
86                SupportedLanguage::Rust => {
87                    r#"
88                    (function_item name: (identifier) @name)
89                    (function_signature_item name: (identifier) @name)
90                "#
91                }
92                SupportedLanguage::Python => {
93                    r#"
94                    (function_definition name: (identifier) @name)
95                "#
96                }
97                SupportedLanguage::JavaScript | SupportedLanguage::TypeScript => {
98                    r#"
99                    (function_declaration name: (identifier) @name)
100                    (export_statement (function_declaration name: (identifier) @name))
101                    (method_definition name: (property_identifier) @name)
102                    (arrow_function) @arrow
103                    (variable_declarator
104                        name: (identifier) @name
105                        value: (arrow_function))
106                "#
107                }
108                SupportedLanguage::Ruby => {
109                    r#"
110                    (method name: (identifier) @name)
111                    (singleton_method name: (identifier) @name)
112                "#
113                }
114                SupportedLanguage::CSharp => {
115                    r#"
116                    (method_declaration name: (identifier) @name)
117                    (local_function_statement name: (identifier) @name)
118                "#
119                } // SupportedLanguage::Erb => "", // ERB usually doesn't define functions
120            };
121
122            let query = Query::new(lang.language(), query_str)
123                .map_err(|e| anyhow::anyhow!("Failed to create query: {:?}", e))?;
124            e.insert(query);
125        }
126        Ok(self.queries.get(&lang).unwrap())
127    }
128
129    /// Find function definitions in the given file
130    pub fn find_functions(&mut self, path: &Path, code: &str) -> Result<Vec<FunctionMatch>> {
131        let lang = match SupportedLanguage::from_path(path) {
132            Some(l) => l,
133            None => return Ok(Vec::new()), // Unsupported language
134        };
135
136        let parser = self.get_parser(lang)?;
137        let tree = parser.parse(code, None).context("Failed to parse code")?;
138
139        let query = self.get_query(lang)?;
140        let mut cursor = QueryCursor::new();
141        let matches = cursor.matches(query, tree.root_node(), code.as_bytes());
142
143        let mut functions = Vec::new();
144        // Capture index for @name is usually 0 if it's the first capture
145        let name_idx = query.capture_index_for_name("name").unwrap_or(0);
146
147        for m in matches {
148            for capture in m.captures {
149                if capture.index == name_idx {
150                    let range = capture.node.range();
151                    let start_line = range.start_point.row + 1; // 1-based
152                    let end_line = range.end_point.row + 1;
153
154                    let name = capture.node.utf8_text(code.as_bytes())?.to_string();
155
156                    functions.push(FunctionMatch {
157                        name,
158                        start_line,
159                        end_line,
160                    });
161                }
162            }
163        }
164
165        Ok(functions)
166    }
167}
168
169#[derive(Debug)]
170pub struct FunctionMatch {
171    pub name: String,
172    pub start_line: usize,
173    pub end_line: usize,
174}