use anyhow::Result;
use serde::{Deserialize, Serialize};
#[derive(Debug, Serialize, Deserialize)]
pub struct WindowFunctionBuilder {
pub function_type: WindowFunctionType,
pub partition_by: Vec<String>,
pub order_by: Vec<OrderByColumn>,
pub window_frame: Option<WindowFrame>,
}
#[derive(Debug, Serialize, Deserialize)]
pub enum WindowFunctionType {
RowNumber,
Rank,
DenseRank,
Lead { column: String, offset: i32, default: Option<String> },
Lag { column: String, offset: i32, default: Option<String> },
FirstValue { column: String },
LastValue { column: String },
NthValue { column: String, n: i32 },
PercentRank,
CumeDist,
Sum { column: String },
Avg { column: String },
Count { column: String },
Max { column: String },
Min { column: String },
}
#[derive(Debug, Serialize, Deserialize)]
pub struct OrderByColumn {
pub column: String,
pub direction: SortDirection,
pub nulls: Option<NullHandling>,
}
#[derive(Debug, Serialize, Deserialize)]
pub enum SortDirection {
Asc,
Desc,
}
#[derive(Debug, Serialize, Deserialize)]
pub enum NullHandling {
First,
Last,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct WindowFrame {
pub mode: FrameMode,
pub start: FrameBound,
pub end: Option<FrameBound>,
}
#[derive(Debug, Serialize, Deserialize)]
pub enum FrameMode {
Rows,
Range,
}
#[derive(Debug, Serialize, Deserialize)]
pub enum FrameBound {
UnboundedPreceding,
CurrentRow,
UnboundedFollowing,
Preceding(i32),
Following(i32),
}
impl WindowFunctionBuilder {
pub fn row_number(partition_by: Vec<String>, order_by: Vec<String>) -> Self {
Self {
function_type: WindowFunctionType::RowNumber,
partition_by,
order_by: order_by.into_iter().map(|col| OrderByColumn {
column: col,
direction: SortDirection::Asc,
nulls: None,
}).collect(),
window_frame: None,
}
}
pub fn lead(column: String, offset: i32, partition_by: Vec<String>, order_by: Vec<String>) -> Self {
Self {
function_type: WindowFunctionType::Lead {
column: column.clone(),
offset,
default: None,
},
partition_by,
order_by: order_by.into_iter().map(|col| OrderByColumn {
column: col,
direction: SortDirection::Asc,
nulls: None,
}).collect(),
window_frame: None,
}
}
pub fn lag(column: String, offset: i32, partition_by: Vec<String>, order_by: Vec<String>) -> Self {
Self {
function_type: WindowFunctionType::Lag {
column: column.clone(),
offset,
default: None,
},
partition_by,
order_by: order_by.into_iter().map(|col| OrderByColumn {
column: col,
direction: SortDirection::Asc,
nulls: None,
}).collect(),
window_frame: None,
}
}
pub fn running_sum(column: String, partition_by: Vec<String>, order_by: Vec<String>) -> Self {
Self {
function_type: WindowFunctionType::Sum { column },
partition_by,
order_by: order_by.into_iter().map(|col| OrderByColumn {
column: col,
direction: SortDirection::Asc,
nulls: None,
}).collect(),
window_frame: Some(WindowFrame {
mode: FrameMode::Rows,
start: FrameBound::UnboundedPreceding,
end: Some(FrameBound::CurrentRow),
}),
}
}
pub fn to_sql(&self, alias: Option<&str>) -> String {
let mut sql = String::new();
match &self.function_type {
WindowFunctionType::RowNumber => sql.push_str("ROW_NUMBER()"),
WindowFunctionType::Rank => sql.push_str("RANK()"),
WindowFunctionType::DenseRank => sql.push_str("DENSE_RANK()"),
WindowFunctionType::Lead { column, offset, default } => {
sql.push_str(&format!("LEAD({}", column));
if *offset != 1 {
sql.push_str(&format!(", {}", offset));
}
if let Some(def) = default {
sql.push_str(&format!(", {}", def));
}
sql.push(')');
}
WindowFunctionType::Lag { column, offset, default } => {
sql.push_str(&format!("LAG({}", column));
if *offset != 1 {
sql.push_str(&format!(", {}", offset));
}
if let Some(def) = default {
sql.push_str(&format!(", {}", def));
}
sql.push(')');
}
WindowFunctionType::FirstValue { column } => {
sql.push_str(&format!("FIRST_VALUE({})", column));
}
WindowFunctionType::LastValue { column } => {
sql.push_str(&format!("LAST_VALUE({})", column));
}
WindowFunctionType::NthValue { column, n } => {
sql.push_str(&format!("NTH_VALUE({}, {})", column, n));
}
WindowFunctionType::PercentRank => sql.push_str("PERCENT_RANK()"),
WindowFunctionType::CumeDist => sql.push_str("CUME_DIST()"),
WindowFunctionType::Sum { column } => sql.push_str(&format!("SUM({})", column)),
WindowFunctionType::Avg { column } => sql.push_str(&format!("AVG({})", column)),
WindowFunctionType::Count { column } => sql.push_str(&format!("COUNT({})", column)),
WindowFunctionType::Max { column } => sql.push_str(&format!("MAX({})", column)),
WindowFunctionType::Min { column } => sql.push_str(&format!("MIN({})", column)),
}
sql.push_str(" OVER (");
if !self.partition_by.is_empty() {
sql.push_str("PARTITION BY ");
sql.push_str(&self.partition_by.join(", "));
if !self.order_by.is_empty() {
sql.push(' ');
}
}
if !self.order_by.is_empty() {
sql.push_str("ORDER BY ");
let order_strs: Vec<String> = self.order_by.iter().map(|col| {
let mut s = col.column.clone();
match col.direction {
SortDirection::Desc => s.push_str(" DESC"),
_ => {}
}
if let Some(nulls) = &col.nulls {
match nulls {
NullHandling::First => s.push_str(" NULLS FIRST"),
NullHandling::Last => s.push_str(" NULLS LAST"),
}
}
s
}).collect();
sql.push_str(&order_strs.join(", "));
}
if let Some(frame) = &self.window_frame {
sql.push(' ');
match frame.mode {
FrameMode::Rows => sql.push_str("ROWS"),
FrameMode::Range => sql.push_str("RANGE"),
}
sql.push_str(" BETWEEN ");
sql.push_str(&frame_bound_to_sql(&frame.start));
sql.push_str(" AND ");
if let Some(end) = &frame.end {
sql.push_str(&frame_bound_to_sql(end));
} else {
sql.push_str("CURRENT ROW");
}
}
sql.push(')');
if let Some(alias) = alias {
sql.push_str(&format!(" AS {}", alias));
}
sql
}
}
fn frame_bound_to_sql(bound: &FrameBound) -> String {
match bound {
FrameBound::UnboundedPreceding => "UNBOUNDED PRECEDING".to_string(),
FrameBound::CurrentRow => "CURRENT ROW".to_string(),
FrameBound::UnboundedFollowing => "UNBOUNDED FOLLOWING".to_string(),
FrameBound::Preceding(n) => format!("{} PRECEDING", n),
FrameBound::Following(n) => format!("{} FOLLOWING", n),
}
}
pub struct WindowPatterns;
impl WindowPatterns {
pub fn ranking_with_ties() -> String {
r#"-- Ranking with different tie handling
ROW_NUMBER() OVER (PARTITION BY category ORDER BY score DESC) as row_rank,
RANK() OVER (PARTITION BY category ORDER BY score DESC) as rank_with_gaps,
DENSE_RANK() OVER (PARTITION BY category ORDER BY score DESC) as dense_rank"#.to_string()
}
pub fn change_detection() -> String {
r#"-- Detect changes from previous row
LAG(value, 1) OVER (PARTITION BY id ORDER BY date) as prev_value,
value - LAG(value, 1) OVER (PARTITION BY id ORDER BY date) as change,
CASE
WHEN value != LAG(value, 1) OVER (PARTITION BY id ORDER BY date) THEN 1
ELSE 0
END as is_changed"#.to_string()
}
pub fn running_totals() -> String {
r#"-- Running totals and averages
SUM(amount) OVER (PARTITION BY account ORDER BY date ROWS UNBOUNDED PRECEDING) as running_total,
AVG(amount) OVER (PARTITION BY account ORDER BY date ROWS BETWEEN 6 PRECEDING AND CURRENT ROW) as moving_avg_7,
COUNT(*) OVER (PARTITION BY account ORDER BY date ROWS UNBOUNDED PRECEDING) as running_count"#.to_string()
}
pub fn percentile_ranking() -> String {
r#"-- Percentile ranking
PERCENT_RANK() OVER (ORDER BY value) as percentile,
NTILE(4) OVER (ORDER BY value) as quartile,
NTILE(10) OVER (ORDER BY value) as decile,
CUME_DIST() OVER (ORDER BY value) as cumulative_dist"#.to_string()
}
pub fn gaps_and_islands() -> String {
r#"-- Gaps and Islands pattern
WITH numbered AS (
SELECT *,
ROW_NUMBER() OVER (ORDER BY date) as rn,
ROW_NUMBER() OVER (PARTITION BY status ORDER BY date) as status_rn
FROM events
),
islands AS (
SELECT *,
rn - status_rn as island_id
FROM numbered
)
SELECT
status,
MIN(date) as island_start,
MAX(date) as island_end,
COUNT(*) as island_length
FROM islands
GROUP BY status, island_id
ORDER BY island_start"#.to_string()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_row_number_generation() {
let wf = WindowFunctionBuilder::row_number(
vec!["department".to_string()],
vec!["salary".to_string()],
);
let sql = wf.to_sql(Some("rank"));
assert_eq!(sql, "ROW_NUMBER() OVER (PARTITION BY department ORDER BY salary) AS rank");
}
#[test]
fn test_lead_lag_generation() {
let lead = WindowFunctionBuilder::lead(
"price".to_string(),
1,
vec!["product".to_string()],
vec!["date".to_string()],
);
let sql = lead.to_sql(Some("next_price"));
assert!(sql.contains("LEAD(price)"));
assert!(sql.contains("PARTITION BY product"));
}
#[test]
fn test_running_sum() {
let sum = WindowFunctionBuilder::running_sum(
"amount".to_string(),
vec!["account".to_string()],
vec!["date".to_string()],
);
let sql = sum.to_sql(Some("running_total"));
assert!(sql.contains("SUM(amount)"));
assert!(sql.contains("ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW"));
}
}