1use std::collections::HashMap;
36use std::fmt;
37use std::ops::{BitAnd, BitOr, BitXor, Add, Sub, Shl, Shr};
38
39use super::Sym;
40use crate::bitvector::b64::B64;
41use crate::bitvector::BV;
42use crate::ir::EnumMember;
43
44#[derive(Clone, Debug)]
45pub enum Ty {
46 Bool,
47 BitVec(u32),
48 Enum(usize),
49 Array(Box<Ty>, Box<Ty>),
50}
51
52impl fmt::Display for Ty {
53 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
54 use Ty::*;
55 match self {
56 Bool => write!(f, "Bool"),
57 BitVec(sz) => write!(f, "(_ BitVec {})", sz),
58 Enum(e) => write!(f, "Enum{}", e),
59 Array(dom, codom) => {
60 write!(f, "(Array ")?;
61 dom.fmt(f)?;
62 write!(f, " ")?;
63 codom.fmt(f)?;
64 write!(f, ")")
65 }
66 }
67 }
68}
69
70#[derive(Clone, Debug)]
71pub enum Exp {
72 Var(Sym),
73 Bits(Vec<bool>),
74 Bits64(B64),
75 Enum(EnumMember),
76 Bool(bool),
77 Eq(Box<Exp>, Box<Exp>),
78 Neq(Box<Exp>, Box<Exp>),
79 And(Box<Exp>, Box<Exp>),
80 Or(Box<Exp>, Box<Exp>),
81 Not(Box<Exp>),
82 Bvnot(Box<Exp>),
83 Bvand(Box<Exp>, Box<Exp>),
84 Bvor(Box<Exp>, Box<Exp>),
85 Bvxor(Box<Exp>, Box<Exp>),
86 Bvnand(Box<Exp>, Box<Exp>),
87 Bvnor(Box<Exp>, Box<Exp>),
88 Bvxnor(Box<Exp>, Box<Exp>),
89 Bvneg(Box<Exp>),
90 Bvadd(Box<Exp>, Box<Exp>),
91 Bvsub(Box<Exp>, Box<Exp>),
92 Bvmul(Box<Exp>, Box<Exp>),
93 Bvudiv(Box<Exp>, Box<Exp>),
94 Bvsdiv(Box<Exp>, Box<Exp>),
95 Bvurem(Box<Exp>, Box<Exp>),
96 Bvsrem(Box<Exp>, Box<Exp>),
97 Bvsmod(Box<Exp>, Box<Exp>),
98 Bvult(Box<Exp>, Box<Exp>),
99 Bvslt(Box<Exp>, Box<Exp>),
100 Bvule(Box<Exp>, Box<Exp>),
101 Bvsle(Box<Exp>, Box<Exp>),
102 Bvuge(Box<Exp>, Box<Exp>),
103 Bvsge(Box<Exp>, Box<Exp>),
104 Bvugt(Box<Exp>, Box<Exp>),
105 Bvsgt(Box<Exp>, Box<Exp>),
106 Extract(u32, u32, Box<Exp>),
107 ZeroExtend(u32, Box<Exp>),
108 SignExtend(u32, Box<Exp>),
109 Bvshl(Box<Exp>, Box<Exp>),
110 Bvlshr(Box<Exp>, Box<Exp>),
111 Bvashr(Box<Exp>, Box<Exp>),
112 Concat(Box<Exp>, Box<Exp>),
113 Ite(Box<Exp>, Box<Exp>, Box<Exp>),
114 App(Sym, Vec<Exp>),
115 Select(Box<Exp>, Box<Exp>),
116 Store(Box<Exp>, Box<Exp>, Box<Exp>),
117}
118
119#[allow(clippy::needless_range_loop)]
120pub fn bits64(bits: u64, size: u32) -> Exp {
121 if size <= 64 {
122 Exp::Bits64(B64::new(bits, size))
123 } else {
124 let mut bitvec = [false; 64];
125 for n in 0..64 {
126 if (bits >> n & 1) == 1 {
127 bitvec[n] = true
128 }
129 }
130 Exp::Bits(bitvec.to_vec())
131 }
132}
133
134fn is_bits64(exp: &Exp) -> bool {
135 matches!(exp, Exp::Bits64(_))
136}
137
138fn is_bits(exp: &Exp) -> bool {
139 matches!(exp, Exp::Bits(_))
140}
141
142fn extract_bits64(exp: &Exp) -> B64 {
143 match *exp {
144 Exp::Bits64(bv) => bv,
145 _ => unreachable!(),
146 }
147}
148
149fn extract_bits(exp: Exp) -> Vec<bool> {
150 match exp {
151 Exp::Bits(bv) => bv,
152 _ => unreachable!(),
153 }
154}
155
156macro_rules! binary_eval {
157 ($eval:path, $exp_op:path, $small_op:path, $lhs:ident, $rhs:ident) => {{
158 *$lhs = $lhs.eval();
159 *$rhs = $rhs.eval();
160 if is_bits64(&$lhs) & is_bits64(&$rhs) {
161 Exp::Bits64($small_op(extract_bits64(&$lhs), extract_bits64(&$rhs)))
162 } else {
163 $exp_op($lhs, $rhs)
164 }
165 }};
166}
167
168fn eval_extract(hi: u32, lo: u32, exp: Box<Exp>) -> Exp {
169 if is_bits64(&exp) {
170 Exp::Bits64(extract_bits64(&exp).extract(hi, lo).unwrap())
171 } else if is_bits(&exp) {
172 let orig_vec = extract_bits(*exp);
173 let len = (hi - lo) + 1;
174 if len <= 64 {
175 let mut bv = B64::zeros(len);
176 for n in 0..len {
177 if orig_vec[(n + lo) as usize] {
178 bv = bv.set_slice(n, B64::ones(1))
179 }
180 }
181 Exp::Bits64(bv)
182 } else {
183 let mut vec = vec![false; len as usize];
184 for n in 0..len {
185 if orig_vec[(n + lo) as usize] {
186 vec[n as usize] = true
187 }
188 }
189 Exp::Bits(vec)
190 }
191 } else {
192 Exp::Extract(hi, lo, exp)
193 }
194}
195
196fn eval_zero_extend(len: u32, exp: Box<Exp>) -> Exp {
197 if is_bits64(&exp) {
198 let bv = extract_bits64(&exp);
199 Exp::Bits64(bv.zero_extend(bv.len() + len))
200 } else {
201 Exp::ZeroExtend(len, exp)
202 }
203}
204
205fn eval_sign_extend(len: u32, exp: Box<Exp>) -> Exp {
206 if is_bits64(&exp) {
207 let bv = extract_bits64(&exp);
208 Exp::Bits64(bv.sign_extend(bv.len() + len))
209 } else {
210 Exp::SignExtend(len, exp)
211 }
212}
213
214impl Exp {
215 pub fn eval(self) -> Self {
216 use Exp::*;
217 match self {
218 Bvnot(mut exp) => {
219 *exp = exp.eval();
220 match *exp {
221 Bits64(bv) => Bits64(!bv),
222 Bits(mut vec) => {
223 vec.iter_mut().for_each(|b| *b = !*b);
224 Bits(vec)
225 }
226 _ => Bvnot(exp),
227 }
228 }
229 Eq(mut lhs, mut rhs) => {
230 *lhs = lhs.eval();
231 *rhs = rhs.eval();
232 Eq(lhs, rhs)
233 }
234 Bvand(mut lhs, mut rhs) => binary_eval!(Exp::eval, Bvand, B64::bitand, lhs, rhs),
235 Bvor(mut lhs, mut rhs) => binary_eval!(Exp::eval, Bvor, B64::bitor, lhs, rhs),
236 Bvxor(mut lhs, mut rhs) => binary_eval!(Exp::eval, Bvxor, B64::bitxor, lhs, rhs),
237 Bvadd(mut lhs, mut rhs) => binary_eval!(Exp::eval, Bvadd, B64::add, lhs, rhs),
238 Bvsub(mut lhs, mut rhs) => binary_eval!(Exp::eval, Bvsub, B64::sub, lhs, rhs),
239 Bvlshr(mut lhs, mut rhs) => binary_eval!(Exp::eval, Bvlshr, B64::shr, lhs, rhs),
240 Bvshl(mut lhs, mut rhs) => binary_eval!(Exp::eval, Bvshl, B64::shl, lhs, rhs),
241 Extract(hi, lo, mut exp) => {
242 *exp = exp.eval();
243 eval_extract(hi, lo, exp)
244 }
245 ZeroExtend(len, mut exp) => {
246 *exp = exp.eval();
247 eval_zero_extend(len, exp)
248 }
249 SignExtend(len, mut exp) => {
250 *exp = exp.eval();
251 eval_sign_extend(len, exp)
252 }
253 _ => self,
254 }
255 }
256
257 pub fn modify<F>(&mut self, f: &F)
259 where
260 F: Fn(&mut Exp),
261 {
262 use Exp::*;
263 match self {
264 Var(_) | Bits(_) | Bits64(_) | Enum { .. } | Bool(_) => (),
265 Not(exp) | Bvnot(exp) | Bvneg(exp) | Extract(_, _, exp) | ZeroExtend(_, exp) | SignExtend(_, exp) => {
266 exp.modify(f)
267 }
268 Eq(lhs, rhs)
269 | Neq(lhs, rhs)
270 | And(lhs, rhs)
271 | Or(lhs, rhs)
272 | Bvand(lhs, rhs)
273 | Bvor(lhs, rhs)
274 | Bvxor(lhs, rhs)
275 | Bvnand(lhs, rhs)
276 | Bvnor(lhs, rhs)
277 | Bvxnor(lhs, rhs)
278 | Bvadd(lhs, rhs)
279 | Bvsub(lhs, rhs)
280 | Bvmul(lhs, rhs)
281 | Bvudiv(lhs, rhs)
282 | Bvsdiv(lhs, rhs)
283 | Bvurem(lhs, rhs)
284 | Bvsrem(lhs, rhs)
285 | Bvsmod(lhs, rhs)
286 | Bvult(lhs, rhs)
287 | Bvslt(lhs, rhs)
288 | Bvule(lhs, rhs)
289 | Bvsle(lhs, rhs)
290 | Bvuge(lhs, rhs)
291 | Bvsge(lhs, rhs)
292 | Bvugt(lhs, rhs)
293 | Bvsgt(lhs, rhs)
294 | Bvshl(lhs, rhs)
295 | Bvlshr(lhs, rhs)
296 | Bvashr(lhs, rhs)
297 | Concat(lhs, rhs) => {
298 lhs.modify(f);
299 rhs.modify(f);
300 }
301 Ite(cond, then_exp, else_exp) => {
302 cond.modify(f);
303 then_exp.modify(f);
304 else_exp.modify(f)
305 }
306 App(_, args) => {
307 for exp in args {
308 exp.modify(f)
309 }
310 }
311 Select(array, index) => {
312 array.modify(f);
313 index.modify(f);
314 }
315 Store(array, index, val) => {
316 array.modify(f);
317 index.modify(f);
318 val.modify(f);
319 }
320 };
321 f(self)
322 }
323
324 pub fn modify_top_down<F>(&mut self, f: &F)
326 where
327 F: Fn(&mut Exp),
328 {
329 use Exp::*;
330 f(self);
331 match self {
332 Var(_) | Bits(_) | Bits64(_) | Enum { .. } | Bool(_) => (),
333 Not(exp) | Bvnot(exp) | Bvneg(exp) | Extract(_, _, exp) | ZeroExtend(_, exp) | SignExtend(_, exp) => {
334 exp.modify(f)
335 }
336 Eq(lhs, rhs)
337 | Neq(lhs, rhs)
338 | And(lhs, rhs)
339 | Or(lhs, rhs)
340 | Bvand(lhs, rhs)
341 | Bvor(lhs, rhs)
342 | Bvxor(lhs, rhs)
343 | Bvnand(lhs, rhs)
344 | Bvnor(lhs, rhs)
345 | Bvxnor(lhs, rhs)
346 | Bvadd(lhs, rhs)
347 | Bvsub(lhs, rhs)
348 | Bvmul(lhs, rhs)
349 | Bvudiv(lhs, rhs)
350 | Bvsdiv(lhs, rhs)
351 | Bvurem(lhs, rhs)
352 | Bvsrem(lhs, rhs)
353 | Bvsmod(lhs, rhs)
354 | Bvult(lhs, rhs)
355 | Bvslt(lhs, rhs)
356 | Bvule(lhs, rhs)
357 | Bvsle(lhs, rhs)
358 | Bvuge(lhs, rhs)
359 | Bvsge(lhs, rhs)
360 | Bvugt(lhs, rhs)
361 | Bvsgt(lhs, rhs)
362 | Bvshl(lhs, rhs)
363 | Bvlshr(lhs, rhs)
364 | Bvashr(lhs, rhs)
365 | Concat(lhs, rhs) => {
366 lhs.modify(f);
367 rhs.modify(f);
368 }
369 Ite(cond, then_exp, else_exp) => {
370 cond.modify(f);
371 then_exp.modify(f);
372 else_exp.modify(f)
373 }
374 App(_, args) => {
375 for exp in args {
376 exp.modify(f)
377 }
378 }
379 Select(array, index) => {
380 array.modify(f);
381 index.modify(f);
382 }
383 Store(array, index, val) => {
384 array.modify(f);
385 index.modify(f);
386 val.modify(f);
387 }
388 }
389 }
390
391 fn binary_commute_extract(self) -> Result<(fn (Box<Self>, Box<Self>) -> Self, Box<Self>, Box<Self>), Self> {
392 use Exp::*;
393 match self {
394 Bvand(lhs, rhs) => Ok((Bvand, lhs, rhs)),
395 Bvor(lhs, rhs) => Ok((Bvor, lhs, rhs)),
396 Bvxor(lhs, rhs) => Ok((Bvxor, lhs, rhs)),
397 Bvnand(lhs, rhs) => Ok((Bvnand, lhs, rhs)),
398 Bvnor(lhs, rhs) => Ok((Bvnor, lhs, rhs)),
399 Bvxnor(lhs, rhs) => Ok((Bvxnor, lhs, rhs)),
400 Bvadd(lhs, rhs) => Ok((Bvadd, lhs, rhs)),
401 Bvsub(lhs, rhs) => Ok((Bvsub, lhs, rhs)),
402 _ => Err(self),
403 }
404 }
405
406 pub fn commute_extract(&mut self) {
407 use Exp::*;
408 if let Extract(hi, lo, exp) = self {
409 match std::mem::replace(&mut **exp, Bool(false)).binary_commute_extract() {
410 Ok((op, lhs, rhs)) => {
411 *self = op(Box::new(Extract(*hi, *lo, lhs)), Box::new(Extract(*hi, *lo, rhs)))
412 }
413 Err(mut orig_exp) => {
414 std::mem::swap(&mut **exp, &mut orig_exp);
415 }
416 }
417 }
418 }
419
420 pub fn subst_once_in_place(&mut self, substs: &mut HashMap<Sym, Option<Exp>>) {
421 use Exp::*;
422 match self {
423 Var(v) => {
424 if let Some(exp) = substs.get_mut(v) {
425 if let Some(exp) = exp.take() {
426 *self = exp
427 } else {
428 panic!("Tried to substitute twice in subst_once_in_place")
429 }
430 }
431 }
432 Bits(_) | Bits64(_) | Enum { .. } | Bool(_) => (),
433 Not(exp) | Bvnot(exp) | Bvneg(exp) | Extract(_, _, exp) | ZeroExtend(_, exp) | SignExtend(_, exp) => {
434 exp.subst_once_in_place(substs)
435 }
436 Eq(lhs, rhs)
437 | Neq(lhs, rhs)
438 | And(lhs, rhs)
439 | Or(lhs, rhs)
440 | Bvand(lhs, rhs)
441 | Bvor(lhs, rhs)
442 | Bvxor(lhs, rhs)
443 | Bvnand(lhs, rhs)
444 | Bvnor(lhs, rhs)
445 | Bvxnor(lhs, rhs)
446 | Bvadd(lhs, rhs)
447 | Bvsub(lhs, rhs)
448 | Bvmul(lhs, rhs)
449 | Bvudiv(lhs, rhs)
450 | Bvsdiv(lhs, rhs)
451 | Bvurem(lhs, rhs)
452 | Bvsrem(lhs, rhs)
453 | Bvsmod(lhs, rhs)
454 | Bvult(lhs, rhs)
455 | Bvslt(lhs, rhs)
456 | Bvule(lhs, rhs)
457 | Bvsle(lhs, rhs)
458 | Bvuge(lhs, rhs)
459 | Bvsge(lhs, rhs)
460 | Bvugt(lhs, rhs)
461 | Bvsgt(lhs, rhs)
462 | Bvshl(lhs, rhs)
463 | Bvlshr(lhs, rhs)
464 | Bvashr(lhs, rhs)
465 | Concat(lhs, rhs) => {
466 lhs.subst_once_in_place(substs);
467 rhs.subst_once_in_place(substs);
468 }
469 Ite(cond, then_exp, else_exp) => {
470 cond.subst_once_in_place(substs);
471 then_exp.subst_once_in_place(substs);
472 else_exp.subst_once_in_place(substs)
473 }
474 App(_, args) => {
475 for exp in args {
476 exp.subst_once_in_place(substs)
477 }
478 }
479 Select(array, index) => {
480 array.subst_once_in_place(substs);
481 index.subst_once_in_place(substs);
482 }
483 Store(array, index, val) => {
484 array.subst_once_in_place(substs);
485 index.subst_once_in_place(substs);
486 val.subst_once_in_place(substs);
487 }
488 }
489 }
490
491 pub fn infer(&self, tcx: &HashMap<Sym, Ty>, ftcx: &HashMap<Sym, (Vec<Ty>, Ty)>) -> Option<Ty> {
493 use Exp::*;
494 match self {
495 Var(v) => tcx.get(v).map(Ty::clone),
496 Bits(bv) => Some(Ty::BitVec(bv.len() as u32)),
497 Bits64(bv) => Some(Ty::BitVec(bv.len())),
498 Enum(e) => Some(Ty::Enum(e.enum_id)),
499 Bool(_)
500 | Not(_)
501 | Eq(_, _)
502 | Neq(_, _)
503 | And(_, _)
504 | Or(_, _)
505 | Bvult(_, _)
506 | Bvslt(_, _)
507 | Bvule(_, _)
508 | Bvsle(_, _)
509 | Bvuge(_, _)
510 | Bvsge(_, _)
511 | Bvugt(_, _)
512 | Bvsgt(_, _) => Some(Ty::Bool),
513 Bvnot(exp) | Bvneg(exp) => exp.infer(tcx, ftcx),
514 Extract(i, j, _) => Some(Ty::BitVec((i - j) + 1)),
515 ZeroExtend(ext, exp) | SignExtend(ext, exp) => match exp.infer(tcx, ftcx) {
516 Some(Ty::BitVec(sz)) => Some(Ty::BitVec(sz + ext)),
517 _ => None,
518 },
519 Bvand(lhs, _)
520 | Bvor(lhs, _)
521 | Bvxor(lhs, _)
522 | Bvnand(lhs, _)
523 | Bvnor(lhs, _)
524 | Bvxnor(lhs, _)
525 | Bvadd(lhs, _)
526 | Bvsub(lhs, _)
527 | Bvmul(lhs, _)
528 | Bvudiv(lhs, _)
529 | Bvsdiv(lhs, _)
530 | Bvurem(lhs, _)
531 | Bvsrem(lhs, _)
532 | Bvsmod(lhs, _)
533 | Bvshl(lhs, _)
534 | Bvlshr(lhs, _)
535 | Bvashr(lhs, _) => lhs.infer(tcx, ftcx),
536 Concat(lhs, rhs) => match (lhs.infer(tcx, ftcx), rhs.infer(tcx, ftcx)) {
537 (Some(Ty::BitVec(lsz)), Some(Ty::BitVec(rsz))) => Some(Ty::BitVec(lsz + rsz)),
538 (_, _) => None,
539 },
540 Ite(_, then_exp, _) => then_exp.infer(tcx, ftcx),
541 App(f, _) => ftcx.get(f).map(|x| x.1.clone()),
542 Select(array, _) => match array.infer(tcx, ftcx) {
543 Some(Ty::Array(_, codom_ty)) => Some(*codom_ty),
544 _ => None,
545 },
546 Store(array, _, _) => array.infer(tcx, ftcx),
547 }
548 }
549}
550
551#[derive(Clone, Debug)]
552pub enum Def {
553 DeclareConst(Sym, Ty),
554 DeclareFun(Sym, Vec<Ty>, Ty),
555 DefineConst(Sym, Exp),
556 DefineEnum(Sym, usize),
557 Assert(Exp),
558}