Skip to main content

codemem_engine/index/
rule_loader.rs

1//! Compile-time YAML rule embedding and deserialization for per-language extraction rules.
2
3use ast_grep_language::SupportLang;
4use codemem_core::CodememError;
5use serde::Deserialize;
6use std::collections::HashMap;
7
8/// A symbol extraction rule from YAML.
9#[derive(Debug, Clone, Deserialize)]
10pub struct SymbolRule {
11    pub kind: String,
12    pub symbol_kind: String,
13    #[serde(default = "default_name_field")]
14    pub name_field: String,
15    #[serde(default)]
16    pub method_when_scoped: bool,
17    #[serde(default)]
18    pub is_scope: bool,
19    #[serde(default)]
20    pub special: Option<String>,
21}
22
23/// A scope container rule from YAML.
24#[derive(Debug, Clone, Deserialize)]
25pub struct ScopeContainerRule {
26    pub kind: String,
27    #[serde(default = "default_name_field")]
28    pub name_field: String,
29    #[serde(default = "default_body_field")]
30    pub body_field: String,
31    #[serde(default)]
32    pub is_method_scope: bool,
33    #[serde(default)]
34    pub special: Option<String>,
35}
36
37/// A reference extraction rule from YAML.
38#[derive(Debug, Clone, Deserialize)]
39pub struct ReferenceRule {
40    pub kind: String,
41    pub reference_kind: String,
42    #[serde(default)]
43    pub name_field: Option<String>,
44    #[serde(default)]
45    pub special: Option<String>,
46}
47
48/// Deserialized symbols YAML file.
49#[derive(Debug, Clone, Deserialize)]
50pub struct SymbolRulesFile {
51    pub symbols: Vec<SymbolRule>,
52    #[serde(default)]
53    pub scope_containers: Vec<ScopeContainerRule>,
54    #[serde(default)]
55    pub unwrap_nodes: Vec<String>,
56}
57
58/// Deserialized references YAML file.
59#[derive(Debug, Clone, Deserialize)]
60pub struct ReferenceRulesFile {
61    pub references: Vec<ReferenceRule>,
62    #[serde(default)]
63    pub scope_containers: Vec<ScopeContainerRule>,
64    #[serde(default)]
65    pub unwrap_nodes: Vec<String>,
66}
67
68fn default_name_field() -> String {
69    "name".to_string()
70}
71
72fn default_body_field() -> String {
73    "body".to_string()
74}
75
76/// Compiled rules for a single language, ready for the engine.
77pub struct LanguageRules {
78    pub name: &'static str,
79    pub lang: SupportLang,
80    pub extensions: &'static [&'static str],
81    pub scope_separator: &'static str,
82    pub symbol_rules: Vec<SymbolRule>,
83    pub symbol_scope_containers: Vec<ScopeContainerRule>,
84    pub symbol_unwrap_nodes: Vec<String>,
85    pub reference_rules: Vec<ReferenceRule>,
86    pub reference_scope_containers: Vec<ScopeContainerRule>,
87    pub reference_unwrap_nodes: Vec<String>,
88    /// Index: node_kind → list of symbol rules
89    pub symbol_index: HashMap<String, Vec<usize>>,
90    /// Index: node_kind → list of reference rules
91    pub reference_index: HashMap<String, Vec<usize>>,
92    /// Index: node_kind → scope container index (symbols)
93    pub symbol_scope_index: HashMap<String, usize>,
94    /// Index: node_kind → scope container index (references)
95    pub reference_scope_index: HashMap<String, usize>,
96    /// Set of node kinds to unwrap (symbols)
97    pub symbol_unwrap_set: std::collections::HashSet<String>,
98    /// Set of node kinds to unwrap (references)
99    pub reference_unwrap_set: std::collections::HashSet<String>,
100}
101
102impl LanguageRules {
103    fn build_indexes(&mut self) {
104        // Symbol rule index
105        for (i, rule) in self.symbol_rules.iter().enumerate() {
106            self.symbol_index
107                .entry(rule.kind.clone())
108                .or_default()
109                .push(i);
110        }
111        // Reference rule index
112        for (i, rule) in self.reference_rules.iter().enumerate() {
113            self.reference_index
114                .entry(rule.kind.clone())
115                .or_default()
116                .push(i);
117        }
118        // Symbol scope container index
119        for (i, sc) in self.symbol_scope_containers.iter().enumerate() {
120            if let Some(prev) = self.symbol_scope_index.insert(sc.kind.clone(), i) {
121                debug_assert!(
122                    false,
123                    "Duplicate symbol scope container kind '{}': index {} overwrites {}",
124                    sc.kind, i, prev
125                );
126            }
127        }
128        // Reference scope container index
129        for (i, sc) in self.reference_scope_containers.iter().enumerate() {
130            if let Some(prev) = self.reference_scope_index.insert(sc.kind.clone(), i) {
131                debug_assert!(
132                    false,
133                    "Duplicate reference scope container kind '{}': index {} overwrites {}",
134                    sc.kind, i, prev
135                );
136            }
137        }
138        // Unwrap sets
139        self.symbol_unwrap_set = self.symbol_unwrap_nodes.iter().cloned().collect();
140        self.reference_unwrap_set = self.reference_unwrap_nodes.iter().cloned().collect();
141    }
142}
143
144/// Raw embedded rules before deserialization.
145struct EmbeddedRules {
146    name: &'static str,
147    lang: SupportLang,
148    extensions: &'static [&'static str],
149    scope_separator: &'static str,
150    symbols_yaml: &'static str,
151    references_yaml: &'static str,
152}
153
154/// All language rule definitions embedded at compile time.
155static LANGUAGE_RULES: &[EmbeddedRules] = &[
156    EmbeddedRules {
157        name: "rust",
158        lang: SupportLang::Rust,
159        extensions: &["rs"],
160        scope_separator: "::",
161        symbols_yaml: include_str!("../../rules/rust/symbols.yml"),
162        references_yaml: include_str!("../../rules/rust/references.yml"),
163    },
164    EmbeddedRules {
165        name: "typescript",
166        lang: SupportLang::TypeScript,
167        extensions: &["ts"],
168        scope_separator: ".",
169        symbols_yaml: include_str!("../../rules/typescript/symbols.yml"),
170        references_yaml: include_str!("../../rules/typescript/references.yml"),
171    },
172    // TSX shares rules with TypeScript but uses a different SupportLang
173    EmbeddedRules {
174        name: "tsx",
175        lang: SupportLang::Tsx,
176        extensions: &["tsx", "jsx"],
177        scope_separator: ".",
178        symbols_yaml: include_str!("../../rules/typescript/symbols.yml"),
179        references_yaml: include_str!("../../rules/typescript/references.yml"),
180    },
181    // JavaScript also uses TypeScript/TSX grammar
182    EmbeddedRules {
183        name: "javascript",
184        lang: SupportLang::JavaScript,
185        extensions: &["js"],
186        scope_separator: ".",
187        symbols_yaml: include_str!("../../rules/typescript/symbols.yml"),
188        references_yaml: include_str!("../../rules/typescript/references.yml"),
189    },
190    EmbeddedRules {
191        name: "python",
192        lang: SupportLang::Python,
193        extensions: &["py"],
194        scope_separator: ".",
195        symbols_yaml: include_str!("../../rules/python/symbols.yml"),
196        references_yaml: include_str!("../../rules/python/references.yml"),
197    },
198    EmbeddedRules {
199        name: "go",
200        lang: SupportLang::Go,
201        extensions: &["go"],
202        scope_separator: ".",
203        symbols_yaml: include_str!("../../rules/go/symbols.yml"),
204        references_yaml: include_str!("../../rules/go/references.yml"),
205    },
206    EmbeddedRules {
207        name: "java",
208        lang: SupportLang::Java,
209        extensions: &["java"],
210        scope_separator: ".",
211        symbols_yaml: include_str!("../../rules/java/symbols.yml"),
212        references_yaml: include_str!("../../rules/java/references.yml"),
213    },
214    EmbeddedRules {
215        name: "cpp",
216        lang: SupportLang::Cpp,
217        extensions: &["c", "h", "cpp", "hpp", "cc", "cxx", "hxx"],
218        scope_separator: "::",
219        symbols_yaml: include_str!("../../rules/cpp/symbols.yml"),
220        references_yaml: include_str!("../../rules/cpp/references.yml"),
221    },
222    EmbeddedRules {
223        name: "csharp",
224        lang: SupportLang::CSharp,
225        extensions: &["cs"],
226        scope_separator: ".",
227        symbols_yaml: include_str!("../../rules/csharp/symbols.yml"),
228        references_yaml: include_str!("../../rules/csharp/references.yml"),
229    },
230    EmbeddedRules {
231        name: "ruby",
232        lang: SupportLang::Ruby,
233        extensions: &["rb"],
234        scope_separator: "::",
235        symbols_yaml: include_str!("../../rules/ruby/symbols.yml"),
236        references_yaml: include_str!("../../rules/ruby/references.yml"),
237    },
238    EmbeddedRules {
239        name: "kotlin",
240        lang: SupportLang::Kotlin,
241        extensions: &["kt", "kts"],
242        scope_separator: ".",
243        symbols_yaml: include_str!("../../rules/kotlin/symbols.yml"),
244        references_yaml: include_str!("../../rules/kotlin/references.yml"),
245    },
246    EmbeddedRules {
247        name: "swift",
248        lang: SupportLang::Swift,
249        extensions: &["swift"],
250        scope_separator: ".",
251        symbols_yaml: include_str!("../../rules/swift/symbols.yml"),
252        references_yaml: include_str!("../../rules/swift/references.yml"),
253    },
254    EmbeddedRules {
255        name: "php",
256        lang: SupportLang::Php,
257        extensions: &["php"],
258        scope_separator: "::",
259        symbols_yaml: include_str!("../../rules/php/symbols.yml"),
260        references_yaml: include_str!("../../rules/php/references.yml"),
261    },
262    EmbeddedRules {
263        name: "scala",
264        lang: SupportLang::Scala,
265        extensions: &["scala", "sc"],
266        scope_separator: ".",
267        symbols_yaml: include_str!("../../rules/scala/symbols.yml"),
268        references_yaml: include_str!("../../rules/scala/references.yml"),
269    },
270    EmbeddedRules {
271        name: "hcl",
272        lang: SupportLang::Hcl,
273        extensions: &["tf", "hcl", "tfvars"],
274        scope_separator: ".",
275        symbols_yaml: include_str!("../../rules/hcl/symbols.yml"),
276        references_yaml: include_str!("../../rules/hcl/references.yml"),
277    },
278];
279
280/// Load and deserialize all language rules.
281///
282/// Returns an error if any embedded YAML rule file fails to deserialize.
283pub fn load_all_rules() -> Result<Vec<LanguageRules>, CodememError> {
284    LANGUAGE_RULES
285        .iter()
286        .map(|embedded| {
287            let sym_file: SymbolRulesFile =
288                serde_yaml::from_str(embedded.symbols_yaml).map_err(|e| {
289                    CodememError::Config(format!(
290                        "Failed to parse symbols.yml for {}: {}",
291                        embedded.name, e
292                    ))
293                })?;
294            let ref_file: ReferenceRulesFile = serde_yaml::from_str(embedded.references_yaml)
295                .map_err(|e| {
296                    CodememError::Config(format!(
297                        "Failed to parse references.yml for {}: {}",
298                        embedded.name, e
299                    ))
300                })?;
301
302            let mut rules = LanguageRules {
303                name: embedded.name,
304                lang: embedded.lang,
305                extensions: embedded.extensions,
306                scope_separator: embedded.scope_separator,
307                symbol_rules: sym_file.symbols,
308                symbol_scope_containers: sym_file.scope_containers,
309                symbol_unwrap_nodes: sym_file.unwrap_nodes,
310                reference_rules: ref_file.references,
311                reference_scope_containers: ref_file.scope_containers,
312                reference_unwrap_nodes: ref_file.unwrap_nodes,
313                symbol_index: HashMap::new(),
314                reference_index: HashMap::new(),
315                symbol_scope_index: HashMap::new(),
316                reference_scope_index: HashMap::new(),
317                symbol_unwrap_set: std::collections::HashSet::new(),
318                reference_unwrap_set: std::collections::HashSet::new(),
319            };
320            rules.build_indexes();
321            Ok(rules)
322        })
323        .collect()
324}