use crate::aggregate_function::AggregateFunction;
use crate::type_coercion::functions::data_types;
use crate::{aggregate_function, AggregateUDF, Signature, TypeSignature, Volatility};
use arrow::datatypes::DataType;
use datafusion_common::{DataFusionError, Result};
use std::sync::Arc;
use std::{fmt, str::FromStr};
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum WindowFunction {
AggregateFunction(AggregateFunction),
BuiltInWindowFunction(BuiltInWindowFunction),
AggregateUDF(Arc<AggregateUDF>),
}
pub fn find_df_window_func(name: &str) -> Option<WindowFunction> {
let name = name.to_lowercase();
if let Ok(aggregate) = AggregateFunction::from_str(name.as_str()) {
Some(WindowFunction::AggregateFunction(aggregate))
} else if let Ok(built_in_function) = BuiltInWindowFunction::from_str(name.as_str()) {
Some(WindowFunction::BuiltInWindowFunction(built_in_function))
} else {
None
}
}
impl fmt::Display for BuiltInWindowFunction {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
BuiltInWindowFunction::RowNumber => write!(f, "ROW_NUMBER"),
BuiltInWindowFunction::Rank => write!(f, "RANK"),
BuiltInWindowFunction::DenseRank => write!(f, "DENSE_RANK"),
BuiltInWindowFunction::PercentRank => write!(f, "PERCENT_RANK"),
BuiltInWindowFunction::CumeDist => write!(f, "CUME_DIST"),
BuiltInWindowFunction::Ntile => write!(f, "NTILE"),
BuiltInWindowFunction::Lag => write!(f, "LAG"),
BuiltInWindowFunction::Lead => write!(f, "LEAD"),
BuiltInWindowFunction::FirstValue => write!(f, "FIRST_VALUE"),
BuiltInWindowFunction::LastValue => write!(f, "LAST_VALUE"),
BuiltInWindowFunction::NthValue => write!(f, "NTH_VALUE"),
}
}
}
impl fmt::Display for WindowFunction {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
WindowFunction::AggregateFunction(fun) => fun.fmt(f),
WindowFunction::BuiltInWindowFunction(fun) => fun.fmt(f),
WindowFunction::AggregateUDF(fun) => std::fmt::Debug::fmt(fun, f),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum BuiltInWindowFunction {
RowNumber,
Rank,
DenseRank,
PercentRank,
CumeDist,
Ntile,
Lag,
Lead,
FirstValue,
LastValue,
NthValue,
}
impl FromStr for BuiltInWindowFunction {
type Err = DataFusionError;
fn from_str(name: &str) -> Result<BuiltInWindowFunction> {
Ok(match name.to_uppercase().as_str() {
"ROW_NUMBER" => BuiltInWindowFunction::RowNumber,
"RANK" => BuiltInWindowFunction::Rank,
"DENSE_RANK" => BuiltInWindowFunction::DenseRank,
"PERCENT_RANK" => BuiltInWindowFunction::PercentRank,
"CUME_DIST" => BuiltInWindowFunction::CumeDist,
"NTILE" => BuiltInWindowFunction::Ntile,
"LAG" => BuiltInWindowFunction::Lag,
"LEAD" => BuiltInWindowFunction::Lead,
"FIRST_VALUE" => BuiltInWindowFunction::FirstValue,
"LAST_VALUE" => BuiltInWindowFunction::LastValue,
"NTH_VALUE" => BuiltInWindowFunction::NthValue,
_ => {
return Err(DataFusionError::Plan(format!(
"There is no built-in window function named {name}"
)))
}
})
}
}
pub fn return_type(
fun: &WindowFunction,
input_expr_types: &[DataType],
) -> Result<DataType> {
match fun {
WindowFunction::AggregateFunction(fun) => {
aggregate_function::return_type(fun, input_expr_types)
}
WindowFunction::BuiltInWindowFunction(fun) => {
return_type_for_built_in(fun, input_expr_types)
}
WindowFunction::AggregateUDF(fun) => {
Ok((*(fun.return_type)(input_expr_types)?).clone())
}
}
}
fn return_type_for_built_in(
fun: &BuiltInWindowFunction,
input_expr_types: &[DataType],
) -> Result<DataType> {
data_types(input_expr_types, &signature_for_built_in(fun))?;
match fun {
BuiltInWindowFunction::RowNumber
| BuiltInWindowFunction::Rank
| BuiltInWindowFunction::DenseRank => Ok(DataType::UInt64),
BuiltInWindowFunction::PercentRank | BuiltInWindowFunction::CumeDist => {
Ok(DataType::Float64)
}
BuiltInWindowFunction::Ntile => Ok(DataType::UInt32),
BuiltInWindowFunction::Lag
| BuiltInWindowFunction::Lead
| BuiltInWindowFunction::FirstValue
| BuiltInWindowFunction::LastValue
| BuiltInWindowFunction::NthValue => Ok(input_expr_types[0].clone()),
}
}
pub fn signature(fun: &WindowFunction) -> Signature {
match fun {
WindowFunction::AggregateFunction(fun) => aggregate_function::signature(fun),
WindowFunction::BuiltInWindowFunction(fun) => signature_for_built_in(fun),
WindowFunction::AggregateUDF(fun) => fun.signature.clone(),
}
}
pub fn signature_for_built_in(fun: &BuiltInWindowFunction) -> Signature {
match fun {
BuiltInWindowFunction::RowNumber
| BuiltInWindowFunction::Rank
| BuiltInWindowFunction::DenseRank
| BuiltInWindowFunction::PercentRank
| BuiltInWindowFunction::CumeDist => Signature::any(0, Volatility::Immutable),
BuiltInWindowFunction::Lag | BuiltInWindowFunction::Lead => Signature::one_of(
vec![
TypeSignature::Any(1),
TypeSignature::Any(2),
TypeSignature::Any(3),
],
Volatility::Immutable,
),
BuiltInWindowFunction::FirstValue | BuiltInWindowFunction::LastValue => {
Signature::any(1, Volatility::Immutable)
}
BuiltInWindowFunction::Ntile => Signature::any(1, Volatility::Immutable),
BuiltInWindowFunction::NthValue => Signature::any(2, Volatility::Immutable),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_count_return_type() -> Result<()> {
let fun = find_df_window_func("count").unwrap();
let observed = return_type(&fun, &[DataType::Utf8])?;
assert_eq!(DataType::Int64, observed);
let observed = return_type(&fun, &[DataType::UInt64])?;
assert_eq!(DataType::Int64, observed);
Ok(())
}
#[test]
fn test_first_value_return_type() -> Result<()> {
let fun = find_df_window_func("first_value").unwrap();
let observed = return_type(&fun, &[DataType::Utf8])?;
assert_eq!(DataType::Utf8, observed);
let observed = return_type(&fun, &[DataType::UInt64])?;
assert_eq!(DataType::UInt64, observed);
Ok(())
}
#[test]
fn test_last_value_return_type() -> Result<()> {
let fun = find_df_window_func("last_value").unwrap();
let observed = return_type(&fun, &[DataType::Utf8])?;
assert_eq!(DataType::Utf8, observed);
let observed = return_type(&fun, &[DataType::Float64])?;
assert_eq!(DataType::Float64, observed);
Ok(())
}
#[test]
fn test_lead_return_type() -> Result<()> {
let fun = find_df_window_func("lead").unwrap();
let observed = return_type(&fun, &[DataType::Utf8])?;
assert_eq!(DataType::Utf8, observed);
let observed = return_type(&fun, &[DataType::Float64])?;
assert_eq!(DataType::Float64, observed);
Ok(())
}
#[test]
fn test_lag_return_type() -> Result<()> {
let fun = find_df_window_func("lag").unwrap();
let observed = return_type(&fun, &[DataType::Utf8])?;
assert_eq!(DataType::Utf8, observed);
let observed = return_type(&fun, &[DataType::Float64])?;
assert_eq!(DataType::Float64, observed);
Ok(())
}
#[test]
fn test_nth_value_return_type() -> Result<()> {
let fun = find_df_window_func("nth_value").unwrap();
let observed = return_type(&fun, &[DataType::Utf8, DataType::UInt64])?;
assert_eq!(DataType::Utf8, observed);
let observed = return_type(&fun, &[DataType::Float64, DataType::UInt64])?;
assert_eq!(DataType::Float64, observed);
Ok(())
}
#[test]
fn test_percent_rank_return_type() -> Result<()> {
let fun = find_df_window_func("percent_rank").unwrap();
let observed = return_type(&fun, &[])?;
assert_eq!(DataType::Float64, observed);
Ok(())
}
#[test]
fn test_cume_dist_return_type() -> Result<()> {
let fun = find_df_window_func("cume_dist").unwrap();
let observed = return_type(&fun, &[])?;
assert_eq!(DataType::Float64, observed);
Ok(())
}
#[test]
fn test_window_function_case_insensitive() -> Result<()> {
let names = vec![
"row_number",
"rank",
"dense_rank",
"percent_rank",
"cume_dist",
"ntile",
"lag",
"lead",
"first_value",
"last_value",
"nth_value",
"min",
"max",
"count",
"avg",
"sum",
];
for name in names {
let fun = find_df_window_func(name).unwrap();
let fun2 = find_df_window_func(name.to_uppercase().as_str()).unwrap();
assert_eq!(fun, fun2);
assert_eq!(fun.to_string(), name.to_uppercase());
}
Ok(())
}
#[test]
fn test_find_df_window_function() {
assert_eq!(
find_df_window_func("max"),
Some(WindowFunction::AggregateFunction(AggregateFunction::Max))
);
assert_eq!(
find_df_window_func("min"),
Some(WindowFunction::AggregateFunction(AggregateFunction::Min))
);
assert_eq!(
find_df_window_func("avg"),
Some(WindowFunction::AggregateFunction(AggregateFunction::Avg))
);
assert_eq!(
find_df_window_func("cume_dist"),
Some(WindowFunction::BuiltInWindowFunction(
BuiltInWindowFunction::CumeDist
))
);
assert_eq!(
find_df_window_func("first_value"),
Some(WindowFunction::BuiltInWindowFunction(
BuiltInWindowFunction::FirstValue
))
);
assert_eq!(
find_df_window_func("LAST_value"),
Some(WindowFunction::BuiltInWindowFunction(
BuiltInWindowFunction::LastValue
))
);
assert_eq!(
find_df_window_func("LAG"),
Some(WindowFunction::BuiltInWindowFunction(
BuiltInWindowFunction::Lag
))
);
assert_eq!(
find_df_window_func("LEAD"),
Some(WindowFunction::BuiltInWindowFunction(
BuiltInWindowFunction::Lead
))
);
assert_eq!(find_df_window_func("not_exist"), None)
}
}