1use std::collections::HashMap;
2
3use crate::Diagnostic;
4use thiserror::Error;
5use wgsl_parse::{Decorated, span::Spanned, syntax::*};
6
7#[derive(Clone, Debug, Error)]
9pub enum CondCompError {
10 #[error("invalid feature flag: `{0}`")]
11 InvalidFeatureFlag(String),
12 #[error("unexpected feature flag: `{0}`")]
13 UnexpectedFeatureFlag(String),
14 #[error("invalid if attribute expression: `{0}`")]
15 InvalidExpression(Expression),
16 #[error("an @elif or @else attribute must be preceded by a @if or @elif on the previous node")]
17 NoPrecedingIf,
18 #[error("cannot have multiple @if/@elif/@else attributes on the same node")]
19 DuplicateIf,
20}
21
22type E = crate::Error;
23
24#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
33pub enum Feature {
34 Enable,
35 #[default]
36 Disable,
37 Keep,
38 Error,
39}
40
41#[derive(Clone, Debug, Default, PartialEq, Eq)]
46pub struct Features {
47 pub default: Feature,
48 pub flags: HashMap<String, Feature>,
49}
50
51impl From<bool> for Feature {
52 fn from(value: bool) -> Self {
53 if value {
54 Feature::Enable
55 } else {
56 Feature::Disable
57 }
58 }
59}
60
61const EXPR_TRUE: Expression = Expression::Literal(LiteralExpression::Bool(true));
62const EXPR_FALSE: Expression = Expression::Literal(LiteralExpression::Bool(false));
63
64pub fn eval_attr(expr: &Expression, features: &Features) -> Result<Expression, E> {
65 fn eval_rec(expr: &ExpressionNode, features: &Features) -> Result<Expression, E> {
66 eval_attr(expr, features).map_err(|e| Diagnostic::from(e).with_span(expr.span()).into())
67 }
68
69 match expr {
70 Expression::Literal(LiteralExpression::Bool(_)) => Ok(expr.clone()),
71 Expression::Parenthesized(paren) => {
72 let expr = eval_rec(&paren.expression, features)?;
73 Ok(match expr {
74 Expression::Binary(_) => ParenthesizedExpression {
75 expression: Spanned::new(expr, paren.expression.span()),
76 }
77 .into(),
78 _ => expr,
79 })
80 }
81 Expression::Unary(unary) => {
82 let operand = eval_rec(&unary.operand, features)?;
83 match &unary.operator {
84 UnaryOperator::LogicalNegation => {
85 let expr = if operand == EXPR_TRUE {
86 EXPR_FALSE.clone()
87 } else if operand == EXPR_FALSE {
88 EXPR_TRUE.clone()
89 } else {
90 expr.clone()
91 };
92 Ok(expr)
93 }
94 _ => Err(CondCompError::InvalidExpression(expr.clone()).into()),
95 }
96 }
97 Expression::Binary(binary) => {
98 let left = eval_rec(&binary.left, features)?;
99 let right = eval_rec(&binary.right, features)?;
100 match &binary.operator {
101 BinaryOperator::ShortCircuitOr => {
102 let expr = if left == EXPR_TRUE || right == EXPR_TRUE {
103 EXPR_TRUE.clone()
104 } else if left == EXPR_FALSE && right == EXPR_FALSE {
105 left } else if left == EXPR_FALSE {
107 right
108 } else if right == EXPR_FALSE {
109 left
110 } else {
111 BinaryExpression {
112 operator: binary.operator,
113 left: Spanned::new(left, binary.left.span()),
114 right: Spanned::new(right, binary.right.span()),
115 }
116 .into()
117 };
118 Ok(expr)
119 }
120 BinaryOperator::ShortCircuitAnd => {
121 let expr = if left == EXPR_TRUE && right == EXPR_TRUE {
122 left } else if left == EXPR_FALSE || right == EXPR_FALSE {
124 EXPR_FALSE.clone()
125 } else if left == EXPR_TRUE {
126 right
127 } else if right == EXPR_TRUE {
128 left
129 } else {
130 BinaryExpression {
131 operator: binary.operator,
132 left: Spanned::new(left, binary.left.span()),
133 right: Spanned::new(right, binary.right.span()),
134 }
135 .into()
136 };
137 Ok(expr)
138 }
139 _ => Err(CondCompError::InvalidExpression(expr.clone()).into()),
140 }
141 }
142 Expression::TypeOrIdentifier(ty) => {
143 if ty.template_args.is_some() {
144 return Err(CondCompError::InvalidFeatureFlag(ty.to_string()).into());
145 }
146 let feat = features
147 .flags
148 .get(&*ty.ident.name())
149 .unwrap_or(&features.default);
150 let expr = match feat {
151 Feature::Enable => EXPR_TRUE.clone(),
152 Feature::Disable => EXPR_FALSE.clone(),
153 Feature::Keep => expr.clone(),
154 Feature::Error => {
155 return Err(
156 CondCompError::UnexpectedFeatureFlag(ty.ident.name().to_string()).into(),
157 );
158 }
159 };
160 Ok(expr)
161 }
162 _ => Err(CondCompError::InvalidExpression(expr.clone()).into()),
163 }
164}
165
166fn get_single_attr(attrs: &mut [AttributeNode]) -> Result<Option<&mut AttributeNode>, E> {
167 let mut it = attrs.iter_mut().filter(|attr| {
168 matches!(
169 attr.node(),
170 Attribute::If(_) | Attribute::Elif(_) | Attribute::Else
171 )
172 });
173 let attr = it.next();
174
175 if it.next().is_some() {
176 Err(CondCompError::DuplicateIf.into())
177 } else {
178 Ok(attr)
179 }
180}
181
182#[derive(Debug, Clone, Copy, PartialEq, Eq)]
183struct PrevEval {
184 has_if: bool,
185 is_true: bool,
186 removed: bool,
187}
188
189fn eval_if_attr(
195 node: &mut impl Decorated,
196 prev: &mut PrevEval,
197 features: &Features,
198) -> Result<(), E> {
199 let attr = get_single_attr(node.attributes_mut())?;
200 if let Some(attr) = attr {
201 let mut has_if = false;
202 if let Attribute::If(expr) = attr.node_mut() {
203 **expr = eval_attr(expr, features)?;
204 has_if = true;
205 prev.is_true = false;
206 } else if let Attribute::Elif(expr) = attr.node_mut() {
207 if !prev.has_if {
208 return Err(CondCompError::NoPrecedingIf.into());
209 } else {
210 **expr = eval_attr(expr, features)?;
211 has_if = true;
212 }
213 } else if let Attribute::Else = attr.node() {
214 if !prev.has_if {
215 return Err(CondCompError::NoPrecedingIf.into());
216 }
217 }
218 prev.has_if = has_if;
219 } else {
220 prev.has_if = false;
221 }
222
223 let mut remove_node = false;
224 let mut remove_attr = false;
225 let mut is_true = false;
226 node.retain_attributes_mut(|attr| {
227 if let Attribute::If(expr) = attr {
228 if **expr == EXPR_TRUE {
229 remove_attr = true; is_true = true;
231 } else if **expr == EXPR_FALSE {
232 remove_node = true; }
234 } else if let Attribute::Elif(expr) = attr {
235 if prev.is_true || **expr == EXPR_FALSE {
236 remove_node = true;
237 } else if **expr == EXPR_TRUE {
238 is_true = true;
239 if prev.removed {
240 remove_attr = true;
241 } else {
242 *attr = Attribute::Else;
243 }
244 } else if prev.removed {
245 *attr = Attribute::If(expr.clone()); }
247 } else if let Attribute::Else = attr {
248 if prev.is_true {
249 remove_node = true; } else if prev.removed {
251 remove_attr = true; }
253 } else {
254 return true;
256 }
257
258 !remove_attr
259 });
260
261 prev.is_true = is_true || prev.is_true;
262 prev.removed = remove_node;
263 Ok(())
264}
265
266fn eval_opt_attr(
267 opt_node: &mut Option<impl Decorated>,
268 prev: &mut PrevEval,
269 features: &Features,
270) -> Result<(), E> {
271 if let Some(node) = opt_node {
272 eval_if_attr(node, prev, features)?;
273 if prev.removed {
274 *opt_node = None;
275 }
276 }
277 Ok(())
278}
279
280fn eval_if_attrs(nodes: &mut Vec<impl Decorated>, features: &Features) -> Result<PrevEval, E> {
281 let mut prev = PrevEval {
282 has_if: false,
283 is_true: false,
284 removed: false,
285 };
286 let mut err = None;
287
288 nodes.retain_mut(|node| {
290 let res = eval_if_attr(node, &mut prev, features);
291 if let Err(e) = res {
292 err = Some(e);
293 }
294 !prev.removed });
296
297 if let Some(e) = err {
298 Err(e)
299 } else {
300 Ok(prev)
301 }
302}
303
304fn stmt_eval_if_attrs(statements: &mut Vec<StatementNode>, features: &Features) -> Result<(), E> {
305 fn rec_one(stmt: &mut StatementNode, feats: &Features) -> Result<(), E> {
306 match stmt.node_mut() {
307 Statement::Compound(stmt) => {
308 rec(&mut stmt.statements, feats)?;
309 }
310 Statement::If(stmt) => {
311 rec(&mut stmt.if_clause.body.statements, feats)?;
312 for elif in &mut stmt.else_if_clauses {
313 rec(&mut elif.body.statements, feats)?;
314 }
315 if let Some(el) = &mut stmt.else_clause {
316 rec(&mut el.body.statements, feats)?;
317 }
318 }
319 Statement::Switch(stmt) => {
320 eval_if_attrs(&mut stmt.clauses, feats)?;
321 for clause in &mut stmt.clauses {
322 rec(&mut clause.body.statements, feats)?;
323 }
324 }
325 Statement::Loop(stmt) => {
326 let mut prev = rec(&mut stmt.body.statements, feats)?;
327 eval_opt_attr(&mut stmt.continuing, &mut prev, feats)?;
328 if let Some(cont) = &mut stmt.continuing {
329 rec(&mut cont.body.statements, feats)?;
330 eval_opt_attr(&mut cont.break_if, &mut prev, feats)?;
331 }
332 rec(&mut stmt.body.statements, feats)?;
333 }
334 Statement::For(stmt) => {
335 if let Some(init) = &mut stmt.initializer {
336 rec_one(&mut *init, feats)?
337 }
338 if let Some(updt) = &mut stmt.update {
339 rec_one(&mut *updt, feats)?
340 }
341 rec(&mut stmt.body.statements, feats)?;
342 }
343 Statement::While(stmt) => {
344 rec(&mut stmt.body.statements, feats)?;
345 }
346 _ => (),
347 };
348 Ok(())
349 }
350 fn rec(stats: &mut Vec<StatementNode>, feats: &Features) -> Result<PrevEval, E> {
351 let prev = eval_if_attrs(stats, feats)?;
352 for stmt in stats {
353 rec_one(stmt, feats)?;
354 }
355 Ok(prev)
356 }
357 rec(statements, features).map(|_| ())
358}
359
360pub fn run(wesl: &mut TranslationUnit, features: &Features) -> Result<(), E> {
361 wesl.remove_voids();
362 eval_if_attrs(&mut wesl.imports, features)?;
363 eval_if_attrs(&mut wesl.global_directives, features)?;
364 eval_if_attrs(&mut wesl.global_declarations, features)?;
365
366 for decl in &mut wesl.global_declarations {
367 if let GlobalDeclaration::Struct(decl) = decl.node_mut() {
368 eval_if_attrs(&mut decl.members, features)
369 .map_err(|e| Diagnostic::from(e).with_declaration(decl.ident.to_string()))?;
370 } else if let GlobalDeclaration::Function(decl) = decl.node_mut() {
371 eval_if_attrs(&mut decl.parameters, features)
372 .map_err(|e| Diagnostic::from(e).with_declaration(decl.ident.to_string()))?;
373 stmt_eval_if_attrs(&mut decl.body.statements, features)
374 .map_err(|e| Diagnostic::from(e).with_declaration(decl.ident.to_string()))?;
375 }
376 }
377
378 Ok(())
379}