use std::collections::BTreeMap;
use std::sync::Arc;
use ahash::AHashMap;
use arrow::array::ArrayRef;
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
use arrow::record_batch::RecordBatch;
use datafusion::physical_expr::{create_physical_expr, PhysicalExpr};
use datafusion::prelude::SessionContext;
use datafusion_common::ScalarValue;
use laminar_sql::translator::{WindowOperatorConfig, WindowType};
use crate::aggregate_state::{
compile_having_filter, expr_to_sql, extract_clauses, find_aggregate, resolve_expr_type,
AggFuncSpec, CompiledProjection, EowcStateCheckpoint,
};
use crate::error::DbError;
#[derive(Debug, Clone)]
pub(crate) enum EowcWindowType {
Tumbling { size_ms: i64 },
Hopping { size_ms: i64, slide_ms: i64 },
Session { gap_ms: i64 },
}
impl EowcWindowType {
fn from_config(config: &WindowOperatorConfig) -> Self {
match config.window_type {
WindowType::Tumbling | WindowType::Cumulate => {
let size_ms = i64::try_from(config.size.as_millis()).unwrap_or(i64::MAX);
EowcWindowType::Tumbling { size_ms }
}
WindowType::Sliding => {
let size_ms = i64::try_from(config.size.as_millis()).unwrap_or(i64::MAX);
let slide_ms = config.slide.map_or(size_ms, |s| {
i64::try_from(s.as_millis()).unwrap_or(i64::MAX)
});
EowcWindowType::Hopping { size_ms, slide_ms }
}
WindowType::Session => {
let gap_ms = config
.gap
.map_or(0, |g| i64::try_from(g.as_millis()).unwrap_or(i64::MAX));
EowcWindowType::Session { gap_ms }
}
}
}
fn size_ms(&self) -> i64 {
match self {
EowcWindowType::Tumbling { size_ms } | EowcWindowType::Hopping { size_ms, .. } => {
*size_ms
}
EowcWindowType::Session { gap_ms } => *gap_ms,
}
}
}
fn assign_windows(ts_ms: i64, window_type: &EowcWindowType) -> Vec<i64> {
match window_type {
EowcWindowType::Tumbling { size_ms } => {
if *size_ms <= 0 {
return vec![0];
}
vec![ts_ms.div_euclid(*size_ms) * size_ms]
}
EowcWindowType::Hopping {
size_ms, slide_ms, ..
} => {
if *slide_ms <= 0 || *size_ms <= 0 {
return vec![0];
}
let mut windows = Vec::new();
let last_start = ts_ms.div_euclid(*slide_ms) * slide_ms;
let mut start = last_start;
while start + size_ms > ts_ms {
windows.push(start);
start -= slide_ms;
}
windows
}
EowcWindowType::Session { .. } => {
unreachable!(
"session window reached EOWC assign_windows — \
this is a routing bug; session queries must use CoreWindowState"
)
}
}
}
pub(crate) struct IncrementalEowcState {
window_type: EowcWindowType,
#[allow(clippy::type_complexity)]
windows:
BTreeMap<i64, AHashMap<arrow::row::OwnedRow, Vec<Box<dyn datafusion_expr::Accumulator>>>>,
row_converter: arrow::row::RowConverter,
agg_specs: Vec<AggFuncSpec>,
num_group_cols: usize,
group_types: Vec<DataType>,
pre_agg_sql: String,
output_schema: SchemaRef,
time_col_index: usize,
compiled_projection: Option<CompiledProjection>,
cached_pre_agg_plan: Option<datafusion_expr::LogicalPlan>,
having_filter: Option<Arc<dyn PhysicalExpr>>,
having_sql: Option<String>,
max_groups_per_window: usize,
allowed_lateness_ms: i64,
}
impl IncrementalEowcState {
#[allow(clippy::too_many_lines)]
pub async fn try_from_sql(
ctx: &SessionContext,
sql: &str,
window_config: &WindowOperatorConfig,
) -> Result<Option<Self>, DbError> {
let df = ctx
.sql(sql)
.await
.map_err(|e| DbError::Pipeline(format!("plan error: {e}")))?;
let plan = df.logical_plan();
let top_schema = Arc::new(plan.schema().as_arrow().clone());
let Some(agg_info) = find_aggregate(plan) else {
return Ok(None);
};
let group_exprs = agg_info.group_exprs;
let aggr_exprs = agg_info.aggr_exprs;
let agg_schema = agg_info.schema;
let input_schema = agg_info.input_schema;
let having_predicate = agg_info.having_predicate;
if aggr_exprs.is_empty() {
return Ok(None);
}
let compile_source = crate::sql_analysis::single_source_table(sql);
let state = ctx.state();
let props = state.execution_props();
let input_df_schema = &agg_info.input_df_schema;
let mut compiled_exprs: Vec<Arc<dyn PhysicalExpr>> = Vec::new();
let mut proj_fields: Vec<Field> = Vec::new();
let mut compile_ok = compile_source.is_some();
if top_schema.fields().len() != agg_schema.fields().len() {
return Ok(None);
}
for (top_f, agg_f) in top_schema.fields().iter().zip(agg_schema.fields()) {
if top_f.data_type() != agg_f.data_type() {
return Ok(None);
}
}
let num_group_cols = group_exprs.len();
let mut group_col_names = Vec::new();
let mut group_types = Vec::new();
for i in 0..num_group_cols {
let top_field = top_schema.field(i);
let agg_field = agg_schema.field(i);
group_col_names.push(top_field.name().clone());
group_types.push(agg_field.data_type().clone());
}
let mut agg_specs = Vec::new();
let mut pre_agg_select_items: Vec<String> = Vec::new();
for (i, group_expr) in group_exprs.iter().enumerate() {
if let datafusion_expr::Expr::Column(col) = group_expr {
pre_agg_select_items.push(format!("\"{}\"", col.name));
} else {
let group_sql = expr_to_sql(group_expr);
pre_agg_select_items.push(format!("{group_sql} AS \"__group_{i}\""));
}
if compile_ok {
match create_physical_expr(group_expr, input_df_schema, props) {
Ok(phys) => {
let dt = phys
.data_type(input_df_schema.as_arrow())
.unwrap_or(DataType::Utf8);
let name = match group_expr {
datafusion_expr::Expr::Column(col) => col.name.clone(),
_ => format!("__group_{i}"),
};
proj_fields.push(Field::new(name, dt, true));
compiled_exprs.push(phys);
}
Err(_) => compile_ok = false,
}
}
}
let mut next_col_idx = num_group_cols;
for (i, expr) in aggr_exprs.iter().enumerate() {
let agg_schema_idx = num_group_cols + i;
let agg_field = agg_schema.field(agg_schema_idx);
let output_name = if agg_schema_idx < top_schema.fields().len() {
top_schema.field(agg_schema_idx).name().clone()
} else {
agg_field.name().clone()
};
if let datafusion_expr::Expr::AggregateFunction(agg_func) = expr {
let udf = Arc::clone(&agg_func.func);
let is_distinct = agg_func.params.distinct;
let mut input_col_indices = Vec::new();
let mut input_types = Vec::new();
if agg_func.params.args.is_empty() {
let col_idx = next_col_idx;
next_col_idx += 1;
pre_agg_select_items.push(format!("TRUE AS \"__agg_input_{col_idx}\""));
input_col_indices.push(col_idx);
input_types.push(DataType::Boolean);
if compile_ok {
match create_physical_expr(
&datafusion_expr::lit(true),
input_df_schema,
props,
) {
Ok(phys) => {
proj_fields.push(Field::new(
format!("__agg_input_{col_idx}"),
DataType::Boolean,
true,
));
compiled_exprs.push(phys);
}
Err(_) => compile_ok = false,
}
}
} else {
for arg_expr in &agg_func.params.args {
let col_idx = next_col_idx;
next_col_idx += 1;
let expr_sql = expr_to_sql(arg_expr);
if let Some(filter_expr) = &agg_func.params.filter {
let filter_sql = expr_to_sql(filter_expr);
pre_agg_select_items.push(format!(
"CASE WHEN {filter_sql} THEN {expr_sql} ELSE NULL END AS \"__agg_input_{col_idx}\""
));
if compile_ok {
let case_expr =
datafusion_expr::Expr::Case(datafusion_expr::expr::Case {
expr: None,
when_then_expr: vec![(
Box::new(filter_expr.as_ref().clone()),
Box::new(arg_expr.clone()),
)],
else_expr: Some(Box::new(datafusion_expr::lit(
ScalarValue::Null,
))),
});
match create_physical_expr(&case_expr, input_df_schema, props) {
Ok(phys) => {
let dt = resolve_expr_type(
arg_expr,
&input_schema,
agg_field.data_type(),
);
proj_fields.push(Field::new(
format!("__agg_input_{col_idx}"),
dt,
true,
));
compiled_exprs.push(phys);
}
Err(_) => compile_ok = false,
}
}
} else {
pre_agg_select_items
.push(format!("{expr_sql} AS \"__agg_input_{col_idx}\""));
if compile_ok {
match create_physical_expr(arg_expr, input_df_schema, props) {
Ok(phys) => {
let dt = resolve_expr_type(
arg_expr,
&input_schema,
agg_field.data_type(),
);
proj_fields.push(Field::new(
format!("__agg_input_{col_idx}"),
dt,
true,
));
compiled_exprs.push(phys);
}
Err(_) => compile_ok = false,
}
}
}
input_col_indices.push(col_idx);
let dt = resolve_expr_type(arg_expr, &input_schema, agg_field.data_type());
input_types.push(dt);
}
}
let filter_col_index = if let Some(filter_expr) = &agg_func.params.filter {
let col_idx = next_col_idx;
next_col_idx += 1;
let filter_sql = expr_to_sql(filter_expr);
pre_agg_select_items.push(format!(
"CASE WHEN {filter_sql} THEN TRUE ELSE FALSE END AS \"__agg_filter_{col_idx}\""
));
if compile_ok {
let case_expr = datafusion_expr::Expr::Case(datafusion_expr::expr::Case {
expr: None,
when_then_expr: vec![(
Box::new(filter_expr.as_ref().clone()),
Box::new(datafusion_expr::lit(true)),
)],
else_expr: Some(Box::new(datafusion_expr::lit(false))),
});
match create_physical_expr(&case_expr, input_df_schema, props) {
Ok(phys) => {
proj_fields.push(Field::new(
format!("__agg_filter_{col_idx}"),
DataType::Boolean,
true,
));
compiled_exprs.push(phys);
}
Err(_) => compile_ok = false,
}
}
Some(col_idx)
} else {
None
};
let return_type = udf
.return_type(&input_types)
.unwrap_or_else(|_| agg_field.data_type().clone());
agg_specs.push(AggFuncSpec {
udf,
input_types,
input_col_indices,
output_name,
return_type,
distinct: is_distinct,
is_count_star: agg_func.params.args.is_empty(),
filter_col_index,
});
} else {
return Ok(None);
}
}
let time_col_index = next_col_idx;
pre_agg_select_items.push(format!(
"\"{}\" AS \"__eowc_ts\"",
window_config.time_column
));
if compile_ok {
let time_expr = datafusion_expr::Expr::Column(
datafusion_common::Column::new_unqualified(&window_config.time_column),
);
match create_physical_expr(&time_expr, input_df_schema, props) {
Ok(phys) => {
let dt = phys
.data_type(input_df_schema.as_arrow())
.unwrap_or(DataType::Int64);
proj_fields.push(Field::new("__eowc_ts", dt, true));
compiled_exprs.push(phys);
}
Err(_) => compile_ok = false,
}
}
let clauses = extract_clauses(sql);
let pre_agg_sql = format!(
"SELECT {} FROM {}{}",
pre_agg_select_items.join(", "),
clauses.from_clause,
clauses.where_clause,
);
let compiled_projection = if compile_ok {
let source_table = compile_source.unwrap();
let filter = if let Some(where_pred) = &agg_info.where_predicate {
if let Ok(phys) = create_physical_expr(where_pred, input_df_schema, props) {
Some(phys)
} else {
compile_ok = false;
None
}
} else {
None
};
if compile_ok {
Some(CompiledProjection {
source_table,
exprs: compiled_exprs,
filter,
output_schema: Arc::new(Schema::new(proj_fields)),
})
} else {
None
}
} else {
None
};
let mut output_fields: Vec<Field> = vec![
Field::new("window_start", DataType::Int64, false),
Field::new("window_end", DataType::Int64, false),
];
for (name, dt) in group_col_names.iter().zip(group_types.iter()) {
output_fields.push(Field::new(name, dt.clone(), true));
}
for spec in &agg_specs {
output_fields.push(Field::new(
&spec.output_name,
spec.return_type.clone(),
true,
));
}
let output_schema = Arc::new(Schema::new(output_fields));
let having_filter = compile_having_filter(ctx, having_predicate.as_ref(), &output_schema);
let having_sql = if having_filter.is_none() {
having_predicate.as_ref().map(expr_to_sql)
} else {
None
};
let window_type = EowcWindowType::from_config(window_config);
let cached_pre_agg_plan = if compiled_projection.is_none() {
match ctx.sql(&pre_agg_sql).await {
Ok(df) => Some(df.logical_plan().clone()),
Err(_) => None,
}
} else {
None
};
let sort_fields: Vec<arrow::row::SortField> = group_types
.iter()
.map(|dt| arrow::row::SortField::new(dt.clone()))
.collect();
let row_converter = arrow::row::RowConverter::new(sort_fields)
.map_err(|e| DbError::Pipeline(format!("row converter init: {e}")))?;
Ok(Some(Self {
window_type,
windows: BTreeMap::new(),
row_converter,
agg_specs,
num_group_cols,
group_types,
pre_agg_sql,
output_schema,
time_col_index,
compiled_projection,
cached_pre_agg_plan,
having_filter,
having_sql,
max_groups_per_window: 1_000_000,
allowed_lateness_ms: i64::try_from(window_config.allowed_lateness.as_millis())
.unwrap_or(0),
}))
}
#[allow(clippy::too_many_lines)]
pub fn update_batch(&mut self, batch: &RecordBatch) -> Result<(), DbError> {
if batch.num_rows() == 0 {
return Ok(());
}
let ts_array = extract_i64_timestamps(batch, self.time_col_index)?;
let has_groups = self.num_group_cols > 0;
let rows = if has_groups {
let group_cols: Vec<ArrayRef> = (0..self.num_group_cols)
.map(|i| Arc::clone(batch.column(i)))
.collect();
let rows = self
.row_converter
.convert_columns(&group_cols)
.map_err(|e| DbError::Pipeline(format!("row conversion: {e}")))?;
Some(rows)
} else {
None
};
if !has_groups {
let empty_key = crate::aggregate_state::global_aggregate_key();
let mut grouped: AHashMap<i64, Vec<u32>> = AHashMap::new();
for (row_idx, &ts_ms) in ts_array.iter().enumerate() {
if ts_ms == NULL_TIMESTAMP {
continue; }
#[allow(clippy::cast_possible_truncation)]
let idx = row_idx as u32;
for ws in assign_windows(ts_ms, &self.window_type) {
grouped.entry(ws).or_default().push(idx);
}
}
for (window_start, indices) in &grouped {
let needs_insert = {
let wg = self.windows.entry(*window_start).or_default();
!wg.contains_key(&empty_key)
};
if needs_insert {
let mut accs = Vec::with_capacity(self.agg_specs.len());
for spec in &self.agg_specs {
accs.push(spec.create_accumulator()?);
}
self.windows
.entry(*window_start)
.or_default()
.insert(empty_key.clone(), accs);
}
let Some(accs) = self
.windows
.get_mut(window_start)
.and_then(|g| g.get_mut(&empty_key))
else {
continue;
};
crate::aggregate_state::IncrementalAggState::update_group_accumulators(
accs,
batch,
indices,
&self.agg_specs,
None,
)?;
}
return Ok(());
}
let rows_ref = rows.as_ref().expect("rows set when has_groups");
let mut grouped: AHashMap<(i64, arrow::row::OwnedRow), Vec<u32>> = AHashMap::new();
for (row_idx, &ts_ms) in ts_array.iter().enumerate() {
if ts_ms == NULL_TIMESTAMP {
continue; }
let row_key = rows_ref.row(row_idx).owned();
#[allow(clippy::cast_possible_truncation)]
let idx = row_idx as u32;
for ws in assign_windows(ts_ms, &self.window_type) {
grouped.entry((ws, row_key.clone())).or_default().push(idx);
}
}
for ((window_start, row_key), indices) in &grouped {
let needs_insert = {
let window_groups = self.windows.entry(*window_start).or_default();
if window_groups.contains_key(row_key) {
false
} else if window_groups.len() >= self.max_groups_per_window {
tracing::warn!(
max_groups = self.max_groups_per_window,
window_start,
"EOWC per-window group cardinality limit reached"
);
continue;
} else {
true
}
};
if needs_insert {
let mut accs = Vec::with_capacity(self.agg_specs.len());
for spec in &self.agg_specs {
accs.push(spec.create_accumulator()?);
}
self.windows
.entry(*window_start)
.or_default()
.insert(row_key.clone(), accs);
}
let Some(accs) = self
.windows
.get_mut(window_start)
.and_then(|g| g.get_mut(row_key))
else {
continue;
};
crate::aggregate_state::IncrementalAggState::update_group_accumulators(
accs,
batch,
indices,
&self.agg_specs,
None,
)?;
}
Ok(())
}
pub fn close_windows(&mut self, watermark_ms: i64) -> Result<Vec<RecordBatch>, DbError> {
let size_ms = self.window_type.size_ms();
if size_ms <= 0 {
return Ok(Vec::new());
}
let to_close: Vec<i64> = self
.windows
.keys()
.copied()
.take_while(|&ws| {
ws.saturating_add(size_ms)
.saturating_add(self.allowed_lateness_ms)
<= watermark_ms
})
.collect();
if to_close.is_empty() {
return Ok(Vec::new());
}
let mut result_batches = Vec::new();
for window_start in to_close {
let Some(groups) = self.windows.remove(&window_start) else {
continue;
};
if groups.is_empty() {
continue;
}
let window_end = window_start.saturating_add(size_ms);
let batch = self.emit_window(window_start, window_end, groups)?;
if let Some(b) = batch {
result_batches.push(b);
}
}
Ok(result_batches)
}
fn emit_window(
&self,
window_start: i64,
window_end: i64,
groups: AHashMap<arrow::row::OwnedRow, Vec<Box<dyn datafusion_expr::Accumulator>>>,
) -> Result<Option<RecordBatch>, DbError> {
crate::aggregate_state::emit_window_batch(
window_start,
window_end,
groups,
&self.row_converter,
self.num_group_cols,
&self.agg_specs,
&self.output_schema,
)
}
pub fn having_sql(&self) -> Option<&str> {
self.having_sql.as_deref()
}
pub fn having_filter(&self) -> Option<&Arc<dyn PhysicalExpr>> {
self.having_filter.as_ref()
}
pub fn compiled_projection(&self) -> Option<&CompiledProjection> {
self.compiled_projection.as_ref()
}
pub fn cached_pre_agg_plan(&self) -> Option<&datafusion_expr::LogicalPlan> {
self.cached_pre_agg_plan.as_ref()
}
#[cfg(test)]
pub fn open_window_count(&self) -> usize {
self.windows.len()
}
pub(crate) fn query_fingerprint(&self) -> u64 {
crate::aggregate_state::query_fingerprint(&self.pre_agg_sql, &self.output_schema)
}
pub(crate) fn estimated_size_bytes(&self) -> usize {
let mut total = 0;
for groups in self.windows.values() {
for (key, accs) in groups {
total += key.as_ref().len();
for acc in accs {
total += acc.size();
}
}
}
total
}
pub(crate) fn checkpoint_windows(&mut self) -> Result<EowcStateCheckpoint, DbError> {
use crate::aggregate_state::{scalar_to_json, GroupCheckpoint, WindowCheckpoint};
let fingerprint = self.query_fingerprint();
let mut windows = Vec::with_capacity(self.windows.len());
for (&window_start, groups) in &mut self.windows {
let mut group_checkpoints = Vec::with_capacity(groups.len());
for (key, accs) in groups {
let sv_key = crate::aggregate_state::row_to_scalar_key_with_types(
&self.row_converter,
key,
&self.group_types,
)?;
let key_json: Vec<serde_json::Value> = sv_key.iter().map(scalar_to_json).collect();
let mut acc_states = Vec::with_capacity(accs.len());
for acc in accs {
let state = acc
.state()
.map_err(|e| DbError::Pipeline(format!("accumulator state: {e}")))?;
acc_states.push(state.iter().map(scalar_to_json).collect());
}
group_checkpoints.push(GroupCheckpoint {
key: key_json,
acc_states,
last_updated_ms: i64::MIN,
});
}
windows.push(WindowCheckpoint {
window_start,
groups: group_checkpoints,
});
}
Ok(EowcStateCheckpoint {
fingerprint,
windows,
})
}
pub(crate) fn restore_windows(
&mut self,
checkpoint: &EowcStateCheckpoint,
) -> Result<usize, DbError> {
use crate::aggregate_state::json_to_scalar;
let current_fp = self.query_fingerprint();
if checkpoint.fingerprint != current_fp {
return Err(DbError::Pipeline(format!(
"EOWC checkpoint fingerprint mismatch: saved={}, current={}",
checkpoint.fingerprint, current_fp
)));
}
self.windows.clear();
let mut total_groups = 0usize;
for wc in &checkpoint.windows {
let mut groups = AHashMap::new();
for gc in &wc.groups {
let sv_key: Result<Vec<ScalarValue>, _> =
gc.key.iter().map(json_to_scalar).collect();
let sv_key = sv_key?;
let row_key = crate::aggregate_state::scalar_key_to_owned_row(
&self.row_converter,
&sv_key,
&self.group_types,
)?;
let mut accs = Vec::with_capacity(self.agg_specs.len());
for (i, spec) in self.agg_specs.iter().enumerate() {
let mut acc = spec.create_accumulator()?;
if i < gc.acc_states.len() {
let state_scalars: Result<Vec<ScalarValue>, _> =
gc.acc_states[i].iter().map(json_to_scalar).collect();
let state_scalars = state_scalars?;
let arrays: Vec<arrow::array::ArrayRef> = state_scalars
.iter()
.map(|sv| {
sv.to_array()
.map_err(|e| DbError::Pipeline(format!("scalar to array: {e}")))
})
.collect::<Result<_, _>>()?;
acc.merge_batch(&arrays)
.map_err(|e| DbError::Pipeline(format!("accumulator merge: {e}")))?;
}
accs.push(acc);
}
groups.insert(row_key, accs);
total_groups += 1;
}
self.windows.insert(wc.window_start, groups);
}
Ok(total_groups)
}
}
pub(crate) const NULL_TIMESTAMP: i64 = i64::MIN;
pub(crate) fn extract_i64_timestamps(
batch: &RecordBatch,
col_index: usize,
) -> Result<Vec<i64>, DbError> {
use arrow::array::{Array, Int64Array};
use arrow::datatypes::TimeUnit;
let col = batch.column(col_index);
let mut result = Vec::with_capacity(batch.num_rows());
match col.data_type() {
DataType::Int64 => {
let arr = col
.as_any()
.downcast_ref::<Int64Array>()
.ok_or_else(|| DbError::Pipeline("expected Int64Array".to_string()))?;
for i in 0..arr.len() {
result.push(if arr.is_null(i) {
NULL_TIMESTAMP
} else {
arr.value(i)
});
}
}
DataType::Timestamp(TimeUnit::Millisecond, _) => {
let arr = col
.as_any()
.downcast_ref::<arrow::array::TimestampMillisecondArray>()
.ok_or_else(|| {
DbError::Pipeline("expected TimestampMillisecondArray".to_string())
})?;
for i in 0..arr.len() {
result.push(if arr.is_null(i) {
NULL_TIMESTAMP
} else {
arr.value(i)
});
}
}
DataType::Timestamp(TimeUnit::Second, _) => {
let arr = col
.as_any()
.downcast_ref::<arrow::array::TimestampSecondArray>()
.ok_or_else(|| DbError::Pipeline("expected TimestampSecondArray".to_string()))?;
for i in 0..arr.len() {
if arr.is_null(i) {
result.push(NULL_TIMESTAMP);
continue;
}
let v = arr.value(i);
result.push(v.saturating_mul(1000));
}
}
DataType::Timestamp(TimeUnit::Microsecond, _) => {
let arr = col
.as_any()
.downcast_ref::<arrow::array::TimestampMicrosecondArray>()
.ok_or_else(|| {
DbError::Pipeline("expected TimestampMicrosecondArray".to_string())
})?;
for i in 0..arr.len() {
if arr.is_null(i) {
result.push(NULL_TIMESTAMP);
continue;
}
let v = arr.value(i);
result.push(v / 1000);
}
}
DataType::Timestamp(TimeUnit::Nanosecond, _) => {
let arr = col
.as_any()
.downcast_ref::<arrow::array::TimestampNanosecondArray>()
.ok_or_else(|| {
DbError::Pipeline("expected TimestampNanosecondArray".to_string())
})?;
for i in 0..arr.len() {
if arr.is_null(i) {
result.push(NULL_TIMESTAMP);
continue;
}
let v = arr.value(i);
result.push(v / 1_000_000);
}
}
other => {
return Err(DbError::Pipeline(format!(
"unsupported timestamp type for EOWC: {other}"
)));
}
}
Ok(result)
}
#[cfg(test)]
mod tests {
use super::*;
use arrow::array::Int64Array;
#[test]
fn test_tumbling_window_assignment_aligns_to_boundary() {
let wt = EowcWindowType::Tumbling { size_ms: 1000 };
assert_eq!(assign_windows(0, &wt), vec![0]);
assert_eq!(assign_windows(500, &wt), vec![0]);
assert_eq!(assign_windows(999, &wt), vec![0]);
assert_eq!(assign_windows(1000, &wt), vec![1000]);
assert_eq!(assign_windows(1500, &wt), vec![1000]);
assert_eq!(assign_windows(2000, &wt), vec![2000]);
}
#[test]
fn test_hopping_window_assignment_returns_multiple_windows() {
let wt = EowcWindowType::Hopping {
size_ms: 1000,
slide_ms: 500,
};
let mut ws = assign_windows(750, &wt);
ws.sort_unstable();
assert_eq!(ws, vec![0, 500]);
let mut ws = assign_windows(1250, &wt);
ws.sort_unstable();
assert_eq!(ws, vec![500, 1000]);
}
#[test]
#[should_panic(expected = "session window reached EOWC assign_windows")]
fn test_session_window_assignment_panics() {
let wt = EowcWindowType::Session { gap_ms: 5000 };
assign_windows(1234, &wt); }
#[test]
fn test_tumbling_window_zero_size_returns_zero() {
let wt = EowcWindowType::Tumbling { size_ms: 0 };
assert_eq!(assign_windows(500, &wt), vec![0]);
}
#[test]
fn test_extract_i64_timestamps_from_int64() {
let schema = Arc::new(Schema::new(vec![Field::new("ts", DataType::Int64, false)]));
let batch = RecordBatch::try_new(
schema,
vec![Arc::new(Int64Array::from(vec![100, 200, 300]))],
)
.unwrap();
let ts = extract_i64_timestamps(&batch, 0).unwrap();
assert_eq!(ts, vec![100, 200, 300]);
}
#[test]
fn test_extract_i64_timestamps_from_timestamp_millis() {
let schema = Arc::new(Schema::new(vec![Field::new(
"ts",
DataType::Timestamp(arrow::datatypes::TimeUnit::Millisecond, None),
false,
)]));
let batch = RecordBatch::try_new(
schema,
vec![Arc::new(arrow::array::TimestampMillisecondArray::from(
vec![100, 200, 300],
))],
)
.unwrap();
let ts = extract_i64_timestamps(&batch, 0).unwrap();
assert_eq!(ts, vec![100, 200, 300]);
}
fn make_pre_agg_batch(
groups: Vec<&str>,
values: Vec<i64>,
timestamps: Vec<i64>,
) -> RecordBatch {
use arrow::array::StringArray;
assert_eq!(groups.len(), values.len());
assert_eq!(groups.len(), timestamps.len());
let schema = Arc::new(Schema::new(vec![
Field::new("symbol", DataType::Utf8, false),
Field::new("__agg_input_1", DataType::Int64, false),
Field::new("__eowc_ts", DataType::Int64, false),
]));
RecordBatch::try_new(
schema,
vec![
Arc::new(StringArray::from(
groups.into_iter().map(String::from).collect::<Vec<_>>(),
)),
Arc::new(Int64Array::from(values)),
Arc::new(Int64Array::from(timestamps)),
],
)
.unwrap()
}
fn make_eowc_state(window_type: EowcWindowType) -> IncrementalEowcState {
use datafusion::execution::FunctionRegistry;
let ctx = SessionContext::new();
let udf = ctx.udaf("sum").expect("SUM should be registered");
let agg_specs = vec![AggFuncSpec {
udf,
input_types: vec![DataType::Int64],
input_col_indices: vec![1],
output_name: "total".to_string(),
return_type: DataType::Int64,
distinct: false,
is_count_star: false,
filter_col_index: None,
}];
let output_schema = Arc::new(Schema::new(vec![
Field::new("window_start", DataType::Int64, false),
Field::new("window_end", DataType::Int64, false),
Field::new("symbol", DataType::Utf8, true),
Field::new("total", DataType::Int64, true),
]));
IncrementalEowcState {
window_type,
windows: BTreeMap::new(),
agg_specs,
num_group_cols: 1,
group_types: vec![DataType::Utf8],
row_converter: arrow::row::RowConverter::new(vec![arrow::row::SortField::new(
DataType::Utf8,
)])
.unwrap(),
pre_agg_sql: String::new(),
output_schema,
time_col_index: 2,
compiled_projection: None,
cached_pre_agg_plan: None,
having_filter: None,
having_sql: None,
max_groups_per_window: 1_000_000,
allowed_lateness_ms: 0,
}
}
#[test]
fn test_incremental_eowc_tumbling_sum() {
let mut state = make_eowc_state(EowcWindowType::Tumbling { size_ms: 1000 });
let batch1 = make_pre_agg_batch(vec!["AAPL", "AAPL"], vec![10, 20], vec![100, 500]);
state.update_batch(&batch1).unwrap();
let batch2 = make_pre_agg_batch(vec!["AAPL"], vec![30], vec![800]);
state.update_batch(&batch2).unwrap();
assert_eq!(state.open_window_count(), 1);
let batches = state.close_windows(1000).unwrap();
assert_eq!(batches.len(), 1);
let result = &batches[0];
assert_eq!(result.num_rows(), 1);
let ws = result
.column(0)
.as_any()
.downcast_ref::<Int64Array>()
.unwrap();
assert_eq!(ws.value(0), 0);
let we = result
.column(1)
.as_any()
.downcast_ref::<Int64Array>()
.unwrap();
assert_eq!(we.value(0), 1000);
let total = result
.column(3)
.as_any()
.downcast_ref::<Int64Array>()
.unwrap();
assert_eq!(total.value(0), 60);
assert_eq!(state.open_window_count(), 0);
}
#[test]
fn test_incremental_eowc_multi_group() {
let mut state = make_eowc_state(EowcWindowType::Tumbling { size_ms: 1000 });
let batch = make_pre_agg_batch(
vec!["AAPL", "GOOG", "AAPL", "GOOG"],
vec![10, 100, 20, 200],
vec![100, 200, 300, 400],
);
state.update_batch(&batch).unwrap();
let batches = state.close_windows(1000).unwrap();
assert_eq!(batches.len(), 1);
assert_eq!(batches[0].num_rows(), 2);
}
#[test]
fn test_incremental_eowc_multi_window() {
let mut state = make_eowc_state(EowcWindowType::Tumbling { size_ms: 1000 });
let batch = make_pre_agg_batch(vec!["AAPL", "AAPL"], vec![10, 20], vec![500, 1500]);
state.update_batch(&batch).unwrap();
assert_eq!(state.open_window_count(), 2);
let batches = state.close_windows(1000).unwrap();
assert_eq!(batches.len(), 1);
let ws = batches[0]
.column(0)
.as_any()
.downcast_ref::<Int64Array>()
.unwrap();
assert_eq!(ws.value(0), 0);
assert_eq!(state.open_window_count(), 1);
let batches = state.close_windows(2000).unwrap();
assert_eq!(batches.len(), 1);
assert_eq!(state.open_window_count(), 0);
}
#[test]
fn test_close_windows_returns_empty_when_none_closed() {
let mut state = make_eowc_state(EowcWindowType::Tumbling { size_ms: 1000 });
let batch = make_pre_agg_batch(vec!["AAPL"], vec![10], vec![500]);
state.update_batch(&batch).unwrap();
let batches = state.close_windows(500).unwrap();
assert!(batches.is_empty());
assert_eq!(state.open_window_count(), 1);
}
#[test]
fn test_empty_batch_update_is_noop() {
let mut state = make_eowc_state(EowcWindowType::Tumbling { size_ms: 1000 });
let schema = Arc::new(Schema::new(vec![
Field::new("symbol", DataType::Utf8, false),
Field::new("__agg_input_1", DataType::Int64, false),
Field::new("__eowc_ts", DataType::Int64, false),
]));
let batch = RecordBatch::new_empty(schema);
state.update_batch(&batch).unwrap();
assert_eq!(state.open_window_count(), 0);
}
#[test]
fn test_window_close_emits_correct_schema() {
let mut state = make_eowc_state(EowcWindowType::Tumbling { size_ms: 1000 });
let batch = make_pre_agg_batch(vec!["AAPL"], vec![42], vec![100]);
state.update_batch(&batch).unwrap();
let batches = state.close_windows(1000).unwrap();
let result = &batches[0];
let schema = result.schema();
assert_eq!(schema.field(0).name(), "window_start");
assert_eq!(schema.field(1).name(), "window_end");
assert_eq!(schema.field(2).name(), "symbol");
assert_eq!(schema.field(3).name(), "total");
assert_eq!(schema.fields().len(), 4);
}
#[test]
fn test_eowc_checkpoint_roundtrip_tumbling() {
let mut state = make_eowc_state(EowcWindowType::Tumbling { size_ms: 1000 });
let batch = make_pre_agg_batch(
vec!["AAPL", "AAPL", "GOOG"],
vec![10, 20, 100],
vec![100, 200, 1500],
);
state.update_batch(&batch).unwrap();
assert_eq!(state.open_window_count(), 2);
let cp = state.checkpoint_windows().unwrap();
assert_eq!(cp.windows.len(), 2);
let mut state2 = make_eowc_state(EowcWindowType::Tumbling { size_ms: 1000 });
let restored = state2.restore_windows(&cp).unwrap();
assert!(restored > 0, "Should have restored groups");
assert_eq!(state2.open_window_count(), 2);
let batches = state2.close_windows(1000).unwrap();
assert_eq!(batches.len(), 1);
let result = &batches[0];
assert_eq!(result.num_rows(), 1);
let total = result
.column(3)
.as_any()
.downcast_ref::<Int64Array>()
.unwrap();
assert_eq!(total.value(0), 30, "SUM should be 10+20=30");
let batches = state2.close_windows(2000).unwrap();
assert_eq!(batches.len(), 1);
let total2 = batches[0]
.column(3)
.as_any()
.downcast_ref::<Int64Array>()
.unwrap();
assert_eq!(total2.value(0), 100, "SUM should be 100");
}
#[test]
fn test_eowc_checkpoint_empty_windows() {
let mut state = make_eowc_state(EowcWindowType::Tumbling { size_ms: 1000 });
let cp = state.checkpoint_windows().unwrap();
assert!(cp.windows.is_empty());
}
#[test]
fn test_eowc_checkpoint_fingerprint_mismatch() {
let mut state = make_eowc_state(EowcWindowType::Tumbling { size_ms: 1000 });
let batch = make_pre_agg_batch(vec!["AAPL"], vec![10], vec![100]);
state.update_batch(&batch).unwrap();
let mut cp = state.checkpoint_windows().unwrap();
cp.fingerprint = 12345;
let mut state2 = make_eowc_state(EowcWindowType::Tumbling { size_ms: 1000 });
let result = state2.restore_windows(&cp);
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("fingerprint mismatch"));
}
}