use crate::aggregate_function::AggregateFunction;
use crate::type_coercion::functions::data_types;
use crate::utils;
use crate::{AggregateUDF, Signature, TypeSignature, Volatility, WindowUDF};
use arrow::datatypes::DataType;
use datafusion_common::{plan_datafusion_err, plan_err, DataFusionError, Result};
use std::sync::Arc;
use std::{fmt, str::FromStr};
use strum_macros::EnumIter;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum WindowFunction {
    AggregateFunction(AggregateFunction),
    BuiltInWindowFunction(BuiltInWindowFunction),
    AggregateUDF(Arc<AggregateUDF>),
    WindowUDF(Arc<WindowUDF>),
}
pub fn find_df_window_func(name: &str) -> Option<WindowFunction> {
    let name = name.to_lowercase();
    if let Ok(built_in_function) = BuiltInWindowFunction::from_str(name.as_str()) {
        Some(WindowFunction::BuiltInWindowFunction(built_in_function))
    } else if let Ok(aggregate) = AggregateFunction::from_str(name.as_str()) {
        Some(WindowFunction::AggregateFunction(aggregate))
    } else {
        None
    }
}
impl fmt::Display for BuiltInWindowFunction {
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        write!(f, "{}", self.name())
    }
}
impl fmt::Display for WindowFunction {
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        match self {
            WindowFunction::AggregateFunction(fun) => fun.fmt(f),
            WindowFunction::BuiltInWindowFunction(fun) => fun.fmt(f),
            WindowFunction::AggregateUDF(fun) => std::fmt::Debug::fmt(fun, f),
            WindowFunction::WindowUDF(fun) => fun.fmt(f),
        }
    }
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, EnumIter)]
pub enum BuiltInWindowFunction {
    RowNumber,
    Rank,
    DenseRank,
    PercentRank,
    CumeDist,
    Ntile,
    Lag,
    Lead,
    FirstValue,
    LastValue,
    NthValue,
}
impl BuiltInWindowFunction {
    fn name(&self) -> &str {
        use BuiltInWindowFunction::*;
        match self {
            RowNumber => "ROW_NUMBER",
            Rank => "RANK",
            DenseRank => "DENSE_RANK",
            PercentRank => "PERCENT_RANK",
            CumeDist => "CUME_DIST",
            Ntile => "NTILE",
            Lag => "LAG",
            Lead => "LEAD",
            FirstValue => "FIRST_VALUE",
            LastValue => "LAST_VALUE",
            NthValue => "NTH_VALUE",
        }
    }
}
impl FromStr for BuiltInWindowFunction {
    type Err = DataFusionError;
    fn from_str(name: &str) -> Result<BuiltInWindowFunction> {
        Ok(match name.to_uppercase().as_str() {
            "ROW_NUMBER" => BuiltInWindowFunction::RowNumber,
            "RANK" => BuiltInWindowFunction::Rank,
            "DENSE_RANK" => BuiltInWindowFunction::DenseRank,
            "PERCENT_RANK" => BuiltInWindowFunction::PercentRank,
            "CUME_DIST" => BuiltInWindowFunction::CumeDist,
            "NTILE" => BuiltInWindowFunction::Ntile,
            "LAG" => BuiltInWindowFunction::Lag,
            "LEAD" => BuiltInWindowFunction::Lead,
            "FIRST_VALUE" => BuiltInWindowFunction::FirstValue,
            "LAST_VALUE" => BuiltInWindowFunction::LastValue,
            "NTH_VALUE" => BuiltInWindowFunction::NthValue,
            _ => return plan_err!("There is no built-in window function named {name}"),
        })
    }
}
#[deprecated(
    since = "27.0.0",
    note = "please use `WindowFunction::return_type` instead"
)]
pub fn return_type(
    fun: &WindowFunction,
    input_expr_types: &[DataType],
) -> Result<DataType> {
    fun.return_type(input_expr_types)
}
impl WindowFunction {
    pub fn return_type(&self, input_expr_types: &[DataType]) -> Result<DataType> {
        match self {
            WindowFunction::AggregateFunction(fun) => fun.return_type(input_expr_types),
            WindowFunction::BuiltInWindowFunction(fun) => {
                fun.return_type(input_expr_types)
            }
            WindowFunction::AggregateUDF(fun) => fun.return_type(input_expr_types),
            WindowFunction::WindowUDF(fun) => fun.return_type(input_expr_types),
        }
    }
}
impl BuiltInWindowFunction {
    pub fn return_type(&self, input_expr_types: &[DataType]) -> Result<DataType> {
        data_types(input_expr_types, &self.signature())
            .map_err(|_| {
                plan_datafusion_err!(
                    "{}",
                    utils::generate_signature_error_msg(
                        &format!("{self}"),
                        self.signature(),
                        input_expr_types,
                    )
                )
            })?;
        match self {
            BuiltInWindowFunction::RowNumber
            | BuiltInWindowFunction::Rank
            | BuiltInWindowFunction::DenseRank => Ok(DataType::UInt64),
            BuiltInWindowFunction::PercentRank | BuiltInWindowFunction::CumeDist => {
                Ok(DataType::Float64)
            }
            BuiltInWindowFunction::Ntile => Ok(DataType::UInt64),
            BuiltInWindowFunction::Lag
            | BuiltInWindowFunction::Lead
            | BuiltInWindowFunction::FirstValue
            | BuiltInWindowFunction::LastValue
            | BuiltInWindowFunction::NthValue => Ok(input_expr_types[0].clone()),
        }
    }
}
#[deprecated(
    since = "27.0.0",
    note = "please use `WindowFunction::signature` instead"
)]
pub fn signature(fun: &WindowFunction) -> Signature {
    fun.signature()
}
impl WindowFunction {
    pub fn signature(&self) -> Signature {
        match self {
            WindowFunction::AggregateFunction(fun) => fun.signature(),
            WindowFunction::BuiltInWindowFunction(fun) => fun.signature(),
            WindowFunction::AggregateUDF(fun) => fun.signature().clone(),
            WindowFunction::WindowUDF(fun) => fun.signature().clone(),
        }
    }
}
#[deprecated(
    since = "27.0.0",
    note = "please use `BuiltInWindowFunction::signature` instead"
)]
pub fn signature_for_built_in(fun: &BuiltInWindowFunction) -> Signature {
    fun.signature()
}
impl BuiltInWindowFunction {
    pub fn signature(&self) -> Signature {
        match self {
            BuiltInWindowFunction::RowNumber
            | BuiltInWindowFunction::Rank
            | BuiltInWindowFunction::DenseRank
            | BuiltInWindowFunction::PercentRank
            | BuiltInWindowFunction::CumeDist => Signature::any(0, Volatility::Immutable),
            BuiltInWindowFunction::Lag | BuiltInWindowFunction::Lead => {
                Signature::one_of(
                    vec![
                        TypeSignature::Any(1),
                        TypeSignature::Any(2),
                        TypeSignature::Any(3),
                    ],
                    Volatility::Immutable,
                )
            }
            BuiltInWindowFunction::FirstValue | BuiltInWindowFunction::LastValue => {
                Signature::any(1, Volatility::Immutable)
            }
            BuiltInWindowFunction::Ntile => Signature::uniform(
                1,
                vec![
                    DataType::UInt64,
                    DataType::UInt32,
                    DataType::UInt16,
                    DataType::UInt8,
                    DataType::Int64,
                    DataType::Int32,
                    DataType::Int16,
                    DataType::Int8,
                ],
                Volatility::Immutable,
            ),
            BuiltInWindowFunction::NthValue => Signature::any(2, Volatility::Immutable),
        }
    }
}
#[cfg(test)]
mod tests {
    use super::*;
    use strum::IntoEnumIterator;
    #[test]
    fn test_count_return_type() -> Result<()> {
        let fun = find_df_window_func("count").unwrap();
        let observed = fun.return_type(&[DataType::Utf8])?;
        assert_eq!(DataType::Int64, observed);
        let observed = fun.return_type(&[DataType::UInt64])?;
        assert_eq!(DataType::Int64, observed);
        Ok(())
    }
    #[test]
    fn test_first_value_return_type() -> Result<()> {
        let fun = find_df_window_func("first_value").unwrap();
        let observed = fun.return_type(&[DataType::Utf8])?;
        assert_eq!(DataType::Utf8, observed);
        let observed = fun.return_type(&[DataType::UInt64])?;
        assert_eq!(DataType::UInt64, observed);
        Ok(())
    }
    #[test]
    fn test_last_value_return_type() -> Result<()> {
        let fun = find_df_window_func("last_value").unwrap();
        let observed = fun.return_type(&[DataType::Utf8])?;
        assert_eq!(DataType::Utf8, observed);
        let observed = fun.return_type(&[DataType::Float64])?;
        assert_eq!(DataType::Float64, observed);
        Ok(())
    }
    #[test]
    fn test_lead_return_type() -> Result<()> {
        let fun = find_df_window_func("lead").unwrap();
        let observed = fun.return_type(&[DataType::Utf8])?;
        assert_eq!(DataType::Utf8, observed);
        let observed = fun.return_type(&[DataType::Float64])?;
        assert_eq!(DataType::Float64, observed);
        Ok(())
    }
    #[test]
    fn test_lag_return_type() -> Result<()> {
        let fun = find_df_window_func("lag").unwrap();
        let observed = fun.return_type(&[DataType::Utf8])?;
        assert_eq!(DataType::Utf8, observed);
        let observed = fun.return_type(&[DataType::Float64])?;
        assert_eq!(DataType::Float64, observed);
        Ok(())
    }
    #[test]
    fn test_nth_value_return_type() -> Result<()> {
        let fun = find_df_window_func("nth_value").unwrap();
        let observed = fun.return_type(&[DataType::Utf8, DataType::UInt64])?;
        assert_eq!(DataType::Utf8, observed);
        let observed = fun.return_type(&[DataType::Float64, DataType::UInt64])?;
        assert_eq!(DataType::Float64, observed);
        Ok(())
    }
    #[test]
    fn test_percent_rank_return_type() -> Result<()> {
        let fun = find_df_window_func("percent_rank").unwrap();
        let observed = fun.return_type(&[])?;
        assert_eq!(DataType::Float64, observed);
        Ok(())
    }
    #[test]
    fn test_cume_dist_return_type() -> Result<()> {
        let fun = find_df_window_func("cume_dist").unwrap();
        let observed = fun.return_type(&[])?;
        assert_eq!(DataType::Float64, observed);
        Ok(())
    }
    #[test]
    fn test_ntile_return_type() -> Result<()> {
        let fun = find_df_window_func("ntile").unwrap();
        let observed = fun.return_type(&[DataType::Int16])?;
        assert_eq!(DataType::UInt64, observed);
        Ok(())
    }
    #[test]
    fn test_window_function_case_insensitive() -> Result<()> {
        let names = vec![
            "row_number",
            "rank",
            "dense_rank",
            "percent_rank",
            "cume_dist",
            "ntile",
            "lag",
            "lead",
            "first_value",
            "last_value",
            "nth_value",
            "min",
            "max",
            "count",
            "avg",
            "sum",
        ];
        for name in names {
            let fun = find_df_window_func(name).unwrap();
            let fun2 = find_df_window_func(name.to_uppercase().as_str()).unwrap();
            assert_eq!(fun, fun2);
            assert_eq!(fun.to_string(), name.to_uppercase());
        }
        Ok(())
    }
    #[test]
    fn test_find_df_window_function() {
        assert_eq!(
            find_df_window_func("max"),
            Some(WindowFunction::AggregateFunction(AggregateFunction::Max))
        );
        assert_eq!(
            find_df_window_func("min"),
            Some(WindowFunction::AggregateFunction(AggregateFunction::Min))
        );
        assert_eq!(
            find_df_window_func("avg"),
            Some(WindowFunction::AggregateFunction(AggregateFunction::Avg))
        );
        assert_eq!(
            find_df_window_func("cume_dist"),
            Some(WindowFunction::BuiltInWindowFunction(
                BuiltInWindowFunction::CumeDist
            ))
        );
        assert_eq!(
            find_df_window_func("first_value"),
            Some(WindowFunction::BuiltInWindowFunction(
                BuiltInWindowFunction::FirstValue
            ))
        );
        assert_eq!(
            find_df_window_func("LAST_value"),
            Some(WindowFunction::BuiltInWindowFunction(
                BuiltInWindowFunction::LastValue
            ))
        );
        assert_eq!(
            find_df_window_func("LAG"),
            Some(WindowFunction::BuiltInWindowFunction(
                BuiltInWindowFunction::Lag
            ))
        );
        assert_eq!(
            find_df_window_func("LEAD"),
            Some(WindowFunction::BuiltInWindowFunction(
                BuiltInWindowFunction::Lead
            ))
        );
        assert_eq!(find_df_window_func("not_exist"), None)
    }
    #[test]
    fn test_display_and_from_str() {
        for func_original in BuiltInWindowFunction::iter() {
            let func_name = func_original.to_string();
            let func_from_str = BuiltInWindowFunction::from_str(&func_name).unwrap();
            assert_eq!(func_from_str, func_original);
        }
    }
}