use crate::{OptimizerConfig, OptimizerRule};
use datafusion_common::Result;
use datafusion_expr::{
logical_plan::{EmptyRelation, Limit, LogicalPlan},
utils::from_plan,
};
#[derive(Default)]
pub struct EliminateLimit;
impl EliminateLimit {
#[allow(missing_docs)]
pub fn new() -> Self {
Self {}
}
}
enum Ancestor {
FromLimit { skip: Option<usize> },
NotRelevant,
}
fn eliminate_limit(
_optimizer: &EliminateLimit,
ancestor: &Ancestor,
plan: &LogicalPlan,
_optimizer_config: &OptimizerConfig,
) -> Result<LogicalPlan> {
match plan {
LogicalPlan::Limit(Limit {
skip, fetch, input, ..
}) => {
let ancestor_skip = match ancestor {
Ancestor::FromLimit { skip, .. } => skip.unwrap_or(0),
_ => 0,
};
match fetch {
Some(fetch) => {
if *fetch == 0 || ancestor_skip >= *fetch {
return Ok(LogicalPlan::EmptyRelation(EmptyRelation {
produce_one_row: false,
schema: input.schema().clone(),
}));
}
}
None => {}
}
let expr = plan.expressions();
let inputs = plan.inputs();
let new_inputs = inputs
.iter()
.map(|plan| {
eliminate_limit(
_optimizer,
&Ancestor::FromLimit { skip: *skip },
plan,
_optimizer_config,
)
})
.collect::<Result<Vec<_>>>()?;
from_plan(plan, &expr, &new_inputs)
}
_ => {
let ancestor = match plan {
LogicalPlan::Projection { .. } | LogicalPlan::Sort { .. } => ancestor,
_ => &Ancestor::NotRelevant,
};
let expr = plan.expressions();
let inputs = plan.inputs();
let new_inputs = inputs
.iter()
.map(|plan| {
eliminate_limit(_optimizer, ancestor, plan, _optimizer_config)
})
.collect::<Result<Vec<_>>>()?;
from_plan(plan, &expr, &new_inputs)
}
}
}
impl OptimizerRule for EliminateLimit {
fn optimize(
&self,
plan: &LogicalPlan,
optimizer_config: &OptimizerConfig,
) -> Result<LogicalPlan> {
eliminate_limit(self, &Ancestor::NotRelevant, plan, optimizer_config)
}
fn name(&self) -> &str {
"eliminate_limit"
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test::*;
use datafusion_common::Column;
use datafusion_expr::{
col,
logical_plan::{builder::LogicalPlanBuilder, JoinType},
sum,
};
fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) {
let rule = EliminateLimit::new();
let optimized_plan = rule
.optimize(plan, &OptimizerConfig::new())
.expect("failed to optimize plan");
let formatted_plan = format!("{:?}", optimized_plan);
assert_eq!(formatted_plan, expected);
assert_eq!(plan.schema(), optimized_plan.schema());
}
#[test]
fn limit_0_root() {
let table_scan = test_table_scan().unwrap();
let plan = LogicalPlanBuilder::from(table_scan)
.aggregate(vec![col("a")], vec![sum(col("b"))])
.unwrap()
.limit(None, Some(0))
.unwrap()
.build()
.unwrap();
let expected = "EmptyRelation";
assert_optimized_plan_eq(&plan, expected);
}
#[test]
fn limit_0_nested() {
let table_scan = test_table_scan().unwrap();
let plan1 = LogicalPlanBuilder::from(table_scan.clone())
.aggregate(vec![col("a")], vec![sum(col("b"))])
.unwrap()
.build()
.unwrap();
let plan = LogicalPlanBuilder::from(table_scan)
.aggregate(vec![col("a")], vec![sum(col("b"))])
.unwrap()
.limit(None, Some(0))
.unwrap()
.union(plan1)
.unwrap()
.build()
.unwrap();
let expected = "Union\
\n EmptyRelation\
\n Aggregate: groupBy=[[#test.a]], aggr=[[SUM(#test.b)]]\
\n TableScan: test";
assert_optimized_plan_eq(&plan, expected);
}
#[test]
fn limit_fetch_with_ancestor_limit_skip() {
let table_scan = test_table_scan().unwrap();
let plan = LogicalPlanBuilder::from(table_scan)
.aggregate(vec![col("a")], vec![sum(col("b"))])
.unwrap()
.limit(None, Some(2))
.unwrap()
.limit(Some(2), None)
.unwrap()
.build()
.unwrap();
let expected = "Limit: skip=2, fetch=None\
\n EmptyRelation";
assert_optimized_plan_eq(&plan, expected);
}
#[test]
fn multi_limit_offset_sort_eliminate() {
let table_scan = test_table_scan().unwrap();
let plan = LogicalPlanBuilder::from(table_scan)
.aggregate(vec![col("a")], vec![sum(col("b"))])
.unwrap()
.limit(None, Some(2))
.unwrap()
.sort(vec![col("a")])
.unwrap()
.limit(Some(2), Some(1))
.unwrap()
.build()
.unwrap();
let expected = "Limit: skip=2, fetch=1\
\n Sort: #test.a\
\n EmptyRelation";
assert_optimized_plan_eq(&plan, expected);
}
#[test]
fn limit_fetch_with_ancestor_limit_fetch() {
let table_scan = test_table_scan().unwrap();
let plan = LogicalPlanBuilder::from(table_scan)
.aggregate(vec![col("a")], vec![sum(col("b"))])
.unwrap()
.limit(None, Some(2))
.unwrap()
.sort(vec![col("a")])
.unwrap()
.limit(None, Some(1))
.unwrap()
.build()
.unwrap();
let expected = "Limit: skip=None, fetch=1\
\n Sort: #test.a\
\n Limit: skip=None, fetch=2\
\n Aggregate: groupBy=[[#test.a]], aggr=[[SUM(#test.b)]]\
\n TableScan: test";
assert_optimized_plan_eq(&plan, expected);
}
#[test]
fn limit_with_ancestor_limit() {
let table_scan = test_table_scan().unwrap();
let plan = LogicalPlanBuilder::from(table_scan)
.aggregate(vec![col("a")], vec![sum(col("b"))])
.unwrap()
.limit(Some(2), Some(1))
.unwrap()
.sort(vec![col("a")])
.unwrap()
.limit(Some(3), Some(1))
.unwrap()
.build()
.unwrap();
let expected = "Limit: skip=3, fetch=1\
\n Sort: #test.a\
\n EmptyRelation";
assert_optimized_plan_eq(&plan, expected);
}
#[test]
fn limit_join_with_ancestor_limit() {
let table_scan = test_table_scan().unwrap();
let table_scan_inner = test_table_scan_with_name("test1").unwrap();
let plan = LogicalPlanBuilder::from(table_scan)
.limit(Some(2), Some(1))
.unwrap()
.join_using(
&table_scan_inner,
JoinType::Inner,
vec![Column::from_name("a".to_string())],
)
.unwrap()
.limit(Some(3), Some(1))
.unwrap()
.build()
.unwrap();
let expected = "Limit: skip=3, fetch=1\
\n Inner Join: Using #test.a = #test1.a\
\n Limit: skip=2, fetch=1\
\n TableScan: test\
\n TableScan: test1";
assert_optimized_plan_eq(&plan, expected);
}
}