#include "internal.hpp"
namespace CaDiCaL {
int Internal::second_literal_in_binary_clause (Eliminator &eliminator,
Clause *c, int first) {
assert (!c->garbage);
int second = 0;
for (const auto &lit : *c) {
if (lit == first)
continue;
const signed char tmp = val (lit);
if (tmp < 0)
continue;
if (tmp > 0) {
mark_garbage (c);
elim_update_removed_clause (eliminator, c);
return 0;
}
if (second) {
second = INT_MIN;
break;
}
second = lit;
}
if (!second)
return 0;
if (second == INT_MIN)
return 0;
assert (active (second));
#ifdef LOGGING
if (c->size == 2)
LOG (c, "found binary");
else
LOG (c, "found actual binary %d %d", first, second);
#endif
return second;
}
int Internal::second_literal_in_binary_clause_lrat (Clause *c, int first) {
if (c->garbage)
return 0;
int second = 0;
for (const auto &lit : *c) {
if (lit == first)
continue;
const signed char tmp = val (lit);
if (tmp < 0)
continue;
if (tmp > 0)
return 0;
if (!tmp) {
if (second) {
second = INT_MIN;
break;
}
second = lit;
}
}
if (!second)
return 0;
if (second == INT_MIN)
return 0;
return second;
}
Clause *Internal::find_binary_clause (int first, int second) {
int best = first;
int other = second;
if (occs (first).size () > occs (second).size ()) {
best = second;
other = first;
}
for (auto c : occs (best))
if (second_literal_in_binary_clause_lrat (c, best) == other)
return c;
return 0;
}
void Internal::mark_binary_literals (Eliminator &eliminator, int first) {
if (unsat)
return;
if (val (first))
return;
if (!eliminator.gates.empty ())
return;
assert (!marked (first));
assert (eliminator.marked.empty ());
const Occs &os = occs (first);
for (const auto &c : os) {
if (c->garbage)
continue;
const int second =
second_literal_in_binary_clause (eliminator, c, first);
if (!second)
continue;
const int tmp = marked (second);
if (tmp < 0) {
LOG ("found binary resolved unit %d", first);
if (lrat) {
Clause *d = find_binary_clause (first, -second);
assert (d);
for (auto &lit : *d) {
if (lit == first || lit == -second)
continue;
assert (val (lit) < 0);
Flags &f = flags (lit);
if (f.seen)
continue;
analyzed.push_back (lit);
f.seen = true;
int64_t id = unit_id (-lit);
lrat_chain.push_back (id);
}
for (auto &lit : *c) {
if (lit == first || lit == second)
continue;
assert (val (lit) < 0);
Flags &f = flags (lit);
if (f.seen)
continue;
analyzed.push_back (lit);
f.seen = true;
int64_t id = unit_id (-lit);
lrat_chain.push_back (id);
}
lrat_chain.push_back (c->id);
lrat_chain.push_back (d->id);
clear_analyzed_literals ();
}
assign_unit (first);
elim_propagate (eliminator, first);
return;
}
if (tmp > 0) {
LOG (c, "duplicated actual binary clause");
elim_update_removed_clause (eliminator, c);
mark_garbage (c);
continue;
}
eliminator.marked.push_back (second);
mark (second);
LOG ("marked second literal %d in binary clause %d %d", second, first,
second);
}
}
void Internal::unmark_binary_literals (Eliminator &eliminator) {
LOG ("unmarking %zd literals", eliminator.marked.size ());
for (const auto &lit : eliminator.marked)
unmark (lit);
eliminator.marked.clear ();
}
void Internal::find_equivalence (Eliminator &eliminator, int pivot) {
if (!opts.elimequivs)
return;
assert (opts.elimsubst);
if (unsat)
return;
if (val (pivot))
return;
if (!eliminator.gates.empty ())
return;
mark_binary_literals (eliminator, pivot);
if (unsat || val (pivot))
goto DONE;
for (const auto &c : occs (-pivot)) {
if (c->garbage)
continue;
const int second =
second_literal_in_binary_clause (eliminator, c, -pivot);
if (!second)
continue;
const int tmp = marked (second);
if (tmp > 0) {
LOG ("found binary resolved unit %d", second);
if (lrat) {
Clause *d = find_binary_clause (pivot, second);
assert (d);
for (auto &lit : *d) {
if (lit == pivot || lit == second)
continue;
assert (val (lit) < 0);
Flags &f = flags (lit);
if (f.seen)
continue;
analyzed.push_back (lit);
f.seen = true;
int64_t id = unit_id (-lit);
lrat_chain.push_back (id);
}
for (auto &lit : *c) {
if (lit == -pivot || lit == second)
continue;
assert (val (lit) < 0);
Flags &f = flags (lit);
if (f.seen)
continue;
analyzed.push_back (lit);
f.seen = true;
int64_t id = unit_id (-lit);
lrat_chain.push_back (id);
}
lrat_chain.push_back (c->id);
lrat_chain.push_back (d->id);
clear_analyzed_literals ();
}
assign_unit (second);
elim_propagate (eliminator, second);
if (val (pivot))
break;
if (unsat)
break;
}
if (tmp >= 0)
continue;
LOG ("found equivalence %d = %d", pivot, -second);
stats.elimequivs++;
stats.elimgates++;
LOG (c, "first gate clause");
assert (!c->gate);
c->gate = true;
eliminator.gates.push_back (c);
Clause *d = 0;
const Occs &ps = occs (pivot);
for (const auto &e : ps) {
if (e->garbage)
continue;
const int other =
second_literal_in_binary_clause (eliminator, e, pivot);
if (other == -second) {
d = e;
break;
}
}
assert (d);
LOG (d, "second gate clause");
assert (!d->gate);
d->gate = true;
eliminator.gates.push_back (d);
eliminator.gatetype = EQUI;
break;
}
DONE:
unmark_binary_literals (eliminator);
}
void Internal::find_and_gate (Eliminator &eliminator, int pivot) {
if (!opts.elimands)
return;
assert (opts.elimsubst);
if (unsat)
return;
if (val (pivot))
return;
if (!eliminator.gates.empty ())
return;
mark_binary_literals (eliminator, pivot);
if (unsat || val (pivot))
goto DONE;
for (const auto &c : occs (-pivot)) {
if (c->garbage)
continue;
if (c->size < 3)
continue;
bool all_literals_marked = true;
unsigned arity = 0;
int satisfied = 0;
for (const auto &lit : *c) {
if (lit == -pivot)
continue;
assert (lit != pivot);
signed char tmp = val (lit);
if (tmp < 0)
continue;
if (tmp > 0) {
satisfied = lit;
break;
}
tmp = marked (lit);
if (tmp < 0) {
arity++;
continue;
}
all_literals_marked = false;
break;
}
if (!all_literals_marked)
continue;
if (satisfied) {
LOG (c, "satisfied by %d candidate base clause", satisfied);
mark_garbage (c);
continue;
}
#ifdef LOGGING
if (opts.log) {
Logger::print_log_prefix (this);
tout.magenta ();
printf ("found arity %u AND gate %d = ", arity, -pivot);
bool first = true;
for (const auto &lit : *c) {
if (lit == -pivot)
continue;
assert (lit != pivot);
if (!first)
fputs (" & ", stdout);
printf ("%d", -lit);
first = false;
}
fputc ('\n', stdout);
tout.normal ();
fflush (stdout);
}
#endif
stats.elimands++;
stats.elimgates++;
eliminator.gatetype = AND;
(void) arity;
assert (!c->gate);
c->gate = true;
eliminator.gates.push_back (c);
for (const auto &lit : *c) {
if (lit == -pivot)
continue;
assert (lit != pivot);
signed char tmp = val (lit);
if (tmp < 0)
continue;
assert (!tmp);
assert (marked (lit) < 0);
marks[vidx (lit)] *= 2;
}
unsigned count = 0;
for (const auto &d : occs (pivot)) {
if (d->garbage)
continue;
const int other =
second_literal_in_binary_clause (eliminator, d, pivot);
if (!other)
continue;
const int tmp = marked (other);
if (tmp != 2)
continue;
LOG (d, "AND gate binary side clause");
assert (!d->gate);
d->gate = true;
eliminator.gates.push_back (d);
count++;
}
assert (count >= arity);
(void) count;
break;
}
DONE:
unmark_binary_literals (eliminator);
}
bool Internal::get_ternary_clause (Clause *d, int &a, int &b, int &c) {
if (d->garbage)
return false;
if (d->size < 3)
return false;
int found = 0;
a = b = c = 0;
for (const auto &lit : *d) {
if (val (lit))
continue;
if (++found == 1)
a = lit;
else if (found == 2)
b = lit;
else if (found == 3)
c = lit;
else
return false;
}
return found == 3;
}
bool Internal::match_ternary_clause (Clause *d, int a, int b, int c) {
if (d->garbage)
return false;
int found = 0;
for (const auto &lit : *d) {
if (val (lit))
continue;
if (a != lit && b != lit && c != lit)
return false;
found++;
}
return found == 3;
}
Clause *Internal::find_ternary_clause (int a, int b, int c) {
if (occs (b).size () > occs (c).size ())
swap (b, c);
if (occs (a).size () > occs (b).size ())
swap (a, b);
for (auto d : occs (a))
if (match_ternary_clause (d, a, b, c))
return d;
return 0;
}
void Internal::find_if_then_else (Eliminator &eliminator, int pivot) {
if (!opts.elimites)
return;
assert (opts.elimsubst);
if (unsat)
return;
if (val (pivot))
return;
if (!eliminator.gates.empty ())
return;
const Occs &os = occs (pivot);
const auto end = os.end ();
for (auto i = os.begin (); i != end; i++) {
Clause *di = *i;
int ai, bi, ci;
if (!get_ternary_clause (di, ai, bi, ci))
continue;
if (bi == pivot)
swap (ai, bi);
if (ci == pivot)
swap (ai, ci);
assert (ai == pivot);
for (auto j = i + 1; j != end; j++) {
Clause *dj = *j;
int aj, bj, cj;
if (!get_ternary_clause (dj, aj, bj, cj))
continue;
if (bj == pivot)
swap (aj, bj);
if (cj == pivot)
swap (aj, cj);
assert (aj == pivot);
if (abs (bi) == abs (cj))
swap (bj, cj);
if (abs (ci) == abs (cj))
continue;
if (bi != -bj)
continue;
Clause *d1 = find_ternary_clause (-pivot, bi, -ci);
if (!d1)
continue;
Clause *d2 = find_ternary_clause (-pivot, bj, -cj);
if (!d2)
continue;
LOG (di, "1st if-then-else");
LOG (dj, "2nd if-then-else");
LOG (d1, "3rd if-then-else");
LOG (d2, "4th if-then-else");
LOG ("found ITE gate %d == (%d ? %d : %d)", pivot, -bi, -ci, -cj);
assert (!di->gate);
assert (!dj->gate);
assert (!d1->gate);
assert (!d2->gate);
di->gate = true;
dj->gate = true;
d1->gate = true;
d2->gate = true;
eliminator.gates.push_back (di);
eliminator.gates.push_back (dj);
eliminator.gates.push_back (d1);
eliminator.gates.push_back (d2);
stats.elimgates++;
stats.elimites++;
eliminator.gatetype = ITE;
return;
}
}
}
bool Internal::get_clause (Clause *c, vector<int> &l) {
if (c->garbage)
return false;
l.clear ();
for (const auto &lit : *c) {
if (val (lit) < 0)
continue;
if (val (lit) > 0) {
l.clear ();
return false;
}
l.push_back (lit);
}
return true;
}
bool Internal::is_clause (Clause *c, const vector<int> &lits) {
if (c->garbage)
return false;
int size = lits.size ();
if (c->size < size)
return false;
int found = 0;
for (const auto &lit : *c) {
if (val (lit) < 0)
continue;
if (val (lit) > 0)
return false;
const auto it = find (lits.begin (), lits.end (), lit);
if (it == lits.end ())
return false;
if (++found > size)
return false;
}
return found == size;
}
Clause *Internal::find_clause (const vector<int> &lits) {
int best = 0;
size_t len = 0;
for (const auto &lit : lits) {
size_t l = occs (lit).size ();
if (best && l >= len)
continue;
len = l, best = lit;
}
for (auto c : occs (best))
if (is_clause (c, lits))
return c;
return 0;
}
void Internal::find_xor_gate (Eliminator &eliminator, int pivot) {
if (!opts.elimxors)
return;
assert (opts.elimsubst);
if (unsat)
return;
if (val (pivot))
return;
if (!eliminator.gates.empty ())
return;
vector<int> lits;
for (auto d : occs (pivot)) {
if (!get_clause (d, lits))
continue;
const int size = lits.size (); const int arity = size - 1;
if (size < 3)
continue;
if (arity > opts.elimxorlim)
continue;
assert (eliminator.gates.empty ());
unsigned needed = (1u << arity) - 1; unsigned signs = 0;
do {
const unsigned prev = signs;
while (parity (++signs))
;
for (int j = 0; j < size; j++) {
const unsigned bit = 1u << j;
int lit = lits[j];
if ((prev & bit) != (signs & bit))
lits[j] = lit = -lit;
}
Clause *e = find_clause (lits);
if (!e)
break;
eliminator.gates.push_back (e);
} while (--needed);
if (needed) {
eliminator.gates.clear ();
continue;
}
eliminator.gates.push_back (d);
assert (eliminator.gates.size () == (1u << arity));
#ifdef LOGGING
if (opts.log) {
Logger::print_log_prefix (this);
tout.magenta ();
printf ("found arity %u XOR gate %d = ", arity, -pivot);
bool first = true;
for (const auto &lit : *d) {
if (lit == pivot)
continue;
assert (lit != -pivot);
if (!first)
fputs (" ^ ", stdout);
printf ("%d", lit);
first = false;
}
fputc ('\n', stdout);
tout.normal ();
fflush (stdout);
}
#endif
stats.elimgates++;
stats.elimxors++;
const auto end = eliminator.gates.end ();
auto j = eliminator.gates.begin ();
for (auto i = j; i != end; i++) {
Clause *e = *i;
if (e->gate)
continue;
e->gate = true;
LOG (e, "contributing");
*j++ = e;
}
eliminator.gates.resize (j - eliminator.gates.begin ());
eliminator.gatetype = XOR;
break;
}
}
void Internal::find_gate_clauses (Eliminator &eliminator, int pivot) {
if (!opts.elimsubst)
return;
if (unsat)
return;
if (val (pivot))
return;
assert (eliminator.gates.empty ());
find_equivalence (eliminator, pivot);
find_and_gate (eliminator, pivot);
find_and_gate (eliminator, -pivot);
find_if_then_else (eliminator, pivot);
find_xor_gate (eliminator, pivot);
find_definition (eliminator, pivot);
}
void Internal::unmark_gate_clauses (Eliminator &eliminator) {
LOG ("unmarking %zd gate clauses", eliminator.gates.size ());
for (const auto &c : eliminator.gates) {
assert (c->gate);
c->gate = false;
}
eliminator.gates.clear ();
eliminator.definition_unit = 0;
}
}