1use anyhow::{Context, Result};
2use std::collections::HashMap;
3use std::path::Path;
4use tree_sitter::{Language, Parser, Query, QueryCursor};
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
8pub enum SupportedLanguage {
9 Rust,
10 Python,
11 JavaScript,
12 TypeScript,
13 Ruby,
14 CSharp,
15 }
17
18impl SupportedLanguage {
19 pub fn from_path(path: &Path) -> Option<Self> {
20 match path.extension()?.to_str()? {
21 "rs" => Some(Self::Rust),
22 "py" => Some(Self::Python),
23 "js" | "jsx" => Some(Self::JavaScript),
24 "ts" | "tsx" => Some(Self::TypeScript),
25 "rb" => Some(Self::Ruby),
26 "cs" => Some(Self::CSharp),
27 _ => None,
29 }
30 }
31
32 pub fn language(&self) -> Language {
33 match self {
34 Self::Rust => tree_sitter_rust::language(),
35 Self::Python => tree_sitter_python::language(),
36 Self::JavaScript => tree_sitter_javascript::language(),
37 Self::TypeScript => tree_sitter_typescript::language_typescript(),
38 Self::Ruby => tree_sitter_ruby::language(),
39 Self::CSharp => tree_sitter_c_sharp::language(),
40 }
42 }
43}
44
45pub struct Sitter {
47 parsers: HashMap<SupportedLanguage, Parser>,
48 queries: HashMap<SupportedLanguage, Query>,
49}
50
51impl Default for Sitter {
52 fn default() -> Self {
53 Self::new()
54 }
55}
56
57impl Sitter {
58 pub fn new() -> Self {
59 Self {
60 parsers: HashMap::new(),
61 queries: HashMap::new(),
62 }
63 }
64
65 pub fn is_supported(&self, path: &Path) -> bool {
67 SupportedLanguage::from_path(path).is_some()
68 }
69
70 fn get_parser(&mut self, lang: SupportedLanguage) -> Result<&mut Parser> {
72 if let std::collections::hash_map::Entry::Vacant(e) = self.parsers.entry(lang) {
73 let mut parser = Parser::new();
74 parser
75 .set_language(lang.language())
76 .context("Failed to set parser language")?;
77 e.insert(parser);
78 }
79 Ok(self.parsers.get_mut(&lang).unwrap())
80 }
81
82 fn get_query(&mut self, lang: SupportedLanguage) -> Result<&Query> {
84 if let std::collections::hash_map::Entry::Vacant(e) = self.queries.entry(lang) {
85 let query_str = match lang {
86 SupportedLanguage::Rust => {
87 r#"
88 (function_item name: (identifier) @name)
89 (function_signature_item name: (identifier) @name)
90 "#
91 }
92 SupportedLanguage::Python => {
93 r#"
94 (function_definition name: (identifier) @name)
95 "#
96 }
97 SupportedLanguage::JavaScript | SupportedLanguage::TypeScript => {
98 r#"
99 (function_declaration name: (identifier) @name)
100 (export_statement (function_declaration name: (identifier) @name))
101 (method_definition name: (property_identifier) @name)
102 (arrow_function) @arrow
103 (variable_declarator
104 name: (identifier) @name
105 value: (arrow_function))
106 "#
107 }
108 SupportedLanguage::Ruby => {
109 r#"
110 (method name: (identifier) @name)
111 (singleton_method name: (identifier) @name)
112 "#
113 }
114 SupportedLanguage::CSharp => {
115 r#"
116 (method_declaration name: (identifier) @name)
117 (local_function_statement name: (identifier) @name)
118 "#
119 } };
121
122 let query = Query::new(lang.language(), query_str)
123 .map_err(|e| anyhow::anyhow!("Failed to create query: {:?}", e))?;
124 e.insert(query);
125 }
126 Ok(self.queries.get(&lang).unwrap())
127 }
128
129 pub fn find_functions(&mut self, path: &Path, code: &str) -> Result<Vec<FunctionMatch>> {
131 let lang = match SupportedLanguage::from_path(path) {
132 Some(l) => l,
133 None => return Ok(Vec::new()), };
135
136 let parser = self.get_parser(lang)?;
137 let tree = parser.parse(code, None).context("Failed to parse code")?;
138
139 let query = self.get_query(lang)?;
140 let mut cursor = QueryCursor::new();
141 let matches = cursor.matches(query, tree.root_node(), code.as_bytes());
142
143 let mut functions = Vec::new();
144 let name_idx = query.capture_index_for_name("name").unwrap_or(0);
146
147 for m in matches {
148 for capture in m.captures {
149 if capture.index == name_idx {
150 let range = capture.node.range();
151 let start_line = range.start_point.row + 1; let end_line = range.end_point.row + 1;
153
154 let name = capture.node.utf8_text(code.as_bytes())?.to_string();
155
156 functions.push(FunctionMatch {
157 name,
158 start_line,
159 end_line,
160 });
161 }
162 }
163 }
164
165 Ok(functions)
166 }
167}
168
169#[derive(Debug)]
170pub struct FunctionMatch {
171 pub name: String,
172 pub start_line: usize,
173 pub end_line: usize,
174}