use serde::{Deserialize, Serialize};
use crate::sql::DatabaseType;
use crate::types::SortOrder;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct WindowFunction {
pub function: WindowFn,
pub over: WindowSpec,
pub alias: Option<String>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum WindowFn {
RowNumber,
Rank,
DenseRank,
Ntile(u32),
PercentRank,
CumeDist,
Lag {
expr: String,
offset: Option<u32>,
default: Option<String>,
},
Lead {
expr: String,
offset: Option<u32>,
default: Option<String>,
},
FirstValue(String),
LastValue(String),
NthValue(String, u32),
Sum(String),
Avg(String),
Count(String),
Min(String),
Max(String),
Custom { name: String, args: Vec<String> },
}
impl WindowFn {
pub fn to_sql(&self) -> String {
match self {
Self::RowNumber => "ROW_NUMBER()".to_string(),
Self::Rank => "RANK()".to_string(),
Self::DenseRank => "DENSE_RANK()".to_string(),
Self::Ntile(n) => format!("NTILE({})", n),
Self::PercentRank => "PERCENT_RANK()".to_string(),
Self::CumeDist => "CUME_DIST()".to_string(),
Self::Lag {
expr,
offset,
default,
} => {
let mut sql = format!("LAG({})", expr);
if let Some(off) = offset {
sql = format!("LAG({}, {})", expr, off);
if let Some(def) = default {
sql = format!("LAG({}, {}, {})", expr, off, def);
}
}
sql
}
Self::Lead {
expr,
offset,
default,
} => {
let mut sql = format!("LEAD({})", expr);
if let Some(off) = offset {
sql = format!("LEAD({}, {})", expr, off);
if let Some(def) = default {
sql = format!("LEAD({}, {}, {})", expr, off, def);
}
}
sql
}
Self::FirstValue(expr) => format!("FIRST_VALUE({})", expr),
Self::LastValue(expr) => format!("LAST_VALUE({})", expr),
Self::NthValue(expr, n) => format!("NTH_VALUE({}, {})", expr, n),
Self::Sum(expr) => format!("SUM({})", expr),
Self::Avg(expr) => format!("AVG({})", expr),
Self::Count(expr) => format!("COUNT({})", expr),
Self::Min(expr) => format!("MIN({})", expr),
Self::Max(expr) => format!("MAX({})", expr),
Self::Custom { name, args } => {
format!("{}({})", name, args.join(", "))
}
}
}
}
#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
pub struct WindowSpec {
pub window_name: Option<String>,
pub partition_by: Vec<String>,
pub order_by: Vec<OrderSpec>,
pub frame: Option<FrameClause>,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct OrderSpec {
pub expr: String,
pub direction: SortOrder,
pub nulls: Option<NullsPosition>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum NullsPosition {
First,
Last,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct FrameClause {
pub frame_type: FrameType,
pub start: FrameBound,
pub end: Option<FrameBound>,
pub exclude: Option<FrameExclude>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum FrameType {
Rows,
Range,
Groups,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum FrameBound {
UnboundedPreceding,
Preceding(u32),
CurrentRow,
Following(u32),
UnboundedFollowing,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum FrameExclude {
CurrentRow,
Group,
Ties,
NoOthers,
}
impl WindowSpec {
pub fn new() -> Self {
Self::default()
}
pub fn named(name: impl Into<String>) -> Self {
Self {
window_name: Some(name.into()),
..Default::default()
}
}
pub fn partition_by<I, S>(mut self, columns: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.partition_by = columns.into_iter().map(Into::into).collect();
self
}
pub fn order_by(mut self, column: impl Into<String>, direction: SortOrder) -> Self {
self.order_by.push(OrderSpec {
expr: column.into(),
direction,
nulls: None,
});
self
}
pub fn order_by_nulls(
mut self,
column: impl Into<String>,
direction: SortOrder,
nulls: NullsPosition,
) -> Self {
self.order_by.push(OrderSpec {
expr: column.into(),
direction,
nulls: Some(nulls),
});
self
}
pub fn rows(mut self, start: FrameBound, end: Option<FrameBound>) -> Self {
self.frame = Some(FrameClause {
frame_type: FrameType::Rows,
start,
end,
exclude: None,
});
self
}
pub fn range(mut self, start: FrameBound, end: Option<FrameBound>) -> Self {
self.frame = Some(FrameClause {
frame_type: FrameType::Range,
start,
end,
exclude: None,
});
self
}
pub fn groups(mut self, start: FrameBound, end: Option<FrameBound>) -> Self {
self.frame = Some(FrameClause {
frame_type: FrameType::Groups,
start,
end,
exclude: None,
});
self
}
pub fn rows_unbounded_preceding(self) -> Self {
self.rows(FrameBound::UnboundedPreceding, Some(FrameBound::CurrentRow))
}
pub fn rows_unbounded_following(self) -> Self {
self.rows(FrameBound::CurrentRow, Some(FrameBound::UnboundedFollowing))
}
pub fn rows_around(self, n: u32) -> Self {
self.rows(FrameBound::Preceding(n), Some(FrameBound::Following(n)))
}
pub fn range_unbounded_preceding(self) -> Self {
self.range(FrameBound::UnboundedPreceding, Some(FrameBound::CurrentRow))
}
pub fn to_sql(&self, db_type: DatabaseType) -> String {
if let Some(ref name) = self.window_name {
return format!("OVER {}", name);
}
let mut parts = Vec::new();
if !self.partition_by.is_empty() {
parts.push(format!("PARTITION BY {}", self.partition_by.join(", ")));
}
if !self.order_by.is_empty() {
let orders: Vec<String> = self
.order_by
.iter()
.map(|o| {
let mut s = format!(
"{} {}",
o.expr,
match o.direction {
SortOrder::Asc => "ASC",
SortOrder::Desc => "DESC",
}
);
if let Some(nulls) = o.nulls {
if db_type != DatabaseType::MSSQL {
s.push_str(match nulls {
NullsPosition::First => " NULLS FIRST",
NullsPosition::Last => " NULLS LAST",
});
}
}
s
})
.collect();
parts.push(format!("ORDER BY {}", orders.join(", ")));
}
if let Some(ref frame) = self.frame {
parts.push(frame.to_sql(db_type));
}
if parts.is_empty() {
"OVER ()".to_string()
} else {
format!("OVER ({})", parts.join(" "))
}
}
}
impl FrameClause {
pub fn to_sql(&self, db_type: DatabaseType) -> String {
let frame_type = match self.frame_type {
FrameType::Rows => "ROWS",
FrameType::Range => "RANGE",
FrameType::Groups => {
match db_type {
DatabaseType::PostgreSQL | DatabaseType::SQLite => "GROUPS",
_ => "ROWS", }
}
};
let bounds = if let Some(ref end) = self.end {
format!("BETWEEN {} AND {}", self.start.to_sql(), end.to_sql())
} else {
self.start.to_sql()
};
let mut sql = format!("{} {}", frame_type, bounds);
if db_type == DatabaseType::PostgreSQL {
if let Some(exclude) = self.exclude {
sql.push_str(match exclude {
FrameExclude::CurrentRow => " EXCLUDE CURRENT ROW",
FrameExclude::Group => " EXCLUDE GROUP",
FrameExclude::Ties => " EXCLUDE TIES",
FrameExclude::NoOthers => " EXCLUDE NO OTHERS",
});
}
}
sql
}
}
impl FrameBound {
pub fn to_sql(&self) -> String {
match self {
Self::UnboundedPreceding => "UNBOUNDED PRECEDING".to_string(),
Self::Preceding(n) => format!("{} PRECEDING", n),
Self::CurrentRow => "CURRENT ROW".to_string(),
Self::Following(n) => format!("{} FOLLOWING", n),
Self::UnboundedFollowing => "UNBOUNDED FOLLOWING".to_string(),
}
}
}
impl WindowFunction {
pub fn new(function: WindowFn) -> WindowFunctionBuilder {
WindowFunctionBuilder {
function,
over: None,
alias: None,
}
}
pub fn over(mut self, spec: WindowSpec) -> Self {
self.over = spec;
self
}
pub fn alias(mut self, name: impl Into<String>) -> Self {
self.alias = Some(name.into());
self
}
pub fn to_sql(&self, db_type: DatabaseType) -> String {
let mut sql = format!("{} {}", self.function.to_sql(), self.over.to_sql(db_type));
if let Some(ref alias) = self.alias {
sql.push_str(" AS ");
sql.push_str(alias);
}
sql
}
}
#[derive(Debug, Clone)]
pub struct WindowFunctionBuilder {
function: WindowFn,
over: Option<WindowSpec>,
alias: Option<String>,
}
impl WindowFunctionBuilder {
pub fn over(mut self, spec: WindowSpec) -> Self {
self.over = Some(spec);
self
}
pub fn alias(mut self, name: impl Into<String>) -> Self {
self.alias = Some(name.into());
self
}
pub fn build(self) -> WindowFunction {
WindowFunction {
function: self.function,
over: self.over.unwrap_or_default(),
alias: self.alias,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct NamedWindow {
pub name: String,
pub spec: WindowSpec,
}
impl NamedWindow {
pub fn new(name: impl Into<String>, spec: WindowSpec) -> Self {
Self {
name: name.into(),
spec,
}
}
pub fn to_sql(&self, db_type: DatabaseType) -> String {
let spec_parts = {
let mut parts = Vec::new();
if !self.spec.partition_by.is_empty() {
parts.push(format!(
"PARTITION BY {}",
self.spec.partition_by.join(", ")
));
}
if !self.spec.order_by.is_empty() {
let orders: Vec<String> = self
.spec
.order_by
.iter()
.map(|o| {
format!(
"{} {}",
o.expr,
match o.direction {
SortOrder::Asc => "ASC",
SortOrder::Desc => "DESC",
}
)
})
.collect();
parts.push(format!("ORDER BY {}", orders.join(", ")));
}
if let Some(ref frame) = self.spec.frame {
parts.push(frame.to_sql(db_type));
}
parts.join(" ")
};
format!("{} AS ({})", self.name, spec_parts)
}
}
pub fn row_number() -> WindowFunctionBuilder {
WindowFunction::new(WindowFn::RowNumber)
}
pub fn rank() -> WindowFunctionBuilder {
WindowFunction::new(WindowFn::Rank)
}
pub fn dense_rank() -> WindowFunctionBuilder {
WindowFunction::new(WindowFn::DenseRank)
}
pub fn ntile(n: u32) -> WindowFunctionBuilder {
WindowFunction::new(WindowFn::Ntile(n))
}
pub fn percent_rank() -> WindowFunctionBuilder {
WindowFunction::new(WindowFn::PercentRank)
}
pub fn cume_dist() -> WindowFunctionBuilder {
WindowFunction::new(WindowFn::CumeDist)
}
pub fn lag(expr: impl Into<String>) -> WindowFunctionBuilder {
WindowFunction::new(WindowFn::Lag {
expr: expr.into(),
offset: None,
default: None,
})
}
pub fn lag_offset(expr: impl Into<String>, offset: u32) -> WindowFunctionBuilder {
WindowFunction::new(WindowFn::Lag {
expr: expr.into(),
offset: Some(offset),
default: None,
})
}
pub fn lag_full(
expr: impl Into<String>,
offset: u32,
default: impl Into<String>,
) -> WindowFunctionBuilder {
WindowFunction::new(WindowFn::Lag {
expr: expr.into(),
offset: Some(offset),
default: Some(default.into()),
})
}
pub fn lead(expr: impl Into<String>) -> WindowFunctionBuilder {
WindowFunction::new(WindowFn::Lead {
expr: expr.into(),
offset: None,
default: None,
})
}
pub fn lead_offset(expr: impl Into<String>, offset: u32) -> WindowFunctionBuilder {
WindowFunction::new(WindowFn::Lead {
expr: expr.into(),
offset: Some(offset),
default: None,
})
}
pub fn lead_full(
expr: impl Into<String>,
offset: u32,
default: impl Into<String>,
) -> WindowFunctionBuilder {
WindowFunction::new(WindowFn::Lead {
expr: expr.into(),
offset: Some(offset),
default: Some(default.into()),
})
}
pub fn first_value(expr: impl Into<String>) -> WindowFunctionBuilder {
WindowFunction::new(WindowFn::FirstValue(expr.into()))
}
pub fn last_value(expr: impl Into<String>) -> WindowFunctionBuilder {
WindowFunction::new(WindowFn::LastValue(expr.into()))
}
pub fn nth_value(expr: impl Into<String>, n: u32) -> WindowFunctionBuilder {
WindowFunction::new(WindowFn::NthValue(expr.into(), n))
}
pub fn sum(expr: impl Into<String>) -> WindowFunctionBuilder {
WindowFunction::new(WindowFn::Sum(expr.into()))
}
pub fn avg(expr: impl Into<String>) -> WindowFunctionBuilder {
WindowFunction::new(WindowFn::Avg(expr.into()))
}
pub fn count(expr: impl Into<String>) -> WindowFunctionBuilder {
WindowFunction::new(WindowFn::Count(expr.into()))
}
pub fn min(expr: impl Into<String>) -> WindowFunctionBuilder {
WindowFunction::new(WindowFn::Min(expr.into()))
}
pub fn max(expr: impl Into<String>) -> WindowFunctionBuilder {
WindowFunction::new(WindowFn::Max(expr.into()))
}
pub fn custom<I, S>(name: impl Into<String>, args: I) -> WindowFunctionBuilder
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
WindowFunction::new(WindowFn::Custom {
name: name.into(),
args: args.into_iter().map(Into::into).collect(),
})
}
pub mod mongodb {
use serde::{Deserialize, Serialize};
use serde_json::Value as JsonValue;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct SetWindowFields {
pub partition_by: Option<JsonValue>,
pub sort_by: Option<JsonValue>,
pub output: serde_json::Map<String, JsonValue>,
}
impl SetWindowFields {
pub fn new() -> SetWindowFieldsBuilder {
SetWindowFieldsBuilder::default()
}
pub fn to_bson(&self) -> JsonValue {
let mut stage = serde_json::Map::new();
if let Some(ref partition) = self.partition_by {
stage.insert("partitionBy".to_string(), partition.clone());
}
if let Some(ref sort) = self.sort_by {
stage.insert("sortBy".to_string(), sort.clone());
}
stage.insert("output".to_string(), JsonValue::Object(self.output.clone()));
serde_json::json!({ "$setWindowFields": stage })
}
}
impl Default for SetWindowFields {
fn default() -> Self {
Self {
partition_by: None,
sort_by: None,
output: serde_json::Map::new(),
}
}
}
#[derive(Debug, Clone, Default)]
pub struct SetWindowFieldsBuilder {
partition_by: Option<JsonValue>,
sort_by: Option<JsonValue>,
output: serde_json::Map<String, JsonValue>,
}
impl SetWindowFieldsBuilder {
pub fn partition_by(mut self, expr: impl Into<String>) -> Self {
self.partition_by = Some(JsonValue::String(format!("${}", expr.into())));
self
}
pub fn partition_by_expr(mut self, expr: JsonValue) -> Self {
self.partition_by = Some(expr);
self
}
pub fn sort_by(mut self, field: impl Into<String>) -> Self {
let mut sort = serde_json::Map::new();
sort.insert(field.into(), JsonValue::Number(1.into()));
self.sort_by = Some(JsonValue::Object(sort));
self
}
pub fn sort_by_desc(mut self, field: impl Into<String>) -> Self {
let mut sort = serde_json::Map::new();
sort.insert(field.into(), JsonValue::Number((-1).into()));
self.sort_by = Some(JsonValue::Object(sort));
self
}
pub fn sort_by_fields(mut self, fields: Vec<(&str, i32)>) -> Self {
let mut sort = serde_json::Map::new();
for (field, dir) in fields {
sort.insert(field.to_string(), JsonValue::Number(dir.into()));
}
self.sort_by = Some(JsonValue::Object(sort));
self
}
pub fn row_number(mut self, output_field: impl Into<String>) -> Self {
self.output
.insert(output_field.into(), serde_json::json!({ "$rowNumber": {} }));
self
}
pub fn rank(mut self, output_field: impl Into<String>) -> Self {
self.output
.insert(output_field.into(), serde_json::json!({ "$rank": {} }));
self
}
pub fn dense_rank(mut self, output_field: impl Into<String>) -> Self {
self.output
.insert(output_field.into(), serde_json::json!({ "$denseRank": {} }));
self
}
pub fn sum(
mut self,
output_field: impl Into<String>,
input: impl Into<String>,
window: Option<MongoWindow>,
) -> Self {
let mut spec = serde_json::Map::new();
spec.insert(
"$sum".to_string(),
JsonValue::String(format!("${}", input.into())),
);
if let Some(w) = window {
spec.insert("window".to_string(), w.to_bson());
}
self.output
.insert(output_field.into(), JsonValue::Object(spec));
self
}
pub fn avg(
mut self,
output_field: impl Into<String>,
input: impl Into<String>,
window: Option<MongoWindow>,
) -> Self {
let mut spec = serde_json::Map::new();
spec.insert(
"$avg".to_string(),
JsonValue::String(format!("${}", input.into())),
);
if let Some(w) = window {
spec.insert("window".to_string(), w.to_bson());
}
self.output
.insert(output_field.into(), JsonValue::Object(spec));
self
}
pub fn first(mut self, output_field: impl Into<String>, input: impl Into<String>) -> Self {
self.output.insert(
output_field.into(),
serde_json::json!({ "$first": format!("${}", input.into()) }),
);
self
}
pub fn last(mut self, output_field: impl Into<String>, input: impl Into<String>) -> Self {
self.output.insert(
output_field.into(),
serde_json::json!({ "$last": format!("${}", input.into()) }),
);
self
}
pub fn shift(
mut self,
output_field: impl Into<String>,
output: impl Into<String>,
by: i32,
default: Option<JsonValue>,
) -> Self {
let mut spec = serde_json::Map::new();
spec.insert(
"output".to_string(),
JsonValue::String(format!("${}", output.into())),
);
spec.insert("by".to_string(), JsonValue::Number(by.into()));
if let Some(def) = default {
spec.insert("default".to_string(), def);
}
self.output
.insert(output_field.into(), serde_json::json!({ "$shift": spec }));
self
}
pub fn output(mut self, field: impl Into<String>, spec: JsonValue) -> Self {
self.output.insert(field.into(), spec);
self
}
pub fn build(self) -> SetWindowFields {
SetWindowFields {
partition_by: self.partition_by,
sort_by: self.sort_by,
output: self.output,
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct MongoWindow {
pub documents: Option<[WindowBound; 2]>,
pub range: Option<[WindowBound; 2]>,
pub unit: Option<String>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(untagged)]
pub enum WindowBound {
Number(i64),
Keyword(String),
}
impl MongoWindow {
pub fn documents(start: i64, end: i64) -> Self {
Self {
documents: Some([WindowBound::Number(start), WindowBound::Number(end)]),
range: None,
unit: None,
}
}
pub fn documents_unbounded() -> Self {
Self {
documents: Some([
WindowBound::Keyword("unbounded".to_string()),
WindowBound::Keyword("unbounded".to_string()),
]),
range: None,
unit: None,
}
}
pub fn documents_to_current() -> Self {
Self {
documents: Some([
WindowBound::Keyword("unbounded".to_string()),
WindowBound::Keyword("current".to_string()),
]),
range: None,
unit: None,
}
}
pub fn range_with_unit(start: i64, end: i64, unit: impl Into<String>) -> Self {
Self {
documents: None,
range: Some([WindowBound::Number(start), WindowBound::Number(end)]),
unit: Some(unit.into()),
}
}
pub fn to_bson(&self) -> JsonValue {
let mut window = serde_json::Map::new();
if let Some(ref docs) = self.documents {
let arr: Vec<JsonValue> = docs
.iter()
.map(|b| match b {
WindowBound::Number(n) => JsonValue::Number((*n).into()),
WindowBound::Keyword(s) => JsonValue::String(s.clone()),
})
.collect();
window.insert("documents".to_string(), JsonValue::Array(arr));
}
if let Some(ref range) = self.range {
let arr: Vec<JsonValue> = range
.iter()
.map(|b| match b {
WindowBound::Number(n) => JsonValue::Number((*n).into()),
WindowBound::Keyword(s) => JsonValue::String(s.clone()),
})
.collect();
window.insert("range".to_string(), JsonValue::Array(arr));
}
if let Some(ref unit) = self.unit {
window.insert("unit".to_string(), JsonValue::String(unit.clone()));
}
JsonValue::Object(window)
}
}
pub fn set_window_fields() -> SetWindowFieldsBuilder {
SetWindowFields::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_row_number() {
let wf = row_number()
.over(
WindowSpec::new()
.partition_by(["dept"])
.order_by("salary", SortOrder::Desc),
)
.build();
let sql = wf.to_sql(DatabaseType::PostgreSQL);
assert!(sql.contains("ROW_NUMBER()"));
assert!(sql.contains("PARTITION BY dept"));
assert!(sql.contains("ORDER BY salary DESC"));
}
#[test]
fn test_rank_functions() {
let r = rank()
.over(WindowSpec::new().order_by("score", SortOrder::Desc))
.build();
assert!(r.to_sql(DatabaseType::PostgreSQL).contains("RANK()"));
let dr = dense_rank()
.over(WindowSpec::new().order_by("score", SortOrder::Desc))
.build();
assert!(dr.to_sql(DatabaseType::PostgreSQL).contains("DENSE_RANK()"));
}
#[test]
fn test_ntile() {
let wf = ntile(4)
.over(WindowSpec::new().order_by("value", SortOrder::Asc))
.build();
assert!(wf.to_sql(DatabaseType::MySQL).contains("NTILE(4)"));
}
#[test]
fn test_lag_lead() {
let l = lag("price")
.over(WindowSpec::new().order_by("date", SortOrder::Asc))
.build();
assert!(l.to_sql(DatabaseType::PostgreSQL).contains("LAG(price)"));
let l2 = lag_offset("price", 2)
.over(WindowSpec::new().order_by("date", SortOrder::Asc))
.build();
assert!(
l2.to_sql(DatabaseType::PostgreSQL)
.contains("LAG(price, 2)")
);
let l3 = lag_full("price", 1, "0")
.over(WindowSpec::new().order_by("date", SortOrder::Asc))
.build();
assert!(
l3.to_sql(DatabaseType::PostgreSQL)
.contains("LAG(price, 1, 0)")
);
let ld = lead("price")
.over(WindowSpec::new().order_by("date", SortOrder::Asc))
.build();
assert!(ld.to_sql(DatabaseType::PostgreSQL).contains("LEAD(price)"));
}
#[test]
fn test_aggregate_window() {
let s = sum("amount")
.over(
WindowSpec::new()
.partition_by(["account_id"])
.order_by("date", SortOrder::Asc)
.rows_unbounded_preceding(),
)
.alias("running_total")
.build();
let sql = s.to_sql(DatabaseType::PostgreSQL);
assert!(sql.contains("SUM(amount)"));
assert!(sql.contains("ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW"));
assert!(sql.contains("AS running_total"));
}
#[test]
fn test_frame_clauses() {
let spec = WindowSpec::new()
.order_by("id", SortOrder::Asc)
.rows(FrameBound::Preceding(3), Some(FrameBound::Following(3)));
let sql = spec.to_sql(DatabaseType::PostgreSQL);
assert!(sql.contains("ROWS BETWEEN 3 PRECEDING AND 3 FOLLOWING"));
}
#[test]
fn test_named_window() {
let nw = NamedWindow::new(
"w",
WindowSpec::new()
.partition_by(["dept"])
.order_by("salary", SortOrder::Desc),
);
let sql = nw.to_sql(DatabaseType::PostgreSQL);
assert!(sql.contains("w AS ("));
assert!(sql.contains("PARTITION BY dept"));
}
#[test]
fn test_window_reference() {
let spec = WindowSpec::named("w");
assert_eq!(spec.to_sql(DatabaseType::PostgreSQL), "OVER w");
}
#[test]
fn test_nulls_position() {
let spec = WindowSpec::new().order_by_nulls("value", SortOrder::Desc, NullsPosition::Last);
let pg_sql = spec.to_sql(DatabaseType::PostgreSQL);
assert!(pg_sql.contains("NULLS LAST"));
let mssql_sql = spec.to_sql(DatabaseType::MSSQL);
assert!(!mssql_sql.contains("NULLS"));
}
#[test]
fn test_first_last_value() {
let fv = first_value("salary")
.over(
WindowSpec::new()
.partition_by(["dept"])
.order_by("hire_date", SortOrder::Asc),
)
.build();
assert!(
fv.to_sql(DatabaseType::PostgreSQL)
.contains("FIRST_VALUE(salary)")
);
let lv = last_value("salary")
.over(
WindowSpec::new()
.partition_by(["dept"])
.order_by("hire_date", SortOrder::Asc)
.rows(
FrameBound::UnboundedPreceding,
Some(FrameBound::UnboundedFollowing),
),
)
.build();
assert!(
lv.to_sql(DatabaseType::PostgreSQL)
.contains("LAST_VALUE(salary)")
);
}
mod mongodb_tests {
use super::super::mongodb::*;
#[test]
fn test_row_number() {
let stage = set_window_fields()
.partition_by("state")
.sort_by_desc("quantity")
.row_number("rowNumber")
.build();
let bson = stage.to_bson();
assert!(bson["$setWindowFields"]["output"]["rowNumber"]["$rowNumber"].is_object());
}
#[test]
fn test_rank() {
let stage = set_window_fields()
.sort_by("score")
.rank("ranking")
.dense_rank("denseRanking")
.build();
let bson = stage.to_bson();
assert!(bson["$setWindowFields"]["output"]["ranking"]["$rank"].is_object());
assert!(bson["$setWindowFields"]["output"]["denseRanking"]["$denseRank"].is_object());
}
#[test]
fn test_running_total() {
let stage = set_window_fields()
.partition_by("account")
.sort_by("date")
.sum(
"runningTotal",
"amount",
Some(MongoWindow::documents_to_current()),
)
.build();
let bson = stage.to_bson();
let output = &bson["$setWindowFields"]["output"]["runningTotal"];
assert!(output["$sum"].is_string());
assert!(output["window"]["documents"].is_array());
}
#[test]
fn test_shift_lag() {
let stage = set_window_fields()
.sort_by("date")
.shift("prevPrice", "price", -1, Some(serde_json::json!(0)))
.shift("nextPrice", "price", 1, None)
.build();
let bson = stage.to_bson();
assert!(bson["$setWindowFields"]["output"]["prevPrice"]["$shift"]["by"] == -1);
assert!(bson["$setWindowFields"]["output"]["nextPrice"]["$shift"]["by"] == 1);
}
#[test]
fn test_window_bounds() {
let w = MongoWindow::documents(-3, 3);
let bson = w.to_bson();
assert_eq!(bson["documents"][0], -3);
assert_eq!(bson["documents"][1], 3);
let w2 = MongoWindow::range_with_unit(-7, 0, "day");
let bson2 = w2.to_bson();
assert!(bson2["range"].is_array());
assert_eq!(bson2["unit"], "day");
}
}
}