use std::hash::Hash;
use ahash::HashSet;
use ahash::HashSetExt;
use indexmap::IndexMap;
use itertools::Itertools;
use mago_atom::Atom;
use mago_atom::AtomSet;
use mago_codex::assertion::Assertion;
use mago_span::Span;
use crate::assertion_set::AssertionSet;
use crate::clause::Clause;
pub mod assertion_set;
pub mod clause;
pub type SatisfyingAssignments = IndexMap<Atom, AssertionSet>;
pub type ActiveTruths = IndexMap<Atom, HashSet<usize>>;
pub const DEFAULT_SATURATION_COMPLEXITY: u16 = 8_192;
pub const DEFAULT_DISJUNCTION_COMPLEXITY: u16 = 4_096;
pub const DEFAULT_NEGATION_COMPLEXITY: u16 = 4_096;
pub const DEFAULT_CONSENSUS_LIMIT: u16 = 256;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct AlgebraThresholds {
pub saturation_complexity: u16,
pub disjunction_complexity: u16,
pub negation_complexity: u16,
pub consensus_limit: u16,
}
impl Default for AlgebraThresholds {
fn default() -> Self {
Self {
saturation_complexity: DEFAULT_SATURATION_COMPLEXITY,
disjunction_complexity: DEFAULT_DISJUNCTION_COMPLEXITY,
negation_complexity: DEFAULT_NEGATION_COMPLEXITY,
consensus_limit: DEFAULT_CONSENSUS_LIMIT,
}
}
}
#[inline]
pub fn saturate_clauses<'a>(
clauses: impl IntoIterator<Item = &'a Clause>,
thresholds: &AlgebraThresholds,
) -> Vec<Clause> {
fn saturate_clauses_inner(
unique_clauses: Vec<&Clause>,
saturation_complexity: usize,
consensus_limit: usize,
) -> Vec<Clause> {
let unique_clauses_len = unique_clauses.len();
if unique_clauses_len == 0 || unique_clauses_len > saturation_complexity {
return vec![];
}
let mut removed_indices: HashSet<usize> = HashSet::default();
let mut added_clauses: Vec<Clause> = Vec::new();
'outer: for (clause_a_idx, clause_a) in unique_clauses.iter().enumerate() {
if !clause_a.reconcilable || clause_a.wedge {
continue;
}
let is_clause_a_simple = clause_a.possibilities.len() == 1
&& clause_a.possibilities.values().next().is_some_and(|p| p.len() == 1);
if is_clause_a_simple {
let (clause_var, var_possibilities) = clause_a.possibilities.iter().next().unwrap();
let only_type = var_possibilities.values().next().unwrap();
let negated_clause_type = only_type.get_negation();
let negated_hash = negated_clause_type.to_hash();
for (clause_b_idx, clause_b) in unique_clauses.iter().enumerate() {
if clause_a_idx == clause_b_idx || removed_indices.contains(&clause_b_idx) {
continue;
}
if !clause_b.reconcilable || clause_b.wedge {
continue;
}
let Some(matching_clause_possibilities) = clause_b.possibilities.get(clause_var) else {
continue;
};
if !matching_clause_possibilities.contains_key(&negated_hash) {
continue;
}
let mut clause_var_possibilities = matching_clause_possibilities.clone();
clause_var_possibilities.retain(|k, _| k != &negated_hash);
removed_indices.insert(clause_b_idx);
if clause_var_possibilities.is_empty() {
if let Some(updated_clause) = clause_b.remove_possibilities(clause_var) {
added_clauses.push(updated_clause);
}
} else {
let updated_clause = clause_b.add_possibility(*clause_var, clause_var_possibilities);
added_clauses.push(updated_clause);
}
}
} else {
let clause_a_size = clause_a.possibilities.len();
'inner: for (clause_b_idx, clause_b) in unique_clauses.iter().enumerate() {
if clause_a_idx >= clause_b_idx || removed_indices.contains(&clause_b_idx) {
continue;
}
if !clause_b.reconcilable || clause_b.wedge {
continue;
}
if clause_b.possibilities.len() != clause_a_size {
continue;
}
let mut opposing_key = None;
let mut mismatch = false;
for (key, a_possibilities) in &clause_a.possibilities {
if let Some(b_possibilities) = clause_b.possibilities.get(key) {
if index_keys_match(a_possibilities, b_possibilities) {
continue;
}
if a_possibilities.len() == 1
&& b_possibilities.len() == 1
&& a_possibilities.values().next().is_some_and(|a| {
b_possibilities.values().next().is_some_and(|b| a.is_negation_of(b))
})
{
if opposing_key.is_some() {
mismatch = true;
break;
}
opposing_key = Some(key);
} else {
mismatch = true;
break;
}
} else {
mismatch = true;
break;
}
}
if mismatch {
continue 'inner;
}
if let Some(key_to_remove) = opposing_key {
removed_indices.insert(clause_a_idx);
let maybe_new_clause = clause_a.remove_possibilities(key_to_remove);
if let Some(new_clause) = maybe_new_clause {
added_clauses.push(new_clause);
} else {
continue 'outer;
}
}
}
}
}
let mut seen_hashes: HashSet<u32> = HashSet::with_capacity(unique_clauses_len);
let mut combined_clauses: Vec<Clause> = Vec::with_capacity(unique_clauses_len);
for (idx, clause) in unique_clauses.iter().enumerate() {
if !removed_indices.contains(&idx) && seen_hashes.insert(clause.hash) {
combined_clauses.push((*clause).clone());
}
}
for clause in added_clauses {
if seen_hashes.insert(clause.hash) {
combined_clauses.push(clause);
}
}
let mut simplified_clauses: Vec<Clause> = Vec::with_capacity(combined_clauses.len());
for clause_a in &combined_clauses {
if clause_a.wedge {
simplified_clauses.push(clause_a.clone());
continue;
}
let mut is_redundant = false;
for clause_b in &combined_clauses {
if std::ptr::eq(clause_a, clause_b) {
continue;
}
if !clause_b.reconcilable || clause_b.wedge {
continue;
}
if clause_b.possibilities.len() >= clause_a.possibilities.len() {
continue;
}
if clause_a.contains(clause_b) {
is_redundant = true;
break;
}
}
if !is_redundant {
simplified_clauses.push(clause_a.clone());
}
}
let simplified_clauses_len = simplified_clauses.len();
if simplified_clauses_len > 2 && simplified_clauses_len < consensus_limit {
let mut compared_clauses: HashSet<(u32, u32)> = HashSet::default();
let mut removed_hashes: HashSet<u32> = HashSet::default();
for (clause_a_idx, clause_a) in simplified_clauses.iter().enumerate() {
for clause_b in simplified_clauses.iter().skip(clause_a_idx + 1) {
if compared_clauses.contains(&(clause_b.hash, clause_a.hash)) {
continue;
}
compared_clauses.insert((clause_a.hash, clause_b.hash));
let common_keys: Vec<_> =
clause_a.possibilities.keys().filter(|k| clause_b.possibilities.contains_key(*k)).collect();
if common_keys.is_empty() {
continue;
}
let mut common_negated_keys: HashSet<&Atom> = HashSet::default();
for common_key in &common_keys {
let clause_a_possibilities = &clause_a.possibilities[*common_key];
let clause_b_possibilities = &clause_b.possibilities[*common_key];
if clause_a_possibilities.len() == 1
&& clause_b_possibilities.len() == 1
&& clause_a_possibilities.values().next().is_some_and(|a| {
clause_b_possibilities.values().next().is_some_and(|b| a.is_negation_of(b))
})
{
common_negated_keys.insert(*common_key);
}
}
if !common_negated_keys.is_empty() {
let mut new_possibilities: IndexMap<Atom, IndexMap<u64, Assertion>> = IndexMap::default();
for (var_id, possibilities) in &clause_a.possibilities {
if !common_negated_keys.contains(var_id) {
new_possibilities
.entry(*var_id)
.or_default()
.extend(possibilities.iter().map(|(&k, v)| (k, v.clone())));
}
}
for (var_id, possibilities) in &clause_b.possibilities {
if !common_negated_keys.contains(var_id) {
new_possibilities
.entry(*var_id)
.or_default()
.extend(possibilities.iter().map(|(&k, v)| (k, v.clone())));
}
}
let conflict_clause =
Clause::new(new_possibilities, clause_a.condition_span, clause_a.span, None, None, None);
removed_hashes.insert(conflict_clause.hash);
}
}
}
simplified_clauses.retain(|f| !removed_hashes.contains(&f.hash));
}
simplified_clauses
}
let unique_clauses = clauses.into_iter().unique().collect::<Vec<_>>();
saturate_clauses_inner(unique_clauses, thresholds.saturation_complexity.into(), thresholds.consensus_limit.into())
}
#[inline]
pub fn find_satisfying_assignments(
clauses: &[Clause],
creating_conditional_id: Option<Span>,
conditionally_referenced_var_ids: &mut AtomSet,
) -> (SatisfyingAssignments, ActiveTruths) {
let mut truths: IndexMap<Atom, AssertionSet> = IndexMap::default();
let mut active_truths: IndexMap<Atom, HashSet<usize>> = IndexMap::default();
for clause in clauses {
if !clause.generated {
for var_id in clause.possibilities.keys() {
if !var_id.as_str().starts_with('*') {
conditionally_referenced_var_ids.insert(*var_id);
}
}
}
if !clause.reconcilable || clause.possibilities.len() != 1 {
continue;
}
let (variable_id, possible_types) = clause.possibilities.iter().next().unwrap();
if variable_id.as_str().starts_with('*') {
continue;
}
let assertions = possible_types.values().cloned().collect::<Vec<_>>();
let truth_entry = truths.entry(*variable_id).or_default();
let new_truth_index = truth_entry.len();
truth_entry.push(assertions);
if let Some(creating_conditional_id) = creating_conditional_id
&& creating_conditional_id == clause.condition_span
{
active_truths.entry(*variable_id).or_default().insert(new_truth_index);
}
}
(truths, active_truths)
}
#[inline]
#[must_use]
pub fn disjoin_clauses(
left_clauses: Vec<Clause>,
right_clauses: Vec<Clause>,
conditional_object_id: Span,
thresholds: &AlgebraThresholds,
) -> Vec<Clause> {
let left_clauses_len = left_clauses.len();
let right_clauses_len = right_clauses.len();
if left_clauses_len == 0 {
return right_clauses;
}
if right_clauses_len == 0 {
return left_clauses;
}
if left_clauses_len > usize::from(thresholds.disjunction_complexity)
|| right_clauses_len > usize::from(thresholds.disjunction_complexity)
{
return vec![];
}
let mut clauses = vec![];
let mut has_wedge = false;
for left_clause in left_clauses {
for right_clause in &right_clauses {
if left_clause.wedge && right_clause.wedge {
has_wedge = true;
continue;
}
if left_clause.wedge {
clauses.push(right_clause.clone());
continue;
}
if right_clause.wedge {
clauses.push(left_clause.clone());
continue;
}
let mut possibilities = left_clause.possibilities.clone();
for (var, possible_types) in &right_clause.possibilities {
possibilities.entry(*var).or_default().extend(possible_types.iter().map(|(&k, v)| (k, v.clone())));
}
let is_tautology = possibilities.values().any(|var_possibilities| {
if var_possibilities.len() > 1 {
let vals = var_possibilities.values().collect::<Vec<_>>();
for (i, v1) in vals.iter().enumerate() {
for v2 in &vals[i + 1..] {
if v1.is_negation_of(v2) {
return true;
}
}
}
}
false
});
if is_tautology {
continue;
}
clauses.push(Clause::new(
possibilities,
conditional_object_id,
conditional_object_id,
Some(false),
Some(left_clause.reconcilable && right_clause.reconcilable),
Some(true),
));
}
}
if has_wedge {
}
clauses
}
#[inline]
#[must_use]
pub fn negate_formula(mut clauses: Vec<Clause>, thresholds: &AlgebraThresholds) -> Option<Vec<Clause>> {
clauses.retain(|clause| clause.reconcilable);
if clauses.is_empty() {
return Some(vec![]);
}
let impossible_clauses = group_impossibilities(clauses, thresholds.negation_complexity.into())?;
if impossible_clauses.is_empty() {
return Some(vec![]);
}
let negated = saturate_clauses(impossible_clauses.iter().as_slice(), thresholds);
Some(negated)
}
#[inline]
fn group_impossibilities(mut clauses: Vec<Clause>, max_complexity: usize) -> Option<Vec<Clause>> {
let mut seed_clauses = Vec::new();
let mut complexity = 1usize;
let Some(clause) = clauses.pop() else {
return Some(seed_clauses);
};
if !clause.wedge {
let impossibilities = clause.get_impossibilities();
for (var, impossible_types) in &impossibilities {
for impossible_type in impossible_types {
let mut seed_clause_possibilities = IndexMap::new();
seed_clause_possibilities
.insert(*var, IndexMap::from([(impossible_type.to_hash(), impossible_type.clone())]));
let seed_clause =
Clause::new(seed_clause_possibilities, clause.condition_span, clause.span, None, None, None);
seed_clauses.push(seed_clause);
}
}
}
if clauses.is_empty() || seed_clauses.is_empty() {
return Some(seed_clauses);
}
let mut complexity_upper_bound = seed_clauses.len();
for clause in &clauses {
let mut possibilities_count = 0;
let impossibilities = clause.get_impossibilities();
for impossible_types in impossibilities.values() {
possibilities_count += impossible_types.len();
}
complexity_upper_bound = complexity_upper_bound.saturating_mul(possibilities_count);
if complexity_upper_bound > max_complexity {
return None;
}
}
while let Some(clause) = clauses.pop() {
let mut new_clauses = Vec::with_capacity(seed_clauses.len() * 4);
let clause_impossibilities = clause.get_impossibilities();
for grouped_clause in &seed_clauses {
for (var, impossible_types) in &clause_impossibilities {
'next: for impossible_type in impossible_types {
complexity += 1;
if complexity > max_complexity {
return None;
}
if let Some(new_insert_value) = grouped_clause.possibilities.get(var) {
for (_, a) in new_insert_value {
if a.is_negation_of(impossible_type) {
break 'next;
}
}
}
let mut new_clause_possibilities = grouped_clause.possibilities.clone();
new_clause_possibilities
.entry(*var)
.or_insert_with(IndexMap::new)
.insert(impossible_type.to_hash(), impossible_type.clone());
new_clauses.push(Clause::new(
new_clause_possibilities,
grouped_clause.condition_span,
clause.span,
Some(false),
Some(true),
Some(true),
));
}
}
}
seed_clauses = new_clauses;
}
seed_clauses.reverse();
Some(seed_clauses)
}
#[inline]
fn index_keys_match<T: Eq + Ord + Hash, U, V>(map1: &IndexMap<T, U>, map2: &IndexMap<T, V>) -> bool {
map1.len() == map2.len() && map1.keys().all(|k| map2.contains_key(k))
}