1use crate::solver::{
2 BVOperator, BitVector, Formula, FormulaVisitor, OperandSide, Solver, SolverError, Symbol,
3 SymbolId,
4};
5use log::{debug, trace, Level};
6pub use petgraph::graph::{EdgeIndex, NodeIndex};
7use petgraph::visit::EdgeRef;
8use petgraph::{
9 dot::Dot,
10 graph::{Neighbors, NodeIndices},
11 Direction,
12};
13use riscu::Instruction;
14use std::{collections::HashMap, fmt, ops::Index};
15
16pub enum Query {
17 Equals((SymbolicValue, u64)),
18 NotEquals((SymbolicValue, u64)),
19 Reachable,
20}
21
22pub enum QueryResult {
23 Sat(Witness),
24 UnSat,
25 Unknown,
26}
27
28pub type SymbolicValue = NodeIndex;
29pub type DataFlowGraph = petgraph::Graph<Symbol, OperandSide>;
30
31fn instruction_to_bv_operator(instruction: Instruction) -> BVOperator {
32 match instruction {
33 Instruction::Add(_) | Instruction::Addi(_) => BVOperator::Add,
34 Instruction::Sub(_) => BVOperator::Sub,
35 Instruction::Mul(_) => BVOperator::Mul,
36 Instruction::Divu(_) => BVOperator::Divu,
37 Instruction::Remu(_) => BVOperator::Remu,
38 Instruction::Sltu(_) => BVOperator::Sltu,
39 _ => unimplemented!("can not translate {:?} to Operator", instruction),
40 }
41}
42
43#[derive(Debug)]
44pub struct SymbolicState<'a, S>
45where
46 S: Solver,
47{
48 data_flow: DataFlowGraph,
49 path_condition: SymbolicValue,
50 solver: &'a S,
51}
52
53impl<'a, S> Clone for SymbolicState<'a, S>
54where
55 S: Solver,
56{
57 fn clone(&self) -> Self {
58 Self {
59 data_flow: self.data_flow.clone(),
60 path_condition: self.path_condition,
61 solver: self.solver,
62 }
63 }
64}
65
66impl<'a, S> SymbolicState<'a, S>
67where
68 S: Solver,
69{
70 pub fn new(solver: &'a S) -> Self {
71 let mut data_flow = DataFlowGraph::new();
72
73 let constant = Symbol::Constant(BitVector(1));
74
75 let path_condition = data_flow.add_node(constant);
76
77 Self {
78 data_flow,
79 path_condition,
80 solver,
81 }
82 }
83
84 pub fn create_const(&mut self, value: u64) -> SymbolicValue {
85 let constant = Symbol::Constant(BitVector(value));
86
87 let i = self.data_flow.add_node(constant);
88
89 trace!("new constant: x{} := {:#x}", i.index(), value);
90
91 i
92 }
93
94 pub fn create_instruction(
95 &mut self,
96 instruction: Instruction,
97 lhs: SymbolicValue,
98 rhs: SymbolicValue,
99 ) -> SymbolicValue {
100 let op = instruction_to_bv_operator(instruction);
101
102 let root = self.create_operator(op, lhs, rhs);
103
104 if matches!(op, BVOperator::Divu)
107 && matches!(self.data_flow[rhs], Symbol::Operator(_) | Symbol::Input(_))
108 {
109 let zero = self.create_const(0);
110 let negated_condition = self.create_operator(BVOperator::Equals, rhs, zero);
111 let condition = self.create_unary_operator(BVOperator::Not, negated_condition);
112
113 self.add_path_condition(condition);
114 }
115
116 root
117 }
118
119 pub fn create_operator(
120 &mut self,
121 op: BVOperator,
122 lhs: SymbolicValue,
123 rhs: SymbolicValue,
124 ) -> SymbolicValue {
125 assert!(op.is_binary(), "has to be a binary operator");
126
127 let n = Symbol::Operator(op);
128 let n_idx = self.data_flow.add_node(n);
129
130 assert!(!(
131 matches!(self.data_flow[lhs], Symbol::Constant(_))
132 && matches!(self.data_flow[rhs], Symbol::Constant(_))
133 ),
134 "every operand has to be derived from an input or has to be an (already folded) constant"
135 );
136
137 self.connect_operator(lhs, rhs, n_idx);
138
139 trace!(
140 "new operator: x{} := x{} {} x{}",
141 n_idx.index(),
142 lhs.index(),
143 op,
144 rhs.index()
145 );
146
147 n_idx
148 }
149
150 fn create_unary_operator(&mut self, op: BVOperator, v: SymbolicValue) -> SymbolicValue {
151 assert!(op.is_unary(), "has to be a unary operator");
152
153 let op_id = self.data_flow.add_node(Symbol::Operator(op));
154
155 self.data_flow.add_edge(v, op_id, OperandSide::Lhs);
156
157 op_id
158 }
159
160 pub fn create_input(&mut self, name: &str) -> SymbolicValue {
161 let node = Symbol::Input(String::from(name));
162
163 let idx = self.data_flow.add_node(node);
164
165 trace!("new input: x{} := {:?}", idx.index(), name);
166
167 idx
168 }
169
170 pub fn create_beq_path_condition(
171 &mut self,
172 decision: bool,
173 lhs: SymbolicValue,
174 rhs: SymbolicValue,
175 ) {
176 let mut pc_idx = self.create_operator(BVOperator::Equals, lhs, rhs);
177
178 if !decision {
179 pc_idx = self.create_unary_operator(BVOperator::Not, pc_idx);
180 }
181
182 self.add_path_condition(pc_idx)
183 }
184
185 fn add_path_condition(&mut self, condition: SymbolicValue) {
186 self.path_condition =
187 self.create_operator(BVOperator::BitwiseAnd, self.path_condition, condition);
188 }
189
190 pub fn execute_query(&mut self, query: Query) -> Result<QueryResult, SolverError> {
191 let (root, cleanup_nodes, cleanup_edges) = match query {
193 Query::Equals(_) | Query::NotEquals(_) => self.prepare_query(query),
194 Query::Reachable => (self.path_condition, vec![], vec![]),
195 };
196
197 let formula = FormulaView::new(&self.data_flow, root);
198
199 if log::log_enabled!(Level::Debug) {
200 debug!("query to solve:");
201
202 let root = formula.print_recursive();
203
204 debug!("assert x{} is 1", root);
205 }
206
207 let result = match self.solver.solve(&formula) {
208 Ok(Some(ref assignment)) => Ok(QueryResult::Sat(formula.build_witness(assignment))),
209 Ok(None) => Ok(QueryResult::UnSat),
210 Err(SolverError::SatUnknown) | Err(SolverError::Timeout) => Ok(QueryResult::Unknown),
211 Err(e) => Err(e),
212 };
213
214 cleanup_edges.iter().for_each(|e| {
215 self.data_flow.remove_edge(*e);
216 });
217 cleanup_nodes.iter().for_each(|n| {
218 self.data_flow.remove_node(*n);
219 });
220
221 result
222 }
223
224 fn append_path_condition(
225 &mut self,
226 r: SymbolicValue,
227 mut ns: Vec<SymbolicValue>,
228 mut es: Vec<EdgeIndex>,
229 ) -> (SymbolicValue, Vec<SymbolicValue>, Vec<EdgeIndex>) {
230 let con_idx = self
231 .data_flow
232 .add_node(Symbol::Operator(BVOperator::BitwiseAnd));
233 let (con_edge_idx1, con_edge_idx2) = self.connect_operator(self.path_condition, r, con_idx);
234
235 ns.push(con_idx);
236 es.push(con_edge_idx1);
237 es.push(con_edge_idx2);
238
239 (con_idx, ns, es)
240 }
241
242 fn prepare_query(
243 &mut self,
244 query: Query,
245 ) -> (SymbolicValue, Vec<SymbolicValue>, Vec<EdgeIndex>) {
246 match query {
247 Query::Equals((sym, c)) | Query::NotEquals((sym, c)) => {
248 let root_idx = self
249 .data_flow
250 .add_node(Symbol::Operator(BVOperator::Equals));
251
252 let const_idx = self.data_flow.add_node(Symbol::Constant(BitVector(c)));
253 let const_edge_idx = self
254 .data_flow
255 .add_edge(const_idx, root_idx, OperandSide::Lhs);
256
257 let sym_edge_idx = self.data_flow.add_edge(sym, root_idx, OperandSide::Rhs);
258
259 if let Query::NotEquals(_) = query {
260 let not_idx = self.data_flow.add_node(Symbol::Operator(BVOperator::Not));
261 let not_edge_idx = self.data_flow.add_edge(root_idx, not_idx, OperandSide::Lhs);
262
263 self.append_path_condition(
264 not_idx,
265 vec![root_idx, const_idx, not_idx],
266 vec![const_edge_idx, sym_edge_idx, not_edge_idx],
267 )
268 } else {
269 self.append_path_condition(
270 root_idx,
271 vec![root_idx, const_idx],
272 vec![const_edge_idx, sym_edge_idx],
273 )
274 }
275 }
276 Query::Reachable => panic!("nothing to be prepeared for that query"),
277 }
278 }
279
280 fn connect_operator(
281 &mut self,
282 lhs: SymbolicValue,
283 rhs: SymbolicValue,
284 op: SymbolicValue,
285 ) -> (EdgeIndex, EdgeIndex) {
286 (
289 self.data_flow.add_edge(rhs, op, OperandSide::Rhs),
290 self.data_flow.add_edge(lhs, op, OperandSide::Lhs),
291 )
292 }
293}
294
295impl<'a, S: Solver> fmt::Display for SymbolicState<'a, S> {
296 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
297 let dot_graph = Dot::with_config(&self.data_flow, &[]);
298
299 write!(f, "{:?}", dot_graph)
300 }
301}
302
303pub struct FormulaView<'a> {
304 data_flow: &'a DataFlowGraph,
305 root: SymbolicValue,
306}
307
308impl<'a> FormulaView<'a> {
309 pub fn new(data_flow: &'a DataFlowGraph, root: SymbolicValue) -> Self {
310 Self { data_flow, root }
311 }
312
313 pub fn print_recursive(&self) -> SymbolId {
314 let mut visited = HashMap::<SymbolId, SymbolId>::new();
315 let mut printer = Printer {};
316
317 self.traverse(self.root(), &mut visited, &mut printer)
318 }
319
320 fn build_witness(&self, assignment: &HashMap<SymbolId, BitVector>) -> Witness {
321 let mut visited = HashMap::<SymbolId, usize>::new();
322
323 let mut witness = Witness::new();
324 let mut builder = WitnessBuilder {
325 witness: &mut witness,
326 assignment,
327 };
328
329 self.traverse(self.root(), &mut visited, &mut builder);
330
331 witness
332 }
333}
334
335impl<'a> Index<SymbolId> for FormulaView<'a> {
336 type Output = Symbol;
337
338 fn index(&self, idx: SymbolId) -> &Self::Output {
339 &self.data_flow[NodeIndex::new(idx)]
340 }
341}
342
343impl<'a> Formula for FormulaView<'a> {
344 type DependencyIter = std::iter::Map<Neighbors<'a, OperandSide>, fn(NodeIndex) -> usize>;
345 type SymbolIdsIter = std::iter::Map<NodeIndices, fn(NodeIndex) -> usize>;
346
347 fn root(&self) -> SymbolId {
348 self.root.index()
349 }
350
351 fn operands(&self, sym: SymbolId) -> (SymbolId, Option<SymbolId>) {
352 let mut iter = self
353 .data_flow
354 .neighbors_directed(NodeIndex::new(sym), Direction::Incoming)
355 .detach();
356
357 let lhs = iter
358 .next(self.data_flow)
359 .expect("get_operands() should not be called on operators without operands")
360 .1
361 .index();
362
363 let rhs = iter.next(self.data_flow).map(|n| n.1.index());
364
365 assert!(
366 iter.next(self.data_flow) == None,
367 "operators with arity 1 or 2 are supported only"
368 );
369
370 (lhs, rhs)
371 }
372
373 fn operand(&self, sym: SymbolId) -> SymbolId {
374 self.data_flow
375 .edges_directed(NodeIndex::new(sym), Direction::Incoming)
376 .next()
377 .expect("every unary operator must have an operand")
378 .source()
379 .index()
380 }
381
382 fn dependencies(&self, sym: SymbolId) -> Self::DependencyIter {
383 self.data_flow
384 .neighbors_directed(NodeIndex::new(sym), Direction::Outgoing)
385 .map(|idx| idx.index())
386 }
387
388 fn symbol_ids(&self) -> Self::SymbolIdsIter {
389 self.data_flow.node_indices().map(|i| i.index())
390 }
391
392 fn is_operand(&self, sym: SymbolId) -> bool {
393 !matches!(self.data_flow[NodeIndex::new(sym)], Symbol::Operator(_))
394 }
395
396 fn traverse<V, R>(&self, n: SymbolId, visit_map: &mut HashMap<SymbolId, R>, v: &mut V) -> R
397 where
398 V: FormulaVisitor<R>,
399 R: Clone,
400 {
401 if let Some(result) = visit_map.get(&n) {
402 return (*result).clone();
403 }
404
405 let result = match &self.data_flow[NodeIndex::new(n)] {
406 Symbol::Operator(op) => {
407 let mut operands = self
408 .data_flow
409 .neighbors_directed(NodeIndex::new(n), Direction::Incoming)
410 .detach();
411
412 if op.is_unary() {
413 let x = operands
414 .next(self.data_flow)
415 .expect("every unary operator must have 1 operand")
416 .1
417 .index();
418
419 let x = self.traverse(x, visit_map, v);
420
421 v.unary(n, *op, x)
422 } else {
423 let lhs = operands
424 .next(self.data_flow)
425 .expect("every binary operator must have an lhs operand")
426 .1
427 .index();
428
429 let rhs = operands
430 .next(self.data_flow)
431 .expect("every binary operator must have an rhs operand")
432 .1
433 .index();
434
435 let lhs = self.traverse(lhs, visit_map, v);
436 let rhs = self.traverse(rhs, visit_map, v);
437
438 v.binary(n, *op, lhs, rhs)
439 }
440 }
441 Symbol::Constant(c) => v.constant(n, *c),
442 Symbol::Input(name) => v.input(n, name.as_str()),
443 };
444
445 visit_map.insert(n, result.clone());
446
447 result
448 }
449}
450
451struct Printer {}
452
453impl<'a> FormulaVisitor<SymbolId> for Printer {
454 fn input(&mut self, idx: SymbolId, name: &str) -> SymbolId {
455 debug!("x{} := {:?}", idx, name);
456 idx
457 }
458 fn constant(&mut self, idx: SymbolId, v: BitVector) -> SymbolId {
459 debug!("x{} := {}", idx, v.0);
460 idx
461 }
462 fn unary(&mut self, idx: SymbolId, op: BVOperator, v: SymbolId) -> SymbolId {
463 debug!("x{} := {}x{}", idx, op, v);
464 idx
465 }
466 fn binary(&mut self, idx: SymbolId, op: BVOperator, lhs: SymbolId, rhs: SymbolId) -> SymbolId {
467 debug!("x{} := x{} {} x{}", idx, lhs, op, rhs);
468 idx
469 }
470}
471
472struct WitnessBuilder<'a> {
473 witness: &'a mut Witness,
474 assignment: &'a HashMap<SymbolId, BitVector>,
475}
476
477impl<'a> FormulaVisitor<usize> for WitnessBuilder<'a> {
478 fn input(&mut self, idx: SymbolId, name: &str) -> usize {
479 self.witness.add_variable(
480 name,
481 *self
482 .assignment
483 .get(&idx)
484 .expect("assignment should be available"),
485 )
486 }
487 fn constant(&mut self, _idx: SymbolId, v: BitVector) -> usize {
488 self.witness.add_constant(v)
489 }
490 fn unary(&mut self, idx: SymbolId, op: BVOperator, v: usize) -> usize {
491 self.witness.add_unary(
492 op,
493 v,
494 *self
495 .assignment
496 .get(&idx)
497 .expect("assignment should be available"),
498 )
499 }
500 fn binary(&mut self, idx: SymbolId, op: BVOperator, lhs: usize, rhs: usize) -> usize {
501 self.witness.add_binary(
502 lhs,
503 op,
504 rhs,
505 *self
506 .assignment
507 .get(&idx)
508 .expect("assignment should be available"),
509 )
510 }
511}
512
513#[derive(Debug, Clone)]
514pub(crate) enum Term {
515 Constant(u64),
516 Variable(String, u64),
517 Unary(BVOperator, usize, u64),
518 Binary(usize, BVOperator, usize, u64),
519}
520
521#[derive(Debug, Clone)]
522pub struct Witness {
523 assignments: Vec<Term>,
524}
525
526impl Default for Witness {
527 fn default() -> Self {
528 Self {
529 assignments: Vec::new(),
530 }
531 }
532}
533
534impl Witness {
535 pub fn new() -> Self {
536 Witness::default()
537 }
538
539 pub fn add_constant(&mut self, value: BitVector) -> usize {
540 self.assignments.push(Term::Constant(value.0));
541
542 self.assignments.len() - 1
543 }
544
545 pub fn add_variable(&mut self, name: &str, result: BitVector) -> usize {
546 self.assignments
547 .push(Term::Variable(name.to_owned(), result.0));
548
549 self.assignments.len() - 1
550 }
551
552 pub fn add_unary(&mut self, op: BVOperator, v: usize, result: BitVector) -> usize {
553 self.assignments.push(Term::Unary(op, v, result.0));
554
555 self.assignments.len() - 1
556 }
557
558 pub fn add_binary(
559 &mut self,
560 lhs: usize,
561 op: BVOperator,
562 rhs: usize,
563 result: BitVector,
564 ) -> usize {
565 self.assignments.push(Term::Binary(lhs, op, rhs, result.0));
566
567 self.assignments.len() - 1
568 }
569}
570
571impl fmt::Display for Witness {
572 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
573 writeln!(f, "[").and_then(|_| {
574 self.assignments
575 .clone()
576 .into_iter()
577 .enumerate()
578 .try_for_each(|(id, a)| match a {
579 Term::Constant(c) => writeln!(f, " x{} := {},", id, c),
580 Term::Variable(name, v) => writeln!(f, " x{} := {:?} ({}),", id, name, v),
581 Term::Unary(op, x, v) => writeln!(f, " x{} := {}x{} ({}),", id, op, x, v),
582 Term::Binary(lhs, op, rhs, v) => {
583 writeln!(f, " x{} := x{} {} x{} ({}),", id, lhs, op, rhs, v)
584 }
585 })
586 .and_then(|_| writeln!(f, "]"))
587 })
588 }
589}