use std::hash::{Hash, Hasher};
use std::sync::Arc;
use ahash::AHashMap;
use rustc_hash::FxHashMap;
use arrow::array::ArrayRef;
use arrow::compute;
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
use arrow::record_batch::RecordBatch;
use datafusion::physical_expr::{create_physical_expr, PhysicalExpr};
use datafusion::prelude::SessionContext;
use datafusion_common::{DFSchema, ScalarValue};
use datafusion_expr::function::AccumulatorArgs;
use datafusion_expr::{AggregateUDF, LogicalPlan};
use serde_json::json;
use crate::error::DbError;
pub(crate) fn row_to_scalar_key_with_types(
converter: &arrow::row::RowConverter,
row_key: &arrow::row::OwnedRow,
group_types: &[DataType],
) -> Result<Vec<ScalarValue>, DbError> {
let row_as_cols = converter
.convert_rows(std::iter::once(row_key.row()))
.map_err(|e| DbError::Pipeline(format!("row→key: {e}")))?;
let mut sv_key = Vec::with_capacity(group_types.len());
for (col_idx, arr) in row_as_cols.iter().enumerate() {
let sv = ScalarValue::try_from_array(arr, 0)
.map_err(|e| DbError::Pipeline(format!("group key decode: {e}")))?;
if sv.data_type() == group_types[col_idx] {
sv_key.push(sv);
} else {
sv_key.push(sv.cast_to(&group_types[col_idx]).unwrap_or(sv));
}
}
Ok(sv_key)
}
pub(crate) fn emit_window_batch(
window_start: i64,
window_end: i64,
groups: ahash::AHashMap<Vec<ScalarValue>, Vec<Box<dyn datafusion_expr::Accumulator>>>,
group_types: &[DataType],
agg_specs: &[AggFuncSpec],
output_schema: &SchemaRef,
) -> Result<Option<RecordBatch>, DbError> {
let num_rows = groups.len();
if num_rows == 0 {
return Ok(None);
}
let num_group_cols = group_types.len();
let mut group_scalars: Vec<Vec<ScalarValue>> = (0..num_group_cols)
.map(|_| Vec::with_capacity(num_rows))
.collect();
let mut agg_scalars: Vec<Vec<ScalarValue>> = (0..agg_specs.len())
.map(|_| Vec::with_capacity(num_rows))
.collect();
for (key, mut accs) in groups {
for (i, sv) in key.into_iter().enumerate() {
group_scalars[i].push(sv);
}
for (i, acc) in accs.iter_mut().enumerate() {
let sv = acc
.evaluate()
.map_err(|e| DbError::Pipeline(format!("accumulator evaluate: {e}")))?;
agg_scalars[i].push(sv);
}
}
let win_start_array: ArrayRef =
Arc::new(arrow::array::Int64Array::from(vec![window_start; num_rows]));
let win_end_array: ArrayRef =
Arc::new(arrow::array::Int64Array::from(vec![window_end; num_rows]));
let mut group_arrays: Vec<ArrayRef> = Vec::with_capacity(num_group_cols);
for (col_idx, scalars) in group_scalars.into_iter().enumerate() {
let dt = &group_types[col_idx];
let array = ScalarValue::iter_to_array(scalars)
.map_err(|e| DbError::Pipeline(format!("group key array: {e}")))?;
if array.data_type() == dt {
group_arrays.push(array);
} else {
let casted = arrow::compute::cast(&array, dt).unwrap_or(array);
group_arrays.push(casted);
}
}
let mut agg_arrays: Vec<ArrayRef> = Vec::with_capacity(agg_specs.len());
for (agg_idx, scalars) in agg_scalars.into_iter().enumerate() {
let spec = &agg_specs[agg_idx];
let array = ScalarValue::iter_to_array(scalars)
.map_err(|e| DbError::Pipeline(format!("agg result array: {e}")))?;
if array.data_type() == &spec.return_type {
agg_arrays.push(array);
} else {
let casted = arrow::compute::cast(&array, &spec.return_type).unwrap_or(array);
agg_arrays.push(casted);
}
}
let mut all_arrays = vec![win_start_array, win_end_array];
all_arrays.extend(group_arrays);
all_arrays.extend(agg_arrays);
let batch = RecordBatch::try_new(Arc::clone(output_schema), all_arrays)
.map_err(|e| DbError::Pipeline(format!("result batch build: {e}")))?;
Ok(Some(batch))
}
pub(crate) fn scalar_to_json(sv: &ScalarValue) -> serde_json::Value {
match sv {
ScalarValue::Null => json!({"t": "N"}),
ScalarValue::Boolean(None) => json!({"t": "B", "v": null}),
ScalarValue::Boolean(Some(b)) => json!({"t": "B", "v": b}),
ScalarValue::Int8(None)
| ScalarValue::Int16(None)
| ScalarValue::Int32(None)
| ScalarValue::Int64(None) => json!({"t": "I64", "v": null}),
ScalarValue::Int8(Some(n)) => json!({"t": "I64", "v": i64::from(*n)}),
ScalarValue::Int16(Some(n)) => json!({"t": "I64", "v": i64::from(*n)}),
ScalarValue::Int32(Some(n)) => json!({"t": "I64", "v": i64::from(*n)}),
ScalarValue::Int64(Some(n)) => json!({"t": "I64", "v": n}),
ScalarValue::UInt8(None)
| ScalarValue::UInt16(None)
| ScalarValue::UInt32(None)
| ScalarValue::UInt64(None) => json!({"t": "U64", "v": null}),
ScalarValue::UInt8(Some(n)) => json!({"t": "U64", "v": u64::from(*n)}),
ScalarValue::UInt16(Some(n)) => json!({"t": "U64", "v": u64::from(*n)}),
ScalarValue::UInt32(Some(n)) => json!({"t": "U64", "v": u64::from(*n)}),
ScalarValue::UInt64(Some(n)) => json!({"t": "U64", "v": n}),
ScalarValue::Float32(None) | ScalarValue::Float64(None) => {
json!({"t": "F64", "v": null})
}
ScalarValue::Float32(Some(f)) => json!({"t": "F64", "v": f64::from(*f)}),
ScalarValue::Float64(Some(f)) => json!({"t": "F64", "v": f}),
ScalarValue::Utf8(None) | ScalarValue::LargeUtf8(None) | ScalarValue::Utf8View(None) => {
json!({"t": "S", "v": null})
}
ScalarValue::Utf8(Some(s))
| ScalarValue::LargeUtf8(Some(s))
| ScalarValue::Utf8View(Some(s)) => json!({"t": "S", "v": s}),
ScalarValue::List(arr) => {
use arrow::array::Array;
let list_arr: Option<&arrow::array::ListArray> = arr.as_any().downcast_ref();
match list_arr {
Some(list) if !list.is_empty() => {
let values = list.value(0);
let mut items = Vec::with_capacity(values.len());
for i in 0..values.len() {
let sv =
ScalarValue::try_from_array(&values, i).unwrap_or(ScalarValue::Null);
items.push(scalar_to_json(&sv));
}
json!({"t": "L", "v": items})
}
_ => json!({"t": "L", "v": []}),
}
}
other => json!({"t": "STR", "v": other.to_string()}),
}
}
pub(crate) fn json_to_scalar(v: &serde_json::Value) -> Result<ScalarValue, DbError> {
let t = v
.get("t")
.and_then(|t| t.as_str())
.ok_or_else(|| DbError::Pipeline("missing type tag in scalar JSON".to_string()))?;
let val = v.get("v");
match t {
"N" => Ok(ScalarValue::Null),
"B" => match val.and_then(serde_json::Value::as_bool) {
Some(b) => Ok(ScalarValue::Boolean(Some(b))),
None => Ok(ScalarValue::Boolean(None)),
},
"I64" => match val {
Some(serde_json::Value::Number(n)) => Ok(ScalarValue::Int64(n.as_i64())),
_ => Ok(ScalarValue::Int64(None)),
},
"U64" => match val {
Some(serde_json::Value::Number(n)) => Ok(ScalarValue::UInt64(n.as_u64())),
_ => Ok(ScalarValue::UInt64(None)),
},
"F64" => match val {
Some(serde_json::Value::Number(n)) => Ok(ScalarValue::Float64(n.as_f64())),
_ => Ok(ScalarValue::Float64(None)),
},
"S" => match val.and_then(|v| v.as_str()) {
Some(s) => Ok(ScalarValue::Utf8(Some(s.to_string()))),
None => Ok(ScalarValue::Utf8(None)),
},
"L" => {
let items = val
.and_then(|v| v.as_array())
.ok_or_else(|| DbError::Pipeline("expected array for List scalar".to_string()))?;
let scalars: Result<Vec<ScalarValue>, _> = items.iter().map(json_to_scalar).collect();
let scalars = scalars?;
if scalars.is_empty() {
Ok(ScalarValue::List(Arc::new(
arrow::array::GenericListArray::new_null(
Arc::new(Field::new("item", DataType::Null, true)),
1,
),
)))
} else {
let arr = ScalarValue::new_list(&scalars, &scalars[0].data_type(), true);
Ok(ScalarValue::List(arr))
}
}
other => Err(DbError::Pipeline(format!(
"unsupported scalar type tag in checkpoint: {other}"
))),
}
}
pub(crate) fn query_fingerprint(pre_agg_sql: &str, output_schema: &Schema) -> u64 {
let mut hasher = std::hash::DefaultHasher::new();
pre_agg_sql.hash(&mut hasher);
for field in output_schema.fields() {
field.name().hash(&mut hasher);
field.data_type().to_string().hash(&mut hasher);
}
hasher.finish()
}
#[derive(Clone, serde::Serialize, serde::Deserialize)]
pub(crate) struct GroupCheckpoint {
pub key: Vec<serde_json::Value>,
pub acc_states: Vec<Vec<serde_json::Value>>,
}
#[derive(Clone, serde::Serialize, serde::Deserialize)]
pub(crate) struct AggStateCheckpoint {
pub fingerprint: u64,
pub groups: Vec<GroupCheckpoint>,
}
#[derive(Clone, serde::Serialize, serde::Deserialize)]
pub(crate) struct WindowCheckpoint {
pub window_start: i64,
pub groups: Vec<GroupCheckpoint>,
}
#[derive(Clone, serde::Serialize, serde::Deserialize)]
pub(crate) struct EowcStateCheckpoint {
pub fingerprint: u64,
pub windows: Vec<WindowCheckpoint>,
}
#[derive(Clone, serde::Serialize, serde::Deserialize, Default)]
pub(crate) struct JoinStateCheckpoint {
#[serde(default)]
pub left_buffer_rows: u64,
#[serde(default)]
pub right_buffer_rows: u64,
#[serde(default)]
pub left_batches: Vec<Vec<u8>>,
#[serde(default)]
pub right_batches: Vec<Vec<u8>>,
#[serde(default)]
pub last_evicted_watermark: i64,
}
#[derive(serde::Serialize, serde::Deserialize)]
pub(crate) struct StreamExecutorCheckpoint {
pub version: u32,
#[serde(default)]
pub vnode_count: u16,
#[serde(default)]
pub agg_states: FxHashMap<String, AggStateCheckpoint>,
#[serde(default)]
pub eowc_states: FxHashMap<String, EowcStateCheckpoint>,
#[serde(default)]
pub core_window_states: FxHashMap<String, crate::core_window_state::CoreWindowCheckpoint>,
#[serde(default)]
pub join_states: FxHashMap<String, JoinStateCheckpoint>,
#[serde(default)]
pub raw_eowc_states: FxHashMap<String, RawEowcCheckpoint>,
}
const _: () = {
const fn assert_send<T: Send>() {}
assert_send::<StreamExecutorCheckpoint>();
};
#[derive(Clone, serde::Serialize, serde::Deserialize)]
pub(crate) struct RawEowcCheckpoint {
pub last_closed_boundary: i64,
pub accumulated_rows: usize,
pub sources: FxHashMap<String, Vec<Vec<u8>>>,
}
pub(crate) struct AggFuncSpec {
pub(crate) udf: Arc<AggregateUDF>,
pub(crate) input_types: Vec<DataType>,
pub(crate) input_col_indices: Vec<usize>,
pub(crate) output_name: String,
pub(crate) return_type: DataType,
pub(crate) distinct: bool,
pub(crate) filter_col_index: Option<usize>,
}
impl AggFuncSpec {
pub(crate) fn create_accumulator(
&self,
) -> Result<Box<dyn datafusion_expr::Accumulator>, DbError> {
let return_field = Arc::new(Field::new(
&self.output_name,
self.return_type.clone(),
true,
));
let schema = Schema::new(
self.input_types
.iter()
.enumerate()
.map(|(i, dt)| Field::new(format!("col_{i}"), dt.clone(), true))
.collect::<Vec<_>>(),
);
let expr_fields: Vec<Arc<Field>> = self
.input_types
.iter()
.enumerate()
.map(|(i, dt)| Arc::new(Field::new(format!("col_{i}"), dt.clone(), true)))
.collect();
let args = AccumulatorArgs {
return_field,
schema: &schema,
ignore_nulls: false,
order_bys: &[],
is_reversed: false,
name: self.udf.name(),
is_distinct: self.distinct,
exprs: &[],
expr_fields: &expr_fields,
};
self.udf.accumulator(args).map_err(|e| {
DbError::Pipeline(format!(
"accumulator creation failed for '{}': {e}",
self.udf.name()
))
})
}
}
pub(crate) struct IncrementalAggState {
pre_agg_sql: String,
num_group_cols: usize,
#[allow(dead_code)]
group_col_names: Vec<String>,
group_types: Vec<DataType>,
agg_specs: Vec<AggFuncSpec>,
groups: AHashMap<Vec<ScalarValue>, Vec<Box<dyn datafusion_expr::Accumulator>>>,
output_schema: SchemaRef,
compiled_projection: Option<CompiledProjection>,
cached_pre_agg_plan: Option<LogicalPlan>,
having_filter: Option<Arc<dyn PhysicalExpr>>,
having_sql: Option<String>,
max_groups: usize,
}
impl IncrementalAggState {
#[allow(clippy::too_many_lines)]
pub async fn try_from_sql(ctx: &SessionContext, sql: &str) -> Result<Option<Self>, DbError> {
let df = ctx
.sql(sql)
.await
.map_err(|e| DbError::Pipeline(format!("plan error: {e}")))?;
let plan = df.logical_plan();
let top_schema = Arc::new(plan.schema().as_arrow().clone());
let Some(agg_info) = find_aggregate(plan) else {
return Ok(None);
};
let group_exprs = agg_info.group_exprs;
let aggr_exprs = agg_info.aggr_exprs;
let agg_schema = agg_info.schema;
let input_schema = agg_info.input_schema;
let having_predicate = agg_info.having_predicate;
if aggr_exprs.is_empty() {
return Ok(None);
}
if top_schema.fields().len() != agg_schema.fields().len() {
return Ok(None);
}
for (top_f, agg_f) in top_schema.fields().iter().zip(agg_schema.fields()) {
if top_f.data_type() != agg_f.data_type() {
return Ok(None);
}
}
let num_group_cols = group_exprs.len();
let mut group_col_names = Vec::new();
let mut group_types = Vec::new();
for i in 0..num_group_cols {
let top_field = top_schema.field(i);
let agg_field = agg_schema.field(i);
group_col_names.push(top_field.name().clone());
group_types.push(agg_field.data_type().clone());
}
let compile_source = crate::stream_executor::single_source_table(sql);
let state = ctx.state();
let props = state.execution_props();
let input_df_schema = &agg_info.input_df_schema;
let mut compiled_exprs: Vec<Arc<dyn PhysicalExpr>> = Vec::new();
let mut proj_fields: Vec<Field> = Vec::new();
let mut compile_ok = compile_source.is_some();
let mut agg_specs = Vec::new();
let mut pre_agg_select_items: Vec<String> = Vec::new();
for (i, group_expr) in group_exprs.iter().enumerate() {
if let datafusion_expr::Expr::Column(col) = group_expr {
pre_agg_select_items.push(format!("\"{}\"", col.name));
} else {
let group_sql = expr_to_sql(group_expr);
pre_agg_select_items.push(format!("{group_sql} AS \"__group_{i}\""));
}
if compile_ok {
match create_physical_expr(group_expr, input_df_schema, props) {
Ok(phys) => {
let dt = phys
.data_type(input_df_schema.as_arrow())
.unwrap_or(DataType::Utf8);
let name = match group_expr {
datafusion_expr::Expr::Column(col) => col.name.clone(),
_ => format!("__group_{i}"),
};
proj_fields.push(Field::new(name, dt, true));
compiled_exprs.push(phys);
}
Err(_) => compile_ok = false,
}
}
}
let mut next_col_idx = num_group_cols;
for (i, expr) in aggr_exprs.iter().enumerate() {
let agg_schema_idx = num_group_cols + i;
let agg_field = agg_schema.field(agg_schema_idx);
let output_name = if agg_schema_idx < top_schema.fields().len() {
top_schema.field(agg_schema_idx).name().clone()
} else {
agg_field.name().clone()
};
if let datafusion_expr::Expr::AggregateFunction(agg_func) = expr {
let udf = Arc::clone(&agg_func.func);
let is_distinct = agg_func.params.distinct;
let mut input_col_indices = Vec::new();
let mut input_types = Vec::new();
if agg_func.params.args.is_empty() {
let col_idx = next_col_idx;
next_col_idx += 1;
pre_agg_select_items.push(format!("TRUE AS \"__agg_input_{col_idx}\""));
input_col_indices.push(col_idx);
input_types.push(DataType::Boolean);
if compile_ok {
match create_physical_expr(
&datafusion_expr::lit(true),
input_df_schema,
props,
) {
Ok(phys) => {
proj_fields.push(Field::new(
format!("__agg_input_{col_idx}"),
DataType::Boolean,
true,
));
compiled_exprs.push(phys);
}
Err(_) => compile_ok = false,
}
}
} else {
for arg_expr in &agg_func.params.args {
let col_idx = next_col_idx;
next_col_idx += 1;
let expr_sql = expr_to_sql(arg_expr);
if let Some(filter_expr) = &agg_func.params.filter {
let filter_sql = expr_to_sql(filter_expr);
pre_agg_select_items.push(format!(
"CASE WHEN {filter_sql} THEN {expr_sql} ELSE NULL END AS \"__agg_input_{col_idx}\""
));
if compile_ok {
let case_expr =
datafusion_expr::Expr::Case(datafusion_expr::expr::Case {
expr: None,
when_then_expr: vec![(
Box::new(filter_expr.as_ref().clone()),
Box::new(arg_expr.clone()),
)],
else_expr: Some(Box::new(datafusion_expr::lit(
ScalarValue::Null,
))),
});
match create_physical_expr(&case_expr, input_df_schema, props) {
Ok(phys) => {
let dt = resolve_expr_type(
arg_expr,
&input_schema,
agg_field.data_type(),
);
proj_fields.push(Field::new(
format!("__agg_input_{col_idx}"),
dt,
true,
));
compiled_exprs.push(phys);
}
Err(_) => compile_ok = false,
}
}
} else {
pre_agg_select_items
.push(format!("{expr_sql} AS \"__agg_input_{col_idx}\""));
if compile_ok {
match create_physical_expr(arg_expr, input_df_schema, props) {
Ok(phys) => {
let dt = resolve_expr_type(
arg_expr,
&input_schema,
agg_field.data_type(),
);
proj_fields.push(Field::new(
format!("__agg_input_{col_idx}"),
dt,
true,
));
compiled_exprs.push(phys);
}
Err(_) => compile_ok = false,
}
}
}
input_col_indices.push(col_idx);
let dt = resolve_expr_type(arg_expr, &input_schema, agg_field.data_type());
input_types.push(dt);
}
}
let filter_col_index = if let Some(filter_expr) = &agg_func.params.filter {
let col_idx = next_col_idx;
next_col_idx += 1;
let filter_sql = expr_to_sql(filter_expr);
pre_agg_select_items.push(format!(
"CASE WHEN {filter_sql} THEN TRUE ELSE FALSE END AS \"__agg_filter_{col_idx}\""
));
if compile_ok {
let case_expr = datafusion_expr::Expr::Case(datafusion_expr::expr::Case {
expr: None,
when_then_expr: vec![(
Box::new(filter_expr.as_ref().clone()),
Box::new(datafusion_expr::lit(true)),
)],
else_expr: Some(Box::new(datafusion_expr::lit(false))),
});
match create_physical_expr(&case_expr, input_df_schema, props) {
Ok(phys) => {
proj_fields.push(Field::new(
format!("__agg_filter_{col_idx}"),
DataType::Boolean,
true,
));
compiled_exprs.push(phys);
}
Err(_) => compile_ok = false,
}
}
Some(col_idx)
} else {
None
};
let return_type = udf
.return_type(&input_types)
.unwrap_or_else(|_| agg_field.data_type().clone());
agg_specs.push(AggFuncSpec {
udf,
input_types,
input_col_indices,
output_name,
return_type,
distinct: is_distinct,
filter_col_index,
});
} else {
return Ok(None);
}
}
let clauses = extract_clauses(sql);
let pre_agg_sql = format!(
"SELECT {} FROM {}{}",
pre_agg_select_items.join(", "),
clauses.from_clause,
clauses.where_clause,
);
let compiled_projection = if compile_ok {
let source_table = compile_source.unwrap();
let filter = if let Some(where_pred) = &agg_info.where_predicate {
if let Ok(phys) = create_physical_expr(where_pred, input_df_schema, props) {
Some(phys)
} else {
compile_ok = false;
None
}
} else {
None
};
if compile_ok {
Some(CompiledProjection {
source_table,
exprs: compiled_exprs,
filter,
output_schema: Arc::new(Schema::new(proj_fields)),
})
} else {
None
}
} else {
None
};
let mut output_fields: Vec<Field> = Vec::new();
for (name, dt) in group_col_names.iter().zip(group_types.iter()) {
output_fields.push(Field::new(name, dt.clone(), true));
}
for spec in &agg_specs {
output_fields.push(Field::new(
&spec.output_name,
spec.return_type.clone(),
true,
));
}
let output_schema = Arc::new(Schema::new(output_fields));
let having_filter = compile_having_filter(ctx, having_predicate.as_ref(), &output_schema);
let having_sql = if having_filter.is_none() {
having_predicate.as_ref().map(expr_to_sql)
} else {
None
};
let cached_pre_agg_plan = if compiled_projection.is_none() {
match ctx.sql(&pre_agg_sql).await {
Ok(df) => Some(df.logical_plan().clone()),
Err(e) => {
return Err(DbError::Pipeline(format!(
"pre-agg SQL planning failed for aggregate: {e}"
)));
}
}
} else {
None
};
Ok(Some(Self {
pre_agg_sql,
num_group_cols,
group_col_names,
group_types,
agg_specs,
groups: AHashMap::new(),
output_schema,
compiled_projection,
cached_pre_agg_plan,
having_filter,
having_sql,
max_groups: 1_000_000,
}))
}
pub fn process_batch(&mut self, batch: &RecordBatch) -> Result<(), DbError> {
if batch.num_rows() == 0 {
return Ok(());
}
if self.num_group_cols == 0 {
return self.process_batch_no_groups(batch);
}
let group_cols: Vec<ArrayRef> = (0..self.num_group_cols)
.map(|i| Arc::clone(batch.column(i)))
.collect();
let sort_fields: Vec<arrow::row::SortField> = group_cols
.iter()
.map(|c| arrow::row::SortField::new(c.data_type().clone()))
.collect();
let converter = arrow::row::RowConverter::new(sort_fields)
.map_err(|e| DbError::Pipeline(format!("row converter: {e}")))?;
let rows = converter
.convert_columns(&group_cols)
.map_err(|e| DbError::Pipeline(format!("row conversion: {e}")))?;
let estimated_groups = (batch.num_rows() / 4).max(16);
let mut group_indices: FxHashMap<arrow::row::OwnedRow, Vec<u32>> =
FxHashMap::with_capacity_and_hasher(estimated_groups, rustc_hash::FxBuildHasher);
for row_idx in 0..batch.num_rows() {
#[allow(clippy::cast_possible_truncation)]
group_indices
.entry(rows.row(row_idx).owned())
.or_default()
.push(row_idx as u32);
}
for (row_key, indices) in &group_indices {
let sv_key = self.row_to_scalar_key(&converter, row_key)?;
if !self.groups.contains_key(&sv_key) {
if self.groups.len() >= self.max_groups {
tracing::warn!(
max_groups = self.max_groups,
current_groups = self.groups.len(),
"group cardinality limit reached, dropping new group"
);
continue;
}
let mut accs = Vec::with_capacity(self.agg_specs.len());
for spec in &self.agg_specs {
accs.push(spec.create_accumulator()?);
}
self.groups.insert(sv_key.clone(), accs);
}
let Some(accs) = self.groups.get_mut(&sv_key) else {
continue;
};
Self::update_group_accumulators(accs, batch, indices, &self.agg_specs)?;
}
Ok(())
}
fn process_batch_no_groups(&mut self, batch: &RecordBatch) -> Result<(), DbError> {
let empty_key: Vec<ScalarValue> = Vec::new();
if !self.groups.contains_key(&empty_key) {
let mut accs = Vec::with_capacity(self.agg_specs.len());
for spec in &self.agg_specs {
accs.push(spec.create_accumulator()?);
}
self.groups.insert(empty_key.clone(), accs);
}
let accs = self.groups.get_mut(&empty_key).unwrap();
#[allow(clippy::cast_possible_truncation)]
let all_indices: Vec<u32> = (0..batch.num_rows() as u32).collect();
Self::update_group_accumulators(accs, batch, &all_indices, &self.agg_specs)
}
fn row_to_scalar_key(
&self,
converter: &arrow::row::RowConverter,
row_key: &arrow::row::OwnedRow,
) -> Result<Vec<ScalarValue>, DbError> {
row_to_scalar_key_with_types(converter, row_key, &self.group_types)
}
pub(crate) fn update_group_accumulators(
accs: &mut [Box<dyn datafusion_expr::Accumulator>],
batch: &RecordBatch,
indices: &[u32],
agg_specs: &[AggFuncSpec],
) -> Result<(), DbError> {
let index_array = arrow::array::UInt32Array::from(indices.to_vec());
for (i, spec) in agg_specs.iter().enumerate() {
let mut input_arrays: Vec<ArrayRef> = Vec::with_capacity(spec.input_col_indices.len());
for &col_idx in &spec.input_col_indices {
let arr = compute::take(batch.column(col_idx), &index_array, None)
.map_err(|e| DbError::Pipeline(format!("array take failed: {e}")))?;
input_arrays.push(arr);
}
if let Some(filter_idx) = spec.filter_col_index {
let filter_arr = compute::take(batch.column(filter_idx), &index_array, None)
.map_err(|e| DbError::Pipeline(format!("filter take: {e}")))?;
if let Some(mask) = filter_arr
.as_any()
.downcast_ref::<arrow::array::BooleanArray>()
{
let mut filtered = Vec::with_capacity(input_arrays.len());
for arr in &input_arrays {
filtered.push(
compute::filter(arr, mask)
.map_err(|e| DbError::Pipeline(format!("filter apply: {e}")))?,
);
}
input_arrays = filtered;
}
}
accs[i]
.update_batch(&input_arrays)
.map_err(|e| DbError::Pipeline(format!("accumulator update: {e}")))?;
}
Ok(())
}
pub fn emit(&mut self) -> Result<Vec<RecordBatch>, DbError> {
if self.groups.is_empty() {
return Ok(Vec::new());
}
let num_rows = self.groups.len();
let mut group_arrays: Vec<ArrayRef> = Vec::with_capacity(self.num_group_cols);
for (col_idx, dt) in self.group_types.iter().enumerate() {
let scalars: Vec<ScalarValue> =
self.groups.keys().map(|key| key[col_idx].clone()).collect();
let array = ScalarValue::iter_to_array(scalars)
.map_err(|e| DbError::Pipeline(format!("group key array build: {e}")))?;
if array.data_type() == dt {
group_arrays.push(array);
} else {
let casted = arrow::compute::cast(&array, dt).unwrap_or(array);
group_arrays.push(casted);
}
}
let mut agg_arrays: Vec<ArrayRef> = Vec::with_capacity(self.agg_specs.len());
for (agg_idx, spec) in self.agg_specs.iter().enumerate() {
let mut scalars: Vec<ScalarValue> = Vec::with_capacity(num_rows);
for accs in self.groups.values_mut() {
let sv = accs[agg_idx]
.evaluate()
.map_err(|e| DbError::Pipeline(format!("accumulator evaluate: {e}")))?;
scalars.push(sv);
}
let array = ScalarValue::iter_to_array(scalars)
.map_err(|e| DbError::Pipeline(format!("agg result array build: {e}")))?;
if array.data_type() == &spec.return_type {
agg_arrays.push(array);
} else {
let casted = arrow::compute::cast(&array, &spec.return_type).unwrap_or(array);
agg_arrays.push(casted);
}
}
let mut all_arrays = group_arrays;
all_arrays.extend(agg_arrays);
let batch = RecordBatch::try_new(Arc::clone(&self.output_schema), all_arrays)
.map_err(|e| DbError::Pipeline(format!("result batch build: {e}")))?;
Ok(vec![batch])
}
#[allow(dead_code)] pub fn pre_agg_sql(&self) -> &str {
&self.pre_agg_sql
}
pub fn having_filter(&self) -> Option<&Arc<dyn PhysicalExpr>> {
self.having_filter.as_ref()
}
pub fn having_sql(&self) -> Option<&str> {
self.having_sql.as_deref()
}
pub fn compiled_projection(&self) -> Option<&CompiledProjection> {
self.compiled_projection.as_ref()
}
pub fn cached_pre_agg_plan(&self) -> Option<&LogicalPlan> {
self.cached_pre_agg_plan.as_ref()
}
pub(crate) fn query_fingerprint(&self) -> u64 {
query_fingerprint(&self.pre_agg_sql, &self.output_schema)
}
pub(crate) fn estimated_size_bytes(&self) -> usize {
let mut total = 0;
for (key, accs) in &self.groups {
for sv in key {
total += sv.size();
}
for acc in accs {
total += acc.size();
}
}
total
}
#[allow(dead_code)]
pub(crate) fn group_count(&self) -> usize {
self.groups.len()
}
pub(crate) fn checkpoint_groups(&mut self) -> Result<AggStateCheckpoint, DbError> {
let fingerprint = self.query_fingerprint();
let mut groups = Vec::with_capacity(self.groups.len());
for (key, accs) in &mut self.groups {
let key_json: Vec<serde_json::Value> = key.iter().map(scalar_to_json).collect();
let mut acc_states = Vec::with_capacity(accs.len());
for acc in accs {
let state = acc
.state()
.map_err(|e| DbError::Pipeline(format!("accumulator state: {e}")))?;
acc_states.push(state.iter().map(scalar_to_json).collect());
}
groups.push(GroupCheckpoint {
key: key_json,
acc_states,
});
}
Ok(AggStateCheckpoint {
fingerprint,
groups,
})
}
pub(crate) fn restore_groups(
&mut self,
checkpoint: &AggStateCheckpoint,
) -> Result<usize, DbError> {
let current_fp = self.query_fingerprint();
if checkpoint.fingerprint != current_fp {
return Err(DbError::Pipeline(format!(
"checkpoint fingerprint mismatch: saved={}, current={}",
checkpoint.fingerprint, current_fp
)));
}
self.groups.clear();
for gc in &checkpoint.groups {
let key: Result<Vec<ScalarValue>, _> = gc.key.iter().map(json_to_scalar).collect();
let key = key?;
let mut accs = Vec::with_capacity(self.agg_specs.len());
for (i, spec) in self.agg_specs.iter().enumerate() {
let mut acc = spec.create_accumulator()?;
if i < gc.acc_states.len() {
let state_scalars: Result<Vec<ScalarValue>, _> =
gc.acc_states[i].iter().map(json_to_scalar).collect();
let state_scalars = state_scalars?;
let arrays: Vec<ArrayRef> = state_scalars
.iter()
.map(|sv| {
sv.to_array()
.map_err(|e| DbError::Pipeline(format!("scalar to array: {e}")))
})
.collect::<Result<_, _>>()?;
acc.merge_batch(&arrays)
.map_err(|e| DbError::Pipeline(format!("accumulator merge: {e}")))?;
}
accs.push(acc);
}
self.groups.insert(key, accs);
}
Ok(checkpoint.groups.len())
}
}
pub(crate) struct AggregateInfo {
pub(crate) group_exprs: Vec<datafusion_expr::Expr>,
pub(crate) aggr_exprs: Vec<datafusion_expr::Expr>,
pub(crate) schema: Arc<Schema>,
pub(crate) input_schema: Arc<Schema>,
pub(crate) having_predicate: Option<datafusion_expr::Expr>,
pub(crate) input_df_schema: Arc<DFSchema>,
pub(crate) where_predicate: Option<datafusion_expr::Expr>,
}
pub(crate) fn find_aggregate(plan: &LogicalPlan) -> Option<AggregateInfo> {
find_aggregate_inner(plan, None)
}
fn find_aggregate_inner(
plan: &LogicalPlan,
parent_filter: Option<&datafusion_expr::Expr>,
) -> Option<AggregateInfo> {
match plan {
LogicalPlan::Aggregate(agg) => {
let schema = Arc::new(agg.schema.as_arrow().clone());
let input_schema = Arc::new(agg.input.schema().as_arrow().clone());
let input_df_schema = Arc::clone(agg.input.schema());
let where_predicate = extract_where_predicate(&agg.input);
Some(AggregateInfo {
group_exprs: agg.group_expr.clone(),
aggr_exprs: agg.aggr_expr.clone(),
schema,
input_schema,
having_predicate: parent_filter.cloned(),
input_df_schema,
where_predicate,
})
}
LogicalPlan::Filter(filter) => {
if matches!(&*filter.input, LogicalPlan::Aggregate(_)) {
find_aggregate_inner(&filter.input, Some(&filter.predicate))
} else {
find_aggregate_inner(&filter.input, None)
}
}
LogicalPlan::Projection(proj) => find_aggregate_inner(&proj.input, None),
LogicalPlan::Sort(sort) => find_aggregate_inner(&sort.input, None),
LogicalPlan::Limit(limit) => find_aggregate_inner(&limit.input, None),
LogicalPlan::SubqueryAlias(alias) => find_aggregate_inner(&alias.input, None),
_ => {
for input in plan.inputs() {
if let Some(result) = find_aggregate_inner(input, None) {
return Some(result);
}
}
None
}
}
}
fn extract_where_predicate(plan: &LogicalPlan) -> Option<datafusion_expr::Expr> {
match plan {
LogicalPlan::Filter(f) => Some(f.predicate.clone()),
LogicalPlan::Projection(p) => extract_where_predicate(&p.input),
LogicalPlan::Sort(s) => extract_where_predicate(&s.input),
LogicalPlan::Limit(l) => extract_where_predicate(&l.input),
LogicalPlan::SubqueryAlias(a) => extract_where_predicate(&a.input),
_ => None,
}
}
fn scalar_value_to_sql(sv: &ScalarValue) -> String {
match sv {
ScalarValue::Utf8(Some(s))
| ScalarValue::LargeUtf8(Some(s))
| ScalarValue::Utf8View(Some(s)) => {
format!("'{}'", s.replace('\'', "''"))
}
ScalarValue::Utf8(None)
| ScalarValue::LargeUtf8(None)
| ScalarValue::Utf8View(None)
| ScalarValue::Null
| ScalarValue::Boolean(None) => "NULL".to_string(),
ScalarValue::Boolean(Some(b)) => if *b { "TRUE" } else { "FALSE" }.to_string(),
ScalarValue::IntervalDayTime(Some(v)) => {
let mut parts = Vec::new();
if v.days != 0 {
parts.push(format!("{} days", v.days));
}
if v.milliseconds != 0 || parts.is_empty() {
let abs_ms = v.milliseconds.unsigned_abs();
let secs = abs_ms / 1000;
let frac = abs_ms % 1000;
let sign = if v.milliseconds < 0 { "-" } else { "" };
if frac == 0 {
parts.push(format!("{sign}{secs} seconds"));
} else {
parts.push(format!("{sign}{secs}.{frac:03} seconds"));
}
}
format!("INTERVAL '{}'", parts.join(" "))
}
ScalarValue::IntervalYearMonth(Some(v)) => {
let years = v / 12;
let months = v % 12;
let mut parts = Vec::new();
if years != 0 {
parts.push(format!("{years} years"));
}
if months != 0 || parts.is_empty() {
parts.push(format!("{months} months"));
}
format!("INTERVAL '{}'", parts.join(" "))
}
ScalarValue::IntervalMonthDayNano(Some(v)) => {
let mut parts = Vec::new();
if v.months != 0 {
parts.push(format!("{} months", v.months));
}
if v.days != 0 {
parts.push(format!("{} days", v.days));
}
let nanos = v.nanoseconds;
if nanos != 0 || parts.is_empty() {
let abs_ns = nanos.unsigned_abs();
let secs = abs_ns / 1_000_000_000;
let remainder_ns = abs_ns % 1_000_000_000;
let sign = if nanos < 0 { "-" } else { "" };
if remainder_ns == 0 {
parts.push(format!("{sign}{secs} seconds"));
} else {
let millis = remainder_ns / 1_000_000;
parts.push(format!("{sign}{secs}.{millis:03} seconds"));
}
}
format!("INTERVAL '{}'", parts.join(" "))
}
_ => sv.to_string(),
}
}
fn case_to_sql(case: &datafusion_expr::expr::Case) -> String {
use std::fmt::Write;
let mut sql = String::from("CASE");
if let Some(operand) = &case.expr {
let _ = write!(sql, " {}", expr_to_sql(operand));
}
for (when_expr, then_expr) in &case.when_then_expr {
let _ = write!(
sql,
" WHEN {} THEN {}",
expr_to_sql(when_expr),
expr_to_sql(then_expr)
);
}
if let Some(else_expr) = &case.else_expr {
let _ = write!(sql, " ELSE {}", expr_to_sql(else_expr));
}
sql.push_str(" END");
sql
}
pub(crate) fn expr_to_sql(expr: &datafusion_expr::Expr) -> String {
use datafusion_expr::Expr;
match expr {
Expr::Column(col) => format!("\"{}\"", col.name),
Expr::Literal(sv, _) => scalar_value_to_sql(sv),
Expr::Alias(alias) => expr_to_sql(&alias.expr),
Expr::BinaryExpr(bin) => {
let left = expr_to_sql(&bin.left);
let right = expr_to_sql(&bin.right);
format!("({left} {op} {right})", op = bin.op)
}
Expr::Cast(cast) => {
let inner = expr_to_sql(&cast.expr);
format!("CAST({inner} AS {})", cast.data_type)
}
Expr::TryCast(cast) => {
let inner = expr_to_sql(&cast.expr);
format!("TRY_CAST({inner} AS {})", cast.data_type)
}
Expr::ScalarFunction(func) => {
let args: Vec<String> = func.args.iter().map(expr_to_sql).collect();
format!("{}({})", func.func.name(), args.join(", "))
}
Expr::AggregateFunction(agg) => {
let name = agg.func.name();
let args: Vec<String> = agg.params.args.iter().map(expr_to_sql).collect();
if agg.params.distinct {
format!("{name}(DISTINCT {})", args.join(", "))
} else {
format!("{name}({})", args.join(", "))
}
}
Expr::Case(case) => case_to_sql(case),
Expr::Not(inner) => format!("(NOT {})", expr_to_sql(inner)),
Expr::Negative(inner) => format!("(-{})", expr_to_sql(inner)),
Expr::IsNull(inner) => {
format!("({} IS NULL)", expr_to_sql(inner))
}
Expr::IsNotNull(inner) => {
format!("({} IS NOT NULL)", expr_to_sql(inner))
}
Expr::IsTrue(inner) => {
format!("({} IS TRUE)", expr_to_sql(inner))
}
Expr::IsFalse(inner) => {
format!("({} IS FALSE)", expr_to_sql(inner))
}
Expr::IsNotTrue(inner) => {
format!("({} IS NOT TRUE)", expr_to_sql(inner))
}
Expr::IsNotFalse(inner) => {
format!("({} IS NOT FALSE)", expr_to_sql(inner))
}
Expr::Between(between) => {
let e = expr_to_sql(&between.expr);
let low = expr_to_sql(&between.low);
let high = expr_to_sql(&between.high);
let not = if between.negated { " NOT" } else { "" };
format!("({e}{not} BETWEEN {low} AND {high})")
}
Expr::InList(in_list) => {
let e = expr_to_sql(&in_list.expr);
let items: Vec<String> = in_list.list.iter().map(expr_to_sql).collect();
let not = if in_list.negated { " NOT" } else { "" };
format!("({e}{not} IN ({}))", items.join(", "))
}
Expr::Like(like) => {
let e = expr_to_sql(&like.expr);
let pat = expr_to_sql(&like.pattern);
let kw = if like.case_insensitive {
"ILIKE"
} else {
"LIKE"
};
let not = if like.negated { " NOT" } else { "" };
if let Some(esc) = &like.escape_char {
format!("({e}{not} {kw} {pat} ESCAPE '{esc}')")
} else {
format!("({e}{not} {kw} {pat})")
}
}
#[allow(deprecated)]
Expr::Wildcard { .. } => "TRUE".to_string(),
other => other.to_string(),
}
}
#[allow(dead_code)] pub(crate) struct CompiledProjection {
pub(crate) source_table: String,
pub(crate) exprs: Vec<Arc<dyn PhysicalExpr>>,
pub(crate) filter: Option<Arc<dyn PhysicalExpr>>,
pub(crate) output_schema: SchemaRef,
}
impl CompiledProjection {
pub(crate) fn source_table(&self) -> &str {
&self.source_table
}
pub(crate) fn evaluate(&self, batch: &RecordBatch) -> Result<RecordBatch, DbError> {
if batch.num_rows() == 0 {
return Ok(RecordBatch::new_empty(Arc::clone(&self.output_schema)));
}
let filtered = if let Some(ref filter) = self.filter {
let result = filter
.evaluate(batch)
.map_err(|e| DbError::Pipeline(format!("WHERE filter evaluate: {e}")))?;
let mask = result
.into_array(batch.num_rows())
.map_err(|e| DbError::Pipeline(format!("WHERE filter to array: {e}")))?;
let bool_arr = mask
.as_any()
.downcast_ref::<arrow::array::BooleanArray>()
.ok_or_else(|| DbError::Pipeline("WHERE filter not boolean".into()))?;
arrow::compute::filter_record_batch(batch, bool_arr)
.map_err(|e| DbError::Pipeline(format!("WHERE filter: {e}")))?
} else {
batch.clone()
};
if filtered.num_rows() == 0 {
return Ok(RecordBatch::new_empty(Arc::clone(&self.output_schema)));
}
let mut arrays = Vec::with_capacity(self.exprs.len());
for expr in &self.exprs {
let result = expr
.evaluate(&filtered)
.map_err(|e| DbError::Pipeline(format!("projection evaluate: {e}")))?;
let arr = result
.into_array(filtered.num_rows())
.map_err(|e| DbError::Pipeline(format!("projection to array: {e}")))?;
arrays.push(arr);
}
RecordBatch::try_new(Arc::clone(&self.output_schema), arrays)
.map_err(|e| DbError::Pipeline(format!("projection batch build: {e}")))
}
}
pub(crate) fn apply_compiled_having(
batches: &[RecordBatch],
having_filter: &Arc<dyn PhysicalExpr>,
) -> Result<Vec<RecordBatch>, DbError> {
let mut result = Vec::with_capacity(batches.len());
for batch in batches {
if batch.num_rows() == 0 {
continue;
}
let mask_result = having_filter
.evaluate(batch)
.map_err(|e| DbError::Pipeline(format!("HAVING evaluate: {e}")))?;
let mask = mask_result
.into_array(batch.num_rows())
.map_err(|e| DbError::Pipeline(format!("HAVING to array: {e}")))?;
let bool_arr = mask
.as_any()
.downcast_ref::<arrow::array::BooleanArray>()
.ok_or_else(|| DbError::Pipeline("HAVING filter not boolean".into()))?;
let filtered = arrow::compute::filter_record_batch(batch, bool_arr)
.map_err(|e| DbError::Pipeline(format!("HAVING filter: {e}")))?;
if filtered.num_rows() > 0 {
result.push(filtered);
}
}
Ok(result)
}
pub(crate) fn compile_having_filter(
ctx: &SessionContext,
having_predicate: Option<&datafusion_expr::Expr>,
output_schema: &SchemaRef,
) -> Option<Arc<dyn PhysicalExpr>> {
let having_pred = having_predicate?;
let df_schema = DFSchema::try_from(output_schema.as_ref().clone()).ok()?;
let state = ctx.state();
let props = state.execution_props();
create_physical_expr(having_pred, &df_schema, props).ok()
}
pub(crate) struct SqlClauses {
pub(crate) from_clause: String,
pub(crate) where_clause: String,
}
pub(crate) fn extract_clauses(sql: &str) -> SqlClauses {
if let Ok(clauses) = extract_clauses_ast(sql) {
return clauses;
}
SqlClauses {
from_clause: extract_from_clause_heuristic(sql),
where_clause: extract_where_clause_heuristic(sql),
}
}
fn extract_clauses_ast(sql: &str) -> Result<SqlClauses, DbError> {
use sqlparser::dialect::GenericDialect;
use sqlparser::parser::Parser;
let dialect = GenericDialect {};
let stmts = Parser::parse_sql(&dialect, sql)
.map_err(|e| DbError::Pipeline(format!("SQL parse error: {e}")))?;
let stmt = stmts
.into_iter()
.next()
.ok_or_else(|| DbError::Pipeline("empty SQL statement".to_string()))?;
let sqlparser::ast::Statement::Query(query) = stmt else {
return Err(DbError::Pipeline("expected SELECT statement".to_string()));
};
let sqlparser::ast::SetExpr::Select(select) = *query.body else {
return Err(DbError::Pipeline("expected simple SELECT".to_string()));
};
let from_clause = select
.from
.iter()
.map(ToString::to_string)
.collect::<Vec<_>>()
.join(", ");
let where_clause = select
.selection
.as_ref()
.map(|expr| format!(" WHERE {expr}"))
.unwrap_or_default();
Ok(SqlClauses {
from_clause,
where_clause,
})
}
fn extract_from_clause_heuristic(sql: &str) -> String {
let upper = sql.to_uppercase();
let from_pos = upper.find(" FROM ").map(|p| p + 6);
let Some(start) = from_pos else {
return String::new();
};
let rest = &sql[start..];
let end_keywords = [" WHERE ", " GROUP ", " ORDER ", " LIMIT ", " HAVING "];
let end = end_keywords
.iter()
.filter_map(|kw| rest.to_uppercase().find(kw))
.min()
.unwrap_or(rest.len());
rest[..end].trim().to_string()
}
fn extract_where_clause_heuristic(sql: &str) -> String {
let upper = sql.to_uppercase();
let where_pos = upper.find(" WHERE ");
let Some(start) = where_pos else {
return String::new();
};
let rest = &sql[start..];
let end_keywords = [" GROUP ", " ORDER ", " LIMIT ", " HAVING "];
let end = end_keywords
.iter()
.filter_map(|kw| rest[7..].to_uppercase().find(kw).map(|p| p + 7))
.min()
.unwrap_or(rest.len());
format!(" {}", rest[..end].trim())
}
pub(crate) fn resolve_expr_type(
expr: &datafusion_expr::Expr,
input_schema: &Schema,
fallback_type: &DataType,
) -> DataType {
match expr {
datafusion_expr::Expr::Column(col) => input_schema
.field_with_name(&col.name)
.map_or_else(|_| fallback_type.clone(), |f| f.data_type().clone()),
datafusion_expr::Expr::Literal(sv, _) => sv.data_type(),
datafusion_expr::Expr::Cast(cast) => cast.data_type.clone(),
datafusion_expr::Expr::TryCast(cast) => cast.data_type.clone(),
datafusion_expr::Expr::BinaryExpr(bin) => {
resolve_expr_type(&bin.left, input_schema, fallback_type)
}
datafusion_expr::Expr::ScalarFunction(func) => {
let arg_types: Vec<DataType> = func
.args
.iter()
.map(|a| resolve_expr_type(a, input_schema, fallback_type))
.collect();
func.func
.return_type(&arg_types)
.unwrap_or_else(|_| fallback_type.clone())
}
#[allow(deprecated)]
datafusion_expr::Expr::Wildcard { .. } => DataType::Boolean,
_ => fallback_type.clone(),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_interval_day_time_seconds_only() {
use arrow::datatypes::IntervalDayTime;
let sv = ScalarValue::IntervalDayTime(Some(IntervalDayTime::new(0, 10_000)));
let sql = scalar_value_to_sql(&sv);
assert_eq!(sql, "INTERVAL '10 seconds'");
}
#[test]
fn test_interval_day_time_days_only() {
use arrow::datatypes::IntervalDayTime;
let sv = ScalarValue::IntervalDayTime(Some(IntervalDayTime::new(3, 0)));
let sql = scalar_value_to_sql(&sv);
assert_eq!(sql, "INTERVAL '3 days'");
}
#[test]
fn test_interval_day_time_mixed() {
use arrow::datatypes::IntervalDayTime;
let sv = ScalarValue::IntervalDayTime(Some(IntervalDayTime::new(1, 5_500)));
let sql = scalar_value_to_sql(&sv);
assert_eq!(sql, "INTERVAL '1 days 5.500 seconds'");
}
#[test]
fn test_interval_year_month() {
let sv = ScalarValue::IntervalYearMonth(Some(15));
let sql = scalar_value_to_sql(&sv);
assert_eq!(sql, "INTERVAL '1 years 3 months'");
}
#[test]
fn test_interval_month_day_nano() {
use arrow::datatypes::IntervalMonthDayNano;
let sv =
ScalarValue::IntervalMonthDayNano(Some(IntervalMonthDayNano::new(2, 1, 3_000_000_000)));
let sql = scalar_value_to_sql(&sv);
assert_eq!(sql, "INTERVAL '2 months 1 days 3 seconds'");
}
#[tokio::test]
async fn test_try_from_sql_rejects_post_aggregate_projection() {
let ctx = laminar_sql::create_session_context();
let schema = Arc::new(Schema::new(vec![
Field::new("name", DataType::Utf8, false),
Field::new("a", DataType::Float64, false),
Field::new("b", DataType::Float64, false),
]));
let batch = RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(arrow::array::StringArray::from(vec!["x"])),
Arc::new(arrow::array::Float64Array::from(vec![1.0])),
Arc::new(arrow::array::Float64Array::from(vec![2.0])),
],
)
.unwrap();
let mem_table =
datafusion::datasource::MemTable::try_new(schema, vec![vec![batch]]).unwrap();
ctx.register_table("events", Arc::new(mem_table)).unwrap();
let result = IncrementalAggState::try_from_sql(
&ctx,
"SELECT name, SUM(a) / SUM(b) AS ratio FROM events GROUP BY name",
)
.await
.unwrap();
assert!(
result.is_none(),
"Post-aggregate projection should return None"
);
}
#[test]
fn test_extract_clauses_simple() {
let c = extract_clauses("SELECT a, SUM(b) FROM trades GROUP BY a");
assert_eq!(c.from_clause, "trades");
assert!(c.where_clause.is_empty());
}
#[test]
fn test_extract_clauses_with_where() {
let c = extract_clauses("SELECT * FROM events WHERE x > 1 GROUP BY y");
assert_eq!(c.from_clause, "events");
assert!(
c.where_clause.contains("WHERE"),
"should contain WHERE: {}",
c.where_clause
);
assert!(
c.where_clause.contains("x > 1"),
"should contain predicate: {}",
c.where_clause
);
}
#[test]
fn test_extract_clauses_with_join() {
let c = extract_clauses("SELECT * FROM events e JOIN dim d ON e.id = d.id");
assert!(
c.from_clause.contains("events"),
"should contain events: {}",
c.from_clause
);
assert!(
c.from_clause.contains("JOIN"),
"should contain JOIN: {}",
c.from_clause
);
assert!(
c.from_clause.contains("dim"),
"should contain dim: {}",
c.from_clause
);
}
#[test]
fn test_extract_clauses_keyword_in_string_literal() {
let c =
extract_clauses("SELECT * FROM logs WHERE msg = 'joined GROUP chat' GROUP BY user_id");
assert_eq!(c.from_clause, "logs");
assert!(
c.where_clause.contains("GROUP chat"),
"string literal should be preserved: {}",
c.where_clause
);
}
#[test]
fn test_extract_clauses_no_where() {
let c = extract_clauses("SELECT * FROM events GROUP BY y");
assert_eq!(c.from_clause, "events");
assert!(c.where_clause.is_empty());
}
#[tokio::test]
async fn test_try_from_sql_non_aggregate() {
let ctx = laminar_sql::create_session_context();
let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int64, false)]));
let batch = RecordBatch::try_new(
Arc::clone(&schema),
vec![Arc::new(arrow::array::Int64Array::from(vec![1]))],
)
.unwrap();
let mem_table =
datafusion::datasource::MemTable::try_new(schema, vec![vec![batch]]).unwrap();
ctx.register_table("events", Arc::new(mem_table)).unwrap();
let result = IncrementalAggState::try_from_sql(&ctx, "SELECT * FROM events")
.await
.unwrap();
assert!(result.is_none(), "Non-aggregate query should return None");
}
#[tokio::test]
async fn test_try_from_sql_with_group_by() {
let ctx = laminar_sql::create_session_context();
let schema = Arc::new(Schema::new(vec![
Field::new("name", DataType::Utf8, false),
Field::new("value", DataType::Float64, false),
]));
let batch = RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(arrow::array::StringArray::from(vec!["a"])),
Arc::new(arrow::array::Float64Array::from(vec![1.0])),
],
)
.unwrap();
let mem_table =
datafusion::datasource::MemTable::try_new(schema, vec![vec![batch]]).unwrap();
ctx.register_table("events", Arc::new(mem_table)).unwrap();
let result = IncrementalAggState::try_from_sql(
&ctx,
"SELECT name, SUM(value) as total FROM events GROUP BY name",
)
.await
.unwrap();
assert!(result.is_some(), "Aggregate query should return Some");
let state = result.unwrap();
assert_eq!(state.num_group_cols, 1);
assert_eq!(state.agg_specs.len(), 1);
assert_eq!(state.group_col_names, vec!["name"]);
}
#[tokio::test]
async fn test_incremental_aggregation_across_batches() {
let ctx = laminar_sql::create_session_context();
let schema = Arc::new(Schema::new(vec![
Field::new("name", DataType::Utf8, false),
Field::new("value", DataType::Float64, false),
]));
let dummy_batch = RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(arrow::array::StringArray::from(vec!["x"])),
Arc::new(arrow::array::Float64Array::from(vec![0.0])),
],
)
.unwrap();
let mem_table =
datafusion::datasource::MemTable::try_new(Arc::clone(&schema), vec![vec![dummy_batch]])
.unwrap();
ctx.register_table("events", Arc::new(mem_table)).unwrap();
let mut state = IncrementalAggState::try_from_sql(
&ctx,
"SELECT name, SUM(value) as total FROM events GROUP BY name",
)
.await
.unwrap()
.unwrap();
let pre_agg_schema = Arc::new(Schema::new(vec![
Field::new("name", DataType::Utf8, true),
Field::new("__agg_input_1", DataType::Float64, true),
]));
let batch1 = RecordBatch::try_new(
Arc::clone(&pre_agg_schema),
vec![
Arc::new(arrow::array::StringArray::from(vec!["a", "b", "a"])),
Arc::new(arrow::array::Float64Array::from(vec![10.0, 20.0, 30.0])),
],
)
.unwrap();
state.process_batch(&batch1).unwrap();
let result1 = state.emit().unwrap();
assert_eq!(result1.len(), 1);
assert_eq!(result1[0].num_rows(), 2);
let batch2 = RecordBatch::try_new(
Arc::clone(&pre_agg_schema),
vec![
Arc::new(arrow::array::StringArray::from(vec!["a", "c"])),
Arc::new(arrow::array::Float64Array::from(vec![5.0, 15.0])),
],
)
.unwrap();
state.process_batch(&batch2).unwrap();
let result2 = state.emit().unwrap();
assert_eq!(result2.len(), 1);
assert_eq!(result2[0].num_rows(), 3);
let names = result2[0]
.column(0)
.as_any()
.downcast_ref::<arrow::array::StringArray>()
.unwrap();
let totals = result2[0]
.column(1)
.as_any()
.downcast_ref::<arrow::array::Float64Array>()
.unwrap();
for i in 0..result2[0].num_rows() {
match names.value(i) {
"a" => assert!(
(totals.value(i) - 45.0).abs() < f64::EPSILON,
"Expected 45.0 for group 'a', got {}",
totals.value(i)
),
"b" => assert!(
(totals.value(i) - 20.0).abs() < f64::EPSILON,
"Expected 20.0 for group 'b', got {}",
totals.value(i)
),
"c" => assert!(
(totals.value(i) - 15.0).abs() < f64::EPSILON,
"Expected 15.0 for group 'c', got {}",
totals.value(i)
),
other => panic!("Unexpected group: {other}"),
}
}
}
async fn setup_agg_state(sql: &str) -> (SessionContext, IncrementalAggState) {
let ctx = laminar_sql::create_session_context();
let schema = Arc::new(Schema::new(vec![
Field::new("name", DataType::Utf8, false),
Field::new("value", DataType::Float64, false),
]));
let dummy = RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(arrow::array::StringArray::from(vec!["x"])),
Arc::new(arrow::array::Float64Array::from(vec![0.0])),
],
)
.unwrap();
let mem_table =
datafusion::datasource::MemTable::try_new(Arc::clone(&schema), vec![vec![dummy]])
.unwrap();
ctx.register_table("events", Arc::new(mem_table)).unwrap();
let state = IncrementalAggState::try_from_sql(&ctx, sql)
.await
.unwrap()
.expect("expected aggregate state");
(ctx, state)
}
#[tokio::test]
async fn test_distinct_flag_extracted() {
let ctx = laminar_sql::create_session_context();
let schema = Arc::new(Schema::new(vec![
Field::new("name", DataType::Utf8, false),
Field::new("value", DataType::Float64, false),
]));
let dummy = RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(arrow::array::StringArray::from(vec!["x"])),
Arc::new(arrow::array::Float64Array::from(vec![0.0])),
],
)
.unwrap();
let mem_table =
datafusion::datasource::MemTable::try_new(Arc::clone(&schema), vec![vec![dummy]])
.unwrap();
ctx.register_table("events", Arc::new(mem_table)).unwrap();
let state = IncrementalAggState::try_from_sql(
&ctx,
"SELECT name, COUNT(DISTINCT value) as cnt FROM events GROUP BY name",
)
.await
.unwrap()
.expect("expected aggregate state");
assert!(state.agg_specs[0].distinct, "DISTINCT flag should be set");
}
#[tokio::test]
async fn test_distinct_count_produces_correct_result() {
let (_, mut state) =
setup_agg_state("SELECT name, COUNT(DISTINCT value) as cnt FROM events GROUP BY name")
.await;
let pre_agg_schema = Arc::new(Schema::new(vec![
Field::new("name", DataType::Utf8, true),
Field::new("__agg_input_1", DataType::Float64, true),
]));
let batch = RecordBatch::try_new(
Arc::clone(&pre_agg_schema),
vec![
Arc::new(arrow::array::StringArray::from(vec!["a", "a", "a", "a"])),
Arc::new(arrow::array::Float64Array::from(vec![
10.0, 10.0, 10.0, 20.0,
])),
],
)
.unwrap();
state.process_batch(&batch).unwrap();
let result = state.emit().unwrap();
assert_eq!(result.len(), 1);
let count_col = result[0]
.column(1)
.as_any()
.downcast_ref::<arrow::array::Int64Array>()
.expect("count should be Int64");
assert_eq!(count_col.value(0), 2, "COUNT(DISTINCT) should be 2");
}
#[tokio::test]
async fn test_distinct_sum_produces_correct_result() {
let (_, mut state) =
setup_agg_state("SELECT name, SUM(DISTINCT value) as total FROM events GROUP BY name")
.await;
let pre_agg_schema = Arc::new(Schema::new(vec![
Field::new("name", DataType::Utf8, true),
Field::new("__agg_input_1", DataType::Float64, true),
]));
let batch = RecordBatch::try_new(
Arc::clone(&pre_agg_schema),
vec![
Arc::new(arrow::array::StringArray::from(vec!["a", "a", "a"])),
Arc::new(arrow::array::Float64Array::from(vec![10.0, 10.0, 20.0])),
],
)
.unwrap();
state.process_batch(&batch).unwrap();
let result = state.emit().unwrap();
let total_col = result[0]
.column(1)
.as_any()
.downcast_ref::<arrow::array::Float64Array>()
.expect("sum should be Float64");
assert!(
(total_col.value(0) - 30.0).abs() < f64::EPSILON,
"SUM(DISTINCT) should be 30, got {}",
total_col.value(0)
);
}
#[tokio::test]
async fn test_filter_clause_extracted() {
let ctx = laminar_sql::create_session_context();
let schema = Arc::new(Schema::new(vec![
Field::new("name", DataType::Utf8, false),
Field::new("value", DataType::Float64, false),
]));
let dummy = RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(arrow::array::StringArray::from(vec!["x"])),
Arc::new(arrow::array::Float64Array::from(vec![0.0])),
],
)
.unwrap();
let mem_table =
datafusion::datasource::MemTable::try_new(Arc::clone(&schema), vec![vec![dummy]])
.unwrap();
ctx.register_table("events", Arc::new(mem_table)).unwrap();
let state = IncrementalAggState::try_from_sql(
&ctx,
"SELECT name, SUM(value) FILTER (WHERE value > 0) as pos_sum FROM events GROUP BY name",
)
.await
.unwrap()
.expect("expected aggregate state");
assert!(
state.agg_specs[0].filter_col_index.is_some(),
"FILTER clause should set filter_col_index"
);
}
#[tokio::test]
async fn test_filter_clause_applied() {
let (_, mut state) = setup_agg_state(
"SELECT name, SUM(value) FILTER (WHERE value > 0) as pos_sum FROM events GROUP BY name",
)
.await;
let filter_col_idx = state.agg_specs[0]
.filter_col_index
.expect("filter_col_index should be set");
let num_cols = state.num_group_cols
+ state
.agg_specs
.iter()
.map(|s| s.input_col_indices.len())
.sum::<usize>()
+ state
.agg_specs
.iter()
.filter(|s| s.filter_col_index.is_some())
.count();
assert!(
filter_col_idx < num_cols,
"filter col index should be in range"
);
let pre_agg_schema = Arc::new(Schema::new(vec![
Field::new("name", DataType::Utf8, true),
Field::new("__agg_input_1", DataType::Float64, true),
Field::new("__agg_filter_2", DataType::Boolean, true),
]));
let batch = RecordBatch::try_new(
Arc::clone(&pre_agg_schema),
vec![
Arc::new(arrow::array::StringArray::from(vec!["a", "a", "a"])),
Arc::new(arrow::array::Float64Array::from(vec![-5.0, 10.0, 20.0])),
Arc::new(arrow::array::BooleanArray::from(vec![false, true, true])),
],
)
.unwrap();
state.process_batch(&batch).unwrap();
let result = state.emit().unwrap();
let total_col = result[0]
.column(1)
.as_any()
.downcast_ref::<arrow::array::Float64Array>()
.expect("sum should be Float64");
assert!(
(total_col.value(0) - 30.0).abs() < f64::EPSILON,
"SUM with FILTER should be 30, got {}",
total_col.value(0)
);
}
#[tokio::test]
async fn test_having_clause_detected() {
let ctx = laminar_sql::create_session_context();
let schema = Arc::new(Schema::new(vec![
Field::new("name", DataType::Utf8, false),
Field::new("value", DataType::Float64, false),
]));
let dummy = RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(arrow::array::StringArray::from(vec!["x"])),
Arc::new(arrow::array::Float64Array::from(vec![0.0])),
],
)
.unwrap();
let mem_table =
datafusion::datasource::MemTable::try_new(Arc::clone(&schema), vec![vec![dummy]])
.unwrap();
ctx.register_table("events", Arc::new(mem_table)).unwrap();
let state = IncrementalAggState::try_from_sql(
&ctx,
"SELECT name, SUM(value) as total FROM events GROUP BY name HAVING SUM(value) > 100",
)
.await
.unwrap()
.expect("expected aggregate state");
assert!(
state.having_sql.is_some(),
"HAVING predicate should be extracted"
);
}
#[tokio::test]
async fn test_create_accumulator_error_propagated() {
let (_, mut state) =
setup_agg_state("SELECT name, SUM(value) as total FROM events GROUP BY name").await;
let pre_agg_schema = Arc::new(Schema::new(vec![
Field::new("name", DataType::Utf8, true),
Field::new("__agg_input_1", DataType::Float64, true),
]));
let batch = RecordBatch::try_new(
Arc::clone(&pre_agg_schema),
vec![
Arc::new(arrow::array::StringArray::from(vec!["a"])),
Arc::new(arrow::array::Float64Array::from(vec![1.0])),
],
)
.unwrap();
assert!(state.process_batch(&batch).is_ok());
}
#[tokio::test]
async fn test_type_inference_preserves_source_int32() {
let ctx = laminar_sql::create_session_context();
let schema = Arc::new(Schema::new(vec![
Field::new("name", DataType::Utf8, false),
Field::new("amount", DataType::Int32, false),
]));
let dummy = RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(arrow::array::StringArray::from(vec!["x"])),
Arc::new(arrow::array::Int32Array::from(vec![0])),
],
)
.unwrap();
let mem_table =
datafusion::datasource::MemTable::try_new(Arc::clone(&schema), vec![vec![dummy]])
.unwrap();
ctx.register_table("orders", Arc::new(mem_table)).unwrap();
let state = IncrementalAggState::try_from_sql(
&ctx,
"SELECT name, SUM(amount) as total FROM orders GROUP BY name",
)
.await
.unwrap()
.expect("expected aggregate state");
assert_eq!(
state.agg_specs[0].input_types[0],
DataType::Int32,
"SUM(int32_col) input type should be Int32, got {:?}",
state.agg_specs[0].input_types[0]
);
}
#[tokio::test]
async fn test_type_inference_preserves_source_float32() {
let ctx = laminar_sql::create_session_context();
let schema = Arc::new(Schema::new(vec![
Field::new("name", DataType::Utf8, false),
Field::new("price", DataType::Float32, false),
]));
let dummy = RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(arrow::array::StringArray::from(vec!["x"])),
Arc::new(arrow::array::Float32Array::from(vec![0.0f32])),
],
)
.unwrap();
let mem_table =
datafusion::datasource::MemTable::try_new(Arc::clone(&schema), vec![vec![dummy]])
.unwrap();
ctx.register_table("products", Arc::new(mem_table)).unwrap();
let state = IncrementalAggState::try_from_sql(
&ctx,
"SELECT name, AVG(price) as avg_price FROM products GROUP BY name",
)
.await
.unwrap()
.expect("expected aggregate state");
assert_eq!(
state.agg_specs[0].input_types[0],
DataType::Float32,
"AVG(float32_col) input type should be Float32, got {:?}",
state.agg_specs[0].input_types[0]
);
}
#[tokio::test]
async fn test_type_inference_literal_expr() {
let ctx = laminar_sql::create_session_context();
let schema = Arc::new(Schema::new(vec![
Field::new("name", DataType::Utf8, false),
Field::new("value", DataType::Int64, false),
]));
let dummy = RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(arrow::array::StringArray::from(vec!["x"])),
Arc::new(arrow::array::Int64Array::from(vec![0])),
],
)
.unwrap();
let mem_table =
datafusion::datasource::MemTable::try_new(Arc::clone(&schema), vec![vec![dummy]])
.unwrap();
ctx.register_table("events", Arc::new(mem_table)).unwrap();
let state = IncrementalAggState::try_from_sql(
&ctx,
"SELECT name, MIN(value) as min_val FROM events GROUP BY name",
)
.await
.unwrap()
.expect("expected aggregate state");
assert_eq!(state.agg_specs[0].input_types[0], DataType::Int64,);
}
#[test]
fn test_extract_clauses_subquery_in_where() {
let c = extract_clauses(
"SELECT * FROM orders WHERE amount > (SELECT AVG(amount) FROM orders WHERE status = 'active') GROUP BY name",
);
assert_eq!(c.from_clause, "orders");
assert!(
c.where_clause.contains("AVG"),
"subquery should be preserved: {}",
c.where_clause
);
}
#[test]
fn test_expr_to_sql_column() {
use datafusion_expr::col;
assert_eq!(expr_to_sql(&col("price")), "\"price\"");
}
#[test]
fn test_expr_to_sql_string_literal() {
let e = datafusion_expr::Expr::Literal(ScalarValue::Utf8(Some("it's".to_string())), None);
assert_eq!(expr_to_sql(&e), "'it''s'");
}
#[test]
fn test_expr_to_sql_null_literal() {
let e = datafusion_expr::Expr::Literal(ScalarValue::Null, None);
assert_eq!(expr_to_sql(&e), "NULL");
}
#[test]
fn test_expr_to_sql_boolean_literal() {
let t = datafusion_expr::Expr::Literal(ScalarValue::Boolean(Some(true)), None);
assert_eq!(expr_to_sql(&t), "TRUE");
let f = datafusion_expr::Expr::Literal(ScalarValue::Boolean(Some(false)), None);
assert_eq!(expr_to_sql(&f), "FALSE");
}
#[test]
fn test_expr_to_sql_binary_expr() {
use datafusion_expr::{col, lit};
let e = col("x").gt(lit(10));
let sql = expr_to_sql(&e);
assert!(sql.contains("\"x\""), "should contain column: {sql}");
assert!(sql.contains('>'), "should contain >: {sql}");
assert!(sql.contains("10"), "should contain 10: {sql}");
}
#[test]
fn test_expr_to_sql_cast() {
use datafusion_expr::Expr;
let e = Expr::Cast(datafusion_expr::expr::Cast {
expr: Box::new(datafusion_expr::col("x")),
data_type: DataType::Float64,
});
let sql = expr_to_sql(&e);
assert!(sql.contains("CAST"), "should contain CAST: {sql}");
assert!(sql.contains("Float64"), "should contain target type: {sql}");
}
#[test]
fn test_expr_to_sql_scalar_function() {
use datafusion_expr::Expr;
let func = datafusion::functions::string::upper();
let e = Expr::ScalarFunction(datafusion_expr::expr::ScalarFunction {
func,
args: vec![datafusion_expr::col("name")],
});
let sql = expr_to_sql(&e);
assert!(sql.contains("upper"), "should contain function name: {sql}");
assert!(sql.contains("\"name\""), "should contain arg: {sql}");
}
#[test]
fn test_expr_to_sql_case() {
use datafusion_expr::{col, lit};
let e = datafusion_expr::Expr::Case(datafusion_expr::expr::Case {
expr: None,
when_then_expr: vec![(Box::new(col("x").gt(lit(0))), Box::new(lit(1)))],
else_expr: Some(Box::new(lit(0))),
});
let sql = expr_to_sql(&e);
assert!(sql.starts_with("CASE"), "should start with CASE: {sql}");
assert!(sql.contains("WHEN"), "should contain WHEN: {sql}");
assert!(sql.contains("THEN"), "should contain THEN: {sql}");
assert!(sql.contains("ELSE"), "should contain ELSE: {sql}");
assert!(sql.ends_with("END"), "should end with END: {sql}");
}
#[test]
fn test_expr_to_sql_not() {
use datafusion_expr::col;
let e = datafusion_expr::Expr::Not(Box::new(col("active")));
assert_eq!(expr_to_sql(&e), "(NOT \"active\")");
}
#[test]
fn test_expr_to_sql_negative() {
use datafusion_expr::col;
let e = datafusion_expr::Expr::Negative(Box::new(col("x")));
assert_eq!(expr_to_sql(&e), "(-\"x\")");
}
#[test]
fn test_expr_to_sql_is_null() {
use datafusion_expr::col;
let e = datafusion_expr::Expr::IsNull(Box::new(col("x")));
assert_eq!(expr_to_sql(&e), "(\"x\" IS NULL)");
}
#[test]
fn test_expr_to_sql_is_not_null() {
use datafusion_expr::col;
let e = datafusion_expr::Expr::IsNotNull(Box::new(col("x")));
assert_eq!(expr_to_sql(&e), "(\"x\" IS NOT NULL)");
}
#[test]
fn test_expr_to_sql_between() {
use datafusion_expr::{col, lit};
let e = col("x").between(lit(1), lit(10));
let sql = expr_to_sql(&e);
assert!(sql.contains("BETWEEN"), "should contain BETWEEN: {sql}");
assert!(sql.contains("AND"), "should contain AND: {sql}");
}
#[test]
fn test_expr_to_sql_in_list() {
use datafusion_expr::{col, lit};
let e = col("status").in_list(vec![lit("a"), lit("b")], false);
let sql = expr_to_sql(&e);
assert!(sql.contains("IN"), "should contain IN: {sql}");
assert!(sql.contains("'a'"), "should contain 'a': {sql}");
assert!(sql.contains("'b'"), "should contain 'b': {sql}");
}
#[test]
fn test_expr_to_sql_like() {
use datafusion_expr::col;
let e = col("name").like(datafusion_expr::lit("foo%"));
let sql = expr_to_sql(&e);
assert!(sql.contains("LIKE"), "should contain LIKE: {sql}");
assert!(sql.contains("'foo%'"), "should contain pattern: {sql}");
}
#[test]
fn test_expr_to_sql_aggregate_function() {
use datafusion_expr::Expr;
let sum_udf = datafusion::functions_aggregate::sum::sum_udaf();
let e = Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction {
func: sum_udf,
params: datafusion_expr::expr::AggregateFunctionParams {
args: vec![datafusion_expr::col("x")],
distinct: false,
filter: None,
order_by: vec![],
null_treatment: None,
},
});
let sql = expr_to_sql(&e);
assert!(sql.contains("sum"), "should contain sum: {sql}");
assert!(sql.contains("\"x\""), "should contain arg: {sql}");
}
#[test]
fn test_expr_to_sql_aggregate_distinct() {
use datafusion_expr::Expr;
let count_udf = datafusion::functions_aggregate::count::count_udaf();
let e = Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction {
func: count_udf,
params: datafusion_expr::expr::AggregateFunctionParams {
args: vec![datafusion_expr::col("id")],
distinct: true,
filter: None,
order_by: vec![],
null_treatment: None,
},
});
let sql = expr_to_sql(&e);
assert!(sql.contains("DISTINCT"), "should contain DISTINCT: {sql}");
}
#[tokio::test]
async fn test_group_by_expression_scalar_function() {
let ctx = laminar_sql::create_session_context();
let schema = Arc::new(Schema::new(vec![
Field::new("name", DataType::Utf8, false),
Field::new("value", DataType::Float64, false),
]));
let dummy = RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(arrow::array::StringArray::from(vec!["hello"])),
Arc::new(arrow::array::Float64Array::from(vec![1.0])),
],
)
.unwrap();
let mem_table =
datafusion::datasource::MemTable::try_new(Arc::clone(&schema), vec![vec![dummy]])
.unwrap();
ctx.register_table("events", Arc::new(mem_table)).unwrap();
let state = IncrementalAggState::try_from_sql(
&ctx,
"SELECT upper(name), SUM(value) as total FROM events GROUP BY upper(name)",
)
.await
.unwrap()
.expect("expected aggregate state");
assert!(
state.pre_agg_sql.contains("upper("),
"pre-agg SQL should contain expression: {}",
state.pre_agg_sql
);
assert!(
!state.pre_agg_sql.contains("\"upper("),
"should NOT quote expression as identifier: {}",
state.pre_agg_sql
);
}
#[tokio::test]
async fn test_group_by_simple_column_still_works() {
let (_, state) =
setup_agg_state("SELECT name, SUM(value) as total FROM events GROUP BY name").await;
assert!(
state.pre_agg_sql.contains("\"name\""),
"simple column should be quoted: {}",
state.pre_agg_sql
);
}
#[tokio::test]
async fn test_group_cardinality_limit_enforced() {
let (_, mut state) =
setup_agg_state("SELECT name, SUM(value) as total FROM events GROUP BY name").await;
state.max_groups = 3;
let pre_agg_schema = Arc::new(Schema::new(vec![
Field::new("name", DataType::Utf8, true),
Field::new("__agg_input_1", DataType::Float64, true),
]));
let batch = RecordBatch::try_new(
Arc::clone(&pre_agg_schema),
vec![
Arc::new(arrow::array::StringArray::from(vec![
"a", "b", "c", "d", "e",
])),
Arc::new(arrow::array::Float64Array::from(vec![
1.0, 2.0, 3.0, 4.0, 5.0,
])),
],
)
.unwrap();
state.process_batch(&batch).unwrap();
let result = state.emit().unwrap();
assert_eq!(result.len(), 1);
assert!(
result[0].num_rows() <= 3,
"should have at most 3 groups, got {}",
result[0].num_rows()
);
}
#[tokio::test]
async fn test_group_cardinality_existing_groups_still_updated() {
let (_, mut state) =
setup_agg_state("SELECT name, SUM(value) as total FROM events GROUP BY name").await;
state.max_groups = 2;
let pre_agg_schema = Arc::new(Schema::new(vec![
Field::new("name", DataType::Utf8, true),
Field::new("__agg_input_1", DataType::Float64, true),
]));
let batch1 = RecordBatch::try_new(
Arc::clone(&pre_agg_schema),
vec![
Arc::new(arrow::array::StringArray::from(vec!["a", "b"])),
Arc::new(arrow::array::Float64Array::from(vec![10.0, 20.0])),
],
)
.unwrap();
state.process_batch(&batch1).unwrap();
let batch2 = RecordBatch::try_new(
Arc::clone(&pre_agg_schema),
vec![
Arc::new(arrow::array::StringArray::from(vec!["a", "c"])),
Arc::new(arrow::array::Float64Array::from(vec![5.0, 100.0])),
],
)
.unwrap();
state.process_batch(&batch2).unwrap();
let result = state.emit().unwrap();
assert_eq!(result[0].num_rows(), 2, "still only 2 groups");
let names = result[0]
.column(0)
.as_any()
.downcast_ref::<arrow::array::StringArray>()
.unwrap();
let totals = result[0]
.column(1)
.as_any()
.downcast_ref::<arrow::array::Float64Array>()
.unwrap();
for i in 0..2 {
if names.value(i) == "a" {
assert!(
(totals.value(i) - 15.0).abs() < f64::EPSILON,
"group 'a' should be 15, got {}",
totals.value(i)
);
}
}
}
#[test]
fn test_extract_clauses_multiple_joins() {
let c = extract_clauses(
"SELECT * FROM orders o JOIN customers c ON o.cust_id = c.id JOIN products p ON o.prod_id = p.id WHERE o.amount > 100 GROUP BY c.name",
);
assert!(
c.from_clause.contains("orders"),
"should contain orders: {}",
c.from_clause
);
assert!(
c.from_clause.contains("customers"),
"should contain customers: {}",
c.from_clause
);
assert!(
c.from_clause.contains("products"),
"should contain products: {}",
c.from_clause
);
assert!(
c.where_clause.contains("100"),
"WHERE should contain predicate: {}",
c.where_clause
);
}
#[test]
#[allow(clippy::too_many_lines)]
fn test_scalar_to_json_roundtrip() {
let cases: Vec<ScalarValue> = vec![
ScalarValue::Null,
ScalarValue::Boolean(Some(true)),
ScalarValue::Boolean(None),
ScalarValue::Int8(Some(42)),
ScalarValue::Int16(Some(-100)),
ScalarValue::Int32(Some(999)),
ScalarValue::Int64(Some(123_456_789)),
ScalarValue::Int64(None),
ScalarValue::UInt8(Some(255)),
ScalarValue::UInt16(Some(65535)),
ScalarValue::UInt32(Some(1_000_000)),
ScalarValue::UInt64(Some(9_999_999_999)),
ScalarValue::UInt64(None),
ScalarValue::Float32(Some(1.5)),
ScalarValue::Float64(Some(9.876_54)),
ScalarValue::Float64(None),
ScalarValue::Utf8(Some("hello world".to_string())),
ScalarValue::Utf8(None),
];
for original in &cases {
let json_val = scalar_to_json(original);
let restored = json_to_scalar(&json_val).unwrap();
let orig_str = format!("{original:?}");
let rest_str = format!("{restored:?}");
match original {
ScalarValue::Null => {
assert!(
matches!(restored, ScalarValue::Null),
"{orig_str} != {rest_str}"
);
}
ScalarValue::Boolean(v) => {
assert_eq!(
*v,
match restored {
ScalarValue::Boolean(r) => r,
_ => panic!("type mismatch: {rest_str}"),
}
);
}
ScalarValue::Int8(Some(n)) => {
assert_eq!(
Some(i64::from(*n)),
match restored {
ScalarValue::Int64(r) => r,
_ => panic!("type mismatch: {rest_str}"),
}
);
}
ScalarValue::Int16(Some(n)) => {
assert_eq!(
Some(i64::from(*n)),
match restored {
ScalarValue::Int64(r) => r,
_ => panic!("type mismatch: {rest_str}"),
}
);
}
ScalarValue::Int32(Some(n)) => {
assert_eq!(
Some(i64::from(*n)),
match restored {
ScalarValue::Int64(r) => r,
_ => panic!("type mismatch: {rest_str}"),
}
);
}
ScalarValue::Int64(v) => {
assert_eq!(
*v,
match restored {
ScalarValue::Int64(r) => r,
_ => panic!("type mismatch: {rest_str}"),
}
);
}
ScalarValue::UInt8(Some(n)) => {
assert_eq!(
Some(u64::from(*n)),
match restored {
ScalarValue::UInt64(r) => r,
_ => panic!("type mismatch: {rest_str}"),
}
);
}
ScalarValue::UInt16(Some(n)) => {
assert_eq!(
Some(u64::from(*n)),
match restored {
ScalarValue::UInt64(r) => r,
_ => panic!("type mismatch: {rest_str}"),
}
);
}
ScalarValue::UInt32(Some(n)) => {
assert_eq!(
Some(u64::from(*n)),
match restored {
ScalarValue::UInt64(r) => r,
_ => panic!("type mismatch: {rest_str}"),
}
);
}
ScalarValue::UInt64(v) => {
assert_eq!(
*v,
match restored {
ScalarValue::UInt64(r) => r,
_ => panic!("type mismatch: {rest_str}"),
}
);
}
ScalarValue::Float32(Some(f)) => {
let ScalarValue::Float64(restored_f) = restored else {
panic!("type mismatch: {rest_str}")
};
assert!(
(f64::from(*f) - restored_f.unwrap()).abs() < 1e-6,
"{orig_str} != {rest_str}"
);
}
ScalarValue::Float64(v) => {
let ScalarValue::Float64(restored_f) = restored else {
panic!("type mismatch: {rest_str}")
};
assert_eq!(*v, restored_f, "{orig_str} != {rest_str}");
}
ScalarValue::Utf8(v) => {
assert_eq!(
*v,
match restored {
ScalarValue::Utf8(r) => r,
_ => panic!("type mismatch: {rest_str}"),
}
);
}
_ => {}
}
}
}
#[tokio::test]
async fn test_agg_checkpoint_roundtrip_single_group() {
let (_, mut state) =
setup_agg_state("SELECT name, SUM(value) as total FROM events GROUP BY name").await;
let pre_agg_schema = Arc::new(Schema::new(vec![
Field::new("name", DataType::Utf8, true),
Field::new("__agg_input_1", DataType::Float64, true),
]));
let batch = RecordBatch::try_new(
Arc::clone(&pre_agg_schema),
vec![
Arc::new(arrow::array::StringArray::from(vec!["a", "a"])),
Arc::new(arrow::array::Float64Array::from(vec![10.0, 20.0])),
],
)
.unwrap();
state.process_batch(&batch).unwrap();
let cp = state.checkpoint_groups().unwrap();
assert_eq!(cp.groups.len(), 1);
let (_, mut state2) =
setup_agg_state("SELECT name, SUM(value) as total FROM events GROUP BY name").await;
let restored = state2.restore_groups(&cp).unwrap();
assert_eq!(restored, 1);
let result = state2.emit().unwrap();
assert_eq!(result.len(), 1);
let total = result[0]
.column(1)
.as_any()
.downcast_ref::<arrow::array::Float64Array>()
.unwrap();
assert!(
(total.value(0) - 30.0).abs() < f64::EPSILON,
"Restored SUM should be 30, got {}",
total.value(0)
);
}
#[tokio::test]
async fn test_agg_checkpoint_roundtrip_multi_group() {
let (_, mut state) =
setup_agg_state("SELECT name, SUM(value) as total FROM events GROUP BY name").await;
let pre_agg_schema = Arc::new(Schema::new(vec![
Field::new("name", DataType::Utf8, true),
Field::new("__agg_input_1", DataType::Float64, true),
]));
let batch = RecordBatch::try_new(
Arc::clone(&pre_agg_schema),
vec![
Arc::new(arrow::array::StringArray::from(vec![
"a", "b", "a", "b", "c",
])),
Arc::new(arrow::array::Float64Array::from(vec![
10.0, 20.0, 30.0, 40.0, 50.0,
])),
],
)
.unwrap();
state.process_batch(&batch).unwrap();
let cp = state.checkpoint_groups().unwrap();
assert_eq!(cp.groups.len(), 3);
let (_, mut state2) =
setup_agg_state("SELECT name, SUM(value) as total FROM events GROUP BY name").await;
let restored = state2.restore_groups(&cp).unwrap();
assert_eq!(restored, 3);
let result = state2.emit().unwrap();
assert_eq!(result[0].num_rows(), 3);
}
#[tokio::test]
async fn test_checkpoint_empty_state_returns_none() {
use crate::stream_executor::StreamExecutor;
let ctx = laminar_sql::create_session_context();
laminar_sql::register_streaming_functions(&ctx);
let mut executor = StreamExecutor::new(ctx);
let result = executor.snapshot_state().unwrap();
assert!(result.is_none());
}
#[test]
fn test_restore_corrupt_bytes_returns_error() {
use crate::stream_executor::StreamExecutor;
let ctx = laminar_sql::create_session_context();
let mut executor = StreamExecutor::new(ctx);
let result = executor.restore_state(b"not valid json");
assert!(result.is_err());
}
#[tokio::test]
async fn test_restore_fingerprint_mismatch_errors() {
let (_, mut state) =
setup_agg_state("SELECT name, SUM(value) as total FROM events GROUP BY name").await;
let pre_agg_schema = Arc::new(Schema::new(vec![
Field::new("name", DataType::Utf8, true),
Field::new("__agg_input_1", DataType::Float64, true),
]));
let batch = RecordBatch::try_new(
Arc::clone(&pre_agg_schema),
vec![
Arc::new(arrow::array::StringArray::from(vec!["a"])),
Arc::new(arrow::array::Float64Array::from(vec![10.0])),
],
)
.unwrap();
state.process_batch(&batch).unwrap();
let mut cp = state.checkpoint_groups().unwrap();
cp.fingerprint = 999_999;
let (_, mut state2) =
setup_agg_state("SELECT name, SUM(value) as total FROM events GROUP BY name").await;
let result = state2.restore_groups(&cp);
assert!(result.is_err(), "Fingerprint mismatch should error");
let err = result.unwrap_err().to_string();
assert!(
err.contains("fingerprint mismatch"),
"Error should mention fingerprint: {err}"
);
}
#[test]
fn test_checkpoint_backward_compat_without_join_states() {
let json = r#"{
"version": 1,
"agg_states": {},
"eowc_states": {},
"core_window_states": {}
}"#;
let cp: StreamExecutorCheckpoint = serde_json::from_str(json).unwrap();
assert_eq!(cp.version, 1);
assert!(cp.join_states.is_empty());
}
#[test]
fn test_checkpoint_with_join_states_round_trip() {
let mut join_states = FxHashMap::default();
join_states.insert(
"enriched".to_string(),
JoinStateCheckpoint {
left_buffer_rows: 100,
right_buffer_rows: 50,
left_batches: vec![vec![1, 2, 3]],
right_batches: vec![vec![4, 5, 6]],
last_evicted_watermark: 42,
},
);
let cp = StreamExecutorCheckpoint {
version: 2,
vnode_count: laminar_core::state::VNODE_COUNT,
agg_states: FxHashMap::default(),
eowc_states: FxHashMap::default(),
core_window_states: FxHashMap::default(),
join_states,
raw_eowc_states: FxHashMap::default(),
};
let bytes = serde_json::to_vec(&cp).unwrap();
let restored: StreamExecutorCheckpoint = serde_json::from_slice(&bytes).unwrap();
assert_eq!(restored.join_states.len(), 1);
let js = restored.join_states.get("enriched").unwrap();
assert_eq!(js.left_buffer_rows, 100);
assert_eq!(js.right_buffer_rows, 50);
assert_eq!(js.left_batches, vec![vec![1, 2, 3]]);
assert_eq!(js.right_batches, vec![vec![4, 5, 6]]);
assert_eq!(js.last_evicted_watermark, 42);
}
}