use crate::{expr::Sort, lit};
use arrow::datatypes::DataType;
use std::fmt::{self, Formatter};
use std::hash::Hash;
use datafusion_common::{plan_err, sql_err, DataFusionError, Result, ScalarValue};
use sqlparser::ast;
use sqlparser::parser::ParserError::ParserError;
#[derive(Clone, PartialEq, Eq, PartialOrd, Hash)]
pub struct WindowFrame {
pub units: WindowFrameUnits,
pub start_bound: WindowFrameBound,
pub end_bound: WindowFrameBound,
causal: bool,
}
impl fmt::Display for WindowFrame {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
write!(
f,
"{} BETWEEN {} AND {}",
self.units, self.start_bound, self.end_bound
)?;
Ok(())
}
}
impl fmt::Debug for WindowFrame {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(
f,
"WindowFrame {{ units: {:?}, start_bound: {:?}, end_bound: {:?}, is_causal: {:?} }}",
self.units, self.start_bound, self.end_bound, self.causal
)?;
Ok(())
}
}
impl TryFrom<ast::WindowFrame> for WindowFrame {
type Error = DataFusionError;
fn try_from(value: ast::WindowFrame) -> Result<Self> {
let start_bound = WindowFrameBound::try_parse(value.start_bound, &value.units)?;
let end_bound = match value.end_bound {
Some(bound) => WindowFrameBound::try_parse(bound, &value.units)?,
None => WindowFrameBound::CurrentRow,
};
if let WindowFrameBound::Following(val) = &start_bound {
if val.is_null() {
plan_err!(
"Invalid window frame: start bound cannot be UNBOUNDED FOLLOWING"
)?
}
} else if let WindowFrameBound::Preceding(val) = &end_bound {
if val.is_null() {
plan_err!(
"Invalid window frame: end bound cannot be UNBOUNDED PRECEDING"
)?
}
};
let units = value.units.into();
Ok(Self::new_bounds(units, start_bound, end_bound))
}
}
impl WindowFrame {
pub fn new(order_by: Option<bool>) -> Self {
if let Some(strict) = order_by {
Self {
units: if strict {
WindowFrameUnits::Rows
} else {
WindowFrameUnits::Range
},
start_bound: WindowFrameBound::Preceding(ScalarValue::Null),
end_bound: WindowFrameBound::CurrentRow,
causal: strict,
}
} else {
Self {
units: WindowFrameUnits::Rows,
start_bound: WindowFrameBound::Preceding(ScalarValue::UInt64(None)),
end_bound: WindowFrameBound::Following(ScalarValue::UInt64(None)),
causal: false,
}
}
}
pub fn reverse(&self) -> Self {
let start_bound = match &self.end_bound {
WindowFrameBound::Preceding(value) => {
WindowFrameBound::Following(value.clone())
}
WindowFrameBound::Following(value) => {
WindowFrameBound::Preceding(value.clone())
}
WindowFrameBound::CurrentRow => WindowFrameBound::CurrentRow,
};
let end_bound = match &self.start_bound {
WindowFrameBound::Preceding(value) => {
WindowFrameBound::Following(value.clone())
}
WindowFrameBound::Following(value) => {
WindowFrameBound::Preceding(value.clone())
}
WindowFrameBound::CurrentRow => WindowFrameBound::CurrentRow,
};
Self::new_bounds(self.units, start_bound, end_bound)
}
pub fn is_causal(&self) -> bool {
self.causal
}
pub fn new_bounds(
units: WindowFrameUnits,
start_bound: WindowFrameBound,
end_bound: WindowFrameBound,
) -> Self {
let causal = match units {
WindowFrameUnits::Rows => match &end_bound {
WindowFrameBound::Following(value) => {
if value.is_null() {
false
} else {
let zero = ScalarValue::new_zero(&value.data_type());
zero.map(|zero| value.eq(&zero)).unwrap_or(false)
}
}
_ => true,
},
WindowFrameUnits::Range | WindowFrameUnits::Groups => match &end_bound {
WindowFrameBound::Preceding(value) => {
if value.is_null() {
true
} else {
let zero = ScalarValue::new_zero(&value.data_type());
zero.map(|zero| value.gt(&zero)).unwrap_or(false)
}
}
_ => false,
},
};
Self {
units,
start_bound,
end_bound,
causal,
}
}
pub fn regularize_order_bys(&self, order_by: &mut Vec<Sort>) -> Result<()> {
match self.units {
WindowFrameUnits::Range if self.free_range() => {
if order_by.is_empty() {
order_by.push(lit(1u64).sort(true, false));
}
}
WindowFrameUnits::Range if order_by.len() != 1 => {
return plan_err!("RANGE requires exactly one ORDER BY column");
}
WindowFrameUnits::Groups if order_by.is_empty() => {
return plan_err!("GROUPS requires an ORDER BY clause");
}
_ => {}
}
Ok(())
}
pub fn can_accept_multi_orderby(&self) -> bool {
match self.units {
WindowFrameUnits::Rows => true,
WindowFrameUnits::Range => self.free_range(),
WindowFrameUnits::Groups => true,
}
}
fn free_range(&self) -> bool {
(self.start_bound.is_unbounded()
|| self.start_bound == WindowFrameBound::CurrentRow)
&& (self.end_bound.is_unbounded()
|| self.end_bound == WindowFrameBound::CurrentRow)
}
pub fn is_ever_expanding(&self) -> bool {
self.start_bound.is_unbounded()
}
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)]
pub enum WindowFrameBound {
Preceding(ScalarValue),
CurrentRow,
Following(ScalarValue),
}
impl WindowFrameBound {
pub fn is_unbounded(&self) -> bool {
match self {
WindowFrameBound::Preceding(elem) => elem.is_null(),
WindowFrameBound::CurrentRow => false,
WindowFrameBound::Following(elem) => elem.is_null(),
}
}
}
impl WindowFrameBound {
fn try_parse(
value: ast::WindowFrameBound,
units: &ast::WindowFrameUnits,
) -> Result<Self> {
Ok(match value {
ast::WindowFrameBound::Preceding(Some(v)) => {
Self::Preceding(convert_frame_bound_to_scalar_value(*v, units)?)
}
ast::WindowFrameBound::Preceding(None) => Self::Preceding(ScalarValue::Null),
ast::WindowFrameBound::Following(Some(v)) => {
Self::Following(convert_frame_bound_to_scalar_value(*v, units)?)
}
ast::WindowFrameBound::Following(None) => Self::Following(ScalarValue::Null),
ast::WindowFrameBound::CurrentRow => Self::CurrentRow,
})
}
}
fn convert_frame_bound_to_scalar_value(
v: ast::Expr,
units: &ast::WindowFrameUnits,
) -> Result<ScalarValue> {
match units {
ast::WindowFrameUnits::Rows | ast::WindowFrameUnits::Groups => match v {
ast::Expr::Value(ast::Value::Number(value, false)) => {
Ok(ScalarValue::try_from_string(value, &DataType::UInt64)?)
},
ast::Expr::Interval(ast::Interval {
value,
leading_field: None,
leading_precision: None,
last_field: None,
fractional_seconds_precision: None,
}) => {
let value = match *value {
ast::Expr::Value(ast::Value::SingleQuotedString(item)) => item,
e => {
return sql_err!(ParserError(format!(
"INTERVAL expression cannot be {e:?}"
)));
}
};
Ok(ScalarValue::try_from_string(value, &DataType::UInt64)?)
}
_ => plan_err!(
"Invalid window frame: frame offsets for ROWS / GROUPS must be non negative integers"
),
},
ast::WindowFrameUnits::Range => Ok(ScalarValue::Utf8(Some(match v {
ast::Expr::Value(ast::Value::Number(value, false)) => value,
ast::Expr::Interval(ast::Interval {
value,
leading_field,
..
}) => {
let result = match *value {
ast::Expr::Value(ast::Value::SingleQuotedString(item)) => item,
e => {
return sql_err!(ParserError(format!(
"INTERVAL expression cannot be {e:?}"
)));
}
};
if let Some(leading_field) = leading_field {
format!("{result} {leading_field}")
} else {
result
}
}
_ => plan_err!(
"Invalid window frame: frame offsets for RANGE must be either a numeric value, a string value or an interval"
)?,
}))),
}
}
impl fmt::Display for WindowFrameBound {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
match self {
WindowFrameBound::Preceding(n) => {
if n.is_null() {
f.write_str("UNBOUNDED PRECEDING")
} else {
write!(f, "{n} PRECEDING")
}
}
WindowFrameBound::CurrentRow => f.write_str("CURRENT ROW"),
WindowFrameBound::Following(n) => {
if n.is_null() {
f.write_str("UNBOUNDED FOLLOWING")
} else {
write!(f, "{n} FOLLOWING")
}
}
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Hash)]
pub enum WindowFrameUnits {
Rows,
Range,
Groups,
}
impl fmt::Display for WindowFrameUnits {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
f.write_str(match self {
WindowFrameUnits::Rows => "ROWS",
WindowFrameUnits::Range => "RANGE",
WindowFrameUnits::Groups => "GROUPS",
})
}
}
impl From<ast::WindowFrameUnits> for WindowFrameUnits {
fn from(value: ast::WindowFrameUnits) -> Self {
match value {
ast::WindowFrameUnits::Range => Self::Range,
ast::WindowFrameUnits::Groups => Self::Groups,
ast::WindowFrameUnits::Rows => Self::Rows,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_window_frame_creation() -> Result<()> {
let window_frame = ast::WindowFrame {
units: ast::WindowFrameUnits::Range,
start_bound: ast::WindowFrameBound::Following(None),
end_bound: None,
};
let err = WindowFrame::try_from(window_frame).unwrap_err();
assert_eq!(
err.strip_backtrace(),
"Error during planning: Invalid window frame: start bound cannot be UNBOUNDED FOLLOWING".to_owned()
);
let window_frame = ast::WindowFrame {
units: ast::WindowFrameUnits::Range,
start_bound: ast::WindowFrameBound::Preceding(None),
end_bound: Some(ast::WindowFrameBound::Preceding(None)),
};
let err = WindowFrame::try_from(window_frame).unwrap_err();
assert_eq!(
err.strip_backtrace(),
"Error during planning: Invalid window frame: end bound cannot be UNBOUNDED PRECEDING".to_owned()
);
let window_frame = ast::WindowFrame {
units: ast::WindowFrameUnits::Rows,
start_bound: ast::WindowFrameBound::Preceding(Some(Box::new(
ast::Expr::Value(ast::Value::Number("2".to_string(), false)),
))),
end_bound: Some(ast::WindowFrameBound::Preceding(Some(Box::new(
ast::Expr::Value(ast::Value::Number("1".to_string(), false)),
)))),
};
let window_frame = WindowFrame::try_from(window_frame)?;
assert_eq!(window_frame.units, WindowFrameUnits::Rows);
assert_eq!(
window_frame.start_bound,
WindowFrameBound::Preceding(ScalarValue::UInt64(Some(2)))
);
assert_eq!(
window_frame.end_bound,
WindowFrameBound::Preceding(ScalarValue::UInt64(Some(1)))
);
Ok(())
}
macro_rules! test_bound {
($unit:ident, $value:expr, $expected:expr) => {
let preceding = WindowFrameBound::try_parse(
ast::WindowFrameBound::Preceding($value),
&ast::WindowFrameUnits::$unit,
)?;
assert_eq!(preceding, WindowFrameBound::Preceding($expected));
let following = WindowFrameBound::try_parse(
ast::WindowFrameBound::Following($value),
&ast::WindowFrameUnits::$unit,
)?;
assert_eq!(following, WindowFrameBound::Following($expected));
};
}
macro_rules! test_bound_err {
($unit:ident, $value:expr, $expected:expr) => {
let err = WindowFrameBound::try_parse(
ast::WindowFrameBound::Preceding($value),
&ast::WindowFrameUnits::$unit,
)
.unwrap_err();
assert_eq!(err.strip_backtrace(), $expected);
let err = WindowFrameBound::try_parse(
ast::WindowFrameBound::Following($value),
&ast::WindowFrameUnits::$unit,
)
.unwrap_err();
assert_eq!(err.strip_backtrace(), $expected);
};
}
#[test]
fn test_window_frame_bound_creation() -> Result<()> {
test_bound!(Rows, None, ScalarValue::Null);
test_bound!(Groups, None, ScalarValue::Null);
test_bound!(Range, None, ScalarValue::Null);
let number = Some(Box::new(ast::Expr::Value(ast::Value::Number(
"42".to_string(),
false,
))));
test_bound!(Rows, number.clone(), ScalarValue::UInt64(Some(42)));
test_bound!(Groups, number.clone(), ScalarValue::UInt64(Some(42)));
test_bound!(
Range,
number.clone(),
ScalarValue::Utf8(Some("42".to_string()))
);
let number = Some(Box::new(ast::Expr::Interval(ast::Interval {
value: Box::new(ast::Expr::Value(ast::Value::SingleQuotedString(
"1".to_string(),
))),
leading_field: Some(ast::DateTimeField::Day),
fractional_seconds_precision: None,
last_field: None,
leading_precision: None,
})));
test_bound_err!(Rows, number.clone(), "Error during planning: Invalid window frame: frame offsets for ROWS / GROUPS must be non negative integers");
test_bound_err!(Groups, number.clone(), "Error during planning: Invalid window frame: frame offsets for ROWS / GROUPS must be non negative integers");
test_bound!(
Range,
number.clone(),
ScalarValue::Utf8(Some("1 DAY".to_string()))
);
Ok(())
}
}