use crate::aggregate_function::AggregateFunction;
use datafusion_common::{DataFusionError, Result};
use std::{fmt, str::FromStr};
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum WindowFunction {
AggregateFunction(AggregateFunction),
BuiltInWindowFunction(BuiltInWindowFunction),
}
impl FromStr for WindowFunction {
type Err = DataFusionError;
fn from_str(name: &str) -> Result<WindowFunction> {
let name = name.to_lowercase();
if let Ok(aggregate) = AggregateFunction::from_str(name.as_str()) {
Ok(WindowFunction::AggregateFunction(aggregate))
} else if let Ok(built_in_function) =
BuiltInWindowFunction::from_str(name.as_str())
{
Ok(WindowFunction::BuiltInWindowFunction(built_in_function))
} else {
Err(DataFusionError::Plan(format!(
"There is no window function named {}",
name
)))
}
}
}
impl fmt::Display for BuiltInWindowFunction {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
BuiltInWindowFunction::RowNumber => write!(f, "ROW_NUMBER"),
BuiltInWindowFunction::Rank => write!(f, "RANK"),
BuiltInWindowFunction::DenseRank => write!(f, "DENSE_RANK"),
BuiltInWindowFunction::PercentRank => write!(f, "PERCENT_RANK"),
BuiltInWindowFunction::CumeDist => write!(f, "CUME_DIST"),
BuiltInWindowFunction::Ntile => write!(f, "NTILE"),
BuiltInWindowFunction::Lag => write!(f, "LAG"),
BuiltInWindowFunction::Lead => write!(f, "LEAD"),
BuiltInWindowFunction::FirstValue => write!(f, "FIRST_VALUE"),
BuiltInWindowFunction::LastValue => write!(f, "LAST_VALUE"),
BuiltInWindowFunction::NthValue => write!(f, "NTH_VALUE"),
}
}
}
impl fmt::Display for WindowFunction {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
WindowFunction::AggregateFunction(fun) => fun.fmt(f),
WindowFunction::BuiltInWindowFunction(fun) => fun.fmt(f),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum BuiltInWindowFunction {
RowNumber,
Rank,
DenseRank,
PercentRank,
CumeDist,
Ntile,
Lag,
Lead,
FirstValue,
LastValue,
NthValue,
}
impl FromStr for BuiltInWindowFunction {
type Err = DataFusionError;
fn from_str(name: &str) -> Result<BuiltInWindowFunction> {
Ok(match name.to_uppercase().as_str() {
"ROW_NUMBER" => BuiltInWindowFunction::RowNumber,
"RANK" => BuiltInWindowFunction::Rank,
"DENSE_RANK" => BuiltInWindowFunction::DenseRank,
"PERCENT_RANK" => BuiltInWindowFunction::PercentRank,
"CUME_DIST" => BuiltInWindowFunction::CumeDist,
"NTILE" => BuiltInWindowFunction::Ntile,
"LAG" => BuiltInWindowFunction::Lag,
"LEAD" => BuiltInWindowFunction::Lead,
"FIRST_VALUE" => BuiltInWindowFunction::FirstValue,
"LAST_VALUE" => BuiltInWindowFunction::LastValue,
"NTH_VALUE" => BuiltInWindowFunction::NthValue,
_ => {
return Err(DataFusionError::Plan(format!(
"There is no built-in window function named {}",
name
)))
}
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_window_function_case_insensitive() -> Result<()> {
let names = vec![
"row_number",
"rank",
"dense_rank",
"percent_rank",
"cume_dist",
"ntile",
"lag",
"lead",
"first_value",
"last_value",
"nth_value",
"min",
"max",
"count",
"avg",
"sum",
];
for name in names {
let fun = WindowFunction::from_str(name)?;
let fun2 = WindowFunction::from_str(name.to_uppercase().as_str())?;
assert_eq!(fun, fun2);
assert_eq!(fun.to_string(), name.to_uppercase());
}
Ok(())
}
#[test]
fn test_window_function_from_str() -> Result<()> {
assert_eq!(
WindowFunction::from_str("max")?,
WindowFunction::AggregateFunction(AggregateFunction::Max)
);
assert_eq!(
WindowFunction::from_str("min")?,
WindowFunction::AggregateFunction(AggregateFunction::Min)
);
assert_eq!(
WindowFunction::from_str("avg")?,
WindowFunction::AggregateFunction(AggregateFunction::Avg)
);
assert_eq!(
WindowFunction::from_str("cume_dist")?,
WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::CumeDist)
);
assert_eq!(
WindowFunction::from_str("first_value")?,
WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::FirstValue)
);
assert_eq!(
WindowFunction::from_str("LAST_value")?,
WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::LastValue)
);
assert_eq!(
WindowFunction::from_str("LAG")?,
WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::Lag)
);
assert_eq!(
WindowFunction::from_str("LEAD")?,
WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::Lead)
);
Ok(())
}
}