isla_lib/
smt.rs

1// BSD 2-Clause License
2//
3// Copyright (c) 2019, 2020 Alasdair Armstrong
4// Copyright (c) 2020 Brian Campbell
5//
6// All rights reserved.
7//
8// Redistribution and use in source and binary forms, with or without
9// modification, are permitted provided that the following conditions are
10// met:
11//
12// 1. Redistributions of source code must retain the above copyright
13// notice, this list of conditions and the following disclaimer.
14//
15// 2. Redistributions in binary form must reproduce the above copyright
16// notice, this list of conditions and the following disclaimer in the
17// documentation and/or other materials provided with the distribution.
18//
19// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23// HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
31//! This module defines an interface with the SMT solver, primarily
32//! via the [Solver] type. It provides a safe abstraction over the
33//! [z3_sys] crate. In addition, all the interaction with the SMT
34//! solver is logged as a [Trace] in an SMTLIB-like format, expanded
35//! with additional events marking e.g. memory events, the start and
36//! end of processor cycles, etc (see the [Event] type). Points in
37//! these traces can be snapshotted and shared between threads via the
38//! [Checkpoint] type.
39
40use libc::{c_int, c_uint};
41use serde::{Deserialize, Serialize};
42use z3_sys::*;
43
44use std::collections::HashMap;
45use std::convert::TryInto;
46use std::error::Error;
47use std::ffi::{CStr, CString};
48use std::fmt;
49use std::io::Write;
50use std::mem;
51use std::ptr;
52use std::sync::Arc;
53
54use crate::bitvector::b64::B64;
55use crate::bitvector::BV;
56use crate::error::ExecError;
57use crate::ir::{EnumMember, Name, Symtab, Val};
58use crate::zencode;
59
60/// A newtype wrapper for symbolic variables, which are `u32` under
61/// the hood.
62#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
63pub struct Sym {
64    pub(crate) id: u32,
65}
66
67impl Sym {
68    pub fn from_u32(id: u32) -> Self {
69        Sym { id }
70    }
71}
72
73impl<B> Into<Result<Val<B>, ExecError>> for Sym {
74    fn into(self) -> Result<Val<B>, ExecError> {
75        Ok(Val::Symbolic(self))
76    }
77}
78
79impl fmt::Display for Sym {
80    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
81        write!(f, "{}", self.id)
82    }
83}
84
85pub mod smtlib;
86use smtlib::*;
87
88/// Snapshot of interaction with underlying solver that can be
89/// efficiently cloned and shared between threads.
90#[derive(Clone, Default)]
91pub struct Checkpoint<B> {
92    num: usize,
93    next_var: u32,
94    trace: Arc<Option<Trace<B>>>,
95}
96
97impl<B> Checkpoint<B> {
98    pub fn new() -> Self {
99        Checkpoint { num: 0, next_var: 0, trace: Arc::new(None) }
100    }
101
102    pub fn trace(&self) -> &Option<Trace<B>> {
103        &self.trace
104    }
105}
106
107/// For the concurrency models, register accesses must be logged at a
108/// subfield level granularity (e.g. for PSTATE in ARM ASL), which is
109/// what the Accessor type is for.
110#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
111pub enum Accessor {
112    Field(Name),
113}
114
115impl Accessor {
116    pub fn to_string(&self, symtab: &Symtab) -> String {
117        match self {
118            Accessor::Field(name) => format!("(_ field |{}|)", zencode::decode(symtab.to_str(*name))),
119        }
120    }
121
122    pub fn pretty(&self, buf: &mut dyn Write, symtab: &Symtab) -> Result<(), Box<dyn Error>> {
123        match self {
124            Accessor::Field(name) => write!(buf, ".{}", zencode::decode(symtab.to_str(*name)))?,
125        }
126        Ok(())
127    }
128}
129
130#[derive(Clone, Debug)]
131pub enum Event<B> {
132    Smt(Def),
133    Fork(u32, Sym, String),
134    ReadReg(Name, Vec<Accessor>, Val<B>),
135    WriteReg(Name, Vec<Accessor>, Val<B>),
136    ReadMem {
137        value: Val<B>,
138        read_kind: Val<B>,
139        address: Val<B>,
140        bytes: u32,
141        tag_value: Option<Val<B>>,
142        kind: &'static str,
143    },
144    WriteMem {
145        value: Sym,
146        write_kind: Val<B>,
147        address: Val<B>,
148        data: Val<B>,
149        bytes: u32,
150        tag_value: Option<Val<B>>,
151        kind: &'static str,
152    },
153    Branch {
154        address: Val<B>,
155    },
156    Barrier {
157        barrier_kind: Val<B>,
158    },
159    CacheOp {
160        cache_op_kind: Val<B>,
161        address: Val<B>,
162    },
163    MarkReg {
164        regs: Vec<Name>,
165        mark: String,
166    },
167    Cycle,
168    Instr(Val<B>),
169    Sleeping(Sym),
170    SleepRequest,
171    WakeupRequest,
172}
173
174impl<B: BV> Event<B> {
175    pub fn is_smt(&self) -> bool {
176        matches!(self, Event::Smt(_))
177    }
178
179    pub fn is_reg(&self) -> bool {
180        matches!(self, Event::ReadReg(_, _, _) | Event::WriteReg(_, _, _) | Event::MarkReg { .. })
181    }
182
183    pub fn is_write_reg(&self) -> bool {
184        matches!(self, Event::WriteReg(_, _, _))
185    }
186
187    pub fn is_cycle(&self) -> bool {
188        matches!(self, Event::Cycle)
189    }
190
191    pub fn is_instr(&self) -> bool {
192        matches!(self, Event::Instr(_))
193    }
194
195    pub fn is_branch(&self) -> bool {
196        matches!(self, Event::Branch { .. })
197    }
198
199    pub fn is_barrier(&self) -> bool {
200        matches!(self, Event::Barrier { .. })
201    }
202
203    pub fn is_fork(&self) -> bool {
204        matches!(self, Event::Fork(_, _, _))
205    }
206
207    pub fn is_memory(&self) -> bool {
208        matches!(self, Event::ReadMem { .. } | Event::WriteMem { .. } | Event::Barrier { .. } | Event::CacheOp { .. })
209    }
210
211    pub fn is_memory_read(&self) -> bool {
212        matches!(self, Event::ReadMem { .. })
213    }
214
215    pub fn is_memory_write(&self) -> bool {
216        matches!(self, Event::WriteMem { .. })
217    }
218
219    pub fn is_cache_op(&self) -> bool {
220        matches!(self, Event::CacheOp { .. })
221    }
222
223    pub fn has_barrier_kind(&self, bk: usize) -> bool {
224        match self {
225            Event::Barrier { barrier_kind: Val::Enum(e) } => e.member == bk,
226            _ => false,
227        }
228    }
229
230    pub fn has_read_kind(&self, rk: usize) -> bool {
231        match self {
232            Event::ReadMem { read_kind: Val::Enum(e), .. } => e.member == rk,
233            _ => false,
234        }
235    }
236
237    pub fn has_write_kind(&self, wk: usize) -> bool {
238        match self {
239            Event::WriteMem { write_kind: Val::Enum(e), .. } => e.member == wk,
240            _ => false,
241        }
242    }
243
244    pub fn has_cache_op_kind(&self, ck: usize) -> bool {
245        match self {
246            Event::CacheOp { cache_op_kind: Val::Enum(e), .. } => e.member == ck,
247            _ => false,
248        }
249    }
250}
251
252pub type EvPath<B> = Vec<Event<B>>;
253
254/// Abstractly represents a sequence of events in such a way that
255/// checkpoints can be created and shared.
256#[derive(Debug)]
257pub struct Trace<B> {
258    checkpoints: usize,
259    head: Vec<Event<B>>,
260    tail: Arc<Option<Trace<B>>>,
261}
262
263impl<B: BV> Trace<B> {
264    #[allow(clippy::new_without_default)]
265    pub fn new() -> Self {
266        Trace { checkpoints: 0, head: Vec::new(), tail: Arc::new(None) }
267    }
268
269    pub fn checkpoint(&mut self, next_var: u32) -> Checkpoint<B> {
270        let mut head = Vec::new();
271        mem::swap(&mut self.head, &mut head);
272        let tail = Arc::new(Some(Trace { checkpoints: self.checkpoints, head, tail: self.tail.clone() }));
273        self.checkpoints += 1;
274        self.tail = tail.clone();
275        Checkpoint { num: self.checkpoints, trace: tail, next_var }
276    }
277
278    pub fn to_vec<'a>(&'a self) -> Vec<&'a Event<B>> {
279        let mut vec: Vec<&'a Event<B>> = Vec::new();
280
281        let mut current_head = &self.head;
282        let mut current_tail = self.tail.as_ref();
283        loop {
284            for def in current_head.iter().rev() {
285                vec.push(def)
286            }
287            match current_tail {
288                Some(trace) => {
289                    current_head = &trace.head;
290                    current_tail = trace.tail.as_ref();
291                }
292                None => return vec,
293            }
294        }
295    }
296}
297
298/// Config is a wrapper around the `Z3_config` type from the C
299/// API. `Z3_del_config` is called when it is dropped.
300pub struct Config {
301    z3_cfg: Z3_config,
302}
303
304impl Config {
305    pub fn new() -> Self {
306        unsafe { Config { z3_cfg: Z3_mk_config() } }
307    }
308}
309
310impl Drop for Config {
311    fn drop(&mut self) {
312        unsafe { Z3_del_config(self.z3_cfg) }
313    }
314}
315
316impl Default for Config {
317    fn default() -> Self {
318        Self::new()
319    }
320}
321
322impl Config {
323    pub fn set_param_value(&mut self, id: &str, value: &str) {
324        let id = CString::new(id).unwrap();
325        let value = CString::new(value).unwrap();
326        unsafe { Z3_set_param_value(self.z3_cfg, id.as_ptr(), value.as_ptr()) }
327    }
328}
329
330pub fn global_set_param_value(id: &str, value: &str) {
331    let id = CString::new(id).unwrap();
332    let value = CString::new(value).unwrap();
333    unsafe { Z3_global_param_set(id.as_ptr(), value.as_ptr()) }
334}
335
336/// Context is a wrapper around `Z3_context`.
337pub struct Context {
338    z3_ctx: Z3_context,
339}
340
341impl Context {
342    pub fn new(cfg: Config) -> Self {
343        unsafe { Context { z3_ctx: Z3_mk_context_rc(cfg.z3_cfg) } }
344    }
345
346    fn error(&self) -> ExecError {
347        unsafe {
348            let code = Z3_get_error_code(self.z3_ctx);
349            let msg = Z3_get_error_msg(self.z3_ctx, code);
350            let str: String = CStr::from_ptr(msg).to_string_lossy().to_string();
351            ExecError::Z3Error(str)
352        }
353    }
354}
355
356impl Drop for Context {
357    fn drop(&mut self) {
358        unsafe { Z3_del_context(self.z3_ctx) }
359    }
360}
361
362struct Enum {
363    sort: Z3_sort,
364    size: usize,
365    consts: Vec<Z3_func_decl>,
366    testers: Vec<Z3_func_decl>,
367}
368
369struct Enums<'ctx> {
370    enums: Vec<Enum>,
371    ctx: &'ctx Context,
372}
373
374impl<'ctx> Enums<'ctx> {
375    fn new(ctx: &'ctx Context) -> Self {
376        Enums { enums: Vec::new(), ctx }
377    }
378
379    fn add_enum(&mut self, name: Sym, members: &[Sym]) {
380        unsafe {
381            let ctx = self.ctx.z3_ctx;
382            let size = members.len();
383
384            let name = Z3_mk_int_symbol(ctx, name.id as c_int);
385            let members: Vec<Z3_symbol> = members.iter().map(|m| Z3_mk_int_symbol(ctx, m.id as c_int)).collect();
386
387            let mut consts = mem::ManuallyDrop::new(Vec::with_capacity(size));
388            let mut testers = mem::ManuallyDrop::new(Vec::with_capacity(size));
389
390            let sort = Z3_mk_enumeration_sort(
391                ctx,
392                name,
393                size as c_uint,
394                members.as_ptr(),
395                consts.as_mut_ptr(),
396                testers.as_mut_ptr(),
397            );
398
399            let consts = Vec::from_raw_parts(consts.as_mut_ptr(), size, size);
400            let testers = Vec::from_raw_parts(testers.as_mut_ptr(), size, size);
401
402            for i in 0..size {
403                Z3_inc_ref(ctx, Z3_func_decl_to_ast(ctx, consts[i]));
404                Z3_inc_ref(ctx, Z3_func_decl_to_ast(ctx, testers[i]))
405            }
406            Z3_inc_ref(ctx, Z3_sort_to_ast(ctx, sort));
407
408            self.enums.push(Enum { sort, size, consts, testers })
409        }
410    }
411}
412
413impl<'ctx> Drop for Enums<'ctx> {
414    fn drop(&mut self) {
415        unsafe {
416            let ctx = self.ctx.z3_ctx;
417            for e in self.enums.drain(..) {
418                for i in 0..e.size {
419                    Z3_dec_ref(ctx, Z3_func_decl_to_ast(ctx, e.consts[i]));
420                    Z3_dec_ref(ctx, Z3_func_decl_to_ast(ctx, e.testers[i]))
421                }
422                Z3_dec_ref(ctx, Z3_sort_to_ast(ctx, e.sort))
423            }
424        }
425    }
426}
427
428struct Sort<'ctx> {
429    z3_sort: Z3_sort,
430    ctx: &'ctx Context,
431}
432
433impl<'ctx> Sort<'ctx> {
434    fn bitvec(ctx: &'ctx Context, sz: u32) -> Self {
435        unsafe {
436            let z3_sort = Z3_mk_bv_sort(ctx.z3_ctx, sz);
437            Z3_inc_ref(ctx.z3_ctx, Z3_sort_to_ast(ctx.z3_ctx, z3_sort));
438            Sort { z3_sort, ctx }
439        }
440    }
441
442    fn new(ctx: &'ctx Context, enums: &Enums<'ctx>, ty: &Ty) -> Self {
443        unsafe {
444            match ty {
445                Ty::Bool => {
446                    let z3_sort = Z3_mk_bool_sort(ctx.z3_ctx);
447                    Z3_inc_ref(ctx.z3_ctx, Z3_sort_to_ast(ctx.z3_ctx, z3_sort));
448                    Sort { z3_sort, ctx }
449                }
450                Ty::BitVec(sz) => Self::bitvec(ctx, *sz),
451                Ty::Enum(e) => {
452                    let z3_sort = enums.enums[*e].sort;
453                    Z3_inc_ref(ctx.z3_ctx, Z3_sort_to_ast(ctx.z3_ctx, z3_sort));
454                    Sort { z3_sort, ctx }
455                }
456                Ty::Array(dom, codom) => {
457                    let dom_s = Self::new(ctx, enums, dom);
458                    let codom_s = Self::new(ctx, enums, codom);
459                    let z3_sort = Z3_mk_array_sort(ctx.z3_ctx, dom_s.z3_sort, codom_s.z3_sort);
460                    Z3_inc_ref(ctx.z3_ctx, Z3_sort_to_ast(ctx.z3_ctx, z3_sort));
461                    Sort { z3_sort, ctx }
462                }
463            }
464        }
465    }
466}
467
468impl<'ctx> Drop for Sort<'ctx> {
469    fn drop(&mut self) {
470        unsafe {
471            let ctx = self.ctx.z3_ctx;
472            Z3_dec_ref(ctx, Z3_sort_to_ast(ctx, self.z3_sort))
473        }
474    }
475}
476
477struct FuncDecl<'ctx> {
478    z3_func_decl: Z3_func_decl,
479    ctx: &'ctx Context,
480}
481
482impl<'ctx> FuncDecl<'ctx> {
483    fn new(ctx: &'ctx Context, v: Sym, enums: &Enums<'ctx>, arg_tys: &[Ty], ty: &Ty) -> Self {
484        unsafe {
485            let name = Z3_mk_int_symbol(ctx.z3_ctx, v.id as c_int);
486            let arg_sorts: Vec<Sort> = arg_tys.iter().map(|ty| Sort::new(ctx, enums, ty)).collect();
487            let arg_z3_sorts: Vec<Z3_sort> = arg_sorts.iter().map(|s| s.z3_sort).collect();
488            let args: u32 = arg_sorts.len() as u32;
489            let z3_func_decl =
490                Z3_mk_func_decl(ctx.z3_ctx, name, args, arg_z3_sorts.as_ptr(), Sort::new(ctx, enums, ty).z3_sort);
491            Z3_inc_ref(ctx.z3_ctx, Z3_func_decl_to_ast(ctx.z3_ctx, z3_func_decl));
492            FuncDecl { z3_func_decl, ctx }
493        }
494    }
495}
496
497impl<'ctx> Drop for FuncDecl<'ctx> {
498    fn drop(&mut self) {
499        unsafe {
500            let ctx = self.ctx.z3_ctx;
501            Z3_dec_ref(ctx, Z3_func_decl_to_ast(ctx, self.z3_func_decl))
502        }
503    }
504}
505
506struct Ast<'ctx> {
507    z3_ast: Z3_ast,
508    ctx: &'ctx Context,
509}
510
511impl<'ctx> Clone for Ast<'ctx> {
512    fn clone(&self) -> Self {
513        unsafe {
514            let z3_ast = self.z3_ast;
515            Z3_inc_ref(self.ctx.z3_ctx, z3_ast);
516            Ast { z3_ast, ctx: self.ctx }
517        }
518    }
519}
520
521macro_rules! z3_unary_op {
522    ($i:ident, $arg:ident) => {
523        unsafe {
524            let z3_ast = $i($arg.ctx.z3_ctx, $arg.z3_ast);
525            Z3_inc_ref($arg.ctx.z3_ctx, z3_ast);
526            Ast { z3_ast, ctx: $arg.ctx }
527        }
528    };
529}
530
531macro_rules! z3_binary_op {
532    ($i:ident, $lhs:ident, $rhs:ident) => {
533        unsafe {
534            let z3_ast = $i($lhs.ctx.z3_ctx, $lhs.z3_ast, $rhs.z3_ast);
535            Z3_inc_ref($lhs.ctx.z3_ctx, z3_ast);
536            Ast { z3_ast, ctx: $lhs.ctx }
537        }
538    };
539}
540
541impl<'ctx> Ast<'ctx> {
542    fn mk_constant(fd: &FuncDecl<'ctx>) -> Self {
543        unsafe {
544            let z3_ast = Z3_mk_app(fd.ctx.z3_ctx, fd.z3_func_decl, 0, ptr::null());
545            Z3_inc_ref(fd.ctx.z3_ctx, z3_ast);
546            Ast { z3_ast, ctx: fd.ctx }
547        }
548    }
549
550    fn mk_app(fd: &FuncDecl<'ctx>, args: &[Ast<'ctx>]) -> Self {
551        unsafe {
552            let z3_args: Vec<Z3_ast> = args.iter().map(|ast| ast.z3_ast).collect();
553            let len = z3_args.len() as u32;
554            let z3_ast = Z3_mk_app(fd.ctx.z3_ctx, fd.z3_func_decl, len, z3_args.as_ptr());
555            Z3_inc_ref(fd.ctx.z3_ctx, z3_ast);
556            Ast { z3_ast, ctx: fd.ctx }
557        }
558    }
559
560    fn mk_enum_member(enums: &Enums<'ctx>, enum_id: usize, member: usize) -> Self {
561        unsafe {
562            let func_decl = enums.enums[enum_id].consts[member];
563            let z3_ast = Z3_mk_app(enums.ctx.z3_ctx, func_decl, 0, ptr::null());
564            Z3_inc_ref(enums.ctx.z3_ctx, z3_ast);
565            Ast { z3_ast, ctx: enums.ctx }
566        }
567    }
568
569    fn mk_bv_u64(ctx: &'ctx Context, sz: u32, bits: u64) -> Self {
570        unsafe {
571            let sort = Sort::bitvec(ctx, sz);
572            let z3_ast = Z3_mk_unsigned_int64(ctx.z3_ctx, bits, sort.z3_sort);
573            Z3_inc_ref(ctx.z3_ctx, z3_ast);
574            Ast { z3_ast, ctx }
575        }
576    }
577
578    fn mk_bv(ctx: &'ctx Context, sz: u32, bits: &[bool]) -> Self {
579        unsafe {
580            let z3_ast = Z3_mk_bv_numeral(ctx.z3_ctx, sz, bits.as_ptr());
581            Z3_inc_ref(ctx.z3_ctx, z3_ast);
582            Ast { z3_ast, ctx }
583        }
584    }
585
586    fn mk_bool(ctx: &'ctx Context, b: bool) -> Self {
587        unsafe {
588            let z3_ast = if b { Z3_mk_true(ctx.z3_ctx) } else { Z3_mk_false(ctx.z3_ctx) };
589            Z3_inc_ref(ctx.z3_ctx, z3_ast);
590            Ast { z3_ast, ctx }
591        }
592    }
593
594    fn mk_not(&self) -> Self {
595        z3_unary_op!(Z3_mk_not, self)
596    }
597
598    fn mk_eq(&self, rhs: &Ast<'ctx>) -> Self {
599        z3_binary_op!(Z3_mk_eq, self, rhs)
600    }
601
602    fn mk_and(&self, rhs: &Ast<'ctx>) -> Self {
603        unsafe {
604            let z3_ast = Z3_mk_and(self.ctx.z3_ctx, 2, &[self.z3_ast, rhs.z3_ast] as *const Z3_ast);
605            Z3_inc_ref(self.ctx.z3_ctx, z3_ast);
606            Ast { z3_ast, ctx: self.ctx }
607        }
608    }
609
610    fn mk_or(&self, rhs: &Ast<'ctx>) -> Self {
611        unsafe {
612            let z3_ast = Z3_mk_or(self.ctx.z3_ctx, 2, &[self.z3_ast, rhs.z3_ast] as *const Z3_ast);
613            Z3_inc_ref(self.ctx.z3_ctx, z3_ast);
614            Ast { z3_ast, ctx: self.ctx }
615        }
616    }
617
618    fn extract(&self, hi: u32, lo: u32) -> Self {
619        unsafe {
620            let z3_ast = Z3_mk_extract(self.ctx.z3_ctx, hi, lo, self.z3_ast);
621            Z3_inc_ref(self.ctx.z3_ctx, z3_ast);
622            Ast { z3_ast, ctx: self.ctx }
623        }
624    }
625
626    fn zero_extend(&self, i: u32) -> Self {
627        unsafe {
628            let z3_ast = Z3_mk_zero_ext(self.ctx.z3_ctx, i, self.z3_ast);
629            Z3_inc_ref(self.ctx.z3_ctx, z3_ast);
630            Ast { z3_ast, ctx: self.ctx }
631        }
632    }
633
634    fn sign_extend(&self, i: u32) -> Self {
635        unsafe {
636            let z3_ast = Z3_mk_sign_ext(self.ctx.z3_ctx, i, self.z3_ast);
637            Z3_inc_ref(self.ctx.z3_ctx, z3_ast);
638            Ast { z3_ast, ctx: self.ctx }
639        }
640    }
641
642    fn ite(&self, true_exp: &Ast<'ctx>, false_exp: &Ast<'ctx>) -> Self {
643        unsafe {
644            let z3_ast = Z3_mk_ite(self.ctx.z3_ctx, self.z3_ast, true_exp.z3_ast, false_exp.z3_ast);
645            Z3_inc_ref(self.ctx.z3_ctx, z3_ast);
646            Ast { z3_ast, ctx: self.ctx }
647        }
648    }
649
650    fn mk_bvnot(&self) -> Self {
651        z3_unary_op!(Z3_mk_bvnot, self)
652    }
653
654    fn mk_bvand(&self, rhs: &Ast<'ctx>) -> Self {
655        z3_binary_op!(Z3_mk_bvand, self, rhs)
656    }
657
658    fn mk_bvor(&self, rhs: &Ast<'ctx>) -> Self {
659        z3_binary_op!(Z3_mk_bvor, self, rhs)
660    }
661
662    fn mk_bvxor(&self, rhs: &Ast<'ctx>) -> Self {
663        z3_binary_op!(Z3_mk_bvxor, self, rhs)
664    }
665
666    fn mk_bvnand(&self, rhs: &Ast<'ctx>) -> Self {
667        z3_binary_op!(Z3_mk_bvnand, self, rhs)
668    }
669
670    fn mk_bvnor(&self, rhs: &Ast<'ctx>) -> Self {
671        z3_binary_op!(Z3_mk_bvnor, self, rhs)
672    }
673
674    fn mk_bvxnor(&self, rhs: &Ast<'ctx>) -> Self {
675        z3_binary_op!(Z3_mk_bvxnor, self, rhs)
676    }
677
678    fn mk_bvneg(&self) -> Self {
679        z3_unary_op!(Z3_mk_bvneg, self)
680    }
681
682    fn mk_bvadd(&self, rhs: &Ast<'ctx>) -> Self {
683        z3_binary_op!(Z3_mk_bvadd, self, rhs)
684    }
685
686    fn mk_bvsub(&self, rhs: &Ast<'ctx>) -> Self {
687        z3_binary_op!(Z3_mk_bvsub, self, rhs)
688    }
689
690    fn mk_bvmul(&self, rhs: &Ast<'ctx>) -> Self {
691        z3_binary_op!(Z3_mk_bvmul, self, rhs)
692    }
693
694    fn mk_bvudiv(&self, rhs: &Ast<'ctx>) -> Self {
695        z3_binary_op!(Z3_mk_bvudiv, self, rhs)
696    }
697
698    fn mk_bvsdiv(&self, rhs: &Ast<'ctx>) -> Self {
699        z3_binary_op!(Z3_mk_bvsdiv, self, rhs)
700    }
701
702    fn mk_bvurem(&self, rhs: &Ast<'ctx>) -> Self {
703        z3_binary_op!(Z3_mk_bvurem, self, rhs)
704    }
705
706    fn mk_bvsrem(&self, rhs: &Ast<'ctx>) -> Self {
707        z3_binary_op!(Z3_mk_bvsrem, self, rhs)
708    }
709
710    fn mk_bvsmod(&self, rhs: &Ast<'ctx>) -> Self {
711        z3_binary_op!(Z3_mk_bvsmod, self, rhs)
712    }
713
714    fn mk_bvult(&self, rhs: &Ast<'ctx>) -> Self {
715        z3_binary_op!(Z3_mk_bvult, self, rhs)
716    }
717
718    fn mk_bvslt(&self, rhs: &Ast<'ctx>) -> Self {
719        z3_binary_op!(Z3_mk_bvslt, self, rhs)
720    }
721
722    fn mk_bvule(&self, rhs: &Ast<'ctx>) -> Self {
723        z3_binary_op!(Z3_mk_bvule, self, rhs)
724    }
725
726    fn mk_bvsle(&self, rhs: &Ast<'ctx>) -> Self {
727        z3_binary_op!(Z3_mk_bvsle, self, rhs)
728    }
729
730    fn mk_bvuge(&self, rhs: &Ast<'ctx>) -> Self {
731        z3_binary_op!(Z3_mk_bvuge, self, rhs)
732    }
733
734    fn mk_bvsge(&self, rhs: &Ast<'ctx>) -> Self {
735        z3_binary_op!(Z3_mk_bvsge, self, rhs)
736    }
737
738    fn mk_bvugt(&self, rhs: &Ast<'ctx>) -> Self {
739        z3_binary_op!(Z3_mk_bvugt, self, rhs)
740    }
741
742    fn mk_bvsgt(&self, rhs: &Ast<'ctx>) -> Self {
743        z3_binary_op!(Z3_mk_bvsgt, self, rhs)
744    }
745
746    fn mk_bvshl(&self, rhs: &Ast<'ctx>) -> Self {
747        z3_binary_op!(Z3_mk_bvshl, self, rhs)
748    }
749
750    fn mk_bvlshr(&self, rhs: &Ast<'ctx>) -> Self {
751        z3_binary_op!(Z3_mk_bvlshr, self, rhs)
752    }
753
754    fn mk_bvashr(&self, rhs: &Ast<'ctx>) -> Self {
755        z3_binary_op!(Z3_mk_bvashr, self, rhs)
756    }
757
758    fn mk_concat(&self, rhs: &Ast<'ctx>) -> Self {
759        z3_binary_op!(Z3_mk_concat, self, rhs)
760    }
761
762    fn mk_select(&self, index: &Ast<'ctx>) -> Self {
763        z3_binary_op!(Z3_mk_select, self, index)
764    }
765
766    fn mk_store(&self, index: &Ast<'ctx>, val: &Ast<'ctx>) -> Self {
767        unsafe {
768            let z3_ast = Z3_mk_store(self.ctx.z3_ctx, self.z3_ast, index.z3_ast, val.z3_ast);
769            Z3_inc_ref(self.ctx.z3_ctx, z3_ast);
770            Ast { z3_ast, ctx: self.ctx }
771        }
772    }
773
774    fn get_bool_value(&self) -> Option<bool> {
775        unsafe {
776            match Z3_get_bool_value(self.ctx.z3_ctx, self.z3_ast) {
777                Z3_L_TRUE => Some(true),
778                Z3_L_FALSE => Some(false),
779                _ => None,
780            }
781        }
782    }
783
784    fn get_numeral_u64(&self) -> Result<u64, ExecError> {
785        let mut v: u64 = 0;
786        unsafe {
787            if Z3_get_numeral_uint64(self.ctx.z3_ctx, self.z3_ast, &mut v) {
788                Ok(v)
789            } else {
790                Err(self.ctx.error())
791            }
792        }
793    }
794}
795
796impl<'ctx> Drop for Ast<'ctx> {
797    fn drop(&mut self) {
798        unsafe { Z3_dec_ref(self.ctx.z3_ctx, self.z3_ast) }
799    }
800}
801
802/// The Solver type handles all interaction with Z3. It mimics
803/// interacting with Z3 via the subset of the SMTLIB 2.0 format we
804/// care about.
805///
806/// For example:
807/// ```
808/// # use isla_lib::bitvector::b64::B64;
809/// # use isla_lib::smt::smtlib::Exp::*;
810/// # use isla_lib::smt::smtlib::Def::*;
811/// # use isla_lib::smt::smtlib::*;
812/// # use isla_lib::smt::*;
813/// # let x = Sym::from_u32(0);
814/// let cfg = Config::new();
815/// let ctx = Context::new(cfg);
816/// let mut solver = Solver::<B64>::new(&ctx);
817/// // (declare-const v0 Bool)
818/// solver.add(DeclareConst(x, Ty::Bool));
819/// // (assert v0)
820/// solver.add(Assert(Var(x)));
821/// // (check-sat)
822/// assert!(solver.check_sat() == SmtResult::Sat)
823/// ```
824///
825/// The other thing the Solver type does is maintain a trace of
826/// interactions with Z3, which can be checkpointed and replayed by
827/// another solver. This `Checkpoint` type is safe to be sent between
828/// threads.
829///
830/// For example:
831/// ```
832/// # use isla_lib::bitvector::b64::B64;
833/// # use isla_lib::smt::smtlib::Exp::*;
834/// # use isla_lib::smt::smtlib::Def::*;
835/// # use isla_lib::smt::smtlib::*;
836/// # use isla_lib::smt::*;
837/// # let x = Sym::from_u32(0);
838/// let point = {
839///     let cfg = Config::new();
840///     let ctx = Context::new(cfg);
841///     let mut solver = Solver::<B64>::new(&ctx);
842///     solver.add(DeclareConst(x, Ty::Bool));
843///     solver.add(Assert(Var(x)));
844///     solver.add(Assert(Not(Box::new(Var(x)))));
845///     checkpoint(&mut solver)
846/// };
847/// let cfg = Config::new();
848/// let ctx = Context::new(cfg);
849/// let mut solver = Solver::from_checkpoint(&ctx, point);
850/// assert!(solver.check_sat() == SmtResult::Unsat);
851pub struct Solver<'ctx, B> {
852    trace: Trace<B>,
853    next_var: u32,
854    cycles: i128,
855    decls: HashMap<Sym, Ast<'ctx>>,
856    func_decls: HashMap<Sym, FuncDecl<'ctx>>,
857    enums: Enums<'ctx>,
858    enum_map: HashMap<usize, usize>,
859    z3_solver: Z3_solver,
860    ctx: &'ctx Context,
861}
862
863impl<'ctx, B> Drop for Solver<'ctx, B> {
864    fn drop(&mut self) {
865        unsafe {
866            Z3_solver_dec_ref(self.ctx.z3_ctx, self.z3_solver);
867        }
868    }
869}
870
871/// Interface for extracting information from Z3 models.
872///
873/// Model generation should be turned on in advance.  This is
874/// currently Z3's default, but it's best to make sure:
875///
876/// ```
877/// # use isla_lib::bitvector::b64::B64;
878/// # use isla_lib::smt::smtlib::Exp::*;
879/// # use isla_lib::smt::smtlib::Def::*;
880/// # use isla_lib::smt::smtlib::*;
881/// # use isla_lib::smt::*;
882/// # let x = Sym::from_u32(0);
883/// let mut cfg = Config::new();
884/// cfg.set_param_value("model", "true");
885/// let ctx = Context::new(cfg);
886/// let mut solver = Solver::<B64>::new(&ctx);
887/// solver.add(DeclareConst(x, Ty::BitVec(4)));
888/// solver.add(Assert(Bvsgt(Box::new(Var(x)), Box::new(Bits(vec![false,false,true,false])))));
889/// assert!(solver.check_sat() == SmtResult::Sat);
890/// let mut model = Model::new(&solver);
891/// let var0 = model.get_var(x).unwrap().unwrap();
892/// ```
893pub struct Model<'ctx, B> {
894    z3_model: Z3_model,
895    solver: &'ctx Solver<'ctx, B>,
896    ctx: &'ctx Context,
897}
898
899impl<'ctx, B> Drop for Model<'ctx, B> {
900    fn drop(&mut self) {
901        unsafe {
902            Z3_model_dec_ref(self.ctx.z3_ctx, self.z3_model);
903        }
904    }
905}
906
907// This implements Debug rather than Display because it displays the internal
908// variable names (albeit with the same numbers that appear in the trace).
909impl<'ctx, B> fmt::Debug for Model<'ctx, B> {
910    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
911        unsafe {
912            let z3_string = CStr::from_ptr(Z3_model_to_string(self.ctx.z3_ctx, self.z3_model));
913            write!(f, "{}", z3_string.to_string_lossy())
914        }
915    }
916}
917
918impl<'ctx, B: BV> Model<'ctx, B> {
919    pub fn new(solver: &'ctx Solver<'ctx, B>) -> Self {
920        unsafe {
921            let z3_model = Z3_solver_get_model(solver.ctx.z3_ctx, solver.z3_solver);
922            Z3_model_inc_ref(solver.ctx.z3_ctx, z3_model);
923            Model { z3_model, solver, ctx: solver.ctx }
924        }
925    }
926
927    #[allow(clippy::needless_range_loop)]
928    fn get_large_bv(&mut self, ast: Ast, size: u32) -> Result<Vec<bool>, ExecError> {
929        let mut i = 0;
930        let size = size.try_into().unwrap();
931        let mut result = vec![false; size];
932        while i < size {
933            let hi = std::cmp::min(size, i + 64);
934            let hi32: u32 = hi.try_into().unwrap();
935            let extract_ast = ast.extract(hi32 - 1, i.try_into().unwrap());
936            let result_ast: Ast;
937
938            unsafe {
939                let mut result_z3_ast: Z3_ast = ptr::null_mut();
940                if !Z3_model_eval(self.ctx.z3_ctx, self.z3_model, extract_ast.z3_ast, true, &mut result_z3_ast) {
941                    return Err(self.ctx.error());
942                }
943                Z3_inc_ref(self.ctx.z3_ctx, result_z3_ast);
944                result_ast = Ast { z3_ast: result_z3_ast, ctx: self.ctx };
945            }
946            let v = result_ast.get_numeral_u64()?;
947            for j in i..hi {
948                result[j] = (v >> (j - i) & 1) == 1;
949            }
950            i += 64;
951        }
952        Ok(result)
953    }
954
955    pub fn get_var(&mut self, var: Sym) -> Result<Option<Exp>, ExecError> {
956        let var_ast = match self.solver.decls.get(&var) {
957            None => return Err(ExecError::Type(format!("Unbound variable {:?}", &var))),
958            Some(ast) => ast.clone(),
959        };
960        self.get_ast(var_ast)
961    }
962
963    pub fn get_exp(&mut self, exp: &Exp) -> Result<Option<Exp>, ExecError> {
964        let ast = self.solver.translate_exp(exp);
965        self.get_ast(ast)
966    }
967
968    // Requiring the model to be mutable as I expect Z3 will alter the underlying data
969    fn get_ast(&mut self, var_ast: Ast) -> Result<Option<Exp>, ExecError> {
970        unsafe {
971            let z3_ctx = self.ctx.z3_ctx;
972            let mut z3_ast: Z3_ast = ptr::null_mut();
973            if !Z3_model_eval(z3_ctx, self.z3_model, var_ast.z3_ast, false, &mut z3_ast) {
974                return Err(self.ctx.error());
975            }
976            Z3_inc_ref(z3_ctx, z3_ast);
977
978            let ast = Ast { z3_ast, ctx: self.ctx };
979
980            let sort = Z3_get_sort(z3_ctx, ast.z3_ast);
981            Z3_inc_ref(z3_ctx, Z3_sort_to_ast(z3_ctx, sort));
982            let sort_kind = Z3_get_sort_kind(z3_ctx, sort);
983
984            let result = if sort_kind == SortKind::BV && Z3_is_numeral_ast(z3_ctx, z3_ast) {
985                let size = Z3_get_bv_sort_size(z3_ctx, sort);
986                if size > 64 {
987                    let v = self.get_large_bv(ast, size)?;
988                    Ok(Some(Exp::Bits(v)))
989                } else {
990                    let result = ast.get_numeral_u64()?;
991                    Ok(Some(Exp::Bits64(B64::new(result, size))))
992                }
993            } else if sort_kind == SortKind::Bool && Z3_is_numeral_ast(z3_ctx, z3_ast) {
994                Ok(Some(Exp::Bool(ast.get_bool_value().unwrap())))
995            } else if sort_kind == SortKind::Bool || sort_kind == SortKind::BV {
996                // Model did not need to assign an interpretation to this variable
997                Ok(None)
998            } else if sort_kind == SortKind::Datatype {
999                let func_decl = Z3_get_app_decl(z3_ctx, Z3_to_app(z3_ctx, z3_ast));
1000                Z3_inc_ref(z3_ctx, Z3_func_decl_to_ast(z3_ctx, func_decl));
1001
1002                let mut result = Ok(None);
1003
1004                // Scan all enumerations to find the enum_id (which is
1005                // the index in the enums vector) and member number.
1006                'outer: for (enum_id, enumeration) in self.solver.enums.enums.iter().enumerate() {
1007                    for (i, member) in enumeration.consts.iter().enumerate() {
1008                        if Z3_is_eq_func_decl(z3_ctx, func_decl, *member) {
1009                            result = Ok(Some(Exp::Enum(EnumMember { enum_id, member: i })));
1010                            break 'outer;
1011                        }
1012                    }
1013                }
1014
1015                Z3_dec_ref(z3_ctx, Z3_func_decl_to_ast(z3_ctx, func_decl));
1016                result
1017            } else {
1018                Err(ExecError::Type("get_ast".to_string()))
1019            };
1020
1021            Z3_dec_ref(z3_ctx, Z3_sort_to_ast(z3_ctx, sort));
1022            result
1023        }
1024    }
1025}
1026
1027#[derive(Clone, Copy, Debug, PartialEq, Eq)]
1028pub enum SmtResult {
1029    Sat,
1030    Unsat,
1031    Unknown,
1032}
1033
1034use SmtResult::*;
1035
1036impl SmtResult {
1037    pub fn is_sat(self) -> Result<bool, ExecError> {
1038        match self {
1039            Sat => Ok(true),
1040            Unsat => Ok(false),
1041            Unknown => Err(ExecError::Z3Unknown),
1042        }
1043    }
1044
1045    pub fn is_unsat(self) -> Result<bool, ExecError> {
1046        match self {
1047            Sat => Ok(false),
1048            Unsat => Ok(true),
1049            Unknown => Err(ExecError::Z3Unknown),
1050        }
1051    }
1052
1053    pub fn is_unknown(self) -> bool {
1054        self == Unknown
1055    }
1056}
1057
1058static QFAUFBV_STR: &[u8] = b"qfaufbv\0";
1059
1060impl<'ctx, B: BV> Solver<'ctx, B> {
1061    pub fn new(ctx: &'ctx Context) -> Self {
1062        unsafe {
1063            let mut major: c_uint = 0;
1064            let mut minor: c_uint = 0;
1065            let mut build: c_uint = 0;
1066            let mut revision: c_uint = 0;
1067            Z3_get_version(&mut major, &mut minor, &mut build, &mut revision);
1068
1069            // The QF_AUFBV solver has good performance on our problems, but we need to initialise it
1070            // using a tactic rather than the logic name to ensure that the enumerations are supported,
1071            // otherwise Z3 may crash.
1072            let qfaufbv_tactic = Z3_mk_tactic(ctx.z3_ctx, CStr::from_bytes_with_nul_unchecked(QFAUFBV_STR).as_ptr());
1073            Z3_tactic_inc_ref(ctx.z3_ctx, qfaufbv_tactic);
1074            let z3_solver = Z3_mk_solver_from_tactic(ctx.z3_ctx, qfaufbv_tactic);
1075            Z3_solver_inc_ref(ctx.z3_ctx, z3_solver);
1076
1077            Solver {
1078                ctx,
1079                z3_solver,
1080                next_var: 0,
1081                cycles: 0,
1082                trace: Trace::new(),
1083                decls: HashMap::new(),
1084                func_decls: HashMap::new(),
1085                enums: Enums::new(ctx),
1086                enum_map: HashMap::new(),
1087            }
1088        }
1089    }
1090
1091    pub fn fresh(&mut self) -> Sym {
1092        let n = self.next_var;
1093        self.next_var += 1;
1094        Sym { id: n }
1095    }
1096
1097    fn translate_exp(&self, exp: &Exp) -> Ast<'ctx> {
1098        use Exp::*;
1099        match exp {
1100            Var(v) => match self.decls.get(v) {
1101                None => panic!("Could not get Z3 func_decl {}", *v),
1102                Some(ast) => ast.clone(),
1103            },
1104            Bits(bv) => Ast::mk_bv(self.ctx, bv.len().try_into().unwrap(), &bv),
1105            Bits64(bv) => Ast::mk_bv_u64(self.ctx, bv.len(), bv.lower_u64()),
1106            Enum(e) => Ast::mk_enum_member(&self.enums, e.enum_id, e.member),
1107            Bool(b) => Ast::mk_bool(self.ctx, *b),
1108            Not(exp) => Ast::mk_not(&self.translate_exp(exp)),
1109            Eq(lhs, rhs) => Ast::mk_eq(&self.translate_exp(lhs), &self.translate_exp(rhs)),
1110            Neq(lhs, rhs) => Ast::mk_not(&Ast::mk_eq(&self.translate_exp(lhs), &self.translate_exp(rhs))),
1111            And(lhs, rhs) => Ast::mk_and(&self.translate_exp(lhs), &self.translate_exp(rhs)),
1112            Or(lhs, rhs) => Ast::mk_or(&self.translate_exp(lhs), &self.translate_exp(rhs)),
1113            Bvnot(exp) => Ast::mk_bvnot(&self.translate_exp(exp)),
1114            Bvand(lhs, rhs) => Ast::mk_bvand(&self.translate_exp(lhs), &self.translate_exp(rhs)),
1115            Bvor(lhs, rhs) => Ast::mk_bvor(&self.translate_exp(lhs), &self.translate_exp(rhs)),
1116            Bvxor(lhs, rhs) => Ast::mk_bvxor(&self.translate_exp(lhs), &self.translate_exp(rhs)),
1117            Bvnand(lhs, rhs) => Ast::mk_bvnand(&self.translate_exp(lhs), &self.translate_exp(rhs)),
1118            Bvnor(lhs, rhs) => Ast::mk_bvnor(&self.translate_exp(lhs), &self.translate_exp(rhs)),
1119            Bvxnor(lhs, rhs) => Ast::mk_bvxnor(&self.translate_exp(lhs), &self.translate_exp(rhs)),
1120            Bvneg(exp) => Ast::mk_bvneg(&self.translate_exp(exp)),
1121            Bvadd(lhs, rhs) => Ast::mk_bvadd(&self.translate_exp(lhs), &self.translate_exp(rhs)),
1122            Bvsub(lhs, rhs) => Ast::mk_bvsub(&self.translate_exp(lhs), &self.translate_exp(rhs)),
1123            Bvmul(lhs, rhs) => Ast::mk_bvmul(&self.translate_exp(lhs), &self.translate_exp(rhs)),
1124            Bvudiv(lhs, rhs) => Ast::mk_bvudiv(&self.translate_exp(lhs), &self.translate_exp(rhs)),
1125            Bvsdiv(lhs, rhs) => Ast::mk_bvsdiv(&self.translate_exp(lhs), &self.translate_exp(rhs)),
1126            Bvurem(lhs, rhs) => Ast::mk_bvurem(&self.translate_exp(lhs), &self.translate_exp(rhs)),
1127            Bvsrem(lhs, rhs) => Ast::mk_bvsrem(&self.translate_exp(lhs), &self.translate_exp(rhs)),
1128            Bvsmod(lhs, rhs) => Ast::mk_bvsmod(&self.translate_exp(lhs), &self.translate_exp(rhs)),
1129            Bvult(lhs, rhs) => Ast::mk_bvult(&self.translate_exp(lhs), &self.translate_exp(rhs)),
1130            Bvslt(lhs, rhs) => Ast::mk_bvslt(&self.translate_exp(lhs), &self.translate_exp(rhs)),
1131            Bvule(lhs, rhs) => Ast::mk_bvule(&self.translate_exp(lhs), &self.translate_exp(rhs)),
1132            Bvsle(lhs, rhs) => Ast::mk_bvsle(&self.translate_exp(lhs), &self.translate_exp(rhs)),
1133            Bvuge(lhs, rhs) => Ast::mk_bvuge(&self.translate_exp(lhs), &self.translate_exp(rhs)),
1134            Bvsge(lhs, rhs) => Ast::mk_bvsge(&self.translate_exp(lhs), &self.translate_exp(rhs)),
1135            Bvugt(lhs, rhs) => Ast::mk_bvugt(&self.translate_exp(lhs), &self.translate_exp(rhs)),
1136            Bvsgt(lhs, rhs) => Ast::mk_bvsgt(&self.translate_exp(lhs), &self.translate_exp(rhs)),
1137            Extract(hi, lo, bv) => self.translate_exp(bv).extract(*hi, *lo),
1138            ZeroExtend(i, bv) => self.translate_exp(bv).zero_extend(*i),
1139            SignExtend(i, bv) => self.translate_exp(bv).sign_extend(*i),
1140            Bvshl(lhs, rhs) => Ast::mk_bvshl(&self.translate_exp(lhs), &self.translate_exp(rhs)),
1141            Bvlshr(lhs, rhs) => Ast::mk_bvlshr(&self.translate_exp(lhs), &self.translate_exp(rhs)),
1142            Bvashr(lhs, rhs) => Ast::mk_bvashr(&self.translate_exp(lhs), &self.translate_exp(rhs)),
1143            Concat(lhs, rhs) => Ast::mk_concat(&self.translate_exp(lhs), &self.translate_exp(rhs)),
1144            Ite(cond, t, f) => self.translate_exp(cond).ite(&self.translate_exp(t), &self.translate_exp(f)),
1145            App(f, args) => {
1146                let args_ast: Vec<_> = args.iter().map(|arg| self.translate_exp(arg)).collect();
1147                match self.func_decls.get(f) {
1148                    None => panic!("Could not get Z3 func_decl {}", *f),
1149                    Some(fd) => Ast::mk_app(&fd, &args_ast),
1150                }
1151            }
1152            Select(array, index) => Ast::mk_select(&self.translate_exp(array), &self.translate_exp(index)),
1153            Store(array, index, val) => {
1154                Ast::mk_store(&self.translate_exp(array), &self.translate_exp(index), &self.translate_exp(val))
1155            }
1156        }
1157    }
1158
1159    fn assert(&mut self, exp: &Exp) {
1160        let ast = self.translate_exp(exp);
1161        unsafe {
1162            Z3_solver_assert(self.ctx.z3_ctx, self.z3_solver, ast.z3_ast);
1163        }
1164    }
1165
1166    pub fn get_enum(&mut self, size: usize) -> usize {
1167        match self.enum_map.get(&size) {
1168            Some(enum_id) => *enum_id,
1169            None => {
1170                let name = self.fresh();
1171                self.add(Def::DefineEnum(name, size));
1172                self.enums.enums.len() - 1
1173            }
1174        }
1175    }
1176
1177    fn add_internal(&mut self, def: &Def) {
1178        match &def {
1179            Def::Assert(exp) => self.assert(exp),
1180            Def::DeclareConst(v, ty) => {
1181                let fd = FuncDecl::new(&self.ctx, *v, &self.enums, &[], ty);
1182                self.decls.insert(*v, Ast::mk_constant(&fd));
1183            }
1184            Def::DeclareFun(v, arg_tys, result_ty) => {
1185                let fd = FuncDecl::new(&self.ctx, *v, &self.enums, arg_tys, result_ty);
1186                self.func_decls.insert(*v, fd);
1187            }
1188            Def::DefineConst(v, exp) => {
1189                let ast = self.translate_exp(exp);
1190                self.decls.insert(*v, ast);
1191            }
1192            Def::DefineEnum(name, size) => {
1193                let members: Vec<Sym> = (0..*size).map(|_| self.fresh()).collect();
1194                self.enums.add_enum(*name, &members);
1195                self.enum_map.insert(*size, self.enums.enums.len() - 1);
1196            }
1197        }
1198    }
1199
1200    pub fn length(&mut self, v: Sym) -> Option<u32> {
1201        match self.decls.get(&v) {
1202            Some(ast) => unsafe {
1203                let z3_ctx = self.ctx.z3_ctx;
1204                let z3_sort = Z3_get_sort(z3_ctx, ast.z3_ast);
1205                Z3_inc_ref(z3_ctx, Z3_sort_to_ast(z3_ctx, z3_sort));
1206                if Z3_get_sort_kind(z3_ctx, z3_sort) == SortKind::BV {
1207                    let sz = Z3_get_bv_sort_size(z3_ctx, z3_sort);
1208                    Z3_dec_ref(z3_ctx, Z3_sort_to_ast(z3_ctx, z3_sort));
1209                    Some(sz)
1210                } else {
1211                    Z3_dec_ref(z3_ctx, Z3_sort_to_ast(z3_ctx, z3_sort));
1212                    None
1213                }
1214            },
1215            None => None,
1216        }
1217    }
1218
1219    pub fn is_bitvector(&mut self, v: Sym) -> bool {
1220        match self.decls.get(&v) {
1221            Some(ast) => unsafe {
1222                let z3_ctx = self.ctx.z3_ctx;
1223                let z3_sort = Z3_get_sort(z3_ctx, ast.z3_ast);
1224                Z3_inc_ref(z3_ctx, Z3_sort_to_ast(z3_ctx, z3_sort));
1225                let result = Z3_get_sort_kind(z3_ctx, z3_sort) == SortKind::BV;
1226                Z3_dec_ref(z3_ctx, Z3_sort_to_ast(z3_ctx, z3_sort));
1227                result
1228            },
1229            None => false,
1230        }
1231    }
1232
1233    pub fn add(&mut self, def: Def) {
1234        self.add_internal(&def);
1235        self.trace.head.push(Event::Smt(def))
1236    }
1237
1238    pub fn declare_const(&mut self, ty: Ty) -> Sym {
1239        let sym = self.fresh();
1240        self.add(Def::DeclareConst(sym, ty));
1241        sym
1242    }
1243
1244    pub fn define_const(&mut self, exp: Exp) -> Sym {
1245        let sym = self.fresh();
1246        self.add(Def::DefineConst(sym, exp));
1247        sym
1248    }
1249
1250    pub fn assert_eq(&mut self, lhs: Exp, rhs: Exp) {
1251        self.add(Def::Assert(Exp::Eq(Box::new(lhs), Box::new(rhs))))
1252    }
1253
1254    pub fn cycle_count(&mut self) {
1255        self.cycles += 1;
1256        self.add_event(Event::Cycle)
1257    }
1258
1259    pub fn get_cycle_count(&self) -> i128 {
1260        self.cycles
1261    }
1262
1263    fn add_event_internal(&mut self, event: &Event<B>) {
1264        if let Event::Smt(def) = event {
1265            self.add_internal(def)
1266        };
1267    }
1268
1269    pub fn add_event(&mut self, event: Event<B>) {
1270        self.add_event_internal(&event);
1271        self.trace.head.push(event)
1272    }
1273
1274    fn replay(&mut self, num: usize, trace: Arc<Option<Trace<B>>>) {
1275        // Some extra work would be required to replay on top of
1276        // another trace, so until we need to do that we'll check it's
1277        // empty:
1278        assert!(self.trace.checkpoints == 0 && self.trace.head.is_empty());
1279        let mut checkpoints: Vec<&[Event<B>]> = Vec::with_capacity(num);
1280        let mut next = &*trace;
1281        loop {
1282            match next {
1283                None => break,
1284                Some(tr) => {
1285                    checkpoints.push(&tr.head);
1286                    next = &*tr.tail
1287                }
1288            }
1289        }
1290        assert!(checkpoints.len() == num);
1291        for events in checkpoints.iter().rev() {
1292            for event in *events {
1293                self.add_event_internal(&event)
1294            }
1295        }
1296        self.trace.checkpoints = num;
1297        self.trace.tail = trace
1298    }
1299
1300    pub fn from_checkpoint(ctx: &'ctx Context, Checkpoint { num, next_var, trace }: Checkpoint<B>) -> Self {
1301        let mut solver = Solver::new(ctx);
1302        solver.replay(num, trace);
1303        solver.next_var = next_var;
1304        solver
1305    }
1306
1307    pub fn check_sat_with(&mut self, exp: &Exp) -> SmtResult {
1308        let ast = self.translate_exp(exp);
1309        unsafe {
1310            let result = Z3_solver_check_assumptions(self.ctx.z3_ctx, self.z3_solver, 1, &ast.z3_ast);
1311            if result == Z3_L_TRUE {
1312                Sat
1313            } else if result == Z3_L_FALSE {
1314                Unsat
1315            } else {
1316                Unknown
1317            }
1318        }
1319    }
1320
1321    pub fn trace(&self) -> &Trace<B> {
1322        &self.trace
1323    }
1324
1325    pub fn check_sat(&mut self) -> SmtResult {
1326        unsafe {
1327            let result = Z3_solver_check(self.ctx.z3_ctx, self.z3_solver);
1328            if result == Z3_L_TRUE {
1329                Sat
1330            } else if result == Z3_L_FALSE {
1331                Unsat
1332            } else {
1333                Unknown
1334            }
1335        }
1336    }
1337
1338    pub fn dump_solver(&mut self, filename: &str) {
1339        let mut file = std::fs::File::create(filename).expect("Failed to open solver dump file");
1340        unsafe {
1341            let s = Z3_solver_to_string(self.ctx.z3_ctx, self.z3_solver);
1342            let cs = CStr::from_ptr(s);
1343            file.write_all(cs.to_bytes()).expect("Failed to write solver dump");
1344        }
1345    }
1346
1347    pub fn dump_solver_with(&mut self, filename: &str, exp: &Exp) {
1348        let mut file = std::fs::File::create(filename).expect("Failed to open solver dump file");
1349        unsafe {
1350            let s = Z3_solver_to_string(self.ctx.z3_ctx, self.z3_solver);
1351            let cs = CStr::from_ptr(s);
1352            file.write_all(cs.to_bytes()).expect("Failed to write solver dump");
1353            writeln!(file, "{}", self.exp_to_str(exp)).expect("Failed to write exp");
1354        }
1355    }
1356
1357    pub fn exp_to_str(&mut self, exp: &Exp) -> String {
1358        let ast = self.translate_exp(exp);
1359        let cs;
1360        unsafe {
1361            let s = Z3_ast_to_string(ast.ctx.z3_ctx, ast.z3_ast);
1362            cs = CStr::from_ptr(s);
1363        }
1364        cs.to_string_lossy().to_string()
1365    }
1366}
1367
1368pub fn checkpoint<B: BV>(solver: &mut Solver<B>) -> Checkpoint<B> {
1369    solver.trace.checkpoint(solver.next_var)
1370}
1371
1372/// This function just calls Z3_finalize_memory(). It's useful because
1373/// by calling it before we exit, we can check whether we are leaking
1374/// memory while interacting with Z3 objects.
1375///
1376/// # Safety
1377///
1378/// Shoud only be called just before exiting.
1379pub unsafe fn finalize_solver() {
1380    Z3_finalize_memory()
1381}
1382
1383#[cfg(test)]
1384mod tests {
1385    use crate::bitvector::b64::B64;
1386
1387    use super::Def::*;
1388    use super::Exp::*;
1389    use super::*;
1390
1391    macro_rules! bv {
1392        ( $bv_string:expr ) => {{
1393            let mut vec = Vec::new();
1394            for c in $bv_string.chars().rev() {
1395                if c == '1' {
1396                    vec.push(true)
1397                } else if c == '0' {
1398                    vec.push(false)
1399                } else {
1400                    ()
1401                }
1402            }
1403            Bits(vec)
1404        }};
1405    }
1406
1407    fn var(id: u32) -> Exp {
1408        Var(Sym::from_u32(id))
1409    }
1410
1411    #[test]
1412    fn bv_macro() {
1413        let cfg = Config::new();
1414        let ctx = Context::new(cfg);
1415        let mut solver = Solver::<B64>::new(&ctx);
1416        solver.add(Assert(Eq(Box::new(bv!("0110")), Box::new(bv!("1001")))));
1417        assert!(solver.check_sat() == Unsat);
1418    }
1419
1420    #[test]
1421    fn get_const() {
1422        let mut cfg = Config::new();
1423        cfg.set_param_value("model", "true");
1424        let ctx = Context::new(cfg);
1425        let mut solver = Solver::<B64>::new(&ctx);
1426        solver.add(DeclareConst(Sym::from_u32(0), Ty::BitVec(4)));
1427        solver.add(DeclareConst(Sym::from_u32(1), Ty::BitVec(1)));
1428        solver.add(DeclareConst(Sym::from_u32(2), Ty::BitVec(5)));
1429        solver.add(DeclareConst(Sym::from_u32(3), Ty::BitVec(5)));
1430        solver.add(DeclareConst(Sym::from_u32(4), Ty::BitVec(257)));
1431        solver.add(Assert(Eq(Box::new(bv!("0110")), Box::new(var(0)))));
1432        solver.add(Assert(Eq(Box::new(var(2)), Box::new(var(3)))));
1433        let big_bv = Box::new(SignExtend(251, Box::new(Bits(vec![true, false, false, true, false, true]))));
1434        solver.add(Assert(Eq(Box::new(var(4)), big_bv)));
1435        assert!(solver.check_sat() == Sat);
1436        let (v0, v2, v3, v4);
1437        {
1438            let mut model = Model::new(&solver);
1439            v0 = model.get_var(Sym::from_u32(0)).unwrap().unwrap();
1440            assert!(model.get_var(Sym::from_u32(1)).unwrap().is_none());
1441            v2 = model.get_var(Sym::from_u32(2)).unwrap().unwrap();
1442            v3 = model.get_var(Sym::from_u32(3)).unwrap().unwrap();
1443            v4 = model.get_var(Sym::from_u32(4)).unwrap().unwrap();
1444        }
1445        solver.add(Assert(Eq(Box::new(var(0)), Box::new(v0))));
1446        solver.add(Assert(Eq(Box::new(var(2)), Box::new(v2))));
1447        solver.add(Assert(Eq(Box::new(var(3)), Box::new(v3))));
1448        solver.add(Assert(Eq(Box::new(var(4)), Box::new(v4))));
1449        match solver.check_sat() {
1450            Sat => (),
1451            _ => panic!("Round-trip failed, trace {:?}", solver.trace()),
1452        }
1453    }
1454
1455    #[test]
1456    fn get_enum_const() {
1457        let mut cfg = Config::new();
1458        cfg.set_param_value("model", "true");
1459        let ctx = Context::new(cfg);
1460        let mut solver = Solver::<B64>::new(&ctx);
1461        let e = solver.get_enum(3);
1462        let v0 = solver.declare_const(Ty::Enum(e));
1463        let v1 = solver.declare_const(Ty::Enum(e));
1464        let v2 = solver.declare_const(Ty::Enum(e));
1465        solver.assert_eq(Var(v0), Var(v1));
1466        assert!(solver.check_sat() == Sat);
1467        let (m0, m1) = {
1468            let mut model = Model::new(&solver);
1469            assert!(model.get_var(v2).unwrap().is_none());
1470            (model.get_var(v0).unwrap().unwrap(), model.get_var(v1).unwrap().unwrap())
1471        };
1472        solver.assert_eq(Var(v0), m0);
1473        solver.assert_eq(Var(v1), m1);
1474        match solver.check_sat() {
1475            Sat => (),
1476            _ => panic!("Round-trip failed, trace {:?}", solver.trace()),
1477        }
1478    }
1479
1480    #[test]
1481    fn smt_func() {
1482        let mut cfg = Config::new();
1483        cfg.set_param_value("model", "true");
1484        let ctx = Context::new(cfg);
1485        let mut solver = Solver::<B64>::new(&ctx);
1486        solver.add(DeclareFun(Sym::from_u32(0), vec![Ty::BitVec(2), Ty::BitVec(4)], Ty::BitVec(8)));
1487        solver.add(DeclareConst(Sym::from_u32(1), Ty::BitVec(8)));
1488        solver.add(DeclareConst(Sym::from_u32(2), Ty::BitVec(2)));
1489        solver
1490            .add(Assert(Eq(Box::new(App(Sym::from_u32(0), vec![bv!("10"), bv!("0110")])), Box::new(bv!("01011011")))));
1491        solver.add(Assert(Eq(Box::new(App(Sym::from_u32(0), vec![var(2), bv!("0110")])), Box::new(var(1)))));
1492        solver.add(Assert(Eq(Box::new(var(2)), Box::new(bv!("10")))));
1493        assert!(solver.check_sat() == Sat);
1494        let mut model = Model::new(&solver);
1495        let val = model.get_var(Sym::from_u32(1)).unwrap().unwrap();
1496        assert!(match val {
1497            Bits64(bv) if bv == B64::new(0b01011011, 8) => true,
1498            _ => false,
1499        });
1500    }
1501
1502    #[test]
1503    fn array() {
1504        let cfg = Config::new();
1505        let ctx = Context::new(cfg);
1506        let mut solver = Solver::<B64>::new(&ctx);
1507        solver.add(DeclareConst(Sym::from_u32(0), Ty::Array(Box::new(Ty::BitVec(3)), Box::new(Ty::BitVec(4)))));
1508        solver.add(DeclareConst(Sym::from_u32(1), Ty::BitVec(3)));
1509        solver.add(Assert(Neq(
1510            Box::new(Select(
1511                Box::new(Store(Box::new(var(0)), Box::new(var(1)), Box::new(bv!("0101")))),
1512                Box::new(var(1)),
1513            )),
1514            Box::new(bv!("0101")),
1515        )));
1516        assert!(solver.check_sat() == Unsat);
1517    }
1518}