use super::expr::Expr;
use super::query::{AggregateExpr, OrderClause};
use super::SqlValue;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum WindowFn {
RowNumber,
Rank,
DenseRank,
Ntile,
Lag,
Lead,
FirstValue,
LastValue,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum FrameKind {
Rows,
Range,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum FrameBoundary {
UnboundedPreceding,
Preceding(i64),
CurrentRow,
Following(i64),
UnboundedFollowing,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct WindowFrame {
pub kind: FrameKind,
pub start: FrameBoundary,
pub end: Option<FrameBoundary>,
}
#[derive(Debug, Clone, PartialEq)]
pub struct WindowExpr {
pub kind: WindowFn,
pub args: Vec<Expr>,
pub partition_by: Vec<&'static str>,
pub order_by: Vec<OrderClause>,
pub frame: Option<WindowFrame>,
}
#[must_use]
pub struct WindowBuilder {
inner: WindowExpr,
}
impl WindowBuilder {
fn new(kind: WindowFn, args: Vec<Expr>) -> Self {
Self {
inner: WindowExpr {
kind,
args,
partition_by: Vec::new(),
order_by: Vec::new(),
frame: None,
},
}
}
pub fn partition_by(mut self, column: &'static str) -> Self {
self.inner.partition_by.push(column);
self
}
pub fn order_by(mut self, items: &[(&'static str, bool)]) -> Self {
for (col, desc) in items {
self.inner.order_by.push(OrderClause {
column: col,
desc: *desc,
});
}
self
}
pub fn frame(mut self, frame: WindowFrame) -> Self {
self.inner.frame = Some(frame);
self
}
#[must_use]
pub fn build(self) -> Expr {
Expr::Window(Box::new(self.inner))
}
}
impl From<WindowBuilder> for Expr {
fn from(b: WindowBuilder) -> Self {
b.build()
}
}
impl From<WindowBuilder> for AggregateExpr {
fn from(b: WindowBuilder) -> Self {
AggregateExpr::Window(Box::new(b.inner))
}
}
#[must_use]
pub fn row_number() -> WindowBuilder {
WindowBuilder::new(WindowFn::RowNumber, vec![])
}
#[must_use]
pub fn rank() -> WindowBuilder {
WindowBuilder::new(WindowFn::Rank, vec![])
}
#[must_use]
pub fn dense_rank() -> WindowBuilder {
WindowBuilder::new(WindowFn::DenseRank, vec![])
}
#[must_use]
pub fn ntile(buckets: i64) -> WindowBuilder {
WindowBuilder::new(WindowFn::Ntile, vec![Expr::Literal(SqlValue::I64(buckets))])
}
#[must_use]
pub fn lag(column: &'static str, offset: i64, default: Option<SqlValue>) -> WindowBuilder {
let mut args = vec![Expr::Column(column), Expr::Literal(SqlValue::I64(offset))];
if let Some(d) = default {
args.push(Expr::Literal(d));
}
WindowBuilder::new(WindowFn::Lag, args)
}
#[must_use]
pub fn lead(column: &'static str, offset: i64, default: Option<SqlValue>) -> WindowBuilder {
let mut args = vec![Expr::Column(column), Expr::Literal(SqlValue::I64(offset))];
if let Some(d) = default {
args.push(Expr::Literal(d));
}
WindowBuilder::new(WindowFn::Lead, args)
}
#[must_use]
pub fn first_value(column: &'static str) -> WindowBuilder {
WindowBuilder::new(WindowFn::FirstValue, vec![Expr::Column(column)])
}
#[must_use]
pub fn last_value(column: &'static str) -> WindowBuilder {
WindowBuilder::new(WindowFn::LastValue, vec![Expr::Column(column)])
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn row_number_has_no_args() {
let b = row_number();
assert!(b.inner.args.is_empty());
assert_eq!(b.inner.kind, WindowFn::RowNumber);
}
#[test]
fn lag_default_none_omits_default_arg() {
let b = lag("price", 1, None);
assert_eq!(b.inner.args.len(), 2);
}
#[test]
fn lag_default_some_includes_default_arg() {
let b = lag("price", 1, Some(SqlValue::I64(0)));
assert_eq!(b.inner.args.len(), 3);
}
#[test]
fn partition_by_appends() {
let b = row_number()
.partition_by("tenant_id")
.partition_by("region");
assert_eq!(b.inner.partition_by, vec!["tenant_id", "region"]);
}
#[test]
fn order_by_appends_multiple() {
let b = row_number().order_by(&[("score", true), ("id", false)]);
assert_eq!(b.inner.order_by.len(), 2);
assert_eq!(b.inner.order_by[0].column, "score");
assert!(b.inner.order_by[0].desc);
assert!(!b.inner.order_by[1].desc);
}
#[test]
fn frame_replaces_prior() {
let f1 = WindowFrame {
kind: FrameKind::Rows,
start: FrameBoundary::UnboundedPreceding,
end: Some(FrameBoundary::CurrentRow),
};
let f2 = WindowFrame {
kind: FrameKind::Range,
start: FrameBoundary::Preceding(5),
end: Some(FrameBoundary::Following(5)),
};
let b = row_number().frame(f1).frame(f2.clone());
assert_eq!(b.inner.frame, Some(f2));
}
#[test]
fn lowers_to_expr_window_variant() {
let e: Expr = row_number().into();
assert!(matches!(e, Expr::Window(_)));
}
#[test]
fn lowers_to_aggregate_expr_window_variant() {
let a: AggregateExpr = rank().partition_by("tenant_id").into();
assert!(matches!(a, AggregateExpr::Window(_)));
}
}