1use ast_grep_language::SupportLang;
4use codemem_core::CodememError;
5use serde::Deserialize;
6use std::collections::HashMap;
7
8#[derive(Debug, Clone, Deserialize)]
10pub struct SymbolRule {
11 pub kind: String,
12 pub symbol_kind: String,
13 #[serde(default = "default_name_field")]
14 pub name_field: String,
15 #[serde(default)]
16 pub method_when_scoped: bool,
17 #[serde(default)]
18 pub is_scope: bool,
19 #[serde(default)]
20 pub special: Option<String>,
21}
22
23#[derive(Debug, Clone, Deserialize)]
25pub struct ScopeContainerRule {
26 pub kind: String,
27 #[serde(default = "default_name_field")]
28 pub name_field: String,
29 #[serde(default = "default_body_field")]
30 pub body_field: String,
31 #[serde(default)]
32 pub is_method_scope: bool,
33 #[serde(default)]
34 pub special: Option<String>,
35}
36
37#[derive(Debug, Clone, Deserialize)]
39pub struct ReferenceRule {
40 pub kind: String,
41 pub reference_kind: String,
42 #[serde(default)]
43 pub name_field: Option<String>,
44 #[serde(default)]
45 pub special: Option<String>,
46}
47
48#[derive(Debug, Clone, Deserialize)]
50pub struct SymbolRulesFile {
51 pub symbols: Vec<SymbolRule>,
52 #[serde(default)]
53 pub scope_containers: Vec<ScopeContainerRule>,
54 #[serde(default)]
55 pub unwrap_nodes: Vec<String>,
56}
57
58#[derive(Debug, Clone, Deserialize)]
60pub struct ReferenceRulesFile {
61 pub references: Vec<ReferenceRule>,
62 #[serde(default)]
63 pub scope_containers: Vec<ScopeContainerRule>,
64 #[serde(default)]
65 pub unwrap_nodes: Vec<String>,
66}
67
68fn default_name_field() -> String {
69 "name".to_string()
70}
71
72fn default_body_field() -> String {
73 "body".to_string()
74}
75
76pub struct LanguageRules {
78 pub name: &'static str,
79 pub lang: SupportLang,
80 pub extensions: &'static [&'static str],
81 pub scope_separator: &'static str,
82 pub symbol_rules: Vec<SymbolRule>,
83 pub symbol_scope_containers: Vec<ScopeContainerRule>,
84 pub symbol_unwrap_nodes: Vec<String>,
85 pub reference_rules: Vec<ReferenceRule>,
86 pub reference_scope_containers: Vec<ScopeContainerRule>,
87 pub reference_unwrap_nodes: Vec<String>,
88 pub symbol_index: HashMap<String, Vec<usize>>,
90 pub reference_index: HashMap<String, Vec<usize>>,
92 pub symbol_scope_index: HashMap<String, usize>,
94 pub reference_scope_index: HashMap<String, usize>,
96 pub symbol_unwrap_set: std::collections::HashSet<String>,
98 pub reference_unwrap_set: std::collections::HashSet<String>,
100}
101
102impl LanguageRules {
103 fn build_indexes(&mut self) {
104 for (i, rule) in self.symbol_rules.iter().enumerate() {
106 self.symbol_index
107 .entry(rule.kind.clone())
108 .or_default()
109 .push(i);
110 }
111 for (i, rule) in self.reference_rules.iter().enumerate() {
113 self.reference_index
114 .entry(rule.kind.clone())
115 .or_default()
116 .push(i);
117 }
118 for (i, sc) in self.symbol_scope_containers.iter().enumerate() {
120 if let Some(prev) = self.symbol_scope_index.insert(sc.kind.clone(), i) {
121 debug_assert!(
122 false,
123 "Duplicate symbol scope container kind '{}': index {} overwrites {}",
124 sc.kind, i, prev
125 );
126 }
127 }
128 for (i, sc) in self.reference_scope_containers.iter().enumerate() {
130 if let Some(prev) = self.reference_scope_index.insert(sc.kind.clone(), i) {
131 debug_assert!(
132 false,
133 "Duplicate reference scope container kind '{}': index {} overwrites {}",
134 sc.kind, i, prev
135 );
136 }
137 }
138 self.symbol_unwrap_set = self.symbol_unwrap_nodes.iter().cloned().collect();
140 self.reference_unwrap_set = self.reference_unwrap_nodes.iter().cloned().collect();
141 }
142}
143
144struct EmbeddedRules {
146 name: &'static str,
147 lang: SupportLang,
148 extensions: &'static [&'static str],
149 scope_separator: &'static str,
150 symbols_yaml: &'static str,
151 references_yaml: &'static str,
152}
153
154static LANGUAGE_RULES: &[EmbeddedRules] = &[
156 EmbeddedRules {
157 name: "rust",
158 lang: SupportLang::Rust,
159 extensions: &["rs"],
160 scope_separator: "::",
161 symbols_yaml: include_str!("../../rules/rust/symbols.yml"),
162 references_yaml: include_str!("../../rules/rust/references.yml"),
163 },
164 EmbeddedRules {
165 name: "typescript",
166 lang: SupportLang::TypeScript,
167 extensions: &["ts"],
168 scope_separator: ".",
169 symbols_yaml: include_str!("../../rules/typescript/symbols.yml"),
170 references_yaml: include_str!("../../rules/typescript/references.yml"),
171 },
172 EmbeddedRules {
174 name: "tsx",
175 lang: SupportLang::Tsx,
176 extensions: &["tsx", "jsx"],
177 scope_separator: ".",
178 symbols_yaml: include_str!("../../rules/typescript/symbols.yml"),
179 references_yaml: include_str!("../../rules/typescript/references.yml"),
180 },
181 EmbeddedRules {
183 name: "javascript",
184 lang: SupportLang::JavaScript,
185 extensions: &["js"],
186 scope_separator: ".",
187 symbols_yaml: include_str!("../../rules/typescript/symbols.yml"),
188 references_yaml: include_str!("../../rules/typescript/references.yml"),
189 },
190 EmbeddedRules {
191 name: "python",
192 lang: SupportLang::Python,
193 extensions: &["py"],
194 scope_separator: ".",
195 symbols_yaml: include_str!("../../rules/python/symbols.yml"),
196 references_yaml: include_str!("../../rules/python/references.yml"),
197 },
198 EmbeddedRules {
199 name: "go",
200 lang: SupportLang::Go,
201 extensions: &["go"],
202 scope_separator: ".",
203 symbols_yaml: include_str!("../../rules/go/symbols.yml"),
204 references_yaml: include_str!("../../rules/go/references.yml"),
205 },
206 EmbeddedRules {
207 name: "java",
208 lang: SupportLang::Java,
209 extensions: &["java"],
210 scope_separator: ".",
211 symbols_yaml: include_str!("../../rules/java/symbols.yml"),
212 references_yaml: include_str!("../../rules/java/references.yml"),
213 },
214 EmbeddedRules {
215 name: "cpp",
216 lang: SupportLang::Cpp,
217 extensions: &["c", "h", "cpp", "hpp", "cc", "cxx", "hxx"],
218 scope_separator: "::",
219 symbols_yaml: include_str!("../../rules/cpp/symbols.yml"),
220 references_yaml: include_str!("../../rules/cpp/references.yml"),
221 },
222 EmbeddedRules {
223 name: "csharp",
224 lang: SupportLang::CSharp,
225 extensions: &["cs"],
226 scope_separator: ".",
227 symbols_yaml: include_str!("../../rules/csharp/symbols.yml"),
228 references_yaml: include_str!("../../rules/csharp/references.yml"),
229 },
230 EmbeddedRules {
231 name: "ruby",
232 lang: SupportLang::Ruby,
233 extensions: &["rb"],
234 scope_separator: "::",
235 symbols_yaml: include_str!("../../rules/ruby/symbols.yml"),
236 references_yaml: include_str!("../../rules/ruby/references.yml"),
237 },
238 EmbeddedRules {
239 name: "kotlin",
240 lang: SupportLang::Kotlin,
241 extensions: &["kt", "kts"],
242 scope_separator: ".",
243 symbols_yaml: include_str!("../../rules/kotlin/symbols.yml"),
244 references_yaml: include_str!("../../rules/kotlin/references.yml"),
245 },
246 EmbeddedRules {
247 name: "swift",
248 lang: SupportLang::Swift,
249 extensions: &["swift"],
250 scope_separator: ".",
251 symbols_yaml: include_str!("../../rules/swift/symbols.yml"),
252 references_yaml: include_str!("../../rules/swift/references.yml"),
253 },
254 EmbeddedRules {
255 name: "php",
256 lang: SupportLang::Php,
257 extensions: &["php"],
258 scope_separator: "::",
259 symbols_yaml: include_str!("../../rules/php/symbols.yml"),
260 references_yaml: include_str!("../../rules/php/references.yml"),
261 },
262 EmbeddedRules {
263 name: "scala",
264 lang: SupportLang::Scala,
265 extensions: &["scala", "sc"],
266 scope_separator: ".",
267 symbols_yaml: include_str!("../../rules/scala/symbols.yml"),
268 references_yaml: include_str!("../../rules/scala/references.yml"),
269 },
270 EmbeddedRules {
271 name: "hcl",
272 lang: SupportLang::Hcl,
273 extensions: &["tf", "hcl", "tfvars"],
274 scope_separator: ".",
275 symbols_yaml: include_str!("../../rules/hcl/symbols.yml"),
276 references_yaml: include_str!("../../rules/hcl/references.yml"),
277 },
278];
279
280pub fn load_all_rules() -> Result<Vec<LanguageRules>, CodememError> {
284 LANGUAGE_RULES
285 .iter()
286 .map(|embedded| {
287 let sym_file: SymbolRulesFile =
288 serde_yaml::from_str(embedded.symbols_yaml).map_err(|e| {
289 CodememError::Config(format!(
290 "Failed to parse symbols.yml for {}: {}",
291 embedded.name, e
292 ))
293 })?;
294 let ref_file: ReferenceRulesFile = serde_yaml::from_str(embedded.references_yaml)
295 .map_err(|e| {
296 CodememError::Config(format!(
297 "Failed to parse references.yml for {}: {}",
298 embedded.name, e
299 ))
300 })?;
301
302 let mut rules = LanguageRules {
303 name: embedded.name,
304 lang: embedded.lang,
305 extensions: embedded.extensions,
306 scope_separator: embedded.scope_separator,
307 symbol_rules: sym_file.symbols,
308 symbol_scope_containers: sym_file.scope_containers,
309 symbol_unwrap_nodes: sym_file.unwrap_nodes,
310 reference_rules: ref_file.references,
311 reference_scope_containers: ref_file.scope_containers,
312 reference_unwrap_nodes: ref_file.unwrap_nodes,
313 symbol_index: HashMap::new(),
314 reference_index: HashMap::new(),
315 symbol_scope_index: HashMap::new(),
316 reference_scope_index: HashMap::new(),
317 symbol_unwrap_set: std::collections::HashSet::new(),
318 reference_unwrap_set: std::collections::HashSet::new(),
319 };
320 rules.build_indexes();
321 Ok(rules)
322 })
323 .collect()
324}