use std::collections::HashMap;
use xlog_core::RelId;
use xlog_ir::rir::{
ColumnSwap, CostPredictionRecord as RirCostPredictionRecord, KCliqueVariableOrder,
MultiwayPlan, PlannedHashReason, ProjectExpr, SortedLayoutSpec, StreamGroupId, VariableOrder,
K_CLIQUE_MAX_EDGES, K_CLIQUE_MAX_K,
};
use xlog_ir::{ExecutionPlan, JoinType, RirNode};
use xlog_stats::{StatsManager, StatsSnapshot};
use crate::compiler_config::CompilerConfig;
use crate::hypergraph::var_order::{
plan_kclique_var_order, FullVariableOrder, KCliqueEdge, KCliqueShape,
};
use crate::hypergraph::VertexId;
use crate::wcoj_var_ordering::{wcoj_cost_gate_predicts_wcoj, WcojVariableOrderingModel};
pub fn promote_multiway(
plan: &mut ExecutionPlan,
_rel_ids: &HashMap<String, RelId>,
stats: &StatsManager,
config: &CompilerConfig,
) {
for (scc_id, rules) in plan.rules_by_scc.iter_mut().enumerate() {
if plan.sccs.get(scc_id).is_none() {
continue;
}
for rule in rules.iter_mut() {
if let Some(promoted) = try_promote_chain(&rule.body) {
rule.body = promoted;
continue;
}
let normalized_tri = normalize_triangle_to_left_deep(&rule.body);
let body_for_tri = normalized_tri.as_ref().unwrap_or(&rule.body);
if let Some(promoted) = try_promote_triangle(body_for_tri, stats, config) {
rule.body = promoted;
continue;
}
let normalized_4c = normalize_4cycle_to_bushy(&rule.body);
let body_for_4c = normalized_4c.as_ref().unwrap_or(&rule.body);
if let Some(promoted) = try_promote_4cycle(body_for_4c, stats, config) {
rule.body = promoted;
continue;
}
if let Some(promoted) = try_promote_clique_k(&rule.body, 5, stats)
.or_else(|| try_promote_clique_k(&rule.body, 6, stats))
.or_else(|| try_promote_clique_k(&rule.body, 7, stats))
.or_else(|| try_promote_clique_k(&rule.body, 8, stats))
{
rule.body = promoted;
continue;
}
}
}
}
fn ac_idx(atom_idx: u8, col_idx: u8) -> u8 {
debug_assert!(atom_idx < 3);
debug_assert!(col_idx < 2);
atom_idx * 2 + col_idx
}
fn inner_output_ac(k: usize) -> Option<(u8, u8)> {
match k {
0 => Some((0, 0)),
1 => Some((0, 1)),
2 => Some((1, 0)),
3 => Some((1, 1)),
_ => None,
}
}
fn outer_output_ac(k: usize) -> Option<(u8, u8)> {
match k {
0..=3 => inner_output_ac(k),
4 => Some((2, 0)),
5 => Some((2, 1)),
_ => None,
}
}
fn uf_find(parent: &mut [u8; 6], x: u8) -> u8 {
let mut root = x;
while parent[root as usize] != root {
root = parent[root as usize];
}
let mut cur = x;
while parent[cur as usize] != root {
let next = parent[cur as usize];
parent[cur as usize] = root;
cur = next;
}
root
}
fn uf_union(parent: &mut [u8; 6], a: u8, b: u8) {
let ra = uf_find(parent, a);
let rb = uf_find(parent, b);
if ra != rb {
parent[rb as usize] = ra;
}
}
#[allow(clippy::too_many_arguments)]
fn infer_triangle_semantics(
inner_left_rel: RelId,
inner_right_rel: RelId,
outer_third_rel: RelId,
lk2: &[usize],
rk2: &[usize],
lk1: &[usize],
rk1: &[usize],
project_cols: &[ProjectExpr],
) -> Option<(RelId, RelId, RelId)> {
if lk2.len() != 1 || rk2.len() != 1 {
return None;
}
if lk1.len() != 2 || rk1.len() != 2 {
return None;
}
if project_cols.len() != 3 {
return None;
}
if lk2[0] >= 2 || rk2[0] >= 2 {
return None;
}
if lk1.iter().any(|k| *k >= 4) || rk1.iter().any(|k| *k >= 2) {
return None;
}
let mut parent = [0u8, 1, 2, 3, 4, 5];
uf_union(
&mut parent,
ac_idx(0, lk2[0] as u8),
ac_idx(1, rk2[0] as u8),
);
for i in 0..2 {
let (inner_atom, inner_col) = inner_output_ac(lk1[i])?;
uf_union(
&mut parent,
ac_idx(inner_atom, inner_col),
ac_idx(2, rk1[i] as u8),
);
}
let roots: [u8; 6] = std::array::from_fn(|i| uf_find(&mut parent, i as u8));
let mut counts: HashMap<u8, u8> = HashMap::new();
for r in &roots {
*counts.entry(*r).or_insert(0) += 1;
}
if counts.len() != 3 || counts.values().any(|c| *c != 2) {
return None;
}
let mut head_classes: [u8; 3] = [0; 3];
for (i, pc) in project_cols.iter().enumerate() {
let ProjectExpr::Column(k) = pc else {
return None;
};
let (atom, col) = outer_output_ac(*k)?;
head_classes[i] = uf_find(&mut parent, ac_idx(atom, col));
}
if head_classes[0] == head_classes[1]
|| head_classes[0] == head_classes[2]
|| head_classes[1] == head_classes[2]
{
return None;
}
let x_class = head_classes[0];
let y_class = head_classes[1];
let z_class = head_classes[2];
let atom_classes = |atom_idx: u8| -> (u8, u8) {
(
roots[ac_idx(atom_idx, 0) as usize],
roots[ac_idx(atom_idx, 1) as usize],
)
};
let atom_rels = [inner_left_rel, inner_right_rel, outer_third_rel];
let mut rel_xy: Option<RelId> = None;
let mut rel_yz: Option<RelId> = None;
let mut rel_xz: Option<RelId> = None;
for atom_idx in 0..3u8 {
let (c0, c1) = atom_classes(atom_idx);
let binds_x = c0 == x_class || c1 == x_class;
let binds_y = c0 == y_class || c1 == y_class;
let binds_z = c0 == z_class || c1 == z_class;
match (binds_x, binds_y, binds_z) {
(true, true, false) => rel_xy = Some(atom_rels[atom_idx as usize]),
(false, true, true) => rel_yz = Some(atom_rels[atom_idx as usize]),
(true, false, true) => rel_xz = Some(atom_rels[atom_idx as usize]),
_ => return None,
}
}
Some((rel_xy?, rel_yz?, rel_xz?))
}
fn normalize_triangle_to_left_deep(node: &RirNode) -> Option<RirNode> {
let RirNode::Project {
input: outer_input,
columns,
} = node
else {
return None;
};
let RirNode::Join {
left: outer_l,
right: outer_r,
left_keys: outer_lk,
right_keys: outer_rk,
join_type: outer_jt,
} = outer_input.as_ref()
else {
return None;
};
if !matches!(outer_jt, JoinType::Inner) {
return None;
}
let RirNode::Scan { rel: _ } = outer_l.as_ref() else {
return None;
};
let RirNode::Join { .. } = outer_r.as_ref() else {
return None;
};
let RirNode::Join {
left: inner_l,
right: inner_r,
..
} = outer_r.as_ref()
else {
return None;
};
if !matches!(inner_l.as_ref(), RirNode::Scan { .. })
|| !matches!(inner_r.as_ref(), RirNode::Scan { .. })
{
return None;
}
let new_outer = RirNode::Join {
left: outer_r.clone(),
right: outer_l.clone(),
left_keys: outer_rk.clone(),
right_keys: outer_lk.clone(),
join_type: JoinType::Inner,
};
let new_columns: Vec<ProjectExpr> = columns
.iter()
.map(|expr| match expr {
ProjectExpr::Column(k) => ProjectExpr::Column((*k + 4) % 6),
other => other.clone(),
})
.collect();
Some(RirNode::Project {
input: Box::new(new_outer),
columns: new_columns,
})
}
fn normalize_4cycle_to_bushy(node: &RirNode) -> Option<RirNode> {
let RirNode::Project {
input: outer_input,
columns,
} = node
else {
return None;
};
let RirNode::Join {
left: outer_l,
right: outer_r,
left_keys: outer_lk,
right_keys: outer_rk,
join_type: outer_jt,
} = outer_input.as_ref()
else {
return None;
};
if !matches!(outer_jt, JoinType::Inner) {
return None;
}
let RirNode::Scan { rel: r0 } = outer_l.as_ref() else {
return None;
};
let RirNode::Join {
left: middle_l,
right: middle_r,
left_keys: middle_lk,
right_keys: middle_rk,
join_type: middle_jt,
} = outer_r.as_ref()
else {
return None;
};
if !matches!(middle_jt, JoinType::Inner) {
return None;
}
let RirNode::Scan { rel: r1 } = middle_l.as_ref() else {
return None;
};
let RirNode::Join {
left: deep_l,
right: deep_r,
left_keys: deep_lk,
right_keys: deep_rk,
join_type: deep_jt,
} = middle_r.as_ref()
else {
return None;
};
if !matches!(deep_jt, JoinType::Inner) {
return None;
}
let RirNode::Scan { rel: r2 } = deep_l.as_ref() else {
return None;
};
let RirNode::Scan { rel: r3 } = deep_r.as_ref() else {
return None;
};
if outer_lk.as_slice() != [0, 1] || outer_rk.as_slice() != [5, 0] {
return None;
}
if middle_lk.as_slice() != [1] || middle_rk.as_slice() != [0] {
return None;
}
if deep_lk.as_slice() != [1] || deep_rk.as_slice() != [0] {
return None;
}
let inner_left = RirNode::Join {
left: Box::new(RirNode::Scan { rel: *r0 }),
right: Box::new(RirNode::Scan { rel: *r1 }),
left_keys: vec![1],
right_keys: vec![0],
join_type: JoinType::Inner,
};
let inner_right = RirNode::Join {
left: Box::new(RirNode::Scan { rel: *r2 }),
right: Box::new(RirNode::Scan { rel: *r3 }),
left_keys: vec![1],
right_keys: vec![0],
join_type: JoinType::Inner,
};
let new_outer = RirNode::Join {
left: Box::new(inner_left),
right: Box::new(inner_right),
left_keys: vec![3, 0],
right_keys: vec![0, 3],
join_type: JoinType::Inner,
};
Some(RirNode::Project {
input: Box::new(new_outer),
columns: columns.clone(),
})
}
fn try_promote_triangle(
node: &RirNode,
stats: &StatsManager,
config: &CompilerConfig,
) -> Option<RirNode> {
let RirNode::Project {
input: outer_input,
columns,
} = node
else {
return None;
};
let RirNode::Join {
left: l1,
right: r1,
left_keys: lk1,
right_keys: rk1,
join_type: jt1,
} = outer_input.as_ref()
else {
return None;
};
if !matches!(jt1, JoinType::Inner) {
return None;
}
let RirNode::Scan { rel: rel_third } = r1.as_ref() else {
return None;
};
let RirNode::Join {
left: l2,
right: r2,
left_keys: lk2,
right_keys: rk2,
join_type: jt2,
} = l1.as_ref()
else {
return None;
};
if !matches!(jt2, JoinType::Inner) {
return None;
}
let RirNode::Scan { rel: rel_inner_l } = l2.as_ref() else {
return None;
};
let RirNode::Scan { rel: rel_inner_r } = r2.as_ref() else {
return None;
};
let (rel_xy, rel_yz, rel_xz) = infer_triangle_semantics(
*rel_inner_l,
*rel_inner_r,
*rel_third,
lk2,
rk2,
lk1,
rk1,
columns,
)?;
let inputs = vec![
RirNode::Scan { rel: rel_xy },
RirNode::Scan { rel: rel_yz },
RirNode::Scan { rel: rel_xz },
];
let slot_vars = vec![
vec![Some(0u32), Some(1)],
vec![Some(1u32), Some(2)],
vec![Some(0u32), Some(2)],
];
let output_columns = columns.clone();
let fallback = Box::new(node.clone());
use crate::compiler_config::WcojVarOrderingKind;
use crate::wcoj_var_ordering::{
build_triangle_var_order, HeatAwareLeaderModel, LeaderCardinalityModel,
};
let leader_idx = match config.wcoj_variable_ordering {
WcojVarOrderingKind::Disabled => None,
WcojVarOrderingKind::LeaderCardinality => {
LeaderCardinalityModel.pick_triangle_leader([rel_xy, rel_yz, rel_xz], stats, config)
}
WcojVarOrderingKind::HeatAware => {
HeatAwareLeaderModel.pick_triangle_leader([rel_xy, rel_yz, rel_xz], stats, config)
}
};
let var_order = leader_idx.map(build_triangle_var_order);
Some(RirNode::MultiWayJoin {
inputs,
slot_vars,
output_columns,
fallback,
plan: None,
var_order,
})
}
fn try_promote_chain(node: &RirNode) -> Option<RirNode> {
let RirNode::Project { input, columns } = node else {
return None;
};
let RirNode::Join {
left,
right,
left_keys,
right_keys,
join_type,
} = input.as_ref()
else {
return None;
};
if !matches!(join_type, JoinType::Inner) {
return None;
}
if left_keys.len() != 1 || right_keys.len() != 1 {
return None;
}
let left_key = left_keys[0];
let right_key = right_keys[0];
if left_key >= 2 || right_key >= 2 {
return None;
}
let RirNode::Scan { rel: rel_left } = left.as_ref() else {
return None;
};
let RirNode::Scan { rel: rel_right } = right.as_ref() else {
return None;
};
Some(RirNode::ChainJoin {
left: Box::new(RirNode::Scan { rel: *rel_left }),
right: Box::new(RirNode::Scan { rel: *rel_right }),
left_key,
right_key,
output_columns: columns.clone(),
fallback: Box::new(node.clone()),
})
}
fn ac_idx_4(atom_idx: u8, col_idx: u8) -> u8 {
debug_assert!(atom_idx < 4);
debug_assert!(col_idx < 2);
atom_idx * 2 + col_idx
}
fn outer_left_inner_output_ac(k: usize) -> Option<(u8, u8)> {
match k {
0 => Some((0, 0)),
1 => Some((0, 1)),
2 => Some((1, 0)),
3 => Some((1, 1)),
_ => None,
}
}
fn outer_right_inner_output_ac(k: usize) -> Option<(u8, u8)> {
match k {
0 => Some((2, 0)),
1 => Some((2, 1)),
2 => Some((3, 0)),
3 => Some((3, 1)),
_ => None,
}
}
fn outer_4cycle_output_ac(k: usize) -> Option<(u8, u8)> {
match k {
0..=3 => outer_left_inner_output_ac(k),
4..=7 => outer_right_inner_output_ac(k - 4),
_ => None,
}
}
fn uf_find_8(parent: &mut [u8; 8], x: u8) -> u8 {
let mut root = x;
while parent[root as usize] != root {
root = parent[root as usize];
}
let mut cur = x;
while parent[cur as usize] != root {
let next = parent[cur as usize];
parent[cur as usize] = root;
cur = next;
}
root
}
fn uf_union_8(parent: &mut [u8; 8], a: u8, b: u8) {
let ra = uf_find_8(parent, a);
let rb = uf_find_8(parent, b);
if ra != rb {
parent[rb as usize] = ra;
}
}
#[allow(clippy::too_many_arguments)]
fn infer_4cycle_semantics(
rel_ll: RelId,
rel_lr: RelId,
rel_rl: RelId,
rel_rr: RelId,
ilk_l: &[usize],
irk_l: &[usize],
ilk_r: &[usize],
irk_r: &[usize],
olk: &[usize],
ork: &[usize],
project_cols: &[ProjectExpr],
) -> Option<(RelId, RelId, RelId, RelId)> {
if ilk_l.len() != 1 || irk_l.len() != 1 {
return None;
}
if ilk_r.len() != 1 || irk_r.len() != 1 {
return None;
}
if olk.len() != 2 || ork.len() != 2 {
return None;
}
if project_cols.len() != 4 {
return None;
}
if ilk_l[0] >= 2 || irk_l[0] >= 2 || ilk_r[0] >= 2 || irk_r[0] >= 2 {
return None;
}
if olk.iter().any(|k| *k >= 4) || ork.iter().any(|k| *k >= 4) {
return None;
}
let mut parent = [0u8, 1, 2, 3, 4, 5, 6, 7];
uf_union_8(
&mut parent,
ac_idx_4(0, ilk_l[0] as u8),
ac_idx_4(1, irk_l[0] as u8),
);
uf_union_8(
&mut parent,
ac_idx_4(2, ilk_r[0] as u8),
ac_idx_4(3, irk_r[0] as u8),
);
for i in 0..2 {
let (la, lc) = outer_left_inner_output_ac(olk[i])?;
let (ra, rc) = outer_right_inner_output_ac(ork[i])?;
uf_union_8(&mut parent, ac_idx_4(la, lc), ac_idx_4(ra, rc));
}
let roots: [u8; 8] = std::array::from_fn(|i| uf_find_8(&mut parent, i as u8));
let mut counts: HashMap<u8, u8> = HashMap::new();
for r in &roots {
*counts.entry(*r).or_insert(0) += 1;
}
if counts.len() != 4 || counts.values().any(|c| *c != 2) {
return None;
}
let mut head_classes: [u8; 4] = [0; 4];
for (i, pc) in project_cols.iter().enumerate() {
let ProjectExpr::Column(k) = pc else {
return None;
};
let (atom, col) = outer_4cycle_output_ac(*k)?;
head_classes[i] = uf_find_8(&mut parent, ac_idx_4(atom, col));
}
for i in 0..4 {
for j in (i + 1)..4 {
if head_classes[i] == head_classes[j] {
return None;
}
}
}
let w_class = head_classes[0];
let x_class = head_classes[1];
let y_class = head_classes[2];
let z_class = head_classes[3];
let atom_classes = |atom_idx: u8| -> (u8, u8) {
(
roots[ac_idx_4(atom_idx, 0) as usize],
roots[ac_idx_4(atom_idx, 1) as usize],
)
};
let atom_rels = [rel_ll, rel_lr, rel_rl, rel_rr];
let mut rel_wx: Option<RelId> = None;
let mut rel_xy: Option<RelId> = None;
let mut rel_yz: Option<RelId> = None;
let mut rel_zw: Option<RelId> = None;
for atom_idx in 0..4u8 {
let (c0, c1) = atom_classes(atom_idx);
let binds_w = c0 == w_class || c1 == w_class;
let binds_x = c0 == x_class || c1 == x_class;
let binds_y = c0 == y_class || c1 == y_class;
let binds_z = c0 == z_class || c1 == z_class;
match (binds_w, binds_x, binds_y, binds_z) {
(true, true, false, false) => rel_wx = Some(atom_rels[atom_idx as usize]),
(false, true, true, false) => rel_xy = Some(atom_rels[atom_idx as usize]),
(false, false, true, true) => rel_yz = Some(atom_rels[atom_idx as usize]),
(true, false, false, true) => rel_zw = Some(atom_rels[atom_idx as usize]),
_ => return None,
}
}
Some((rel_wx?, rel_xy?, rel_yz?, rel_zw?))
}
fn try_promote_4cycle(
node: &RirNode,
stats: &StatsManager,
config: &CompilerConfig,
) -> Option<RirNode> {
let RirNode::Project {
input: outer_input,
columns,
} = node
else {
return None;
};
let RirNode::Join {
left: outer_l,
right: outer_r,
left_keys: olk,
right_keys: ork,
join_type: ojt,
} = outer_input.as_ref()
else {
return None;
};
if !matches!(ojt, JoinType::Inner) {
return None;
}
let RirNode::Join {
left: ll,
right: lr,
left_keys: ilk_l,
right_keys: irk_l,
join_type: ijt_l,
} = outer_l.as_ref()
else {
return None;
};
if !matches!(ijt_l, JoinType::Inner) {
return None;
}
let RirNode::Scan { rel: rel_ll } = ll.as_ref() else {
return None;
};
let RirNode::Scan { rel: rel_lr } = lr.as_ref() else {
return None;
};
let RirNode::Join {
left: rl,
right: rr,
left_keys: ilk_r,
right_keys: irk_r,
join_type: ijt_r,
} = outer_r.as_ref()
else {
return None;
};
if !matches!(ijt_r, JoinType::Inner) {
return None;
}
let RirNode::Scan { rel: rel_rl } = rl.as_ref() else {
return None;
};
let RirNode::Scan { rel: rel_rr } = rr.as_ref() else {
return None;
};
let (rel_wx, rel_xy, rel_yz, rel_zw) = infer_4cycle_semantics(
*rel_ll, *rel_lr, *rel_rl, *rel_rr, ilk_l, irk_l, ilk_r, irk_r, olk, ork, columns,
)?;
let inputs = vec![
RirNode::Scan { rel: rel_wx },
RirNode::Scan { rel: rel_xy },
RirNode::Scan { rel: rel_yz },
RirNode::Scan { rel: rel_zw },
];
let slot_vars = vec![
vec![Some(0u32), Some(1)],
vec![Some(1u32), Some(2)],
vec![Some(2u32), Some(3)],
vec![Some(3u32), Some(0)],
];
let output_columns = columns.clone();
let fallback = Box::new(node.clone());
use crate::compiler_config::WcojVarOrderingKind;
use crate::wcoj_var_ordering::{
build_cycle4_var_order, HeatAwareLeaderModel, LeaderCardinalityModel,
};
let leader_idx_4 = match config.wcoj_variable_ordering {
WcojVarOrderingKind::Disabled => None,
WcojVarOrderingKind::LeaderCardinality => LeaderCardinalityModel.pick_4cycle_leader(
[rel_wx, rel_xy, rel_yz, rel_zw],
stats,
config,
),
WcojVarOrderingKind::HeatAware => {
HeatAwareLeaderModel.pick_4cycle_leader([rel_wx, rel_xy, rel_yz, rel_zw], stats, config)
}
};
let var_order = leader_idx_4.map(build_cycle4_var_order);
Some(RirNode::MultiWayJoin {
inputs,
slot_vars,
output_columns,
fallback,
plan: None,
var_order,
})
}
fn clique_edge_idx(i: usize, j: usize, k: usize) -> usize {
debug_assert!(i < j && j < k);
i * (2 * k - i - 1) / 2 + (j - i - 1)
}
fn uf_find_clique(parent: &mut [usize], mut x: usize) -> usize {
while parent[x] != x {
parent[x] = parent[parent[x]];
x = parent[x];
}
x
}
fn uf_union_clique(parent: &mut [usize], a: usize, b: usize) {
let ra = uf_find_clique(parent, a);
let rb = uf_find_clique(parent, b);
if ra != rb {
parent[rb] = ra;
}
}
#[allow(clippy::type_complexity)]
fn flatten_clique_body(body: &RirNode) -> Option<(Vec<RelId>, Vec<(usize, usize)>, Vec<usize>)> {
let RirNode::Project { input, columns } = body else {
return None;
};
let mut scans: Vec<RelId> = Vec::new();
let mut key_pairs: Vec<(usize, usize)> = Vec::new();
let _width = walk_clique_node(input, &mut scans, &mut key_pairs)?;
let mut project_globals: Vec<usize> = Vec::with_capacity(columns.len());
for c in columns {
let xlog_ir::rir::ProjectExpr::Column(k) = c else {
return None;
};
project_globals.push(*k);
}
Some((scans, key_pairs, project_globals))
}
fn walk_clique_node(
node: &RirNode,
scans: &mut Vec<RelId>,
key_pairs: &mut Vec<(usize, usize)>,
) -> Option<usize> {
match node {
RirNode::Scan { rel } => {
scans.push(*rel);
Some(2)
}
RirNode::Join {
left,
right,
left_keys,
right_keys,
join_type,
} => {
if !matches!(join_type, JoinType::Inner) {
return None;
}
let left_offset = scans.len() * 2;
let left_width = walk_clique_node(left, scans, key_pairs)?;
let right_offset = left_offset + left_width;
let right_width = walk_clique_node(right, scans, key_pairs)?;
if left_keys.len() != right_keys.len() {
return None;
}
for (lk, rk) in left_keys.iter().zip(right_keys.iter()) {
if *lk >= left_width || *rk >= right_width {
return None;
}
key_pairs.push((left_offset + *lk, right_offset + *rk));
}
Some(left_width + right_width)
}
_ => None,
}
}
fn try_promote_clique_k(body: &RirNode, k: usize, stats: &StatsManager) -> Option<RirNode> {
if !(5..=8).contains(&k) {
return None;
}
let expected_edges = k * (k - 1) / 2;
let (scans, key_pairs, project_globals) = flatten_clique_body(body)?;
if scans.len() != expected_edges {
return None;
}
if project_globals.len() != k {
return None;
}
let n_slots = 2 * expected_edges;
let mut parent: Vec<usize> = (0..n_slots).collect();
for (a, b) in &key_pairs {
if *a >= n_slots || *b >= n_slots {
return None;
}
uf_union_clique(&mut parent, *a, *b);
}
let mut head_class: Vec<usize> = Vec::with_capacity(k);
for col in &project_globals {
if *col >= n_slots {
return None;
}
head_class.push(uf_find_clique(&mut parent, *col));
}
let mut sorted_head_classes = head_class.clone();
sorted_head_classes.sort();
sorted_head_classes.dedup();
if sorted_head_classes.len() != k {
return None;
}
let mut all_class_count: HashMap<usize, usize> = HashMap::new();
for slot in 0..n_slots {
let root = uf_find_clique(&mut parent, slot);
*all_class_count.entry(root).or_insert(0) += 1;
}
if all_class_count.len() != k {
return None;
}
for &count in all_class_count.values() {
if count != k - 1 {
return None;
}
}
let mut class_to_head_idx: HashMap<usize, usize> = HashMap::new();
for (head_idx, cls) in head_class.iter().enumerate() {
class_to_head_idx.insert(*cls, head_idx);
}
let mut atom_pairs: Vec<(usize, usize)> = Vec::with_capacity(expected_edges);
let mut canonical_to_scan_idx: HashMap<(usize, usize), usize> = HashMap::new();
for (atom_i, _rel) in scans.iter().enumerate() {
let slot_a = 2 * atom_i;
let slot_b = 2 * atom_i + 1;
let cls_a = uf_find_clique(&mut parent, slot_a);
let cls_b = uf_find_clique(&mut parent, slot_b);
if cls_a == cls_b {
return None;
}
let head_a = class_to_head_idx.get(&cls_a)?;
let head_b = class_to_head_idx.get(&cls_b)?;
if *head_a > *head_b {
return None;
}
let (lo, hi) = (*head_a, *head_b);
atom_pairs.push((lo, hi));
if canonical_to_scan_idx.insert((lo, hi), atom_i).is_some() {
return None;
}
}
if canonical_to_scan_idx.len() != expected_edges {
return None;
}
for i in 0..k {
for j in (i + 1)..k {
if !canonical_to_scan_idx.contains_key(&(i, j)) {
return None;
}
}
}
let mut reordered_scans: Vec<RelId> = Vec::with_capacity(expected_edges);
for i in 0..k {
for j in (i + 1)..k {
let scan_idx = canonical_to_scan_idx[&(i, j)];
reordered_scans.push(scans[scan_idx]);
}
}
let inputs: Vec<RirNode> = reordered_scans
.iter()
.map(|rel| RirNode::Scan { rel: *rel })
.collect();
let mut slot_vars: Vec<Vec<Option<u32>>> = Vec::with_capacity(expected_edges);
for i in 0..k {
for j in (i + 1)..k {
let _ = clique_edge_idx(i, j, k); slot_vars.push(vec![Some(i as u32), Some(j as u32)]);
}
}
let RirNode::Project { columns, .. } = body else {
return None;
};
let output_columns = columns.clone();
let fallback = Box::new(body.clone());
let shape = build_kclique_shape(k, &reordered_scans)?;
let planner_stats = kclique_planner_stats(stats);
let (plan, var_order) = match plan_kclique_var_order(&shape, &planner_stats) {
Some(full_order) => {
let evidence = rir_cost_prediction(&full_order);
if wcoj_cost_gate_predicts_wcoj(evidence.wcoj_cost, evidence.hash_cost) {
let kclique_order = kclique_variable_order_from_plan(&shape, &full_order)?;
(
MultiwayPlan::WcojWithPlan(kclique_order.clone()),
Some(VariableOrder::kclique(kclique_order)),
)
} else {
(
MultiwayPlan::PlannedHashRoute {
reason: PlannedHashReason::PlannerPredictsHashWins,
planner_evidence: evidence,
},
None,
)
}
}
None => (
MultiwayPlan::PlannedHashRoute {
reason: PlannedHashReason::IncompleteStatsSafeDefault,
planner_evidence: RirCostPredictionRecord::empty(),
},
None,
),
};
Some(RirNode::MultiWayJoin {
inputs,
slot_vars,
output_columns,
fallback,
plan: Some(plan),
var_order,
})
}
fn build_kclique_shape(k: usize, rels: &[RelId]) -> Option<KCliqueShape> {
let mut edges = Vec::with_capacity(rels.len());
let mut idx = 0usize;
for i in 0..k {
for j in (i + 1)..k {
let rel_id = *rels.get(idx)?;
edges.push(KCliqueEdge {
rel_id,
left: VertexId(i),
right: VertexId(j),
left_col: 0,
right_col: 1,
});
idx += 1;
}
}
KCliqueShape::from_edges(k as u8, edges)
}
fn kclique_planner_stats(stats: &StatsManager) -> StatsSnapshot {
stats.snapshot()
}
fn rir_cost_prediction(plan: &FullVariableOrder) -> RirCostPredictionRecord {
RirCostPredictionRecord {
wcoj_cost: plan.cost_prediction.wcoj_cost,
hash_cost: plan.cost_prediction.hash_cost,
}
}
fn kclique_variable_order_from_plan(
shape: &KCliqueShape,
plan: &FullVariableOrder,
) -> Option<KCliqueVariableOrder> {
let k = shape.variable_count();
let expected_edges = usize::from(k) * usize::from(k - 1) / 2;
if plan.variable_order.len() != usize::from(k) || plan.edge_permutation.len() != expected_edges
{
return None;
}
let mut variable_positions = [u8::MAX; K_CLIQUE_MAX_K];
for (position, variable) in plan.variable_order.iter().enumerate() {
if variable.0 >= usize::from(k) {
return None;
}
variable_positions[variable.0] = position as u8;
}
let mut edge_permutation = [u8::MAX; K_CLIQUE_MAX_EDGES];
let mut column_swaps = Vec::new();
let mut leader_slot = None;
for (slot, edge_idx) in plan.edge_permutation.iter().copied().enumerate() {
let edge = shape.edges().get(edge_idx)?;
let left_pos = variable_positions[edge.left.0];
let right_pos = variable_positions[edge.right.0];
if left_pos == u8::MAX || right_pos == u8::MAX {
return None;
}
edge_permutation[slot] = edge_idx as u8;
if left_pos > right_pos {
column_swaps.push(ColumnSwap {
edge_slot: slot as u8,
swap_cols: true,
});
}
if [left_pos, right_pos].into_iter().min() == Some(0)
&& [left_pos, right_pos].into_iter().max() == Some(1)
{
leader_slot = Some(slot as u8);
}
}
let sorted_edge_slots = vec![leader_slot.unwrap_or(0)];
let sorted_layout_requirements = SortedLayoutSpec {
edge_slots: sorted_edge_slots,
key_columns: vec![vec![0, 1]],
};
let helper_split_specs = plan.helper_split_specs.clone();
Some(KCliqueVariableOrder::new(
k,
variable_positions,
edge_permutation,
column_swaps,
sorted_layout_requirements,
helper_split_specs,
StreamGroupId(0),
))
}
#[cfg(test)]
mod tests {
use super::*;
use xlog_core::RelId;
use xlog_ir::{CompiledRule, ExecutionPlan, PlanBuilder, Scc};
fn canonical_triangle_tree() -> RirNode {
let inner = RirNode::Join {
left: Box::new(RirNode::Scan { rel: RelId(1) }),
right: Box::new(RirNode::Scan { rel: RelId(2) }),
left_keys: vec![1],
right_keys: vec![0],
join_type: JoinType::Inner,
};
let outer = RirNode::Join {
left: Box::new(inner),
right: Box::new(RirNode::Scan { rel: RelId(3) }),
left_keys: vec![0, 3],
right_keys: vec![0, 1],
join_type: JoinType::Inner,
};
RirNode::Project {
input: Box::new(outer),
columns: vec![
ProjectExpr::Column(0),
ProjectExpr::Column(1),
ProjectExpr::Column(3),
],
}
}
fn plan_with_body(body: RirNode) -> ExecutionPlan {
let mut builder = PlanBuilder::new();
builder.add_scc(Scc {
id: 0,
predicates: vec!["t".to_string()],
is_recursive: false,
});
builder.add_rule(
0,
CompiledRule {
head: "t".to_string(),
body,
meta: Default::default(),
},
);
builder.build()
}
fn canonical_chain_tree() -> RirNode {
let join = RirNode::Join {
left: Box::new(RirNode::Scan { rel: RelId(1) }),
right: Box::new(RirNode::Scan { rel: RelId(2) }),
left_keys: vec![1],
right_keys: vec![0],
join_type: JoinType::Inner,
};
RirNode::Project {
input: Box::new(join),
columns: vec![ProjectExpr::Column(0), ProjectExpr::Column(3)],
}
}
#[test]
fn promotes_canonical_chain() {
let mut plan = plan_with_body(canonical_chain_tree());
promote_multiway(
&mut plan,
&HashMap::new(),
&StatsManager::new(),
&CompilerConfig::default(),
);
let body = &plan.rules_by_scc[0][0].body;
match body {
RirNode::ChainJoin {
left,
right,
left_key,
right_key,
output_columns,
fallback,
} => {
assert!(matches!(left.as_ref(), RirNode::Scan { rel: RelId(1) }));
assert!(matches!(right.as_ref(), RirNode::Scan { rel: RelId(2) }));
assert_eq!(*left_key, 1);
assert_eq!(*right_key, 0);
assert_eq!(
output_columns,
&vec![ProjectExpr::Column(0), ProjectExpr::Column(3)]
);
assert!(matches!(fallback.as_ref(), RirNode::Project { .. }));
}
other => panic!("expected ChainJoin, got {:?}", other),
}
}
#[test]
fn chain_promotion_rejects_non_inner_join() {
let mut body = canonical_chain_tree();
if let RirNode::Project { input, .. } = &mut body {
if let RirNode::Join { join_type, .. } = input.as_mut() {
*join_type = JoinType::LeftOuter;
}
}
assert!(try_promote_chain(&body).is_none());
}
#[test]
fn chain_promotion_rejects_multi_key_join() {
let mut body = canonical_chain_tree();
if let RirNode::Project { input, .. } = &mut body {
if let RirNode::Join {
left_keys,
right_keys,
..
} = input.as_mut()
{
*left_keys = vec![0, 1];
*right_keys = vec![0, 1];
}
}
assert!(try_promote_chain(&body).is_none());
}
#[test]
fn promotes_canonical_triangle() {
let mut plan = plan_with_body(canonical_triangle_tree());
promote_multiway(
&mut plan,
&HashMap::new(),
&StatsManager::new(),
&CompilerConfig::default(),
);
let body = &plan.rules_by_scc[0][0].body;
match body {
RirNode::MultiWayJoin {
inputs,
slot_vars,
output_columns,
fallback,
var_order: _,
..
} => {
assert_eq!(inputs.len(), 3);
assert!(matches!(inputs[0], RirNode::Scan { rel: RelId(1) }));
assert!(matches!(inputs[1], RirNode::Scan { rel: RelId(2) }));
assert!(matches!(inputs[2], RirNode::Scan { rel: RelId(3) }));
assert_eq!(
slot_vars,
&vec![
vec![Some(0u32), Some(1)],
vec![Some(1u32), Some(2)],
vec![Some(0u32), Some(2)],
]
);
assert_eq!(
output_columns,
&vec![
ProjectExpr::Column(0),
ProjectExpr::Column(1),
ProjectExpr::Column(3),
]
);
assert!(matches!(fallback.as_ref(), RirNode::Project { .. }));
}
other => panic!("expected MultiWayJoin, got {:?}", other),
}
}
#[test]
fn fallback_is_structurally_equal_to_input() {
let pre = canonical_triangle_tree();
let mut plan = plan_with_body(pre.clone());
promote_multiway(
&mut plan,
&HashMap::new(),
&StatsManager::new(),
&CompilerConfig::default(),
);
let body = &plan.rules_by_scc[0][0].body;
let RirNode::MultiWayJoin { fallback, .. } = body else {
panic!("expected MultiWayJoin");
};
assert_eq!(format!("{:?}", fallback.as_ref()), format!("{:?}", pre));
}
#[test]
fn idempotent_under_repeat_calls() {
let mut plan = plan_with_body(canonical_triangle_tree());
promote_multiway(
&mut plan,
&HashMap::new(),
&StatsManager::new(),
&CompilerConfig::default(),
);
let first = format!("{:?}", &plan.rules_by_scc[0][0].body);
promote_multiway(
&mut plan,
&HashMap::new(),
&StatsManager::new(),
&CompilerConfig::default(),
);
let second = format!("{:?}", &plan.rules_by_scc[0][0].body);
assert_eq!(first, second);
}
#[test]
fn promotes_triangle_with_x_shared_inner_pair() {
let inner = RirNode::Join {
left: Box::new(RirNode::Scan { rel: RelId(1) }), right: Box::new(RirNode::Scan { rel: RelId(2) }), left_keys: vec![0],
right_keys: vec![0],
join_type: JoinType::Inner,
};
let outer = RirNode::Join {
left: Box::new(inner),
right: Box::new(RirNode::Scan { rel: RelId(3) }), left_keys: vec![1, 3],
right_keys: vec![0, 1],
join_type: JoinType::Inner,
};
let body = RirNode::Project {
input: Box::new(outer),
columns: vec![
ProjectExpr::Column(0),
ProjectExpr::Column(1),
ProjectExpr::Column(3),
],
};
let mut plan = plan_with_body(body);
promote_multiway(
&mut plan,
&HashMap::new(),
&StatsManager::new(),
&CompilerConfig::default(),
);
let RirNode::MultiWayJoin {
inputs, slot_vars, ..
} = &plan.rules_by_scc[0][0].body
else {
panic!("expected MultiWayJoin after promotion");
};
let scan_rels: Vec<RelId> = inputs
.iter()
.map(|n| match n {
RirNode::Scan { rel } => *rel,
_ => panic!("expected Scan in MultiWayJoin inputs"),
})
.collect();
assert_eq!(scan_rels, vec![RelId(1), RelId(3), RelId(2)]);
assert_eq!(
slot_vars,
&vec![
vec![Some(0u32), Some(1)],
vec![Some(1u32), Some(2)],
vec![Some(0u32), Some(2)],
]
);
}
#[test]
fn promotes_triangle_with_z_shared_inner_pair() {
let inner = RirNode::Join {
left: Box::new(RirNode::Scan { rel: RelId(1) }), right: Box::new(RirNode::Scan { rel: RelId(2) }), left_keys: vec![1],
right_keys: vec![1],
join_type: JoinType::Inner,
};
let outer = RirNode::Join {
left: Box::new(inner),
right: Box::new(RirNode::Scan { rel: RelId(3) }), left_keys: vec![0, 2],
right_keys: vec![0, 1],
join_type: JoinType::Inner,
};
let body = RirNode::Project {
input: Box::new(outer),
columns: vec![
ProjectExpr::Column(0),
ProjectExpr::Column(2),
ProjectExpr::Column(3),
],
};
let mut plan = plan_with_body(body);
promote_multiway(
&mut plan,
&HashMap::new(),
&StatsManager::new(),
&CompilerConfig::default(),
);
let RirNode::MultiWayJoin {
inputs, slot_vars, ..
} = &plan.rules_by_scc[0][0].body
else {
panic!("expected MultiWayJoin after promotion");
};
let scan_rels: Vec<RelId> = inputs
.iter()
.map(|n| match n {
RirNode::Scan { rel } => *rel,
_ => panic!("expected Scan in MultiWayJoin inputs"),
})
.collect();
assert_eq!(scan_rels, vec![RelId(3), RelId(2), RelId(1)]);
assert_eq!(
slot_vars,
&vec![
vec![Some(0u32), Some(1)],
vec![Some(1u32), Some(2)],
vec![Some(0u32), Some(2)],
]
);
}
#[test]
fn promotes_triangle_with_rotated_projection_columns() {
let inner = RirNode::Join {
left: Box::new(RirNode::Scan { rel: RelId(1) }),
right: Box::new(RirNode::Scan { rel: RelId(2) }),
left_keys: vec![1],
right_keys: vec![0],
join_type: JoinType::Inner,
};
let outer = RirNode::Join {
left: Box::new(inner),
right: Box::new(RirNode::Scan { rel: RelId(3) }),
left_keys: vec![0, 3],
right_keys: vec![0, 1],
join_type: JoinType::Inner,
};
let body = RirNode::Project {
input: Box::new(outer),
columns: vec![
ProjectExpr::Column(1),
ProjectExpr::Column(0),
ProjectExpr::Column(3),
],
};
let mut plan = plan_with_body(body);
promote_multiway(
&mut plan,
&HashMap::new(),
&StatsManager::new(),
&CompilerConfig::default(),
);
let RirNode::MultiWayJoin {
slot_vars,
output_columns,
..
} = &plan.rules_by_scc[0][0].body
else {
panic!("expected MultiWayJoin after promotion");
};
assert_eq!(
slot_vars,
&vec![
vec![Some(0u32), Some(1)],
vec![Some(1u32), Some(2)],
vec![Some(0u32), Some(2)],
]
);
assert_eq!(
output_columns,
&vec![
ProjectExpr::Column(1),
ProjectExpr::Column(0),
ProjectExpr::Column(3),
]
);
}
#[test]
fn rejects_non_inner_join() {
let inner = RirNode::Join {
left: Box::new(RirNode::Scan { rel: RelId(1) }),
right: Box::new(RirNode::Scan { rel: RelId(2) }),
left_keys: vec![1],
right_keys: vec![0],
join_type: JoinType::LeftOuter,
};
let outer = RirNode::Join {
left: Box::new(inner),
right: Box::new(RirNode::Scan { rel: RelId(3) }),
left_keys: vec![0, 3],
right_keys: vec![0, 1],
join_type: JoinType::Inner,
};
let body = RirNode::Project {
input: Box::new(outer),
columns: vec![
ProjectExpr::Column(0),
ProjectExpr::Column(1),
ProjectExpr::Column(3),
],
};
let mut plan = plan_with_body(body);
promote_multiway(
&mut plan,
&HashMap::new(),
&StatsManager::new(),
&CompilerConfig::default(),
);
assert!(matches!(
&plan.rules_by_scc[0][0].body,
RirNode::Project { .. }
));
}
#[test]
fn rejects_filter_above_outer_join() {
let inner = RirNode::Join {
left: Box::new(RirNode::Scan { rel: RelId(1) }),
right: Box::new(RirNode::Scan { rel: RelId(2) }),
left_keys: vec![1],
right_keys: vec![0],
join_type: JoinType::Inner,
};
let outer = RirNode::Join {
left: Box::new(inner),
right: Box::new(RirNode::Scan { rel: RelId(3) }),
left_keys: vec![0, 3],
right_keys: vec![0, 1],
join_type: JoinType::Inner,
};
let filtered = RirNode::Filter {
input: Box::new(outer),
predicate: xlog_ir::Expr::Column(0),
};
let body = RirNode::Project {
input: Box::new(filtered),
columns: vec![
ProjectExpr::Column(0),
ProjectExpr::Column(1),
ProjectExpr::Column(3),
],
};
let mut plan = plan_with_body(body);
promote_multiway(
&mut plan,
&HashMap::new(),
&StatsManager::new(),
&CompilerConfig::default(),
);
assert!(matches!(
&plan.rules_by_scc[0][0].body,
RirNode::Project { .. }
));
}
#[test]
fn meta_preserved_byte_for_byte() {
use xlog_core::Schema;
use xlog_ir::metadata::RirMeta;
let schema = Schema::new(vec![
("x".to_string(), xlog_core::ScalarType::U32),
("y".to_string(), xlog_core::ScalarType::U32),
("z".to_string(), xlog_core::ScalarType::U32),
]);
let meta_pre = RirMeta::with_schema(schema).with_rows(100, 250);
let mut builder = PlanBuilder::new();
builder.add_scc(Scc {
id: 0,
predicates: vec!["t".to_string()],
is_recursive: false,
});
builder.add_rule(
0,
CompiledRule {
head: "t".to_string(),
body: canonical_triangle_tree(),
meta: meta_pre.clone(),
},
);
let mut plan = builder.build();
promote_multiway(
&mut plan,
&HashMap::new(),
&StatsManager::new(),
&CompilerConfig::default(),
);
assert_eq!(
format!("{:?}", &plan.rules_by_scc[0][0].meta),
format!("{:?}", meta_pre),
);
}
#[test]
fn promotes_stable_triangle_in_recursive_scc() {
let mut builder = PlanBuilder::new();
builder.add_scc(Scc {
id: 0,
predicates: vec!["tri".to_string()],
is_recursive: true,
});
builder.add_rule(
0,
CompiledRule {
head: "tri".to_string(),
body: canonical_triangle_tree(),
meta: Default::default(),
},
);
let mut plan = builder.build();
promote_multiway(
&mut plan,
&HashMap::new(),
&StatsManager::new(),
&CompilerConfig::default(),
);
assert!(matches!(
&plan.rules_by_scc[0][0].body,
RirNode::MultiWayJoin { .. }
));
}
#[test]
fn promotes_linear_recursive_triangle() {
let mut builder = PlanBuilder::new();
builder.add_scc(Scc {
id: 0,
predicates: vec!["tri".to_string()],
is_recursive: true,
});
builder.add_rule(
0,
CompiledRule {
head: "tri".to_string(),
body: canonical_triangle_tree(),
meta: Default::default(),
},
);
let mut plan = builder.build();
let mut rel_ids = HashMap::new();
rel_ids.insert("tri".to_string(), RelId(2)); promote_multiway(
&mut plan,
&rel_ids,
&StatsManager::new(),
&CompilerConfig::default(),
);
assert!(matches!(
&plan.rules_by_scc[0][0].body,
RirNode::MultiWayJoin { .. }
));
}
#[test]
fn promotes_multirec_triangle_in_recursive_scc() {
let mut builder = PlanBuilder::new();
builder.add_scc(Scc {
id: 0,
predicates: vec!["tri_a".to_string(), "tri_b".to_string()],
is_recursive: true,
});
builder.add_rule(
0,
CompiledRule {
head: "tri_a".to_string(),
body: canonical_triangle_tree(),
meta: Default::default(),
},
);
let mut plan = builder.build();
let mut rel_ids = HashMap::new();
rel_ids.insert("tri_a".to_string(), RelId(1));
rel_ids.insert("tri_b".to_string(), RelId(2));
promote_multiway(
&mut plan,
&rel_ids,
&StatsManager::new(),
&CompilerConfig::default(),
);
assert!(matches!(
&plan.rules_by_scc[0][0].body,
RirNode::MultiWayJoin { .. }
));
}
#[test]
fn promotes_linear_rec_and_non_rec_sccs_in_mixed_plan() {
let mut builder = PlanBuilder::new();
builder.add_scc(Scc {
id: 0,
predicates: vec!["rec".to_string()],
is_recursive: true,
});
builder.add_rule(
0,
CompiledRule {
head: "rec".to_string(),
body: canonical_triangle_tree(),
meta: Default::default(),
},
);
builder.add_scc(Scc {
id: 1,
predicates: vec!["nonrec".to_string()],
is_recursive: false,
});
builder.add_rule(
1,
CompiledRule {
head: "nonrec".to_string(),
body: canonical_triangle_tree(),
meta: Default::default(),
},
);
let mut plan = builder.build();
let mut rel_ids = HashMap::new();
rel_ids.insert("rec".to_string(), RelId(1)); promote_multiway(
&mut plan,
&rel_ids,
&StatsManager::new(),
&CompilerConfig::default(),
);
assert!(matches!(
&plan.rules_by_scc[0][0].body,
RirNode::MultiWayJoin { .. }
));
assert!(matches!(
&plan.rules_by_scc[1][0].body,
RirNode::MultiWayJoin { .. }
));
}
fn canonical_4cycle_tree() -> RirNode {
let inner_l = RirNode::Join {
left: Box::new(RirNode::Scan { rel: RelId(1) }),
right: Box::new(RirNode::Scan { rel: RelId(2) }),
left_keys: vec![1],
right_keys: vec![0],
join_type: JoinType::Inner,
};
let inner_r = RirNode::Join {
left: Box::new(RirNode::Scan { rel: RelId(3) }),
right: Box::new(RirNode::Scan { rel: RelId(4) }),
left_keys: vec![1],
right_keys: vec![0],
join_type: JoinType::Inner,
};
let outer = RirNode::Join {
left: Box::new(inner_l),
right: Box::new(inner_r),
left_keys: vec![0, 3],
right_keys: vec![3, 0],
join_type: JoinType::Inner,
};
RirNode::Project {
input: Box::new(outer),
columns: vec![
ProjectExpr::Column(0),
ProjectExpr::Column(1),
ProjectExpr::Column(3),
ProjectExpr::Column(5),
],
}
}
fn plan_with_4cycle_body(body: RirNode) -> ExecutionPlan {
let mut builder = PlanBuilder::new();
builder.add_scc(Scc {
id: 0,
predicates: vec!["cycle4".to_string()],
is_recursive: false,
});
builder.add_rule(
0,
CompiledRule {
head: "cycle4".to_string(),
body,
meta: Default::default(),
},
);
builder.build()
}
#[test]
fn promotes_canonical_4cycle() {
let mut plan = plan_with_4cycle_body(canonical_4cycle_tree());
promote_multiway(
&mut plan,
&HashMap::new(),
&StatsManager::new(),
&CompilerConfig::default(),
);
let body = &plan.rules_by_scc[0][0].body;
match body {
RirNode::MultiWayJoin {
inputs,
slot_vars,
output_columns,
fallback,
var_order: _,
..
} => {
assert_eq!(inputs.len(), 4);
assert!(matches!(inputs[0], RirNode::Scan { rel: RelId(1) }));
assert!(matches!(inputs[1], RirNode::Scan { rel: RelId(2) }));
assert!(matches!(inputs[2], RirNode::Scan { rel: RelId(3) }));
assert!(matches!(inputs[3], RirNode::Scan { rel: RelId(4) }));
assert_eq!(
slot_vars,
&vec![
vec![Some(0u32), Some(1)],
vec![Some(1u32), Some(2)],
vec![Some(2u32), Some(3)],
vec![Some(3u32), Some(0)],
]
);
assert_eq!(
output_columns,
&vec![
ProjectExpr::Column(0),
ProjectExpr::Column(1),
ProjectExpr::Column(3),
ProjectExpr::Column(5),
]
);
assert!(matches!(fallback.as_ref(), RirNode::Project { .. }));
}
other => panic!("expected MultiWayJoin, got {:?}", other),
}
}
#[test]
fn fallback_4cycle_is_structurally_equal_to_input() {
let pre = canonical_4cycle_tree();
let mut plan = plan_with_4cycle_body(pre.clone());
promote_multiway(
&mut plan,
&HashMap::new(),
&StatsManager::new(),
&CompilerConfig::default(),
);
let body = &plan.rules_by_scc[0][0].body;
let RirNode::MultiWayJoin { fallback, .. } = body else {
panic!("expected MultiWayJoin");
};
assert_eq!(format!("{:?}", fallback.as_ref()), format!("{:?}", pre));
}
#[test]
fn idempotent_4cycle_under_repeat_calls() {
let mut plan = plan_with_4cycle_body(canonical_4cycle_tree());
promote_multiway(
&mut plan,
&HashMap::new(),
&StatsManager::new(),
&CompilerConfig::default(),
);
let first = format!("{:?}", &plan.rules_by_scc[0][0].body);
promote_multiway(
&mut plan,
&HashMap::new(),
&StatsManager::new(),
&CompilerConfig::default(),
);
let second = format!("{:?}", &plan.rules_by_scc[0][0].body);
assert_eq!(first, second);
}
#[test]
fn rejects_4cycle_with_left_deep_shape() {
let inner = RirNode::Join {
left: Box::new(RirNode::Scan { rel: RelId(1) }),
right: Box::new(RirNode::Scan { rel: RelId(2) }),
left_keys: vec![1],
right_keys: vec![0],
join_type: JoinType::Inner,
};
let outer = RirNode::Join {
left: Box::new(inner),
right: Box::new(RirNode::Scan { rel: RelId(3) }),
left_keys: vec![0, 3],
right_keys: vec![0, 1],
join_type: JoinType::Inner,
};
let body = RirNode::Project {
input: Box::new(outer),
columns: vec![
ProjectExpr::Column(0),
ProjectExpr::Column(1),
ProjectExpr::Column(3),
ProjectExpr::Column(5),
],
};
assert!(
try_promote_4cycle(&body, &StatsManager::new(), &CompilerConfig::default()).is_none()
);
}
#[test]
fn promotes_4cycle_with_alternative_inner_grouping() {
let inner_left = RirNode::Join {
left: Box::new(RirNode::Scan { rel: RelId(2) }),
right: Box::new(RirNode::Scan { rel: RelId(3) }),
left_keys: vec![1],
right_keys: vec![0],
join_type: JoinType::Inner,
};
let inner_right = RirNode::Join {
left: Box::new(RirNode::Scan { rel: RelId(4) }),
right: Box::new(RirNode::Scan { rel: RelId(1) }),
left_keys: vec![1],
right_keys: vec![0],
join_type: JoinType::Inner,
};
let outer = RirNode::Join {
left: Box::new(inner_left),
right: Box::new(inner_right),
left_keys: vec![0, 3],
right_keys: vec![3, 0],
join_type: JoinType::Inner,
};
let body = RirNode::Project {
input: Box::new(outer),
columns: vec![
ProjectExpr::Column(5),
ProjectExpr::Column(0),
ProjectExpr::Column(1),
ProjectExpr::Column(3),
],
};
let mut plan = plan_with_body(body);
promote_multiway(
&mut plan,
&HashMap::new(),
&StatsManager::new(),
&CompilerConfig::default(),
);
let RirNode::MultiWayJoin {
inputs, slot_vars, ..
} = &plan.rules_by_scc[0][0].body
else {
panic!("expected MultiWayJoin after promotion");
};
let scan_rels: Vec<RelId> = inputs
.iter()
.map(|n| match n {
RirNode::Scan { rel } => *rel,
_ => panic!("expected Scan in MultiWayJoin inputs"),
})
.collect();
assert_eq!(
scan_rels,
vec![RelId(1), RelId(2), RelId(3), RelId(4)],
"inputs must be in semantic order regardless of positional layout"
);
assert_eq!(
slot_vars,
&vec![
vec![Some(0u32), Some(1)],
vec![Some(1u32), Some(2)],
vec![Some(2u32), Some(3)],
vec![Some(3u32), Some(0)],
]
);
}
#[test]
fn rejects_4cycle_with_rotated_columns() {
let mut body = canonical_4cycle_tree();
if let RirNode::Project { columns, .. } = &mut body {
columns.swap(0, 1);
}
assert!(
try_promote_4cycle(&body, &StatsManager::new(), &CompilerConfig::default()).is_none()
);
}
#[test]
fn rejects_4cycle_with_non_inner_outer_join() {
let mut body = canonical_4cycle_tree();
if let RirNode::Project { input, .. } = &mut body {
if let RirNode::Join { join_type, .. } = input.as_mut() {
*join_type = JoinType::LeftOuter;
}
}
assert!(
try_promote_4cycle(&body, &StatsManager::new(), &CompilerConfig::default()).is_none()
);
}
#[test]
fn rejects_4cycle_with_wrong_outer_keys() {
let mut body = canonical_4cycle_tree();
if let RirNode::Project { input, .. } = &mut body {
if let RirNode::Join { left_keys, .. } = input.as_mut() {
*left_keys = vec![0, 4]; }
}
assert!(
try_promote_4cycle(&body, &StatsManager::new(), &CompilerConfig::default()).is_none()
);
}
#[test]
fn promotes_stable_4cycle_in_recursive_scc() {
let mut builder = PlanBuilder::new();
builder.add_scc(Scc {
id: 0,
predicates: vec!["rec_cycle".to_string()],
is_recursive: true,
});
builder.add_rule(
0,
CompiledRule {
head: "rec_cycle".to_string(),
body: canonical_4cycle_tree(),
meta: Default::default(),
},
);
let mut plan = builder.build();
promote_multiway(
&mut plan,
&HashMap::new(),
&StatsManager::new(),
&CompilerConfig::default(),
);
assert!(matches!(
&plan.rules_by_scc[0][0].body,
RirNode::MultiWayJoin { .. }
));
}
#[test]
fn promotes_linear_recursive_4cycle() {
let mut builder = PlanBuilder::new();
builder.add_scc(Scc {
id: 0,
predicates: vec!["rec_cycle".to_string()],
is_recursive: true,
});
builder.add_rule(
0,
CompiledRule {
head: "rec_cycle".to_string(),
body: canonical_4cycle_tree(),
meta: Default::default(),
},
);
let mut plan = builder.build();
let mut rel_ids = HashMap::new();
rel_ids.insert("rec_cycle".to_string(), RelId(2));
promote_multiway(
&mut plan,
&rel_ids,
&StatsManager::new(),
&CompilerConfig::default(),
);
assert!(matches!(
&plan.rules_by_scc[0][0].body,
RirNode::MultiWayJoin { .. }
));
}
#[test]
fn promotes_multirec_4cycle_in_recursive_scc() {
let mut builder = PlanBuilder::new();
builder.add_scc(Scc {
id: 0,
predicates: vec!["rc_a".to_string(), "rc_b".to_string()],
is_recursive: true,
});
builder.add_rule(
0,
CompiledRule {
head: "rc_a".to_string(),
body: canonical_4cycle_tree(),
meta: Default::default(),
},
);
let mut plan = builder.build();
let mut rel_ids = HashMap::new();
rel_ids.insert("rc_a".to_string(), RelId(1));
rel_ids.insert("rc_b".to_string(), RelId(2));
promote_multiway(
&mut plan,
&rel_ids,
&StatsManager::new(),
&CompilerConfig::default(),
);
assert!(matches!(
&plan.rules_by_scc[0][0].body,
RirNode::MultiWayJoin { .. }
));
}
#[test]
fn triangle_does_not_match_4cycle_promoter() {
let triangle = canonical_triangle_tree();
assert!(
try_promote_4cycle(&triangle, &StatsManager::new(), &CompilerConfig::default())
.is_none()
);
let four_cycle = canonical_4cycle_tree();
assert!(try_promote_triangle(
&four_cycle,
&StatsManager::new(),
&CompilerConfig::default()
)
.is_none());
}
}