use crate::error::{QueryError, Result};
use crate::executor::filter::Value;
use crate::executor::scan::{ColumnData, RecordBatch};
use std::cmp::Ordering;
use std::collections::HashMap;
#[derive(Debug, Clone, PartialEq)]
pub enum WindowFunction {
RowNumber,
Rank,
DenseRank,
Lag {
offset: usize,
default: Option<Value>,
},
Lead {
offset: usize,
default: Option<Value>,
},
FirstValue,
LastValue,
NthValue {
n: usize,
},
}
impl WindowFunction {
pub fn lag() -> Self {
WindowFunction::Lag {
offset: 1,
default: None,
}
}
pub fn lag_offset(offset: usize) -> Self {
WindowFunction::Lag {
offset,
default: None,
}
}
pub fn lag_offset_default(offset: usize, default: Value) -> Self {
WindowFunction::Lag {
offset,
default: Some(default),
}
}
pub fn lead() -> Self {
WindowFunction::Lead {
offset: 1,
default: None,
}
}
pub fn lead_offset(offset: usize) -> Self {
WindowFunction::Lead {
offset,
default: None,
}
}
pub fn lead_offset_default(offset: usize, default: Value) -> Self {
WindowFunction::Lead {
offset,
default: Some(default),
}
}
pub fn nth_value(n: usize) -> Self {
WindowFunction::NthValue { n }
}
pub fn reads_target(&self) -> bool {
matches!(
self,
WindowFunction::Lag { .. }
| WindowFunction::Lead { .. }
| WindowFunction::FirstValue
| WindowFunction::LastValue
| WindowFunction::NthValue { .. }
)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct OrderKey {
pub column: usize,
pub ascending: bool,
}
impl OrderKey {
pub fn asc(column: usize) -> Self {
Self {
column,
ascending: true,
}
}
pub fn desc(column: usize) -> Self {
Self {
column,
ascending: false,
}
}
}
#[derive(Debug, Clone, Default, PartialEq)]
pub struct WindowSpec {
pub partition_by: Vec<usize>,
pub order_by: Vec<OrderKey>,
}
impl WindowSpec {
pub fn new(partition_by: Vec<usize>, order_by: Vec<OrderKey>) -> Self {
Self {
partition_by,
order_by,
}
}
pub fn ordered(order_by: Vec<OrderKey>) -> Self {
Self {
partition_by: Vec::new(),
order_by,
}
}
}
fn compare_values(left: &Value, right: &Value) -> Ordering {
match (left, right) {
(Value::Null, Value::Null) => return Ordering::Equal,
(Value::Null, _) => return Ordering::Greater,
(_, Value::Null) => return Ordering::Less,
_ => {}
}
match (left, right) {
(Value::Boolean(a), Value::Boolean(b)) => a.cmp(b),
(Value::String(a), Value::String(b)) => a.cmp(b),
(Value::Int32(a), Value::Int32(b)) => a.cmp(b),
(Value::Int64(a), Value::Int64(b)) => a.cmp(b),
(Value::Int32(a), Value::Int64(b)) => (*a as i64).cmp(b),
(Value::Int64(a), Value::Int32(b)) => a.cmp(&(*b as i64)),
(Value::Float32(a), Value::Float32(b)) => float_cmp(*a as f64, *b as f64),
(Value::Float64(a), Value::Float64(b)) => float_cmp(*a, *b),
(Value::Float32(a), Value::Float64(b)) => float_cmp(*a as f64, *b),
(Value::Float64(a), Value::Float32(b)) => float_cmp(*a, *b as f64),
(Value::Int32(a), Value::Float32(b)) => float_cmp(*a as f64, *b as f64),
(Value::Int32(a), Value::Float64(b)) => float_cmp(*a as f64, *b),
(Value::Int64(a), Value::Float32(b)) => float_cmp(*a as f64, *b as f64),
(Value::Int64(a), Value::Float64(b)) => float_cmp(*a as f64, *b),
(Value::Float32(a), Value::Int32(b)) => float_cmp(*a as f64, *b as f64),
(Value::Float32(a), Value::Int64(b)) => float_cmp(*a as f64, *b as f64),
(Value::Float64(a), Value::Int32(b)) => float_cmp(*a, *b as f64),
(Value::Float64(a), Value::Int64(b)) => float_cmp(*a, *b as f64),
_ => Ordering::Equal,
}
}
fn float_cmp(a: f64, b: f64) -> Ordering {
match a.partial_cmp(&b) {
Some(ordering) => ordering,
None => {
match (a.is_nan(), b.is_nan()) {
(true, true) => Ordering::Equal,
(true, false) => Ordering::Greater,
(false, true) => Ordering::Less,
(false, false) => Ordering::Equal,
}
}
}
}
fn order_keys_equal(a: &[Value], b: &[Value]) -> bool {
a.len() == b.len()
&& a.iter()
.zip(b)
.all(|(x, y)| compare_values(x, y) == Ordering::Equal)
}
#[allow(unreachable_patterns)]
fn partition_fingerprint(values: &[Value]) -> String {
let mut key = String::new();
for value in values {
match value {
Value::Null => key.push_str("N|"),
Value::Boolean(b) => {
key.push_str("b:");
key.push_str(if *b { "1" } else { "0" });
key.push('|');
}
Value::Int32(i) => {
key.push_str("i:");
key.push_str(&(*i as i64).to_string());
key.push('|');
}
Value::Int64(i) => {
key.push_str("i:");
key.push_str(&i.to_string());
key.push('|');
}
Value::Float32(fl) => {
key.push_str("f:");
key.push_str(&format!("{:?}", *fl as f64));
key.push('|');
}
Value::Float64(fl) => {
key.push_str("f:");
key.push_str(&format!("{:?}", fl));
key.push('|');
}
Value::String(s) => {
key.push_str("s:");
key.push_str(&s.len().to_string());
key.push(':');
key.push_str(s);
key.push('|');
}
other => {
key.push_str("x:");
key.push_str(&format!("{:?}", other));
key.push('|');
}
}
}
key
}
pub fn evaluate_window(
func: &WindowFunction,
spec: &WindowSpec,
num_rows: usize,
target_column: usize,
value_at: impl Fn(usize, usize) -> Value,
) -> Result<Vec<Value>> {
if num_rows == 0 {
return Ok(Vec::new());
}
let partitions = build_partitions(spec, num_rows, &value_at);
let mut output = vec![Value::Null; num_rows];
for partition in &partitions {
let sorted = sort_partition(spec, partition, &value_at);
match func {
WindowFunction::RowNumber => {
assign_row_number(&sorted, &mut output);
}
WindowFunction::Rank => {
assign_rank(spec, &sorted, &value_at, &mut output, true);
}
WindowFunction::DenseRank => {
assign_rank(spec, &sorted, &value_at, &mut output, false);
}
WindowFunction::Lag { offset, default } => {
assign_lag_lead(
&sorted,
target_column,
&value_at,
default,
*offset as isize,
true,
&mut output,
);
}
WindowFunction::Lead { offset, default } => {
assign_lag_lead(
&sorted,
target_column,
&value_at,
default,
*offset as isize,
false,
&mut output,
);
}
WindowFunction::FirstValue => {
assign_nth_in_partition(&sorted, target_column, &value_at, Some(0), &mut output);
}
WindowFunction::LastValue => {
let last = sorted.len().checked_sub(1);
assign_nth_in_partition(&sorted, target_column, &value_at, last, &mut output);
}
WindowFunction::NthValue { n } => {
let index = n.checked_sub(1);
assign_nth_in_partition(&sorted, target_column, &value_at, index, &mut output);
}
}
}
Ok(output)
}
pub fn evaluate_window_batch(
func: &WindowFunction,
spec: &WindowSpec,
target_column: usize,
batch: &RecordBatch,
) -> Result<Vec<Value>> {
let max_column = spec
.partition_by
.iter()
.chain(spec.order_by.iter().map(|key| &key.column))
.copied()
.max();
if let Some(max_column) = max_column {
if max_column >= batch.columns.len() {
return Err(QueryError::execution(format!(
"Window column index {} out of bounds (batch has {} columns)",
max_column,
batch.columns.len()
)));
}
}
if func.reads_target() && target_column >= batch.columns.len() {
return Err(QueryError::execution(format!(
"Window target column index {} out of bounds (batch has {} columns)",
target_column,
batch.columns.len()
)));
}
evaluate_window(func, spec, batch.num_rows, target_column, |row, column| {
column_value(&batch.columns[column], row)
})
}
fn column_value(column: &ColumnData, row_idx: usize) -> Value {
match column {
ColumnData::Boolean(data) => data
.get(row_idx)
.and_then(|v| v.as_ref())
.map(|&v| Value::Boolean(v))
.unwrap_or(Value::Null),
ColumnData::Int32(data) => data
.get(row_idx)
.and_then(|v| v.as_ref())
.map(|&v| Value::Int32(v))
.unwrap_or(Value::Null),
ColumnData::Int64(data) => data
.get(row_idx)
.and_then(|v| v.as_ref())
.map(|&v| Value::Int64(v))
.unwrap_or(Value::Null),
ColumnData::Float32(data) => data
.get(row_idx)
.and_then(|v| v.as_ref())
.map(|&v| Value::Float32(v))
.unwrap_or(Value::Null),
ColumnData::Float64(data) => data
.get(row_idx)
.and_then(|v| v.as_ref())
.map(|&v| Value::Float64(v))
.unwrap_or(Value::Null),
ColumnData::String(data) => data
.get(row_idx)
.and_then(|v| v.as_ref())
.map(|v| Value::String(v.clone()))
.unwrap_or(Value::Null),
ColumnData::Binary(_) => Value::Null,
}
}
fn build_partitions(
spec: &WindowSpec,
num_rows: usize,
value_at: &impl Fn(usize, usize) -> Value,
) -> Vec<Vec<usize>> {
if spec.partition_by.is_empty() {
return vec![(0..num_rows).collect()];
}
let mut order: Vec<String> = Vec::new();
let mut groups: HashMap<String, Vec<usize>> = HashMap::new();
for row in 0..num_rows {
let key_values: Vec<Value> = spec
.partition_by
.iter()
.map(|&column| value_at(row, column))
.collect();
let fingerprint = partition_fingerprint(&key_values);
match groups.get_mut(&fingerprint) {
Some(bucket) => bucket.push(row),
None => {
order.push(fingerprint.clone());
groups.insert(fingerprint, vec![row]);
}
}
}
order
.into_iter()
.filter_map(|fingerprint| groups.remove(&fingerprint))
.collect()
}
fn sort_partition(
spec: &WindowSpec,
partition: &[usize],
value_at: &impl Fn(usize, usize) -> Value,
) -> Vec<usize> {
let mut sorted = partition.to_vec();
if spec.order_by.is_empty() {
return sorted;
}
sorted.sort_by(|&a, &b| {
for key in &spec.order_by {
let left = value_at(a, key.column);
let right = value_at(b, key.column);
let mut ordering = compare_values(&left, &right);
if !key.ascending {
ordering = ordering.reverse();
}
if ordering != Ordering::Equal {
return ordering;
}
}
Ordering::Equal
});
sorted
}
fn order_key_values(
spec: &WindowSpec,
row: usize,
value_at: &impl Fn(usize, usize) -> Value,
) -> Vec<Value> {
spec.order_by
.iter()
.map(|key| value_at(row, key.column))
.collect()
}
fn assign_row_number(sorted: &[usize], output: &mut [Value]) {
for (position, &row) in sorted.iter().enumerate() {
output[row] = Value::Int64((position + 1) as i64);
}
}
fn assign_rank(
spec: &WindowSpec,
sorted: &[usize],
value_at: &impl Fn(usize, usize) -> Value,
output: &mut [Value],
competition: bool,
) {
let mut current_rank: i64 = 0;
let mut previous_key: Option<Vec<Value>> = None;
for (position, &row) in sorted.iter().enumerate() {
let ordinal = (position + 1) as i64;
let key = order_key_values(spec, row, value_at);
let is_new_group = match &previous_key {
None => true,
Some(prev) => !order_keys_equal(prev, &key),
};
if is_new_group {
current_rank = if competition {
ordinal
} else {
current_rank + 1
};
previous_key = Some(key);
}
output[row] = Value::Int64(current_rank);
}
}
fn assign_lag_lead(
sorted: &[usize],
target_column: usize,
value_at: &impl Fn(usize, usize) -> Value,
default: &Option<Value>,
offset: isize,
backward: bool,
output: &mut [Value],
) {
let len = sorted.len() as isize;
let signed_offset = if backward { -offset } else { offset };
for (position, &row) in sorted.iter().enumerate() {
let target_index = position as isize + signed_offset;
let value = if target_index >= 0 && target_index < len {
let source_row = sorted[target_index as usize];
value_at(source_row, target_column)
} else {
default.clone().unwrap_or(Value::Null)
};
output[row] = value;
}
}
fn assign_nth_in_partition(
sorted: &[usize],
target_column: usize,
value_at: &impl Fn(usize, usize) -> Value,
index: Option<usize>,
output: &mut [Value],
) {
let picked = match index {
Some(idx) => sorted
.get(idx)
.map(|&source_row| value_at(source_row, target_column))
.unwrap_or(Value::Null),
None => Value::Null,
};
for &row in sorted {
output[row] = picked.clone();
}
}
#[cfg(test)]
mod tests {
use super::*;
fn table(columns: Vec<Vec<Value>>) -> (usize, impl Fn(usize, usize) -> Value) {
let num_rows = columns.first().map(|c| c.len()).unwrap_or(0);
(num_rows, move |row: usize, col: usize| {
columns[col][row].clone()
})
}
#[test]
fn unit_row_number_basic() -> Result<()> {
let (rows, value_at) = table(vec![vec![
Value::Int64(30),
Value::Int64(10),
Value::Int64(20),
]]);
let spec = WindowSpec::ordered(vec![OrderKey::asc(0)]);
let out = evaluate_window(&WindowFunction::RowNumber, &spec, rows, 0, value_at)?;
assert_eq!(out[1], Value::Int64(1));
assert_eq!(out[2], Value::Int64(2));
assert_eq!(out[0], Value::Int64(3));
Ok(())
}
#[test]
fn unit_compare_nulls_last() {
assert_eq!(
compare_values(&Value::Int64(1), &Value::Null),
Ordering::Less
);
assert_eq!(
compare_values(&Value::Null, &Value::Int64(1)),
Ordering::Greater
);
assert_eq!(compare_values(&Value::Null, &Value::Null), Ordering::Equal);
}
#[test]
fn unit_compare_mixed_numeric() {
assert_eq!(
compare_values(&Value::Int32(5), &Value::Float64(5.5)),
Ordering::Less
);
assert_eq!(
compare_values(&Value::Float64(2.0), &Value::Int64(2)),
Ordering::Equal
);
}
}