1use super::{
2    Assignment, BVOperator, BitVector, Formula, FormulaVisitor, SmtSolver, Solver, SolverError,
3    SymbolId,
4};
5use bytesize::ByteSize;
6use std::{
7    collections::HashMap,
8    io::{stdout, BufWriter, Write},
9    sync::{Arc, Mutex},
10};
11use strum::{EnumString, EnumVariantNames, IntoStaticStr};
12
13#[derive(Debug, EnumString, EnumVariantNames, IntoStaticStr)]
14#[strum(serialize_all = "kebab_case")]
15pub enum SmtType {
16    Generic,
17    #[cfg(feature = "boolector")]
18    Boolector,
19    #[cfg(feature = "z3")]
20    Z3,
21}
22
23pub struct SmtGenerationOptions {
24    pub smt_type: SmtType,
25    pub memory_size: ByteSize,
26    pub max_execution_depth: u64,
27}
28
29pub mod defaults {
30    use super::*;
31
32    pub const MEMORY_SIZE: ByteSize = ByteSize(bytesize::MIB);
33    pub const MAX_EXECUTION_DEPTH: u64 = 1000;
34    pub const SMT_TYPE: SmtType = SmtType::Generic;
35}
36
37impl Default for SmtGenerationOptions {
38    fn default() -> Self {
39        Self {
40            memory_size: defaults::MEMORY_SIZE,
41            max_execution_depth: defaults::MAX_EXECUTION_DEPTH,
42            smt_type: defaults::SMT_TYPE,
43        }
44    }
45}
46
47pub struct SmtWriter {
48    output: Arc<Mutex<dyn Write + Send>>,
49}
50
51const SET_LOGIC_STATEMENT: &str = "(set-logic QF_BV)";
52
53impl SmtWriter {
54    pub fn new<W: 'static>(write: W) -> Result<Self, SolverError>
55    where
56        W: Write + Send,
57    {
58        Self::new_with_smt_prefix(write, "")
59    }
60
61    pub fn new_for_solver<S, W: 'static>(write: W) -> Result<Self, SolverError>
62    where
63        S: SmtSolver,
64        W: Write + Send,
65    {
66        Self::new_with_smt_prefix(write, S::smt_options())
67    }
68
69    fn new_with_smt_prefix<W: 'static>(write: W, prefix: &str) -> Result<Self, SolverError>
70    where
71        W: Write + Send,
72    {
73        let mut writer = BufWriter::new(write);
74
75        writeln!(writer, "{}{}", prefix, SET_LOGIC_STATEMENT).map_err(SolverError::from)?;
76
77        let output = Arc::new(Mutex::new(writer));
78
79        Ok(Self { output })
80    }
81}
82
83impl Default for SmtWriter {
84    fn default() -> Self {
85        Self::new_with_smt_prefix(stdout(), "").expect("stdout should not fail")
86    }
87}
88
89impl Solver for SmtWriter {
90    fn name() -> &'static str {
91        "External"
92    }
93
94    fn solve_impl<F: Formula>(&self, formula: &F) -> Result<Option<Assignment>, SolverError> {
95        {
96            let mut output = self.output.lock().expect("no other thread should fail");
97
98            writeln!(output, "(push 1)")?;
99
100            }
102
103        let mut printer = SmtPrinter {
104            output: self.output.clone(),
105        };
106        let mut visited = HashMap::<SymbolId, Result<SymbolId, SolverError>>::new();
107
108        formula.traverse(formula.root(), &mut visited, &mut printer)?;
109
110        let mut output = self.output.lock().expect("no other thread should fail");
111
112        writeln!(output, "(check-sat)\n(get-model)\n(pop 1)")?;
113
114        output.flush().expect("can flush SMT write buffer");
115
116        Err(SolverError::SatUnknown)
117    }
118}
119
120struct SmtPrinter {
121    output: Arc<Mutex<dyn Write>>,
122}
123
124impl FormulaVisitor<Result<SymbolId, SolverError>> for SmtPrinter {
125    fn input(&mut self, idx: SymbolId, name: &str) -> Result<SymbolId, SolverError> {
126        let mut o = self.output.lock().expect("no other thread should fail");
127
128        writeln!(o, "(declare-fun x{} () (_ BitVec 64)); {:?}", idx, name)?;
129
130        Ok(idx)
131    }
132
133    fn constant(&mut self, idx: SymbolId, v: BitVector) -> Result<SymbolId, SolverError> {
134        let mut o = self.output.lock().expect("no other thread should fail");
135
136        writeln!(
137            o,
138            "(declare-fun x{} () (_ BitVec 64))\n(assert (= x{} (_ bv{} 64)))",
139            idx, idx, v.0
140        )?;
141
142        Ok(idx)
143    }
144
145    fn unary(
146        &mut self,
147        idx: SymbolId,
148        op: BVOperator,
149        v: Result<SymbolId, SolverError>,
150    ) -> Result<SymbolId, SolverError> {
151        let mut o = self.output.lock().expect("no other thread should fail");
152
153        match op {
154            BVOperator::Not => {
155                writeln!(
156                    o,
157                    "(declare-fun x{} () (_ BitVec 64))\n(assert (= x{} (ite (= x{} (_ bv0 64)) (_ bv1 64) (_ bv0 64))))",
158                    idx, idx, v?,
159                )?;
160            }
161            _ => unreachable!("operator {} not supported as unary operator", op),
162        }
163
164        Ok(idx)
165    }
166
167    fn binary(
168        &mut self,
169        idx: SymbolId,
170        op: BVOperator,
171        lhs: Result<SymbolId, SolverError>,
172        rhs: Result<SymbolId, SolverError>,
173    ) -> Result<SymbolId, SolverError> {
174        let mut o = self.output.lock().expect("no other thread should fail");
175
176        match op {
177            BVOperator::Equals | BVOperator::Sltu => {
178                writeln!(
179                    o,
180                    "(declare-fun x{} () (_ BitVec 64))\n(assert (= x{} (ite ({} x{} x{}) (_ bv1 64) (_ bv0 64))))",
181                    idx,
182                    idx,
183                    to_smt(op),
184                    lhs?,
185                    rhs?
186                )?;
187            }
188            _ => {
189                writeln!(
190                    o,
191                    "(declare-fun x{} () (_ BitVec 64))\n(assert (= x{} ({} x{} x{})))",
192                    idx,
193                    idx,
194                    to_smt(op),
195                    lhs?,
196                    rhs?
197                )?;
198            }
199        };
200
201        Ok(idx)
202    }
203}
204
205fn to_smt(op: BVOperator) -> &'static str {
206    match op {
207        BVOperator::Add => "bvadd",
208        BVOperator::Sub => "bvsub",
209        BVOperator::Not => panic!("no direct translation"),
210        BVOperator::Mul => "bvmul",
211        BVOperator::Divu => "bvudiv",
212        BVOperator::Remu => "bvurem",
213        BVOperator::Equals => "=",
214        BVOperator::BitwiseAnd => "bvand",
215        BVOperator::Sltu => "bvult",
216    }
217}