use std::cmp::Ordering;
use std::collections::HashMap;
use super::super::engine::binding::{Binding, Value, Var};
use super::aggregation::create_aggregator;
use super::value_compare::{total_compare_values, values_equal};
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum WindowFuncType {
RowNumber,
Rank,
DenseRank,
Ntile(i64),
PercentRank,
CumeDist,
FirstValue(Var),
LastValue(Var),
NthValue(Var, i64),
Lag(Var, i64, Option<Value>),
Lead(Var, i64, Option<Value>),
Aggregate(String, Var),
}
impl WindowFuncType {
pub fn row_number() -> Self {
Self::RowNumber
}
pub fn rank() -> Self {
Self::Rank
}
pub fn dense_rank() -> Self {
Self::DenseRank
}
pub fn ntile(n: i64) -> Self {
Self::Ntile(n)
}
pub fn lag(var: Var, offset: i64, default: Option<Value>) -> Self {
Self::Lag(var, offset, default)
}
pub fn lead(var: Var, offset: i64, default: Option<Value>) -> Self {
Self::Lead(var, offset, default)
}
pub fn first_value(var: Var) -> Self {
Self::FirstValue(var)
}
pub fn last_value(var: Var) -> Self {
Self::LastValue(var)
}
pub fn aggregate(name: &str, var: Var) -> Self {
Self::Aggregate(name.to_uppercase(), var)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum FrameType {
Rows,
Range,
Groups,
}
#[derive(Debug, Clone, PartialEq, Eq, Default)]
pub enum FrameBound {
UnboundedPreceding,
UnboundedFollowing,
#[default]
CurrentRow,
Preceding(i64),
Following(i64),
}
#[derive(Debug, Clone)]
pub struct FrameSpec {
pub frame_type: FrameType,
pub start: FrameBound,
pub end: FrameBound,
pub exclude: FrameExclude,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum FrameExclude {
#[default]
NoOthers,
CurrentRow,
Group,
Ties,
}
impl Default for FrameSpec {
fn default() -> Self {
Self {
frame_type: FrameType::Range,
start: FrameBound::UnboundedPreceding,
end: FrameBound::CurrentRow,
exclude: FrameExclude::NoOthers,
}
}
}
impl FrameSpec {
pub fn entire_partition() -> Self {
Self {
frame_type: FrameType::Rows,
start: FrameBound::UnboundedPreceding,
end: FrameBound::UnboundedFollowing,
exclude: FrameExclude::NoOthers,
}
}
pub fn running() -> Self {
Self {
frame_type: FrameType::Rows,
start: FrameBound::UnboundedPreceding,
end: FrameBound::CurrentRow,
exclude: FrameExclude::NoOthers,
}
}
pub fn sliding(n: i64) -> Self {
Self {
frame_type: FrameType::Rows,
start: FrameBound::Preceding(n),
end: FrameBound::CurrentRow,
exclude: FrameExclude::NoOthers,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum SortDirection {
#[default]
Asc,
Desc,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum NullsOrder {
#[default]
First,
Last,
}
#[derive(Debug, Clone)]
pub struct WindowOrderBy {
pub var: Var,
pub direction: SortDirection,
pub nulls: NullsOrder,
}
impl WindowOrderBy {
pub fn new(var: Var) -> Self {
Self {
var,
direction: SortDirection::Asc,
nulls: NullsOrder::Last,
}
}
pub fn desc(mut self) -> Self {
self.direction = SortDirection::Desc;
self
}
pub fn nulls_first(mut self) -> Self {
self.nulls = NullsOrder::First;
self
}
}
#[derive(Debug, Clone, Default)]
pub struct WindowDef {
pub name: Option<String>,
pub partition_by: Vec<Var>,
pub order_by: Vec<WindowOrderBy>,
pub frame: FrameSpec,
}
impl WindowDef {
pub fn partition_by(vars: Vec<Var>) -> Self {
Self {
partition_by: vars,
..Default::default()
}
}
pub fn with_order_by(mut self, order: Vec<WindowOrderBy>) -> Self {
self.order_by = order;
self
}
pub fn with_frame(mut self, frame: FrameSpec) -> Self {
self.frame = frame;
self
}
}
#[derive(Debug, Clone)]
pub struct WindowFunc {
pub func_type: WindowFuncType,
pub result_var: Var,
pub window: WindowDef,
}
impl WindowFunc {
pub fn new(func_type: WindowFuncType, result_var: Var, window: WindowDef) -> Self {
Self {
func_type,
result_var,
window,
}
}
}
pub struct WindowExecutor;
#[derive(Debug, Clone)]
struct IndexedBinding {
index: usize,
binding: Binding,
}
impl WindowExecutor {
pub fn execute(bindings: Vec<Binding>, functions: &[WindowFunc]) -> Vec<Binding> {
if bindings.is_empty() || functions.is_empty() {
return bindings;
}
let mut result = bindings;
for func in functions {
result = Self::apply_window_function(&result, func);
}
result
}
fn apply_window_function(bindings: &[Binding], func: &WindowFunc) -> Vec<Binding> {
let partitions = Self::partition_bindings(bindings, &func.window.partition_by);
let mut result: Vec<Option<Binding>> = vec![None; bindings.len()];
for (_key, mut partition) in partitions {
Self::sort_partition(&mut partition, &func.window.order_by);
let computed = Self::compute_for_partition(&partition, func);
for entry in computed {
if entry.index < result.len() {
result[entry.index] = Some(entry.binding);
}
}
}
result
.into_iter()
.enumerate()
.map(|(idx, binding)| binding.unwrap_or_else(|| bindings[idx].clone()))
.collect()
}
fn partition_bindings(
bindings: &[Binding],
partition_by: &[Var],
) -> Vec<(Vec<Option<Value>>, Vec<IndexedBinding>)> {
if partition_by.is_empty() {
let entries = bindings
.iter()
.cloned()
.enumerate()
.map(|(index, binding)| IndexedBinding { index, binding })
.collect();
return vec![(vec![], entries)];
}
let mut partitions: HashMap<Vec<Option<Value>>, Vec<IndexedBinding>> = HashMap::new();
let mut key_order: Vec<Vec<Option<Value>>> = Vec::new();
for (index, binding) in bindings.iter().cloned().enumerate() {
let key_values: Vec<Option<Value>> = partition_by
.iter()
.map(|v| binding.get(v).cloned())
.collect();
if !partitions.contains_key(&key_values) {
key_order.push(key_values.clone());
}
partitions
.entry(key_values)
.or_default()
.push(IndexedBinding { index, binding });
}
key_order
.into_iter()
.filter_map(|values| partitions.remove(&values).map(|rows| (values, rows)))
.collect()
}
fn sort_partition(partition: &mut [IndexedBinding], order_by: &[WindowOrderBy]) {
if order_by.is_empty() {
return;
}
partition.sort_by(|a, b| {
for spec in order_by {
let val_a = a.binding.get(&spec.var);
let val_b = b.binding.get(&spec.var);
let cmp = match (val_a, val_b) {
(None, None) => Ordering::Equal,
(None, Some(_)) => match spec.nulls {
NullsOrder::First => Ordering::Less,
NullsOrder::Last => Ordering::Greater,
},
(Some(_), None) => match spec.nulls {
NullsOrder::First => Ordering::Greater,
NullsOrder::Last => Ordering::Less,
},
(Some(a), Some(b)) => total_compare_values(a, b),
};
if cmp != Ordering::Equal {
return match spec.direction {
SortDirection::Asc => cmp,
SortDirection::Desc => cmp.reverse(),
};
}
}
a.index.cmp(&b.index)
});
}
fn compute_for_partition(
partition: &[IndexedBinding],
func: &WindowFunc,
) -> Vec<IndexedBinding> {
let partition_size = partition.len();
let peer_groups = Self::compute_peer_groups(partition, &func.window.order_by);
partition
.iter()
.enumerate()
.map(|(row_idx, indexed)| {
let value =
Self::compute_value(partition, row_idx, &peer_groups, partition_size, func);
let result_binding = Binding::one(func.result_var.clone(), value);
let binding = indexed
.binding
.merge(&result_binding)
.unwrap_or_else(|| indexed.binding.clone());
IndexedBinding {
index: indexed.index,
binding,
}
})
.collect()
}
fn compute_peer_groups(partition: &[IndexedBinding], order_by: &[WindowOrderBy]) -> Vec<usize> {
if order_by.is_empty() {
return vec![0; partition.len()];
}
let mut groups = Vec::with_capacity(partition.len());
let mut current_group = 0;
for (idx, indexed) in partition.iter().enumerate() {
if idx == 0 {
groups.push(0);
continue;
}
let prev = &partition[idx - 1].binding;
let binding = &indexed.binding;
let is_peer = order_by.iter().all(|spec| {
let a = prev.get(&spec.var);
let b = binding.get(&spec.var);
match (a, b) {
(None, None) => true,
(Some(va), Some(vb)) => values_equal(va, vb),
_ => false,
}
});
if !is_peer {
current_group += 1;
}
groups.push(current_group);
}
groups
}
fn compute_value(
partition: &[IndexedBinding],
row_idx: usize,
peer_groups: &[usize],
partition_size: usize,
func: &WindowFunc,
) -> Value {
match &func.func_type {
WindowFuncType::RowNumber => Value::Integer((row_idx + 1) as i64),
WindowFuncType::Rank => {
let current_group = peer_groups[row_idx];
let first_in_group = peer_groups
.iter()
.position(|&g| g == current_group)
.unwrap();
Value::Integer((first_in_group + 1) as i64)
}
WindowFuncType::DenseRank => {
Value::Integer((peer_groups[row_idx] + 1) as i64)
}
WindowFuncType::Ntile(n) => {
let n = *n as usize;
if n == 0 || partition_size == 0 {
return Value::Null;
}
let bucket_size = partition_size / n;
let remainder = partition_size % n;
let mut row = 0;
let mut bucket = 1;
for i in 0..n {
let size = bucket_size + if i < remainder { 1 } else { 0 };
if row_idx < row + size {
bucket = i + 1;
break;
}
row += size;
}
Value::Integer(bucket as i64)
}
WindowFuncType::PercentRank => {
if partition_size <= 1 {
return Value::Float(0.0);
}
let current_group = peer_groups[row_idx];
let first_in_group = peer_groups
.iter()
.position(|&g| g == current_group)
.unwrap();
let rank = first_in_group as f64;
Value::Float(rank / (partition_size - 1) as f64)
}
WindowFuncType::CumeDist => {
let current_group = peer_groups[row_idx];
let count = peer_groups.iter().filter(|&&g| g <= current_group).count();
Value::Float(count as f64 / partition_size as f64)
}
WindowFuncType::FirstValue(var) => {
let (start, _) = Self::get_frame_bounds(
row_idx,
partition_size,
peer_groups,
&func.window.frame,
);
partition
.get(start)
.and_then(|b| b.binding.get(var))
.cloned()
.unwrap_or(Value::Null)
}
WindowFuncType::LastValue(var) => {
let (_, end) = Self::get_frame_bounds(
row_idx,
partition_size,
peer_groups,
&func.window.frame,
);
if end > 0 {
partition
.get(end - 1)
.and_then(|b| b.binding.get(var))
.cloned()
.unwrap_or(Value::Null)
} else {
Value::Null
}
}
WindowFuncType::NthValue(var, n) => {
let (start, end) = Self::get_frame_bounds(
row_idx,
partition_size,
peer_groups,
&func.window.frame,
);
let n = *n as usize;
if n == 0 {
return Value::Null;
}
let target_idx = start + n - 1;
if target_idx < end {
partition
.get(target_idx)
.and_then(|b| b.binding.get(var))
.cloned()
.unwrap_or(Value::Null)
} else {
Value::Null
}
}
WindowFuncType::Lag(var, offset, default) => {
let offset = *offset as usize;
if row_idx >= offset {
partition
.get(row_idx - offset)
.and_then(|b| b.binding.get(var))
.cloned()
.unwrap_or_else(|| default.clone().unwrap_or(Value::Null))
} else {
default.clone().unwrap_or(Value::Null)
}
}
WindowFuncType::Lead(var, offset, default) => {
let offset = *offset as usize;
let target = row_idx + offset;
if target < partition_size {
partition
.get(target)
.and_then(|b| b.binding.get(var))
.cloned()
.unwrap_or_else(|| default.clone().unwrap_or(Value::Null))
} else {
default.clone().unwrap_or(Value::Null)
}
}
WindowFuncType::Aggregate(agg_name, var) => {
let (start, end) = Self::get_frame_bounds(
row_idx,
partition_size,
peer_groups,
&func.window.frame,
);
if let Some(mut aggregator) = create_aggregator(agg_name) {
for i in start..end {
if let Some(binding) = partition.get(i) {
let value = binding.binding.get(var);
aggregator.accumulate(value);
}
}
aggregator.finalize()
} else {
Value::Null
}
}
}
}
fn get_frame_bounds(
row_idx: usize,
partition_size: usize,
peer_groups: &[usize],
frame: &FrameSpec,
) -> (usize, usize) {
let start = match &frame.start {
FrameBound::UnboundedPreceding => 0,
FrameBound::CurrentRow => {
match frame.frame_type {
FrameType::Rows => row_idx,
FrameType::Range | FrameType::Groups => {
let group = peer_groups[row_idx];
peer_groups
.iter()
.position(|&g| g == group)
.unwrap_or(row_idx)
}
}
}
FrameBound::Preceding(n) => {
match frame.frame_type {
FrameType::Rows => row_idx.saturating_sub(*n as usize),
FrameType::Groups => {
let current_group = peer_groups[row_idx];
let target_group = current_group.saturating_sub(*n as usize);
peer_groups
.iter()
.position(|&g| g == target_group)
.unwrap_or(0)
}
FrameType::Range => row_idx.saturating_sub(*n as usize),
}
}
FrameBound::Following(n) => match frame.frame_type {
FrameType::Rows => (row_idx + *n as usize).min(partition_size),
FrameType::Groups => {
let current_group = peer_groups[row_idx];
let target_group = current_group + *n as usize;
peer_groups
.iter()
.position(|&g| g >= target_group)
.unwrap_or(partition_size)
}
FrameType::Range => (row_idx + *n as usize).min(partition_size),
},
FrameBound::UnboundedFollowing => partition_size,
};
let end = match &frame.end {
FrameBound::UnboundedFollowing => partition_size,
FrameBound::CurrentRow => {
match frame.frame_type {
FrameType::Rows => row_idx + 1,
FrameType::Range | FrameType::Groups => {
let group = peer_groups[row_idx];
peer_groups
.iter()
.position(|&g| g > group)
.unwrap_or(partition_size)
}
}
}
FrameBound::Preceding(n) => match frame.frame_type {
FrameType::Rows => row_idx.saturating_sub(*n as usize) + 1,
FrameType::Groups => {
let current_group = peer_groups[row_idx];
let target_group = current_group.saturating_sub(*n as usize);
peer_groups
.iter()
.position(|&g| g > target_group)
.unwrap_or(partition_size)
}
FrameType::Range => row_idx.saturating_sub(*n as usize) + 1,
},
FrameBound::Following(n) => match frame.frame_type {
FrameType::Rows => (row_idx + *n as usize + 1).min(partition_size),
FrameType::Groups => {
let current_group = peer_groups[row_idx];
let target_group = current_group + *n as usize;
peer_groups
.iter()
.position(|&g| g > target_group)
.unwrap_or(partition_size)
}
FrameType::Range => (row_idx + *n as usize + 1).min(partition_size),
},
FrameBound::UnboundedPreceding => 0, };
(
start.min(partition_size),
end.min(partition_size).max(start),
)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_binding(pairs: &[(&str, Value)]) -> Binding {
if pairs.is_empty() {
return Binding::empty();
}
let mut result = Binding::one(Var::new(pairs[0].0), pairs[0].1.clone());
for (k, v) in pairs.iter().skip(1) {
let next = Binding::one(Var::new(k), v.clone());
result = result.merge(&next).unwrap_or(result);
}
result
}
fn get_values(bindings: &[Binding], var: &str) -> Vec<i64> {
let v = Var::new(var);
bindings
.iter()
.filter_map(|b| b.get(&v))
.filter_map(|v| match v {
Value::Integer(i) => Some(*i),
_ => None,
})
.collect()
}
#[test]
fn test_row_number() {
let bindings = vec![
make_binding(&[
("dept", Value::String("A".to_string())),
("salary", Value::Integer(100)),
]),
make_binding(&[
("dept", Value::String("A".to_string())),
("salary", Value::Integer(200)),
]),
make_binding(&[
("dept", Value::String("B".to_string())),
("salary", Value::Integer(150)),
]),
];
let func = WindowFunc::new(
WindowFuncType::RowNumber,
Var::new("rn"),
WindowDef::partition_by(vec![Var::new("dept")])
.with_order_by(vec![WindowOrderBy::new(Var::new("salary"))]),
);
let result = WindowExecutor::execute(bindings, &[func]);
let rns = get_values(&result, "rn");
assert_eq!(rns, vec![1, 2, 1]); }
#[test]
fn test_rank_with_ties() {
let bindings = vec![
make_binding(&[("score", Value::Integer(100))]),
make_binding(&[("score", Value::Integer(100))]), make_binding(&[("score", Value::Integer(90))]),
make_binding(&[("score", Value::Integer(80))]),
];
let func = WindowFunc::new(
WindowFuncType::Rank,
Var::new("rank"),
WindowDef::default().with_order_by(vec![WindowOrderBy::new(Var::new("score")).desc()]),
);
let result = WindowExecutor::execute(bindings, &[func]);
let ranks = get_values(&result, "rank");
assert_eq!(ranks, vec![1, 1, 3, 4]);
}
#[test]
fn test_dense_rank() {
let bindings = vec![
make_binding(&[("score", Value::Integer(100))]),
make_binding(&[("score", Value::Integer(100))]),
make_binding(&[("score", Value::Integer(90))]),
];
let func = WindowFunc::new(
WindowFuncType::DenseRank,
Var::new("drank"),
WindowDef::default().with_order_by(vec![WindowOrderBy::new(Var::new("score")).desc()]),
);
let result = WindowExecutor::execute(bindings, &[func]);
let ranks = get_values(&result, "drank");
assert_eq!(ranks, vec![1, 1, 2]);
}
#[test]
fn test_ntile() {
let bindings: Vec<Binding> = (1..=10)
.map(|i| make_binding(&[("val", Value::Integer(i))]))
.collect();
let func = WindowFunc::new(
WindowFuncType::Ntile(4),
Var::new("bucket"),
WindowDef::default().with_order_by(vec![WindowOrderBy::new(Var::new("val"))]),
);
let result = WindowExecutor::execute(bindings, &[func]);
let buckets = get_values(&result, "bucket");
assert_eq!(buckets, vec![1, 1, 1, 2, 2, 2, 3, 3, 4, 4]);
}
#[test]
fn test_lag_lead() {
let bindings: Vec<Binding> = (1..=5)
.map(|i| make_binding(&[("val", Value::Integer(i))]))
.collect();
let lag_func = WindowFunc::new(
WindowFuncType::Lag(Var::new("val"), 1, Some(Value::Integer(0))),
Var::new("prev"),
WindowDef::default().with_order_by(vec![WindowOrderBy::new(Var::new("val"))]),
);
let lead_func = WindowFunc::new(
WindowFuncType::Lead(Var::new("val"), 1, Some(Value::Integer(0))),
Var::new("next"),
WindowDef::default().with_order_by(vec![WindowOrderBy::new(Var::new("val"))]),
);
let result = WindowExecutor::execute(bindings, &[lag_func, lead_func]);
let prevs = get_values(&result, "prev");
let nexts = get_values(&result, "next");
assert_eq!(prevs, vec![0, 1, 2, 3, 4]); assert_eq!(nexts, vec![2, 3, 4, 5, 0]); }
#[test]
fn test_running_sum() {
let bindings: Vec<Binding> = (1..=5)
.map(|i| make_binding(&[("val", Value::Integer(i))]))
.collect();
let func = WindowFunc::new(
WindowFuncType::Aggregate("SUM".to_string(), Var::new("val")),
Var::new("running_sum"),
WindowDef::default()
.with_order_by(vec![WindowOrderBy::new(Var::new("val"))])
.with_frame(FrameSpec::running()),
);
let result = WindowExecutor::execute(bindings, &[func]);
let sums = get_values(&result, "running_sum");
assert_eq!(sums, vec![1, 3, 6, 10, 15]);
}
#[test]
fn test_first_last_value() {
let bindings: Vec<Binding> = (1..=5)
.map(|i| make_binding(&[("val", Value::Integer(i))]))
.collect();
let first_func = WindowFunc::new(
WindowFuncType::FirstValue(Var::new("val")),
Var::new("first"),
WindowDef::default()
.with_order_by(vec![WindowOrderBy::new(Var::new("val"))])
.with_frame(FrameSpec::entire_partition()),
);
let last_func = WindowFunc::new(
WindowFuncType::LastValue(Var::new("val")),
Var::new("last"),
WindowDef::default()
.with_order_by(vec![WindowOrderBy::new(Var::new("val"))])
.with_frame(FrameSpec::entire_partition()),
);
let result = WindowExecutor::execute(bindings, &[first_func, last_func]);
let firsts = get_values(&result, "first");
let lasts = get_values(&result, "last");
assert_eq!(firsts, vec![1, 1, 1, 1, 1]); assert_eq!(lasts, vec![5, 5, 5, 5, 5]); }
#[test]
fn test_partitioned_sum() {
let bindings = vec![
make_binding(&[
("dept", Value::String("A".to_string())),
("salary", Value::Integer(100)),
]),
make_binding(&[
("dept", Value::String("A".to_string())),
("salary", Value::Integer(200)),
]),
make_binding(&[
("dept", Value::String("B".to_string())),
("salary", Value::Integer(150)),
]),
make_binding(&[
("dept", Value::String("B".to_string())),
("salary", Value::Integer(250)),
]),
];
let func = WindowFunc::new(
WindowFuncType::Aggregate("SUM".to_string(), Var::new("salary")),
Var::new("dept_total"),
WindowDef::partition_by(vec![Var::new("dept")])
.with_frame(FrameSpec::entire_partition()),
);
let result = WindowExecutor::execute(bindings, &[func]);
let totals = get_values(&result, "dept_total");
assert_eq!(totals, vec![300, 300, 400, 400]);
}
#[test]
fn test_window_preserves_input_order() {
let bindings = vec![
make_binding(&[
("dept", Value::String("A".to_string())),
("seq", Value::Integer(1)),
]),
make_binding(&[
("dept", Value::String("B".to_string())),
("seq", Value::Integer(1)),
]),
make_binding(&[
("dept", Value::String("A".to_string())),
("seq", Value::Integer(2)),
]),
make_binding(&[
("dept", Value::String("B".to_string())),
("seq", Value::Integer(2)),
]),
];
let func = WindowFunc::new(
WindowFuncType::RowNumber,
Var::new("rn"),
WindowDef::partition_by(vec![Var::new("dept")])
.with_order_by(vec![WindowOrderBy::new(Var::new("seq"))]),
);
let result = WindowExecutor::execute(bindings, &[func]);
let dept_var = Var::new("dept");
let depts: Vec<String> = result
.iter()
.filter_map(|b| b.get(&dept_var))
.filter_map(|v| match v {
Value::String(s) => Some(s.clone()),
_ => None,
})
.collect();
assert_eq!(depts, vec!["A", "B", "A", "B"]);
assert_eq!(get_values(&result, "rn"), vec![1, 1, 2, 2]);
}
#[test]
fn test_percent_rank() {
let bindings: Vec<Binding> = (1..=4)
.map(|i| make_binding(&[("val", Value::Integer(i))]))
.collect();
let func = WindowFunc::new(
WindowFuncType::PercentRank,
Var::new("prank"),
WindowDef::default().with_order_by(vec![WindowOrderBy::new(Var::new("val"))]),
);
let result = WindowExecutor::execute(bindings, &[func]);
for (i, binding) in result.iter().enumerate() {
if let Some(Value::Float(pr)) = binding.get(&Var::new("prank")) {
let expected = i as f64 / 3.0;
assert!(
(pr - expected).abs() < 0.001,
"Row {}: expected {}, got {}",
i,
expected,
pr
);
}
}
}
}