Skip to main content

patronus/expr/
transform.rs

1// Copyright 2023 The Regents of the University of California
2// released under BSD 3-Clause License
3// author: Kevin Laeufer <laeufer@berkeley.edu>
4
5use crate::expr::meta::get_fixed_point;
6use crate::expr::*;
7
8#[derive(Debug, Copy, Clone, Eq, PartialEq)]
9pub enum ExprTransformMode {
10    SingleStep,
11    FixedPoint,
12}
13
14/// transform an expression with a single step (no fixed point) and no persistent cache
15#[inline]
16pub fn simple_transform_expr(
17    ctx: &mut Context,
18    e: ExprRef,
19    tran: impl FnMut(&mut Context, ExprRef, &[ExprRef]) -> Option<ExprRef>,
20) -> ExprRef {
21    let mut cache = SparseExprMap::default();
22    do_transform_expr(
23        ctx,
24        ExprTransformMode::SingleStep,
25        &mut cache,
26        vec![e],
27        tran,
28    );
29    cache[e].unwrap()
30}
31
32#[inline]
33pub(crate) fn do_transform_expr<T: ExprMap<Option<ExprRef>>>(
34    ctx: &mut Context,
35    mode: ExprTransformMode,
36    transformed: &mut T,
37    mut todo: Vec<ExprRef>,
38    mut tran: impl FnMut(&mut Context, ExprRef, &[ExprRef]) -> Option<ExprRef>,
39) {
40    let mut children = Vec::with_capacity(4);
41
42    while let Some(expr_ref) = todo.pop() {
43        // check to see if we translated all the children
44        children.clear();
45        let mut children_changed = false; // track whether any of the children changed
46        let mut all_transformed = true; // tracks whether all children have been transformed or if there is more work to do
47        ctx[expr_ref].for_each_child(|c| {
48            let transformed_child = if mode == ExprTransformMode::FixedPoint {
49                get_fixed_point(transformed, *c)
50            } else {
51                transformed[*c]
52            };
53            match transformed_child {
54                Some(new_child_expr) => {
55                    if new_child_expr != *c {
56                        children_changed = true; // child changed
57                    }
58                    children.push(new_child_expr);
59                }
60                None => {
61                    if all_transformed {
62                        todo.push(expr_ref);
63                    }
64                    all_transformed = false;
65                    todo.push(*c);
66                }
67            }
68        });
69        if !all_transformed {
70            continue;
71        }
72
73        // call out to the transform
74        let tran_res = tran(ctx, expr_ref, &children);
75        let new_expr_ref = match tran_res {
76            Some(e) => e,
77            None => {
78                if children_changed {
79                    update_expr_children(ctx, expr_ref, &children)
80                } else {
81                    // if no children changed and the transform does not want to do changes,
82                    // we can just keep the old expression
83                    expr_ref
84                }
85            }
86        };
87        // remember the transformed version
88        transformed[expr_ref] = Some(new_expr_ref);
89
90        // in fixed point mode, we might not be done yet
91        let is_at_fixed_point = expr_ref == new_expr_ref;
92        if mode == ExprTransformMode::FixedPoint
93            && !is_at_fixed_point
94            && transformed[new_expr_ref].is_none()
95        {
96            // see if we can further simplify the new expression
97            todo.push(new_expr_ref);
98        }
99    }
100}
101
102fn update_expr_children(ctx: &mut Context, expr_ref: ExprRef, children: &[ExprRef]) -> ExprRef {
103    let new_expr = match (&ctx[expr_ref], children) {
104        (Expr::BVSymbol { .. }, _) => panic!("No children, should never get here."),
105        (Expr::BVLiteral { .. }, _) => panic!("No children, should never get here."),
106        (Expr::BVZeroExt { by, width, .. }, [e]) => Expr::BVZeroExt {
107            e: *e,
108            by: *by,
109            width: *width,
110        },
111        (Expr::BVSignExt { by, width, .. }, [e]) => Expr::BVSignExt {
112            e: *e,
113            by: *by,
114            width: *width,
115        },
116        (Expr::BVSlice { hi, lo, .. }, [e]) => Expr::BVSlice {
117            e: *e,
118            hi: *hi,
119            lo: *lo,
120        },
121        (Expr::BVNot(_, width), [e]) => Expr::BVNot(*e, *width),
122        (Expr::BVNegate(_, width), [e]) => Expr::BVNegate(*e, *width),
123        (Expr::BVEqual(_, _), [a, b]) => Expr::BVEqual(*a, *b),
124        (Expr::BVImplies(_, _), [a, b]) => Expr::BVImplies(*a, *b),
125        (Expr::BVGreater(_, _), [a, b]) => Expr::BVGreater(*a, *b),
126        (Expr::BVGreaterSigned(_, _, w), [a, b]) => Expr::BVGreaterSigned(*a, *b, *w),
127        (Expr::BVGreaterEqual(_, _), [a, b]) => Expr::BVGreaterEqual(*a, *b),
128        (Expr::BVGreaterEqualSigned(_, _, w), [a, b]) => Expr::BVGreaterEqualSigned(*a, *b, *w),
129        (Expr::BVConcat(_, _, w), [a, b]) => Expr::BVConcat(*a, *b, *w),
130        (Expr::BVAnd(_, _, w), [a, b]) => Expr::BVAnd(*a, *b, *w),
131        (Expr::BVOr(_, _, w), [a, b]) => Expr::BVOr(*a, *b, *w),
132        (Expr::BVXor(_, _, w), [a, b]) => Expr::BVXor(*a, *b, *w),
133        (Expr::BVShiftLeft(_, _, w), [a, b]) => Expr::BVShiftLeft(*a, *b, *w),
134        (Expr::BVArithmeticShiftRight(_, _, w), [a, b]) => Expr::BVArithmeticShiftRight(*a, *b, *w),
135        (Expr::BVShiftRight(_, _, w), [a, b]) => Expr::BVShiftRight(*a, *b, *w),
136        (Expr::BVAdd(_, _, w), [a, b]) => Expr::BVAdd(*a, *b, *w),
137        (Expr::BVMul(_, _, w), [a, b]) => Expr::BVMul(*a, *b, *w),
138        (Expr::BVSignedDiv(_, _, w), [a, b]) => Expr::BVSignedDiv(*a, *b, *w),
139        (Expr::BVUnsignedDiv(_, _, w), [a, b]) => Expr::BVUnsignedDiv(*a, *b, *w),
140        (Expr::BVSignedMod(_, _, w), [a, b]) => Expr::BVSignedMod(*a, *b, *w),
141        (Expr::BVSignedRem(_, _, w), [a, b]) => Expr::BVSignedRem(*a, *b, *w),
142        (Expr::BVUnsignedRem(_, _, w), [a, b]) => Expr::BVUnsignedRem(*a, *b, *w),
143        (Expr::BVSub(_, _, w), [a, b]) => Expr::BVSub(*a, *b, *w),
144        (Expr::BVArrayRead { width, .. }, [array, index]) => Expr::BVArrayRead {
145            array: *array,
146            index: *index,
147            width: *width,
148        },
149        (Expr::BVIte { .. }, [cond, tru, fals]) => Expr::BVIte {
150            cond: *cond,
151            tru: *tru,
152            fals: *fals,
153        },
154        (Expr::ArraySymbol { .. }, _) => panic!("No children, should never get here."),
155        (
156            Expr::ArrayConstant {
157                index_width,
158                data_width,
159                ..
160            },
161            [e],
162        ) => Expr::ArrayConstant {
163            e: *e,
164            index_width: *index_width,
165            data_width: *data_width,
166        },
167        (Expr::ArrayEqual(_, _), [a, b]) => Expr::ArrayEqual(*a, *b),
168        (Expr::ArrayStore { .. }, [array, index, data]) => Expr::ArrayStore {
169            array: *array,
170            index: *index,
171            data: *data,
172        },
173        (Expr::ArrayIte { .. }, [cond, tru, fals]) => Expr::ArrayIte {
174            cond: *cond,
175            tru: *tru,
176            fals: *fals,
177        },
178        (other, _) => {
179            todo!("implement code to re-create expression `{other:?}` with updated children")
180        }
181    };
182    ctx.add_expr(new_expr)
183}