1use ast_grep_language::SupportLang;
4use serde::Deserialize;
5use std::collections::HashMap;
6
7#[derive(Debug, Clone, Deserialize)]
9pub struct SymbolRule {
10 pub kind: String,
11 pub symbol_kind: String,
12 #[serde(default = "default_name_field")]
13 pub name_field: String,
14 #[serde(default)]
15 pub method_when_scoped: bool,
16 #[serde(default)]
17 pub is_scope: bool,
18 #[serde(default)]
19 pub special: Option<String>,
20}
21
22#[derive(Debug, Clone, Deserialize)]
24pub struct ScopeContainerRule {
25 pub kind: String,
26 #[serde(default = "default_name_field")]
27 pub name_field: String,
28 #[serde(default = "default_body_field")]
29 pub body_field: String,
30 #[serde(default)]
31 pub is_method_scope: bool,
32 #[serde(default)]
33 pub special: Option<String>,
34}
35
36#[derive(Debug, Clone, Deserialize)]
38pub struct ReferenceRule {
39 pub kind: String,
40 pub reference_kind: String,
41 #[serde(default)]
42 pub name_field: Option<String>,
43 #[serde(default)]
44 pub special: Option<String>,
45}
46
47#[derive(Debug, Clone, Deserialize)]
49pub struct SymbolRulesFile {
50 pub symbols: Vec<SymbolRule>,
51 #[serde(default)]
52 pub scope_containers: Vec<ScopeContainerRule>,
53 #[serde(default)]
54 pub unwrap_nodes: Vec<String>,
55}
56
57#[derive(Debug, Clone, Deserialize)]
59pub struct ReferenceRulesFile {
60 pub references: Vec<ReferenceRule>,
61 #[serde(default)]
62 pub scope_containers: Vec<ScopeContainerRule>,
63 #[serde(default)]
64 pub unwrap_nodes: Vec<String>,
65}
66
67fn default_name_field() -> String {
68 "name".to_string()
69}
70
71fn default_body_field() -> String {
72 "body".to_string()
73}
74
75pub struct LanguageRules {
77 pub name: &'static str,
78 pub lang: SupportLang,
79 pub extensions: &'static [&'static str],
80 pub scope_separator: &'static str,
81 pub symbol_rules: Vec<SymbolRule>,
82 pub symbol_scope_containers: Vec<ScopeContainerRule>,
83 pub symbol_unwrap_nodes: Vec<String>,
84 pub reference_rules: Vec<ReferenceRule>,
85 pub reference_scope_containers: Vec<ScopeContainerRule>,
86 pub reference_unwrap_nodes: Vec<String>,
87 pub symbol_index: HashMap<String, Vec<usize>>,
89 pub reference_index: HashMap<String, Vec<usize>>,
91 pub symbol_scope_index: HashMap<String, usize>,
93 pub reference_scope_index: HashMap<String, usize>,
95 pub symbol_unwrap_set: std::collections::HashSet<String>,
97 pub reference_unwrap_set: std::collections::HashSet<String>,
99}
100
101impl LanguageRules {
102 fn build_indexes(&mut self) {
103 for (i, rule) in self.symbol_rules.iter().enumerate() {
105 self.symbol_index
106 .entry(rule.kind.clone())
107 .or_default()
108 .push(i);
109 }
110 for (i, rule) in self.reference_rules.iter().enumerate() {
112 self.reference_index
113 .entry(rule.kind.clone())
114 .or_default()
115 .push(i);
116 }
117 for (i, sc) in self.symbol_scope_containers.iter().enumerate() {
119 if let Some(prev) = self.symbol_scope_index.insert(sc.kind.clone(), i) {
120 debug_assert!(
121 false,
122 "Duplicate symbol scope container kind '{}': index {} overwrites {}",
123 sc.kind, i, prev
124 );
125 }
126 }
127 for (i, sc) in self.reference_scope_containers.iter().enumerate() {
129 if let Some(prev) = self.reference_scope_index.insert(sc.kind.clone(), i) {
130 debug_assert!(
131 false,
132 "Duplicate reference scope container kind '{}': index {} overwrites {}",
133 sc.kind, i, prev
134 );
135 }
136 }
137 self.symbol_unwrap_set = self.symbol_unwrap_nodes.iter().cloned().collect();
139 self.reference_unwrap_set = self.reference_unwrap_nodes.iter().cloned().collect();
140 }
141}
142
143struct EmbeddedRules {
145 name: &'static str,
146 lang: SupportLang,
147 extensions: &'static [&'static str],
148 scope_separator: &'static str,
149 symbols_yaml: &'static str,
150 references_yaml: &'static str,
151}
152
153static LANGUAGE_RULES: &[EmbeddedRules] = &[
155 EmbeddedRules {
156 name: "rust",
157 lang: SupportLang::Rust,
158 extensions: &["rs"],
159 scope_separator: "::",
160 symbols_yaml: include_str!("../../rules/rust/symbols.yml"),
161 references_yaml: include_str!("../../rules/rust/references.yml"),
162 },
163 EmbeddedRules {
164 name: "typescript",
165 lang: SupportLang::TypeScript,
166 extensions: &["ts"],
167 scope_separator: ".",
168 symbols_yaml: include_str!("../../rules/typescript/symbols.yml"),
169 references_yaml: include_str!("../../rules/typescript/references.yml"),
170 },
171 EmbeddedRules {
173 name: "tsx",
174 lang: SupportLang::Tsx,
175 extensions: &["tsx", "jsx"],
176 scope_separator: ".",
177 symbols_yaml: include_str!("../../rules/typescript/symbols.yml"),
178 references_yaml: include_str!("../../rules/typescript/references.yml"),
179 },
180 EmbeddedRules {
182 name: "javascript",
183 lang: SupportLang::JavaScript,
184 extensions: &["js"],
185 scope_separator: ".",
186 symbols_yaml: include_str!("../../rules/typescript/symbols.yml"),
187 references_yaml: include_str!("../../rules/typescript/references.yml"),
188 },
189 EmbeddedRules {
190 name: "python",
191 lang: SupportLang::Python,
192 extensions: &["py"],
193 scope_separator: ".",
194 symbols_yaml: include_str!("../../rules/python/symbols.yml"),
195 references_yaml: include_str!("../../rules/python/references.yml"),
196 },
197 EmbeddedRules {
198 name: "go",
199 lang: SupportLang::Go,
200 extensions: &["go"],
201 scope_separator: ".",
202 symbols_yaml: include_str!("../../rules/go/symbols.yml"),
203 references_yaml: include_str!("../../rules/go/references.yml"),
204 },
205 EmbeddedRules {
206 name: "java",
207 lang: SupportLang::Java,
208 extensions: &["java"],
209 scope_separator: ".",
210 symbols_yaml: include_str!("../../rules/java/symbols.yml"),
211 references_yaml: include_str!("../../rules/java/references.yml"),
212 },
213 EmbeddedRules {
214 name: "cpp",
215 lang: SupportLang::Cpp,
216 extensions: &["c", "h", "cpp", "hpp", "cc", "cxx", "hxx"],
217 scope_separator: "::",
218 symbols_yaml: include_str!("../../rules/cpp/symbols.yml"),
219 references_yaml: include_str!("../../rules/cpp/references.yml"),
220 },
221 EmbeddedRules {
222 name: "csharp",
223 lang: SupportLang::CSharp,
224 extensions: &["cs"],
225 scope_separator: ".",
226 symbols_yaml: include_str!("../../rules/csharp/symbols.yml"),
227 references_yaml: include_str!("../../rules/csharp/references.yml"),
228 },
229 EmbeddedRules {
230 name: "ruby",
231 lang: SupportLang::Ruby,
232 extensions: &["rb"],
233 scope_separator: "::",
234 symbols_yaml: include_str!("../../rules/ruby/symbols.yml"),
235 references_yaml: include_str!("../../rules/ruby/references.yml"),
236 },
237 EmbeddedRules {
238 name: "kotlin",
239 lang: SupportLang::Kotlin,
240 extensions: &["kt", "kts"],
241 scope_separator: ".",
242 symbols_yaml: include_str!("../../rules/kotlin/symbols.yml"),
243 references_yaml: include_str!("../../rules/kotlin/references.yml"),
244 },
245 EmbeddedRules {
246 name: "swift",
247 lang: SupportLang::Swift,
248 extensions: &["swift"],
249 scope_separator: ".",
250 symbols_yaml: include_str!("../../rules/swift/symbols.yml"),
251 references_yaml: include_str!("../../rules/swift/references.yml"),
252 },
253 EmbeddedRules {
254 name: "php",
255 lang: SupportLang::Php,
256 extensions: &["php"],
257 scope_separator: "::",
258 symbols_yaml: include_str!("../../rules/php/symbols.yml"),
259 references_yaml: include_str!("../../rules/php/references.yml"),
260 },
261 EmbeddedRules {
262 name: "scala",
263 lang: SupportLang::Scala,
264 extensions: &["scala", "sc"],
265 scope_separator: ".",
266 symbols_yaml: include_str!("../../rules/scala/symbols.yml"),
267 references_yaml: include_str!("../../rules/scala/references.yml"),
268 },
269 EmbeddedRules {
270 name: "hcl",
271 lang: SupportLang::Hcl,
272 extensions: &["tf", "hcl", "tfvars"],
273 scope_separator: ".",
274 symbols_yaml: include_str!("../../rules/hcl/symbols.yml"),
275 references_yaml: include_str!("../../rules/hcl/references.yml"),
276 },
277];
278
279pub fn load_all_rules() -> Vec<LanguageRules> {
287 LANGUAGE_RULES
288 .iter()
289 .map(|embedded| {
290 let sym_file: SymbolRulesFile = serde_yaml::from_str(embedded.symbols_yaml)
291 .unwrap_or_else(|e| {
292 panic!("Failed to parse symbols.yml for {}: {}", embedded.name, e)
293 });
294 let ref_file: ReferenceRulesFile = serde_yaml::from_str(embedded.references_yaml)
295 .unwrap_or_else(|e| {
296 panic!(
297 "Failed to parse references.yml for {}: {}",
298 embedded.name, e
299 )
300 });
301
302 let mut rules = LanguageRules {
303 name: embedded.name,
304 lang: embedded.lang,
305 extensions: embedded.extensions,
306 scope_separator: embedded.scope_separator,
307 symbol_rules: sym_file.symbols,
308 symbol_scope_containers: sym_file.scope_containers,
309 symbol_unwrap_nodes: sym_file.unwrap_nodes,
310 reference_rules: ref_file.references,
311 reference_scope_containers: ref_file.scope_containers,
312 reference_unwrap_nodes: ref_file.unwrap_nodes,
313 symbol_index: HashMap::new(),
314 reference_index: HashMap::new(),
315 symbol_scope_index: HashMap::new(),
316 reference_scope_index: HashMap::new(),
317 symbol_unwrap_set: std::collections::HashSet::new(),
318 reference_unwrap_set: std::collections::HashSet::new(),
319 };
320 rules.build_indexes();
321 rules
322 })
323 .collect()
324}