use std::sync::Arc;
use arrow::array::{Array, ArrayRef, Int64Array};
use arrow::datatypes::DataType;
use datafusion::error::DataFusionError;
use datafusion::logical_expr::{ColumnarValue, Volatility, create_udf};
use datafusion::prelude::SessionContext;
pub fn register_window_functions(ctx: &SessionContext) -> Result<(), DataFusionError> {
ctx.register_udf(make_tumble_start());
ctx.register_udf(make_tumble_end());
ctx.register_udf(make_hop_start());
ctx.register_udf(make_hop_end());
ctx.register_udf(make_session_start());
ctx.register_udf(make_session_end());
Ok(())
}
fn make_tumble_start() -> datafusion::logical_expr::ScalarUDF {
create_udf(
"tumble_start",
vec![DataType::Int64, DataType::Int64],
DataType::Int64,
Volatility::Immutable,
Arc::new(|args: &[ColumnarValue]| {
let a0 = args
.first()
.ok_or_else(|| DataFusionError::Internal("tumble_start: missing arg 0".into()))?;
let a1 = args
.get(1)
.ok_or_else(|| DataFusionError::Internal("tumble_start: missing arg 1".into()))?;
apply2(a0, a1, |t, s| if s != 0 { t - t.rem_euclid(s) } else { t })
}),
)
}
fn make_tumble_end() -> datafusion::logical_expr::ScalarUDF {
create_udf(
"tumble_end",
vec![DataType::Int64, DataType::Int64],
DataType::Int64,
Volatility::Immutable,
Arc::new(|args: &[ColumnarValue]| {
let a0 = args
.first()
.ok_or_else(|| DataFusionError::Internal("tumble_end: missing arg 0".into()))?;
let a1 = args
.get(1)
.ok_or_else(|| DataFusionError::Internal("tumble_end: missing arg 1".into()))?;
apply2(
a0,
a1,
|t, s| {
if s != 0 { t - t.rem_euclid(s) + s } else { t }
},
)
}),
)
}
fn make_hop_start() -> datafusion::logical_expr::ScalarUDF {
create_udf(
"hop_start",
vec![DataType::Int64, DataType::Int64, DataType::Int64],
DataType::Int64,
Volatility::Immutable,
Arc::new(|args: &[ColumnarValue]| {
let a0 = args
.first()
.ok_or_else(|| DataFusionError::Internal("hop_start: missing arg 0".into()))?;
let a1 = args
.get(1)
.ok_or_else(|| DataFusionError::Internal("hop_start: missing arg 1".into()))?;
let a2 = args
.get(2)
.ok_or_else(|| DataFusionError::Internal("hop_start: missing arg 2".into()))?;
apply3(
a0,
a1,
a2,
|t, sl, _sz| {
if sl != 0 { t - t.rem_euclid(sl) } else { t }
},
)
}),
)
}
fn make_hop_end() -> datafusion::logical_expr::ScalarUDF {
create_udf(
"hop_end",
vec![DataType::Int64, DataType::Int64, DataType::Int64],
DataType::Int64,
Volatility::Immutable,
Arc::new(|args: &[ColumnarValue]| {
let a0 = args
.first()
.ok_or_else(|| DataFusionError::Internal("hop_end: missing arg 0".into()))?;
let a1 = args
.get(1)
.ok_or_else(|| DataFusionError::Internal("hop_end: missing arg 1".into()))?;
let a2 = args
.get(2)
.ok_or_else(|| DataFusionError::Internal("hop_end: missing arg 2".into()))?;
apply3(a0, a1, a2, |t, sl, sz| {
if sl != 0 {
t - t.rem_euclid(sl) + sz
} else {
t
}
})
}),
)
}
fn make_session_start() -> datafusion::logical_expr::ScalarUDF {
create_udf(
"session_start",
vec![DataType::Int64, DataType::Int64],
DataType::Int64,
Volatility::Immutable,
Arc::new(|args: &[ColumnarValue]| {
let a0 = args
.first()
.ok_or_else(|| DataFusionError::Internal("session_start: missing arg 0".into()))?;
let a1 = args
.get(1)
.ok_or_else(|| DataFusionError::Internal("session_start: missing arg 1".into()))?;
apply2(a0, a1, |t, _gap| t)
}),
)
}
fn make_session_end() -> datafusion::logical_expr::ScalarUDF {
create_udf(
"session_end",
vec![DataType::Int64, DataType::Int64],
DataType::Int64,
Volatility::Immutable,
Arc::new(|args: &[ColumnarValue]| {
let a0 = args
.first()
.ok_or_else(|| DataFusionError::Internal("session_end: missing arg 0".into()))?;
let a1 = args
.get(1)
.ok_or_else(|| DataFusionError::Internal("session_end: missing arg 1".into()))?;
apply2(a0, a1, |t, gap| t + gap)
}),
)
}
fn cast_to_int64_array(args: &[ColumnarValue], idx: usize) -> Result<Int64Array, DataFusionError> {
use datafusion::scalar::ScalarValue;
match args.get(idx).ok_or_else(|| {
DataFusionError::Internal(format!("window function: missing argument {idx}"))
})? {
ColumnarValue::Array(arr) => {
let typed = arr.as_any().downcast_ref::<Int64Array>().ok_or_else(|| {
DataFusionError::Internal(format!(
"window function argument {idx} expected Int64, got {:?}",
arr.data_type()
))
})?;
Ok(typed.clone())
}
ColumnarValue::Scalar(ScalarValue::Int64(v)) => {
Ok(Int64Array::from(vec![*v]))
}
ColumnarValue::Scalar(other) => Err(DataFusionError::Internal(format!(
"window function argument {idx} expected Int64 scalar, got {other:?}"
))),
}
}
fn apply2(
lhs: &ColumnarValue,
rhs: &ColumnarValue,
f: impl Fn(i64, i64) -> i64,
) -> Result<ColumnarValue, DataFusionError> {
use datafusion::scalar::ScalarValue;
if let (
ColumnarValue::Scalar(ScalarValue::Int64(a)),
ColumnarValue::Scalar(ScalarValue::Int64(b)),
) = (lhs, rhs)
{
let result = match (a, b) {
(Some(a), Some(b)) => Some(f(*a, *b)),
_ => None,
};
return Ok(ColumnarValue::Scalar(ScalarValue::Int64(result)));
}
let a_arr = cast_to_int64_array(std::slice::from_ref(lhs), 0)?;
let b_arr = cast_to_int64_array(std::slice::from_ref(rhs), 0)?;
if a_arr.is_empty() || b_arr.is_empty() {
return Ok(ColumnarValue::Array(
Arc::new(Int64Array::from(Vec::<Option<i64>>::new())) as ArrayRef,
));
}
if a_arr.len() != 1 && b_arr.len() != 1 && a_arr.len() != b_arr.len() {
return Err(DataFusionError::Internal(format!(
"window function: incompatible array lengths {} and {}",
a_arr.len(),
b_arr.len()
)));
}
let len = a_arr.len().max(b_arr.len());
let a_val = |i: usize| {
if a_arr.len() == 1 {
a_arr.value(0)
} else {
a_arr.value(i)
}
};
let b_val = |i: usize| {
if b_arr.len() == 1 {
b_arr.value(0)
} else {
b_arr.value(i)
}
};
let result: Int64Array = (0..len)
.map(|i| {
if a_arr.is_null(i.min(a_arr.len() - 1)) || b_arr.is_null(i.min(b_arr.len() - 1)) {
None
} else {
Some(f(a_val(i), b_val(i)))
}
})
.collect();
Ok(ColumnarValue::Array(Arc::new(result) as ArrayRef))
}
fn apply3(
a: &ColumnarValue,
b: &ColumnarValue,
c: &ColumnarValue,
f: impl Fn(i64, i64, i64) -> i64,
) -> Result<ColumnarValue, DataFusionError> {
use datafusion::scalar::ScalarValue;
if let (
ColumnarValue::Scalar(ScalarValue::Int64(va)),
ColumnarValue::Scalar(ScalarValue::Int64(vb)),
ColumnarValue::Scalar(ScalarValue::Int64(vc)),
) = (a, b, c)
{
let result = match (va, vb, vc) {
(Some(a), Some(b), Some(c)) => Some(f(*a, *b, *c)),
_ => None,
};
return Ok(ColumnarValue::Scalar(ScalarValue::Int64(result)));
}
let a_arr = cast_to_int64_array(std::slice::from_ref(a), 0)?;
let b_arr = cast_to_int64_array(std::slice::from_ref(b), 0)?;
let c_arr = cast_to_int64_array(std::slice::from_ref(c), 0)?;
if a_arr.is_empty() || b_arr.is_empty() || c_arr.is_empty() {
return Ok(ColumnarValue::Array(
Arc::new(Int64Array::from(Vec::<Option<i64>>::new())) as ArrayRef,
));
}
let max_len = a_arr.len().max(b_arr.len()).max(c_arr.len());
for (name, len) in [("a", a_arr.len()), ("b", b_arr.len()), ("c", c_arr.len())] {
if len != 1 && len != max_len {
return Err(DataFusionError::Internal(format!(
"window function: argument '{name}' length {len} incompatible with max length {max_len}"
)));
}
}
let a_val = |i: usize| {
if a_arr.len() == 1 {
a_arr.value(0)
} else {
a_arr.value(i)
}
};
let b_val = |i: usize| {
if b_arr.len() == 1 {
b_arr.value(0)
} else {
b_arr.value(i)
}
};
let c_val = |i: usize| {
if c_arr.len() == 1 {
c_arr.value(0)
} else {
c_arr.value(i)
}
};
let result: Int64Array = (0..max_len)
.map(|i| {
let ai = i.min(a_arr.len() - 1);
let bi = i.min(b_arr.len() - 1);
let ci = i.min(c_arr.len() - 1);
if a_arr.is_null(ai) || b_arr.is_null(bi) || c_arr.is_null(ci) {
None
} else {
Some(f(a_val(i), b_val(i), c_val(i)))
}
})
.collect();
Ok(ColumnarValue::Array(Arc::new(result) as ArrayRef))
}
#[cfg(test)]
mod tests {
use arrow::array::cast::AsArray;
use arrow::datatypes::Int64Type;
use super::*;
fn make_ctx() -> SessionContext {
let ctx = SessionContext::new();
register_window_functions(&ctx).unwrap();
ctx
}
async fn query_i64(ctx: &SessionContext, sql: &str) -> i64 {
let result = ctx.sql(sql).await.unwrap().collect().await.unwrap();
let col = result.first().expect("empty result").column(0);
if let Some(arr) = col.as_any().downcast_ref::<Int64Array>() {
return arr.value(0);
}
col.as_primitive::<Int64Type>().value(0)
}
#[tokio::test]
async fn tumble_start_aligns_to_window() {
let ctx = make_ctx();
let val = query_i64(&ctx, "SELECT tumble_start(65000, 60000) AS ws").await;
assert_eq!(val, 60000, "65s → window starting at 60s");
}
#[tokio::test]
async fn tumble_end_is_start_plus_size() {
let ctx = make_ctx();
let val = query_i64(&ctx, "SELECT tumble_end(65000, 60000) AS we").await;
assert_eq!(val, 120000, "window end = 60000 + 60000");
}
#[tokio::test]
async fn hop_start_aligns_to_slide() {
let ctx = make_ctx();
let val = query_i64(&ctx, "SELECT hop_start(65000, 30000, 60000) AS hs").await;
assert_eq!(val, 60000, "65s / 30s slide → hop start at 60s");
}
#[tokio::test]
async fn hop_end_is_start_plus_size() {
let ctx = make_ctx();
let val = query_i64(&ctx, "SELECT hop_end(65000, 30000, 60000) AS he").await;
assert_eq!(val, 120000, "hop end = 60000 + 60000");
}
#[tokio::test]
async fn window_functions_work_on_table_column() {
let ctx = make_ctx();
register_window_functions(&ctx).unwrap();
ctx.sql(
"CREATE TABLE events (ts BIGINT, user_id VARCHAR) AS VALUES (65000, 'alice'), (130000, 'bob')"
).await.unwrap().collect().await.unwrap();
let result = ctx
.sql("SELECT tumble_start(ts, 60000), user_id FROM events ORDER BY ts")
.await
.unwrap()
.collect()
.await
.unwrap();
let starts = result
.first()
.expect("empty result")
.column(0)
.as_primitive::<Int64Type>();
assert_eq!(starts.value(0), 60000);
assert_eq!(starts.value(1), 120000);
}
}