use crate::dml::CopyTo;
use crate::expr::{Alias, Sort, WindowFunction};
use crate::logical_plan::builder::build_join_schema;
use crate::logical_plan::{
    Aggregate, Analyze, Distinct, Extension, Filter, Join, Limit, Partitioning, Prepare,
    Projection, Repartition, Sort as SortPlan, Subquery, SubqueryAlias, Union, Unnest,
    Values, Window,
};
use crate::signature::{Signature, TypeSignature};
use crate::{
    BinaryExpr, Cast, CreateMemoryTable, CreateView, DdlStatement, DmlStatement, Expr,
    ExprSchemable, GroupingSet, LogicalPlan, LogicalPlanBuilder, Operator, TableScan,
    TryCast,
};
use arrow::datatypes::{DataType, TimeUnit};
use datafusion_common::tree_node::{
    RewriteRecursion, TreeNode, TreeNodeRewriter, VisitRecursion,
};
use datafusion_common::{
    internal_err, plan_err, Column, Constraints, DFField, DFSchema, DFSchemaRef,
    DataFusionError, Result, ScalarValue, TableReference,
};
use sqlparser::ast::{ExceptSelectItem, ExcludeSelectItem, WildcardAdditionalOptions};
use std::cmp::Ordering;
use std::collections::HashSet;
use std::sync::Arc;
pub const COUNT_STAR_EXPANSION: ScalarValue = ScalarValue::UInt8(Some(1));
pub fn exprlist_to_columns(expr: &[Expr], accum: &mut HashSet<Column>) -> Result<()> {
    for e in expr {
        expr_to_columns(e, accum)?;
    }
    Ok(())
}
pub fn grouping_set_expr_count(group_expr: &[Expr]) -> Result<usize> {
    if let Some(Expr::GroupingSet(grouping_set)) = group_expr.first() {
        if group_expr.len() > 1 {
            return plan_err!(
                "Invalid group by expressions, GroupingSet must be the only expression"
            );
        }
        Ok(grouping_set.distinct_expr().len())
    } else {
        Ok(group_expr.len())
    }
}
fn powerset<T>(slice: &[T]) -> Result<Vec<Vec<&T>>, String> {
    if slice.len() >= 64 {
        return Err("The size of the set must be less than 64.".into());
    }
    let mut v = Vec::new();
    for mask in 0..(1 << slice.len()) {
        let mut ss = vec![];
        let mut bitset = mask;
        while bitset > 0 {
            let rightmost: u64 = bitset & !(bitset - 1);
            let idx = rightmost.trailing_zeros();
            let item = slice.get(idx as usize).unwrap();
            ss.push(item);
            bitset &= bitset - 1;
        }
        v.push(ss);
    }
    Ok(v)
}
fn check_grouping_set_size_limit(size: usize) -> Result<()> {
    let max_grouping_set_size = 65535;
    if size > max_grouping_set_size {
        return plan_err!("The number of group_expression in grouping_set exceeds the maximum limit {max_grouping_set_size}, found {size}");
    }
    Ok(())
}
fn check_grouping_sets_size_limit(size: usize) -> Result<()> {
    let max_grouping_sets_size = 4096;
    if size > max_grouping_sets_size {
        return plan_err!("The number of grouping_set in grouping_sets exceeds the maximum limit {max_grouping_sets_size}, found {size}");
    }
    Ok(())
}
fn merge_grouping_set<T: Clone>(left: &[T], right: &[T]) -> Result<Vec<T>> {
    check_grouping_set_size_limit(left.len() + right.len())?;
    Ok(left.iter().chain(right.iter()).cloned().collect())
}
fn cross_join_grouping_sets<T: Clone>(
    left: &[Vec<T>],
    right: &[Vec<T>],
) -> Result<Vec<Vec<T>>> {
    let grouping_sets_size = left.len() * right.len();
    check_grouping_sets_size_limit(grouping_sets_size)?;
    let mut result = Vec::with_capacity(grouping_sets_size);
    for le in left {
        for re in right {
            result.push(merge_grouping_set(le, re)?);
        }
    }
    Ok(result)
}
pub fn enumerate_grouping_sets(group_expr: Vec<Expr>) -> Result<Vec<Expr>> {
    let has_grouping_set = group_expr
        .iter()
        .any(|expr| matches!(expr, Expr::GroupingSet(_)));
    if !has_grouping_set || group_expr.len() == 1 {
        return Ok(group_expr);
    }
    let partial_sets = group_expr
        .iter()
        .map(|expr| {
            let exprs = match expr {
                Expr::GroupingSet(GroupingSet::GroupingSets(grouping_sets)) => {
                    check_grouping_sets_size_limit(grouping_sets.len())?;
                    grouping_sets.iter().map(|e| e.iter().collect()).collect()
                }
                Expr::GroupingSet(GroupingSet::Cube(group_exprs)) => {
                    let grouping_sets =
                        powerset(group_exprs).map_err(DataFusionError::Plan)?;
                    check_grouping_sets_size_limit(grouping_sets.len())?;
                    grouping_sets
                }
                Expr::GroupingSet(GroupingSet::Rollup(group_exprs)) => {
                    let size = group_exprs.len();
                    let slice = group_exprs.as_slice();
                    check_grouping_sets_size_limit(size * (size + 1) / 2 + 1)?;
                    (0..(size + 1))
                        .map(|i| slice[0..i].iter().collect())
                        .collect()
                }
                expr => vec![vec![expr]],
            };
            Ok(exprs)
        })
        .collect::<Result<Vec<_>>>()?;
    let grouping_sets = partial_sets
        .into_iter()
        .map(Ok)
        .reduce(|l, r| cross_join_grouping_sets(&l?, &r?))
        .transpose()?
        .map(|e| {
            e.into_iter()
                .map(|e| e.into_iter().cloned().collect())
                .collect()
        })
        .unwrap_or_default();
    Ok(vec![Expr::GroupingSet(GroupingSet::GroupingSets(
        grouping_sets,
    ))])
}
pub fn grouping_set_to_exprlist(group_expr: &[Expr]) -> Result<Vec<Expr>> {
    if let Some(Expr::GroupingSet(grouping_set)) = group_expr.first() {
        if group_expr.len() > 1 {
            return plan_err!(
                "Invalid group by expressions, GroupingSet must be the only expression"
            );
        }
        Ok(grouping_set.distinct_expr())
    } else {
        Ok(group_expr.to_vec())
    }
}
pub fn expr_to_columns(expr: &Expr, accum: &mut HashSet<Column>) -> Result<()> {
    inspect_expr_pre(expr, |expr| {
        match expr {
            Expr::Column(qc) => {
                accum.insert(qc.clone());
            }
            Expr::ScalarVariable(_, _)
            | Expr::Alias(_)
            | Expr::Literal(_)
            | Expr::BinaryExpr { .. }
            | Expr::Like { .. }
            | Expr::SimilarTo { .. }
            | Expr::Not(_)
            | Expr::IsNotNull(_)
            | Expr::IsNull(_)
            | Expr::IsTrue(_)
            | Expr::IsFalse(_)
            | Expr::IsUnknown(_)
            | Expr::IsNotTrue(_)
            | Expr::IsNotFalse(_)
            | Expr::IsNotUnknown(_)
            | Expr::Negative(_)
            | Expr::Between { .. }
            | Expr::Case { .. }
            | Expr::Cast { .. }
            | Expr::TryCast { .. }
            | Expr::Sort { .. }
            | Expr::ScalarFunction(..)
            | Expr::ScalarUDF(..)
            | Expr::WindowFunction { .. }
            | Expr::AggregateFunction { .. }
            | Expr::GroupingSet(_)
            | Expr::AggregateUDF { .. }
            | Expr::InList { .. }
            | Expr::Exists { .. }
            | Expr::InSubquery(_)
            | Expr::ScalarSubquery(_)
            | Expr::Wildcard
            | Expr::QualifiedWildcard { .. }
            | Expr::GetIndexedField { .. }
            | Expr::Placeholder(_)
            | Expr::OuterReferenceColumn { .. } => {}
        }
        Ok(())
    })
}
fn get_excluded_columns(
    opt_exclude: Option<ExcludeSelectItem>,
    opt_except: Option<ExceptSelectItem>,
    schema: &DFSchema,
    qualifier: &Option<TableReference>,
) -> Result<Vec<Column>> {
    let mut idents = vec![];
    if let Some(excepts) = opt_except {
        idents.push(excepts.first_element);
        idents.extend(excepts.additional_elements);
    }
    if let Some(exclude) = opt_exclude {
        match exclude {
            ExcludeSelectItem::Single(ident) => idents.push(ident),
            ExcludeSelectItem::Multiple(idents_inner) => idents.extend(idents_inner),
        }
    }
    let n_elem = idents.len();
    let unique_idents = idents.into_iter().collect::<HashSet<_>>();
    if n_elem != unique_idents.len() {
        return plan_err!("EXCLUDE or EXCEPT contains duplicate column names");
    }
    let mut result = vec![];
    for ident in unique_idents.into_iter() {
        let col_name = ident.value.as_str();
        let field = if let Some(qualifier) = qualifier {
            schema.field_with_qualified_name(qualifier, col_name)?
        } else {
            schema.field_with_unqualified_name(col_name)?
        };
        result.push(field.qualified_column())
    }
    Ok(result)
}
fn get_exprs_except_skipped(
    schema: &DFSchema,
    columns_to_skip: HashSet<Column>,
) -> Vec<Expr> {
    if columns_to_skip.is_empty() {
        schema
            .fields()
            .iter()
            .map(|f| Expr::Column(f.qualified_column()))
            .collect::<Vec<Expr>>()
    } else {
        schema
            .fields()
            .iter()
            .filter_map(|f| {
                let col = f.qualified_column();
                if !columns_to_skip.contains(&col) {
                    Some(Expr::Column(col))
                } else {
                    None
                }
            })
            .collect::<Vec<Expr>>()
    }
}
pub fn expand_wildcard(
    schema: &DFSchema,
    plan: &LogicalPlan,
    wildcard_options: Option<WildcardAdditionalOptions>,
) -> Result<Vec<Expr>> {
    let using_columns = plan.using_columns()?;
    let mut columns_to_skip = using_columns
        .into_iter()
        .flat_map(|cols| {
            let mut cols = cols.into_iter().collect::<Vec<_>>();
            cols.sort();
            let mut out_column_names: HashSet<String> = HashSet::new();
            cols.into_iter()
                .filter_map(|c| {
                    if out_column_names.contains(&c.name) {
                        Some(c)
                    } else {
                        out_column_names.insert(c.name);
                        None
                    }
                })
                .collect::<Vec<_>>()
        })
        .collect::<HashSet<_>>();
    let excluded_columns = if let Some(WildcardAdditionalOptions {
        opt_exclude,
        opt_except,
        ..
    }) = wildcard_options
    {
        get_excluded_columns(opt_exclude, opt_except, schema, &None)?
    } else {
        vec![]
    };
    columns_to_skip.extend(excluded_columns);
    Ok(get_exprs_except_skipped(schema, columns_to_skip))
}
pub fn expand_qualified_wildcard(
    qualifier: &str,
    schema: &DFSchema,
    wildcard_options: Option<WildcardAdditionalOptions>,
) -> Result<Vec<Expr>> {
    let qualifier = TableReference::from(qualifier);
    let qualified_fields: Vec<DFField> = schema
        .fields_with_qualified(&qualifier)
        .into_iter()
        .cloned()
        .collect();
    if qualified_fields.is_empty() {
        return plan_err!("Invalid qualifier {qualifier}");
    }
    let qualified_schema =
        DFSchema::new_with_metadata(qualified_fields, schema.metadata().clone())?
            .with_functional_dependencies(schema.functional_dependencies().clone());
    let excluded_columns = if let Some(WildcardAdditionalOptions {
        opt_exclude,
        opt_except,
        ..
    }) = wildcard_options
    {
        get_excluded_columns(opt_exclude, opt_except, schema, &Some(qualifier))?
    } else {
        vec![]
    };
    let mut columns_to_skip = HashSet::new();
    columns_to_skip.extend(excluded_columns);
    Ok(get_exprs_except_skipped(&qualified_schema, columns_to_skip))
}
type WindowSortKey = Vec<(Expr, bool)>;
pub fn generate_sort_key(
    partition_by: &[Expr],
    order_by: &[Expr],
) -> Result<WindowSortKey> {
    let normalized_order_by_keys = order_by
        .iter()
        .map(|e| match e {
            Expr::Sort(Sort { expr, .. }) => {
                Ok(Expr::Sort(Sort::new(expr.clone(), true, false)))
            }
            _ => plan_err!("Order by only accepts sort expressions"),
        })
        .collect::<Result<Vec<_>>>()?;
    let mut final_sort_keys = vec![];
    let mut is_partition_flag = vec![];
    partition_by.iter().for_each(|e| {
        let e = e.clone().sort(true, false);
        if let Some(pos) = normalized_order_by_keys.iter().position(|key| key.eq(&e)) {
            let order_by_key = &order_by[pos];
            if !final_sort_keys.contains(order_by_key) {
                final_sort_keys.push(order_by_key.clone());
                is_partition_flag.push(true);
            }
        } else if !final_sort_keys.contains(&e) {
            final_sort_keys.push(e);
            is_partition_flag.push(true);
        }
    });
    order_by.iter().for_each(|e| {
        if !final_sort_keys.contains(e) {
            final_sort_keys.push(e.clone());
            is_partition_flag.push(false);
        }
    });
    let res = final_sort_keys
        .into_iter()
        .zip(is_partition_flag)
        .map(|(lhs, rhs)| (lhs, rhs))
        .collect::<Vec<_>>();
    Ok(res)
}
pub fn compare_sort_expr(
    sort_expr_a: &Expr,
    sort_expr_b: &Expr,
    schema: &DFSchemaRef,
) -> Ordering {
    match (sort_expr_a, sort_expr_b) {
        (
            Expr::Sort(Sort {
                expr: expr_a,
                asc: asc_a,
                nulls_first: nulls_first_a,
            }),
            Expr::Sort(Sort {
                expr: expr_b,
                asc: asc_b,
                nulls_first: nulls_first_b,
            }),
        ) => {
            let ref_indexes_a = find_column_indexes_referenced_by_expr(expr_a, schema);
            let ref_indexes_b = find_column_indexes_referenced_by_expr(expr_b, schema);
            for (idx_a, idx_b) in ref_indexes_a.iter().zip(ref_indexes_b.iter()) {
                match idx_a.cmp(idx_b) {
                    Ordering::Less => {
                        return Ordering::Less;
                    }
                    Ordering::Greater => {
                        return Ordering::Greater;
                    }
                    Ordering::Equal => {}
                }
            }
            match ref_indexes_a.len().cmp(&ref_indexes_b.len()) {
                Ordering::Less => return Ordering::Greater,
                Ordering::Greater => {
                    return Ordering::Less;
                }
                Ordering::Equal => {}
            }
            match (asc_a, asc_b) {
                (true, false) => {
                    return Ordering::Greater;
                }
                (false, true) => {
                    return Ordering::Less;
                }
                _ => {}
            }
            match (nulls_first_a, nulls_first_b) {
                (true, false) => {
                    return Ordering::Less;
                }
                (false, true) => {
                    return Ordering::Greater;
                }
                _ => {}
            }
            Ordering::Equal
        }
        _ => Ordering::Equal,
    }
}
pub fn group_window_expr_by_sort_keys(
    window_expr: &[Expr],
) -> Result<Vec<(WindowSortKey, Vec<&Expr>)>> {
    let mut result = vec![];
    window_expr.iter().try_for_each(|expr| match expr {
        Expr::WindowFunction(WindowFunction{ partition_by, order_by, .. }) => {
            let sort_key = generate_sort_key(partition_by, order_by)?;
            if let Some((_, values)) = result.iter_mut().find(
                |group: &&mut (WindowSortKey, Vec<&Expr>)| matches!(group, (key, _) if *key == sort_key),
            ) {
                values.push(expr);
            } else {
                result.push((sort_key, vec![expr]))
            }
            Ok(())
        }
        other => internal_err!(
            "Impossibly got non-window expr {other:?}"
        ),
    })?;
    Ok(result)
}
pub fn find_aggregate_exprs(exprs: &[Expr]) -> Vec<Expr> {
    find_exprs_in_exprs(exprs, &|nested_expr| {
        matches!(
            nested_expr,
            Expr::AggregateFunction { .. } | Expr::AggregateUDF { .. }
        )
    })
}
pub fn find_sort_exprs(exprs: &[Expr]) -> Vec<Expr> {
    find_exprs_in_exprs(exprs, &|nested_expr| {
        matches!(nested_expr, Expr::Sort { .. })
    })
}
pub fn find_window_exprs(exprs: &[Expr]) -> Vec<Expr> {
    find_exprs_in_exprs(exprs, &|nested_expr| {
        matches!(nested_expr, Expr::WindowFunction { .. })
    })
}
pub fn find_out_reference_exprs(expr: &Expr) -> Vec<Expr> {
    find_exprs_in_expr(expr, &|nested_expr| {
        matches!(nested_expr, Expr::OuterReferenceColumn { .. })
    })
}
fn find_exprs_in_exprs<F>(exprs: &[Expr], test_fn: &F) -> Vec<Expr>
where
    F: Fn(&Expr) -> bool,
{
    exprs
        .iter()
        .flat_map(|expr| find_exprs_in_expr(expr, test_fn))
        .fold(vec![], |mut acc, expr| {
            if !acc.contains(&expr) {
                acc.push(expr)
            }
            acc
        })
}
fn find_exprs_in_expr<F>(expr: &Expr, test_fn: &F) -> Vec<Expr>
where
    F: Fn(&Expr) -> bool,
{
    let mut exprs = vec![];
    expr.apply(&mut |expr| {
        if test_fn(expr) {
            if !(exprs.contains(expr)) {
                exprs.push(expr.clone())
            }
            return Ok(VisitRecursion::Skip);
        }
        Ok(VisitRecursion::Continue)
    })
    .expect("no way to return error during recursion");
    exprs
}
pub fn inspect_expr_pre<F, E>(expr: &Expr, mut f: F) -> Result<(), E>
where
    F: FnMut(&Expr) -> Result<(), E>,
{
    let mut err = Ok(());
    expr.apply(&mut |expr| {
        if let Err(e) = f(expr) {
            err = Err(e);
            Ok(VisitRecursion::Stop)
        } else {
            Ok(VisitRecursion::Continue)
        }
    })
    .expect("no way to return error during recursion");
    err
}
pub fn from_plan(
    plan: &LogicalPlan,
    expr: &[Expr],
    inputs: &[LogicalPlan],
) -> Result<LogicalPlan> {
    match plan {
        LogicalPlan::Projection(Projection { schema, .. }) => {
            Ok(LogicalPlan::Projection(Projection::try_new_with_schema(
                expr.to_vec(),
                Arc::new(inputs[0].clone()),
                schema.clone(),
            )?))
        }
        LogicalPlan::Dml(DmlStatement {
            table_name,
            table_schema,
            op,
            ..
        }) => Ok(LogicalPlan::Dml(DmlStatement {
            table_name: table_name.clone(),
            table_schema: table_schema.clone(),
            op: op.clone(),
            input: Arc::new(inputs[0].clone()),
        })),
        LogicalPlan::Copy(CopyTo {
            input: _,
            output_url,
            file_format,
            per_thread_output,
            options,
        }) => Ok(LogicalPlan::Copy(CopyTo {
            input: Arc::new(inputs[0].clone()),
            output_url: output_url.clone(),
            file_format: file_format.clone(),
            per_thread_output: *per_thread_output,
            options: options.clone(),
        })),
        LogicalPlan::Values(Values { schema, .. }) => Ok(LogicalPlan::Values(Values {
            schema: schema.clone(),
            values: expr
                .chunks_exact(schema.fields().len())
                .map(|s| s.to_vec())
                .collect::<Vec<_>>(),
        })),
        LogicalPlan::Filter { .. } => {
            assert_eq!(1, expr.len());
            let predicate = expr[0].clone();
            struct RemoveAliases {}
            impl TreeNodeRewriter for RemoveAliases {
                type N = Expr;
                fn pre_visit(&mut self, expr: &Expr) -> Result<RewriteRecursion> {
                    match expr {
                        Expr::Exists { .. }
                        | Expr::ScalarSubquery(_)
                        | Expr::InSubquery(_) => {
                            Ok(RewriteRecursion::Stop)
                        }
                        Expr::Alias(_) => Ok(RewriteRecursion::Mutate),
                        _ => Ok(RewriteRecursion::Continue),
                    }
                }
                fn mutate(&mut self, expr: Expr) -> Result<Expr> {
                    Ok(expr.unalias())
                }
            }
            let mut remove_aliases = RemoveAliases {};
            let predicate = predicate.rewrite(&mut remove_aliases)?;
            Ok(LogicalPlan::Filter(Filter::try_new(
                predicate,
                Arc::new(inputs[0].clone()),
            )?))
        }
        LogicalPlan::Repartition(Repartition {
            partitioning_scheme,
            ..
        }) => match partitioning_scheme {
            Partitioning::RoundRobinBatch(n) => {
                Ok(LogicalPlan::Repartition(Repartition {
                    partitioning_scheme: Partitioning::RoundRobinBatch(*n),
                    input: Arc::new(inputs[0].clone()),
                }))
            }
            Partitioning::Hash(_, n) => Ok(LogicalPlan::Repartition(Repartition {
                partitioning_scheme: Partitioning::Hash(expr.to_owned(), *n),
                input: Arc::new(inputs[0].clone()),
            })),
            Partitioning::DistributeBy(_) => Ok(LogicalPlan::Repartition(Repartition {
                partitioning_scheme: Partitioning::DistributeBy(expr.to_owned()),
                input: Arc::new(inputs[0].clone()),
            })),
        },
        LogicalPlan::Window(Window {
            window_expr,
            schema,
            ..
        }) => Ok(LogicalPlan::Window(Window {
            input: Arc::new(inputs[0].clone()),
            window_expr: expr[0..window_expr.len()].to_vec(),
            schema: schema.clone(),
        })),
        LogicalPlan::Aggregate(Aggregate {
            group_expr, schema, ..
        }) => Ok(LogicalPlan::Aggregate(Aggregate::try_new_with_schema(
            Arc::new(inputs[0].clone()),
            expr[0..group_expr.len()].to_vec(),
            expr[group_expr.len()..].to_vec(),
            schema.clone(),
        )?)),
        LogicalPlan::Sort(SortPlan { fetch, .. }) => Ok(LogicalPlan::Sort(SortPlan {
            expr: expr.to_vec(),
            input: Arc::new(inputs[0].clone()),
            fetch: *fetch,
        })),
        LogicalPlan::Join(Join {
            join_type,
            join_constraint,
            on,
            null_equals_null,
            ..
        }) => {
            let schema =
                build_join_schema(inputs[0].schema(), inputs[1].schema(), join_type)?;
            let equi_expr_count = on.len();
            assert!(expr.len() >= equi_expr_count);
            let new_on:Vec<(Expr,Expr)> = expr.iter().take(equi_expr_count).map(|equi_expr| {
                    let unalias_expr = equi_expr.clone().unalias();
                    if let Expr::BinaryExpr(BinaryExpr { left, op:Operator::Eq, right }) = unalias_expr {
                        Ok((*left, *right))
                    } else {
                        internal_err!(
                            "The front part expressions should be an binary equiality expression, actual:{equi_expr}"
                        )
                    }
                }).collect::<Result<Vec<(Expr, Expr)>>>()?;
            let filter_expr =
                (expr.len() > equi_expr_count).then(|| expr[expr.len() - 1].clone());
            Ok(LogicalPlan::Join(Join {
                left: Arc::new(inputs[0].clone()),
                right: Arc::new(inputs[1].clone()),
                join_type: *join_type,
                join_constraint: *join_constraint,
                on: new_on,
                filter: filter_expr,
                schema: DFSchemaRef::new(schema),
                null_equals_null: *null_equals_null,
            }))
        }
        LogicalPlan::CrossJoin(_) => {
            let left = inputs[0].clone();
            let right = inputs[1].clone();
            LogicalPlanBuilder::from(left).cross_join(right)?.build()
        }
        LogicalPlan::Subquery(Subquery {
            outer_ref_columns, ..
        }) => {
            let subquery = LogicalPlanBuilder::from(inputs[0].clone()).build()?;
            Ok(LogicalPlan::Subquery(Subquery {
                subquery: Arc::new(subquery),
                outer_ref_columns: outer_ref_columns.clone(),
            }))
        }
        LogicalPlan::SubqueryAlias(SubqueryAlias { alias, .. }) => {
            Ok(LogicalPlan::SubqueryAlias(SubqueryAlias::try_new(
                inputs[0].clone(),
                alias.clone(),
            )?))
        }
        LogicalPlan::Limit(Limit { skip, fetch, .. }) => Ok(LogicalPlan::Limit(Limit {
            skip: *skip,
            fetch: *fetch,
            input: Arc::new(inputs[0].clone()),
        })),
        LogicalPlan::Ddl(DdlStatement::CreateMemoryTable(CreateMemoryTable {
            name,
            if_not_exists,
            or_replace,
            ..
        })) => Ok(LogicalPlan::Ddl(DdlStatement::CreateMemoryTable(
            CreateMemoryTable {
                input: Arc::new(inputs[0].clone()),
                constraints: Constraints::empty(),
                name: name.clone(),
                if_not_exists: *if_not_exists,
                or_replace: *or_replace,
            },
        ))),
        LogicalPlan::Ddl(DdlStatement::CreateView(CreateView {
            name,
            or_replace,
            definition,
            ..
        })) => Ok(LogicalPlan::Ddl(DdlStatement::CreateView(CreateView {
            input: Arc::new(inputs[0].clone()),
            name: name.clone(),
            or_replace: *or_replace,
            definition: definition.clone(),
        }))),
        LogicalPlan::Extension(e) => Ok(LogicalPlan::Extension(Extension {
            node: e.node.from_template(expr, inputs),
        })),
        LogicalPlan::Union(Union { schema, .. }) => Ok(LogicalPlan::Union(Union {
            inputs: inputs.iter().cloned().map(Arc::new).collect(),
            schema: schema.clone(),
        })),
        LogicalPlan::Distinct(Distinct { .. }) => Ok(LogicalPlan::Distinct(Distinct {
            input: Arc::new(inputs[0].clone()),
        })),
        LogicalPlan::Analyze(a) => {
            assert!(expr.is_empty());
            assert_eq!(inputs.len(), 1);
            Ok(LogicalPlan::Analyze(Analyze {
                verbose: a.verbose,
                schema: a.schema.clone(),
                input: Arc::new(inputs[0].clone()),
            }))
        }
        LogicalPlan::Explain(_) => {
            if expr.is_empty() {
                return plan_err!("Invalid EXPLAIN command. Expression is empty");
            }
            if inputs.is_empty() {
                return plan_err!("Invalid EXPLAIN command. Inputs are empty");
            }
            Ok(plan.clone())
        }
        LogicalPlan::Prepare(Prepare {
            name, data_types, ..
        }) => Ok(LogicalPlan::Prepare(Prepare {
            name: name.clone(),
            data_types: data_types.clone(),
            input: Arc::new(inputs[0].clone()),
        })),
        LogicalPlan::TableScan(ts) => {
            assert!(inputs.is_empty(), "{plan:?}  should have no inputs");
            Ok(LogicalPlan::TableScan(TableScan {
                filters: expr.to_vec(),
                ..ts.clone()
            }))
        }
        LogicalPlan::EmptyRelation(_)
        | LogicalPlan::Ddl(_)
        | LogicalPlan::Statement(_) => {
            assert!(expr.is_empty(), "{plan:?} should have no exprs");
            assert!(inputs.is_empty(), "{plan:?}  should have no inputs");
            Ok(plan.clone())
        }
        LogicalPlan::DescribeTable(_) => Ok(plan.clone()),
        LogicalPlan::Unnest(Unnest {
            column,
            schema,
            options,
            ..
        }) => {
            let input = Arc::new(inputs[0].clone());
            let nested_field = input.schema().field_from_column(column)?;
            let unnested_field = schema.field_from_column(column)?;
            let fields = input
                .schema()
                .fields()
                .iter()
                .map(|f| {
                    if f == nested_field {
                        unnested_field.clone()
                    } else {
                        f.clone()
                    }
                })
                .collect::<Vec<_>>();
            let schema = Arc::new(
                DFSchema::new_with_metadata(fields, input.schema().metadata().clone())?
                    .with_functional_dependencies(
                        input.schema().functional_dependencies().clone(),
                    ),
            );
            Ok(LogicalPlan::Unnest(Unnest {
                input,
                column: column.clone(),
                schema,
                options: options.clone(),
            }))
        }
    }
}
fn agg_cols(agg: &Aggregate) -> Vec<Column> {
    agg.aggr_expr
        .iter()
        .chain(&agg.group_expr)
        .flat_map(find_columns_referenced_by_expr)
        .collect()
}
fn exprlist_to_fields_aggregate(
    exprs: &[Expr],
    plan: &LogicalPlan,
    agg: &Aggregate,
) -> Result<Vec<DFField>> {
    let agg_cols = agg_cols(agg);
    let mut fields = vec![];
    for expr in exprs {
        match expr {
            Expr::Column(c) if agg_cols.iter().any(|x| x == c) => {
                fields.push(expr.to_field(agg.input.schema())?);
            }
            _ => fields.push(expr.to_field(plan.schema())?),
        }
    }
    Ok(fields)
}
pub fn exprlist_to_fields<'a>(
    expr: impl IntoIterator<Item = &'a Expr>,
    plan: &LogicalPlan,
) -> Result<Vec<DFField>> {
    let exprs: Vec<Expr> = expr.into_iter().cloned().collect();
    let fields = match plan {
        LogicalPlan::Aggregate(agg) => {
            Some(exprlist_to_fields_aggregate(&exprs, plan, agg))
        }
        LogicalPlan::Window(window) => match window.input.as_ref() {
            LogicalPlan::Aggregate(agg) => {
                Some(exprlist_to_fields_aggregate(&exprs, plan, agg))
            }
            _ => None,
        },
        _ => None,
    };
    if let Some(fields) = fields {
        fields
    } else {
        let input_schema = &plan.schema();
        exprs.iter().map(|e| e.to_field(input_schema)).collect()
    }
}
pub fn columnize_expr(e: Expr, input_schema: &DFSchema) -> Expr {
    match e {
        Expr::Column(_) => e,
        Expr::OuterReferenceColumn(_, _) => e,
        Expr::Alias(Alias { expr, name, .. }) => {
            columnize_expr(*expr, input_schema).alias(name)
        }
        Expr::Cast(Cast { expr, data_type }) => Expr::Cast(Cast {
            expr: Box::new(columnize_expr(*expr, input_schema)),
            data_type,
        }),
        Expr::TryCast(TryCast { expr, data_type }) => Expr::TryCast(TryCast::new(
            Box::new(columnize_expr(*expr, input_schema)),
            data_type,
        )),
        Expr::ScalarSubquery(_) => e.clone(),
        _ => match e.display_name() {
            Ok(name) => match input_schema.field_with_unqualified_name(&name) {
                Ok(field) => Expr::Column(field.qualified_column()),
                Err(_) => e,
            },
            Err(_) => e,
        },
    }
}
pub fn find_column_exprs(exprs: &[Expr]) -> Vec<Expr> {
    exprs
        .iter()
        .flat_map(find_columns_referenced_by_expr)
        .map(Expr::Column)
        .collect()
}
pub(crate) fn find_columns_referenced_by_expr(e: &Expr) -> Vec<Column> {
    let mut exprs = vec![];
    inspect_expr_pre(e, |expr| {
        if let Expr::Column(c) = expr {
            exprs.push(c.clone())
        }
        Ok(()) as Result<()>
    })
    .expect("Unexpected error");
    exprs
}
pub fn expr_as_column_expr(expr: &Expr, plan: &LogicalPlan) -> Result<Expr> {
    match expr {
        Expr::Column(col) => {
            let field = plan.schema().field_from_column(col)?;
            Ok(Expr::Column(field.qualified_column()))
        }
        _ => Ok(Expr::Column(Column::from_name(expr.display_name()?))),
    }
}
pub(crate) fn find_column_indexes_referenced_by_expr(
    e: &Expr,
    schema: &DFSchemaRef,
) -> Vec<usize> {
    let mut indexes = vec![];
    inspect_expr_pre(e, |expr| {
        match expr {
            Expr::Column(qc) => {
                if let Ok(idx) = schema.index_of_column(qc) {
                    indexes.push(idx);
                }
            }
            Expr::Literal(_) => {
                indexes.push(std::usize::MAX);
            }
            _ => {}
        }
        Ok(()) as Result<()>
    })
    .unwrap();
    indexes
}
pub fn can_hash(data_type: &DataType) -> bool {
    match data_type {
        DataType::Null => true,
        DataType::Boolean => true,
        DataType::Int8 => true,
        DataType::Int16 => true,
        DataType::Int32 => true,
        DataType::Int64 => true,
        DataType::UInt8 => true,
        DataType::UInt16 => true,
        DataType::UInt32 => true,
        DataType::UInt64 => true,
        DataType::Float32 => true,
        DataType::Float64 => true,
        DataType::Timestamp(time_unit, None) => match time_unit {
            TimeUnit::Second => true,
            TimeUnit::Millisecond => true,
            TimeUnit::Microsecond => true,
            TimeUnit::Nanosecond => true,
        },
        DataType::Utf8 => true,
        DataType::LargeUtf8 => true,
        DataType::Decimal128(_, _) => true,
        DataType::Date32 => true,
        DataType::Date64 => true,
        DataType::FixedSizeBinary(_) => true,
        DataType::Dictionary(key_type, value_type)
            if *value_type.as_ref() == DataType::Utf8 =>
        {
            DataType::is_dictionary_key_type(key_type)
        }
        _ => false,
    }
}
pub fn check_all_columns_from_schema(
    columns: &HashSet<Column>,
    schema: DFSchemaRef,
) -> Result<bool> {
    for col in columns.iter() {
        let exist = schema.is_column_from_schema(col)?;
        if !exist {
            return Ok(false);
        }
    }
    Ok(true)
}
pub fn find_valid_equijoin_key_pair(
    left_key: &Expr,
    right_key: &Expr,
    left_schema: DFSchemaRef,
    right_schema: DFSchemaRef,
) -> Result<Option<(Expr, Expr)>> {
    let left_using_columns = left_key.to_columns()?;
    let right_using_columns = right_key.to_columns()?;
    if left_using_columns.is_empty() || right_using_columns.is_empty() {
        return Ok(None);
    }
    let l_is_left =
        check_all_columns_from_schema(&left_using_columns, left_schema.clone())?;
    let r_is_right =
        check_all_columns_from_schema(&right_using_columns, right_schema.clone())?;
    let r_is_left_and_l_is_right = || {
        let result =
            check_all_columns_from_schema(&right_using_columns, left_schema.clone())?
                && check_all_columns_from_schema(
                    &left_using_columns,
                    right_schema.clone(),
                )?;
        Result::<_>::Ok(result)
    };
    let join_key_pair = match (l_is_left, r_is_right) {
        (true, true) => Some((left_key.clone(), right_key.clone())),
        (_, _) if r_is_left_and_l_is_right()? => {
            Some((right_key.clone(), left_key.clone()))
        }
        _ => None,
    };
    Ok(join_key_pair)
}
pub fn generate_signature_error_msg(
    func_name: &str,
    func_signature: Signature,
    input_expr_types: &[DataType],
) -> String {
    let candidate_signatures = func_signature
        .type_signature
        .to_string_repr()
        .iter()
        .map(|args_str| format!("\t{func_name}({args_str})"))
        .collect::<Vec<String>>()
        .join("\n");
    format!(
            "No function matches the given name and argument types '{}({})'. You might need to add explicit type casts.\n\tCandidate functions:\n{}",
            func_name, TypeSignature::join_types(input_expr_types, ", "), candidate_signatures
        )
}
#[cfg(test)]
mod tests {
    use super::*;
    use crate::expr_vec_fmt;
    use crate::{
        col, cube, expr, grouping_set, rollup, AggregateFunction, WindowFrame,
        WindowFunction,
    };
    #[test]
    fn test_group_window_expr_by_sort_keys_empty_case() -> Result<()> {
        let result = group_window_expr_by_sort_keys(&[])?;
        let expected: Vec<(WindowSortKey, Vec<&Expr>)> = vec![];
        assert_eq!(expected, result);
        Ok(())
    }
    #[test]
    fn test_group_window_expr_by_sort_keys_empty_window() -> Result<()> {
        let max1 = Expr::WindowFunction(expr::WindowFunction::new(
            WindowFunction::AggregateFunction(AggregateFunction::Max),
            vec![col("name")],
            vec![],
            vec![],
            WindowFrame::new(false),
        ));
        let max2 = Expr::WindowFunction(expr::WindowFunction::new(
            WindowFunction::AggregateFunction(AggregateFunction::Max),
            vec![col("name")],
            vec![],
            vec![],
            WindowFrame::new(false),
        ));
        let min3 = Expr::WindowFunction(expr::WindowFunction::new(
            WindowFunction::AggregateFunction(AggregateFunction::Min),
            vec![col("name")],
            vec![],
            vec![],
            WindowFrame::new(false),
        ));
        let sum4 = Expr::WindowFunction(expr::WindowFunction::new(
            WindowFunction::AggregateFunction(AggregateFunction::Sum),
            vec![col("age")],
            vec![],
            vec![],
            WindowFrame::new(false),
        ));
        let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()];
        let result = group_window_expr_by_sort_keys(exprs)?;
        let key = vec![];
        let expected: Vec<(WindowSortKey, Vec<&Expr>)> =
            vec![(key, vec![&max1, &max2, &min3, &sum4])];
        assert_eq!(expected, result);
        Ok(())
    }
    #[test]
    fn test_group_window_expr_by_sort_keys() -> Result<()> {
        let age_asc = Expr::Sort(expr::Sort::new(Box::new(col("age")), true, true));
        let name_desc = Expr::Sort(expr::Sort::new(Box::new(col("name")), false, true));
        let created_at_desc =
            Expr::Sort(expr::Sort::new(Box::new(col("created_at")), false, true));
        let max1 = Expr::WindowFunction(expr::WindowFunction::new(
            WindowFunction::AggregateFunction(AggregateFunction::Max),
            vec![col("name")],
            vec![],
            vec![age_asc.clone(), name_desc.clone()],
            WindowFrame::new(true),
        ));
        let max2 = Expr::WindowFunction(expr::WindowFunction::new(
            WindowFunction::AggregateFunction(AggregateFunction::Max),
            vec![col("name")],
            vec![],
            vec![],
            WindowFrame::new(false),
        ));
        let min3 = Expr::WindowFunction(expr::WindowFunction::new(
            WindowFunction::AggregateFunction(AggregateFunction::Min),
            vec![col("name")],
            vec![],
            vec![age_asc.clone(), name_desc.clone()],
            WindowFrame::new(true),
        ));
        let sum4 = Expr::WindowFunction(expr::WindowFunction::new(
            WindowFunction::AggregateFunction(AggregateFunction::Sum),
            vec![col("age")],
            vec![],
            vec![name_desc.clone(), age_asc.clone(), created_at_desc.clone()],
            WindowFrame::new(true),
        ));
        let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()];
        let result = group_window_expr_by_sort_keys(exprs)?;
        let key1 = vec![(age_asc.clone(), false), (name_desc.clone(), false)];
        let key2 = vec![];
        let key3 = vec![
            (name_desc, false),
            (age_asc, false),
            (created_at_desc, false),
        ];
        let expected: Vec<(WindowSortKey, Vec<&Expr>)> = vec![
            (key1, vec![&max1, &min3]),
            (key2, vec![&max2]),
            (key3, vec![&sum4]),
        ];
        assert_eq!(expected, result);
        Ok(())
    }
    #[test]
    fn test_find_sort_exprs() -> Result<()> {
        let exprs = &[
            Expr::WindowFunction(expr::WindowFunction::new(
                WindowFunction::AggregateFunction(AggregateFunction::Max),
                vec![col("name")],
                vec![],
                vec![
                    Expr::Sort(expr::Sort::new(Box::new(col("age")), true, true)),
                    Expr::Sort(expr::Sort::new(Box::new(col("name")), false, true)),
                ],
                WindowFrame::new(true),
            )),
            Expr::WindowFunction(expr::WindowFunction::new(
                WindowFunction::AggregateFunction(AggregateFunction::Sum),
                vec![col("age")],
                vec![],
                vec![
                    Expr::Sort(expr::Sort::new(Box::new(col("name")), false, true)),
                    Expr::Sort(expr::Sort::new(Box::new(col("age")), true, true)),
                    Expr::Sort(expr::Sort::new(Box::new(col("created_at")), false, true)),
                ],
                WindowFrame::new(true),
            )),
        ];
        let expected = vec![
            Expr::Sort(expr::Sort::new(Box::new(col("age")), true, true)),
            Expr::Sort(expr::Sort::new(Box::new(col("name")), false, true)),
            Expr::Sort(expr::Sort::new(Box::new(col("created_at")), false, true)),
        ];
        let result = find_sort_exprs(exprs);
        assert_eq!(expected, result);
        Ok(())
    }
    #[test]
    fn avoid_generate_duplicate_sort_keys() -> Result<()> {
        let asc_or_desc = [true, false];
        let nulls_first_or_last = [true, false];
        let partition_by = &[col("age"), col("name"), col("created_at")];
        for asc_ in asc_or_desc {
            for nulls_first_ in nulls_first_or_last {
                let order_by = &[
                    Expr::Sort(Sort {
                        expr: Box::new(col("age")),
                        asc: asc_,
                        nulls_first: nulls_first_,
                    }),
                    Expr::Sort(Sort {
                        expr: Box::new(col("name")),
                        asc: asc_,
                        nulls_first: nulls_first_,
                    }),
                ];
                let expected = vec![
                    (
                        Expr::Sort(Sort {
                            expr: Box::new(col("age")),
                            asc: asc_,
                            nulls_first: nulls_first_,
                        }),
                        true,
                    ),
                    (
                        Expr::Sort(Sort {
                            expr: Box::new(col("name")),
                            asc: asc_,
                            nulls_first: nulls_first_,
                        }),
                        true,
                    ),
                    (
                        Expr::Sort(Sort {
                            expr: Box::new(col("created_at")),
                            asc: true,
                            nulls_first: false,
                        }),
                        true,
                    ),
                ];
                let result = generate_sort_key(partition_by, order_by)?;
                assert_eq!(expected, result);
            }
        }
        Ok(())
    }
    #[test]
    fn test_enumerate_grouping_sets() -> Result<()> {
        let multi_cols = vec![col("col1"), col("col2"), col("col3")];
        let simple_col = col("simple_col");
        let cube = cube(multi_cols.clone());
        let rollup = rollup(multi_cols.clone());
        let grouping_set = grouping_set(vec![multi_cols]);
        let sets = enumerate_grouping_sets(vec![simple_col.clone()])?;
        let result = format!("[{}]", expr_vec_fmt!(sets));
        assert_eq!("[simple_col]", &result);
        let sets = enumerate_grouping_sets(vec![cube.clone()])?;
        let result = format!("[{}]", expr_vec_fmt!(sets));
        assert_eq!("[CUBE (col1, col2, col3)]", &result);
        let sets = enumerate_grouping_sets(vec![rollup.clone()])?;
        let result = format!("[{}]", expr_vec_fmt!(sets));
        assert_eq!("[ROLLUP (col1, col2, col3)]", &result);
        let sets = enumerate_grouping_sets(vec![simple_col.clone(), cube.clone()])?;
        let result = format!("[{}]", expr_vec_fmt!(sets));
        assert_eq!(
            "[GROUPING SETS (\
            (simple_col), \
            (simple_col, col1), \
            (simple_col, col2), \
            (simple_col, col1, col2), \
            (simple_col, col3), \
            (simple_col, col1, col3), \
            (simple_col, col2, col3), \
            (simple_col, col1, col2, col3))]",
            &result
        );
        let sets = enumerate_grouping_sets(vec![simple_col.clone(), rollup.clone()])?;
        let result = format!("[{}]", expr_vec_fmt!(sets));
        assert_eq!(
            "[GROUPING SETS (\
            (simple_col), \
            (simple_col, col1), \
            (simple_col, col1, col2), \
            (simple_col, col1, col2, col3))]",
            &result
        );
        let sets =
            enumerate_grouping_sets(vec![simple_col.clone(), grouping_set.clone()])?;
        let result = format!("[{}]", expr_vec_fmt!(sets));
        assert_eq!(
            "[GROUPING SETS (\
            (simple_col, col1, col2, col3))]",
            &result
        );
        let sets = enumerate_grouping_sets(vec![
            simple_col.clone(),
            grouping_set,
            rollup.clone(),
        ])?;
        let result = format!("[{}]", expr_vec_fmt!(sets));
        assert_eq!(
            "[GROUPING SETS (\
            (simple_col, col1, col2, col3), \
            (simple_col, col1, col2, col3, col1), \
            (simple_col, col1, col2, col3, col1, col2), \
            (simple_col, col1, col2, col3, col1, col2, col3))]",
            &result
        );
        let sets = enumerate_grouping_sets(vec![simple_col, cube, rollup])?;
        let result = format!("[{}]", expr_vec_fmt!(sets));
        assert_eq!(
            "[GROUPING SETS (\
            (simple_col), \
            (simple_col, col1), \
            (simple_col, col1, col2), \
            (simple_col, col1, col2, col3), \
            (simple_col, col1), \
            (simple_col, col1, col1), \
            (simple_col, col1, col1, col2), \
            (simple_col, col1, col1, col2, col3), \
            (simple_col, col2), \
            (simple_col, col2, col1), \
            (simple_col, col2, col1, col2), \
            (simple_col, col2, col1, col2, col3), \
            (simple_col, col1, col2), \
            (simple_col, col1, col2, col1), \
            (simple_col, col1, col2, col1, col2), \
            (simple_col, col1, col2, col1, col2, col3), \
            (simple_col, col3), \
            (simple_col, col3, col1), \
            (simple_col, col3, col1, col2), \
            (simple_col, col3, col1, col2, col3), \
            (simple_col, col1, col3), \
            (simple_col, col1, col3, col1), \
            (simple_col, col1, col3, col1, col2), \
            (simple_col, col1, col3, col1, col2, col3), \
            (simple_col, col2, col3), \
            (simple_col, col2, col3, col1), \
            (simple_col, col2, col3, col1, col2), \
            (simple_col, col2, col3, col1, col2, col3), \
            (simple_col, col1, col2, col3), \
            (simple_col, col1, col2, col3, col1), \
            (simple_col, col1, col2, col3, col1, col2), \
            (simple_col, col1, col2, col3, col1, col2, col3))]",
            &result
        );
        Ok(())
    }
}