#include "internal.hpp"
namespace CaDiCaL {
inline unsigned Checker::l2u (int lit) {
assert (lit);
assert (lit != INT_MIN);
unsigned res = 2 * (abs (lit) - 1);
if (lit < 0)
res++;
return res;
}
inline signed char Checker::val (int lit) {
assert (lit);
assert (lit != INT_MIN);
assert (abs (lit) < size_vars);
assert (vals[lit] == -vals[-lit]);
return vals[lit];
}
signed char &Checker::mark (int lit) {
const unsigned u = l2u (lit);
assert (u < marks.size ());
return marks[u];
}
inline CheckerWatcher &Checker::watcher (int lit) {
const unsigned u = l2u (lit);
assert (u < watchers.size ());
return watchers[u];
}
CheckerClause *Checker::new_clause () {
const size_t size = simplified.size ();
assert (size > 1), assert (size <= UINT_MAX);
const size_t bytes = sizeof (CheckerClause) + (size - 2) * sizeof (int);
CheckerClause *res = (CheckerClause *) new char[bytes];
res->next = 0;
res->hash = last_hash;
res->size = size;
int *literals = res->literals, *p = literals;
for (const auto &lit : simplified)
*p++ = lit;
num_clauses++;
for (unsigned i = 0; i < 2; i++) {
int lit = literals[i];
if (!val (lit))
continue;
for (unsigned j = i + 1; j < size; j++) {
int other = literals[j];
if (val (other))
continue;
swap (literals[i], literals[j]);
break;
}
}
assert (!val (literals[0]));
assert (!val (literals[1]));
watcher (literals[0]).push_back (CheckerWatch (literals[1], res));
watcher (literals[1]).push_back (CheckerWatch (literals[0], res));
return res;
}
void Checker::delete_clause (CheckerClause *c) {
if (c->size) {
assert (c->size > 1);
assert (num_clauses);
num_clauses--;
} else {
assert (num_garbage);
num_garbage--;
}
delete[](char *) c;
}
void Checker::enlarge_clauses () {
assert (num_clauses == size_clauses);
const uint64_t new_size_clauses = size_clauses ? 2 * size_clauses : 1;
LOG ("CHECKER enlarging clauses of checker from %" PRIu64 " to %" PRIu64,
(uint64_t) size_clauses, (uint64_t) new_size_clauses);
CheckerClause **new_clauses;
new_clauses = new CheckerClause *[new_size_clauses];
clear_n (new_clauses, new_size_clauses);
for (uint64_t i = 0; i < size_clauses; i++) {
for (CheckerClause *c = clauses[i], *next; c; c = next) {
next = c->next;
const uint64_t h = reduce_hash (c->hash, new_size_clauses);
c->next = new_clauses[h];
new_clauses[h] = c;
}
}
delete[] clauses;
clauses = new_clauses;
size_clauses = new_size_clauses;
}
bool Checker::clause_satisfied (CheckerClause *c) {
for (unsigned i = 0; i < c->size; i++)
if (val (c->literals[i]) > 0)
return true;
return false;
}
void Checker::collect_garbage_clauses () {
stats.collections++;
for (size_t i = 0; i < size_clauses; i++) {
CheckerClause **p = clauses + i, *c;
while ((c = *p)) {
if (clause_satisfied (c)) {
c->size = 0; *p = c->next;
c->next = garbage;
garbage = c;
num_garbage++;
assert (num_clauses);
num_clauses--;
} else
p = &c->next;
}
}
LOG ("CHECKER collecting %" PRIu64 " garbage clauses %.0f%%", num_garbage,
percent (num_garbage, num_clauses));
for (int lit = -size_vars + 1; lit < size_vars; lit++) {
if (!lit)
continue;
CheckerWatcher &ws = watcher (lit);
const auto end = ws.end ();
auto j = ws.begin (), i = j;
for (; i != end; i++) {
CheckerWatch &w = *i;
if (w.clause->size)
*j++ = w;
}
if (j == ws.end ())
continue;
if (j == ws.begin ())
erase_vector (ws);
else
ws.resize (j - ws.begin ());
}
for (CheckerClause *c = garbage, *next; c; c = next)
next = c->next, delete_clause (c);
assert (!num_garbage);
garbage = 0;
}
Checker::Checker (Internal *i)
: internal (i), size_vars (0), vals (0), inconsistent (false),
num_clauses (0), num_garbage (0), size_clauses (0), clauses (0),
garbage (0), next_to_propagate (0), last_hash (0) {
Random random (42);
for (unsigned n = 0; n < num_nonces; n++) {
uint64_t nonce = random.next ();
if (!(nonce & 1))
nonce++;
assert (nonce), assert (nonce & 1);
nonces[n] = nonce;
}
memset (&stats, 0, sizeof (stats)); }
void Checker::connect_internal (Internal *i) {
internal = i;
LOG ("CHECKER connected to internal");
}
Checker::~Checker () {
LOG ("CHECKER delete");
vals -= size_vars;
delete[] vals;
for (size_t i = 0; i < size_clauses; i++)
for (CheckerClause *c = clauses[i], *next; c; c = next)
next = c->next, delete_clause (c);
for (CheckerClause *c = garbage, *next; c; c = next)
next = c->next, delete_clause (c);
delete[] clauses;
}
void Checker::enlarge_vars (int64_t idx) {
assert (0 < idx), assert (idx <= INT_MAX);
int64_t new_size_vars = size_vars ? 2 * size_vars : 2;
while (idx >= new_size_vars)
new_size_vars *= 2;
LOG ("CHECKER enlarging variables of checker from %" PRId64 " to %" PRId64
"",
size_vars, new_size_vars);
signed char *new_vals;
new_vals = new signed char[2 * new_size_vars];
clear_n (new_vals, 2 * new_size_vars);
new_vals += new_size_vars;
if (size_vars) memcpy ((void *) (new_vals - size_vars), (void *) (vals - size_vars),
2 * size_vars);
vals -= size_vars;
delete[] vals;
vals = new_vals;
watchers.resize (2 * new_size_vars);
marks.resize (2 * new_size_vars);
assert (idx < new_size_vars);
size_vars = new_size_vars;
}
inline void Checker::import_literal (int lit) {
assert (lit);
assert (lit != INT_MIN);
int idx = abs (lit);
if (idx >= size_vars)
enlarge_vars (idx);
simplified.push_back (lit);
unsimplified.push_back (lit);
}
void Checker::import_clause (const vector<int> &c) {
for (const auto &lit : c)
import_literal (lit);
}
struct lit_smaller {
bool operator() (int a, int b) const {
int c = abs (a), d = abs (b);
if (c < d)
return true;
if (c > d)
return false;
return a < b;
}
};
bool Checker::tautological () {
sort (simplified.begin (), simplified.end (), lit_smaller ());
const auto end = simplified.end ();
auto j = simplified.begin ();
int prev = 0;
for (auto i = j; i != end; i++) {
int lit = *i;
if (lit == prev)
continue; if (lit == -prev)
return true; const signed char tmp = val (lit);
if (tmp > 0)
return true; *j++ = prev = lit;
}
simplified.resize (j - simplified.begin ());
return false;
}
uint64_t Checker::reduce_hash (uint64_t hash, uint64_t size) {
assert (size > 0);
unsigned shift = 32;
uint64_t res = hash;
while ((((uint64_t) 1) << shift) > size) {
res ^= res >> shift;
shift >>= 1;
}
res &= size - 1;
assert (res < size);
return res;
}
uint64_t Checker::compute_hash () {
unsigned j = last_id % num_nonces;
uint64_t tmp = nonces[j] * last_id;
return last_hash = tmp;
}
CheckerClause **Checker::find () {
stats.searches++;
CheckerClause **res, *c;
const uint64_t hash = compute_hash ();
const unsigned size = simplified.size ();
const uint64_t h = reduce_hash (hash, size_clauses);
for (const auto &lit : simplified)
mark (lit) = true;
for (res = clauses + h; (c = *res); res = &c->next) {
if (c->hash == hash && c->size == size) {
bool found = true;
const int *literals = c->literals;
for (unsigned i = 0; found && i != size; i++)
found = mark (literals[i]);
if (found)
break;
}
stats.collisions++;
}
for (const auto &lit : simplified)
mark (lit) = false;
return res;
}
void Checker::insert () {
stats.insertions++;
if (num_clauses == size_clauses)
enlarge_clauses ();
const uint64_t h = reduce_hash (compute_hash (), size_clauses);
CheckerClause *c = new_clause ();
c->next = clauses[h];
clauses[h] = c;
}
inline void Checker::assign (int lit) {
assert (!val (lit));
vals[lit] = 1;
vals[-lit] = -1;
trail.push_back (lit);
}
inline void Checker::assume (int lit) {
signed char tmp = val (lit);
if (tmp > 0)
return;
assert (!tmp);
stats.assumptions++;
assign (lit);
}
void Checker::backtrack (unsigned previously_propagated) {
assert (previously_propagated <= trail.size ());
while (trail.size () > previously_propagated) {
int lit = trail.back ();
assert (val (lit) > 0);
assert (val (-lit) < 0);
vals[lit] = vals[-lit] = 0;
trail.pop_back ();
}
trail.resize (previously_propagated);
next_to_propagate = previously_propagated;
assert (trail.size () == next_to_propagate);
}
bool Checker::propagate () {
bool res = true;
while (res && next_to_propagate < trail.size ()) {
int lit = trail[next_to_propagate++];
stats.propagations++;
assert (val (lit) > 0);
assert (abs (lit) < size_vars);
CheckerWatcher &ws = watcher (-lit);
const auto end = ws.end ();
auto j = ws.begin (), i = j;
for (; res && i != end; i++) {
CheckerWatch &w = *j++ = *i;
const int blit = w.blit;
assert (blit != -lit);
const signed char blit_val = val (blit);
if (blit_val > 0)
continue;
const unsigned size = w.size;
if (size == 2) { if (blit_val < 0)
res = false; else
assign (w.blit); } else {
assert (size > 2);
CheckerClause *c = w.clause;
if (!c->size) {
j--;
continue;
} assert (size == c->size);
int *lits = c->literals;
int other = lits[0] ^ lits[1] ^ (-lit);
assert (other != -lit);
signed char other_val = val (other);
if (other_val > 0) {
j[-1].blit = other;
continue;
}
lits[0] = other, lits[1] = -lit;
unsigned k;
int replacement = 0;
signed char replacement_val = -1;
for (k = 2; k < size; k++)
if ((replacement_val = val (replacement = lits[k])) >= 0)
break;
if (replacement_val >= 0) {
watcher (replacement).push_back (CheckerWatch (-lit, c));
swap (lits[1], lits[k]);
j--;
} else if (!other_val)
assign (other);
else
res = false;
}
}
while (i != end)
*j++ = *i++;
ws.resize (j - ws.begin ());
}
return res;
}
bool Checker::check () {
stats.checks++;
if (inconsistent)
return true;
unsigned previously_propagated = next_to_propagate;
for (const auto &lit : simplified)
assume (-lit);
bool res = !propagate ();
backtrack (previously_propagated);
return res;
}
void Checker::add_clause (const char *type) {
#ifndef LOGGING
(void) type;
#endif
int unit = 0;
for (const auto &lit : simplified) {
const signed char tmp = val (lit);
if (tmp < 0)
continue;
assert (!tmp);
if (unit) {
unit = INT_MIN;
break;
}
unit = lit;
}
if (simplified.empty ()) {
LOG ("CHECKER added empty %s clause", type);
inconsistent = true;
}
if (!unit) {
LOG ("CHECKER added and checked falsified %s clause", type);
inconsistent = true;
} else if (unit != INT_MIN) {
LOG ("CHECKER added and checked %s unit clause %d", type, unit);
assign (unit);
stats.units++;
if (!propagate ()) {
LOG ("CHECKER inconsistent after propagating %s unit", type);
inconsistent = true;
}
} else
insert ();
}
void Checker::add_original_clause (uint64_t id, bool, const vector<int> &c,
bool) {
if (inconsistent)
return;
START (checking);
LOG (c, "CHECKER addition of original clause");
stats.added++;
stats.original++;
import_clause (c);
last_id = id;
if (tautological ())
LOG ("CHECKER ignoring satisfied original clause");
else
add_clause ("original");
simplified.clear ();
unsimplified.clear ();
STOP (checking);
}
void Checker::add_derived_clause (uint64_t id, bool, const vector<int> &c,
const vector<uint64_t> &) {
if (inconsistent)
return;
START (checking);
LOG (c, "CHECKER addition of derived clause");
stats.added++;
stats.derived++;
import_clause (c);
last_id = id;
if (tautological ())
LOG ("CHECKER ignoring satisfied derived clause");
else if (!check ()) {
fatal_message_start ();
fputs ("failed to check derived clause:\n", stderr);
for (const auto &lit : unsimplified)
fprintf (stderr, "%d ", lit);
fputc ('0', stderr);
fatal_message_end ();
} else
add_clause ("derived");
simplified.clear ();
unsimplified.clear ();
STOP (checking);
}
void Checker::delete_clause (uint64_t id, bool, const vector<int> &c) {
if (inconsistent)
return;
START (checking);
LOG (c, "CHECKER checking deletion of clause");
stats.deleted++;
import_clause (c);
last_id = id;
if (!tautological ()) {
CheckerClause **p = find (), *d = *p;
if (d) {
assert (d->size > 1);
num_garbage++;
assert (num_clauses);
num_clauses--;
*p = d->next;
d->next = garbage;
garbage = d;
d->size = 0;
if (num_garbage >
0.5 * max ((size_t) size_clauses, (size_t) size_vars))
collect_garbage_clauses ();
} else {
fatal_message_start ();
fputs ("deleted clause not in proof:\n", stderr);
for (const auto &lit : unsimplified)
fprintf (stderr, "%d ", lit);
fputc ('0', stderr);
fatal_message_end ();
}
}
simplified.clear ();
unsimplified.clear ();
STOP (checking);
}
void Checker::add_assumption_clause (uint64_t id, const vector<int> &c,
const vector<uint64_t> &chain) {
add_derived_clause (id, true, c, chain);
delete_clause (id, true, c);
}
void Checker::dump () {
int max_var = 0;
for (uint64_t i = 0; i < size_clauses; i++)
for (CheckerClause *c = clauses[i]; c; c = c->next)
for (unsigned i = 0; i < c->size; i++)
if (abs (c->literals[i]) > max_var)
max_var = abs (c->literals[i]);
printf ("p cnf %d %" PRIu64 "\n", max_var, num_clauses);
for (uint64_t i = 0; i < size_clauses; i++)
for (CheckerClause *c = clauses[i]; c; c = c->next) {
for (unsigned i = 0; i < c->size; i++)
printf ("%d ", c->literals[i]);
printf ("0\n");
}
}
}