wesl/
condcomp.rs

1use std::collections::HashMap;
2
3use crate::Diagnostic;
4use thiserror::Error;
5use wgsl_parse::{Decorated, span::Spanned, syntax::*};
6
7/// Conditional translation error.
8#[derive(Clone, Debug, Error)]
9pub enum CondCompError {
10    #[error("invalid feature flag: `{0}`")]
11    InvalidFeatureFlag(String),
12    #[error("unexpected feature flag: `{0}`")]
13    UnexpectedFeatureFlag(String),
14    #[error("invalid if attribute expression: `{0}`")]
15    InvalidExpression(Expression),
16    #[error("an @elif or @else attribute must be preceded by a @if or @elif on the previous node")]
17    NoPrecedingIf,
18    #[error("cannot have multiple @if/@elif/@else attributes on the same node")]
19    DuplicateIf,
20}
21
22type E = crate::Error;
23
24/// Set the behavior for a feature flag during conditional translation.
25///
26/// * `Keep` means that the feature flag will be left as-is. This is useful for
27///   incremental compilation, e.g. for generating shader variants
28/// * `Error` means that unspecified feature flags will trigger a
29///   [`CondCompError::UnexpectedFeatureFlag`].
30///
31/// Default is `Disable`.
32#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
33pub enum Feature {
34    Enable,
35    #[default]
36    Disable,
37    Keep,
38    Error,
39}
40
41/// Toggle conditional compilation feature flags.
42///
43/// Feature flags set to `true` are enabled, and `false` are disabled. Feature flags not
44/// present in `flags` are treated according to `default`.
45#[derive(Clone, Debug, Default, PartialEq, Eq)]
46pub struct Features {
47    pub default: Feature,
48    pub flags: HashMap<String, Feature>,
49}
50
51impl From<bool> for Feature {
52    fn from(value: bool) -> Self {
53        if value {
54            Feature::Enable
55        } else {
56            Feature::Disable
57        }
58    }
59}
60
61const EXPR_TRUE: Expression = Expression::Literal(LiteralExpression::Bool(true));
62const EXPR_FALSE: Expression = Expression::Literal(LiteralExpression::Bool(false));
63
64pub fn eval_attr(expr: &Expression, features: &Features) -> Result<Expression, E> {
65    fn eval_rec(expr: &ExpressionNode, features: &Features) -> Result<Expression, E> {
66        eval_attr(expr, features).map_err(|e| Diagnostic::from(e).with_span(expr.span()).into())
67    }
68
69    match expr {
70        Expression::Literal(LiteralExpression::Bool(_)) => Ok(expr.clone()),
71        Expression::Parenthesized(paren) => {
72            let expr = eval_rec(&paren.expression, features)?;
73            Ok(match expr {
74                Expression::Binary(_) => ParenthesizedExpression {
75                    expression: Spanned::new(expr, paren.expression.span()),
76                }
77                .into(),
78                _ => expr,
79            })
80        }
81        Expression::Unary(unary) => {
82            let operand = eval_rec(&unary.operand, features)?;
83            match &unary.operator {
84                UnaryOperator::LogicalNegation => {
85                    let expr = if operand == EXPR_TRUE {
86                        EXPR_FALSE.clone()
87                    } else if operand == EXPR_FALSE {
88                        EXPR_TRUE.clone()
89                    } else {
90                        expr.clone()
91                    };
92                    Ok(expr)
93                }
94                _ => Err(CondCompError::InvalidExpression(expr.clone()).into()),
95            }
96        }
97        Expression::Binary(binary) => {
98            let left = eval_rec(&binary.left, features)?;
99            let right = eval_rec(&binary.right, features)?;
100            match &binary.operator {
101                BinaryOperator::ShortCircuitOr => {
102                    let expr = if left == EXPR_TRUE || right == EXPR_TRUE {
103                        EXPR_TRUE.clone()
104                    } else if left == EXPR_FALSE && right == EXPR_FALSE {
105                        left // false
106                    } else if left == EXPR_FALSE {
107                        right
108                    } else if right == EXPR_FALSE {
109                        left
110                    } else {
111                        BinaryExpression {
112                            operator: binary.operator,
113                            left: Spanned::new(left, binary.left.span()),
114                            right: Spanned::new(right, binary.right.span()),
115                        }
116                        .into()
117                    };
118                    Ok(expr)
119                }
120                BinaryOperator::ShortCircuitAnd => {
121                    let expr = if left == EXPR_TRUE && right == EXPR_TRUE {
122                        left // true
123                    } else if left == EXPR_FALSE || right == EXPR_FALSE {
124                        EXPR_FALSE.clone()
125                    } else if left == EXPR_TRUE {
126                        right
127                    } else if right == EXPR_TRUE {
128                        left
129                    } else {
130                        BinaryExpression {
131                            operator: binary.operator,
132                            left: Spanned::new(left, binary.left.span()),
133                            right: Spanned::new(right, binary.right.span()),
134                        }
135                        .into()
136                    };
137                    Ok(expr)
138                }
139                _ => Err(CondCompError::InvalidExpression(expr.clone()).into()),
140            }
141        }
142        Expression::TypeOrIdentifier(ty) => {
143            if ty.template_args.is_some() {
144                return Err(CondCompError::InvalidFeatureFlag(ty.to_string()).into());
145            }
146            let feat = features
147                .flags
148                .get(&*ty.ident.name())
149                .unwrap_or(&features.default);
150            let expr = match feat {
151                Feature::Enable => EXPR_TRUE.clone(),
152                Feature::Disable => EXPR_FALSE.clone(),
153                Feature::Keep => expr.clone(),
154                Feature::Error => {
155                    return Err(
156                        CondCompError::UnexpectedFeatureFlag(ty.ident.name().to_string()).into(),
157                    );
158                }
159            };
160            Ok(expr)
161        }
162        _ => Err(CondCompError::InvalidExpression(expr.clone()).into()),
163    }
164}
165
166fn get_single_attr(attrs: &mut [AttributeNode]) -> Result<Option<&mut AttributeNode>, E> {
167    let mut it = attrs.iter_mut().filter(|attr| {
168        matches!(
169            attr.node(),
170            Attribute::If(_) | Attribute::Elif(_) | Attribute::Else
171        )
172    });
173    let attr = it.next();
174
175    if it.next().is_some() {
176        Err(CondCompError::DuplicateIf.into())
177    } else {
178        Ok(attr)
179    }
180}
181
182#[derive(Debug, Clone, Copy, PartialEq, Eq)]
183struct PrevEval {
184    has_if: bool,
185    is_true: bool,
186    removed: bool,
187}
188
189/// * ensure there is at most one if/elif/else node.
190/// * ensure elif/else nodes are preceded by if/elif.
191/// * remove the attributes which evaluate to true.
192/// * turn elifs into ifs when previous node was deleted.
193/// * turn elifs into elses when it evaluates to true.
194fn eval_if_attr(
195    node: &mut impl Decorated,
196    prev: &mut PrevEval,
197    features: &Features,
198) -> Result<(), E> {
199    let attr = get_single_attr(node.attributes_mut())?;
200    if let Some(attr) = attr {
201        let mut has_if = false;
202        if let Attribute::If(expr) = attr.node_mut() {
203            **expr = eval_attr(expr, features)?;
204            has_if = true;
205            prev.is_true = false;
206        } else if let Attribute::Elif(expr) = attr.node_mut() {
207            if !prev.has_if {
208                return Err(CondCompError::NoPrecedingIf.into());
209            } else {
210                **expr = eval_attr(expr, features)?;
211                has_if = true;
212            }
213        } else if let Attribute::Else = attr.node() {
214            if !prev.has_if {
215                return Err(CondCompError::NoPrecedingIf.into());
216            }
217        }
218        prev.has_if = has_if;
219    } else {
220        prev.has_if = false;
221    }
222
223    let mut remove_node = false;
224    let mut remove_attr = false;
225    let mut is_true = false;
226    node.retain_attributes_mut(|attr| {
227        if let Attribute::If(expr) = attr {
228            if **expr == EXPR_TRUE {
229                remove_attr = true; // if(true) => remove the attribute
230                is_true = true;
231            } else if **expr == EXPR_FALSE {
232                remove_node = true; // if(false) => remove the node
233            }
234        } else if let Attribute::Elif(expr) = attr {
235            if prev.is_true || **expr == EXPR_FALSE {
236                remove_node = true;
237            } else if **expr == EXPR_TRUE {
238                is_true = true;
239                if prev.removed {
240                    remove_attr = true;
241                } else {
242                    *attr = Attribute::Else;
243                }
244            } else if prev.removed {
245                *attr = Attribute::If(expr.clone()); // previous node was deleted, make this an if
246            }
247        } else if let Attribute::Else = attr {
248            if prev.is_true {
249                remove_node = true; // previous node was chosen, delete the whole node
250            } else if prev.removed {
251                remove_attr = true; // previous node was deleted, delete this attribute
252            }
253        } else {
254            // we keep non-condcomp attributes
255            return true;
256        }
257
258        !remove_attr
259    });
260
261    prev.is_true = is_true || prev.is_true;
262    prev.removed = remove_node;
263    Ok(())
264}
265
266fn eval_opt_attr(
267    opt_node: &mut Option<impl Decorated>,
268    prev: &mut PrevEval,
269    features: &Features,
270) -> Result<(), E> {
271    if let Some(node) = opt_node {
272        eval_if_attr(node, prev, features)?;
273        if prev.removed {
274            *opt_node = None;
275        }
276    }
277    Ok(())
278}
279
280fn eval_if_attrs(nodes: &mut Vec<impl Decorated>, features: &Features) -> Result<PrevEval, E> {
281    let mut prev = PrevEval {
282        has_if: false,
283        is_true: false,
284        removed: false,
285    };
286    let mut err = None;
287
288    // remove the nodes for which the attr evaluate to false.
289    nodes.retain_mut(|node| {
290        let res = eval_if_attr(node, &mut prev, features);
291        if let Err(e) = res {
292            err = Some(e);
293        }
294        !prev.removed // keep the node if attr is unresolved or true.
295    });
296
297    if let Some(e) = err {
298        Err(e)
299    } else {
300        Ok(prev)
301    }
302}
303
304fn stmt_eval_if_attrs(statements: &mut Vec<StatementNode>, features: &Features) -> Result<(), E> {
305    fn rec_one(stmt: &mut StatementNode, feats: &Features) -> Result<(), E> {
306        match stmt.node_mut() {
307            Statement::Compound(stmt) => {
308                rec(&mut stmt.statements, feats)?;
309            }
310            Statement::If(stmt) => {
311                rec(&mut stmt.if_clause.body.statements, feats)?;
312                for elif in &mut stmt.else_if_clauses {
313                    rec(&mut elif.body.statements, feats)?;
314                }
315                if let Some(el) = &mut stmt.else_clause {
316                    rec(&mut el.body.statements, feats)?;
317                }
318            }
319            Statement::Switch(stmt) => {
320                eval_if_attrs(&mut stmt.clauses, feats)?;
321                for clause in &mut stmt.clauses {
322                    rec(&mut clause.body.statements, feats)?;
323                }
324            }
325            Statement::Loop(stmt) => {
326                let mut prev = rec(&mut stmt.body.statements, feats)?;
327                eval_opt_attr(&mut stmt.continuing, &mut prev, feats)?;
328                if let Some(cont) = &mut stmt.continuing {
329                    rec(&mut cont.body.statements, feats)?;
330                    eval_opt_attr(&mut cont.break_if, &mut prev, feats)?;
331                }
332                rec(&mut stmt.body.statements, feats)?;
333            }
334            Statement::For(stmt) => {
335                if let Some(init) = &mut stmt.initializer {
336                    rec_one(&mut *init, feats)?
337                }
338                if let Some(updt) = &mut stmt.update {
339                    rec_one(&mut *updt, feats)?
340                }
341                rec(&mut stmt.body.statements, feats)?;
342            }
343            Statement::While(stmt) => {
344                rec(&mut stmt.body.statements, feats)?;
345            }
346            _ => (),
347        };
348        Ok(())
349    }
350    fn rec(stats: &mut Vec<StatementNode>, feats: &Features) -> Result<PrevEval, E> {
351        let prev = eval_if_attrs(stats, feats)?;
352        for stmt in stats {
353            rec_one(stmt, feats)?;
354        }
355        Ok(prev)
356    }
357    rec(statements, features).map(|_| ())
358}
359
360pub fn run(wesl: &mut TranslationUnit, features: &Features) -> Result<(), E> {
361    wesl.remove_voids();
362    eval_if_attrs(&mut wesl.imports, features)?;
363    eval_if_attrs(&mut wesl.global_directives, features)?;
364    eval_if_attrs(&mut wesl.global_declarations, features)?;
365
366    for decl in &mut wesl.global_declarations {
367        if let GlobalDeclaration::Struct(decl) = decl.node_mut() {
368            eval_if_attrs(&mut decl.members, features)
369                .map_err(|e| Diagnostic::from(e).with_declaration(decl.ident.to_string()))?;
370        } else if let GlobalDeclaration::Function(decl) = decl.node_mut() {
371            eval_if_attrs(&mut decl.parameters, features)
372                .map_err(|e| Diagnostic::from(e).with_declaration(decl.ident.to_string()))?;
373            stmt_eval_if_attrs(&mut decl.body.statements, features)
374                .map_err(|e| Diagnostic::from(e).with_declaration(decl.ident.to_string()))?;
375        }
376    }
377
378    Ok(())
379}