use std::{
cmp::min,
fmt::{self, Display},
sync::Arc,
};
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
use crate::datasource::TableProvider;
use crate::sql::parser::FileType;
use super::expr::Expr;
use super::extension::UserDefinedLogicalNode;
use super::{
col,
display::{GraphvizVisitor, IndentVisitor},
};
use crate::logical_plan::dfschema::DFSchemaRef;
#[derive(Debug, Clone, Copy)]
pub enum JoinType {
Inner,
Left,
Right,
}
#[derive(Clone)]
pub enum LogicalPlan {
Projection {
expr: Vec<Expr>,
input: Arc<LogicalPlan>,
schema: DFSchemaRef,
},
Filter {
predicate: Expr,
input: Arc<LogicalPlan>,
},
Aggregate {
input: Arc<LogicalPlan>,
group_expr: Vec<Expr>,
aggr_expr: Vec<Expr>,
schema: DFSchemaRef,
},
Sort {
expr: Vec<Expr>,
input: Arc<LogicalPlan>,
},
Join {
left: Arc<LogicalPlan>,
right: Arc<LogicalPlan>,
on: Vec<(String, String)>,
join_type: JoinType,
schema: DFSchemaRef,
},
Repartition {
input: Arc<LogicalPlan>,
partitioning_scheme: Partitioning,
},
Union {
inputs: Vec<LogicalPlan>,
schema: DFSchemaRef,
alias: Option<String>,
},
TableScan {
table_name: String,
source: Arc<dyn TableProvider>,
projection: Option<Vec<usize>>,
projected_schema: DFSchemaRef,
filters: Vec<Expr>,
limit: Option<usize>,
},
EmptyRelation {
produce_one_row: bool,
schema: DFSchemaRef,
},
Limit {
n: usize,
input: Arc<LogicalPlan>,
},
CreateExternalTable {
schema: DFSchemaRef,
name: String,
location: String,
file_type: FileType,
has_header: bool,
},
Explain {
verbose: bool,
plan: Arc<LogicalPlan>,
stringified_plans: Vec<StringifiedPlan>,
schema: DFSchemaRef,
},
Extension {
node: Arc<dyn UserDefinedLogicalNode + Send + Sync>,
},
}
impl LogicalPlan {
pub fn schema(&self) -> &DFSchemaRef {
match self {
LogicalPlan::EmptyRelation { schema, .. } => &schema,
LogicalPlan::TableScan {
projected_schema, ..
} => &projected_schema,
LogicalPlan::Projection { schema, .. } => &schema,
LogicalPlan::Filter { input, .. } => input.schema(),
LogicalPlan::Aggregate { schema, .. } => &schema,
LogicalPlan::Sort { input, .. } => input.schema(),
LogicalPlan::Join { schema, .. } => &schema,
LogicalPlan::Repartition { input, .. } => input.schema(),
LogicalPlan::Limit { input, .. } => input.schema(),
LogicalPlan::CreateExternalTable { schema, .. } => &schema,
LogicalPlan::Explain { schema, .. } => &schema,
LogicalPlan::Extension { node } => &node.schema(),
LogicalPlan::Union { schema, .. } => &schema,
}
}
pub fn all_schemas(&self) -> Vec<&DFSchemaRef> {
match self {
LogicalPlan::TableScan {
projected_schema, ..
} => vec![&projected_schema],
LogicalPlan::Aggregate { input, schema, .. }
| LogicalPlan::Projection { input, schema, .. } => {
let mut schemas = input.all_schemas();
schemas.insert(0, &schema);
schemas
}
LogicalPlan::Join {
left,
right,
schema,
..
} => {
let mut schemas = left.all_schemas();
schemas.extend(right.all_schemas());
schemas.insert(0, &schema);
schemas
}
LogicalPlan::Union { schema, .. } => {
vec![schema]
}
LogicalPlan::Extension { node } => vec![&node.schema()],
LogicalPlan::Explain { schema, .. }
| LogicalPlan::EmptyRelation { schema, .. }
| LogicalPlan::CreateExternalTable { schema, .. } => vec![&schema],
LogicalPlan::Limit { input, .. }
| LogicalPlan::Repartition { input, .. }
| LogicalPlan::Sort { input, .. }
| LogicalPlan::Filter { input, .. } => input.all_schemas(),
}
}
pub fn explain_schema() -> SchemaRef {
SchemaRef::new(Schema::new(vec![
Field::new("plan_type", DataType::Utf8, false),
Field::new("plan", DataType::Utf8, false),
]))
}
pub fn expressions(self: &LogicalPlan) -> Vec<Expr> {
match self {
LogicalPlan::Projection { expr, .. } => expr.clone(),
LogicalPlan::Filter { predicate, .. } => vec![predicate.clone()],
LogicalPlan::Repartition {
partitioning_scheme,
..
} => match partitioning_scheme {
Partitioning::Hash(expr, _) => expr.clone(),
_ => vec![],
},
LogicalPlan::Aggregate {
group_expr,
aggr_expr,
..
} => {
let mut result = group_expr.clone();
result.extend(aggr_expr.clone());
result
}
LogicalPlan::Join { on, .. } => {
on.iter().flat_map(|(l, r)| vec![col(l), col(r)]).collect()
}
LogicalPlan::Sort { expr, .. } => expr.clone(),
LogicalPlan::Extension { node } => node.expressions(),
LogicalPlan::TableScan { .. }
| LogicalPlan::EmptyRelation { .. }
| LogicalPlan::Limit { .. }
| LogicalPlan::CreateExternalTable { .. }
| LogicalPlan::Explain { .. } => vec![],
LogicalPlan::Union { .. } => {
vec![]
}
}
}
pub fn inputs(self: &LogicalPlan) -> Vec<&LogicalPlan> {
match self {
LogicalPlan::Projection { input, .. } => vec![input],
LogicalPlan::Filter { input, .. } => vec![input],
LogicalPlan::Repartition { input, .. } => vec![input],
LogicalPlan::Aggregate { input, .. } => vec![input],
LogicalPlan::Sort { input, .. } => vec![input],
LogicalPlan::Join { left, right, .. } => vec![left, right],
LogicalPlan::Limit { input, .. } => vec![input],
LogicalPlan::Extension { node } => node.inputs(),
LogicalPlan::Union { inputs, .. } => inputs.iter().collect(),
LogicalPlan::TableScan { .. }
| LogicalPlan::EmptyRelation { .. }
| LogicalPlan::CreateExternalTable { .. }
| LogicalPlan::Explain { .. } => vec![],
}
}
}
#[derive(Debug, Clone)]
pub enum Partitioning {
RoundRobinBatch(usize),
Hash(Vec<Expr>, usize),
}
pub trait PlanVisitor {
type Error;
fn pre_visit(&mut self, plan: &LogicalPlan)
-> std::result::Result<bool, Self::Error>;
fn post_visit(
&mut self,
_plan: &LogicalPlan,
) -> std::result::Result<bool, Self::Error> {
Ok(true)
}
}
impl LogicalPlan {
pub fn accept<V>(&self, visitor: &mut V) -> std::result::Result<bool, V::Error>
where
V: PlanVisitor,
{
if !visitor.pre_visit(self)? {
return Ok(false);
}
let recurse = match self {
LogicalPlan::Projection { input, .. } => input.accept(visitor)?,
LogicalPlan::Filter { input, .. } => input.accept(visitor)?,
LogicalPlan::Repartition { input, .. } => input.accept(visitor)?,
LogicalPlan::Aggregate { input, .. } => input.accept(visitor)?,
LogicalPlan::Sort { input, .. } => input.accept(visitor)?,
LogicalPlan::Join { left, right, .. } => {
left.accept(visitor)? && right.accept(visitor)?
}
LogicalPlan::Union { inputs, .. } => {
for input in inputs {
if !input.accept(visitor)? {
return Ok(false);
}
}
true
}
LogicalPlan::Limit { input, .. } => input.accept(visitor)?,
LogicalPlan::Extension { node } => {
for input in node.inputs() {
if !input.accept(visitor)? {
return Ok(false);
}
}
true
}
LogicalPlan::TableScan { .. }
| LogicalPlan::EmptyRelation { .. }
| LogicalPlan::CreateExternalTable { .. }
| LogicalPlan::Explain { .. } => true,
};
if !recurse {
return Ok(false);
}
if !visitor.post_visit(self)? {
return Ok(false);
}
Ok(true)
}
}
impl LogicalPlan {
pub fn display_indent(&self) -> impl fmt::Display + '_ {
struct Wrapper<'a>(&'a LogicalPlan);
impl<'a> fmt::Display for Wrapper<'a> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let with_schema = false;
let mut visitor = IndentVisitor::new(f, with_schema);
self.0.accept(&mut visitor).unwrap();
Ok(())
}
}
Wrapper(self)
}
pub fn display_indent_schema(&self) -> impl fmt::Display + '_ {
struct Wrapper<'a>(&'a LogicalPlan);
impl<'a> fmt::Display for Wrapper<'a> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let with_schema = true;
let mut visitor = IndentVisitor::new(f, with_schema);
self.0.accept(&mut visitor).unwrap();
Ok(())
}
}
Wrapper(self)
}
pub fn display_graphviz(&self) -> impl fmt::Display + '_ {
struct Wrapper<'a>(&'a LogicalPlan);
impl<'a> fmt::Display for Wrapper<'a> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
writeln!(
f,
"// Begin DataFusion GraphViz Plan (see https://graphviz.org)"
)?;
writeln!(f, "digraph {{")?;
let mut visitor = GraphvizVisitor::new(f);
visitor.pre_visit_plan("LogicalPlan")?;
self.0.accept(&mut visitor).unwrap();
visitor.post_visit_plan()?;
visitor.set_with_schema(true);
visitor.pre_visit_plan("Detailed LogicalPlan")?;
self.0.accept(&mut visitor).unwrap();
visitor.post_visit_plan()?;
writeln!(f, "}}")?;
writeln!(f, "// End DataFusion GraphViz Plan")?;
Ok(())
}
}
Wrapper(self)
}
pub fn display(&self) -> impl fmt::Display + '_ {
struct Wrapper<'a>(&'a LogicalPlan);
impl<'a> fmt::Display for Wrapper<'a> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match &*self.0 {
LogicalPlan::EmptyRelation { .. } => write!(f, "EmptyRelation"),
LogicalPlan::TableScan {
ref table_name,
ref projection,
ref filters,
ref limit,
..
} => {
let sep = " ".repeat(min(1, table_name.len()));
write!(
f,
"TableScan: {}{}projection={:?}",
table_name, sep, projection
)?;
if !filters.is_empty() {
write!(f, ", filters={:?}", filters)?;
}
if let Some(n) = limit {
write!(f, ", limit={}", n)?;
}
Ok(())
}
LogicalPlan::Projection { ref expr, .. } => {
write!(f, "Projection: ")?;
for (i, expr_item) in expr.iter().enumerate() {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{:?}", expr_item)?;
}
Ok(())
}
LogicalPlan::Filter {
predicate: ref expr,
..
} => write!(f, "Filter: {:?}", expr),
LogicalPlan::Aggregate {
ref group_expr,
ref aggr_expr,
..
} => write!(
f,
"Aggregate: groupBy=[{:?}], aggr=[{:?}]",
group_expr, aggr_expr
),
LogicalPlan::Sort { ref expr, .. } => {
write!(f, "Sort: ")?;
for (i, expr_item) in expr.iter().enumerate() {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{:?}", expr_item)?;
}
Ok(())
}
LogicalPlan::Join { on: ref keys, .. } => {
let join_expr: Vec<String> =
keys.iter().map(|(l, r)| format!("{} = {}", l, r)).collect();
write!(f, "Join: {}", join_expr.join(", "))
}
LogicalPlan::Repartition {
partitioning_scheme,
..
} => match partitioning_scheme {
Partitioning::RoundRobinBatch(n) => write!(
f,
"Repartition: RoundRobinBatch partition_count={}",
n
),
Partitioning::Hash(expr, n) => {
let hash_expr: Vec<String> =
expr.iter().map(|e| format!("{:?}", e)).collect();
write!(
f,
"Repartition: Hash({}) partition_count={}",
hash_expr.join(", "),
n
)
}
},
LogicalPlan::Limit { ref n, .. } => write!(f, "Limit: {}", n),
LogicalPlan::CreateExternalTable { ref name, .. } => {
write!(f, "CreateExternalTable: {:?}", name)
}
LogicalPlan::Explain { .. } => write!(f, "Explain"),
LogicalPlan::Union { .. } => write!(f, "Union"),
LogicalPlan::Extension { ref node } => node.fmt_for_explain(f),
}
}
}
Wrapper(self)
}
}
impl fmt::Debug for LogicalPlan {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
self.display_indent().fmt(f)
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum PlanType {
LogicalPlan,
OptimizedLogicalPlan {
optimizer_name: String,
},
PhysicalPlan,
}
impl From<&PlanType> for String {
fn from(t: &PlanType) -> Self {
match t {
PlanType::LogicalPlan => "logical_plan".into(),
PlanType::OptimizedLogicalPlan { optimizer_name } => {
format!("logical_plan after {}", optimizer_name)
}
PlanType::PhysicalPlan => "physical_plan".into(),
}
}
}
#[derive(Debug, Clone, PartialEq)]
#[allow(clippy::rc_buffer)]
pub struct StringifiedPlan {
pub plan_type: PlanType,
pub plan: Arc<String>,
}
impl StringifiedPlan {
pub fn new(plan_type: PlanType, plan: impl Into<String>) -> Self {
StringifiedPlan {
plan_type,
plan: Arc::new(plan.into()),
}
}
pub fn should_display(&self, verbose_mode: bool) -> bool {
self.plan_type == PlanType::LogicalPlan || verbose_mode
}
}
#[cfg(test)]
mod tests {
use super::super::{col, lit, LogicalPlanBuilder};
use super::*;
fn employee_schema() -> Schema {
Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("first_name", DataType::Utf8, false),
Field::new("last_name", DataType::Utf8, false),
Field::new("state", DataType::Utf8, false),
Field::new("salary", DataType::Int32, false),
])
}
fn display_plan() -> LogicalPlan {
LogicalPlanBuilder::scan_empty(
"employee.csv",
&employee_schema(),
Some(vec![0, 3]),
)
.unwrap()
.filter(col("state").eq(lit("CO")))
.unwrap()
.project(vec![col("id")])
.unwrap()
.build()
.unwrap()
}
#[test]
fn test_display_indent() {
let plan = display_plan();
let expected = "Projection: #id\
\n Filter: #state Eq Utf8(\"CO\")\
\n TableScan: employee.csv projection=Some([0, 3])";
assert_eq!(expected, format!("{}", plan.display_indent()));
}
#[test]
fn test_display_indent_schema() {
let plan = display_plan();
let expected = "Projection: #id [id:Int32]\
\n Filter: #state Eq Utf8(\"CO\") [id:Int32, state:Utf8]\
\n TableScan: employee.csv projection=Some([0, 3]) [id:Int32, state:Utf8]";
assert_eq!(expected, format!("{}", plan.display_indent_schema()));
}
#[test]
fn test_display_graphviz() {
let plan = display_plan();
let graphviz = format!("{}", plan.display_graphviz());
assert!(
graphviz.contains(
r#"// Begin DataFusion GraphViz Plan (see https://graphviz.org)"#
),
"\n{}",
plan.display_graphviz()
);
assert!(
graphviz.contains(
r#"[shape=box label="TableScan: employee.csv projection=Some([0, 3])"]"#
),
"\n{}",
plan.display_graphviz()
);
assert!(graphviz.contains(r#"[shape=box label="TableScan: employee.csv projection=Some([0, 3])\nSchema: [id:Int32, state:Utf8]"]"#),
"\n{}", plan.display_graphviz());
assert!(
graphviz.contains(r#"// End DataFusion GraphViz Plan"#),
"\n{}",
plan.display_graphviz()
);
}
#[derive(Debug, Default)]
struct OkVisitor {
strings: Vec<String>,
}
impl PlanVisitor for OkVisitor {
type Error = String;
fn pre_visit(
&mut self,
plan: &LogicalPlan,
) -> std::result::Result<bool, Self::Error> {
let s = match plan {
LogicalPlan::Projection { .. } => "pre_visit Projection",
LogicalPlan::Filter { .. } => "pre_visit Filter",
LogicalPlan::TableScan { .. } => "pre_visit TableScan",
_ => unimplemented!("unknown plan type"),
};
self.strings.push(s.into());
Ok(true)
}
fn post_visit(
&mut self,
plan: &LogicalPlan,
) -> std::result::Result<bool, Self::Error> {
let s = match plan {
LogicalPlan::Projection { .. } => "post_visit Projection",
LogicalPlan::Filter { .. } => "post_visit Filter",
LogicalPlan::TableScan { .. } => "post_visit TableScan",
_ => unimplemented!("unknown plan type"),
};
self.strings.push(s.into());
Ok(true)
}
}
#[test]
fn visit_order() {
let mut visitor = OkVisitor::default();
let plan = test_plan();
let res = plan.accept(&mut visitor);
assert!(res.is_ok());
assert_eq!(
visitor.strings,
vec![
"pre_visit Projection",
"pre_visit Filter",
"pre_visit TableScan",
"post_visit TableScan",
"post_visit Filter",
"post_visit Projection"
]
);
}
#[derive(Debug, Default)]
struct OptionalCounter {
val: Option<usize>,
}
impl OptionalCounter {
fn new(val: usize) -> Self {
Self { val: Some(val) }
}
fn dec(&mut self) -> bool {
if Some(0) == self.val {
true
} else {
self.val = self.val.take().map(|i| i - 1);
false
}
}
}
#[derive(Debug, Default)]
struct StoppingVisitor {
inner: OkVisitor,
return_false_from_pre_in: OptionalCounter,
return_false_from_post_in: OptionalCounter,
}
impl PlanVisitor for StoppingVisitor {
type Error = String;
fn pre_visit(
&mut self,
plan: &LogicalPlan,
) -> std::result::Result<bool, Self::Error> {
if self.return_false_from_pre_in.dec() {
return Ok(false);
}
self.inner.pre_visit(plan)
}
fn post_visit(
&mut self,
plan: &LogicalPlan,
) -> std::result::Result<bool, Self::Error> {
if self.return_false_from_post_in.dec() {
return Ok(false);
}
self.inner.post_visit(plan)
}
}
#[test]
fn early_stopping_pre_visit() {
let mut visitor = StoppingVisitor {
return_false_from_pre_in: OptionalCounter::new(2),
..Default::default()
};
let plan = test_plan();
let res = plan.accept(&mut visitor);
assert!(res.is_ok());
assert_eq!(
visitor.inner.strings,
vec!["pre_visit Projection", "pre_visit Filter",]
);
}
#[test]
fn early_stopping_post_visit() {
let mut visitor = StoppingVisitor {
return_false_from_post_in: OptionalCounter::new(1),
..Default::default()
};
let plan = test_plan();
let res = plan.accept(&mut visitor);
assert!(res.is_ok());
assert_eq!(
visitor.inner.strings,
vec![
"pre_visit Projection",
"pre_visit Filter",
"pre_visit TableScan",
"post_visit TableScan",
]
);
}
#[derive(Debug, Default)]
struct ErrorVisitor {
inner: OkVisitor,
return_error_from_pre_in: OptionalCounter,
return_error_from_post_in: OptionalCounter,
}
impl PlanVisitor for ErrorVisitor {
type Error = String;
fn pre_visit(
&mut self,
plan: &LogicalPlan,
) -> std::result::Result<bool, Self::Error> {
if self.return_error_from_pre_in.dec() {
return Err("Error in pre_visit".into());
}
self.inner.pre_visit(plan)
}
fn post_visit(
&mut self,
plan: &LogicalPlan,
) -> std::result::Result<bool, Self::Error> {
if self.return_error_from_post_in.dec() {
return Err("Error in post_visit".into());
}
self.inner.post_visit(plan)
}
}
#[test]
fn error_pre_visit() {
let mut visitor = ErrorVisitor {
return_error_from_pre_in: OptionalCounter::new(2),
..Default::default()
};
let plan = test_plan();
let res = plan.accept(&mut visitor);
if let Err(e) = res {
assert_eq!("Error in pre_visit", e);
} else {
panic!("Expected an error");
}
assert_eq!(
visitor.inner.strings,
vec!["pre_visit Projection", "pre_visit Filter",]
);
}
#[test]
fn error_post_visit() {
let mut visitor = ErrorVisitor {
return_error_from_post_in: OptionalCounter::new(1),
..Default::default()
};
let plan = test_plan();
let res = plan.accept(&mut visitor);
if let Err(e) = res {
assert_eq!("Error in post_visit", e);
} else {
panic!("Expected an error");
}
assert_eq!(
visitor.inner.strings,
vec![
"pre_visit Projection",
"pre_visit Filter",
"pre_visit TableScan",
"post_visit TableScan",
]
);
}
fn test_plan() -> LogicalPlan {
let schema = Schema::new(vec![Field::new("id", DataType::Int32, false)]);
LogicalPlanBuilder::scan_empty("", &schema, Some(vec![0]))
.unwrap()
.filter(col("state").eq(lit("CO")))
.unwrap()
.project(vec![col("id")])
.unwrap()
.build()
.unwrap()
}
}