libpatron/ir/
transform.rs

1// Copyright 2023 The Regents of the University of California
2// released under BSD 3-Clause License
3// author: Kevin Laeufer <laeufer@berkeley.edu>
4
5use crate::btor2::{DEFAULT_INPUT_PREFIX, DEFAULT_STATE_PREFIX};
6use crate::ir::*;
7use std::collections::HashMap;
8
9/** Remove any inputs named `_input_[...]` and replace their use with a literal zero.
10 * This essentially gets rid of all undefined value modelling by yosys.
11 */
12pub fn replace_anonymous_inputs_with_zero(ctx: &mut Context, sys: &mut TransitionSystem) {
13    // find and remove inputs
14    let mut replace_map = HashMap::new();
15    for (expr, signal_info) in sys.get_signals(|s| s.is_input()) {
16        let name = expr.get_symbol_name(ctx).unwrap();
17        if name.starts_with(DEFAULT_INPUT_PREFIX) || name.starts_with(DEFAULT_STATE_PREFIX) {
18            let replacement = match expr.get_type(ctx) {
19                Type::BV(width) => ctx.zero(width),
20                Type::Array(tpe) => ctx.zero_array(tpe),
21            };
22            replace_map.insert(expr, replacement);
23            sys.remove_signal(expr);
24            // re-insert signal info if the input has labels
25            if !signal_info.labels.is_none() {
26                sys.add_signal(
27                    replacement,
28                    SignalKind::Node,
29                    signal_info.labels,
30                    signal_info.name,
31                );
32            }
33        }
34    }
35
36    // replace any use of the input with zero
37    do_transform(ctx, sys, |_ctx, expr, _children| {
38        replace_map.get(&expr).cloned()
39    });
40}
41
42/// Applies simplifications to the expressions used in the system.
43pub fn simplify_expressions(ctx: &mut Context, sys: &mut TransitionSystem) {
44    do_transform(ctx, sys, simplify);
45}
46
47fn simplify(ctx: &mut Context, expr: ExprRef, children: &[ExprRef]) -> Option<ExprRef> {
48    match (ctx.get(expr).clone(), children) {
49        (Expr::BVIte { .. }, [cond, tru, fals]) => {
50            if tru == fals {
51                // condition does not matter
52                Some(*tru)
53            } else if let Expr::BVLiteral { value, .. } = ctx.get(*cond) {
54                if *value == 0 {
55                    Some(*fals)
56                } else {
57                    Some(*tru)
58                }
59            } else {
60                None
61            }
62        }
63        (Expr::BVAnd(_, _, width), [a, b]) => {
64            if let (Expr::BVLiteral { value: va, .. }, Expr::BVLiteral { value: vb, .. }) =
65                (ctx.get(*a), ctx.get(*b))
66            {
67                Some(ctx.bv_lit(*va & *vb, width))
68            } else {
69                None
70            }
71        }
72        (Expr::BVOr(_, _, width), [a, b]) => {
73            if let (Expr::BVLiteral { value: va, .. }, Expr::BVLiteral { value: vb, .. }) =
74                (ctx.get(*a), ctx.get(*b))
75            {
76                Some(ctx.bv_lit(*va | *vb, width))
77            } else {
78                None
79            }
80        }
81        (Expr::BVNot(_, width), [e]) => {
82            match ctx.get(*e) {
83                Expr::BVNot(inner, _) => Some(*inner), // double negation
84                Expr::BVLiteral { value, .. } => {
85                    Some(ctx.bv_lit((!*value) & value::mask(width), width))
86                }
87                _ => None,
88            }
89        }
90        (Expr::BVZeroExt { width, .. }, [e]) => match ctx.get(*e) {
91            Expr::BVLiteral { value, .. } => Some(ctx.bv_lit(*value, width)),
92            _ => None,
93        },
94        // combine slices
95        (Expr::BVSlice { lo, hi, .. }, [e]) => match ctx.get(*e) {
96            Expr::BVSlice {
97                lo: inner_lo,
98                e: inner_e,
99                ..
100            } => Some(ctx.slice(*inner_e, hi + inner_lo, lo + inner_lo)),
101            _ => None,
102        },
103        _ => None, // no matching simplification
104    }
105}
106
107pub fn do_transform(
108    ctx: &mut Context,
109    sys: &mut TransitionSystem,
110    tran: impl FnMut(&mut Context, ExprRef, &[ExprRef]) -> Option<ExprRef>,
111) {
112    let todo = get_root_expressions(sys);
113    let transformed = do_transform_expr(ctx, todo, tran);
114
115    // update transition system signals
116    for (old_expr, maybe_new_expr) in transformed.iter() {
117        if let Some(new_expr) = maybe_new_expr {
118            if *new_expr != old_expr {
119                sys.update_signal_expr(old_expr, *new_expr);
120            }
121        }
122    }
123    // update states
124    for state in sys.states.iter_mut() {
125        if let Some(new_symbol) = changed(&transformed, state.symbol) {
126            state.symbol = new_symbol;
127        }
128        if let Some(old_next) = state.next {
129            if let Some(new_next) = changed(&transformed, old_next) {
130                state.next = Some(new_next);
131            }
132        }
133        if let Some(old_init) = state.init {
134            if let Some(new_init) = changed(&transformed, old_init) {
135                state.init = Some(new_init);
136            }
137        }
138    }
139}
140
141fn changed(transformed: &ExprMetaData<Option<ExprRef>>, old_expr: ExprRef) -> Option<ExprRef> {
142    if let Some(new_expr) = transformed.get(old_expr) {
143        if *new_expr != old_expr {
144            Some(*new_expr)
145        } else {
146            None
147        }
148    } else {
149        None
150    }
151}
152
153fn do_transform_expr(
154    ctx: &mut Context,
155    mut todo: Vec<ExprRef>,
156    mut tran: impl FnMut(&mut Context, ExprRef, &[ExprRef]) -> Option<ExprRef>,
157) -> ExprMetaData<Option<ExprRef>> {
158    let mut transformed = ExprMetaData::default();
159    let mut children = Vec::with_capacity(4);
160
161    while let Some(expr_ref) = todo.pop() {
162        // check to see if we translated all the children
163        children.clear();
164        let mut children_changed = false; // track whether any of the children changed
165        let mut all_transformed = true; // tracks whether all children have been transformed or if there is more work to do
166        ctx.get(expr_ref).for_each_child(|c| {
167            match transformed.get(*c) {
168                Some(new_child_expr) => {
169                    if *new_child_expr != *c {
170                        children_changed = true; // child changed
171                    }
172                    children.push(*new_child_expr);
173                }
174                None => {
175                    if all_transformed {
176                        todo.push(expr_ref);
177                    }
178                    all_transformed = false;
179                    todo.push(*c);
180                }
181            }
182        });
183        if !all_transformed {
184            continue;
185        }
186
187        // call out to the transform
188        let tran_res = (tran)(ctx, expr_ref, &children);
189        let new_expr_ref = match tran_res {
190            Some(e) => e,
191            None => {
192                if children_changed {
193                    update_expr_children(ctx, expr_ref, &children)
194                } else {
195                    // if no children changed and the transform does not want to do changes,
196                    // we can just keep the old expression
197                    expr_ref
198                }
199            }
200        };
201        // remember the transformed version
202        *transformed.get_mut(expr_ref) = Some(new_expr_ref);
203    }
204    transformed
205}
206
207fn get_root_expressions(sys: &TransitionSystem) -> Vec<ExprRef> {
208    // include all input, output, assertion and assumptions expressions
209    let mut out = Vec::from_iter(
210        sys.get_signals(is_usage_root_signal)
211            .iter()
212            .map(|(e, _)| *e),
213    );
214
215    // include all states
216    for (_, state) in sys.states() {
217        out.push(state.symbol);
218        if let Some(init) = state.init {
219            out.push(init);
220        }
221        if let Some(next) = state.next {
222            out.push(next);
223        }
224    }
225
226    out
227}
228
229fn update_expr_children(ctx: &mut Context, expr_ref: ExprRef, children: &[ExprRef]) -> ExprRef {
230    let new_expr = match (ctx.get(expr_ref), children) {
231        (Expr::BVSymbol { .. }, _) => panic!("No children, should never get here."),
232        (Expr::BVLiteral { .. }, _) => panic!("No children, should never get here."),
233        (Expr::BVZeroExt { by, width, .. }, [e]) => Expr::BVZeroExt {
234            e: *e,
235            by: *by,
236            width: *width,
237        },
238        (Expr::BVSignExt { by, width, .. }, [e]) => Expr::BVSignExt {
239            e: *e,
240            by: *by,
241            width: *width,
242        },
243        (Expr::BVSlice { hi, lo, .. }, [e]) => Expr::BVSlice {
244            e: *e,
245            hi: *hi,
246            lo: *lo,
247        },
248        (Expr::BVNot(_, width), [e]) => Expr::BVNot(*e, *width),
249        (Expr::BVNegate(_, width), [e]) => Expr::BVNegate(*e, *width),
250        (Expr::BVEqual(_, _), [a, b]) => Expr::BVEqual(*a, *b),
251        (Expr::BVImplies(_, _), [a, b]) => Expr::BVImplies(*a, *b),
252        (Expr::BVGreater(_, _), [a, b]) => Expr::BVGreater(*a, *b),
253        (Expr::BVGreaterSigned(_, _, w), [a, b]) => Expr::BVGreaterSigned(*a, *b, *w),
254        (Expr::BVGreaterEqual(_, _), [a, b]) => Expr::BVGreaterEqual(*a, *b),
255        (Expr::BVGreaterEqualSigned(_, _, w), [a, b]) => Expr::BVGreaterEqualSigned(*a, *b, *w),
256        (Expr::BVConcat(_, _, w), [a, b]) => Expr::BVConcat(*a, *b, *w),
257        (Expr::BVAnd(_, _, w), [a, b]) => Expr::BVAnd(*a, *b, *w),
258        (Expr::BVOr(_, _, w), [a, b]) => Expr::BVOr(*a, *b, *w),
259        (Expr::BVXor(_, _, w), [a, b]) => Expr::BVXor(*a, *b, *w),
260        (Expr::BVShiftLeft(_, _, w), [a, b]) => Expr::BVShiftLeft(*a, *b, *w),
261        (Expr::BVArithmeticShiftRight(_, _, w), [a, b]) => Expr::BVArithmeticShiftRight(*a, *b, *w),
262        (Expr::BVShiftRight(_, _, w), [a, b]) => Expr::BVShiftRight(*a, *b, *w),
263        (Expr::BVAdd(_, _, w), [a, b]) => Expr::BVAdd(*a, *b, *w),
264        (Expr::BVMul(_, _, w), [a, b]) => Expr::BVMul(*a, *b, *w),
265        (Expr::BVSignedDiv(_, _, w), [a, b]) => Expr::BVSignedDiv(*a, *b, *w),
266        (Expr::BVUnsignedDiv(_, _, w), [a, b]) => Expr::BVUnsignedDiv(*a, *b, *w),
267        (Expr::BVSignedMod(_, _, w), [a, b]) => Expr::BVSignedMod(*a, *b, *w),
268        (Expr::BVSignedRem(_, _, w), [a, b]) => Expr::BVSignedRem(*a, *b, *w),
269        (Expr::BVUnsignedRem(_, _, w), [a, b]) => Expr::BVUnsignedRem(*a, *b, *w),
270        (Expr::BVSub(_, _, w), [a, b]) => Expr::BVSub(*a, *b, *w),
271        (Expr::BVArrayRead { width, .. }, [array, index]) => Expr::BVArrayRead {
272            array: *array,
273            index: *index,
274            width: *width,
275        },
276        (Expr::BVIte { .. }, [cond, tru, fals]) => Expr::BVIte {
277            cond: *cond,
278            tru: *tru,
279            fals: *fals,
280        },
281        (Expr::ArraySymbol { .. }, _) => panic!("No children, should never get here."),
282        (
283            Expr::ArrayConstant {
284                index_width,
285                data_width,
286                ..
287            },
288            [e],
289        ) => Expr::ArrayConstant {
290            e: *e,
291            index_width: *index_width,
292            data_width: *data_width,
293        },
294        (Expr::ArrayEqual(_, _), [a, b]) => Expr::ArrayEqual(*a, *b),
295        (Expr::ArrayStore { .. }, [array, index, data]) => Expr::ArrayStore {
296            array: *array,
297            index: *index,
298            data: *data,
299        },
300        (Expr::ArrayIte { .. }, [cond, tru, fals]) => Expr::ArrayIte {
301            cond: *cond,
302            tru: *tru,
303            fals: *fals,
304        },
305        (other, _) => {
306            todo!("implement code to re-create expression `{other:?}` with updated children")
307        }
308    };
309    ctx.add_node(new_expr)
310}
311
312/// Slightly different definition from use counts, as we want to retain all inputs in transformation passes.
313fn is_usage_root_signal(info: &SignalInfo) -> bool {
314    info.is_input()
315        || info.labels.is_output()
316        || info.labels.is_constraint()
317        || info.labels.is_bad()
318        || info.labels.is_fair()
319}