use xlog_core::{AggOp, RelId, ScalarType};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum JoinType {
Inner,
LeftOuter,
Semi,
Anti,
}
#[derive(Debug, Clone, PartialEq)]
pub enum Expr {
Column(usize),
Const(ConstValue),
Compare {
left: Box<Expr>,
op: CompareOp,
right: Box<Expr>,
},
And(Vec<Expr>),
Or(Vec<Expr>),
Not(Box<Expr>),
Add(Box<Expr>, Box<Expr>),
Sub(Box<Expr>, Box<Expr>),
Mul(Box<Expr>, Box<Expr>),
Div(Box<Expr>, Box<Expr>),
Mod(Box<Expr>, Box<Expr>),
Abs(Box<Expr>),
Min(Box<Expr>, Box<Expr>),
Max(Box<Expr>, Box<Expr>),
Pow(Box<Expr>, Box<Expr>),
Cast(Box<Expr>, ScalarType),
Conditional {
condition: Box<Expr>,
then_expr: Box<Expr>,
else_expr: Box<Expr>,
},
}
#[derive(Debug, Clone, PartialEq)]
pub enum ProjectExpr {
Column(usize),
Computed(Expr, ScalarType),
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct LookupPerm {
pub input_idx: u8,
pub swap_cols: bool,
}
pub const K_CLIQUE_MAX_K: usize = 8;
pub const K_CLIQUE_MAX_EDGES: usize = 28;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ColumnSwap {
pub edge_slot: u8,
pub swap_cols: bool,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SortedLayoutSpec {
pub edge_slots: Vec<u8>,
pub key_columns: Vec<Vec<u8>>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct HelperSplitSpec {
pub helper_id: u8,
pub variable: u8,
pub edge_slots: Vec<u8>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct StreamGroupId(pub u8);
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct KCliqueVariableOrder {
pub k: u8,
pub variable_positions: [u8; K_CLIQUE_MAX_K],
pub edge_permutation: [u8; K_CLIQUE_MAX_EDGES],
pub column_swaps: Vec<ColumnSwap>,
pub sorted_layout_requirements: SortedLayoutSpec,
pub helper_split_specs: Vec<HelperSplitSpec>,
pub stream_group: StreamGroupId,
}
impl KCliqueVariableOrder {
pub fn new(
k: u8,
variable_positions: [u8; K_CLIQUE_MAX_K],
edge_permutation: [u8; K_CLIQUE_MAX_EDGES],
column_swaps: Vec<ColumnSwap>,
sorted_layout_requirements: SortedLayoutSpec,
helper_split_specs: Vec<HelperSplitSpec>,
stream_group: StreamGroupId,
) -> Self {
Self {
k,
variable_positions,
edge_permutation,
column_swaps,
sorted_layout_requirements,
helper_split_specs,
stream_group,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct CostPredictionRecord {
pub wcoj_cost: f64,
pub hash_cost: f64,
}
impl CostPredictionRecord {
pub fn empty() -> Self {
Self {
wcoj_cost: f64::INFINITY,
hash_cost: 0.0,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PlannedHashReason {
PlannerPredictsHashWins,
IncompleteStatsSafeDefault,
}
#[derive(Debug, Clone, PartialEq)]
pub enum MultiwayPlan {
WcojWithPlan(KCliqueVariableOrder),
PlannedHashRoute {
reason: PlannedHashReason,
planner_evidence: CostPredictionRecord,
},
}
#[derive(Debug, Clone, PartialEq)]
pub struct VariableOrder {
pub leader_idx: u8,
pub lookup_perms: Vec<LookupPerm>,
pub kernel_output_cols: Vec<ProjectExpr>,
pub kclique: Option<KCliqueVariableOrder>,
}
impl VariableOrder {
pub fn legacy(
leader_idx: u8,
lookup_perms: Vec<LookupPerm>,
kernel_output_cols: Vec<ProjectExpr>,
) -> Self {
Self {
leader_idx,
lookup_perms,
kernel_output_cols,
kclique: None,
}
}
pub fn kclique(kclique: KCliqueVariableOrder) -> Self {
Self {
leader_idx: 0,
lookup_perms: Vec::new(),
kernel_output_cols: Vec::new(),
kclique: Some(kclique),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CompareOp {
Eq,
Ne,
Lt,
Le,
Gt,
Ge,
}
#[derive(Debug, Clone, PartialEq)]
pub enum ConstValue {
U32(u32),
U64(u64),
I32(i32),
I64(i64),
F32(f32),
F64(f64),
Bool(bool),
Symbol(String),
}
#[derive(Debug, Clone)]
#[allow(clippy::large_enum_variant)]
pub enum RirNode {
Unit,
Scan {
rel: RelId,
},
Filter {
input: Box<RirNode>,
predicate: Expr,
},
Project {
input: Box<RirNode>,
columns: Vec<ProjectExpr>,
},
Join {
left: Box<RirNode>,
right: Box<RirNode>,
left_keys: Vec<usize>,
right_keys: Vec<usize>,
join_type: JoinType,
},
ChainJoin {
left: Box<RirNode>,
right: Box<RirNode>,
left_key: usize,
right_key: usize,
output_columns: Vec<ProjectExpr>,
fallback: Box<RirNode>,
},
GroupBy {
input: Box<RirNode>,
key_cols: Vec<usize>,
aggs: Vec<(usize, AggOp)>,
},
Union {
inputs: Vec<RirNode>,
},
Distinct {
input: Box<RirNode>,
key_cols: Vec<usize>,
},
Diff {
left: Box<RirNode>,
right: Box<RirNode>,
},
Fixpoint {
scc_id: u32,
base: Box<RirNode>,
recursive: Box<RirNode>,
delta_rel: RelId,
full_rel: RelId,
},
MultiWayJoin {
inputs: Vec<RirNode>,
slot_vars: Vec<Vec<Option<u32>>>,
output_columns: Vec<ProjectExpr>,
fallback: Box<RirNode>,
plan: Option<MultiwayPlan>,
var_order: Option<VariableOrder>,
},
TensorMaskedJoin {
mask_name: String,
schema_size: usize,
left_keys: Vec<usize>,
right_keys: Vec<usize>,
rel_index: Vec<(RelId, String)>,
head_rel_name: String,
head_rel_id: RelId,
max_active_rules: usize,
head_projection: Vec<usize>,
},
}
impl RirNode {
pub fn is_leaf(&self) -> bool {
matches!(self, RirNode::Scan { .. })
}
pub fn referenced_relations(&self) -> Vec<RelId> {
let mut rels = Vec::new();
self.collect_relations(&mut rels);
rels
}
fn collect_relations(&self, rels: &mut Vec<RelId>) {
match self {
RirNode::Unit => {}
RirNode::Scan { rel } => rels.push(*rel),
RirNode::Filter { input, .. } | RirNode::Project { input, .. } => {
input.collect_relations(rels);
}
RirNode::Join { left, right, .. }
| RirNode::ChainJoin { left, right, .. }
| RirNode::Diff { left, right } => {
left.collect_relations(rels);
right.collect_relations(rels);
}
RirNode::Union { inputs } => {
for input in inputs {
input.collect_relations(rels);
}
}
RirNode::GroupBy { input, .. } | RirNode::Distinct { input, .. } => {
input.collect_relations(rels);
}
RirNode::Fixpoint {
base,
recursive,
delta_rel,
full_rel,
..
} => {
base.collect_relations(rels);
recursive.collect_relations(rels);
rels.push(*delta_rel);
rels.push(*full_rel);
}
RirNode::TensorMaskedJoin { rel_index, .. } => {
for (rel_id, _) in rel_index {
rels.push(*rel_id);
}
}
RirNode::MultiWayJoin { inputs, .. } => {
for input in inputs {
input.collect_relations(rels);
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use xlog_core::ScalarType;
#[test]
fn test_scan_node() {
let node = RirNode::Scan { rel: RelId(1) };
assert!(matches!(node, RirNode::Scan { rel: RelId(1) }));
assert!(node.is_leaf());
}
#[test]
fn test_join_node() {
let left = Box::new(RirNode::Scan { rel: RelId(1) });
let right = Box::new(RirNode::Scan { rel: RelId(2) });
let join = RirNode::Join {
left,
right,
left_keys: vec![0],
right_keys: vec![0],
join_type: JoinType::Inner,
};
assert!(matches!(join, RirNode::Join { .. }));
let rels = join.referenced_relations();
assert!(rels.contains(&RelId(1)));
assert!(rels.contains(&RelId(2)));
}
#[test]
fn test_fixpoint_node() {
let base = Box::new(RirNode::Scan { rel: RelId(1) });
let recursive = Box::new(RirNode::Scan { rel: RelId(2) });
let fp = RirNode::Fixpoint {
scc_id: 0,
base,
recursive,
delta_rel: RelId(3),
full_rel: RelId(4),
};
assert!(matches!(fp, RirNode::Fixpoint { scc_id: 0, .. }));
}
#[test]
fn test_anti_join() {
let left = Box::new(RirNode::Scan { rel: RelId(1) });
let right = Box::new(RirNode::Scan { rel: RelId(2) });
let anti = RirNode::Join {
left,
right,
left_keys: vec![0],
right_keys: vec![0],
join_type: JoinType::Anti,
};
if let RirNode::Join { join_type, .. } = anti {
assert_eq!(join_type, JoinType::Anti);
}
}
#[test]
fn test_expr_arithmetic() {
let expr = Expr::Add(
Box::new(Expr::Column(0)),
Box::new(Expr::Const(ConstValue::I64(1))),
);
assert!(matches!(expr, Expr::Add(_, _)));
}
#[test]
fn test_project_expr_computed() {
let proj = ProjectExpr::Computed(
Expr::Add(
Box::new(Expr::Column(0)),
Box::new(Expr::Const(ConstValue::I64(1))),
),
ScalarType::I64,
);
assert!(matches!(proj, ProjectExpr::Computed(_, _)));
}
}