use crate::solver::{
BVOperator, BitVector, Formula, FormulaVisitor, OperandSide, Solver, SolverError, Symbol,
SymbolId,
};
use log::{debug, trace, Level};
pub use petgraph::graph::{EdgeIndex, NodeIndex};
use petgraph::visit::EdgeRef;
use petgraph::{
dot::Dot,
graph::{Neighbors, NodeIndices},
Direction,
};
use riscu::Instruction;
use std::{collections::HashMap, fmt, ops::Index};
pub enum Query {
Equals((SymbolicValue, u64)),
NotEquals((SymbolicValue, u64)),
Reachable,
}
pub enum QueryResult {
Sat(Witness),
UnSat,
Unknown,
}
pub type SymbolicValue = NodeIndex;
pub type DataFlowGraph = petgraph::Graph<Symbol, OperandSide>;
fn instruction_to_bv_operator(instruction: Instruction) -> BVOperator {
match instruction {
Instruction::Add(_) | Instruction::Addi(_) => BVOperator::Add,
Instruction::Sub(_) => BVOperator::Sub,
Instruction::Mul(_) => BVOperator::Mul,
Instruction::Divu(_) => BVOperator::Divu,
Instruction::Remu(_) => BVOperator::Remu,
Instruction::Sltu(_) => BVOperator::Sltu,
_ => unimplemented!("can not translate {:?} to Operator", instruction),
}
}
#[derive(Debug)]
pub struct SymbolicState<'a, S>
where
S: Solver,
{
data_flow: DataFlowGraph,
path_condition: SymbolicValue,
solver: &'a S,
}
impl<'a, S> Clone for SymbolicState<'a, S>
where
S: Solver,
{
fn clone(&self) -> Self {
Self {
data_flow: self.data_flow.clone(),
path_condition: self.path_condition,
solver: self.solver,
}
}
}
impl<'a, S> SymbolicState<'a, S>
where
S: Solver,
{
pub fn new(solver: &'a S) -> Self {
let mut data_flow = DataFlowGraph::new();
let constant = Symbol::Constant(BitVector(1));
let path_condition = data_flow.add_node(constant);
Self {
data_flow,
path_condition,
solver,
}
}
pub fn create_const(&mut self, value: u64) -> SymbolicValue {
let constant = Symbol::Constant(BitVector(value));
let i = self.data_flow.add_node(constant);
trace!("new constant: x{} := {:#x}", i.index(), value);
i
}
pub fn create_instruction(
&mut self,
instruction: Instruction,
lhs: SymbolicValue,
rhs: SymbolicValue,
) -> SymbolicValue {
let op = instruction_to_bv_operator(instruction);
let root = self.create_operator(op, lhs, rhs);
if matches!(op, BVOperator::Divu)
&& matches!(self.data_flow[rhs], Symbol::Operator(_) | Symbol::Input(_))
{
let zero = self.create_const(0);
let negated_condition = self.create_operator(BVOperator::Equals, rhs, zero);
let condition = self.create_unary_operator(BVOperator::Not, negated_condition);
self.add_path_condition(condition);
}
root
}
pub fn create_operator(
&mut self,
op: BVOperator,
lhs: SymbolicValue,
rhs: SymbolicValue,
) -> SymbolicValue {
assert!(op.is_binary(), "has to be a binary operator");
let n = Symbol::Operator(op);
let n_idx = self.data_flow.add_node(n);
assert!(!(
matches!(self.data_flow[lhs], Symbol::Constant(_))
&& matches!(self.data_flow[rhs], Symbol::Constant(_))
),
"every operand has to be derived from an input or has to be an (already folded) constant"
);
self.connect_operator(lhs, rhs, n_idx);
trace!(
"new operator: x{} := x{} {} x{}",
n_idx.index(),
lhs.index(),
op,
rhs.index()
);
n_idx
}
fn create_unary_operator(&mut self, op: BVOperator, v: SymbolicValue) -> SymbolicValue {
assert!(op.is_unary(), "has to be a unary operator");
let op_id = self.data_flow.add_node(Symbol::Operator(op));
self.data_flow.add_edge(v, op_id, OperandSide::Lhs);
op_id
}
pub fn create_input(&mut self, name: &str) -> SymbolicValue {
let node = Symbol::Input(String::from(name));
let idx = self.data_flow.add_node(node);
trace!("new input: x{} := {:?}", idx.index(), name);
idx
}
pub fn create_beq_path_condition(
&mut self,
decision: bool,
lhs: SymbolicValue,
rhs: SymbolicValue,
) {
let mut pc_idx = self.create_operator(BVOperator::Equals, lhs, rhs);
if !decision {
pc_idx = self.create_unary_operator(BVOperator::Not, pc_idx);
}
self.add_path_condition(pc_idx)
}
fn add_path_condition(&mut self, condition: SymbolicValue) {
self.path_condition =
self.create_operator(BVOperator::BitwiseAnd, self.path_condition, condition);
}
pub fn execute_query(&mut self, query: Query) -> Result<QueryResult, SolverError> {
let (root, cleanup_nodes, cleanup_edges) = match query {
Query::Equals(_) | Query::NotEquals(_) => self.prepare_query(query),
Query::Reachable => (self.path_condition, vec![], vec![]),
};
let formula = FormulaView::new(&self.data_flow, root);
if log::log_enabled!(Level::Debug) {
debug!("query to solve:");
let root = formula.print_recursive();
debug!("assert x{} is 1", root);
}
let result = match self.solver.solve(&formula) {
Ok(Some(ref assignment)) => Ok(QueryResult::Sat(formula.build_witness(assignment))),
Ok(None) => Ok(QueryResult::UnSat),
Err(SolverError::SatUnknown) | Err(SolverError::Timeout) => Ok(QueryResult::Unknown),
Err(e) => Err(e),
};
cleanup_edges.iter().for_each(|e| {
self.data_flow.remove_edge(*e);
});
cleanup_nodes.iter().for_each(|n| {
self.data_flow.remove_node(*n);
});
result
}
fn append_path_condition(
&mut self,
r: SymbolicValue,
mut ns: Vec<SymbolicValue>,
mut es: Vec<EdgeIndex>,
) -> (SymbolicValue, Vec<SymbolicValue>, Vec<EdgeIndex>) {
let con_idx = self
.data_flow
.add_node(Symbol::Operator(BVOperator::BitwiseAnd));
let (con_edge_idx1, con_edge_idx2) = self.connect_operator(self.path_condition, r, con_idx);
ns.push(con_idx);
es.push(con_edge_idx1);
es.push(con_edge_idx2);
(con_idx, ns, es)
}
fn prepare_query(
&mut self,
query: Query,
) -> (SymbolicValue, Vec<SymbolicValue>, Vec<EdgeIndex>) {
match query {
Query::Equals((sym, c)) | Query::NotEquals((sym, c)) => {
let root_idx = self
.data_flow
.add_node(Symbol::Operator(BVOperator::Equals));
let const_idx = self.data_flow.add_node(Symbol::Constant(BitVector(c)));
let const_edge_idx = self
.data_flow
.add_edge(const_idx, root_idx, OperandSide::Lhs);
let sym_edge_idx = self.data_flow.add_edge(sym, root_idx, OperandSide::Rhs);
if let Query::NotEquals(_) = query {
let not_idx = self.data_flow.add_node(Symbol::Operator(BVOperator::Not));
let not_edge_idx = self.data_flow.add_edge(root_idx, not_idx, OperandSide::Lhs);
self.append_path_condition(
not_idx,
vec![root_idx, const_idx, not_idx],
vec![const_edge_idx, sym_edge_idx, not_edge_idx],
)
} else {
self.append_path_condition(
root_idx,
vec![root_idx, const_idx],
vec![const_edge_idx, sym_edge_idx],
)
}
}
Query::Reachable => panic!("nothing to be prepeared for that query"),
}
}
fn connect_operator(
&mut self,
lhs: SymbolicValue,
rhs: SymbolicValue,
op: SymbolicValue,
) -> (EdgeIndex, EdgeIndex) {
(
self.data_flow.add_edge(rhs, op, OperandSide::Rhs),
self.data_flow.add_edge(lhs, op, OperandSide::Lhs),
)
}
}
impl<'a, S: Solver> fmt::Display for SymbolicState<'a, S> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let dot_graph = Dot::with_config(&self.data_flow, &[]);
write!(f, "{:?}", dot_graph)
}
}
pub struct FormulaView<'a> {
data_flow: &'a DataFlowGraph,
root: SymbolicValue,
}
impl<'a> FormulaView<'a> {
pub fn new(data_flow: &'a DataFlowGraph, root: SymbolicValue) -> Self {
Self { data_flow, root }
}
pub fn print_recursive(&self) -> SymbolId {
let mut visited = HashMap::<SymbolId, SymbolId>::new();
let mut printer = Printer {};
self.traverse(self.root(), &mut visited, &mut printer)
}
fn build_witness(&self, assignment: &HashMap<SymbolId, BitVector>) -> Witness {
let mut visited = HashMap::<SymbolId, usize>::new();
let mut witness = Witness::new();
let mut builder = WitnessBuilder {
witness: &mut witness,
assignment,
};
self.traverse(self.root(), &mut visited, &mut builder);
witness
}
}
impl<'a> Index<SymbolId> for FormulaView<'a> {
type Output = Symbol;
fn index(&self, idx: SymbolId) -> &Self::Output {
&self.data_flow[NodeIndex::new(idx)]
}
}
impl<'a> Formula for FormulaView<'a> {
type DependencyIter = std::iter::Map<Neighbors<'a, OperandSide>, fn(NodeIndex) -> usize>;
type SymbolIdsIter = std::iter::Map<NodeIndices, fn(NodeIndex) -> usize>;
fn root(&self) -> SymbolId {
self.root.index()
}
fn operands(&self, sym: SymbolId) -> (SymbolId, Option<SymbolId>) {
let mut iter = self
.data_flow
.neighbors_directed(NodeIndex::new(sym), Direction::Incoming)
.detach();
let lhs = iter
.next(self.data_flow)
.expect("get_operands() should not be called on operators without operands")
.1
.index();
let rhs = iter.next(self.data_flow).map(|n| n.1.index());
assert!(
iter.next(self.data_flow) == None,
"operators with arity 1 or 2 are supported only"
);
(lhs, rhs)
}
fn operand(&self, sym: SymbolId) -> SymbolId {
self.data_flow
.edges_directed(NodeIndex::new(sym), Direction::Incoming)
.next()
.expect("every unary operator must have an operand")
.source()
.index()
}
fn dependencies(&self, sym: SymbolId) -> Self::DependencyIter {
self.data_flow
.neighbors_directed(NodeIndex::new(sym), Direction::Outgoing)
.map(|idx| idx.index())
}
fn symbol_ids(&self) -> Self::SymbolIdsIter {
self.data_flow.node_indices().map(|i| i.index())
}
fn is_operand(&self, sym: SymbolId) -> bool {
!matches!(self.data_flow[NodeIndex::new(sym)], Symbol::Operator(_))
}
fn traverse<V, R>(&self, n: SymbolId, visit_map: &mut HashMap<SymbolId, R>, v: &mut V) -> R
where
V: FormulaVisitor<R>,
R: Clone,
{
if let Some(result) = visit_map.get(&n) {
return (*result).clone();
}
let result = match &self.data_flow[NodeIndex::new(n)] {
Symbol::Operator(op) => {
let mut operands = self
.data_flow
.neighbors_directed(NodeIndex::new(n), Direction::Incoming)
.detach();
if op.is_unary() {
let x = operands
.next(self.data_flow)
.expect("every unary operator must have 1 operand")
.1
.index();
let x = self.traverse(x, visit_map, v);
v.unary(n, *op, x)
} else {
let lhs = operands
.next(self.data_flow)
.expect("every binary operator must have an lhs operand")
.1
.index();
let rhs = operands
.next(self.data_flow)
.expect("every binary operator must have an rhs operand")
.1
.index();
let lhs = self.traverse(lhs, visit_map, v);
let rhs = self.traverse(rhs, visit_map, v);
v.binary(n, *op, lhs, rhs)
}
}
Symbol::Constant(c) => v.constant(n, *c),
Symbol::Input(name) => v.input(n, name.as_str()),
};
visit_map.insert(n, result.clone());
result
}
}
struct Printer {}
impl<'a> FormulaVisitor<SymbolId> for Printer {
fn input(&mut self, idx: SymbolId, name: &str) -> SymbolId {
debug!("x{} := {:?}", idx, name);
idx
}
fn constant(&mut self, idx: SymbolId, v: BitVector) -> SymbolId {
debug!("x{} := {}", idx, v.0);
idx
}
fn unary(&mut self, idx: SymbolId, op: BVOperator, v: SymbolId) -> SymbolId {
debug!("x{} := {}x{}", idx, op, v);
idx
}
fn binary(&mut self, idx: SymbolId, op: BVOperator, lhs: SymbolId, rhs: SymbolId) -> SymbolId {
debug!("x{} := x{} {} x{}", idx, lhs, op, rhs);
idx
}
}
struct WitnessBuilder<'a> {
witness: &'a mut Witness,
assignment: &'a HashMap<SymbolId, BitVector>,
}
impl<'a> FormulaVisitor<usize> for WitnessBuilder<'a> {
fn input(&mut self, idx: SymbolId, name: &str) -> usize {
self.witness.add_variable(
name,
*self
.assignment
.get(&idx)
.expect("assignment should be available"),
)
}
fn constant(&mut self, _idx: SymbolId, v: BitVector) -> usize {
self.witness.add_constant(v)
}
fn unary(&mut self, idx: SymbolId, op: BVOperator, v: usize) -> usize {
self.witness.add_unary(
op,
v,
*self
.assignment
.get(&idx)
.expect("assignment should be available"),
)
}
fn binary(&mut self, idx: SymbolId, op: BVOperator, lhs: usize, rhs: usize) -> usize {
self.witness.add_binary(
lhs,
op,
rhs,
*self
.assignment
.get(&idx)
.expect("assignment should be available"),
)
}
}
#[derive(Debug, Clone)]
pub(crate) enum Term {
Constant(u64),
Variable(String, u64),
Unary(BVOperator, usize, u64),
Binary(usize, BVOperator, usize, u64),
}
#[derive(Debug, Clone)]
pub struct Witness {
assignments: Vec<Term>,
}
impl Default for Witness {
fn default() -> Self {
Self {
assignments: Vec::new(),
}
}
}
impl Witness {
pub fn new() -> Self {
Witness::default()
}
pub fn add_constant(&mut self, value: BitVector) -> usize {
self.assignments.push(Term::Constant(value.0));
self.assignments.len() - 1
}
pub fn add_variable(&mut self, name: &str, result: BitVector) -> usize {
self.assignments
.push(Term::Variable(name.to_owned(), result.0));
self.assignments.len() - 1
}
pub fn add_unary(&mut self, op: BVOperator, v: usize, result: BitVector) -> usize {
self.assignments.push(Term::Unary(op, v, result.0));
self.assignments.len() - 1
}
pub fn add_binary(
&mut self,
lhs: usize,
op: BVOperator,
rhs: usize,
result: BitVector,
) -> usize {
self.assignments.push(Term::Binary(lhs, op, rhs, result.0));
self.assignments.len() - 1
}
}
impl fmt::Display for Witness {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
writeln!(f, "[").and_then(|_| {
self.assignments
.clone()
.into_iter()
.enumerate()
.try_for_each(|(id, a)| match a {
Term::Constant(c) => writeln!(f, " x{} := {},", id, c),
Term::Variable(name, v) => writeln!(f, " x{} := {:?} ({}),", id, name, v),
Term::Unary(op, x, v) => writeln!(f, " x{} := {}x{} ({}),", id, op, x, v),
Term::Binary(lhs, op, rhs, v) => {
writeln!(f, " x{} := x{} {} x{} ({}),", id, lhs, op, rhs, v)
}
})
.and_then(|_| writeln!(f, "]"))
})
}
}