use super::ir::{HypergraphRule, VertexId};
use std::collections::BTreeMap;
use xlog_core::ScalarType;
pub const BINARY_FALLBACK_KEY_LIMIT: usize = 4;
pub const WCOJ_ELIGIBLE_KEY_LIMIT: usize = 8;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ExecutorContext {
HashFallback,
WcojEligible,
}
impl ExecutorContext {
fn join_key_limit(self) -> usize {
match self {
Self::HashFallback => BINARY_FALLBACK_KEY_LIMIT,
Self::WcojEligible => WCOJ_ELIGIBLE_KEY_LIMIT,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Eligibility {
Eligible,
Ineligible(Vec<Boundary>),
}
impl Eligibility {
pub fn boundaries(&self) -> &[Boundary] {
match self {
Eligibility::Eligible => &[],
Eligibility::Ineligible(bs) => bs,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Boundary {
GroundFact,
HeadAggregation,
BodyNegation,
BodyIsExpr,
InsufficientPositiveAtoms {
positive_count: usize,
},
JoinKeysExceedBinaryFallbackLimit {
context: ExecutorContext,
count: usize,
limit: usize,
},
UnsupportedKeyType {
var: String,
ty: ScalarType,
},
}
pub fn analyze(hg: &HypergraphRule, context: ExecutorContext) -> Eligibility {
let mut boundaries = Vec::new();
if hg.is_fact {
boundaries.push(Boundary::GroundFact);
}
if hg.head_has_aggregation {
boundaries.push(Boundary::HeadAggregation);
}
if hg.has_negation {
boundaries.push(Boundary::BodyNegation);
}
if hg.has_is_expr {
boundaries.push(Boundary::BodyIsExpr);
}
let positive_count = hg.hyperedge_count();
if !hg.is_fact && positive_count < 2 {
boundaries.push(Boundary::InsufficientPositiveAtoms { positive_count });
}
let join_key_count = count_join_keys(hg);
let join_key_limit = context.join_key_limit();
if join_key_count > join_key_limit {
boundaries.push(Boundary::JoinKeysExceedBinaryFallbackLimit {
context,
count: join_key_count,
limit: join_key_limit,
});
}
if boundaries.is_empty() {
Eligibility::Eligible
} else {
Eligibility::Ineligible(boundaries)
}
}
pub fn is_eligible(hg: &HypergraphRule, context: ExecutorContext) -> bool {
matches!(analyze(hg, context), Eligibility::Eligible)
}
fn count_join_keys(hg: &HypergraphRule) -> usize {
let mut occurrences: Vec<usize> = vec![0; hg.vertex_count()];
for edge in &hg.hyperedges {
for vid in edge.vertices() {
let VertexId(idx) = vid;
occurrences[idx] += 1;
}
}
occurrences.iter().filter(|c| **c >= 2).count()
}
pub const WCOJ_SUPPORTED_KEY_TYPES: &[ScalarType] =
&[ScalarType::U32, ScalarType::U64, ScalarType::Symbol];
pub fn analyze_typed(
hg: &HypergraphRule,
vertex_types: &BTreeMap<String, ScalarType>,
context: ExecutorContext,
) -> Eligibility {
let base = analyze(hg, context);
let mut boundaries: Vec<Boundary> = base.boundaries().to_vec();
let join_key_ids = join_key_vertex_ids(hg);
for vid in join_key_ids {
let name = &hg.vertex(vid).name;
if let Some(&ty) = vertex_types.get(name) {
if !WCOJ_SUPPORTED_KEY_TYPES.contains(&ty) {
boundaries.push(Boundary::UnsupportedKeyType {
var: name.clone(),
ty,
});
}
}
}
if boundaries.is_empty() {
Eligibility::Eligible
} else {
Eligibility::Ineligible(boundaries)
}
}
fn join_key_vertex_ids(hg: &HypergraphRule) -> Vec<VertexId> {
let mut occurrences: Vec<usize> = vec![0; hg.vertex_count()];
for edge in &hg.hyperedges {
for vid in edge.vertices() {
let VertexId(idx) = vid;
occurrences[idx] += 1;
}
}
occurrences
.iter()
.enumerate()
.filter(|(_, c)| **c >= 2)
.map(|(i, _)| VertexId(i))
.collect()
}