Skip to main content

patronus/expr/
context.rs

1// Copyright 2023 The Regents of the University of California
2// Copyright 2024 Cornell University
3// released under BSD 3-Clause License
4// author: Kevin Laeufer <laeufer@cornell.edu>
5
6//! # IR Context
7//!
8//! The [`Context`] is used to create and access bit-vector and array expressions.
9//! It ensures that the same expression always maps to the same expression reference.
10//! Thus, if two references are equal, we can be certain that the expressions they point to are
11//! equivalent.
12//!
13//! Users are expected to generally use a single Context for all their expressions. There
14//! are no checks to ensure that a [`ExprRef`] or [`StringRef`] from different contexts are
15//! not matched. Thus working with more than one [`Context`] object can be dangerous.
16
17use crate::expr::TypeCheck;
18use crate::expr::nodes::*;
19use baa::{
20    ArrayOps, BitVecOps, BitVecValue, BitVecValueIndex, BitVecValueRef, IndexToRef,
21    SparseArrayValue, Value,
22};
23use rustc_hash::FxBuildHasher;
24use std::borrow::Borrow;
25use std::cell::RefCell;
26use std::fmt::{Debug, Formatter};
27use std::num::NonZeroU32;
28use std::ops::Index;
29
30/// Uniquely identifies a [`String`] stored in a [`Context`].
31#[derive(PartialEq, Eq, Clone, Copy, Hash, PartialOrd, Ord)]
32pub struct StringRef(NonZeroU32);
33
34impl Debug for StringRef {
35    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
36        write!(f, "StringRef({})", self.index())
37    }
38}
39
40impl StringRef {
41    fn from_index(index: usize) -> Self {
42        Self(NonZeroU32::new((index + 1) as u32).unwrap())
43    }
44
45    fn index(&self) -> usize {
46        (self.0.get() - 1) as usize
47    }
48}
49
50/// Uniquely identifies an [`Expr`] stored in a [`Context`].
51#[derive(PartialEq, Eq, Clone, Copy, Hash, Ord, PartialOrd)]
52pub struct ExprRef(NonZeroU32);
53
54impl Debug for ExprRef {
55    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
56        // we need a custom implementation in order to show the zero based index
57        let index: usize = (*self).into();
58        write!(f, "ExprRef({})", index)
59    }
60}
61
62impl From<ExprRef> for usize {
63    fn from(value: ExprRef) -> Self {
64        (value.0.get() - 1) as usize
65    }
66}
67
68impl From<usize> for ExprRef {
69    fn from(index: usize) -> Self {
70        ExprRef(NonZeroU32::new((index + 1) as u32).unwrap())
71    }
72}
73
74/// Context which is used to create all SMT expressions. Expressions are interned such that
75/// reference equivalence implies structural equivalence.
76#[derive(Clone)]
77pub struct Context {
78    strings: indexmap::IndexSet<String, FxBuildHasher>,
79    exprs: indexmap::IndexSet<Expr, FxBuildHasher>,
80    values: baa::ValueInterner,
81    // cached special values
82    true_expr_ref: ExprRef,
83    false_expr_ref: ExprRef,
84}
85
86impl Default for Context {
87    // TODO: should probably rename this to "new" at some point.
88    fn default() -> Self {
89        let mut out = Self {
90            strings: Default::default(),
91            exprs: Default::default(),
92            values: Default::default(),
93            true_expr_ref: 0.into(),  // only a placeholder!
94            false_expr_ref: 0.into(), // only a placeholder!
95        };
96        // create valid cached expressions
97        out.false_expr_ref = out.zero(1);
98        out.true_expr_ref = out.one(1);
99        out
100    }
101}
102
103/// Adding and removing nodes.
104impl Context {
105    pub fn get_symbol_name(&self, reference: ExprRef) -> Option<&str> {
106        self[reference].get_symbol_name(self)
107    }
108
109    pub(crate) fn add_expr(&mut self, value: Expr) -> ExprRef {
110        let (index, _) = self.exprs.insert_full(value);
111        index.into()
112    }
113
114    pub fn string(&mut self, value: std::borrow::Cow<str>) -> StringRef {
115        if let Some(index) = self.strings.get_index_of(value.as_ref()) {
116            StringRef::from_index(index)
117        } else {
118            let (index, _) = self.strings.insert_full(value.into_owned());
119            StringRef::from_index(index)
120        }
121    }
122
123    pub(crate) fn get_bv_value(&self, index: impl Borrow<BitVecValueIndex>) -> BitVecValueRef<'_> {
124        self.values.words().get_ref(index)
125    }
126}
127
128impl Index<ExprRef> for Context {
129    type Output = Expr;
130
131    fn index(&self, index: ExprRef) -> &Self::Output {
132        self.exprs
133            .get_index(index.into())
134            .expect("Invalid ExprRef!")
135    }
136}
137
138impl Index<StringRef> for Context {
139    type Output = String;
140
141    fn index(&self, index: StringRef) -> &Self::Output {
142        self.strings
143            .get_index(index.index())
144            .expect("Invalid StringRef!")
145    }
146}
147
148/// Convenience methods to inspect IR nodes.
149impl Context {
150    /// Returns whether `e` represents a bit vector literal `0` of any width.
151    pub fn is_zero(&self, e: ExprRef) -> bool {
152        if let Expr::BVLiteral(value) = self[e] {
153            value.get(self).is_zero()
154        } else {
155            false
156        }
157    }
158}
159
160/// Convenience methods to construct IR nodes.
161impl Context {
162    // helper functions to construct expressions
163    pub fn bv_symbol(&mut self, name: &str, width: WidthInt) -> ExprRef {
164        assert!(width > 0, "0-bit bitvectors are not allowed");
165        let name_ref = self.string(name.into());
166        self.add_expr(Expr::BVSymbol {
167            name: name_ref,
168            width,
169        })
170    }
171
172    pub fn array_symbol(
173        &mut self,
174        name: &str,
175        index_width: WidthInt,
176        data_width: WidthInt,
177    ) -> ExprRef {
178        assert!(index_width > 0, "0-bit bitvectors are not allowed");
179        assert!(data_width > 0, "0-bit bitvectors are not allowed");
180        let name_ref = self.string(name.into());
181        self.add_expr(Expr::ArraySymbol {
182            name: name_ref,
183            index_width,
184            data_width,
185        })
186    }
187    pub fn symbol(&mut self, name: StringRef, tpe: Type) -> ExprRef {
188        assert_ne!(tpe, Type::BV(0), "0-bit bitvectors are not allowed");
189        self.add_expr(Expr::symbol(name, tpe))
190    }
191    pub fn lit(&mut self, value: impl Borrow<Value>) -> ExprRef {
192        match value.borrow() {
193            Value::BitVec(value) => self.bv_lit(value),
194            Value::Array(value) => {
195                let sparse: SparseArrayValue = value.into();
196                let default = self.bv_lit(&sparse.default());
197                let base = self.array_const(default, sparse.index_width());
198                sparse
199                    .non_default_entries()
200                    .fold(base, |array, (index, data)| {
201                        let index = self.bv_lit(&index);
202                        let data = self.bv_lit(&data);
203                        self.array_store(array, index, data)
204                    })
205            }
206        }
207    }
208    pub fn bv_lit<'a>(&mut self, value: impl Into<BitVecValueRef<'a>>) -> ExprRef {
209        let index = self.values.get_index(value);
210        self.add_expr(Expr::BVLiteral(BVLitValue::new(index)))
211    }
212    pub fn bit_vec_val(
213        &mut self,
214        value: impl TryInto<u128>,
215        width: impl TryInto<WidthInt>,
216    ) -> ExprRef {
217        let (value, width) = match (value.try_into(), width.try_into()) {
218            (Ok(value), Ok(width)) => (value, width),
219            _ => panic!("failed to convert value or width! Both must be positive!"),
220        };
221        let value = BitVecValue::from_u128(value, width);
222        self.bv_lit(&value)
223    }
224    pub fn zero(&mut self, width: WidthInt) -> ExprRef {
225        self.bv_lit(&BitVecValue::zero(width))
226    }
227
228    pub fn zero_array(&mut self, tpe: ArrayType) -> ExprRef {
229        let data = self.zero(tpe.data_width);
230        self.array_const(data, tpe.index_width)
231    }
232
233    pub fn get_true(&self) -> ExprRef {
234        self.true_expr_ref
235    }
236
237    pub fn get_false(&self) -> ExprRef {
238        self.false_expr_ref
239    }
240
241    pub fn one(&mut self, width: WidthInt) -> ExprRef {
242        self.bv_lit(&BitVecValue::from_u64(1, width))
243    }
244    pub fn ones(&mut self, width: WidthInt) -> ExprRef {
245        self.bv_lit(&BitVecValue::ones(width))
246    }
247
248    pub fn distinct(&mut self, a: ExprRef, b: ExprRef) -> ExprRef {
249        let is_eq = self.equal(a, b);
250        self.not(is_eq)
251    }
252    pub fn equal(&mut self, a: ExprRef, b: ExprRef) -> ExprRef {
253        debug_assert_eq!(a.get_type(self), b.get_type(self));
254        if a.get_type(self).is_bit_vector() {
255            self.add_expr(Expr::BVEqual(a, b))
256        } else {
257            self.add_expr(Expr::ArrayEqual(a, b))
258        }
259    }
260    pub fn ite(&mut self, cond: ExprRef, tru: ExprRef, fals: ExprRef) -> ExprRef {
261        debug_assert_eq!(cond.get_bv_type(self).unwrap(), 1);
262        debug_assert_eq!(tru.get_type(self), fals.get_type(self));
263        if tru.get_type(self).is_bit_vector() {
264            self.add_expr(Expr::BVIte { cond, tru, fals })
265        } else {
266            self.add_expr(Expr::ArrayIte { cond, tru, fals })
267        }
268    }
269    pub fn implies(&mut self, a: ExprRef, b: ExprRef) -> ExprRef {
270        debug_assert_eq!(a.get_bv_type(self).unwrap(), 1);
271        debug_assert_eq!(b.get_bv_type(self).unwrap(), 1);
272        self.add_expr(Expr::BVImplies(a, b))
273    }
274    pub fn greater_signed(&mut self, a: ExprRef, b: ExprRef) -> ExprRef {
275        debug_assert_eq!(a.get_bv_type(self).unwrap(), b.get_bv_type(self).unwrap());
276        self.add_expr(Expr::BVGreaterSigned(a, b, b.get_bv_type(self).unwrap()))
277    }
278
279    pub fn greater(&mut self, a: ExprRef, b: ExprRef) -> ExprRef {
280        debug_assert_eq!(a.get_bv_type(self).unwrap(), b.get_bv_type(self).unwrap());
281        self.add_expr(Expr::BVGreater(a, b))
282    }
283    pub fn greater_or_equal_signed(&mut self, a: ExprRef, b: ExprRef) -> ExprRef {
284        debug_assert_eq!(a.get_bv_type(self).unwrap(), b.get_bv_type(self).unwrap());
285        self.add_expr(Expr::BVGreaterEqualSigned(
286            a,
287            b,
288            b.get_bv_type(self).unwrap(),
289        ))
290    }
291
292    pub fn greater_or_equal(&mut self, a: ExprRef, b: ExprRef) -> ExprRef {
293        debug_assert_eq!(a.get_bv_type(self).unwrap(), b.get_bv_type(self).unwrap());
294        self.add_expr(Expr::BVGreaterEqual(a, b))
295    }
296    pub fn not(&mut self, e: ExprRef) -> ExprRef {
297        debug_assert!(e.get_type(self).is_bit_vector());
298        self.add_expr(Expr::BVNot(e, e.get_bv_type(self).unwrap()))
299    }
300    pub fn negate(&mut self, e: ExprRef) -> ExprRef {
301        debug_assert!(e.get_type(self).is_bit_vector());
302        self.add_expr(Expr::BVNegate(e, e.get_bv_type(self).unwrap()))
303    }
304    pub fn and(&mut self, a: ExprRef, b: ExprRef) -> ExprRef {
305        debug_assert_eq!(a.get_bv_type(self).unwrap(), b.get_bv_type(self).unwrap());
306        self.add_expr(Expr::BVAnd(a, b, b.get_bv_type(self).unwrap()))
307    }
308    pub fn or(&mut self, a: ExprRef, b: ExprRef) -> ExprRef {
309        debug_assert_eq!(a.get_bv_type(self).unwrap(), b.get_bv_type(self).unwrap());
310        self.add_expr(Expr::BVOr(a, b, b.get_bv_type(self).unwrap()))
311    }
312    pub fn xor(&mut self, a: ExprRef, b: ExprRef) -> ExprRef {
313        debug_assert_eq!(a.get_bv_type(self).unwrap(), b.get_bv_type(self).unwrap());
314        self.add_expr(Expr::BVXor(a, b, b.get_bv_type(self).unwrap()))
315    }
316
317    pub fn xor3(&mut self, a: ExprRef, b: ExprRef, c: ExprRef) -> ExprRef {
318        let x = self.xor(a, b);
319        self.xor(x, c)
320    }
321
322    pub fn majority(&mut self, a: ExprRef, b: ExprRef, c: ExprRef) -> ExprRef {
323        let a_and_b = self.and(a, b);
324        let a_and_c = self.and(a, c);
325        let b_and_c = self.and(b, c);
326        let x = self.or(a_and_b, a_and_c);
327        self.or(x, b_and_c)
328    }
329
330    pub fn shift_left(&mut self, a: ExprRef, b: ExprRef) -> ExprRef {
331        debug_assert_eq!(a.get_bv_type(self).unwrap(), b.get_bv_type(self).unwrap());
332        self.add_expr(Expr::BVShiftLeft(a, b, b.get_bv_type(self).unwrap()))
333    }
334    pub fn arithmetic_shift_right(&mut self, a: ExprRef, b: ExprRef) -> ExprRef {
335        debug_assert_eq!(a.get_bv_type(self).unwrap(), b.get_bv_type(self).unwrap());
336        self.add_expr(Expr::BVArithmeticShiftRight(
337            a,
338            b,
339            b.get_bv_type(self).unwrap(),
340        ))
341    }
342    pub fn shift_right(&mut self, a: ExprRef, b: ExprRef) -> ExprRef {
343        debug_assert_eq!(a.get_bv_type(self).unwrap(), b.get_bv_type(self).unwrap());
344        self.add_expr(Expr::BVShiftRight(a, b, b.get_bv_type(self).unwrap()))
345    }
346    pub fn add(&mut self, a: ExprRef, b: ExprRef) -> ExprRef {
347        debug_assert_eq!(a.get_bv_type(self).unwrap(), b.get_bv_type(self).unwrap());
348        self.add_expr(Expr::BVAdd(a, b, b.get_bv_type(self).unwrap()))
349    }
350    pub fn sub(&mut self, a: ExprRef, b: ExprRef) -> ExprRef {
351        debug_assert_eq!(a.get_bv_type(self).unwrap(), b.get_bv_type(self).unwrap());
352        self.add_expr(Expr::BVSub(a, b, b.get_bv_type(self).unwrap()))
353    }
354    pub fn mul(&mut self, a: ExprRef, b: ExprRef) -> ExprRef {
355        debug_assert_eq!(a.get_bv_type(self).unwrap(), b.get_bv_type(self).unwrap());
356        self.add_expr(Expr::BVMul(a, b, b.get_bv_type(self).unwrap()))
357    }
358    pub fn div(&mut self, a: ExprRef, b: ExprRef) -> ExprRef {
359        debug_assert_eq!(a.get_bv_type(self).unwrap(), b.get_bv_type(self).unwrap());
360        self.add_expr(Expr::BVUnsignedDiv(a, b, b.get_bv_type(self).unwrap()))
361    }
362    pub fn signed_div(&mut self, a: ExprRef, b: ExprRef) -> ExprRef {
363        debug_assert_eq!(a.get_bv_type(self).unwrap(), b.get_bv_type(self).unwrap());
364        self.add_expr(Expr::BVSignedDiv(a, b, b.get_bv_type(self).unwrap()))
365    }
366    pub fn signed_mod(&mut self, a: ExprRef, b: ExprRef) -> ExprRef {
367        debug_assert_eq!(a.get_bv_type(self).unwrap(), b.get_bv_type(self).unwrap());
368        self.add_expr(Expr::BVSignedMod(a, b, b.get_bv_type(self).unwrap()))
369    }
370    pub fn signed_remainder(&mut self, a: ExprRef, b: ExprRef) -> ExprRef {
371        debug_assert_eq!(a.get_bv_type(self).unwrap(), b.get_bv_type(self).unwrap());
372        self.add_expr(Expr::BVSignedRem(a, b, b.get_bv_type(self).unwrap()))
373    }
374    pub fn remainder(&mut self, a: ExprRef, b: ExprRef) -> ExprRef {
375        debug_assert_eq!(a.get_bv_type(self).unwrap(), b.get_bv_type(self).unwrap());
376        self.add_expr(Expr::BVUnsignedRem(a, b, b.get_bv_type(self).unwrap()))
377    }
378    pub fn concat(&mut self, a: ExprRef, b: ExprRef) -> ExprRef {
379        debug_assert!(a.get_type(self).is_bit_vector());
380        debug_assert!(b.get_type(self).is_bit_vector());
381        let width = a.get_bv_type(self).unwrap() + b.get_bv_type(self).unwrap();
382        self.add_expr(Expr::BVConcat(a, b, width))
383    }
384    pub fn slice(&mut self, e: ExprRef, hi: WidthInt, lo: WidthInt) -> ExprRef {
385        if lo == 0 && hi + 1 == e.get_bv_type(self).unwrap() {
386            e
387        } else {
388            assert!(hi >= lo, "{hi} < {lo} ... not allowed!");
389            self.add_expr(Expr::BVSlice { e, hi, lo })
390        }
391    }
392    pub fn zero_extend(&mut self, e: ExprRef, by: WidthInt) -> ExprRef {
393        if by == 0 {
394            e
395        } else {
396            let width = e.get_bv_type(self).unwrap() + by;
397            self.add_expr(Expr::BVZeroExt { e, by, width })
398        }
399    }
400    pub fn sign_extend(&mut self, e: ExprRef, by: WidthInt) -> ExprRef {
401        if by == 0 {
402            e
403        } else {
404            let width = e.get_bv_type(self).unwrap() + by;
405            self.add_expr(Expr::BVSignExt { e, by, width })
406        }
407    }
408
409    /// Sign or zero extends depending on the value of `signed`.
410    pub fn extend(&mut self, e: ExprRef, by: WidthInt, signed: bool) -> ExprRef {
411        if signed {
412            self.sign_extend(e, by)
413        } else {
414            self.zero_extend(e, by)
415        }
416    }
417
418    pub fn array_store(&mut self, array: ExprRef, index: ExprRef, data: ExprRef) -> ExprRef {
419        self.add_expr(Expr::ArrayStore { array, index, data })
420    }
421
422    pub fn array_const(&mut self, e: ExprRef, index_width: WidthInt) -> ExprRef {
423        let data_width = e.get_bv_type(self).unwrap();
424        self.add_expr(Expr::ArrayConstant {
425            e,
426            index_width,
427            data_width,
428        })
429    }
430
431    pub fn array_read(&mut self, array: ExprRef, index: ExprRef) -> ExprRef {
432        let width = array.get_type(self).get_array_data_width().unwrap();
433        self.add_expr(Expr::BVArrayRead {
434            array,
435            index,
436            width,
437        })
438    }
439
440    pub fn build(&mut self, foo: impl FnOnce(Builder) -> ExprRef) -> ExprRef {
441        let builder = Builder::new(self);
442        foo(builder)
443    }
444}
445
446/// Makes it possible to build up expressions while using dynamically checked borrowing rules
447/// to work around a shortcoming of the Rust borrow checker.
448/// Thus, with a builder you will be able to build up nested expressions easily!
449pub struct Builder<'a> {
450    ctx: RefCell<&'a mut Context>,
451}
452
453impl<'a> Builder<'a> {
454    fn new(ctx: &'a mut Context) -> Self {
455        Self {
456            ctx: RefCell::new(ctx),
457        }
458    }
459}
460
461impl<'a> Builder<'a> {
462    pub fn bv_symbol(&self, name: &str, width: WidthInt) -> ExprRef {
463        self.ctx.borrow_mut().bv_symbol(name, width)
464    }
465    pub fn symbol(&self, name: StringRef, tpe: Type) -> ExprRef {
466        self.ctx.borrow_mut().symbol(name, tpe)
467    }
468    pub fn bv_lit<'b>(&self, value: impl Into<BitVecValueRef<'b>>) -> ExprRef {
469        self.ctx.borrow_mut().bv_lit(value)
470    }
471    pub fn bit_vec_val(&self, value: impl TryInto<u128>, width: impl TryInto<WidthInt>) -> ExprRef {
472        self.ctx.borrow_mut().bit_vec_val(value, width)
473    }
474    pub fn zero(&self, width: WidthInt) -> ExprRef {
475        self.ctx.borrow_mut().zero(width)
476    }
477
478    pub fn get_true(&self) -> ExprRef {
479        self.ctx.borrow().get_true()
480    }
481
482    pub fn get_false(&self) -> ExprRef {
483        self.ctx.borrow().get_false()
484    }
485
486    pub fn zero_array(&self, tpe: ArrayType) -> ExprRef {
487        self.ctx.borrow_mut().zero_array(tpe)
488    }
489
490    pub fn one(&self, width: WidthInt) -> ExprRef {
491        self.ctx.borrow_mut().one(width)
492    }
493    pub fn ones(&self, width: WidthInt) -> ExprRef {
494        self.ctx.borrow_mut().ones(width)
495    }
496    pub fn equal(&self, a: ExprRef, b: ExprRef) -> ExprRef {
497        self.ctx.borrow_mut().equal(a, b)
498    }
499    pub fn ite(&self, cond: ExprRef, tru: ExprRef, fals: ExprRef) -> ExprRef {
500        self.ctx.borrow_mut().ite(cond, tru, fals)
501    }
502    pub fn implies(&self, a: ExprRef, b: ExprRef) -> ExprRef {
503        self.ctx.borrow_mut().implies(a, b)
504    }
505    pub fn greater_signed(&self, a: ExprRef, b: ExprRef) -> ExprRef {
506        self.ctx.borrow_mut().greater_signed(a, b)
507    }
508
509    pub fn greater(&self, a: ExprRef, b: ExprRef) -> ExprRef {
510        self.ctx.borrow_mut().greater(a, b)
511    }
512    pub fn greater_or_equal_signed(&self, a: ExprRef, b: ExprRef) -> ExprRef {
513        self.ctx.borrow_mut().greater_or_equal_signed(a, b)
514    }
515
516    pub fn greater_or_equal(&self, a: ExprRef, b: ExprRef) -> ExprRef {
517        self.ctx.borrow_mut().greater_or_equal(a, b)
518    }
519    pub fn not(&self, e: ExprRef) -> ExprRef {
520        self.ctx.borrow_mut().not(e)
521    }
522    pub fn negate(&self, e: ExprRef) -> ExprRef {
523        self.ctx.borrow_mut().negate(e)
524    }
525    pub fn and(&self, a: ExprRef, b: ExprRef) -> ExprRef {
526        self.ctx.borrow_mut().and(a, b)
527    }
528    pub fn or(&self, a: ExprRef, b: ExprRef) -> ExprRef {
529        self.ctx.borrow_mut().or(a, b)
530    }
531    pub fn xor(&self, a: ExprRef, b: ExprRef) -> ExprRef {
532        self.ctx.borrow_mut().xor(a, b)
533    }
534    pub fn xor3(&mut self, a: ExprRef, b: ExprRef, c: ExprRef) -> ExprRef {
535        self.ctx.borrow_mut().xor3(a, b, c)
536    }
537    pub fn majority(&mut self, a: ExprRef, b: ExprRef, c: ExprRef) -> ExprRef {
538        self.ctx.borrow_mut().majority(a, b, c)
539    }
540    pub fn shift_left(&self, a: ExprRef, b: ExprRef) -> ExprRef {
541        self.ctx.borrow_mut().shift_left(a, b)
542    }
543    pub fn arithmetic_shift_right(&self, a: ExprRef, b: ExprRef) -> ExprRef {
544        self.ctx.borrow_mut().arithmetic_shift_right(a, b)
545    }
546    pub fn shift_right(&self, a: ExprRef, b: ExprRef) -> ExprRef {
547        self.ctx.borrow_mut().shift_right(a, b)
548    }
549    pub fn add(&self, a: ExprRef, b: ExprRef) -> ExprRef {
550        self.ctx.borrow_mut().add(a, b)
551    }
552    pub fn sub(&self, a: ExprRef, b: ExprRef) -> ExprRef {
553        self.ctx.borrow_mut().sub(a, b)
554    }
555    pub fn mul(&self, a: ExprRef, b: ExprRef) -> ExprRef {
556        self.ctx.borrow_mut().mul(a, b)
557    }
558    pub fn div(&self, a: ExprRef, b: ExprRef) -> ExprRef {
559        self.ctx.borrow_mut().div(a, b)
560    }
561    pub fn signed_div(&self, a: ExprRef, b: ExprRef) -> ExprRef {
562        self.ctx.borrow_mut().signed_div(a, b)
563    }
564    pub fn signed_mod(&self, a: ExprRef, b: ExprRef) -> ExprRef {
565        self.ctx.borrow_mut().signed_mod(a, b)
566    }
567    pub fn signed_remainder(&self, a: ExprRef, b: ExprRef) -> ExprRef {
568        self.ctx.borrow_mut().signed_remainder(a, b)
569    }
570    pub fn remainder(&self, a: ExprRef, b: ExprRef) -> ExprRef {
571        self.ctx.borrow_mut().remainder(a, b)
572    }
573    pub fn concat(&self, a: ExprRef, b: ExprRef) -> ExprRef {
574        self.ctx.borrow_mut().concat(a, b)
575    }
576    pub fn slice(&self, e: ExprRef, hi: WidthInt, lo: WidthInt) -> ExprRef {
577        self.ctx.borrow_mut().slice(e, hi, lo)
578    }
579    pub fn zero_extend(&self, e: ExprRef, by: WidthInt) -> ExprRef {
580        self.ctx.borrow_mut().zero_extend(e, by)
581    }
582    pub fn sign_extend(&self, e: ExprRef, by: WidthInt) -> ExprRef {
583        self.ctx.borrow_mut().sign_extend(e, by)
584    }
585
586    /// Sign or zero extends depending on the value of `signed`.
587    pub fn extend(&mut self, e: ExprRef, by: WidthInt, signed: bool) -> ExprRef {
588        self.ctx.borrow_mut().extend(e, by, signed)
589    }
590
591    pub fn array_store(&self, array: ExprRef, index: ExprRef, data: ExprRef) -> ExprRef {
592        self.ctx.borrow_mut().array_store(array, index, data)
593    }
594
595    pub fn array_const(&self, e: ExprRef, index_width: WidthInt) -> ExprRef {
596        self.ctx.borrow_mut().array_const(e, index_width)
597    }
598
599    pub fn array_read(&self, array: ExprRef, index: ExprRef) -> ExprRef {
600        self.ctx.borrow_mut().array_read(array, index)
601    }
602}
603
604#[cfg(test)]
605mod tests {
606    use super::*;
607    use crate::expr::SerializableIrNode;
608
609    #[test]
610    fn ir_type_size() {
611        assert_eq!(std::mem::size_of::<StringRef>(), 4);
612        assert_eq!(std::mem::size_of::<ExprRef>(), 4);
613    }
614
615    #[test]
616    fn reference_ids() {
617        let mut ctx = Context::default();
618
619        // ids 1 and 2 are reserved for true and false
620        assert_eq!(ctx.get_false().0.get(), 1);
621        assert_eq!(ctx.get_true().0.get(), 2);
622
623        let str_id0 = ctx.string("a".into());
624        let id0 = ctx.add_expr(Expr::BVSymbol {
625            name: str_id0,
626            width: 1,
627        });
628        assert_eq!(id0.0.get(), 3, "ids start at three (for now)");
629        let id0_b = ctx.add_expr(Expr::BVSymbol {
630            name: str_id0,
631            width: 1,
632        });
633        assert_eq!(id0.0, id0_b.0, "ids should be interned!");
634        let id1 = ctx.add_expr(Expr::BVSymbol {
635            name: str_id0,
636            width: 2,
637        });
638        assert_eq!(id0.0.get() + 1, id1.0.get(), "ids should increment!");
639    }
640
641    /// make sure that we can intern a lot of strings before running out of IDs
642    #[test]
643    fn intern_lots_of_strings() {
644        let mut ctx = Context::default();
645        // we loose 1 ID since 0 is not a valid ID value
646        let max_strings = (1u64 << 16) - 1;
647        for ii in 0..max_strings {
648            let value = format!("{ii}AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA");
649            let _id = ctx.string(value.into());
650        }
651        // now that we have used up all the IDs, we should still be able to "add" strings that
652        // are already part of the context
653        let first = "0AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA";
654        assert_eq!(ctx.string(first.into()).index(), 0);
655    }
656
657    #[test]
658    fn test_builder() {
659        let mut ctx = Context::default();
660        let expr = ctx.build(|b| b.and(b.bv_symbol("a", 1), b.bv_symbol("b", 1)));
661        assert_eq!(expr.serialize_to_str(&ctx), "and(a, b)");
662    }
663
664    #[test]
665    fn test_bit_vec_val() {
666        let mut ctx = Context::default();
667        let _v0 = ctx.bit_vec_val(1, 128);
668    }
669}