use std::collections::HashSet;
use xlog_core::{RelId, Result, ScalarType, Schema};
use xlog_cuda::device_runtime::StreamId;
use xlog_cuda::provider::NESTED_LOOP_TOTAL_THRESHOLD;
use xlog_cuda::CudaBuffer;
use xlog_cuda::JoinType as CudaJoinType;
use xlog_ir::{
rir::{KCliqueVariableOrder, MultiwayPlan, ProjectExpr, VariableOrder},
CompiledRule, RirNode,
};
use super::Executor;
#[cfg(feature = "wcoj-phase-timing")]
use std::time::Instant;
pub const ENV_USE_WCOJ_TRIANGLE_U32: &str = "XLOG_USE_WCOJ_TRIANGLE_U32";
pub(super) fn wcoj_gate_enabled(config_override: Option<bool>) -> bool {
if let Some(v) = config_override {
return v;
}
std::env::var(ENV_USE_WCOJ_TRIANGLE_U32)
.map(|v| v == "1" || v.eq_ignore_ascii_case("true"))
.unwrap_or(false)
}
pub const ENV_WCOJ_BLOCK_WORK_UNIT: &str = "XLOG_WCOJ_BLOCK_WORK_UNIT";
pub(super) const WCOJ_BLOCK_WORK_UNIT_DEFAULT: u32 = 1024;
pub(super) const WCOJ_BLOCK_WORK_UNIT_MAX: u32 = 8192;
pub(super) fn wcoj_block_work_unit() -> u32 {
match std::env::var(ENV_WCOJ_BLOCK_WORK_UNIT) {
Ok(raw) => match raw.trim().parse::<u32>() {
Ok(v @ 1..=WCOJ_BLOCK_WORK_UNIT_MAX) => v,
Ok(v) => {
eprintln!(
"warning: {ENV_WCOJ_BLOCK_WORK_UNIT}={v} is outside 1..={WCOJ_BLOCK_WORK_UNIT_MAX}; \
using {WCOJ_BLOCK_WORK_UNIT_DEFAULT}"
);
WCOJ_BLOCK_WORK_UNIT_DEFAULT
}
Err(_) => {
eprintln!(
"warning: {ENV_WCOJ_BLOCK_WORK_UNIT}={raw:?} is not a u32; \
using {WCOJ_BLOCK_WORK_UNIT_DEFAULT}"
);
WCOJ_BLOCK_WORK_UNIT_DEFAULT
}
},
Err(_) => WCOJ_BLOCK_WORK_UNIT_DEFAULT,
}
}
pub(super) fn wcoj_adaptive_enabled(config_override: Option<bool>) -> bool {
config_override.unwrap_or(true)
}
pub const ENV_WCOJ_W63_CHAIN_ENABLE: &str = "XLOG_WCOJ_W63_CHAIN_ENABLE";
pub(super) fn w63_chain_enabled() -> bool {
std::env::var(ENV_WCOJ_W63_CHAIN_ENABLE)
.map(|v| !(v == "0" || v.eq_ignore_ascii_case("false")))
.unwrap_or(true)
}
pub const ENV_USE_WCOJ_4CYCLE: &str = "XLOG_USE_WCOJ_4CYCLE";
pub const ENV_USE_WCOJ_4CYCLE_ADAPTIVE: &str = "XLOG_USE_WCOJ_4CYCLE_ADAPTIVE";
pub const ENV_DISABLE_WCOJ_4CYCLE: &str = "XLOG_DISABLE_WCOJ_4CYCLE";
pub(super) fn wcoj_4cycle_gate_enabled(config_override: Option<bool>) -> bool {
if let Some(v) = config_override {
return v;
}
std::env::var(ENV_USE_WCOJ_4CYCLE)
.map(|v| v == "1" || v.eq_ignore_ascii_case("true"))
.unwrap_or(false)
}
pub(super) fn wcoj_4cycle_adaptive_enabled(config_override: Option<bool>) -> bool {
if let Some(v) = config_override {
return v;
}
std::env::var(ENV_USE_WCOJ_4CYCLE_ADAPTIVE)
.map(|v| v == "1" || v.eq_ignore_ascii_case("true"))
.unwrap_or(false)
}
pub(super) fn wcoj_4cycle_disabled(config_override: Option<bool>) -> bool {
if let Some(v) = config_override {
return v;
}
std::env::var(ENV_DISABLE_WCOJ_4CYCLE)
.map(|v| v == "1" || v.eq_ignore_ascii_case("true"))
.unwrap_or(false)
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
enum DispatchMode {
Force,
CostModel,
}
pub(super) struct ChainRirMatch {
pub rel_left: RelId,
pub rel_right: RelId,
pub left_key: usize,
pub right_key: usize,
pub output_columns: Vec<ProjectExpr>,
}
pub(super) fn match_chain_join(body: &RirNode) -> Option<ChainRirMatch> {
let RirNode::ChainJoin {
left,
right,
left_key,
right_key,
output_columns,
..
} = body
else {
return None;
};
if *left_key >= 2 || *right_key >= 2 {
return None;
}
let rel_left = scan_rel(left)?;
let rel_right = scan_rel(right)?;
Some(ChainRirMatch {
rel_left,
rel_right,
left_key: *left_key,
right_key: *right_key,
output_columns: output_columns.clone(),
})
}
pub(super) struct TriangleRirMatch {
pub rel_xy: RelId,
pub rel_yz: RelId,
pub rel_xz: RelId,
}
pub(super) fn match_multiway_triangle(body: &RirNode) -> Option<TriangleRirMatch> {
let RirNode::MultiWayJoin {
inputs,
slot_vars,
output_columns,
..
} = body
else {
return None;
};
if inputs.len() != 3 {
return None;
}
if !slot_vars_match_canonical_triangle(slot_vars) {
return None;
}
if !output_columns_match_canonical_triangle(output_columns) {
return None;
}
let rel_xy = scan_rel(&inputs[0])?;
let rel_yz = scan_rel(&inputs[1])?;
let rel_xz = scan_rel(&inputs[2])?;
Some(TriangleRirMatch {
rel_xy,
rel_yz,
rel_xz,
})
}
fn slot_vars_match_canonical_triangle(slot_vars: &[Vec<Option<u32>>]) -> bool {
if slot_vars.len() != 3 {
return false;
}
let s0 = &slot_vars[0];
let s1 = &slot_vars[1];
let s2 = &slot_vars[2];
if s0.len() != 2 || s1.len() != 2 || s2.len() != 2 {
return false;
}
let (a, b) = match (s0[0], s0[1]) {
(Some(a), Some(b)) if a != b => (a, b),
_ => return false,
};
let c = match (s1[0], s1[1]) {
(Some(b1), Some(c)) if b1 == b && c != a && c != b => c,
_ => return false,
};
matches!((s2[0], s2[1]), (Some(a2), Some(c2)) if a2 == a && c2 == c)
}
fn output_columns_match_canonical_triangle(cols: &[ProjectExpr]) -> bool {
if cols.len() != 3 {
return false;
}
let cols_pattern = (
matches!(cols[0], ProjectExpr::Column(0)),
matches!(cols[1], ProjectExpr::Column(1)) || matches!(cols[1], ProjectExpr::Column(2)),
matches!(cols[2], ProjectExpr::Column(3)),
);
cols_pattern == (true, true, true)
}
pub(super) struct FourCycleRirMatch {
pub rel_e1: RelId,
pub rel_e2: RelId,
pub rel_e3: RelId,
pub rel_e4: RelId,
}
pub(super) fn match_multiway_4cycle(body: &RirNode) -> Option<FourCycleRirMatch> {
let RirNode::MultiWayJoin {
inputs,
slot_vars,
output_columns,
..
} = body
else {
return None;
};
if inputs.len() != 4 {
return None;
}
if !slot_vars_match_canonical_4cycle(slot_vars) {
return None;
}
if !output_columns_match_canonical_4cycle(output_columns) {
return None;
}
let rel_e1 = scan_rel(&inputs[0])?;
let rel_e2 = scan_rel(&inputs[1])?;
let rel_e3 = scan_rel(&inputs[2])?;
let rel_e4 = scan_rel(&inputs[3])?;
Some(FourCycleRirMatch {
rel_e1,
rel_e2,
rel_e3,
rel_e4,
})
}
fn slot_vars_match_canonical_4cycle(slot_vars: &[Vec<Option<u32>>]) -> bool {
if slot_vars.len() != 4 {
return false;
}
for s in slot_vars {
if s.len() != 2 {
return false;
}
}
let (a, b) = match (slot_vars[0][0], slot_vars[0][1]) {
(Some(a), Some(b)) if a != b => (a, b),
_ => return false,
};
let c = match (slot_vars[1][0], slot_vars[1][1]) {
(Some(b1), Some(c)) if b1 == b && c != a && c != b => c,
_ => return false,
};
let d = match (slot_vars[2][0], slot_vars[2][1]) {
(Some(c1), Some(d)) if c1 == c && d != a && d != b && d != c => d,
_ => return false,
};
matches!(
(slot_vars[3][0], slot_vars[3][1]),
(Some(d2), Some(a2)) if d2 == d && a2 == a
)
}
fn output_columns_match_canonical_4cycle(cols: &[ProjectExpr]) -> bool {
if cols.len() != 4 {
return false;
}
let exact = |idx: usize, want: usize| matches!(cols[idx], ProjectExpr::Column(c) if c == want);
let default_layout = exact(0, 0) && exact(1, 1) && exact(2, 3) && exact(3, 5);
let alt_layout = exact(0, 5) && exact(1, 0) && exact(2, 1) && exact(3, 3);
default_layout || alt_layout
}
fn scan_rel(node: &RirNode) -> Option<RelId> {
match node {
RirNode::Scan { rel } => Some(*rel),
_ => None,
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(super) enum WcojKeyWidth {
FourByte,
EightByte,
}
fn classify_two_col_wcoj_width(buf: &CudaBuffer) -> Option<WcojKeyWidth> {
if buf.arity() != 2 {
return None;
}
let c0 = buf.schema.column_type(0)?;
let c1 = buf.schema.column_type(1)?;
let w0 = scalar_wcoj_width(c0)?;
let w1 = scalar_wcoj_width(c1)?;
if w0 != w1 {
return None;
}
Some(w0)
}
fn scalar_wcoj_width(ty: xlog_core::ScalarType) -> Option<WcojKeyWidth> {
match ty {
xlog_core::ScalarType::U32 | xlog_core::ScalarType::Symbol => Some(WcojKeyWidth::FourByte),
xlog_core::ScalarType::U64 => Some(WcojKeyWidth::EightByte),
_ => None,
}
}
fn feedback_pair_from_var_order(
slot_rels: &[RelId],
var_order: Option<&VariableOrder>,
) -> Option<(RelId, RelId, Vec<usize>, Vec<usize>)> {
if slot_rels.len() < 2 {
return None;
}
let Some(vo) = var_order else {
return Some((slot_rels[0], slot_rels[1], vec![1], vec![0]));
};
let leader_idx = vo.leader_idx as usize;
match slot_rels.len() {
3 => {
match leader_idx {
0 => Some((slot_rels[0], slot_rels[1], vec![1], vec![0])),
1 => {
Some((slot_rels[1], slot_rels[2], vec![1], vec![1]))
}
2 => {
Some((slot_rels[2], slot_rels[1], vec![1], vec![1]))
}
_ => None,
}
}
4 => {
if leader_idx >= 4 {
return None;
}
let slot1_input_idx = (leader_idx + 1) % 4;
Some((
slot_rels[leader_idx],
slot_rels[slot1_input_idx],
vec![1],
vec![0],
))
}
_ => None,
}
}
fn perm_indices_from_kernel_output_cols(cols: &[ProjectExpr]) -> Result<Vec<usize>> {
let mut out = Vec::with_capacity(cols.len());
for c in cols {
match c {
ProjectExpr::Column(idx) => out.push(*idx),
other => {
return Err(xlog_core::XlogError::Kernel(format!(
"perm_indices_from_kernel_output_cols: \
W2.1 kernel_output_cols must be ProjectExpr::Column(_), got {:?}",
other
)));
}
}
}
Ok(out)
}
fn build_triangle_head_schema(buf_xy: &CudaBuffer, buf_yz: &CudaBuffer) -> Result<Schema> {
let x_type = buf_xy.schema.column_type(0).ok_or_else(|| {
xlog_core::XlogError::Kernel("build_triangle_head_schema: e_xy.col0 type missing".into())
})?;
let y_type = buf_xy.schema.column_type(1).ok_or_else(|| {
xlog_core::XlogError::Kernel("build_triangle_head_schema: e_xy.col1 type missing".into())
})?;
let z_type = buf_yz.schema.column_type(1).ok_or_else(|| {
xlog_core::XlogError::Kernel("build_triangle_head_schema: e_yz.col1 type missing".into())
})?;
Schema::new(vec![
("col0".to_string(), x_type),
("col1".to_string(), y_type),
("col2".to_string(), z_type),
])
.with_sort_labels(vec![
buf_xy
.schema
.column_sort_label(0)
.unwrap_or("col0")
.to_string(),
buf_xy
.schema
.column_sort_label(1)
.unwrap_or("col1")
.to_string(),
buf_yz
.schema
.column_sort_label(1)
.unwrap_or("col2")
.to_string(),
])
.map_err(xlog_core::XlogError::Kernel)
}
fn build_4cycle_head_schema(
buf_e1: &CudaBuffer,
buf_e2: &CudaBuffer,
buf_e3: &CudaBuffer,
) -> Result<Schema> {
let w_type = buf_e1.schema.column_type(0).ok_or_else(|| {
xlog_core::XlogError::Kernel("build_4cycle_head_schema: e_wx.col0 type missing".into())
})?;
let x_type = buf_e1.schema.column_type(1).ok_or_else(|| {
xlog_core::XlogError::Kernel("build_4cycle_head_schema: e_wx.col1 type missing".into())
})?;
let y_type = buf_e2.schema.column_type(1).ok_or_else(|| {
xlog_core::XlogError::Kernel("build_4cycle_head_schema: e_xy.col1 type missing".into())
})?;
let z_type = buf_e3.schema.column_type(1).ok_or_else(|| {
xlog_core::XlogError::Kernel("build_4cycle_head_schema: e_yz.col1 type missing".into())
})?;
let _: ScalarType = w_type;
Schema::new(vec![
("col0".to_string(), w_type),
("col1".to_string(), x_type),
("col2".to_string(), y_type),
("col3".to_string(), z_type),
])
.with_sort_labels(vec![
buf_e1
.schema
.column_sort_label(0)
.unwrap_or("col0")
.to_string(),
buf_e1
.schema
.column_sort_label(1)
.unwrap_or("col1")
.to_string(),
buf_e2
.schema
.column_sort_label(1)
.unwrap_or("col2")
.to_string(),
buf_e3
.schema
.column_sort_label(1)
.unwrap_or("col3")
.to_string(),
])
.map_err(xlog_core::XlogError::Kernel)
}
impl Executor {
pub(super) fn try_dispatch_wcoj_triangle(
&mut self,
rule: &CompiledRule,
) -> Result<Option<CudaBuffer>> {
self.try_dispatch_wcoj_triangle_on_body(&rule.body)
}
fn wcoj_output_rows(buf: &CudaBuffer) -> Option<u64> {
buf.cached_row_count().map(u64::from)
}
fn record_wcoj_feedback(
&mut self,
slot_rels: &[RelId],
var_order: Option<&VariableOrder>,
output_rows: Option<u64>,
) {
if slot_rels.len() < 2 {
return;
}
let Some(out_rows) = output_rows else {
return;
};
let Some((rel_a, rel_b, left_keys, right_keys)) =
feedback_pair_from_var_order(slot_rels, var_order)
else {
return;
};
let card_a = self
.stats
.get_relation_stats(rel_a)
.map(|s| s.cardinality)
.filter(|c| *c > 0);
let card_b = self
.stats
.get_relation_stats(rel_b)
.map(|s| s.cardinality)
.filter(|c| *c > 0);
let (Some(a), Some(b)) = (card_a, card_b) else {
return;
};
let input_rows = a.saturating_mul(b);
self.stats
.record_join_result(rel_a, rel_b, left_keys, right_keys, input_rows, out_rows);
}
pub(super) fn try_dispatch_wcoj_triangle_on_body(
&mut self,
body: &RirNode,
) -> Result<Option<CudaBuffer>> {
#[cfg(feature = "wcoj-phase-timing")]
let wall_start = Instant::now();
if self.config.wcoj_triangle_dispatch_disabled.unwrap_or(false) {
return Ok(None);
}
let force_override = self.config.wcoj_triangle_dispatch;
let force_on = wcoj_gate_enabled(force_override);
let mode = if force_on {
DispatchMode::Force
} else {
let force_explicit_off = matches!(force_override, Some(false));
if force_explicit_off {
return Ok(None);
}
let adaptive_override = self.config.wcoj_triangle_dispatch_adaptive;
if wcoj_adaptive_enabled(adaptive_override) {
DispatchMode::CostModel
} else {
return Ok(None);
}
};
let Some(matched) = match_multiway_triangle(body) else {
return Ok(None);
};
let name_xy = match self.get_rel_name(matched.rel_xy) {
Some(s) => s.to_string(),
None => return Ok(None),
};
let name_yz = match self.get_rel_name(matched.rel_yz) {
Some(s) => s.to_string(),
None => return Ok(None),
};
let name_xz = match self.get_rel_name(matched.rel_xz) {
Some(s) => s.to_string(),
None => return Ok(None),
};
let buf_xy = match self.store.get(&name_xy) {
Some(b) => b,
None => return Ok(None),
};
let buf_yz = match self.store.get(&name_yz) {
Some(b) => b,
None => return Ok(None),
};
let buf_xz = match self.store.get(&name_xz) {
Some(b) => b,
None => return Ok(None),
};
let width = match (
classify_two_col_wcoj_width(buf_xy),
classify_two_col_wcoj_width(buf_yz),
classify_two_col_wcoj_width(buf_xz),
) {
(Some(a), Some(b), Some(c)) if a == b && b == c => a,
_ => return Ok(None),
};
if self.provider.memory().runtime().is_none() {
return Ok(None);
}
let launch_stream = match self.wcoj_dispatch_stream_or_init() {
Some(s) => s,
None => return Ok(None),
};
#[cfg(feature = "wcoj-phase-timing")]
let mut classifier_ms: f32 = 0.0;
if mode == DispatchMode::CostModel {
#[cfg(feature = "wcoj-phase-timing")]
let cls_start = Instant::now();
let model = super::wcoj_cost_model::build_wcoj_cost_model(&self.config);
let slot_rels = [matched.rel_xy, matched.rel_yz, matched.rel_xz];
let ctx = super::wcoj_cost_model::WcojDispatchCtx {
stats: &self.stats,
launch_stream,
width,
slot_rels: &slot_rels,
};
let dispatch = model.should_dispatch_triangle(&ctx);
#[cfg(feature = "wcoj-phase-timing")]
{
classifier_ms = cls_start.elapsed().as_secs_f64() as f32 * 1000.0;
}
if !dispatch {
return Ok(None);
}
}
let var_order_opt: Option<&VariableOrder> = match body {
RirNode::MultiWayJoin { var_order, .. } => var_order.as_ref(),
_ => None,
};
#[cfg(feature = "wcoj-phase-timing")]
let mut layout_times: [f32; 3] = [0.0; 3];
let dispatch_result = self.run_wcoj_triangle_pipeline(
buf_xy,
buf_yz,
buf_xz,
launch_stream,
width,
var_order_opt,
#[cfg(feature = "wcoj-phase-timing")]
&mut layout_times,
);
match dispatch_result {
Ok(buf) => {
let output_rows = Self::wcoj_output_rows(&buf);
let slot_rels = [matched.rel_xy, matched.rel_yz, matched.rel_xz];
self.record_wcoj_feedback(&slot_rels, var_order_opt, output_rows);
self.wcoj_triangle_dispatch_count += 1;
#[cfg(feature = "wcoj-phase-timing")]
{
let triangle_timing = self
.provider
.take_wcoj_triangle_phase_timing()
.unwrap_or_default();
let wall_ms = wall_start.elapsed().as_secs_f64() as f32 * 1000.0;
let timing = super::wcoj_phase_timing::WcojDispatchPhaseTiming::new(
classifier_ms,
layout_times[0],
layout_times[1],
layout_times[2],
triangle_timing,
wall_ms,
);
if let Ok(mut g) = self.last_wcoj_phase_timing.lock() {
*g = Some(timing);
}
}
Ok(Some(buf))
}
Err(_) => Ok(None),
}
}
#[allow(clippy::too_many_arguments)]
fn run_wcoj_triangle_pipeline(
&self,
buf_xy: &CudaBuffer,
buf_yz: &CudaBuffer,
buf_xz: &CudaBuffer,
launch_stream: StreamId,
width: WcojKeyWidth,
var_order: Option<&VariableOrder>,
#[cfg(feature = "wcoj-phase-timing")] layout_times_ms: &mut [f32; 3],
) -> Result<CudaBuffer> {
if let Some(vo) = var_order {
return self.run_wcoj_triangle_pipeline_w21(
buf_xy,
buf_yz,
buf_xz,
launch_stream,
width,
vo,
);
}
#[cfg(feature = "wcoj-phase-timing")]
let mut time_layout =
|f: &dyn Fn() -> Result<CudaBuffer>, slot: usize| -> Result<CudaBuffer> {
let s = Instant::now();
let r = f()?;
layout_times_ms[slot] = s.elapsed().as_secs_f64() as f32 * 1000.0;
Ok(r)
};
match width {
WcojKeyWidth::FourByte => {
#[cfg(feature = "wcoj-phase-timing")]
let (layout_xy, layout_yz, layout_xz) = {
let xy = time_layout(
&|| {
self.provider
.wcoj_layout_u32_recorded(buf_xy, launch_stream)
},
0,
)?;
let yz = time_layout(
&|| {
self.provider
.wcoj_layout_u32_recorded(buf_yz, launch_stream)
},
1,
)?;
let xz = time_layout(
&|| {
self.provider
.wcoj_layout_u32_recorded(buf_xz, launch_stream)
},
2,
)?;
(xy, yz, xz)
};
#[cfg(not(feature = "wcoj-phase-timing"))]
let layout_xy = self
.provider
.wcoj_layout_u32_recorded(buf_xy, launch_stream)?;
#[cfg(not(feature = "wcoj-phase-timing"))]
let layout_yz = self
.provider
.wcoj_layout_u32_recorded(buf_yz, launch_stream)?;
#[cfg(not(feature = "wcoj-phase-timing"))]
let layout_xz = self
.provider
.wcoj_layout_u32_recorded(buf_xz, launch_stream)?;
let out = self.provider.wcoj_triangle_hg_u32_recorded(
&layout_xy,
&layout_yz,
&layout_xz,
wcoj_block_work_unit(),
launch_stream,
)?;
self.provider.record_wcoj_triangle_hg_dispatch();
Ok(out)
}
WcojKeyWidth::EightByte => {
#[cfg(feature = "wcoj-phase-timing")]
let (layout_xy, layout_yz, layout_xz) = {
let xy = time_layout(
&|| {
self.provider
.wcoj_layout_u64_recorded(buf_xy, launch_stream)
},
0,
)?;
let yz = time_layout(
&|| {
self.provider
.wcoj_layout_u64_recorded(buf_yz, launch_stream)
},
1,
)?;
let xz = time_layout(
&|| {
self.provider
.wcoj_layout_u64_recorded(buf_xz, launch_stream)
},
2,
)?;
(xy, yz, xz)
};
#[cfg(not(feature = "wcoj-phase-timing"))]
let layout_xy = self
.provider
.wcoj_layout_u64_recorded(buf_xy, launch_stream)?;
#[cfg(not(feature = "wcoj-phase-timing"))]
let layout_yz = self
.provider
.wcoj_layout_u64_recorded(buf_yz, launch_stream)?;
#[cfg(not(feature = "wcoj-phase-timing"))]
let layout_xz = self
.provider
.wcoj_layout_u64_recorded(buf_xz, launch_stream)?;
self.provider.wcoj_triangle_u64_recorded(
&layout_xy,
&layout_yz,
&layout_xz,
launch_stream,
)
}
}
}
fn run_wcoj_triangle_pipeline_w21(
&self,
buf_xy: &CudaBuffer,
buf_yz: &CudaBuffer,
buf_xz: &CudaBuffer,
launch_stream: StreamId,
width: WcojKeyWidth,
var_order: &VariableOrder,
) -> Result<CudaBuffer> {
let canonical: [&CudaBuffer; 3] = [buf_xy, buf_yz, buf_xz];
let slot_inputs = self.prepare_leader_inputs(&canonical, var_order, launch_stream)?;
if slot_inputs.len() != 3 {
return Err(xlog_core::XlogError::Kernel(
"run_wcoj_triangle_pipeline_w21: prepare_leader_inputs must return 3 slots"
.to_string(),
));
}
let head_schema = build_triangle_head_schema(buf_xy, buf_yz)?;
let perm = perm_indices_from_kernel_output_cols(&var_order.kernel_output_cols)?;
let kernel_out: CudaBuffer = match width {
WcojKeyWidth::FourByte => {
let l0 = self
.provider
.wcoj_layout_u32_recorded(&slot_inputs[0], launch_stream)?;
let l1 = self
.provider
.wcoj_layout_u32_recorded(&slot_inputs[1], launch_stream)?;
let l2 = self
.provider
.wcoj_layout_u32_recorded(&slot_inputs[2], launch_stream)?;
let out = self.provider.wcoj_triangle_hg_u32_recorded(
&l0,
&l1,
&l2,
wcoj_block_work_unit(),
launch_stream,
)?;
self.provider.record_wcoj_triangle_hg_dispatch();
out
}
WcojKeyWidth::EightByte => {
let l0 = self
.provider
.wcoj_layout_u64_recorded(&slot_inputs[0], launch_stream)?;
let l1 = self
.provider
.wcoj_layout_u64_recorded(&slot_inputs[1], launch_stream)?;
let l2 = self
.provider
.wcoj_layout_u64_recorded(&slot_inputs[2], launch_stream)?;
self.provider
.wcoj_triangle_u64_recorded(&l0, &l1, &l2, launch_stream)?
}
};
self.provider.wcoj_project_output_columns_recorded(
&kernel_out,
&perm,
head_schema,
launch_stream,
)
}
pub fn wcoj_triangle_dispatch_count(&self) -> u64 {
self.wcoj_triangle_dispatch_count
}
pub fn wcoj_4cycle_dispatch_count(&self) -> u64 {
self.wcoj_4cycle_dispatch_count
}
pub fn w63_chain_dispatch_count(&self) -> u64 {
self.w63_chain_dispatch_count
}
pub fn nested_loop_dispatch_count(&self) -> u64 {
self.nested_loop_dispatch_count
}
pub(super) fn try_dispatch_w63_chain_on_body(
&mut self,
body: &RirNode,
) -> Result<Option<CudaBuffer>> {
if !w63_chain_enabled() {
return Ok(None);
}
let Some(matched) = match_chain_join(body) else {
return Ok(None);
};
let name_left = match self.get_rel_name(matched.rel_left) {
Some(s) => s.to_string(),
None => return Ok(None),
};
let name_right = match self.get_rel_name(matched.rel_right) {
Some(s) => s.to_string(),
None => return Ok(None),
};
let left = match self.store.get(&name_left) {
Some(buf) => buf,
None => return Ok(None),
};
let right = match self.store.get(&name_right) {
Some(buf) => buf,
None => return Ok(None),
};
let num_left = self.provider.device_row_count(left)? as u64;
let num_right = self.provider.device_row_count(right)? as u64;
let in_threshold = num_left
.checked_mul(num_right)
.map(|p| p <= NESTED_LOOP_TOTAL_THRESHOLD)
.unwrap_or(false);
let four_byte = matches!(
classify_two_col_wcoj_width(left),
Some(WcojKeyWidth::FourByte)
) && matches!(
classify_two_col_wcoj_width(right),
Some(WcojKeyWidth::FourByte)
);
let mut used_nested_loop = false;
let joined = if four_byte {
let left_sorted = self
.provider
.is_sorted_ascending_u32(left, matched.left_key)
.unwrap_or(false);
let right_sorted = self
.provider
.is_sorted_ascending_u32(right, matched.right_key)
.unwrap_or(false);
if left_sorted && right_sorted {
if in_threshold {
self.provider.sort_merge_join_v2_inner_u32_1key(
left,
right,
matched.left_key,
matched.right_key,
)
} else {
let capacity = usize::try_from(num_left.min(num_right)).unwrap_or(usize::MAX);
self.provider.sort_merge_join_v2_inner_u32_1key_bounded(
left,
right,
matched.left_key,
matched.right_key,
capacity,
)
}
} else if in_threshold {
used_nested_loop = true;
self.provider.nested_loop_join_v2_inner_u32_1key(
left,
right,
matched.left_key,
matched.right_key,
)
} else {
self.provider.hash_join_v2(
left,
right,
&[matched.left_key],
&[matched.right_key],
CudaJoinType::Inner,
)
}
} else {
self.provider.hash_join_v2(
left,
right,
&[matched.left_key],
&[matched.right_key],
CudaJoinType::Inner,
)
};
let Ok(joined) = joined else {
return Ok(None);
};
let projected = match self.execute_project(&joined, &matched.output_columns) {
Ok(buf) => buf,
Err(_) => return Ok(None),
};
self.stats.record_join_result(
matched.rel_left,
matched.rel_right,
vec![matched.left_key],
vec![matched.right_key],
num_left.saturating_mul(num_right),
joined.num_rows(),
);
if used_nested_loop {
self.nested_loop_dispatch_count += 1;
}
self.w63_chain_dispatch_count += 1;
Ok(Some(projected))
}
pub(super) fn try_dispatch_wcoj_4cycle(
&mut self,
rule: &CompiledRule,
) -> Result<Option<CudaBuffer>> {
self.try_dispatch_wcoj_4cycle_on_body(&rule.body)
}
pub(super) fn try_dispatch_wcoj_4cycle_on_body(
&mut self,
body: &RirNode,
) -> Result<Option<CudaBuffer>> {
if wcoj_4cycle_disabled(self.config.wcoj_4cycle_dispatch_disabled) {
return Ok(None);
}
let force_override = self.config.wcoj_4cycle_dispatch;
let force_on = wcoj_4cycle_gate_enabled(force_override);
let mode = if force_on {
DispatchMode::Force
} else {
if matches!(force_override, Some(false)) {
return Ok(None);
}
let adaptive_override = self.config.wcoj_4cycle_dispatch_adaptive;
if wcoj_4cycle_adaptive_enabled(adaptive_override) {
DispatchMode::CostModel
} else {
return Ok(None);
}
};
let Some(matched) = match_multiway_4cycle(body) else {
return Ok(None);
};
let name_e1 = match self.get_rel_name(matched.rel_e1) {
Some(s) => s.to_string(),
None => return Ok(None),
};
let name_e2 = match self.get_rel_name(matched.rel_e2) {
Some(s) => s.to_string(),
None => return Ok(None),
};
let name_e3 = match self.get_rel_name(matched.rel_e3) {
Some(s) => s.to_string(),
None => return Ok(None),
};
let name_e4 = match self.get_rel_name(matched.rel_e4) {
Some(s) => s.to_string(),
None => return Ok(None),
};
let buf_e1 = match self.store.get(&name_e1) {
Some(b) => b,
None => return Ok(None),
};
let buf_e2 = match self.store.get(&name_e2) {
Some(b) => b,
None => return Ok(None),
};
let buf_e3 = match self.store.get(&name_e3) {
Some(b) => b,
None => return Ok(None),
};
let buf_e4 = match self.store.get(&name_e4) {
Some(b) => b,
None => return Ok(None),
};
let width = match (
classify_two_col_wcoj_width(buf_e1),
classify_two_col_wcoj_width(buf_e2),
classify_two_col_wcoj_width(buf_e3),
classify_two_col_wcoj_width(buf_e4),
) {
(Some(a), Some(b), Some(c), Some(d)) if a == b && b == c && c == d => a,
_ => return Ok(None),
};
if self.provider.memory().runtime().is_none() {
return Ok(None);
}
let launch_stream = match self.wcoj_dispatch_stream_or_init() {
Some(s) => s,
None => return Ok(None),
};
if mode == DispatchMode::CostModel {
let model = super::wcoj_cost_model::build_wcoj_cost_model(&self.config);
let slot_rels = [
matched.rel_e1,
matched.rel_e2,
matched.rel_e3,
matched.rel_e4,
];
let ctx = super::wcoj_cost_model::WcojDispatchCtx {
stats: &self.stats,
launch_stream,
width,
slot_rels: &slot_rels,
};
let dispatch = model.should_dispatch_4cycle(&ctx);
if !dispatch {
return Ok(None);
}
}
let var_order_opt: Option<&VariableOrder> = match body {
RirNode::MultiWayJoin { var_order, .. } => var_order.as_ref(),
_ => None,
};
let dispatch_result = self.run_wcoj_4cycle_pipeline(
buf_e1,
buf_e2,
buf_e3,
buf_e4,
launch_stream,
width,
var_order_opt,
);
match dispatch_result {
Ok(buf) => {
let output_rows = Self::wcoj_output_rows(&buf);
let slot_rels = [
matched.rel_e1,
matched.rel_e2,
matched.rel_e3,
matched.rel_e4,
];
self.record_wcoj_feedback(&slot_rels, var_order_opt, output_rows);
self.wcoj_4cycle_dispatch_count += 1;
Ok(Some(buf))
}
Err(_) => Ok(None),
}
}
#[allow(clippy::too_many_arguments)]
fn run_wcoj_4cycle_pipeline(
&self,
buf_e1: &CudaBuffer,
buf_e2: &CudaBuffer,
buf_e3: &CudaBuffer,
buf_e4: &CudaBuffer,
launch_stream: StreamId,
width: WcojKeyWidth,
var_order: Option<&VariableOrder>,
) -> Result<CudaBuffer> {
if let Some(vo) = var_order {
return self.run_wcoj_4cycle_pipeline_w21(
buf_e1,
buf_e2,
buf_e3,
buf_e4,
launch_stream,
width,
vo,
);
}
match width {
WcojKeyWidth::FourByte => {
let layout_e1 = self
.provider
.wcoj_layout_u32_recorded(buf_e1, launch_stream)?;
let layout_e2 = self
.provider
.wcoj_layout_u32_recorded(buf_e2, launch_stream)?;
let layout_e3 = self
.provider
.wcoj_layout_u32_recorded(buf_e3, launch_stream)?;
let layout_e4 = self
.provider
.wcoj_layout_u32_recorded(buf_e4, launch_stream)?;
self.provider.wcoj_4cycle_u32_recorded(
&layout_e1,
&layout_e2,
&layout_e3,
&layout_e4,
launch_stream,
)
}
WcojKeyWidth::EightByte => {
let layout_e1 = self
.provider
.wcoj_layout_u64_recorded(buf_e1, launch_stream)?;
let layout_e2 = self
.provider
.wcoj_layout_u64_recorded(buf_e2, launch_stream)?;
let layout_e3 = self
.provider
.wcoj_layout_u64_recorded(buf_e3, launch_stream)?;
let layout_e4 = self
.provider
.wcoj_layout_u64_recorded(buf_e4, launch_stream)?;
self.provider.wcoj_4cycle_u64_recorded(
&layout_e1,
&layout_e2,
&layout_e3,
&layout_e4,
launch_stream,
)
}
}
}
#[allow(clippy::too_many_arguments)]
fn run_wcoj_4cycle_pipeline_w21(
&self,
buf_e1: &CudaBuffer,
buf_e2: &CudaBuffer,
buf_e3: &CudaBuffer,
buf_e4: &CudaBuffer,
launch_stream: StreamId,
width: WcojKeyWidth,
var_order: &VariableOrder,
) -> Result<CudaBuffer> {
let canonical: [&CudaBuffer; 4] = [buf_e1, buf_e2, buf_e3, buf_e4];
let slot_inputs = self.prepare_leader_inputs(&canonical, var_order, launch_stream)?;
if slot_inputs.len() != 4 {
return Err(xlog_core::XlogError::Kernel(
"run_wcoj_4cycle_pipeline_w21: prepare_leader_inputs must return 4 slots"
.to_string(),
));
}
let head_schema = build_4cycle_head_schema(buf_e1, buf_e2, buf_e3)?;
let perm = perm_indices_from_kernel_output_cols(&var_order.kernel_output_cols)?;
let kernel_out: CudaBuffer = match width {
WcojKeyWidth::FourByte => {
let l0 = self
.provider
.wcoj_layout_u32_recorded(&slot_inputs[0], launch_stream)?;
let l1 = self
.provider
.wcoj_layout_u32_recorded(&slot_inputs[1], launch_stream)?;
let l2 = self
.provider
.wcoj_layout_u32_recorded(&slot_inputs[2], launch_stream)?;
let l3 = self
.provider
.wcoj_layout_u32_recorded(&slot_inputs[3], launch_stream)?;
self.provider
.wcoj_4cycle_u32_recorded(&l0, &l1, &l2, &l3, launch_stream)?
}
WcojKeyWidth::EightByte => {
let l0 = self
.provider
.wcoj_layout_u64_recorded(&slot_inputs[0], launch_stream)?;
let l1 = self
.provider
.wcoj_layout_u64_recorded(&slot_inputs[1], launch_stream)?;
let l2 = self
.provider
.wcoj_layout_u64_recorded(&slot_inputs[2], launch_stream)?;
let l3 = self
.provider
.wcoj_layout_u64_recorded(&slot_inputs[3], launch_stream)?;
self.provider
.wcoj_4cycle_u64_recorded(&l0, &l1, &l2, &l3, launch_stream)?
}
};
self.provider.wcoj_project_output_columns_recorded(
&kernel_out,
&perm,
head_schema,
launch_stream,
)
}
pub fn prepare_leader_inputs(
&self,
canonical: &[&CudaBuffer],
var_order: &VariableOrder,
launch_stream: StreamId,
) -> Result<Vec<CudaBuffer>> {
let n = canonical.len();
if !(n == 3 || n == 4) {
return Err(xlog_core::XlogError::Kernel(format!(
"prepare_leader_inputs: canonical inputs must be 3 (triangle) or 4 (4-cycle), got {n}"
)));
}
let leader_idx = var_order.leader_idx as usize;
if leader_idx >= n {
return Err(xlog_core::XlogError::Kernel(format!(
"prepare_leader_inputs: leader_idx {leader_idx} out of range for arity {n}"
)));
}
if var_order.lookup_perms.len() != n - 1 {
return Err(xlog_core::XlogError::Kernel(format!(
"prepare_leader_inputs: lookup_perms.len() = {} must equal {} (arity - 1)",
var_order.lookup_perms.len(),
n - 1
)));
}
for (slot, lp) in var_order.lookup_perms.iter().enumerate() {
let input_idx = lp.input_idx as usize;
if input_idx >= n {
return Err(xlog_core::XlogError::Kernel(format!(
"prepare_leader_inputs: lookup_perms[{slot}].input_idx {input_idx} out of range for arity {n}"
)));
}
}
if n == 4 {
for lp in &var_order.lookup_perms {
if lp.swap_cols {
return Err(xlog_core::XlogError::Kernel(
"prepare_leader_inputs: 4-cycle does not support col-swaps".to_string(),
));
}
}
}
let mut slots: Vec<CudaBuffer> = Vec::with_capacity(n);
slots.push(self.clone_buffer_via_swap(canonical[leader_idx], launch_stream)?);
for lp in &var_order.lookup_perms {
let src = canonical[lp.input_idx as usize];
let buf = if lp.swap_cols {
self.provider
.wcoj_project_2col_swap_recorded(src, launch_stream)?
} else {
self.clone_buffer_via_swap(src, launch_stream)?
};
slots.push(buf);
}
Ok(slots)
}
fn clone_buffer_via_swap(
&self,
src: &CudaBuffer,
launch_stream: StreamId,
) -> Result<CudaBuffer> {
let once = self
.provider
.wcoj_project_2col_swap_recorded(src, launch_stream)?;
self.provider
.wcoj_project_2col_swap_recorded(&once, launch_stream)
}
pub fn wcoj_dispatch_stream_or_init(&self) -> Option<StreamId> {
if let Some(s) = self.wcoj_dispatch_stream.get() {
return Some(*s);
}
let runtime = self.provider.memory().runtime()?;
let stream = runtime.stream_pool().acquire().ok()?;
let _ = self.wcoj_dispatch_stream.set(stream);
self.wcoj_dispatch_stream.get().copied()
}
}
impl Executor {
pub fn wcoj_clique5_dispatch_count(&self) -> u64 {
self.wcoj_clique5_dispatch_count
}
pub fn wcoj_clique6_dispatch_count(&self) -> u64 {
self.wcoj_clique6_dispatch_count
}
pub fn wcoj_clique7_dispatch_count(&self) -> u64 {
self.wcoj_clique7_dispatch_count
}
pub fn wcoj_clique8_dispatch_count(&self) -> u64 {
self.wcoj_clique8_dispatch_count
}
pub fn kclique_histogram_refresh_count(&self) -> u64 {
self.kclique_histogram_refresh_count
}
pub fn kclique_histogram_refresh_nanos(&self) -> u128 {
self.kclique_histogram_refresh_nanos
}
pub(super) fn try_dispatch_wcoj_clique5(
&mut self,
rule: &CompiledRule,
) -> Result<Option<CudaBuffer>> {
self.try_dispatch_wcoj_clique5_on_body(&rule.body)
}
pub(super) fn try_dispatch_wcoj_clique6(
&mut self,
rule: &CompiledRule,
) -> Result<Option<CudaBuffer>> {
self.try_dispatch_wcoj_clique6_on_body(&rule.body)
}
pub(super) fn try_dispatch_wcoj_clique7(
&mut self,
rule: &CompiledRule,
) -> Result<Option<CudaBuffer>> {
self.try_dispatch_wcoj_clique7_on_body(&rule.body)
}
pub(super) fn try_dispatch_wcoj_clique8(
&mut self,
rule: &CompiledRule,
) -> Result<Option<CudaBuffer>> {
self.try_dispatch_wcoj_clique8_on_body(&rule.body)
}
pub(super) fn try_dispatch_wcoj_clique5_on_body(
&mut self,
body: &RirNode,
) -> Result<Option<CudaBuffer>> {
self.try_dispatch_wcoj_clique_k_on_body(body, 5)
}
pub(super) fn try_dispatch_wcoj_clique6_on_body(
&mut self,
body: &RirNode,
) -> Result<Option<CudaBuffer>> {
self.try_dispatch_wcoj_clique_k_on_body(body, 6)
}
pub(super) fn try_dispatch_wcoj_clique7_on_body(
&mut self,
body: &RirNode,
) -> Result<Option<CudaBuffer>> {
self.try_dispatch_wcoj_clique_k_on_body(body, 7)
}
pub(super) fn try_dispatch_wcoj_clique8_on_body(
&mut self,
body: &RirNode,
) -> Result<Option<CudaBuffer>> {
self.try_dispatch_wcoj_clique_k_on_body(body, 8)
}
fn try_dispatch_wcoj_clique_k_on_body(
&mut self,
body: &RirNode,
k: usize,
) -> Result<Option<CudaBuffer>> {
let expected_edges = k * (k - 1) / 2;
let RirNode::MultiWayJoin {
inputs,
plan,
var_order,
..
} = body
else {
return Ok(None);
};
if matches!(plan, Some(MultiwayPlan::PlannedHashRoute { .. })) {
return Ok(None);
}
if inputs.len() != expected_edges {
return Ok(None);
}
let kclique = match var_order.as_ref().and_then(|order| order.kclique.as_ref()) {
Some(plan) if usize::from(plan.k) == k => plan,
_ => return Ok(None),
};
let mut rel_ids: Vec<RelId> = Vec::with_capacity(expected_edges);
for input in inputs {
let RirNode::Scan { rel } = input else {
return Ok(None);
};
rel_ids.push(*rel);
}
let mut raw_bufs: Vec<&CudaBuffer> = Vec::with_capacity(expected_edges);
for rid in &rel_ids {
let name = match self.rel_names.get(rid) {
Some(n) => n.clone(),
None => return Ok(None),
};
match self.store.get(&name) {
Some(b) => raw_bufs.push(b),
None => return Ok(None),
}
}
let launch_stream = match self.wcoj_dispatch_stream_or_init() {
Some(s) => s,
None => return Ok(None),
};
let first_ty = match raw_bufs[0].schema.column_type(0) {
Some(t) => t,
None => return Ok(None),
};
let is_u64 = matches!(first_ty, xlog_core::ScalarType::U64);
let is_4byte = matches!(
first_ty,
xlog_core::ScalarType::U32 | xlog_core::ScalarType::Symbol
);
if !is_u64 && !is_4byte {
return Ok(None);
}
let Some(plan_params) = kclique_dispatch_params(kclique, k) else {
return Ok(None);
};
let head_schema = match build_kclique_head_schema(&raw_bufs, k) {
Some(schema) => schema,
None => return Ok(None),
};
let output_perm = match kclique_output_perm(kclique, k) {
Some(perm) => perm,
None => return Ok(None),
};
let mut laid_out: Vec<CudaBuffer> = Vec::with_capacity(expected_edges);
for (slot, &input_idx) in plan_params.edge_permutation.iter().enumerate() {
let src = raw_bufs[input_idx];
let swapped = if plan_params.swap_slots.contains(&slot) {
Some(
self.provider
.wcoj_project_2col_swap_recorded(src, launch_stream)?,
)
} else {
None
};
let oriented = swapped.as_ref().unwrap_or(src);
let res = if plan_params.required_sort_slots.contains(&slot) {
if is_u64 {
self.provider
.wcoj_layout_sort_u64_recorded(oriented, launch_stream)
} else {
self.provider
.wcoj_layout_sort_u32_recorded(oriented, launch_stream)
}
} else if is_u64 {
self.provider
.wcoj_layout_u64_recorded(oriented, launch_stream)
} else {
self.provider
.wcoj_layout_u32_recorded(oriented, launch_stream)
};
match res {
Ok(b) => laid_out.push(b),
Err(_) => return Ok(None),
}
}
let edge_refs: Vec<&CudaBuffer> = laid_out.iter().collect();
let result = match (k, is_u64) {
(5, false) => {
let arr: &[&CudaBuffer; 10] = match edge_refs.as_slice().try_into() {
Ok(a) => a,
Err(_) => return Ok(None),
};
self.provider.wcoj_clique5_u32_recorded_planned(
arr,
plan_params.leader_edge_idx,
&plan_params.edge_order,
&plan_params.iteration_order,
launch_stream,
)
}
(5, true) => {
let arr: &[&CudaBuffer; 10] = match edge_refs.as_slice().try_into() {
Ok(a) => a,
Err(_) => return Ok(None),
};
self.provider.wcoj_clique5_u64_recorded_planned(
arr,
plan_params.leader_edge_idx,
&plan_params.edge_order,
&plan_params.iteration_order,
launch_stream,
)
}
(6, false) => {
let arr: &[&CudaBuffer; 15] = match edge_refs.as_slice().try_into() {
Ok(a) => a,
Err(_) => return Ok(None),
};
self.provider.wcoj_clique6_u32_recorded_planned(
arr,
plan_params.leader_edge_idx,
&plan_params.edge_order,
&plan_params.iteration_order,
launch_stream,
)
}
(6, true) => {
let arr: &[&CudaBuffer; 15] = match edge_refs.as_slice().try_into() {
Ok(a) => a,
Err(_) => return Ok(None),
};
self.provider.wcoj_clique6_u64_recorded_planned(
arr,
plan_params.leader_edge_idx,
&plan_params.edge_order,
&plan_params.iteration_order,
launch_stream,
)
}
(7, false) => {
let arr: &[&CudaBuffer; 21] = match edge_refs.as_slice().try_into() {
Ok(a) => a,
Err(_) => return Ok(None),
};
self.provider.wcoj_clique7_u32_recorded_planned(
arr,
plan_params.leader_edge_idx,
&plan_params.edge_order,
&plan_params.iteration_order,
launch_stream,
)
}
(7, true) => {
let arr: &[&CudaBuffer; 21] = match edge_refs.as_slice().try_into() {
Ok(a) => a,
Err(_) => return Ok(None),
};
self.provider.wcoj_clique7_u64_recorded_planned(
arr,
plan_params.leader_edge_idx,
&plan_params.edge_order,
&plan_params.iteration_order,
launch_stream,
)
}
(8, false) => {
let arr: &[&CudaBuffer; 28] = match edge_refs.as_slice().try_into() {
Ok(a) => a,
Err(_) => return Ok(None),
};
self.provider.wcoj_clique8_u32_recorded_planned(
arr,
plan_params.leader_edge_idx,
&plan_params.edge_order,
&plan_params.iteration_order,
launch_stream,
)
}
(8, true) => {
let arr: &[&CudaBuffer; 28] = match edge_refs.as_slice().try_into() {
Ok(a) => a,
Err(_) => return Ok(None),
};
self.provider.wcoj_clique8_u64_recorded_planned(
arr,
plan_params.leader_edge_idx,
&plan_params.edge_order,
&plan_params.iteration_order,
launch_stream,
)
}
_ => return Ok(None),
};
match result {
Ok(buf) => {
let buf = if output_perm.iter().copied().eq(0..output_perm.len()) {
buf
} else {
self.provider.wcoj_project_output_columns_recorded(
&buf,
&output_perm,
head_schema,
launch_stream,
)?
};
match k {
5 => self.wcoj_clique5_dispatch_count += 1,
6 => self.wcoj_clique6_dispatch_count += 1,
7 => self.wcoj_clique7_dispatch_count += 1,
8 => self.wcoj_clique8_dispatch_count += 1,
_ => {}
}
Ok(Some(buf))
}
Err(_) => Ok(None),
}
}
}
#[derive(Debug)]
struct KCliqueDispatchParams {
edge_permutation: Vec<usize>,
edge_order: Vec<u8>,
iteration_order: Vec<u8>,
leader_edge_idx: u32,
swap_slots: HashSet<usize>,
required_sort_slots: HashSet<usize>,
}
fn kclique_dispatch_params(plan: &KCliqueVariableOrder, k: usize) -> Option<KCliqueDispatchParams> {
let expected_edges = k * (k - 1) / 2;
let edge_permutation = live_kclique_edge_permutation(plan, expected_edges)?;
let positions = live_kclique_variable_positions(plan, k)?;
let mut edge_order = vec![u8::MAX; expected_edges];
for (slot, &edge_idx) in edge_permutation.iter().enumerate() {
let (left, right) = clique_edge_pair(edge_idx, k)?;
let left_pos = positions[left];
let right_pos = positions[right];
let logical_edge =
clique_edge_idx_runtime(left_pos.min(right_pos), left_pos.max(right_pos), k)?;
edge_order[logical_edge] = u8::try_from(slot).ok()?;
}
if edge_order.contains(&u8::MAX) {
return None;
}
let leader_edge_idx = u32::from(edge_order[clique_edge_idx_runtime(0, 1, k)?]);
let iteration_order: Vec<u8> = (0..k)
.map(|idx| u8::try_from(idx).ok())
.collect::<Option<_>>()?;
let swap_slots: HashSet<usize> = plan
.column_swaps
.iter()
.filter(|swap| swap.swap_cols)
.map(|swap| usize::from(swap.edge_slot))
.collect();
if swap_slots.iter().any(|slot| *slot >= expected_edges) {
return None;
}
let required_sort_slots: HashSet<usize> = plan
.sorted_layout_requirements
.edge_slots
.iter()
.copied()
.map(usize::from)
.collect();
if required_sort_slots
.iter()
.any(|slot| *slot >= expected_edges)
{
return None;
}
Some(KCliqueDispatchParams {
edge_permutation,
edge_order,
iteration_order,
leader_edge_idx,
swap_slots,
required_sort_slots,
})
}
fn live_kclique_edge_permutation(
plan: &KCliqueVariableOrder,
expected_edges: usize,
) -> Option<Vec<usize>> {
let values: Vec<usize> = plan
.edge_permutation
.iter()
.copied()
.take_while(|value| *value != u8::MAX)
.map(usize::from)
.collect();
if values.len() != expected_edges {
return None;
}
let mut seen = vec![false; expected_edges];
for &value in &values {
if value >= expected_edges || seen[value] {
return None;
}
seen[value] = true;
}
Some(values)
}
fn live_kclique_variable_positions(plan: &KCliqueVariableOrder, k: usize) -> Option<Vec<usize>> {
let mut positions = Vec::with_capacity(k);
let mut seen = vec![false; k];
for original_var in 0..k {
let pos = usize::from(*plan.variable_positions.get(original_var)?);
if pos >= k || seen[pos] {
return None;
}
seen[pos] = true;
positions.push(pos);
}
Some(positions)
}
fn clique_edge_idx_runtime(i: usize, j: usize, k: usize) -> Option<usize> {
if !(i < j && j < k) {
return None;
}
Some(i * (k - 1) - i.saturating_sub(1) * i / 2 + (j - i - 1))
}
fn clique_edge_pair(edge_idx: usize, k: usize) -> Option<(usize, usize)> {
let mut idx = 0usize;
for i in 0..k {
for j in (i + 1)..k {
if idx == edge_idx {
return Some((i, j));
}
idx += 1;
}
}
None
}
fn build_kclique_head_schema(raw_bufs: &[&CudaBuffer], k: usize) -> Option<Schema> {
let mut columns = Vec::with_capacity(k);
for variable in 0..k {
let (edge_idx, col_idx) = if variable == 0 {
(clique_edge_idx_runtime(0, 1, k)?, 0)
} else {
(clique_edge_idx_runtime(0, variable, k)?, 1)
};
let ty = raw_bufs.get(edge_idx)?.schema.column_type(col_idx)?;
columns.push((format!("col{}", variable), ty));
}
Some(Schema::new(columns))
}
fn kclique_output_perm(plan: &KCliqueVariableOrder, k: usize) -> Option<Vec<usize>> {
let positions = live_kclique_variable_positions(plan, k)?;
Some(positions)
}
#[cfg(test)]
mod tests {
use std::sync::{Mutex, OnceLock};
use super::{
match_chain_join, match_multiway_triangle, w63_chain_enabled, wcoj_adaptive_enabled,
wcoj_gate_enabled, ENV_USE_WCOJ_TRIANGLE_U32, ENV_WCOJ_W63_CHAIN_ENABLE,
};
use xlog_core::RelId;
use xlog_ir::rir::ProjectExpr;
use xlog_ir::RirNode;
fn canonical_multiway() -> RirNode {
RirNode::MultiWayJoin {
inputs: vec![
RirNode::Scan { rel: RelId(1) },
RirNode::Scan { rel: RelId(2) },
RirNode::Scan { rel: RelId(3) },
],
slot_vars: vec![
vec![Some(0u32), Some(1)],
vec![Some(1u32), Some(2)],
vec![Some(0u32), Some(2)],
],
output_columns: vec![
ProjectExpr::Column(0),
ProjectExpr::Column(1),
ProjectExpr::Column(3),
],
fallback: Box::new(RirNode::Unit),
plan: None,
var_order: None,
}
}
fn canonical_chain_join() -> RirNode {
RirNode::ChainJoin {
left: Box::new(RirNode::Scan { rel: RelId(1) }),
right: Box::new(RirNode::Scan { rel: RelId(2) }),
left_key: 1,
right_key: 0,
output_columns: vec![ProjectExpr::Column(0), ProjectExpr::Column(3)],
fallback: Box::new(RirNode::Unit),
}
}
#[test]
fn match_chain_returns_two_rels_and_keys() {
let node = canonical_chain_join();
let m = match_chain_join(&node).expect("must match canonical chain");
assert_eq!(m.rel_left, RelId(1));
assert_eq!(m.rel_right, RelId(2));
assert_eq!(m.left_key, 1);
assert_eq!(m.right_key, 0);
assert_eq!(
m.output_columns,
vec![ProjectExpr::Column(0), ProjectExpr::Column(3)]
);
}
#[test]
fn match_chain_rejects_non_scan_inputs() {
let mut node = canonical_chain_join();
if let RirNode::ChainJoin { left, .. } = &mut node {
**left = RirNode::Unit;
}
assert!(match_chain_join(&node).is_none());
}
#[test]
fn match_chain_rejects_multiway_triangle() {
let node = canonical_multiway();
assert!(match_chain_join(&node).is_none());
}
#[test]
fn w63_chain_env_defaults_on_and_can_disable() {
static ENV_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
let _guard = ENV_LOCK.get_or_init(|| Mutex::new(())).lock().unwrap();
let old = std::env::var(ENV_WCOJ_W63_CHAIN_ENABLE).ok();
unsafe {
std::env::remove_var(ENV_WCOJ_W63_CHAIN_ENABLE);
}
assert!(w63_chain_enabled());
unsafe {
std::env::set_var(ENV_WCOJ_W63_CHAIN_ENABLE, "0");
}
assert!(!w63_chain_enabled());
unsafe {
std::env::set_var(ENV_WCOJ_W63_CHAIN_ENABLE, "false");
}
assert!(!w63_chain_enabled());
unsafe {
std::env::set_var(ENV_WCOJ_W63_CHAIN_ENABLE, "1");
}
assert!(w63_chain_enabled());
unsafe {
match old {
Some(v) => std::env::set_var(ENV_WCOJ_W63_CHAIN_ENABLE, v),
None => std::env::remove_var(ENV_WCOJ_W63_CHAIN_ENABLE),
}
}
}
#[test]
fn match_canonical_returns_three_rels() {
let node = canonical_multiway();
let m = match_multiway_triangle(&node).expect("must match canonical triangle");
assert_eq!(m.rel_xy, RelId(1));
assert_eq!(m.rel_yz, RelId(2));
assert_eq!(m.rel_xz, RelId(3));
}
#[test]
fn match_rejects_non_multiway_body() {
let node = RirNode::Scan { rel: RelId(1) };
assert!(match_multiway_triangle(&node).is_none());
}
#[test]
fn match_rejects_rotated_output_columns() {
let mut node = canonical_multiway();
if let RirNode::MultiWayJoin { output_columns, .. } = &mut node {
*output_columns = vec![
ProjectExpr::Column(1),
ProjectExpr::Column(0),
ProjectExpr::Column(3),
];
}
assert!(match_multiway_triangle(&node).is_none());
}
#[test]
fn match_accepts_w22_z_shared_triangle_output_columns() {
let mut node = canonical_multiway();
if let RirNode::MultiWayJoin { output_columns, .. } = &mut node {
*output_columns = vec![
ProjectExpr::Column(0),
ProjectExpr::Column(2),
ProjectExpr::Column(3),
];
}
let m = match_multiway_triangle(&node)
.expect("W2.2 matcher must accept the Z-shared output-column layout");
assert_eq!(m.rel_xy, RelId(1));
assert_eq!(m.rel_yz, RelId(2));
assert_eq!(m.rel_xz, RelId(3));
}
#[test]
fn match_rejects_invalid_w22_triangle_output_columns() {
let mut node = canonical_multiway();
if let RirNode::MultiWayJoin { output_columns, .. } = &mut node {
*output_columns = vec![
ProjectExpr::Column(0),
ProjectExpr::Column(3),
ProjectExpr::Column(3),
];
}
assert!(match_multiway_triangle(&node).is_none());
}
#[test]
fn match_rejects_arity_mismatched_output_columns() {
let mut node = canonical_multiway();
if let RirNode::MultiWayJoin { output_columns, .. } = &mut node {
*output_columns = vec![ProjectExpr::Column(0), ProjectExpr::Column(1)];
}
assert!(match_multiway_triangle(&node).is_none());
}
#[test]
fn match_rejects_malformed_slot_vars() {
let mut node = canonical_multiway();
if let RirNode::MultiWayJoin { slot_vars, .. } = &mut node {
*slot_vars = vec![
vec![Some(0u32), Some(1)],
vec![Some(1u32), Some(2)],
vec![Some(0u32), Some(1)],
];
}
assert!(match_multiway_triangle(&node).is_none());
}
#[test]
fn match_rejects_repeated_var_in_slot() {
let mut node = canonical_multiway();
if let RirNode::MultiWayJoin { slot_vars, .. } = &mut node {
*slot_vars = vec![
vec![Some(0u32), Some(0)],
vec![Some(1u32), Some(2)],
vec![Some(0u32), Some(2)],
];
}
assert!(match_multiway_triangle(&node).is_none());
}
#[test]
fn match_rejects_non_scan_input() {
let mut node = canonical_multiway();
if let RirNode::MultiWayJoin { inputs, .. } = &mut node {
inputs[0] = RirNode::Unit;
}
assert!(match_multiway_triangle(&node).is_none());
}
#[test]
fn match_rejects_input_arity_mismatch() {
let mut node = canonical_multiway();
if let RirNode::MultiWayJoin { inputs, .. } = &mut node {
inputs.pop();
}
assert!(match_multiway_triangle(&node).is_none());
}
fn env_lock() -> &'static Mutex<()> {
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
LOCK.get_or_init(|| Mutex::new(()))
}
struct EnvSnapshot {
force: Option<String>,
}
impl EnvSnapshot {
fn capture_and_clear() -> Self {
let snapshot = Self {
force: std::env::var(ENV_USE_WCOJ_TRIANGLE_U32).ok(),
};
unsafe {
std::env::remove_var(ENV_USE_WCOJ_TRIANGLE_U32);
}
snapshot
}
}
impl Drop for EnvSnapshot {
fn drop(&mut self) {
unsafe {
match self.force.take() {
Some(v) => std::env::set_var(ENV_USE_WCOJ_TRIANGLE_U32, v),
None => std::env::remove_var(ENV_USE_WCOJ_TRIANGLE_U32),
}
}
}
}
fn with_wcoj_env<R>(f: impl FnOnce() -> R) -> R {
let _guard = env_lock().lock().expect("WCOJ env lock poisoned");
let _snapshot = EnvSnapshot::capture_and_clear();
f()
}
fn set_env(name: &str, value: &str) {
unsafe {
std::env::set_var(name, value);
}
}
#[test]
fn stats_gate_defaults_on_when_env_unset() {
with_wcoj_env(|| {
assert!(wcoj_adaptive_enabled(None));
assert!(wcoj_adaptive_enabled(Some(true)));
assert!(!wcoj_adaptive_enabled(Some(false)));
});
}
#[test]
fn config_controls_stats_gate() {
with_wcoj_env(|| {
assert!(wcoj_adaptive_enabled(Some(true)));
assert!(!wcoj_adaptive_enabled(Some(false)));
});
}
#[test]
fn force_resolver_config_still_overrides_env() {
with_wcoj_env(|| {
set_env(ENV_USE_WCOJ_TRIANGLE_U32, "1");
assert!(wcoj_gate_enabled(None));
assert!(!wcoj_gate_enabled(Some(false)));
set_env(ENV_USE_WCOJ_TRIANGLE_U32, "0");
assert!(!wcoj_gate_enabled(None));
assert!(wcoj_gate_enabled(Some(true)));
});
}
use super::{
match_multiway_4cycle, wcoj_4cycle_adaptive_enabled, wcoj_4cycle_disabled,
wcoj_4cycle_gate_enabled, ENV_DISABLE_WCOJ_4CYCLE, ENV_USE_WCOJ_4CYCLE,
ENV_USE_WCOJ_4CYCLE_ADAPTIVE,
};
struct EnvSnapshot4Cycle {
force: Option<String>,
adaptive: Option<String>,
disable: Option<String>,
}
impl EnvSnapshot4Cycle {
fn capture_and_clear() -> Self {
let snap = Self {
force: std::env::var(ENV_USE_WCOJ_4CYCLE).ok(),
adaptive: std::env::var(ENV_USE_WCOJ_4CYCLE_ADAPTIVE).ok(),
disable: std::env::var(ENV_DISABLE_WCOJ_4CYCLE).ok(),
};
unsafe {
std::env::remove_var(ENV_USE_WCOJ_4CYCLE);
std::env::remove_var(ENV_USE_WCOJ_4CYCLE_ADAPTIVE);
std::env::remove_var(ENV_DISABLE_WCOJ_4CYCLE);
}
snap
}
}
impl Drop for EnvSnapshot4Cycle {
fn drop(&mut self) {
unsafe {
match self.force.take() {
Some(v) => std::env::set_var(ENV_USE_WCOJ_4CYCLE, v),
None => std::env::remove_var(ENV_USE_WCOJ_4CYCLE),
}
match self.adaptive.take() {
Some(v) => std::env::set_var(ENV_USE_WCOJ_4CYCLE_ADAPTIVE, v),
None => std::env::remove_var(ENV_USE_WCOJ_4CYCLE_ADAPTIVE),
}
match self.disable.take() {
Some(v) => std::env::set_var(ENV_DISABLE_WCOJ_4CYCLE, v),
None => std::env::remove_var(ENV_DISABLE_WCOJ_4CYCLE),
}
}
}
}
fn with_4cycle_env<R>(f: impl FnOnce() -> R) -> R {
let _guard = env_lock().lock().expect("4-cycle env lock poisoned");
let _snap = EnvSnapshot4Cycle::capture_and_clear();
f()
}
#[test]
fn force_4cycle_resolver_defaults_off_when_env_unset() {
with_4cycle_env(|| {
assert!(!wcoj_4cycle_gate_enabled(None));
assert!(wcoj_4cycle_gate_enabled(Some(true)));
assert!(!wcoj_4cycle_gate_enabled(Some(false)));
});
}
#[test]
fn force_4cycle_resolver_env_can_enable() {
with_4cycle_env(|| {
set_env(ENV_USE_WCOJ_4CYCLE, "1");
assert!(wcoj_4cycle_gate_enabled(None));
set_env(ENV_USE_WCOJ_4CYCLE, "true");
assert!(wcoj_4cycle_gate_enabled(None));
set_env(ENV_USE_WCOJ_4CYCLE, "0");
assert!(!wcoj_4cycle_gate_enabled(None));
});
}
#[test]
fn adaptive_4cycle_resolver_defaults_off_when_env_unset() {
with_4cycle_env(|| {
assert!(
!wcoj_4cycle_adaptive_enabled(None),
"4-cycle adaptive must be OPT-IN by default (unlike triangle's default-on)"
);
assert!(wcoj_4cycle_adaptive_enabled(Some(true)));
assert!(!wcoj_4cycle_adaptive_enabled(Some(false)));
});
}
#[test]
fn adaptive_4cycle_resolver_env_can_enable() {
with_4cycle_env(|| {
set_env(ENV_USE_WCOJ_4CYCLE_ADAPTIVE, "1");
assert!(wcoj_4cycle_adaptive_enabled(None));
set_env(ENV_USE_WCOJ_4CYCLE_ADAPTIVE, "0");
assert!(!wcoj_4cycle_adaptive_enabled(None));
set_env(ENV_USE_WCOJ_4CYCLE_ADAPTIVE, "true");
assert!(wcoj_4cycle_adaptive_enabled(None));
});
}
#[test]
fn kill_4cycle_resolver_honors_env_and_config() {
with_4cycle_env(|| {
assert!(!wcoj_4cycle_disabled(None));
set_env(ENV_DISABLE_WCOJ_4CYCLE, "1");
assert!(wcoj_4cycle_disabled(None));
assert!(!wcoj_4cycle_disabled(Some(false)));
set_env(ENV_DISABLE_WCOJ_4CYCLE, "0");
assert!(wcoj_4cycle_disabled(Some(true)));
});
}
fn canonical_4cycle_multiway() -> RirNode {
RirNode::MultiWayJoin {
inputs: vec![
RirNode::Scan { rel: RelId(1) },
RirNode::Scan { rel: RelId(2) },
RirNode::Scan { rel: RelId(3) },
RirNode::Scan { rel: RelId(4) },
],
slot_vars: vec![
vec![Some(0u32), Some(1)],
vec![Some(1u32), Some(2)],
vec![Some(2u32), Some(3)],
vec![Some(3u32), Some(0)],
],
output_columns: vec![
ProjectExpr::Column(0),
ProjectExpr::Column(1),
ProjectExpr::Column(3),
ProjectExpr::Column(5),
],
fallback: Box::new(RirNode::Unit),
plan: None,
var_order: None,
}
}
#[test]
fn match_4cycle_canonical_returns_four_rels() {
let node = canonical_4cycle_multiway();
let m = match_multiway_4cycle(&node).expect("must match canonical 4-cycle");
assert_eq!(m.rel_e1, RelId(1));
assert_eq!(m.rel_e2, RelId(2));
assert_eq!(m.rel_e3, RelId(3));
assert_eq!(m.rel_e4, RelId(4));
}
#[test]
fn match_4cycle_rejects_non_multiway() {
assert!(match_multiway_4cycle(&RirNode::Scan { rel: RelId(1) }).is_none());
}
#[test]
fn match_4cycle_rejects_triangle_shape() {
let triangle = RirNode::MultiWayJoin {
inputs: vec![
RirNode::Scan { rel: RelId(1) },
RirNode::Scan { rel: RelId(2) },
RirNode::Scan { rel: RelId(3) },
],
slot_vars: vec![
vec![Some(0u32), Some(1)],
vec![Some(1u32), Some(2)],
vec![Some(0u32), Some(2)],
],
output_columns: vec![
ProjectExpr::Column(0),
ProjectExpr::Column(1),
ProjectExpr::Column(3),
],
fallback: Box::new(RirNode::Unit),
plan: None,
var_order: None,
};
assert!(match_multiway_4cycle(&triangle).is_none());
}
#[test]
fn match_4cycle_rejects_rotated_output_columns() {
let mut node = canonical_4cycle_multiway();
if let RirNode::MultiWayJoin { output_columns, .. } = &mut node {
output_columns.swap(0, 1);
}
assert!(match_multiway_4cycle(&node).is_none());
}
#[test]
fn match_4cycle_accepts_w22_alt_grouping_output_columns() {
let mut node = canonical_4cycle_multiway();
if let RirNode::MultiWayJoin { output_columns, .. } = &mut node {
*output_columns = vec![
ProjectExpr::Column(5),
ProjectExpr::Column(0),
ProjectExpr::Column(1),
ProjectExpr::Column(3),
];
}
let m = match_multiway_4cycle(&node)
.expect("W2.2 matcher must accept the Alt-grouping output-column layout");
assert_eq!(m.rel_e1, RelId(1));
assert_eq!(m.rel_e2, RelId(2));
assert_eq!(m.rel_e3, RelId(3));
assert_eq!(m.rel_e4, RelId(4));
}
#[test]
fn match_4cycle_rejects_invalid_w22_output_columns() {
let mut node = canonical_4cycle_multiway();
if let RirNode::MultiWayJoin { output_columns, .. } = &mut node {
*output_columns = vec![
ProjectExpr::Column(1),
ProjectExpr::Column(0),
ProjectExpr::Column(3),
ProjectExpr::Column(5),
];
}
assert!(match_multiway_4cycle(&node).is_none());
}
#[test]
fn match_4cycle_rejects_arity_mismatched_output_columns() {
let mut node = canonical_4cycle_multiway();
if let RirNode::MultiWayJoin { output_columns, .. } = &mut node {
output_columns.pop();
}
assert!(match_multiway_4cycle(&node).is_none());
}
#[test]
fn match_4cycle_rejects_unclosed_cycle() {
let mut node = canonical_4cycle_multiway();
if let RirNode::MultiWayJoin { slot_vars, .. } = &mut node {
slot_vars[3] = vec![Some(3), Some(99)];
}
assert!(match_multiway_4cycle(&node).is_none());
}
#[test]
fn match_4cycle_rejects_non_scan_input() {
let mut node = canonical_4cycle_multiway();
if let RirNode::MultiWayJoin { inputs, .. } = &mut node {
inputs[0] = RirNode::Unit;
}
assert!(match_multiway_4cycle(&node).is_none());
}
#[test]
fn match_4cycle_rejects_input_arity_mismatch() {
let mut node = canonical_4cycle_multiway();
if let RirNode::MultiWayJoin { inputs, .. } = &mut node {
inputs.push(RirNode::Scan { rel: RelId(5) });
}
assert!(match_multiway_4cycle(&node).is_none());
}
}