Skip to main content

codemem_engine/index/
rule_loader.rs

1//! Compile-time YAML rule embedding and deserialization for per-language extraction rules.
2
3use ast_grep_language::SupportLang;
4use serde::Deserialize;
5use std::collections::HashMap;
6
7/// A symbol extraction rule from YAML.
8#[derive(Debug, Clone, Deserialize)]
9pub struct SymbolRule {
10    pub kind: String,
11    pub symbol_kind: String,
12    #[serde(default = "default_name_field")]
13    pub name_field: String,
14    #[serde(default)]
15    pub method_when_scoped: bool,
16    #[serde(default)]
17    pub is_scope: bool,
18    #[serde(default)]
19    pub special: Option<String>,
20}
21
22/// A scope container rule from YAML.
23#[derive(Debug, Clone, Deserialize)]
24pub struct ScopeContainerRule {
25    pub kind: String,
26    #[serde(default = "default_name_field")]
27    pub name_field: String,
28    #[serde(default = "default_body_field")]
29    pub body_field: String,
30    #[serde(default)]
31    pub is_method_scope: bool,
32    #[serde(default)]
33    pub special: Option<String>,
34}
35
36/// A reference extraction rule from YAML.
37#[derive(Debug, Clone, Deserialize)]
38pub struct ReferenceRule {
39    pub kind: String,
40    pub reference_kind: String,
41    #[serde(default)]
42    pub name_field: Option<String>,
43    #[serde(default)]
44    pub special: Option<String>,
45}
46
47/// Deserialized symbols YAML file.
48#[derive(Debug, Clone, Deserialize)]
49pub struct SymbolRulesFile {
50    pub symbols: Vec<SymbolRule>,
51    #[serde(default)]
52    pub scope_containers: Vec<ScopeContainerRule>,
53    #[serde(default)]
54    pub unwrap_nodes: Vec<String>,
55}
56
57/// Deserialized references YAML file.
58#[derive(Debug, Clone, Deserialize)]
59pub struct ReferenceRulesFile {
60    pub references: Vec<ReferenceRule>,
61    #[serde(default)]
62    pub scope_containers: Vec<ScopeContainerRule>,
63    #[serde(default)]
64    pub unwrap_nodes: Vec<String>,
65}
66
67fn default_name_field() -> String {
68    "name".to_string()
69}
70
71fn default_body_field() -> String {
72    "body".to_string()
73}
74
75/// Compiled rules for a single language, ready for the engine.
76pub struct LanguageRules {
77    pub name: &'static str,
78    pub lang: SupportLang,
79    pub extensions: &'static [&'static str],
80    pub scope_separator: &'static str,
81    pub symbol_rules: Vec<SymbolRule>,
82    pub symbol_scope_containers: Vec<ScopeContainerRule>,
83    pub symbol_unwrap_nodes: Vec<String>,
84    pub reference_rules: Vec<ReferenceRule>,
85    pub reference_scope_containers: Vec<ScopeContainerRule>,
86    pub reference_unwrap_nodes: Vec<String>,
87    /// Index: node_kind → list of symbol rules
88    pub symbol_index: HashMap<String, Vec<usize>>,
89    /// Index: node_kind → list of reference rules
90    pub reference_index: HashMap<String, Vec<usize>>,
91    /// Index: node_kind → scope container index (symbols)
92    pub symbol_scope_index: HashMap<String, usize>,
93    /// Index: node_kind → scope container index (references)
94    pub reference_scope_index: HashMap<String, usize>,
95    /// Set of node kinds to unwrap (symbols)
96    pub symbol_unwrap_set: std::collections::HashSet<String>,
97    /// Set of node kinds to unwrap (references)
98    pub reference_unwrap_set: std::collections::HashSet<String>,
99}
100
101impl LanguageRules {
102    fn build_indexes(&mut self) {
103        // Symbol rule index
104        for (i, rule) in self.symbol_rules.iter().enumerate() {
105            self.symbol_index
106                .entry(rule.kind.clone())
107                .or_default()
108                .push(i);
109        }
110        // Reference rule index
111        for (i, rule) in self.reference_rules.iter().enumerate() {
112            self.reference_index
113                .entry(rule.kind.clone())
114                .or_default()
115                .push(i);
116        }
117        // Symbol scope container index
118        for (i, sc) in self.symbol_scope_containers.iter().enumerate() {
119            if let Some(prev) = self.symbol_scope_index.insert(sc.kind.clone(), i) {
120                debug_assert!(
121                    false,
122                    "Duplicate symbol scope container kind '{}': index {} overwrites {}",
123                    sc.kind, i, prev
124                );
125            }
126        }
127        // Reference scope container index
128        for (i, sc) in self.reference_scope_containers.iter().enumerate() {
129            if let Some(prev) = self.reference_scope_index.insert(sc.kind.clone(), i) {
130                debug_assert!(
131                    false,
132                    "Duplicate reference scope container kind '{}': index {} overwrites {}",
133                    sc.kind, i, prev
134                );
135            }
136        }
137        // Unwrap sets
138        self.symbol_unwrap_set = self.symbol_unwrap_nodes.iter().cloned().collect();
139        self.reference_unwrap_set = self.reference_unwrap_nodes.iter().cloned().collect();
140    }
141}
142
143/// Raw embedded rules before deserialization.
144struct EmbeddedRules {
145    name: &'static str,
146    lang: SupportLang,
147    extensions: &'static [&'static str],
148    scope_separator: &'static str,
149    symbols_yaml: &'static str,
150    references_yaml: &'static str,
151}
152
153/// All language rule definitions embedded at compile time.
154static LANGUAGE_RULES: &[EmbeddedRules] = &[
155    EmbeddedRules {
156        name: "rust",
157        lang: SupportLang::Rust,
158        extensions: &["rs"],
159        scope_separator: "::",
160        symbols_yaml: include_str!("../../rules/rust/symbols.yml"),
161        references_yaml: include_str!("../../rules/rust/references.yml"),
162    },
163    EmbeddedRules {
164        name: "typescript",
165        lang: SupportLang::TypeScript,
166        extensions: &["ts"],
167        scope_separator: ".",
168        symbols_yaml: include_str!("../../rules/typescript/symbols.yml"),
169        references_yaml: include_str!("../../rules/typescript/references.yml"),
170    },
171    // TSX shares rules with TypeScript but uses a different SupportLang
172    EmbeddedRules {
173        name: "tsx",
174        lang: SupportLang::Tsx,
175        extensions: &["tsx", "jsx"],
176        scope_separator: ".",
177        symbols_yaml: include_str!("../../rules/typescript/symbols.yml"),
178        references_yaml: include_str!("../../rules/typescript/references.yml"),
179    },
180    // JavaScript also uses TypeScript/TSX grammar
181    EmbeddedRules {
182        name: "javascript",
183        lang: SupportLang::JavaScript,
184        extensions: &["js"],
185        scope_separator: ".",
186        symbols_yaml: include_str!("../../rules/typescript/symbols.yml"),
187        references_yaml: include_str!("../../rules/typescript/references.yml"),
188    },
189    EmbeddedRules {
190        name: "python",
191        lang: SupportLang::Python,
192        extensions: &["py"],
193        scope_separator: ".",
194        symbols_yaml: include_str!("../../rules/python/symbols.yml"),
195        references_yaml: include_str!("../../rules/python/references.yml"),
196    },
197    EmbeddedRules {
198        name: "go",
199        lang: SupportLang::Go,
200        extensions: &["go"],
201        scope_separator: ".",
202        symbols_yaml: include_str!("../../rules/go/symbols.yml"),
203        references_yaml: include_str!("../../rules/go/references.yml"),
204    },
205    EmbeddedRules {
206        name: "java",
207        lang: SupportLang::Java,
208        extensions: &["java"],
209        scope_separator: ".",
210        symbols_yaml: include_str!("../../rules/java/symbols.yml"),
211        references_yaml: include_str!("../../rules/java/references.yml"),
212    },
213    EmbeddedRules {
214        name: "cpp",
215        lang: SupportLang::Cpp,
216        extensions: &["c", "h", "cpp", "hpp", "cc", "cxx", "hxx"],
217        scope_separator: "::",
218        symbols_yaml: include_str!("../../rules/cpp/symbols.yml"),
219        references_yaml: include_str!("../../rules/cpp/references.yml"),
220    },
221    EmbeddedRules {
222        name: "csharp",
223        lang: SupportLang::CSharp,
224        extensions: &["cs"],
225        scope_separator: ".",
226        symbols_yaml: include_str!("../../rules/csharp/symbols.yml"),
227        references_yaml: include_str!("../../rules/csharp/references.yml"),
228    },
229    EmbeddedRules {
230        name: "ruby",
231        lang: SupportLang::Ruby,
232        extensions: &["rb"],
233        scope_separator: "::",
234        symbols_yaml: include_str!("../../rules/ruby/symbols.yml"),
235        references_yaml: include_str!("../../rules/ruby/references.yml"),
236    },
237    EmbeddedRules {
238        name: "kotlin",
239        lang: SupportLang::Kotlin,
240        extensions: &["kt", "kts"],
241        scope_separator: ".",
242        symbols_yaml: include_str!("../../rules/kotlin/symbols.yml"),
243        references_yaml: include_str!("../../rules/kotlin/references.yml"),
244    },
245    EmbeddedRules {
246        name: "swift",
247        lang: SupportLang::Swift,
248        extensions: &["swift"],
249        scope_separator: ".",
250        symbols_yaml: include_str!("../../rules/swift/symbols.yml"),
251        references_yaml: include_str!("../../rules/swift/references.yml"),
252    },
253    EmbeddedRules {
254        name: "php",
255        lang: SupportLang::Php,
256        extensions: &["php"],
257        scope_separator: "::",
258        symbols_yaml: include_str!("../../rules/php/symbols.yml"),
259        references_yaml: include_str!("../../rules/php/references.yml"),
260    },
261    EmbeddedRules {
262        name: "scala",
263        lang: SupportLang::Scala,
264        extensions: &["scala", "sc"],
265        scope_separator: ".",
266        symbols_yaml: include_str!("../../rules/scala/symbols.yml"),
267        references_yaml: include_str!("../../rules/scala/references.yml"),
268    },
269    EmbeddedRules {
270        name: "hcl",
271        lang: SupportLang::Hcl,
272        extensions: &["tf", "hcl", "tfvars"],
273        scope_separator: ".",
274        symbols_yaml: include_str!("../../rules/hcl/symbols.yml"),
275        references_yaml: include_str!("../../rules/hcl/references.yml"),
276    },
277];
278
279/// Load and deserialize all language rules.
280///
281/// # Panics
282///
283/// Panics if any embedded YAML rule file fails to deserialize. This is intentional:
284/// these files are compiled into the binary, so a parse failure indicates a build-time
285/// error that must be fixed before shipping.
286pub fn load_all_rules() -> Vec<LanguageRules> {
287    LANGUAGE_RULES
288        .iter()
289        .map(|embedded| {
290            let sym_file: SymbolRulesFile = serde_yaml::from_str(embedded.symbols_yaml)
291                .unwrap_or_else(|e| {
292                    panic!("Failed to parse symbols.yml for {}: {}", embedded.name, e)
293                });
294            let ref_file: ReferenceRulesFile = serde_yaml::from_str(embedded.references_yaml)
295                .unwrap_or_else(|e| {
296                    panic!(
297                        "Failed to parse references.yml for {}: {}",
298                        embedded.name, e
299                    )
300                });
301
302            let mut rules = LanguageRules {
303                name: embedded.name,
304                lang: embedded.lang,
305                extensions: embedded.extensions,
306                scope_separator: embedded.scope_separator,
307                symbol_rules: sym_file.symbols,
308                symbol_scope_containers: sym_file.scope_containers,
309                symbol_unwrap_nodes: sym_file.unwrap_nodes,
310                reference_rules: ref_file.references,
311                reference_scope_containers: ref_file.scope_containers,
312                reference_unwrap_nodes: ref_file.unwrap_nodes,
313                symbol_index: HashMap::new(),
314                reference_index: HashMap::new(),
315                symbol_scope_index: HashMap::new(),
316                reference_scope_index: HashMap::new(),
317                symbol_unwrap_set: std::collections::HashSet::new(),
318                reference_unwrap_set: std::collections::HashSet::new(),
319            };
320            rules.build_indexes();
321            rules
322        })
323        .collect()
324}