use crate::{Expr, LogicalPlan};
use datafusion_common::{DFSchema, DFSchemaRef, Result};
use std::cmp::Ordering;
use std::hash::{Hash, Hasher};
use std::{any::Any, collections::HashSet, fmt, sync::Arc};
use super::InvariantLevel;
pub trait UserDefinedLogicalNode: fmt::Debug + Send + Sync {
fn as_any(&self) -> &dyn Any;
fn name(&self) -> &str;
fn inputs(&self) -> Vec<&LogicalPlan>;
fn schema(&self) -> &DFSchemaRef;
fn check_invariants(&self, check: InvariantLevel) -> Result<()>;
fn expressions(&self) -> Vec<Expr>;
fn prevent_predicate_push_down_columns(&self) -> HashSet<String> {
get_all_columns_from_schema(self.schema())
}
fn fmt_for_explain(&self, f: &mut fmt::Formatter) -> fmt::Result;
fn with_exprs_and_inputs(
&self,
exprs: Vec<Expr>,
inputs: Vec<LogicalPlan>,
) -> Result<Arc<dyn UserDefinedLogicalNode>>;
fn necessary_children_exprs(
&self,
_output_columns: &[usize],
) -> Option<Vec<Vec<usize>>> {
None
}
fn dyn_hash(&self, state: &mut dyn Hasher);
fn dyn_eq(&self, other: &dyn UserDefinedLogicalNode) -> bool;
fn dyn_ord(&self, other: &dyn UserDefinedLogicalNode) -> Option<Ordering>;
fn supports_limit_pushdown(&self) -> bool {
false
}
}
impl Hash for dyn UserDefinedLogicalNode {
fn hash<H: Hasher>(&self, state: &mut H) {
self.dyn_hash(state);
}
}
impl PartialEq for dyn UserDefinedLogicalNode {
fn eq(&self, other: &Self) -> bool {
self.dyn_eq(other)
}
}
impl PartialOrd for dyn UserDefinedLogicalNode {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
self.dyn_ord(other)
}
}
impl Eq for dyn UserDefinedLogicalNode {}
pub trait UserDefinedLogicalNodeCore:
fmt::Debug + Eq + PartialOrd + Hash + Sized + Send + Sync + 'static
{
fn name(&self) -> &str;
fn inputs(&self) -> Vec<&LogicalPlan>;
fn schema(&self) -> &DFSchemaRef;
fn check_invariants(&self, _check: InvariantLevel) -> Result<()> {
Ok(())
}
fn expressions(&self) -> Vec<Expr>;
fn prevent_predicate_push_down_columns(&self) -> HashSet<String> {
get_all_columns_from_schema(self.schema())
}
fn fmt_for_explain(&self, f: &mut fmt::Formatter) -> fmt::Result;
fn with_exprs_and_inputs(
&self,
exprs: Vec<Expr>,
inputs: Vec<LogicalPlan>,
) -> Result<Self>;
fn necessary_children_exprs(
&self,
_output_columns: &[usize],
) -> Option<Vec<Vec<usize>>> {
None
}
fn supports_limit_pushdown(&self) -> bool {
false }
}
impl<T: UserDefinedLogicalNodeCore> UserDefinedLogicalNode for T {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
self.name()
}
fn inputs(&self) -> Vec<&LogicalPlan> {
self.inputs()
}
fn schema(&self) -> &DFSchemaRef {
self.schema()
}
fn check_invariants(&self, check: InvariantLevel) -> Result<()> {
self.check_invariants(check)
}
fn expressions(&self) -> Vec<Expr> {
self.expressions()
}
fn prevent_predicate_push_down_columns(&self) -> HashSet<String> {
self.prevent_predicate_push_down_columns()
}
fn fmt_for_explain(&self, f: &mut fmt::Formatter) -> fmt::Result {
self.fmt_for_explain(f)
}
fn with_exprs_and_inputs(
&self,
exprs: Vec<Expr>,
inputs: Vec<LogicalPlan>,
) -> Result<Arc<dyn UserDefinedLogicalNode>> {
Ok(Arc::new(self.with_exprs_and_inputs(exprs, inputs)?))
}
fn necessary_children_exprs(
&self,
output_columns: &[usize],
) -> Option<Vec<Vec<usize>>> {
self.necessary_children_exprs(output_columns)
}
fn dyn_hash(&self, state: &mut dyn Hasher) {
let mut s = state;
self.hash(&mut s);
}
fn dyn_eq(&self, other: &dyn UserDefinedLogicalNode) -> bool {
match other.as_any().downcast_ref::<Self>() {
Some(o) => self == o,
None => false,
}
}
fn dyn_ord(&self, other: &dyn UserDefinedLogicalNode) -> Option<Ordering> {
other
.as_any()
.downcast_ref::<Self>()
.and_then(|other| self.partial_cmp(other))
}
fn supports_limit_pushdown(&self) -> bool {
self.supports_limit_pushdown()
}
}
fn get_all_columns_from_schema(schema: &DFSchema) -> HashSet<String> {
schema.fields().iter().map(|f| f.name().clone()).collect()
}