use crate::config::ConfigOptions;
use crate::error::Result;
use crate::physical_optimizer::PhysicalOptimizerRule;
use crate::physical_plan::rewrite::TreeNodeRewritable;
use crate::physical_plan::{with_new_children_if_necessary, ExecutionPlan};
use std::sync::Arc;
#[derive(Default)]
pub struct PipelineChecker {}
impl PipelineChecker {
#[allow(missing_docs)]
pub fn new() -> Self {
Self {}
}
}
impl PhysicalOptimizerRule for PipelineChecker {
fn optimize(
&self,
plan: Arc<dyn ExecutionPlan>,
_config: &ConfigOptions,
) -> Result<Arc<dyn ExecutionPlan>> {
let pipeline = PipelineStatePropagator::new(plan);
let state = pipeline.transform_up(&check_finiteness_requirements)?;
Ok(state.plan)
}
fn name(&self) -> &str {
"PipelineChecker"
}
fn schema_check(&self) -> bool {
true
}
}
#[derive(Clone, Debug)]
pub struct PipelineStatePropagator {
pub(crate) plan: Arc<dyn ExecutionPlan>,
pub(crate) unbounded: bool,
pub(crate) children_unbounded: Vec<bool>,
}
impl PipelineStatePropagator {
pub fn new(plan: Arc<dyn ExecutionPlan>) -> Self {
let length = plan.children().len();
PipelineStatePropagator {
plan,
unbounded: false,
children_unbounded: vec![false; length],
}
}
}
impl TreeNodeRewritable for PipelineStatePropagator {
fn map_children<F>(self, transform: F) -> Result<Self>
where
F: FnMut(Self) -> Result<Self>,
{
let children = self.plan.children();
if !children.is_empty() {
let new_children = children
.into_iter()
.map(|child| PipelineStatePropagator::new(child))
.map(transform)
.collect::<Result<Vec<_>>>()?;
let children_unbounded = new_children
.iter()
.map(|c| c.unbounded)
.collect::<Vec<bool>>();
let children_plans = new_children
.into_iter()
.map(|child| child.plan)
.collect::<Vec<_>>();
Ok(PipelineStatePropagator {
plan: with_new_children_if_necessary(self.plan, children_plans)?,
unbounded: self.unbounded,
children_unbounded,
})
} else {
Ok(self)
}
}
}
pub fn check_finiteness_requirements(
input: PipelineStatePropagator,
) -> Result<Option<PipelineStatePropagator>> {
let plan = input.plan;
let children = input.children_unbounded;
plan.unbounded_output(&children).map(|value| {
Some(PipelineStatePropagator {
plan,
unbounded: value,
children_unbounded: children,
})
})
}
#[cfg(test)]
mod sql_tests {
use super::*;
use crate::physical_optimizer::test_utils::{
BinaryTestCase, QueryCase, SourceType, UnaryTestCase,
};
#[tokio::test]
async fn test_hash_left_join_swap() -> Result<()> {
let test1 = BinaryTestCase {
source_types: (SourceType::Unbounded, SourceType::Bounded),
expect_fail: false,
};
let test2 = BinaryTestCase {
source_types: (SourceType::Unbounded, SourceType::Unbounded),
expect_fail: true,
};
let test3 = BinaryTestCase {
source_types: (SourceType::Bounded, SourceType::Unbounded),
expect_fail: true,
};
let test4 = BinaryTestCase {
source_types: (SourceType::Bounded, SourceType::Bounded),
expect_fail: false,
};
let case = QueryCase {
sql: "SELECT t2.c1 FROM left as t1 LEFT JOIN right as t2 ON t1.c1 = t2.c1"
.to_string(),
cases: vec![
Arc::new(test1),
Arc::new(test2),
Arc::new(test3),
Arc::new(test4),
],
error_operator: "Join Error".to_string(),
};
case.run().await?;
Ok(())
}
#[tokio::test]
async fn test_hash_right_join_swap() -> Result<()> {
let test1 = BinaryTestCase {
source_types: (SourceType::Unbounded, SourceType::Bounded),
expect_fail: true,
};
let test2 = BinaryTestCase {
source_types: (SourceType::Unbounded, SourceType::Unbounded),
expect_fail: true,
};
let test3 = BinaryTestCase {
source_types: (SourceType::Bounded, SourceType::Unbounded),
expect_fail: false,
};
let test4 = BinaryTestCase {
source_types: (SourceType::Bounded, SourceType::Bounded),
expect_fail: false,
};
let case = QueryCase {
sql: "SELECT t2.c1 FROM left as t1 RIGHT JOIN right as t2 ON t1.c1 = t2.c1"
.to_string(),
cases: vec![
Arc::new(test1),
Arc::new(test2),
Arc::new(test3),
Arc::new(test4),
],
error_operator: "Join Error".to_string(),
};
case.run().await?;
Ok(())
}
#[tokio::test]
async fn test_hash_inner_join_swap() -> Result<()> {
let test1 = BinaryTestCase {
source_types: (SourceType::Unbounded, SourceType::Bounded),
expect_fail: false,
};
let test2 = BinaryTestCase {
source_types: (SourceType::Unbounded, SourceType::Unbounded),
expect_fail: true,
};
let test3 = BinaryTestCase {
source_types: (SourceType::Bounded, SourceType::Unbounded),
expect_fail: false,
};
let test4 = BinaryTestCase {
source_types: (SourceType::Bounded, SourceType::Bounded),
expect_fail: false,
};
let case = QueryCase {
sql: "SELECT t2.c1 FROM left as t1 JOIN right as t2 ON t1.c1 = t2.c1"
.to_string(),
cases: vec![
Arc::new(test1),
Arc::new(test2),
Arc::new(test3),
Arc::new(test4),
],
error_operator: "Join Error".to_string(),
};
case.run().await?;
Ok(())
}
#[tokio::test]
async fn test_hash_full_outer_join_swap() -> Result<()> {
let test1 = BinaryTestCase {
source_types: (SourceType::Unbounded, SourceType::Bounded),
expect_fail: true,
};
let test2 = BinaryTestCase {
source_types: (SourceType::Unbounded, SourceType::Unbounded),
expect_fail: true,
};
let test3 = BinaryTestCase {
source_types: (SourceType::Bounded, SourceType::Unbounded),
expect_fail: true,
};
let test4 = BinaryTestCase {
source_types: (SourceType::Bounded, SourceType::Bounded),
expect_fail: false,
};
let case = QueryCase {
sql: "SELECT t2.c1 FROM left as t1 FULL JOIN right as t2 ON t1.c1 = t2.c1"
.to_string(),
cases: vec![
Arc::new(test1),
Arc::new(test2),
Arc::new(test3),
Arc::new(test4),
],
error_operator: "Join Error".to_string(),
};
case.run().await?;
Ok(())
}
#[tokio::test]
async fn test_aggregate() -> Result<()> {
let test1 = UnaryTestCase {
source_type: SourceType::Bounded,
expect_fail: false,
};
let test2 = UnaryTestCase {
source_type: SourceType::Unbounded,
expect_fail: true,
};
let case = QueryCase {
sql: "SELECT c1, MIN(c4) FROM test GROUP BY c1".to_string(),
cases: vec![Arc::new(test1), Arc::new(test2)],
error_operator: "Aggregate Error".to_string(),
};
case.run().await?;
Ok(())
}
#[tokio::test]
async fn test_window_agg_hash_partition() -> Result<()> {
let test1 = UnaryTestCase {
source_type: SourceType::Bounded,
expect_fail: false,
};
let test2 = UnaryTestCase {
source_type: SourceType::Unbounded,
expect_fail: true,
};
let case = QueryCase {
sql: "SELECT
c9,
SUM(c9) OVER(PARTITION BY c1 ORDER BY c9 ASC ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING) as sum1
FROM test
LIMIT 5".to_string(),
cases: vec![Arc::new(test1), Arc::new(test2)],
error_operator: "Sort Error".to_string()
};
case.run().await?;
Ok(())
}
#[tokio::test]
async fn test_window_agg_single_partition() -> Result<()> {
let test1 = UnaryTestCase {
source_type: SourceType::Bounded,
expect_fail: false,
};
let test2 = UnaryTestCase {
source_type: SourceType::Unbounded,
expect_fail: true,
};
let case = QueryCase {
sql: "SELECT
c9,
SUM(c9) OVER(ORDER BY c9 ASC ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING) as sum1
FROM test".to_string(),
cases: vec![Arc::new(test1), Arc::new(test2)],
error_operator: "Sort Error".to_string()
};
case.run().await?;
Ok(())
}
#[tokio::test]
async fn test_hash_cross_join() -> Result<()> {
let test1 = BinaryTestCase {
source_types: (SourceType::Unbounded, SourceType::Bounded),
expect_fail: true,
};
let test2 = BinaryTestCase {
source_types: (SourceType::Unbounded, SourceType::Unbounded),
expect_fail: true,
};
let test3 = BinaryTestCase {
source_types: (SourceType::Bounded, SourceType::Unbounded),
expect_fail: true,
};
let test4 = BinaryTestCase {
source_types: (SourceType::Bounded, SourceType::Bounded),
expect_fail: false,
};
let case = QueryCase {
sql: "SELECT t2.c1 FROM left as t1 CROSS JOIN right as t2".to_string(),
cases: vec![
Arc::new(test1),
Arc::new(test2),
Arc::new(test3),
Arc::new(test4),
],
error_operator: "Cross Join Error".to_string(),
};
case.run().await?;
Ok(())
}
#[tokio::test]
async fn test_analyzer() -> Result<()> {
let test1 = UnaryTestCase {
source_type: SourceType::Bounded,
expect_fail: false,
};
let test2 = UnaryTestCase {
source_type: SourceType::Unbounded,
expect_fail: true,
};
let case = QueryCase {
sql: "EXPLAIN ANALYZE SELECT * FROM test".to_string(),
cases: vec![Arc::new(test1), Arc::new(test2)],
error_operator: "Analyze Error".to_string(),
};
case.run().await?;
Ok(())
}
}