use xlog_core::{AggOp, RelId, ScalarType};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum JoinType {
Inner,
LeftOuter,
Semi,
Anti,
}
#[derive(Debug, Clone, PartialEq)]
pub enum Expr {
Column(usize),
Const(ConstValue),
Compare {
left: Box<Expr>,
op: CompareOp,
right: Box<Expr>,
},
And(Vec<Expr>),
Or(Vec<Expr>),
Not(Box<Expr>),
Add(Box<Expr>, Box<Expr>),
Sub(Box<Expr>, Box<Expr>),
Mul(Box<Expr>, Box<Expr>),
Div(Box<Expr>, Box<Expr>),
Mod(Box<Expr>, Box<Expr>),
Abs(Box<Expr>),
Min(Box<Expr>, Box<Expr>),
Max(Box<Expr>, Box<Expr>),
Pow(Box<Expr>, Box<Expr>),
Cast(Box<Expr>, ScalarType),
Conditional {
condition: Box<Expr>,
then_expr: Box<Expr>,
else_expr: Box<Expr>,
},
}
#[derive(Debug, Clone, PartialEq)]
pub enum ProjectExpr {
Column(usize),
Computed(Expr, ScalarType),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CompareOp {
Eq,
Ne,
Lt,
Le,
Gt,
Ge,
}
#[derive(Debug, Clone, PartialEq)]
pub enum ConstValue {
U32(u32),
U64(u64),
I32(i32),
I64(i64),
F32(f32),
F64(f64),
Bool(bool),
Symbol(String),
}
#[derive(Debug, Clone)]
pub enum RirNode {
Unit,
Scan { rel: RelId },
Filter {
input: Box<RirNode>,
predicate: Expr,
},
Project {
input: Box<RirNode>,
columns: Vec<ProjectExpr>,
},
Join {
left: Box<RirNode>,
right: Box<RirNode>,
left_keys: Vec<usize>,
right_keys: Vec<usize>,
join_type: JoinType,
},
GroupBy {
input: Box<RirNode>,
key_cols: Vec<usize>,
aggs: Vec<(usize, AggOp)>,
},
Union { inputs: Vec<RirNode> },
Distinct {
input: Box<RirNode>,
key_cols: Vec<usize>,
},
Diff {
left: Box<RirNode>,
right: Box<RirNode>,
},
Fixpoint {
scc_id: u32,
base: Box<RirNode>,
recursive: Box<RirNode>,
delta_rel: RelId,
full_rel: RelId,
},
TensorMaskedJoin {
mask_name: String,
schema_size: usize,
left_keys: Vec<usize>,
right_keys: Vec<usize>,
rel_index: Vec<(RelId, String)>,
head_rel_name: String,
head_rel_id: RelId,
max_active_rules: usize,
head_projection: Vec<usize>,
},
}
impl RirNode {
pub fn is_leaf(&self) -> bool {
matches!(self, RirNode::Scan { .. })
}
pub fn referenced_relations(&self) -> Vec<RelId> {
let mut rels = Vec::new();
self.collect_relations(&mut rels);
rels
}
fn collect_relations(&self, rels: &mut Vec<RelId>) {
match self {
RirNode::Unit => {}
RirNode::Scan { rel } => rels.push(*rel),
RirNode::Filter { input, .. } | RirNode::Project { input, .. } => {
input.collect_relations(rels);
}
RirNode::Join { left, right, .. } | RirNode::Diff { left, right } => {
left.collect_relations(rels);
right.collect_relations(rels);
}
RirNode::Union { inputs } => {
for input in inputs {
input.collect_relations(rels);
}
}
RirNode::GroupBy { input, .. } | RirNode::Distinct { input, .. } => {
input.collect_relations(rels);
}
RirNode::Fixpoint {
base,
recursive,
delta_rel,
full_rel,
..
} => {
base.collect_relations(rels);
recursive.collect_relations(rels);
rels.push(*delta_rel);
rels.push(*full_rel);
}
RirNode::TensorMaskedJoin { rel_index, .. } => {
for (rel_id, _) in rel_index {
rels.push(*rel_id);
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use xlog_core::ScalarType;
#[test]
fn test_scan_node() {
let node = RirNode::Scan { rel: RelId(1) };
assert!(matches!(node, RirNode::Scan { rel: RelId(1) }));
assert!(node.is_leaf());
}
#[test]
fn test_join_node() {
let left = Box::new(RirNode::Scan { rel: RelId(1) });
let right = Box::new(RirNode::Scan { rel: RelId(2) });
let join = RirNode::Join {
left,
right,
left_keys: vec![0],
right_keys: vec![0],
join_type: JoinType::Inner,
};
assert!(matches!(join, RirNode::Join { .. }));
let rels = join.referenced_relations();
assert!(rels.contains(&RelId(1)));
assert!(rels.contains(&RelId(2)));
}
#[test]
fn test_fixpoint_node() {
let base = Box::new(RirNode::Scan { rel: RelId(1) });
let recursive = Box::new(RirNode::Scan { rel: RelId(2) });
let fp = RirNode::Fixpoint {
scc_id: 0,
base,
recursive,
delta_rel: RelId(3),
full_rel: RelId(4),
};
assert!(matches!(fp, RirNode::Fixpoint { scc_id: 0, .. }));
}
#[test]
fn test_anti_join() {
let left = Box::new(RirNode::Scan { rel: RelId(1) });
let right = Box::new(RirNode::Scan { rel: RelId(2) });
let anti = RirNode::Join {
left,
right,
left_keys: vec![0],
right_keys: vec![0],
join_type: JoinType::Anti,
};
if let RirNode::Join { join_type, .. } = anti {
assert_eq!(join_type, JoinType::Anti);
}
}
#[test]
fn test_expr_arithmetic() {
let expr = Expr::Add(
Box::new(Expr::Column(0)),
Box::new(Expr::Const(ConstValue::I64(1))),
);
assert!(matches!(expr, Expr::Add(_, _)));
}
#[test]
fn test_project_expr_computed() {
let proj = ProjectExpr::Computed(
Expr::Add(
Box::new(Expr::Column(0)),
Box::new(Expr::Const(ConstValue::I64(1))),
),
ScalarType::I64,
);
assert!(matches!(proj, ProjectExpr::Computed(_, _)));
}
}