cairo_lang_plugins/plugins/
config.rs1use std::vec;
2
3use cairo_lang_defs::patcher::PatchBuilder;
4use cairo_lang_defs::plugin::{
5 MacroPlugin, MacroPluginMetadata, PluginDiagnostic, PluginGeneratedFile, PluginResult,
6};
7use cairo_lang_filesystem::cfg::{Cfg, CfgSet};
8use cairo_lang_syntax::attribute::structured::{
9 Attribute, AttributeArg, AttributeArgVariant, AttributeStructurize,
10};
11use cairo_lang_syntax::node::db::SyntaxGroup;
12use cairo_lang_syntax::node::helpers::{BodyItems, GetIdentifier, QueryAttrs};
13use cairo_lang_syntax::node::{TypedStablePtr, TypedSyntaxNode, ast};
14use cairo_lang_utils::try_extract_matches;
15use itertools::Itertools;
16
17#[derive(Debug, Clone)]
21enum PredicateTree {
22 Cfg(Cfg),
23 Not(Box<PredicateTree>),
24 And(Vec<PredicateTree>),
25 Or(Vec<PredicateTree>),
26}
27
28impl PredicateTree {
29 fn evaluate(&self, cfg_set: &CfgSet) -> bool {
33 match self {
34 PredicateTree::Cfg(cfg) => cfg_set.contains(cfg),
35 PredicateTree::Not(inner) => !inner.evaluate(cfg_set),
36 PredicateTree::And(predicates) => predicates.iter().all(|p| p.evaluate(cfg_set)),
37 PredicateTree::Or(predicates) => predicates.iter().any(|p| p.evaluate(cfg_set)),
38 }
39 }
40}
41
42pub enum ConfigPredicatePart {
44 Cfg(Cfg),
46 Call(ast::ExprFunctionCall),
48}
49
50#[derive(Debug, Default)]
55#[non_exhaustive]
56pub struct ConfigPlugin;
57
58const CFG_ATTR: &str = "cfg";
59
60impl MacroPlugin for ConfigPlugin {
61 fn generate_code(
62 &self,
63 db: &dyn SyntaxGroup,
64 item_ast: ast::ModuleItem,
65 metadata: &MacroPluginMetadata<'_>,
66 ) -> PluginResult {
67 let mut diagnostics = vec![];
68
69 if should_drop(db, metadata.cfg_set, &item_ast, &mut diagnostics) {
70 PluginResult { code: None, diagnostics, remove_original_item: true }
71 } else if let Some(builder) =
72 handle_undropped_item(db, metadata.cfg_set, item_ast, &mut diagnostics)
73 {
74 let (content, code_mappings) = builder.build();
75 PluginResult {
76 code: Some(PluginGeneratedFile {
77 name: "config".into(),
78 content,
79 code_mappings,
80 aux_data: None,
81 diagnostics_note: Default::default(),
82 is_unhygienic: false,
83 }),
84 diagnostics,
85 remove_original_item: true,
86 }
87 } else {
88 PluginResult { code: None, diagnostics, remove_original_item: false }
89 }
90 }
91
92 fn declared_attributes(&self) -> Vec<String> {
93 vec![CFG_ATTR.to_string()]
94 }
95}
96
97pub trait HasItemsInCfgEx<Item: QueryAttrs>: BodyItems<Item = Item> {
99 fn iter_items_in_cfg<'a>(
100 &self,
101 db: &'a dyn SyntaxGroup,
102 cfg_set: &'a CfgSet,
103 ) -> impl Iterator<Item = Item> + 'a;
104}
105
106impl<Item: QueryAttrs, Body: BodyItems<Item = Item>> HasItemsInCfgEx<Item> for Body {
107 fn iter_items_in_cfg<'a>(
108 &self,
109 db: &'a dyn SyntaxGroup,
110 cfg_set: &'a CfgSet,
111 ) -> impl Iterator<Item = Item> + 'a {
112 self.iter_items(db).filter(move |item| !should_drop(db, cfg_set, item, &mut vec![]))
113 }
114}
115
116fn handle_undropped_item<'a>(
120 db: &'a dyn SyntaxGroup,
121 cfg_set: &CfgSet,
122 item_ast: ast::ModuleItem,
123 diagnostics: &mut Vec<PluginDiagnostic>,
124) -> Option<PatchBuilder<'a>> {
125 match item_ast {
126 ast::ModuleItem::Trait(trait_item) => {
127 let body = try_extract_matches!(trait_item.body(db), ast::MaybeTraitBody::Some)?;
128 let items = get_kept_items_nodes(db, cfg_set, body.iter_items(db), diagnostics)?;
129 let mut builder = PatchBuilder::new(db, &trait_item);
130 builder.add_node(trait_item.attributes(db).as_syntax_node());
131 builder.add_node(trait_item.trait_kw(db).as_syntax_node());
132 builder.add_node(trait_item.name(db).as_syntax_node());
133 builder.add_node(trait_item.generic_params(db).as_syntax_node());
134 builder.add_node(body.lbrace(db).as_syntax_node());
135 for item in items {
136 builder.add_node(item);
137 }
138 builder.add_node(body.rbrace(db).as_syntax_node());
139 Some(builder)
140 }
141 ast::ModuleItem::Impl(impl_item) => {
142 let body = try_extract_matches!(impl_item.body(db), ast::MaybeImplBody::Some)?;
143 let items = get_kept_items_nodes(db, cfg_set, body.iter_items(db), diagnostics)?;
144 let mut builder = PatchBuilder::new(db, &impl_item);
145 builder.add_node(impl_item.attributes(db).as_syntax_node());
146 builder.add_node(impl_item.impl_kw(db).as_syntax_node());
147 builder.add_node(impl_item.name(db).as_syntax_node());
148 builder.add_node(impl_item.generic_params(db).as_syntax_node());
149 builder.add_node(impl_item.of_kw(db).as_syntax_node());
150 builder.add_node(impl_item.trait_path(db).as_syntax_node());
151 builder.add_node(body.lbrace(db).as_syntax_node());
152 for item in items {
153 builder.add_node(item);
154 }
155 builder.add_node(body.rbrace(db).as_syntax_node());
156 Some(builder)
157 }
158 _ => None,
159 }
160}
161
162fn get_kept_items_nodes<Item: QueryAttrs + TypedSyntaxNode>(
165 db: &dyn SyntaxGroup,
166 cfg_set: &CfgSet,
167 all_items: impl Iterator<Item = Item>,
168 diagnostics: &mut Vec<PluginDiagnostic>,
169) -> Option<Vec<cairo_lang_syntax::node::SyntaxNode>> {
170 let mut any_dropped = false;
171 let mut kept_items_nodes = vec![];
172 for item in all_items {
173 if should_drop(db, cfg_set, &item, diagnostics) {
174 any_dropped = true;
175 } else {
176 kept_items_nodes.push(item.as_syntax_node());
177 }
178 }
179 if any_dropped { Some(kept_items_nodes) } else { None }
180}
181
182fn should_drop<Item: QueryAttrs>(
184 db: &dyn SyntaxGroup,
185 cfg_set: &CfgSet,
186 item: &Item,
187 diagnostics: &mut Vec<PluginDiagnostic>,
188) -> bool {
189 item.query_attr(db, CFG_ATTR).into_iter().any(|attr| {
190 match parse_predicate(db, attr.structurize(db), diagnostics) {
191 Some(predicate_tree) => !predicate_tree.evaluate(cfg_set),
192 None => false,
193 }
194 })
195}
196
197fn parse_predicate(
199 db: &dyn SyntaxGroup,
200 attr: Attribute,
201 diagnostics: &mut Vec<PluginDiagnostic>,
202) -> Option<PredicateTree> {
203 Some(PredicateTree::And(
204 attr.args
205 .into_iter()
206 .filter_map(|arg| parse_predicate_item(db, arg, diagnostics))
207 .collect(),
208 ))
209}
210
211fn parse_predicate_item(
213 db: &dyn SyntaxGroup,
214 item: AttributeArg,
215 diagnostics: &mut Vec<PluginDiagnostic>,
216) -> Option<PredicateTree> {
217 match extract_config_predicate_part(db, &item) {
218 Some(ConfigPredicatePart::Cfg(cfg)) => Some(PredicateTree::Cfg(cfg)),
219 Some(ConfigPredicatePart::Call(call)) => {
220 let operator = call.path(db).as_syntax_node().get_text(db);
221 let args = call
222 .arguments(db)
223 .arguments(db)
224 .elements(db)
225 .map(|arg| AttributeArg::from_ast(arg, db))
226 .collect_vec();
227
228 match operator.as_str() {
229 "not" => {
230 if args.len() != 1 {
231 diagnostics.push(PluginDiagnostic::error(
232 call.stable_ptr(db),
233 "`not` operator expects exactly one argument.".into(),
234 ));
235 None
236 } else {
237 Some(PredicateTree::Not(Box::new(parse_predicate_item(
238 db,
239 args[0].clone(),
240 diagnostics,
241 )?)))
242 }
243 }
244 "and" => {
245 if args.len() < 2 {
246 diagnostics.push(PluginDiagnostic::error(
247 call.stable_ptr(db),
248 "`and` operator expects at least two arguments.".into(),
249 ));
250 None
251 } else {
252 Some(PredicateTree::And(
253 args.into_iter()
254 .filter_map(|arg| parse_predicate_item(db, arg, diagnostics))
255 .collect(),
256 ))
257 }
258 }
259 "or" => {
260 if args.len() < 2 {
261 diagnostics.push(PluginDiagnostic::error(
262 call.stable_ptr(db),
263 "`or` operator expects at least two arguments.".into(),
264 ));
265 None
266 } else {
267 Some(PredicateTree::Or(
268 args.into_iter()
269 .filter_map(|arg| parse_predicate_item(db, arg, diagnostics))
270 .collect(),
271 ))
272 }
273 }
274 _ => {
275 diagnostics.push(PluginDiagnostic::error(
276 call.stable_ptr(db),
277 format!("Unsupported operator: `{operator}`."),
278 ));
279 None
280 }
281 }
282 }
283 None => {
284 diagnostics.push(PluginDiagnostic::error(
285 item.arg.stable_ptr(db).untyped(),
286 "Invalid configuration argument.".into(),
287 ));
288 None
289 }
290 }
291}
292
293fn extract_config_predicate_part(
295 db: &dyn SyntaxGroup,
296 arg: &AttributeArg,
297) -> Option<ConfigPredicatePart> {
298 match &arg.variant {
299 AttributeArgVariant::Unnamed(ast::Expr::Path(path)) => {
300 if let Some([ast::PathSegment::Simple(segment)]) =
301 path.segments(db).elements(db).collect_array()
302 {
303 Some(ConfigPredicatePart::Cfg(Cfg::name(segment.identifier(db).to_string())))
304 } else {
305 None
306 }
307 }
308 AttributeArgVariant::Unnamed(ast::Expr::FunctionCall(call)) => {
309 Some(ConfigPredicatePart::Call(call.clone()))
310 }
311 AttributeArgVariant::Named { name, value } => {
312 let value_text = match value {
313 ast::Expr::String(terminal) => terminal.string_value(db).unwrap_or_default(),
314 ast::Expr::ShortString(terminal) => terminal.string_value(db).unwrap_or_default(),
315 _ => return None,
316 };
317
318 Some(ConfigPredicatePart::Cfg(Cfg::kv(name.text.to_string(), value_text)))
319 }
320 _ => None,
321 }
322}