monster/engine/
symbolic_state.rs

1use crate::solver::{
2    BVOperator, BitVector, Formula, FormulaVisitor, OperandSide, Solver, SolverError, Symbol,
3    SymbolId,
4};
5use log::{debug, trace, Level};
6pub use petgraph::graph::{EdgeIndex, NodeIndex};
7use petgraph::visit::EdgeRef;
8use petgraph::{
9    dot::Dot,
10    graph::{Neighbors, NodeIndices},
11    Direction,
12};
13use riscu::Instruction;
14use std::{collections::HashMap, fmt, ops::Index};
15
16pub enum Query {
17    Equals((SymbolicValue, u64)),
18    NotEquals((SymbolicValue, u64)),
19    Reachable,
20}
21
22pub enum QueryResult {
23    Sat(Witness),
24    UnSat,
25    Unknown,
26}
27
28pub type SymbolicValue = NodeIndex;
29pub type DataFlowGraph = petgraph::Graph<Symbol, OperandSide>;
30
31fn instruction_to_bv_operator(instruction: Instruction) -> BVOperator {
32    match instruction {
33        Instruction::Add(_) | Instruction::Addi(_) => BVOperator::Add,
34        Instruction::Sub(_) => BVOperator::Sub,
35        Instruction::Mul(_) => BVOperator::Mul,
36        Instruction::Divu(_) => BVOperator::Divu,
37        Instruction::Remu(_) => BVOperator::Remu,
38        Instruction::Sltu(_) => BVOperator::Sltu,
39        _ => unimplemented!("can not translate {:?} to Operator", instruction),
40    }
41}
42
43#[derive(Debug)]
44pub struct SymbolicState<'a, S>
45where
46    S: Solver,
47{
48    data_flow: DataFlowGraph,
49    path_condition: SymbolicValue,
50    solver: &'a S,
51}
52
53impl<'a, S> Clone for SymbolicState<'a, S>
54where
55    S: Solver,
56{
57    fn clone(&self) -> Self {
58        Self {
59            data_flow: self.data_flow.clone(),
60            path_condition: self.path_condition,
61            solver: self.solver,
62        }
63    }
64}
65
66impl<'a, S> SymbolicState<'a, S>
67where
68    S: Solver,
69{
70    pub fn new(solver: &'a S) -> Self {
71        let mut data_flow = DataFlowGraph::new();
72
73        let constant = Symbol::Constant(BitVector(1));
74
75        let path_condition = data_flow.add_node(constant);
76
77        Self {
78            data_flow,
79            path_condition,
80            solver,
81        }
82    }
83
84    pub fn create_const(&mut self, value: u64) -> SymbolicValue {
85        let constant = Symbol::Constant(BitVector(value));
86
87        let i = self.data_flow.add_node(constant);
88
89        trace!("new constant: x{} := {:#x}", i.index(), value);
90
91        i
92    }
93
94    pub fn create_instruction(
95        &mut self,
96        instruction: Instruction,
97        lhs: SymbolicValue,
98        rhs: SymbolicValue,
99    ) -> SymbolicValue {
100        let op = instruction_to_bv_operator(instruction);
101
102        let root = self.create_operator(op, lhs, rhs);
103
104        // constrain divisor to be not zero,
105        // as division by zero is allowed in SMT bit-vector formulas
106        if matches!(op, BVOperator::Divu)
107            && matches!(self.data_flow[rhs], Symbol::Operator(_) | Symbol::Input(_))
108        {
109            let zero = self.create_const(0);
110            let negated_condition = self.create_operator(BVOperator::Equals, rhs, zero);
111            let condition = self.create_unary_operator(BVOperator::Not, negated_condition);
112
113            self.add_path_condition(condition);
114        }
115
116        root
117    }
118
119    pub fn create_operator(
120        &mut self,
121        op: BVOperator,
122        lhs: SymbolicValue,
123        rhs: SymbolicValue,
124    ) -> SymbolicValue {
125        assert!(op.is_binary(), "has to be a binary operator");
126
127        let n = Symbol::Operator(op);
128        let n_idx = self.data_flow.add_node(n);
129
130        assert!(!(
131                matches!(self.data_flow[lhs], Symbol::Constant(_))
132                && matches!(self.data_flow[rhs], Symbol::Constant(_))
133            ),
134            "every operand has to be derived from an input or has to be an (already folded) constant"
135        );
136
137        self.connect_operator(lhs, rhs, n_idx);
138
139        trace!(
140            "new operator: x{} := x{} {} x{}",
141            n_idx.index(),
142            lhs.index(),
143            op,
144            rhs.index()
145        );
146
147        n_idx
148    }
149
150    fn create_unary_operator(&mut self, op: BVOperator, v: SymbolicValue) -> SymbolicValue {
151        assert!(op.is_unary(), "has to be a unary operator");
152
153        let op_id = self.data_flow.add_node(Symbol::Operator(op));
154
155        self.data_flow.add_edge(v, op_id, OperandSide::Lhs);
156
157        op_id
158    }
159
160    pub fn create_input(&mut self, name: &str) -> SymbolicValue {
161        let node = Symbol::Input(String::from(name));
162
163        let idx = self.data_flow.add_node(node);
164
165        trace!("new input: x{} := {:?}", idx.index(), name);
166
167        idx
168    }
169
170    pub fn create_beq_path_condition(
171        &mut self,
172        decision: bool,
173        lhs: SymbolicValue,
174        rhs: SymbolicValue,
175    ) {
176        let mut pc_idx = self.create_operator(BVOperator::Equals, lhs, rhs);
177
178        if !decision {
179            pc_idx = self.create_unary_operator(BVOperator::Not, pc_idx);
180        }
181
182        self.add_path_condition(pc_idx)
183    }
184
185    fn add_path_condition(&mut self, condition: SymbolicValue) {
186        self.path_condition =
187            self.create_operator(BVOperator::BitwiseAnd, self.path_condition, condition);
188    }
189
190    pub fn execute_query(&mut self, query: Query) -> Result<QueryResult, SolverError> {
191        // prepare graph for query
192        let (root, cleanup_nodes, cleanup_edges) = match query {
193            Query::Equals(_) | Query::NotEquals(_) => self.prepare_query(query),
194            Query::Reachable => (self.path_condition, vec![], vec![]),
195        };
196
197        let formula = FormulaView::new(&self.data_flow, root);
198
199        if log::log_enabled!(Level::Debug) {
200            debug!("query to solve:");
201
202            let root = formula.print_recursive();
203
204            debug!("assert x{} is 1", root);
205        }
206
207        let result = match self.solver.solve(&formula) {
208            Ok(Some(ref assignment)) => Ok(QueryResult::Sat(formula.build_witness(assignment))),
209            Ok(None) => Ok(QueryResult::UnSat),
210            Err(SolverError::SatUnknown) | Err(SolverError::Timeout) => Ok(QueryResult::Unknown),
211            Err(e) => Err(e),
212        };
213
214        cleanup_edges.iter().for_each(|e| {
215            self.data_flow.remove_edge(*e);
216        });
217        cleanup_nodes.iter().for_each(|n| {
218            self.data_flow.remove_node(*n);
219        });
220
221        result
222    }
223
224    fn append_path_condition(
225        &mut self,
226        r: SymbolicValue,
227        mut ns: Vec<SymbolicValue>,
228        mut es: Vec<EdgeIndex>,
229    ) -> (SymbolicValue, Vec<SymbolicValue>, Vec<EdgeIndex>) {
230        let con_idx = self
231            .data_flow
232            .add_node(Symbol::Operator(BVOperator::BitwiseAnd));
233        let (con_edge_idx1, con_edge_idx2) = self.connect_operator(self.path_condition, r, con_idx);
234
235        ns.push(con_idx);
236        es.push(con_edge_idx1);
237        es.push(con_edge_idx2);
238
239        (con_idx, ns, es)
240    }
241
242    fn prepare_query(
243        &mut self,
244        query: Query,
245    ) -> (SymbolicValue, Vec<SymbolicValue>, Vec<EdgeIndex>) {
246        match query {
247            Query::Equals((sym, c)) | Query::NotEquals((sym, c)) => {
248                let root_idx = self
249                    .data_flow
250                    .add_node(Symbol::Operator(BVOperator::Equals));
251
252                let const_idx = self.data_flow.add_node(Symbol::Constant(BitVector(c)));
253                let const_edge_idx = self
254                    .data_flow
255                    .add_edge(const_idx, root_idx, OperandSide::Lhs);
256
257                let sym_edge_idx = self.data_flow.add_edge(sym, root_idx, OperandSide::Rhs);
258
259                if let Query::NotEquals(_) = query {
260                    let not_idx = self.data_flow.add_node(Symbol::Operator(BVOperator::Not));
261                    let not_edge_idx = self.data_flow.add_edge(root_idx, not_idx, OperandSide::Lhs);
262
263                    self.append_path_condition(
264                        not_idx,
265                        vec![root_idx, const_idx, not_idx],
266                        vec![const_edge_idx, sym_edge_idx, not_edge_idx],
267                    )
268                } else {
269                    self.append_path_condition(
270                        root_idx,
271                        vec![root_idx, const_idx],
272                        vec![const_edge_idx, sym_edge_idx],
273                    )
274                }
275            }
276            Query::Reachable => panic!("nothing to be prepeared for that query"),
277        }
278    }
279
280    fn connect_operator(
281        &mut self,
282        lhs: SymbolicValue,
283        rhs: SymbolicValue,
284        op: SymbolicValue,
285    ) -> (EdgeIndex, EdgeIndex) {
286        // assert: right hand side edge has to be inserted first
287        // solvers depend on edge insertion order!!!
288        (
289            self.data_flow.add_edge(rhs, op, OperandSide::Rhs),
290            self.data_flow.add_edge(lhs, op, OperandSide::Lhs),
291        )
292    }
293}
294
295impl<'a, S: Solver> fmt::Display for SymbolicState<'a, S> {
296    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
297        let dot_graph = Dot::with_config(&self.data_flow, &[]);
298
299        write!(f, "{:?}", dot_graph)
300    }
301}
302
303pub struct FormulaView<'a> {
304    data_flow: &'a DataFlowGraph,
305    root: SymbolicValue,
306}
307
308impl<'a> FormulaView<'a> {
309    pub fn new(data_flow: &'a DataFlowGraph, root: SymbolicValue) -> Self {
310        Self { data_flow, root }
311    }
312
313    pub fn print_recursive(&self) -> SymbolId {
314        let mut visited = HashMap::<SymbolId, SymbolId>::new();
315        let mut printer = Printer {};
316
317        self.traverse(self.root(), &mut visited, &mut printer)
318    }
319
320    fn build_witness(&self, assignment: &HashMap<SymbolId, BitVector>) -> Witness {
321        let mut visited = HashMap::<SymbolId, usize>::new();
322
323        let mut witness = Witness::new();
324        let mut builder = WitnessBuilder {
325            witness: &mut witness,
326            assignment,
327        };
328
329        self.traverse(self.root(), &mut visited, &mut builder);
330
331        witness
332    }
333}
334
335impl<'a> Index<SymbolId> for FormulaView<'a> {
336    type Output = Symbol;
337
338    fn index(&self, idx: SymbolId) -> &Self::Output {
339        &self.data_flow[NodeIndex::new(idx)]
340    }
341}
342
343impl<'a> Formula for FormulaView<'a> {
344    type DependencyIter = std::iter::Map<Neighbors<'a, OperandSide>, fn(NodeIndex) -> usize>;
345    type SymbolIdsIter = std::iter::Map<NodeIndices, fn(NodeIndex) -> usize>;
346
347    fn root(&self) -> SymbolId {
348        self.root.index()
349    }
350
351    fn operands(&self, sym: SymbolId) -> (SymbolId, Option<SymbolId>) {
352        let mut iter = self
353            .data_flow
354            .neighbors_directed(NodeIndex::new(sym), Direction::Incoming)
355            .detach();
356
357        let lhs = iter
358            .next(self.data_flow)
359            .expect("get_operands() should not be called on operators without operands")
360            .1
361            .index();
362
363        let rhs = iter.next(self.data_flow).map(|n| n.1.index());
364
365        assert!(
366            iter.next(self.data_flow) == None,
367            "operators with arity 1 or 2 are supported only"
368        );
369
370        (lhs, rhs)
371    }
372
373    fn operand(&self, sym: SymbolId) -> SymbolId {
374        self.data_flow
375            .edges_directed(NodeIndex::new(sym), Direction::Incoming)
376            .next()
377            .expect("every unary operator must have an operand")
378            .source()
379            .index()
380    }
381
382    fn dependencies(&self, sym: SymbolId) -> Self::DependencyIter {
383        self.data_flow
384            .neighbors_directed(NodeIndex::new(sym), Direction::Outgoing)
385            .map(|idx| idx.index())
386    }
387
388    fn symbol_ids(&self) -> Self::SymbolIdsIter {
389        self.data_flow.node_indices().map(|i| i.index())
390    }
391
392    fn is_operand(&self, sym: SymbolId) -> bool {
393        !matches!(self.data_flow[NodeIndex::new(sym)], Symbol::Operator(_))
394    }
395
396    fn traverse<V, R>(&self, n: SymbolId, visit_map: &mut HashMap<SymbolId, R>, v: &mut V) -> R
397    where
398        V: FormulaVisitor<R>,
399        R: Clone,
400    {
401        if let Some(result) = visit_map.get(&n) {
402            return (*result).clone();
403        }
404
405        let result = match &self.data_flow[NodeIndex::new(n)] {
406            Symbol::Operator(op) => {
407                let mut operands = self
408                    .data_flow
409                    .neighbors_directed(NodeIndex::new(n), Direction::Incoming)
410                    .detach();
411
412                if op.is_unary() {
413                    let x = operands
414                        .next(self.data_flow)
415                        .expect("every unary operator must have 1 operand")
416                        .1
417                        .index();
418
419                    let x = self.traverse(x, visit_map, v);
420
421                    v.unary(n, *op, x)
422                } else {
423                    let lhs = operands
424                        .next(self.data_flow)
425                        .expect("every binary operator must have an lhs operand")
426                        .1
427                        .index();
428
429                    let rhs = operands
430                        .next(self.data_flow)
431                        .expect("every binary operator must have an rhs operand")
432                        .1
433                        .index();
434
435                    let lhs = self.traverse(lhs, visit_map, v);
436                    let rhs = self.traverse(rhs, visit_map, v);
437
438                    v.binary(n, *op, lhs, rhs)
439                }
440            }
441            Symbol::Constant(c) => v.constant(n, *c),
442            Symbol::Input(name) => v.input(n, name.as_str()),
443        };
444
445        visit_map.insert(n, result.clone());
446
447        result
448    }
449}
450
451struct Printer {}
452
453impl<'a> FormulaVisitor<SymbolId> for Printer {
454    fn input(&mut self, idx: SymbolId, name: &str) -> SymbolId {
455        debug!("x{} := {:?}", idx, name);
456        idx
457    }
458    fn constant(&mut self, idx: SymbolId, v: BitVector) -> SymbolId {
459        debug!("x{} := {}", idx, v.0);
460        idx
461    }
462    fn unary(&mut self, idx: SymbolId, op: BVOperator, v: SymbolId) -> SymbolId {
463        debug!("x{} := {}x{}", idx, op, v);
464        idx
465    }
466    fn binary(&mut self, idx: SymbolId, op: BVOperator, lhs: SymbolId, rhs: SymbolId) -> SymbolId {
467        debug!("x{} := x{} {} x{}", idx, lhs, op, rhs);
468        idx
469    }
470}
471
472struct WitnessBuilder<'a> {
473    witness: &'a mut Witness,
474    assignment: &'a HashMap<SymbolId, BitVector>,
475}
476
477impl<'a> FormulaVisitor<usize> for WitnessBuilder<'a> {
478    fn input(&mut self, idx: SymbolId, name: &str) -> usize {
479        self.witness.add_variable(
480            name,
481            *self
482                .assignment
483                .get(&idx)
484                .expect("assignment should be available"),
485        )
486    }
487    fn constant(&mut self, _idx: SymbolId, v: BitVector) -> usize {
488        self.witness.add_constant(v)
489    }
490    fn unary(&mut self, idx: SymbolId, op: BVOperator, v: usize) -> usize {
491        self.witness.add_unary(
492            op,
493            v,
494            *self
495                .assignment
496                .get(&idx)
497                .expect("assignment should be available"),
498        )
499    }
500    fn binary(&mut self, idx: SymbolId, op: BVOperator, lhs: usize, rhs: usize) -> usize {
501        self.witness.add_binary(
502            lhs,
503            op,
504            rhs,
505            *self
506                .assignment
507                .get(&idx)
508                .expect("assignment should be available"),
509        )
510    }
511}
512
513#[derive(Debug, Clone)]
514pub(crate) enum Term {
515    Constant(u64),
516    Variable(String, u64),
517    Unary(BVOperator, usize, u64),
518    Binary(usize, BVOperator, usize, u64),
519}
520
521#[derive(Debug, Clone)]
522pub struct Witness {
523    assignments: Vec<Term>,
524}
525
526impl Default for Witness {
527    fn default() -> Self {
528        Self {
529            assignments: Vec::new(),
530        }
531    }
532}
533
534impl Witness {
535    pub fn new() -> Self {
536        Witness::default()
537    }
538
539    pub fn add_constant(&mut self, value: BitVector) -> usize {
540        self.assignments.push(Term::Constant(value.0));
541
542        self.assignments.len() - 1
543    }
544
545    pub fn add_variable(&mut self, name: &str, result: BitVector) -> usize {
546        self.assignments
547            .push(Term::Variable(name.to_owned(), result.0));
548
549        self.assignments.len() - 1
550    }
551
552    pub fn add_unary(&mut self, op: BVOperator, v: usize, result: BitVector) -> usize {
553        self.assignments.push(Term::Unary(op, v, result.0));
554
555        self.assignments.len() - 1
556    }
557
558    pub fn add_binary(
559        &mut self,
560        lhs: usize,
561        op: BVOperator,
562        rhs: usize,
563        result: BitVector,
564    ) -> usize {
565        self.assignments.push(Term::Binary(lhs, op, rhs, result.0));
566
567        self.assignments.len() - 1
568    }
569}
570
571impl fmt::Display for Witness {
572    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
573        writeln!(f, "[").and_then(|_| {
574            self.assignments
575                .clone()
576                .into_iter()
577                .enumerate()
578                .try_for_each(|(id, a)| match a {
579                    Term::Constant(c) => writeln!(f, "  x{} := {},", id, c),
580                    Term::Variable(name, v) => writeln!(f, "  x{} := {:?} ({}),", id, name, v),
581                    Term::Unary(op, x, v) => writeln!(f, "  x{} := {}x{} ({}),", id, op, x, v),
582                    Term::Binary(lhs, op, rhs, v) => {
583                        writeln!(f, "  x{} := x{} {} x{} ({}),", id, lhs, op, rhs, v)
584                    }
585                })
586                .and_then(|_| writeln!(f, "]"))
587        })
588    }
589}