use crate::expression::Expr;
use crate::schema::{FileFormat, GenomicSchema};
use std::path::PathBuf;
#[derive(Debug, Clone)]
pub struct LogicalPlan {
pub root: PlanNode,
pub schema: GenomicSchema,
}
#[derive(Debug, Clone)]
pub enum PlanNode {
Scan {
path: PathBuf,
format: FileFormat,
projection: Option<Vec<String>>,
},
Filter {
input: Box<PlanNode>,
predicate: Expr,
},
Select {
input: Box<PlanNode>,
columns: Vec<String>,
},
WithColumn {
input: Box<PlanNode>,
name: String,
expr: Expr,
},
Limit {
input: Box<PlanNode>,
count: usize,
},
MaxScan {
input: Box<PlanNode>,
count: usize,
},
Join {
left: Box<PlanNode>,
right: Box<PlanNode>,
join_type: JoinType,
on: Vec<String>,
},
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum JoinType {
Inner,
Left,
Right,
Full,
Overlap,
}
impl LogicalPlan {
pub fn scan<P: Into<PathBuf>>(path: P, format: FileFormat) -> Self {
let schema = format.schema();
Self {
root: PlanNode::Scan {
path: path.into(),
format,
projection: None,
},
schema,
}
}
pub fn filter(self, predicate: Expr) -> Self {
Self {
root: PlanNode::Filter {
input: Box::new(self.root),
predicate,
},
schema: self.schema,
}
}
pub fn select(self, columns: &[&str]) -> Self {
Self {
root: PlanNode::Select {
input: Box::new(self.root),
columns: columns.iter().map(|s| s.to_string()).collect(),
},
schema: self.schema, }
}
pub fn with_column(self, name: &str, expr: Expr) -> Self {
Self {
root: PlanNode::WithColumn {
input: Box::new(self.root),
name: name.to_string(),
expr,
},
schema: self.schema, }
}
pub fn limit(self, count: usize) -> Self {
Self {
root: PlanNode::Limit {
input: Box::new(self.root),
count,
},
schema: self.schema,
}
}
pub fn max_scan(self, count: usize) -> Self {
Self {
root: PlanNode::MaxScan {
input: Box::new(self.root),
count,
},
schema: self.schema,
}
}
pub fn format(&self) -> Option<FileFormat> {
self.root.format()
}
}
impl PlanNode {
pub fn format(&self) -> Option<FileFormat> {
match self {
PlanNode::Scan { format, .. } => Some(*format),
PlanNode::Filter { input, .. } => input.format(),
PlanNode::Select { input, .. } => input.format(),
PlanNode::WithColumn { input, .. } => input.format(),
PlanNode::Limit { input, .. } => input.format(),
PlanNode::MaxScan { input, .. } => input.format(),
PlanNode::Join { left, .. } => left.format(), }
}
pub fn filters(&self) -> Vec<&Expr> {
let mut result = Vec::new();
self.collect_filters(&mut result);
result
}
fn collect_filters<'a>(&'a self, acc: &mut Vec<&'a Expr>) {
match self {
PlanNode::Filter { input, predicate } => {
acc.push(predicate);
input.collect_filters(acc);
}
PlanNode::Select { input, .. } => input.collect_filters(acc),
PlanNode::WithColumn { input, .. } => input.collect_filters(acc),
PlanNode::Limit { input, .. } => input.collect_filters(acc),
PlanNode::MaxScan { input, .. } => input.collect_filters(acc),
PlanNode::Join { left, right, .. } => {
left.collect_filters(acc);
right.collect_filters(acc);
}
PlanNode::Scan { .. } => {}
}
}
pub fn has_filters(&self) -> bool {
!self.filters().is_empty()
}
}
impl LogicalPlan {
pub fn optimize(self) -> Self {
let mut plan = self;
plan = plan.combine_filters();
plan = plan.push_down_filters();
plan = plan.prune_columns();
plan
}
pub fn combine_filters(self) -> Self {
Self {
root: self.root.combine_filters_node(),
schema: self.schema,
}
}
pub fn push_down_filters(self) -> Self {
Self {
root: self.root.push_down_filters_node(),
schema: self.schema,
}
}
pub fn prune_columns(self) -> Self {
Self {
root: self.root.prune_columns_node(),
schema: self.schema,
}
}
pub fn explain(&self) -> String {
self.root.explain(0)
}
}
impl PlanNode {
fn combine_filters_node(self) -> Self {
match self {
PlanNode::Filter { input, predicate } => {
match *input {
PlanNode::Filter {
input: inner_input,
predicate: inner_predicate,
} => {
let combined = predicate.and(inner_predicate);
PlanNode::Filter {
input: Box::new(inner_input.combine_filters_node()),
predicate: combined,
}
}
other => PlanNode::Filter {
input: Box::new(other.combine_filters_node()),
predicate,
},
}
}
PlanNode::Select { input, columns } => PlanNode::Select {
input: Box::new(input.combine_filters_node()),
columns,
},
PlanNode::WithColumn { input, name, expr } => PlanNode::WithColumn {
input: Box::new(input.combine_filters_node()),
name,
expr,
},
PlanNode::Limit { input, count } => PlanNode::Limit {
input: Box::new(input.combine_filters_node()),
count,
},
PlanNode::MaxScan { input, count } => PlanNode::MaxScan {
input: Box::new(input.combine_filters_node()),
count,
},
other => other,
}
}
fn push_down_filters_node(self) -> Self {
match self {
PlanNode::Select { input, columns } => {
match *input {
PlanNode::Filter {
input: filter_input,
predicate,
} => {
let select_below = PlanNode::Select {
input: filter_input,
columns,
};
PlanNode::Filter {
input: Box::new(select_below.push_down_filters_node()),
predicate,
}
}
other => PlanNode::Select {
input: Box::new(other.push_down_filters_node()),
columns,
},
}
}
PlanNode::Filter { input, predicate } => PlanNode::Filter {
input: Box::new(input.push_down_filters_node()),
predicate,
},
PlanNode::WithColumn { input, name, expr } => PlanNode::WithColumn {
input: Box::new(input.push_down_filters_node()),
name,
expr,
},
PlanNode::Limit { input, count } => PlanNode::Limit {
input: Box::new(input.push_down_filters_node()),
count,
},
PlanNode::MaxScan { input, count } => PlanNode::MaxScan {
input: Box::new(input.push_down_filters_node()),
count,
},
other => other,
}
}
fn prune_columns_node(self) -> Self {
match self {
PlanNode::Filter { input, predicate } => PlanNode::Filter {
input: Box::new(input.prune_columns_node()),
predicate,
},
PlanNode::Select { input, columns } => PlanNode::Select {
input: Box::new(input.prune_columns_node()),
columns,
},
PlanNode::WithColumn { input, name, expr } => PlanNode::WithColumn {
input: Box::new(input.prune_columns_node()),
name,
expr,
},
PlanNode::Limit { input, count } => PlanNode::Limit {
input: Box::new(input.prune_columns_node()),
count,
},
other => other,
}
}
fn explain(&self, indent: usize) -> String {
let prefix = " ".repeat(indent);
match self {
PlanNode::Scan {
path,
format,
projection,
} => {
let proj = if let Some(cols) = projection {
format!(" [{}]", cols.join(", "))
} else {
String::new()
};
format!("{}Scan: {:?} ({:?}){}", prefix, path, format, proj)
}
PlanNode::Filter { input, predicate } => {
format!(
"{}Filter: {}\n{}",
prefix,
predicate,
input.explain(indent + 1)
)
}
PlanNode::Select { input, columns } => {
format!(
"{}Select: [{}]\n{}",
prefix,
columns.join(", "),
input.explain(indent + 1)
)
}
PlanNode::WithColumn { input, name, expr } => {
format!(
"{}WithColumn: {} = {}\n{}",
prefix,
name,
expr,
input.explain(indent + 1)
)
}
PlanNode::Limit { input, count } => {
format!("{}Limit: {}\n{}", prefix, count, input.explain(indent + 1))
}
PlanNode::MaxScan { input, count } => {
format!(
"{}MaxScan: {}\n{}",
prefix,
count,
input.explain(indent + 1)
)
}
PlanNode::Join {
left,
right,
join_type,
on,
} => {
format!(
"{}Join: {:?} ON [{}]\n{}{}",
prefix,
join_type,
on.join(", "),
left.explain(indent + 1),
right.explain(indent + 1)
)
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::expression::{col, lit};
#[test]
fn test_scan_plan() {
let plan = LogicalPlan::scan("test.vcf", FileFormat::Vcf);
assert_eq!(plan.format(), Some(FileFormat::Vcf));
}
#[test]
fn test_filter_plan() {
let plan = LogicalPlan::scan("test.vcf", FileFormat::Vcf).filter(col("qual").gt(lit(30.0)));
assert!(plan.root.has_filters());
assert_eq!(plan.root.filters().len(), 1);
}
#[test]
fn test_chained_filters() {
let plan = LogicalPlan::scan("test.vcf", FileFormat::Vcf)
.filter(col("qual").gt(lit(30.0)))
.filter(Expr::IsSnp);
assert_eq!(plan.root.filters().len(), 2);
}
#[test]
fn test_combine_filters() {
let plan = LogicalPlan::scan("test.vcf", FileFormat::Vcf)
.filter(col("qual").gt(lit(30.0)))
.filter(Expr::IsSnp);
let optimized = plan.combine_filters();
assert_eq!(optimized.root.filters().len(), 1);
}
#[test]
fn test_select() {
let plan = LogicalPlan::scan("test.vcf", FileFormat::Vcf).select(&["chrom", "pos"]);
match plan.root {
PlanNode::Select { columns, .. } => {
assert_eq!(columns, vec!["chrom", "pos"]);
}
_ => panic!("Expected Select node"),
}
}
#[test]
fn test_limit() {
let plan = LogicalPlan::scan("test.vcf", FileFormat::Vcf).limit(100);
match plan.root {
PlanNode::Limit { count, .. } => {
assert_eq!(count, 100);
}
_ => panic!("Expected Limit node"),
}
}
#[test]
fn test_complex_plan() {
let plan = LogicalPlan::scan("test.vcf", FileFormat::Vcf)
.filter(col("qual").gt(lit(30.0)))
.filter(Expr::IsSnp)
.select(&["chrom", "pos", "ref", "alt"])
.limit(1000);
let optimized = plan.optimize();
assert_eq!(optimized.root.filters().len(), 1);
}
#[test]
fn test_explain() {
let plan = LogicalPlan::scan("test.vcf", FileFormat::Vcf)
.filter(col("qual").gt(lit(30.0)))
.select(&["chrom", "pos"]);
let explanation = plan.explain();
assert!(explanation.contains("Scan"));
assert!(explanation.contains("Filter"));
assert!(explanation.contains("Select"));
}
}