patronus/expr/
context.rs

1// Copyright 2023 The Regents of the University of California
2// Copyright 2024 Cornell University
3// released under BSD 3-Clause License
4// author: Kevin Laeufer <laeufer@cornell.edu>
5
6//! # IR Context
7//!
8//! The [`Context`] is used to create and access bit-vector and array expressions.
9//! It ensures that the same expression always maps to the same expression reference.
10//! Thus, if two references are equal, we can be certain that the expressions they point to are
11//! equivalent.
12//!
13//! Users are expected to generally use a single Context for all their expressions. There
14//! are no checks to ensure that a [`ExprRef`] or [`StringRef`] from different contexts are
15//! not matched. Thus working with more than one [`Context`] object can be dangerous.
16
17use crate::expr::TypeCheck;
18use crate::expr::nodes::*;
19use baa::{
20    ArrayOps, BitVecValue, BitVecValueIndex, BitVecValueRef, IndexToRef, SparseArrayValue, Value,
21};
22use rustc_hash::FxBuildHasher;
23use std::borrow::Borrow;
24use std::cell::RefCell;
25use std::fmt::{Debug, Formatter};
26use std::num::NonZeroU32;
27use std::ops::Index;
28
29/// Uniquely identifies a [`String`] stored in a [`Context`].
30#[derive(PartialEq, Eq, Clone, Copy, Hash, PartialOrd, Ord)]
31pub struct StringRef(NonZeroU32);
32
33impl Debug for StringRef {
34    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
35        write!(f, "StringRef({})", self.index())
36    }
37}
38
39impl StringRef {
40    fn from_index(index: usize) -> Self {
41        Self(NonZeroU32::new((index + 1) as u32).unwrap())
42    }
43
44    fn index(&self) -> usize {
45        (self.0.get() - 1) as usize
46    }
47}
48
49/// Uniquely identifies an [`Expr`] stored in a [`Context`].
50#[derive(PartialEq, Eq, Clone, Copy, Hash, Ord, PartialOrd)]
51pub struct ExprRef(NonZeroU32);
52
53impl Debug for ExprRef {
54    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
55        // we need a custom implementation in order to show the zero based index
56        write!(f, "ExprRef({})", self.index())
57    }
58}
59
60impl ExprRef {
61    // TODO: reduce visibility to pub(crate)
62    pub fn from_index(index: usize) -> Self {
63        ExprRef(NonZeroU32::new((index + 1) as u32).unwrap())
64    }
65
66    pub(crate) fn index(&self) -> usize {
67        (self.0.get() - 1) as usize
68    }
69}
70
71/// Context which is used to create all SMT expressions. Expressions are interned such that
72/// reference equivalence implies structural equivalence.
73#[derive(Clone)]
74pub struct Context {
75    strings: indexmap::IndexSet<String, FxBuildHasher>,
76    exprs: indexmap::IndexSet<Expr, FxBuildHasher>,
77    values: baa::ValueInterner,
78    // cached special values
79    true_expr_ref: ExprRef,
80    false_expr_ref: ExprRef,
81}
82
83impl Default for Context {
84    // TODO: should probably rename this to "new" at some point.
85    fn default() -> Self {
86        let mut out = Self {
87            strings: Default::default(),
88            exprs: Default::default(),
89            values: Default::default(),
90            true_expr_ref: ExprRef::from_index(0),
91            false_expr_ref: ExprRef::from_index(0),
92        };
93        // create valid cached expressions
94        out.false_expr_ref = out.zero(1);
95        out.true_expr_ref = out.one(1);
96        out
97    }
98}
99
100/// Adding and removing nodes.
101impl Context {
102    pub fn get_symbol_name(&self, reference: ExprRef) -> Option<&str> {
103        self[reference].get_symbol_name(self)
104    }
105
106    pub(crate) fn add_expr(&mut self, value: Expr) -> ExprRef {
107        let (index, _) = self.exprs.insert_full(value);
108        ExprRef::from_index(index)
109    }
110
111    pub fn string(&mut self, value: std::borrow::Cow<str>) -> StringRef {
112        if let Some(index) = self.strings.get_index_of(value.as_ref()) {
113            StringRef::from_index(index)
114        } else {
115            let (index, _) = self.strings.insert_full(value.into_owned());
116            StringRef::from_index(index)
117        }
118    }
119
120    pub(crate) fn get_bv_value(&self, index: impl Borrow<BitVecValueIndex>) -> BitVecValueRef<'_> {
121        self.values.words().get_ref(index)
122    }
123}
124
125impl Index<ExprRef> for Context {
126    type Output = Expr;
127
128    fn index(&self, index: ExprRef) -> &Self::Output {
129        self.exprs
130            .get_index(index.index())
131            .expect("Invalid ExprRef!")
132    }
133}
134
135impl Index<StringRef> for Context {
136    type Output = String;
137
138    fn index(&self, index: StringRef) -> &Self::Output {
139        self.strings
140            .get_index(index.index())
141            .expect("Invalid StringRef!")
142    }
143}
144
145impl Context {
146    /// Returns the number of interned expressions in this context.
147    pub fn num_exprs(&self) -> usize {
148        self.exprs.len()
149    }
150
151    /// Returns a reference to the expression for the given reference.
152    /// Panics if the reference is invalid (use indices in range 0..num_exprs()).
153    pub fn get_expr(&self, r: ExprRef) -> &Expr {
154        &self[r]
155    }
156
157    /// Returns the zero-based intern index of the given expression reference.
158    pub fn expr_index(&self, r: ExprRef) -> usize {
159        r.index()
160    }
161}
162
163/// Convenience methods to construct IR nodes.
164impl Context {
165    // helper functions to construct expressions
166    pub fn bv_symbol(&mut self, name: &str, width: WidthInt) -> ExprRef {
167        assert!(width > 0, "0-bit bitvectors are not allowed");
168        let name_ref = self.string(name.into());
169        self.add_expr(Expr::BVSymbol {
170            name: name_ref,
171            width,
172        })
173    }
174
175    pub fn array_symbol(
176        &mut self,
177        name: &str,
178        index_width: WidthInt,
179        data_width: WidthInt,
180    ) -> ExprRef {
181        assert!(index_width > 0, "0-bit bitvectors are not allowed");
182        assert!(data_width > 0, "0-bit bitvectors are not allowed");
183        let name_ref = self.string(name.into());
184        self.add_expr(Expr::ArraySymbol {
185            name: name_ref,
186            index_width,
187            data_width,
188        })
189    }
190    pub fn symbol(&mut self, name: StringRef, tpe: Type) -> ExprRef {
191        assert_ne!(tpe, Type::BV(0), "0-bit bitvectors are not allowed");
192        self.add_expr(Expr::symbol(name, tpe))
193    }
194    pub fn lit(&mut self, value: impl Borrow<Value>) -> ExprRef {
195        match value.borrow() {
196            Value::BitVec(value) => self.bv_lit(value),
197            Value::Array(value) => {
198                let sparse: SparseArrayValue = value.into();
199                let default = self.bv_lit(&sparse.default());
200                let base = self.array_const(default, sparse.index_width());
201                sparse
202                    .non_default_entries()
203                    .fold(base, |array, (index, data)| {
204                        let index = self.bv_lit(&index);
205                        let data = self.bv_lit(&data);
206                        self.array_store(array, index, data)
207                    })
208            }
209        }
210    }
211    pub fn bv_lit<'a>(&mut self, value: impl Into<BitVecValueRef<'a>>) -> ExprRef {
212        let index = self.values.get_index(value);
213        self.add_expr(Expr::BVLiteral(BVLitValue::new(index)))
214    }
215    pub fn bit_vec_val(
216        &mut self,
217        value: impl TryInto<u128>,
218        width: impl TryInto<WidthInt>,
219    ) -> ExprRef {
220        let (value, width) = match (value.try_into(), width.try_into()) {
221            (Ok(value), Ok(width)) => (value, width),
222            _ => panic!("failed to convert value or width! Both must be positive!"),
223        };
224        let value = BitVecValue::from_u128(value, width);
225        self.bv_lit(&value)
226    }
227    pub fn zero(&mut self, width: WidthInt) -> ExprRef {
228        self.bv_lit(&BitVecValue::zero(width))
229    }
230
231    pub fn zero_array(&mut self, tpe: ArrayType) -> ExprRef {
232        let data = self.zero(tpe.data_width);
233        self.array_const(data, tpe.index_width)
234    }
235
236    pub fn get_true(&self) -> ExprRef {
237        self.true_expr_ref
238    }
239
240    pub fn get_false(&self) -> ExprRef {
241        self.false_expr_ref
242    }
243
244    pub fn one(&mut self, width: WidthInt) -> ExprRef {
245        self.bv_lit(&BitVecValue::from_u64(1, width))
246    }
247    pub fn ones(&mut self, width: WidthInt) -> ExprRef {
248        self.bv_lit(&BitVecValue::ones(width))
249    }
250    pub fn equal(&mut self, a: ExprRef, b: ExprRef) -> ExprRef {
251        debug_assert_eq!(a.get_type(self), b.get_type(self));
252        if a.get_type(self).is_bit_vector() {
253            self.add_expr(Expr::BVEqual(a, b))
254        } else {
255            self.add_expr(Expr::ArrayEqual(a, b))
256        }
257    }
258    pub fn ite(&mut self, cond: ExprRef, tru: ExprRef, fals: ExprRef) -> ExprRef {
259        debug_assert_eq!(cond.get_bv_type(self).unwrap(), 1);
260        debug_assert_eq!(tru.get_type(self), fals.get_type(self));
261        if tru.get_type(self).is_bit_vector() {
262            self.add_expr(Expr::BVIte { cond, tru, fals })
263        } else {
264            self.add_expr(Expr::ArrayIte { cond, tru, fals })
265        }
266    }
267    pub fn implies(&mut self, a: ExprRef, b: ExprRef) -> ExprRef {
268        debug_assert_eq!(a.get_bv_type(self).unwrap(), 1);
269        debug_assert_eq!(b.get_bv_type(self).unwrap(), 1);
270        self.add_expr(Expr::BVImplies(a, b))
271    }
272    pub fn greater_signed(&mut self, a: ExprRef, b: ExprRef) -> ExprRef {
273        debug_assert_eq!(a.get_bv_type(self).unwrap(), b.get_bv_type(self).unwrap());
274        self.add_expr(Expr::BVGreaterSigned(a, b, b.get_bv_type(self).unwrap()))
275    }
276
277    pub fn greater(&mut self, a: ExprRef, b: ExprRef) -> ExprRef {
278        debug_assert_eq!(a.get_bv_type(self).unwrap(), b.get_bv_type(self).unwrap());
279        self.add_expr(Expr::BVGreater(a, b))
280    }
281    pub fn greater_or_equal_signed(&mut self, a: ExprRef, b: ExprRef) -> ExprRef {
282        debug_assert_eq!(a.get_bv_type(self).unwrap(), b.get_bv_type(self).unwrap());
283        self.add_expr(Expr::BVGreaterEqualSigned(
284            a,
285            b,
286            b.get_bv_type(self).unwrap(),
287        ))
288    }
289
290    pub fn greater_or_equal(&mut self, a: ExprRef, b: ExprRef) -> ExprRef {
291        debug_assert_eq!(a.get_bv_type(self).unwrap(), b.get_bv_type(self).unwrap());
292        self.add_expr(Expr::BVGreaterEqual(a, b))
293    }
294    pub fn not(&mut self, e: ExprRef) -> ExprRef {
295        debug_assert!(e.get_type(self).is_bit_vector());
296        self.add_expr(Expr::BVNot(e, e.get_bv_type(self).unwrap()))
297    }
298    pub fn negate(&mut self, e: ExprRef) -> ExprRef {
299        debug_assert!(e.get_type(self).is_bit_vector());
300        self.add_expr(Expr::BVNegate(e, e.get_bv_type(self).unwrap()))
301    }
302    pub fn and(&mut self, a: ExprRef, b: ExprRef) -> ExprRef {
303        debug_assert_eq!(a.get_bv_type(self).unwrap(), b.get_bv_type(self).unwrap());
304        self.add_expr(Expr::BVAnd(a, b, b.get_bv_type(self).unwrap()))
305    }
306    pub fn or(&mut self, a: ExprRef, b: ExprRef) -> ExprRef {
307        debug_assert_eq!(a.get_bv_type(self).unwrap(), b.get_bv_type(self).unwrap());
308        self.add_expr(Expr::BVOr(a, b, b.get_bv_type(self).unwrap()))
309    }
310    pub fn xor(&mut self, a: ExprRef, b: ExprRef) -> ExprRef {
311        debug_assert_eq!(a.get_bv_type(self).unwrap(), b.get_bv_type(self).unwrap());
312        self.add_expr(Expr::BVXor(a, b, b.get_bv_type(self).unwrap()))
313    }
314    pub fn shift_left(&mut self, a: ExprRef, b: ExprRef) -> ExprRef {
315        debug_assert_eq!(a.get_bv_type(self).unwrap(), b.get_bv_type(self).unwrap());
316        self.add_expr(Expr::BVShiftLeft(a, b, b.get_bv_type(self).unwrap()))
317    }
318    pub fn arithmetic_shift_right(&mut self, a: ExprRef, b: ExprRef) -> ExprRef {
319        debug_assert_eq!(a.get_bv_type(self).unwrap(), b.get_bv_type(self).unwrap());
320        self.add_expr(Expr::BVArithmeticShiftRight(
321            a,
322            b,
323            b.get_bv_type(self).unwrap(),
324        ))
325    }
326    pub fn shift_right(&mut self, a: ExprRef, b: ExprRef) -> ExprRef {
327        debug_assert_eq!(a.get_bv_type(self).unwrap(), b.get_bv_type(self).unwrap());
328        self.add_expr(Expr::BVShiftRight(a, b, b.get_bv_type(self).unwrap()))
329    }
330    pub fn add(&mut self, a: ExprRef, b: ExprRef) -> ExprRef {
331        debug_assert_eq!(a.get_bv_type(self).unwrap(), b.get_bv_type(self).unwrap());
332        self.add_expr(Expr::BVAdd(a, b, b.get_bv_type(self).unwrap()))
333    }
334    pub fn sub(&mut self, a: ExprRef, b: ExprRef) -> ExprRef {
335        debug_assert_eq!(a.get_bv_type(self).unwrap(), b.get_bv_type(self).unwrap());
336        self.add_expr(Expr::BVSub(a, b, b.get_bv_type(self).unwrap()))
337    }
338    pub fn mul(&mut self, a: ExprRef, b: ExprRef) -> ExprRef {
339        debug_assert_eq!(a.get_bv_type(self).unwrap(), b.get_bv_type(self).unwrap());
340        self.add_expr(Expr::BVMul(a, b, b.get_bv_type(self).unwrap()))
341    }
342    pub fn div(&mut self, a: ExprRef, b: ExprRef) -> ExprRef {
343        debug_assert_eq!(a.get_bv_type(self).unwrap(), b.get_bv_type(self).unwrap());
344        self.add_expr(Expr::BVUnsignedDiv(a, b, b.get_bv_type(self).unwrap()))
345    }
346    pub fn signed_div(&mut self, a: ExprRef, b: ExprRef) -> ExprRef {
347        debug_assert_eq!(a.get_bv_type(self).unwrap(), b.get_bv_type(self).unwrap());
348        self.add_expr(Expr::BVSignedDiv(a, b, b.get_bv_type(self).unwrap()))
349    }
350    pub fn signed_mod(&mut self, a: ExprRef, b: ExprRef) -> ExprRef {
351        debug_assert_eq!(a.get_bv_type(self).unwrap(), b.get_bv_type(self).unwrap());
352        self.add_expr(Expr::BVSignedMod(a, b, b.get_bv_type(self).unwrap()))
353    }
354    pub fn signed_remainder(&mut self, a: ExprRef, b: ExprRef) -> ExprRef {
355        debug_assert_eq!(a.get_bv_type(self).unwrap(), b.get_bv_type(self).unwrap());
356        self.add_expr(Expr::BVSignedRem(a, b, b.get_bv_type(self).unwrap()))
357    }
358    pub fn remainder(&mut self, a: ExprRef, b: ExprRef) -> ExprRef {
359        debug_assert_eq!(a.get_bv_type(self).unwrap(), b.get_bv_type(self).unwrap());
360        self.add_expr(Expr::BVUnsignedRem(a, b, b.get_bv_type(self).unwrap()))
361    }
362    pub fn concat(&mut self, a: ExprRef, b: ExprRef) -> ExprRef {
363        debug_assert!(a.get_type(self).is_bit_vector());
364        debug_assert!(b.get_type(self).is_bit_vector());
365        let width = a.get_bv_type(self).unwrap() + b.get_bv_type(self).unwrap();
366        self.add_expr(Expr::BVConcat(a, b, width))
367    }
368    pub fn slice(&mut self, e: ExprRef, hi: WidthInt, lo: WidthInt) -> ExprRef {
369        if lo == 0 && hi + 1 == e.get_bv_type(self).unwrap() {
370            e
371        } else {
372            assert!(hi >= lo, "{hi} < {lo} ... not allowed!");
373            self.add_expr(Expr::BVSlice { e, hi, lo })
374        }
375    }
376    pub fn zero_extend(&mut self, e: ExprRef, by: WidthInt) -> ExprRef {
377        if by == 0 {
378            e
379        } else {
380            let width = e.get_bv_type(self).unwrap() + by;
381            self.add_expr(Expr::BVZeroExt { e, by, width })
382        }
383    }
384    pub fn sign_extend(&mut self, e: ExprRef, by: WidthInt) -> ExprRef {
385        if by == 0 {
386            e
387        } else {
388            let width = e.get_bv_type(self).unwrap() + by;
389            self.add_expr(Expr::BVSignExt { e, by, width })
390        }
391    }
392
393    /// Sign or zero extends depending on the value of `signed`.
394    pub fn extend(&mut self, e: ExprRef, by: WidthInt, signed: bool) -> ExprRef {
395        if signed {
396            self.sign_extend(e, by)
397        } else {
398            self.zero_extend(e, by)
399        }
400    }
401
402    pub fn array_store(&mut self, array: ExprRef, index: ExprRef, data: ExprRef) -> ExprRef {
403        self.add_expr(Expr::ArrayStore { array, index, data })
404    }
405
406    pub fn array_const(&mut self, e: ExprRef, index_width: WidthInt) -> ExprRef {
407        let data_width = e.get_bv_type(self).unwrap();
408        self.add_expr(Expr::ArrayConstant {
409            e,
410            index_width,
411            data_width,
412        })
413    }
414
415    pub fn array_read(&mut self, array: ExprRef, index: ExprRef) -> ExprRef {
416        let width = array.get_type(self).get_array_data_width().unwrap();
417        self.add_expr(Expr::BVArrayRead {
418            array,
419            index,
420            width,
421        })
422    }
423
424    pub fn build(&mut self, foo: impl FnOnce(Builder) -> ExprRef) -> ExprRef {
425        let builder = Builder::new(self);
426        foo(builder)
427    }
428}
429
430/// Makes it possible to build up expressions while using dynamically checked borrowing rules
431/// to work around a shortcoming of the Rust borrow checker.
432/// Thus, with a builder you will be able to build up nested expressions easily!
433pub struct Builder<'a> {
434    ctx: RefCell<&'a mut Context>,
435}
436
437impl<'a> Builder<'a> {
438    fn new(ctx: &'a mut Context) -> Self {
439        Self {
440            ctx: RefCell::new(ctx),
441        }
442    }
443}
444
445impl<'a> Builder<'a> {
446    pub fn bv_symbol(&self, name: &str, width: WidthInt) -> ExprRef {
447        self.ctx.borrow_mut().bv_symbol(name, width)
448    }
449    pub fn symbol(&self, name: StringRef, tpe: Type) -> ExprRef {
450        self.ctx.borrow_mut().symbol(name, tpe)
451    }
452    pub fn bv_lit<'b>(&self, value: impl Into<BitVecValueRef<'b>>) -> ExprRef {
453        self.ctx.borrow_mut().bv_lit(value)
454    }
455    pub fn bit_vec_val(&self, value: impl TryInto<u128>, width: impl TryInto<WidthInt>) -> ExprRef {
456        self.ctx.borrow_mut().bit_vec_val(value, width)
457    }
458    pub fn zero(&self, width: WidthInt) -> ExprRef {
459        self.ctx.borrow_mut().zero(width)
460    }
461
462    pub fn get_true(&self) -> ExprRef {
463        self.ctx.borrow().get_true()
464    }
465
466    pub fn get_false(&self) -> ExprRef {
467        self.ctx.borrow().get_false()
468    }
469
470    pub fn zero_array(&self, tpe: ArrayType) -> ExprRef {
471        self.ctx.borrow_mut().zero_array(tpe)
472    }
473
474    pub fn one(&self, width: WidthInt) -> ExprRef {
475        self.ctx.borrow_mut().one(width)
476    }
477    pub fn ones(&self, width: WidthInt) -> ExprRef {
478        self.ctx.borrow_mut().ones(width)
479    }
480    pub fn equal(&self, a: ExprRef, b: ExprRef) -> ExprRef {
481        self.ctx.borrow_mut().equal(a, b)
482    }
483    pub fn ite(&self, cond: ExprRef, tru: ExprRef, fals: ExprRef) -> ExprRef {
484        self.ctx.borrow_mut().ite(cond, tru, fals)
485    }
486    pub fn implies(&self, a: ExprRef, b: ExprRef) -> ExprRef {
487        self.ctx.borrow_mut().implies(a, b)
488    }
489    pub fn greater_signed(&self, a: ExprRef, b: ExprRef) -> ExprRef {
490        self.ctx.borrow_mut().greater_signed(a, b)
491    }
492
493    pub fn greater(&self, a: ExprRef, b: ExprRef) -> ExprRef {
494        self.ctx.borrow_mut().greater(a, b)
495    }
496    pub fn greater_or_equal_signed(&self, a: ExprRef, b: ExprRef) -> ExprRef {
497        self.ctx.borrow_mut().greater_or_equal_signed(a, b)
498    }
499
500    pub fn greater_or_equal(&self, a: ExprRef, b: ExprRef) -> ExprRef {
501        self.ctx.borrow_mut().greater_or_equal(a, b)
502    }
503    pub fn not(&self, e: ExprRef) -> ExprRef {
504        self.ctx.borrow_mut().not(e)
505    }
506    pub fn negate(&self, e: ExprRef) -> ExprRef {
507        self.ctx.borrow_mut().negate(e)
508    }
509    pub fn and(&self, a: ExprRef, b: ExprRef) -> ExprRef {
510        self.ctx.borrow_mut().and(a, b)
511    }
512    pub fn or(&self, a: ExprRef, b: ExprRef) -> ExprRef {
513        self.ctx.borrow_mut().or(a, b)
514    }
515    pub fn xor(&self, a: ExprRef, b: ExprRef) -> ExprRef {
516        self.ctx.borrow_mut().xor(a, b)
517    }
518    pub fn shift_left(&self, a: ExprRef, b: ExprRef) -> ExprRef {
519        self.ctx.borrow_mut().shift_left(a, b)
520    }
521    pub fn arithmetic_shift_right(&self, a: ExprRef, b: ExprRef) -> ExprRef {
522        self.ctx.borrow_mut().arithmetic_shift_right(a, b)
523    }
524    pub fn shift_right(&self, a: ExprRef, b: ExprRef) -> ExprRef {
525        self.ctx.borrow_mut().shift_right(a, b)
526    }
527    pub fn add(&self, a: ExprRef, b: ExprRef) -> ExprRef {
528        self.ctx.borrow_mut().add(a, b)
529    }
530    pub fn sub(&self, a: ExprRef, b: ExprRef) -> ExprRef {
531        self.ctx.borrow_mut().sub(a, b)
532    }
533    pub fn mul(&self, a: ExprRef, b: ExprRef) -> ExprRef {
534        self.ctx.borrow_mut().mul(a, b)
535    }
536    pub fn div(&self, a: ExprRef, b: ExprRef) -> ExprRef {
537        self.ctx.borrow_mut().div(a, b)
538    }
539    pub fn signed_div(&self, a: ExprRef, b: ExprRef) -> ExprRef {
540        self.ctx.borrow_mut().signed_div(a, b)
541    }
542    pub fn signed_mod(&self, a: ExprRef, b: ExprRef) -> ExprRef {
543        self.ctx.borrow_mut().signed_mod(a, b)
544    }
545    pub fn signed_remainder(&self, a: ExprRef, b: ExprRef) -> ExprRef {
546        self.ctx.borrow_mut().signed_remainder(a, b)
547    }
548    pub fn remainder(&self, a: ExprRef, b: ExprRef) -> ExprRef {
549        self.ctx.borrow_mut().remainder(a, b)
550    }
551    pub fn concat(&self, a: ExprRef, b: ExprRef) -> ExprRef {
552        self.ctx.borrow_mut().concat(a, b)
553    }
554    pub fn slice(&self, e: ExprRef, hi: WidthInt, lo: WidthInt) -> ExprRef {
555        self.ctx.borrow_mut().slice(e, hi, lo)
556    }
557    pub fn zero_extend(&self, e: ExprRef, by: WidthInt) -> ExprRef {
558        self.ctx.borrow_mut().zero_extend(e, by)
559    }
560    pub fn sign_extend(&self, e: ExprRef, by: WidthInt) -> ExprRef {
561        self.ctx.borrow_mut().sign_extend(e, by)
562    }
563
564    /// Sign or zero extends depending on the value of `signed`.
565    pub fn extend(&mut self, e: ExprRef, by: WidthInt, signed: bool) -> ExprRef {
566        self.ctx.borrow_mut().extend(e, by, signed)
567    }
568
569    pub fn array_store(&self, array: ExprRef, index: ExprRef, data: ExprRef) -> ExprRef {
570        self.ctx.borrow_mut().array_store(array, index, data)
571    }
572
573    pub fn array_const(&self, e: ExprRef, index_width: WidthInt) -> ExprRef {
574        self.ctx.borrow_mut().array_const(e, index_width)
575    }
576
577    pub fn array_read(&self, array: ExprRef, index: ExprRef) -> ExprRef {
578        self.ctx.borrow_mut().array_read(array, index)
579    }
580}
581
582#[cfg(test)]
583mod tests {
584    use super::*;
585    use crate::expr::SerializableIrNode;
586
587    #[test]
588    fn ir_type_size() {
589        assert_eq!(std::mem::size_of::<StringRef>(), 4);
590        assert_eq!(std::mem::size_of::<ExprRef>(), 4);
591    }
592
593    #[test]
594    fn reference_ids() {
595        let mut ctx = Context::default();
596
597        // ids 1 and 2 are reserved for true and false
598        assert_eq!(ctx.get_false().0.get(), 1);
599        assert_eq!(ctx.get_true().0.get(), 2);
600
601        let str_id0 = ctx.string("a".into());
602        let id0 = ctx.add_expr(Expr::BVSymbol {
603            name: str_id0,
604            width: 1,
605        });
606        assert_eq!(id0.0.get(), 3, "ids start at three (for now)");
607        let id0_b = ctx.add_expr(Expr::BVSymbol {
608            name: str_id0,
609            width: 1,
610        });
611        assert_eq!(id0.0, id0_b.0, "ids should be interned!");
612        let id1 = ctx.add_expr(Expr::BVSymbol {
613            name: str_id0,
614            width: 2,
615        });
616        assert_eq!(id0.0.get() + 1, id1.0.get(), "ids should increment!");
617    }
618
619    /// make sure that we can intern a lot of strings before running out of IDs
620    #[test]
621    fn intern_lots_of_strings() {
622        let mut ctx = Context::default();
623        // we loose 1 ID since 0 is not a valid ID value
624        let max_strings = (1u64 << 16) - 1;
625        for ii in 0..max_strings {
626            let value = format!("{ii}AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA");
627            let _id = ctx.string(value.into());
628        }
629        // now that we have used up all the IDs, we should still be able to "add" strings that
630        // are already part of the context
631        let first = "0AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA";
632        assert_eq!(ctx.string(first.into()).index(), 0);
633    }
634
635    #[test]
636    fn test_builder() {
637        let mut ctx = Context::default();
638        let expr = ctx.build(|b| b.and(b.bv_symbol("a", 1), b.bv_symbol("b", 1)));
639        assert_eq!(expr.serialize_to_str(&ctx), "and(a, b)");
640    }
641
642    #[test]
643    fn test_bit_vec_val() {
644        let mut ctx = Context::default();
645        let _v0 = ctx.bit_vec_val(1, 128);
646    }
647}