use std::marker::PhantomData;
use hashbrown::HashMap;
use bitflags::bitflags;
use crate::expr::{Expression, Node, NodeId};
bitflags! {
#[derive(Clone, Copy, PartialEq)]
pub(crate) struct MergeRelation: u8 {
const TRIVIAL = 0;
const SUBSET = 0b0001; const SUPERSET = 0b0010; const DISJOINT = 0b0100; const COVER = 0b1000;
const EQUAL = Self::SUBSET.bits() | Self::SUPERSET.bits(); const COMPLEMENTARY = Self::DISJOINT.bits() | Self::COVER.bits(); }
}
impl MergeRelation {
pub(crate) fn flip(self) -> Self {
match self {
MergeRelation::SUBSET => MergeRelation::SUPERSET,
MergeRelation::SUPERSET => MergeRelation::SUBSET,
_ => self,
}
}
pub(crate) fn is_subset(&self) -> bool {
self.contains(Self::SUBSET)
}
pub(crate) fn is_superset(&self) -> bool {
self.contains(Self::SUPERSET)
}
pub(crate) fn is_disjoint(&self) -> bool {
self.contains(Self::DISJOINT)
}
pub(crate) fn is_cover(&self) -> bool {
self.contains(Self::COVER)
}
}
pub enum SetRelation {
Trivial,
Subset,
Superset,
Disjoint,
Cover,
Complementary,
Equal,
}
impl From<SetRelation> for MergeRelation {
fn from(r: SetRelation) -> Self {
match r {
SetRelation::Trivial => MergeRelation::TRIVIAL,
SetRelation::Subset => MergeRelation::SUBSET,
SetRelation::Superset => MergeRelation::SUPERSET,
SetRelation::Disjoint => MergeRelation::DISJOINT,
SetRelation::Cover => MergeRelation::COVER,
SetRelation::Complementary => MergeRelation::COMPLEMENTARY,
SetRelation::Equal => MergeRelation::EQUAL,
}
}
}
pub enum MergeResult<T> {
Empty,
Universal,
Set(T, bool),
}
impl<T> From<T> for MergeResult<T> {
fn from(value: T) -> Self {
MergeResult::Set(value, false)
}
}
pub trait Mergeable<T> {
fn get_relation(&mut self, _a: &T, _b: &T) -> SetRelation {
SetRelation::Trivial
}
fn merge_union(
&mut self,
_a: &T,
_a_neg: bool,
_b: &T,
_b_neg: bool,
) -> Option<MergeResult<T>> {
None
}
fn merge_intersection(
&mut self,
_a: &T,
_a_neg: bool,
_b: &T,
_b_neg: bool,
) -> Option<MergeResult<T>> {
None
}
}
impl<T> Mergeable<T> for () {}
pub(crate) struct Merger<'a, T, M: Mergeable<T>> {
pub mergeable: &'a mut M,
cache: HashMap<(usize, usize), (MergeRelation, usize)>,
_mergeable_type: PhantomData<T>,
}
impl<'a, T, M: Mergeable<T>> Merger<'a, T, M> {
pub(crate) fn new(mergeable: &'a mut M) -> Self {
Self {
mergeable,
cache: HashMap::new(),
_mergeable_type: PhantomData,
}
}
pub(crate) fn get_relation(
&mut self,
expr: &Expression<T>,
a: NodeId,
b: NodeId,
depth: usize,
) -> MergeRelation {
if a == b {
return MergeRelation::EQUAL;
}
if a == b.not() {
return MergeRelation::COMPLEMENTARY;
}
self.get_relation_recursive(expr, a, b, depth)
}
fn get_relation_recursive(
&mut self,
expr: &Expression<T>,
a: NodeId,
b: NodeId,
depth: usize,
) -> MergeRelation
where
M: Mergeable<T>,
{
if a == b {
return MergeRelation::EQUAL;
}
if a == b.not() {
return MergeRelation::COMPLEMENTARY;
}
let (min, max) = if a.idx() <= b.idx() { (a, b) } else { (b, a) };
let key = (min.idx(), max.idx());
if let Some(&(cached_rel, cached_depth)) = self.cache.get(&key)
&& cached_depth >= depth
{
let mut final_rel = cached_rel;
if a != min {
final_rel = final_rel.flip();
}
return self.apply_negation_logic(final_rel, a.is_neg(), b.is_neg());
}
if depth == 0 {
return MergeRelation::TRIVIAL;
}
let node_min = &expr.nodes[min.idx()];
let node_max = &expr.nodes[max.idx()];
let rel = match (node_min, node_max) {
(Node::Empty, Node::Empty) => MergeRelation::EQUAL, (Node::Empty, _) | (_, Node::Empty) => MergeRelation::DISJOINT,
(Node::Set(set_min), Node::Set(set_max)) => {
self.mergeable.get_relation(set_min, set_max).into()
}
(Node::Set(_), Node::Union(kids_b)) | (Node::Set(_), Node::Intersection(kids_b)) => {
let is_union = matches!(node_max, Node::Union(_));
self.get_groups_relation(expr, &[min], is_union, kids_b, is_union, depth - 1)
}
(Node::Union(kids_a), Node::Set(_)) | (Node::Intersection(kids_a), Node::Set(_)) => {
let is_union = matches!(node_min, Node::Union(_));
self.get_groups_relation(expr, kids_a, is_union, &[max], is_union, depth - 1)
}
(Node::Union(kids_a), Node::Union(kids_b))
| (Node::Union(kids_a), Node::Intersection(kids_b))
| (Node::Intersection(kids_a), Node::Union(kids_b))
| (Node::Intersection(kids_a), Node::Intersection(kids_b)) => self.get_groups_relation(
expr,
kids_a,
matches!(node_min, Node::Union(_)),
kids_b,
matches!(node_max, Node::Union(_)),
depth - 1,
),
};
let stored_depth = if rel == MergeRelation::EQUAL || rel == MergeRelation::COMPLEMENTARY {
usize::MAX
} else {
depth
};
self.cache.insert(key, (rel, stored_depth));
let mut final_rel = rel;
if a != min {
final_rel = final_rel.flip();
}
self.apply_negation_logic(final_rel, a.is_neg(), b.is_neg())
}
fn apply_negation_logic(&self, rel: MergeRelation, neg_a: bool, neg_b: bool) -> MergeRelation {
if !neg_a && !neg_b {
return rel;
}
let mut result = MergeRelation::TRIVIAL;
if rel == MergeRelation::EQUAL {
return if neg_a == neg_b {
MergeRelation::EQUAL
} else {
MergeRelation::COMPLEMENTARY
};
}
if rel == MergeRelation::COMPLEMENTARY {
return if neg_a == neg_b {
MergeRelation::COMPLEMENTARY
} else {
MergeRelation::EQUAL
};
}
if rel.is_subset() {
match (neg_a, neg_b) {
(true, true) => result |= MergeRelation::SUPERSET, (false, true) => result |= MergeRelation::DISJOINT, _ => {}
}
}
if rel.is_superset() {
match (neg_a, neg_b) {
(true, true) => result |= MergeRelation::SUBSET, (true, false) => result |= MergeRelation::DISJOINT, _ => {}
}
}
if rel.is_disjoint() {
match (neg_a, neg_b) {
(false, true) => result |= MergeRelation::SUBSET, (true, false) => result |= MergeRelation::SUPERSET, _ => {}
}
}
if rel.is_cover() {
match (neg_a, neg_b) {
(false, true) => result |= MergeRelation::SUPERSET, (true, false) => result |= MergeRelation::SUBSET, _ => {}
}
}
result
}
fn get_groups_relation(
&mut self,
expr: &Expression<T>,
kids_a: &[NodeId],
is_union_a: bool,
kids_b: &[NodeId],
is_union_b: bool,
depth: usize,
) -> MergeRelation
where
M: Mergeable<T>,
{
let mut result = MergeRelation::TRIVIAL;
let is_disjoint = match (is_union_a, is_union_b) {
(false, false) =>
{
kids_a.iter().any(|&a| {
kids_b
.iter()
.any(|&b| self.get_relation_recursive(expr, a, b, depth).is_disjoint())
})
}
(true, false) =>
{
kids_a.iter().all(|&a| {
kids_b
.iter()
.any(|&b| self.get_relation_recursive(expr, a, b, depth).is_disjoint())
})
}
(false, true) =>
{
kids_b.iter().all(|&b| {
kids_a
.iter()
.any(|&a| self.get_relation_recursive(expr, a, b, depth).is_disjoint())
})
}
(true, true) =>
{
kids_a.iter().all(|&a| {
kids_b
.iter()
.all(|&b| self.get_relation_recursive(expr, a, b, depth).is_disjoint())
})
}
};
if is_disjoint {
result |= MergeRelation::DISJOINT;
}
let is_subset = match (is_union_a, is_union_b) {
(true, true) =>
{
kids_a.iter().all(|&a| {
kids_b
.iter()
.any(|&b| self.get_relation_recursive(expr, a, b, depth).is_subset())
})
}
(true, false) =>
{
kids_a.iter().all(|&a| {
kids_b
.iter()
.all(|&b| self.get_relation_recursive(expr, a, b, depth).is_subset())
})
}
(false, true) =>
{
kids_a.iter().any(|&a| {
kids_b
.iter()
.any(|&b| self.get_relation_recursive(expr, a, b, depth).is_subset())
})
}
(false, false) =>
{
kids_b.iter().all(|&b| {
kids_a
.iter()
.any(|&a| self.get_relation_recursive(expr, a, b, depth).is_superset())
})
}
};
if is_subset {
result |= MergeRelation::SUBSET;
}
let is_superset = match (is_union_a, is_union_b) {
(true, true) =>
{
kids_b.iter().all(|&b| {
kids_a
.iter()
.any(|&a| self.get_relation_recursive(expr, a, b, depth).is_subset())
})
}
(true, false) =>
{
kids_a.iter().any(|&a| {
kids_b
.iter()
.any(|&b| self.get_relation_recursive(expr, a, b, depth).is_superset())
})
}
(false, true) =>
{
kids_a.iter().all(|&a| {
kids_b
.iter()
.all(|&b| self.get_relation_recursive(expr, a, b, depth).is_superset())
})
}
(false, false) =>
{
kids_a.iter().all(|&a| {
kids_b
.iter()
.any(|&b| self.get_relation_recursive(expr, a, b, depth).is_superset())
})
}
};
if is_superset {
result |= MergeRelation::SUPERSET;
}
result
}
}