pub trait TableProvider<P>
where
P: Pager<Blob = EntryHandle> + Send + Sync,
{
fn get_table(&self, canonical_name: &str) -> ExecutorResult<Arc<ExecutorTable<P>>>;
}
pub struct QueryExecutor<P>
where
P: Pager<Blob = EntryHandle> + Send + Sync,
{
provider: Arc<dyn TableProvider<P>>,
}
impl<P> QueryExecutor<P>
where
P: Pager<Blob = EntryHandle> + Send + Sync + 'static,
{
pub fn new(provider: Arc<dyn TableProvider<P>>) -> Self {
Self { provider }
}
pub fn execute_select(&self, plan: SelectPlan) -> ExecutorResult<SelectExecution<P>> {
self.execute_select_with_filter(plan, None)
}
pub fn execute_select_with_filter(
&self,
plan: SelectPlan,
row_filter: Option<std::sync::Arc<dyn RowIdFilter<P>>>,
) -> ExecutorResult<SelectExecution<P>> {
if plan.tables.is_empty() {
return self.execute_select_without_table(plan);
}
if plan.tables.len() > 1 {
return self.execute_cross_product(plan);
}
let table_ref = &plan.tables[0];
let table = self.provider.get_table(&table_ref.qualified_name())?;
let display_name = table_ref.qualified_name();
if !plan.aggregates.is_empty() {
self.execute_aggregates(table, display_name, plan, row_filter)
} else if self.has_computed_aggregates(&plan) {
self.execute_computed_aggregates(table, display_name, plan, row_filter)
} else {
self.execute_projection(table, display_name, plan, row_filter)
}
}
fn has_computed_aggregates(&self, plan: &SelectPlan) -> bool {
plan.projections.iter().any(|proj| {
if let SelectProjection::Computed { expr, .. } = proj {
Self::expr_contains_aggregate(expr)
} else {
false
}
})
}
fn expr_contains_aggregate(expr: &ScalarExpr<String>) -> bool {
match expr {
ScalarExpr::Aggregate(_) => true,
ScalarExpr::Binary { left, right, .. } => {
Self::expr_contains_aggregate(left) || Self::expr_contains_aggregate(right)
}
ScalarExpr::GetField { base, .. } => Self::expr_contains_aggregate(base),
ScalarExpr::Column(_) | ScalarExpr::Literal(_) => false,
}
}
fn execute_select_without_table(&self, plan: SelectPlan) -> ExecutorResult<SelectExecution<P>> {
use arrow::array::ArrayRef;
use arrow::datatypes::Field;
let mut fields = Vec::new();
let mut arrays: Vec<ArrayRef> = Vec::new();
for proj in &plan.projections {
match proj {
SelectProjection::Computed { expr, alias } => {
let (field_name, dtype, array) = match expr {
ScalarExpr::Literal(lit) => {
let (dtype, array) = Self::literal_to_array(lit)?;
(alias.clone(), dtype, array)
}
_ => {
return Err(Error::InvalidArgumentError(
"SELECT without FROM only supports literal expressions".into(),
));
}
};
fields.push(Field::new(field_name, dtype, true));
arrays.push(array);
}
_ => {
return Err(Error::InvalidArgumentError(
"SELECT without FROM only supports computed projections".into(),
));
}
}
}
let schema = Arc::new(Schema::new(fields));
let batch = RecordBatch::try_new(Arc::clone(&schema), arrays)
.map_err(|e| Error::Internal(format!("failed to create record batch: {}", e)))?;
Ok(SelectExecution::new_single_batch(
String::new(), schema,
batch,
))
}
fn literal_to_array(lit: &llkv_expr::literal::Literal) -> ExecutorResult<(DataType, ArrayRef)> {
use arrow::array::{
ArrayRef, BooleanArray, Float64Array, Int64Array, StringArray, StructArray,
new_null_array,
};
use arrow::datatypes::{DataType, Field};
use llkv_expr::literal::Literal;
match lit {
Literal::Integer(v) => {
let val = i64::try_from(*v).unwrap_or(0);
Ok((
DataType::Int64,
Arc::new(Int64Array::from(vec![val])) as ArrayRef,
))
}
Literal::Float(v) => Ok((
DataType::Float64,
Arc::new(Float64Array::from(vec![*v])) as ArrayRef,
)),
Literal::Boolean(v) => Ok((
DataType::Boolean,
Arc::new(BooleanArray::from(vec![*v])) as ArrayRef,
)),
Literal::String(v) => Ok((
DataType::Utf8,
Arc::new(StringArray::from(vec![v.clone()])) as ArrayRef,
)),
Literal::Null => Ok((DataType::Null, new_null_array(&DataType::Null, 1))),
Literal::Struct(struct_fields) => {
let mut inner_fields = Vec::new();
let mut inner_arrays = Vec::new();
for (field_name, field_lit) in struct_fields {
let (field_dtype, field_array) = Self::literal_to_array(field_lit)?;
inner_fields.push(Field::new(field_name.clone(), field_dtype, true));
inner_arrays.push(field_array);
}
let struct_array =
StructArray::try_new(inner_fields.clone().into(), inner_arrays, None).map_err(
|e| Error::Internal(format!("failed to create struct array: {}", e)),
)?;
Ok((
DataType::Struct(inner_fields.into()),
Arc::new(struct_array) as ArrayRef,
))
}
}
}
fn execute_cross_product(&self, plan: SelectPlan) -> ExecutorResult<SelectExecution<P>> {
use arrow::compute::concat_batches;
if plan.tables.len() < 2 {
return Err(Error::InvalidArgumentError(
"cross product requires at least 2 tables".into(),
));
}
let mut tables = Vec::new();
for table_ref in &plan.tables {
let qualified_name = table_ref.qualified_name();
let table = self.provider.get_table(&qualified_name)?;
tables.push((table_ref.clone(), table));
}
if tables.len() > 2 {
return Err(Error::InvalidArgumentError(
"cross products with more than 2 tables not yet supported".into(),
));
}
let (left_ref, left_table) = &tables[0];
let (right_ref, right_table) = &tables[1];
use llkv_join::{JoinOptions, JoinType, TableJoinExt};
let mut result_batches = Vec::new();
left_table.table.join_stream(
&right_table.table,
&[], &JoinOptions {
join_type: JoinType::Inner,
..Default::default()
},
|batch| {
result_batches.push(batch);
},
)?;
let mut combined_fields = Vec::new();
for col in &left_table.schema.columns {
let qualified_name = format!("{}.{}.{}", left_ref.schema, left_ref.table, col.name);
combined_fields.push(arrow::datatypes::Field::new(
qualified_name,
col.data_type.clone(),
col.nullable,
));
}
for col in &right_table.schema.columns {
let qualified_name = format!("{}.{}.{}", right_ref.schema, right_ref.table, col.name);
combined_fields.push(arrow::datatypes::Field::new(
qualified_name,
col.data_type.clone(),
col.nullable,
));
}
let combined_schema = Arc::new(Schema::new(combined_fields));
let mut combined_batch = if result_batches.is_empty() {
RecordBatch::new_empty(Arc::clone(&combined_schema))
} else if result_batches.len() == 1 {
let batch = result_batches.into_iter().next().unwrap();
RecordBatch::try_new(Arc::clone(&combined_schema), batch.columns().to_vec()).map_err(
|e| {
Error::Internal(format!(
"failed to create batch with qualified names: {}",
e
))
},
)?
} else {
let original_batch = concat_batches(&result_batches[0].schema(), &result_batches)
.map_err(|e| Error::Internal(format!("failed to concatenate batches: {}", e)))?;
RecordBatch::try_new(
Arc::clone(&combined_schema),
original_batch.columns().to_vec(),
)
.map_err(|e| {
Error::Internal(format!(
"failed to create batch with qualified names: {}",
e
))
})?
};
if !plan.projections.is_empty() {
let mut selected_fields = Vec::new();
let mut selected_columns = Vec::new();
for proj in &plan.projections {
match proj {
SelectProjection::AllColumns => {
selected_fields = combined_schema.fields().iter().cloned().collect();
selected_columns = combined_batch.columns().to_vec();
break;
}
SelectProjection::AllColumnsExcept { exclude } => {
let exclude_lower: Vec<String> =
exclude.iter().map(|e| e.to_ascii_lowercase()).collect();
for (idx, field) in combined_schema.fields().iter().enumerate() {
let field_name_lower = field.name().to_ascii_lowercase();
if !exclude_lower.contains(&field_name_lower) {
selected_fields.push(field.clone());
selected_columns.push(combined_batch.column(idx).clone());
}
}
break;
}
SelectProjection::Column { name, alias } => {
let col_name = name.to_ascii_lowercase();
if let Some((idx, field)) = combined_schema
.fields()
.iter()
.enumerate()
.find(|(_, f)| f.name().to_ascii_lowercase() == col_name)
{
let output_name = alias.as_ref().unwrap_or(name).clone();
selected_fields.push(Arc::new(arrow::datatypes::Field::new(
output_name,
field.data_type().clone(),
field.is_nullable(),
)));
selected_columns.push(combined_batch.column(idx).clone());
} else {
return Err(Error::InvalidArgumentError(format!(
"column '{}' not found in cross product result",
name
)));
}
}
SelectProjection::Computed { expr, alias } => {
if let ScalarExpr::Column(col_name) = expr {
let col_name_lower = col_name.to_ascii_lowercase();
if let Some((idx, field)) = combined_schema
.fields()
.iter()
.enumerate()
.find(|(_, f)| f.name().to_ascii_lowercase() == col_name_lower)
{
selected_fields.push(Arc::new(arrow::datatypes::Field::new(
alias.clone(),
field.data_type().clone(),
field.is_nullable(),
)));
selected_columns.push(combined_batch.column(idx).clone());
} else {
return Err(Error::InvalidArgumentError(format!(
"column '{}' not found in cross product result",
col_name
)));
}
} else {
return Err(Error::InvalidArgumentError(
"complex computed projections not yet supported in cross products"
.into(),
));
}
}
}
}
let projected_schema = Arc::new(Schema::new(selected_fields));
combined_batch = RecordBatch::try_new(projected_schema, selected_columns)
.map_err(|e| Error::Internal(format!("failed to apply projections: {}", e)))?;
}
Ok(SelectExecution::new_single_batch(
format!(
"{},{}",
left_ref.qualified_name(),
right_ref.qualified_name()
),
combined_batch.schema(),
combined_batch,
))
}
fn execute_projection(
&self,
table: Arc<ExecutorTable<P>>,
display_name: String,
plan: SelectPlan,
row_filter: Option<std::sync::Arc<dyn RowIdFilter<P>>>,
) -> ExecutorResult<SelectExecution<P>> {
let table_ref = table.as_ref();
let projections = if plan.projections.is_empty() {
build_wildcard_projections(table_ref)
} else {
build_projected_columns(table_ref, &plan.projections)?
};
let schema = schema_for_projections(table_ref, &projections)?;
let (filter_expr, full_table_scan) = match plan.filter {
Some(expr) => (
crate::expression::translate_predicate(
expr,
table_ref.schema.as_ref(),
|name| Error::InvalidArgumentError(format!("unknown column '{}'", name)),
)?,
false,
),
None => {
let field_id = table_ref.schema.first_field_id().ok_or_else(|| {
Error::InvalidArgumentError(
"table has no columns; cannot perform wildcard scan".into(),
)
})?;
(crate::expression::full_table_scan_filter(field_id), true)
}
};
let expanded_order = expand_order_targets(&plan.order_by, &projections)?;
let physical_order = if let Some(first) = expanded_order.first() {
Some(resolve_scan_order(table_ref, &projections, first)?)
} else {
None
};
let options = if let Some(order_spec) = physical_order {
if row_filter.is_some() {
tracing::debug!("Applying MVCC row filter with ORDER BY");
}
ScanStreamOptions {
include_nulls: true,
order: Some(order_spec),
row_id_filter: row_filter.clone(),
}
} else {
if row_filter.is_some() {
tracing::debug!("Applying MVCC row filter");
}
ScanStreamOptions {
include_nulls: true,
order: None,
row_id_filter: row_filter.clone(),
}
};
Ok(SelectExecution::new_projection(
display_name,
schema,
table,
projections,
filter_expr,
options,
full_table_scan,
expanded_order,
))
}
fn execute_aggregates(
&self,
table: Arc<ExecutorTable<P>>,
display_name: String,
plan: SelectPlan,
row_filter: Option<std::sync::Arc<dyn RowIdFilter<P>>>,
) -> ExecutorResult<SelectExecution<P>> {
let table_ref = table.as_ref();
let mut specs: Vec<AggregateSpec> = Vec::with_capacity(plan.aggregates.len());
for aggregate in plan.aggregates {
match aggregate {
AggregateExpr::CountStar { alias } => {
specs.push(AggregateSpec {
alias,
kind: AggregateKind::CountStar,
});
}
AggregateExpr::Column {
column,
alias,
function,
distinct,
} => {
let col = table_ref.schema.resolve(&column).ok_or_else(|| {
Error::InvalidArgumentError(format!(
"unknown column '{}' in aggregate",
column
))
})?;
let kind = match function {
AggregateFunction::Count => {
if distinct {
AggregateKind::CountDistinctField {
field_id: col.field_id,
}
} else {
AggregateKind::CountField {
field_id: col.field_id,
}
}
}
AggregateFunction::SumInt64 => {
if col.data_type != DataType::Int64 {
return Err(Error::InvalidArgumentError(
"SUM currently supports only INTEGER columns".into(),
));
}
AggregateKind::SumInt64 {
field_id: col.field_id,
}
}
AggregateFunction::MinInt64 => {
if col.data_type != DataType::Int64 {
return Err(Error::InvalidArgumentError(
"MIN currently supports only INTEGER columns".into(),
));
}
AggregateKind::MinInt64 {
field_id: col.field_id,
}
}
AggregateFunction::MaxInt64 => {
if col.data_type != DataType::Int64 {
return Err(Error::InvalidArgumentError(
"MAX currently supports only INTEGER columns".into(),
));
}
AggregateKind::MaxInt64 {
field_id: col.field_id,
}
}
AggregateFunction::CountNulls => {
if distinct {
return Err(Error::InvalidArgumentError(
"DISTINCT is not supported for COUNT_NULLS".into(),
));
}
AggregateKind::CountNulls {
field_id: col.field_id,
}
}
};
specs.push(AggregateSpec { alias, kind });
}
}
}
if specs.is_empty() {
return Err(Error::InvalidArgumentError(
"aggregate query requires at least one aggregate expression".into(),
));
}
let had_filter = plan.filter.is_some();
let filter_expr = match plan.filter {
Some(expr) => crate::expression::translate_predicate(
expr,
table.schema.as_ref(),
|name| Error::InvalidArgumentError(format!("unknown column '{}'", name)),
)?,
None => {
let field_id = table.schema.first_field_id().ok_or_else(|| {
Error::InvalidArgumentError(
"table has no columns; cannot perform aggregate scan".into(),
)
})?;
crate::expression::full_table_scan_filter(field_id)
}
};
let mut projections = Vec::new();
let mut spec_to_projection: Vec<Option<usize>> = Vec::with_capacity(specs.len());
for spec in &specs {
if let Some(field_id) = spec.kind.field_id() {
let proj_idx = projections.len();
spec_to_projection.push(Some(proj_idx));
projections.push(ScanProjection::from(StoreProjection::with_alias(
LogicalFieldId::for_user(table.table.table_id(), field_id),
table
.schema
.column_by_field_id(field_id)
.map(|c| c.name.clone())
.unwrap_or_else(|| format!("col{field_id}")),
)));
} else {
spec_to_projection.push(None);
}
}
if projections.is_empty() {
let field_id = table.schema.first_field_id().ok_or_else(|| {
Error::InvalidArgumentError(
"table has no columns; cannot perform aggregate scan".into(),
)
})?;
projections.push(ScanProjection::from(StoreProjection::with_alias(
LogicalFieldId::for_user(table.table.table_id(), field_id),
table
.schema
.column_by_field_id(field_id)
.map(|c| c.name.clone())
.unwrap_or_else(|| format!("col{field_id}")),
)));
}
let options = ScanStreamOptions {
include_nulls: true,
order: None,
row_id_filter: row_filter.clone(),
};
let mut states: Vec<AggregateState> = Vec::with_capacity(specs.len());
let mut count_star_override: Option<i64> = None;
if !had_filter && row_filter.is_none() {
let total_rows = table.total_rows.load(Ordering::SeqCst);
tracing::debug!(
"[AGGREGATE] Using COUNT(*) shortcut: total_rows={}",
total_rows
);
if total_rows > i64::MAX as u64 {
return Err(Error::InvalidArgumentError(
"COUNT(*) result exceeds supported range".into(),
));
}
count_star_override = Some(total_rows as i64);
} else {
tracing::debug!(
"[AGGREGATE] NOT using COUNT(*) shortcut: had_filter={}, has_row_filter={}",
had_filter,
row_filter.is_some()
);
}
for (idx, spec) in specs.iter().enumerate() {
states.push(AggregateState {
alias: spec.alias.clone(),
accumulator: AggregateAccumulator::new_with_projection_index(
spec,
spec_to_projection[idx],
count_star_override,
)?,
override_value: match spec.kind {
AggregateKind::CountStar => {
tracing::debug!(
"[AGGREGATE] CountStar override_value={:?}",
count_star_override
);
count_star_override
}
_ => None,
},
});
}
let mut error: Option<Error> = None;
match table.table.scan_stream(
projections,
&filter_expr,
ScanStreamOptions {
row_id_filter: row_filter.clone(),
..options
},
|batch| {
if error.is_some() {
return;
}
for state in &mut states {
if let Err(err) = state.update(&batch) {
error = Some(err);
return;
}
}
},
) {
Ok(()) => {}
Err(llkv_result::Error::NotFound) => {
}
Err(err) => return Err(err),
}
if let Some(err) = error {
return Err(err);
}
let mut fields = Vec::with_capacity(states.len());
let mut arrays: Vec<ArrayRef> = Vec::with_capacity(states.len());
for state in states {
let (field, array) = state.finalize()?;
fields.push(field);
arrays.push(array);
}
let schema = Arc::new(Schema::new(fields));
let batch = RecordBatch::try_new(Arc::clone(&schema), arrays)?;
Ok(SelectExecution::new_single_batch(
display_name,
schema,
batch,
))
}
fn execute_computed_aggregates(
&self,
table: Arc<ExecutorTable<P>>,
display_name: String,
plan: SelectPlan,
row_filter: Option<std::sync::Arc<dyn RowIdFilter<P>>>,
) -> ExecutorResult<SelectExecution<P>> {
use arrow::array::Int64Array;
use llkv_expr::expr::AggregateCall;
let table_ref = table.as_ref();
let mut aggregate_specs: Vec<(String, AggregateCall<String>)> = Vec::new();
for proj in &plan.projections {
if let SelectProjection::Computed { expr, .. } = proj {
Self::collect_aggregates(expr, &mut aggregate_specs);
}
}
let computed_aggregates = self.compute_aggregate_values(
table.clone(),
&plan.filter,
&aggregate_specs,
row_filter.clone(),
)?;
let mut fields = Vec::with_capacity(plan.projections.len());
let mut arrays: Vec<ArrayRef> = Vec::with_capacity(plan.projections.len());
for proj in &plan.projections {
match proj {
SelectProjection::AllColumns | SelectProjection::AllColumnsExcept { .. } => {
return Err(Error::InvalidArgumentError(
"Wildcard projections not supported with computed aggregates".into(),
));
}
SelectProjection::Column { name, alias } => {
let col = table_ref.schema.resolve(name).ok_or_else(|| {
Error::InvalidArgumentError(format!("unknown column '{}'", name))
})?;
let field_name = alias.as_ref().unwrap_or(name);
fields.push(arrow::datatypes::Field::new(
field_name,
col.data_type.clone(),
col.nullable,
));
return Err(Error::InvalidArgumentError(
"Regular columns not supported in aggregate queries without GROUP BY"
.into(),
));
}
SelectProjection::Computed { expr, alias } => {
let value = Self::evaluate_expr_with_aggregates(expr, &computed_aggregates)?;
fields.push(arrow::datatypes::Field::new(alias, DataType::Int64, false));
let array = Arc::new(Int64Array::from(vec![value])) as ArrayRef;
arrays.push(array);
}
}
}
let schema = Arc::new(Schema::new(fields));
let batch = RecordBatch::try_new(Arc::clone(&schema), arrays)?;
Ok(SelectExecution::new_single_batch(
display_name,
schema,
batch,
))
}
fn collect_aggregates(
expr: &ScalarExpr<String>,
aggregates: &mut Vec<(String, llkv_expr::expr::AggregateCall<String>)>,
) {
match expr {
ScalarExpr::Aggregate(agg) => {
let key = format!("{:?}", agg);
if !aggregates.iter().any(|(k, _)| k == &key) {
aggregates.push((key, agg.clone()));
}
}
ScalarExpr::Binary { left, right, .. } => {
Self::collect_aggregates(left, aggregates);
Self::collect_aggregates(right, aggregates);
}
ScalarExpr::GetField { base, .. } => {
Self::collect_aggregates(base, aggregates);
}
ScalarExpr::Column(_) | ScalarExpr::Literal(_) => {}
}
}
fn compute_aggregate_values(
&self,
table: Arc<ExecutorTable<P>>,
filter: &Option<llkv_expr::expr::Expr<'static, String>>,
aggregate_specs: &[(String, llkv_expr::expr::AggregateCall<String>)],
row_filter: Option<std::sync::Arc<dyn RowIdFilter<P>>>,
) -> ExecutorResult<FxHashMap<String, i64>> {
use llkv_expr::expr::AggregateCall;
let table_ref = table.as_ref();
let mut results =
FxHashMap::with_capacity_and_hasher(aggregate_specs.len(), Default::default());
let mut specs: Vec<AggregateSpec> = Vec::new();
for (key, agg) in aggregate_specs {
let kind = match agg {
AggregateCall::CountStar => AggregateKind::CountStar,
AggregateCall::Count(col_name) => {
let col = table_ref.schema.resolve(col_name).ok_or_else(|| {
Error::InvalidArgumentError(format!("unknown column '{}'", col_name))
})?;
AggregateKind::CountField {
field_id: col.field_id,
}
}
AggregateCall::Sum(col_name) => {
let col = table_ref.schema.resolve(col_name).ok_or_else(|| {
Error::InvalidArgumentError(format!("unknown column '{}'", col_name))
})?;
AggregateKind::SumInt64 {
field_id: col.field_id,
}
}
AggregateCall::Min(col_name) => {
let col = table_ref.schema.resolve(col_name).ok_or_else(|| {
Error::InvalidArgumentError(format!("unknown column '{}'", col_name))
})?;
AggregateKind::MinInt64 {
field_id: col.field_id,
}
}
AggregateCall::Max(col_name) => {
let col = table_ref.schema.resolve(col_name).ok_or_else(|| {
Error::InvalidArgumentError(format!("unknown column '{}'", col_name))
})?;
AggregateKind::MaxInt64 {
field_id: col.field_id,
}
}
AggregateCall::CountNulls(col_name) => {
let col = table_ref.schema.resolve(col_name).ok_or_else(|| {
Error::InvalidArgumentError(format!("unknown column '{}'", col_name))
})?;
AggregateKind::CountNulls {
field_id: col.field_id,
}
}
};
specs.push(AggregateSpec {
alias: key.clone(),
kind,
});
}
let filter_expr = match filter {
Some(expr) => crate::expression::translate_predicate(
expr.clone(),
table_ref.schema.as_ref(),
|name| Error::InvalidArgumentError(format!("unknown column '{}'", name)),
)?,
None => {
let field_id = table_ref.schema.first_field_id().ok_or_else(|| {
Error::InvalidArgumentError(
"table has no columns; cannot perform aggregate scan".into(),
)
})?;
crate::expression::full_table_scan_filter(field_id)
}
};
let mut projections: Vec<ScanProjection> = Vec::new();
let mut spec_to_projection: Vec<Option<usize>> = Vec::with_capacity(specs.len());
let count_star_override: Option<i64> = None;
for spec in &specs {
if let Some(field_id) = spec.kind.field_id() {
spec_to_projection.push(Some(projections.len()));
projections.push(ScanProjection::from(StoreProjection::with_alias(
LogicalFieldId::for_user(table.table.table_id(), field_id),
table
.schema
.column_by_field_id(field_id)
.map(|c| c.name.clone())
.unwrap_or_else(|| format!("col{field_id}")),
)));
} else {
spec_to_projection.push(None);
}
}
if projections.is_empty() {
let field_id = table_ref.schema.first_field_id().ok_or_else(|| {
Error::InvalidArgumentError(
"table has no columns; cannot perform aggregate scan".into(),
)
})?;
projections.push(ScanProjection::from(StoreProjection::with_alias(
LogicalFieldId::for_user(table.table.table_id(), field_id),
table
.schema
.column_by_field_id(field_id)
.map(|c| c.name.clone())
.unwrap_or_else(|| format!("col{field_id}")),
)));
}
let base_options = ScanStreamOptions {
include_nulls: true,
order: None,
row_id_filter: None,
};
let mut states: Vec<AggregateState> = Vec::with_capacity(specs.len());
for (idx, spec) in specs.iter().enumerate() {
states.push(AggregateState {
alias: spec.alias.clone(),
accumulator: AggregateAccumulator::new_with_projection_index(
spec,
spec_to_projection[idx],
count_star_override,
)?,
override_value: match spec.kind {
AggregateKind::CountStar => count_star_override,
_ => None,
},
});
}
let mut error: Option<Error> = None;
match table.table.scan_stream(
projections,
&filter_expr,
ScanStreamOptions {
row_id_filter: row_filter.clone(),
..base_options
},
|batch| {
if error.is_some() {
return;
}
for state in &mut states {
if let Err(err) = state.update(&batch) {
error = Some(err);
return;
}
}
},
) {
Ok(()) => {}
Err(llkv_result::Error::NotFound) => {}
Err(err) => return Err(err),
}
if let Some(err) = error {
return Err(err);
}
for state in states {
let alias = state.alias.clone();
let (_field, array) = state.finalize()?;
let int64_array = array
.as_any()
.downcast_ref::<arrow::array::Int64Array>()
.ok_or_else(|| Error::Internal("Expected Int64Array from aggregate".into()))?;
if int64_array.len() != 1 {
return Err(Error::Internal(format!(
"Expected single value from aggregate, got {}",
int64_array.len()
)));
}
let value = if int64_array.is_null(0) {
0
} else {
int64_array.value(0)
};
results.insert(alias, value);
}
Ok(results)
}
fn evaluate_expr_with_aggregates(
expr: &ScalarExpr<String>,
aggregates: &FxHashMap<String, i64>,
) -> ExecutorResult<i64> {
use llkv_expr::expr::BinaryOp;
use llkv_expr::literal::Literal;
match expr {
ScalarExpr::Literal(Literal::Integer(v)) => Ok(*v as i64),
ScalarExpr::Literal(Literal::Float(v)) => Ok(*v as i64),
ScalarExpr::Literal(Literal::Boolean(v)) => Ok(if *v { 1 } else { 0 }),
ScalarExpr::Literal(Literal::String(_)) => Err(Error::InvalidArgumentError(
"String literals not supported in aggregate expressions".into(),
)),
ScalarExpr::Literal(Literal::Null) => Err(Error::InvalidArgumentError(
"NULL literals not supported in aggregate expressions".into(),
)),
ScalarExpr::Literal(Literal::Struct(_)) => Err(Error::InvalidArgumentError(
"Struct literals not supported in aggregate expressions".into(),
)),
ScalarExpr::Column(_) => Err(Error::InvalidArgumentError(
"Column references not supported in aggregate-only expressions".into(),
)),
ScalarExpr::Aggregate(agg) => {
let key = format!("{:?}", agg);
aggregates.get(&key).copied().ok_or_else(|| {
Error::Internal(format!("Aggregate value not found for key: {}", key))
})
}
ScalarExpr::Binary { left, op, right } => {
let left_val = Self::evaluate_expr_with_aggregates(left, aggregates)?;
let right_val = Self::evaluate_expr_with_aggregates(right, aggregates)?;
let result = match op {
BinaryOp::Add => left_val.checked_add(right_val),
BinaryOp::Subtract => left_val.checked_sub(right_val),
BinaryOp::Multiply => left_val.checked_mul(right_val),
BinaryOp::Divide => {
if right_val == 0 {
return Err(Error::InvalidArgumentError("Division by zero".into()));
}
left_val.checked_div(right_val)
}
BinaryOp::Modulo => {
if right_val == 0 {
return Err(Error::InvalidArgumentError("Modulo by zero".into()));
}
left_val.checked_rem(right_val)
}
};
result.ok_or_else(|| {
Error::InvalidArgumentError("Arithmetic overflow in expression".into())
})
}
ScalarExpr::GetField { .. } => Err(Error::InvalidArgumentError(
"GetField not supported in aggregate-only expressions".into(),
)),
}
}
}
#[derive(Clone)]
pub struct SelectExecution<P>
where
P: Pager<Blob = EntryHandle> + Send + Sync,
{
table_name: String,
schema: Arc<Schema>,
stream: SelectStream<P>,
}
#[derive(Clone)]
enum SelectStream<P>
where
P: Pager<Blob = EntryHandle> + Send + Sync,
{
Projection {
table: Arc<ExecutorTable<P>>,
projections: Vec<ScanProjection>,
filter_expr: LlkvExpr<'static, FieldId>,
options: ScanStreamOptions<P>,
full_table_scan: bool,
order_by: Vec<OrderByPlan>,
},
Aggregation {
batch: RecordBatch,
},
}
impl<P> SelectExecution<P>
where
P: Pager<Blob = EntryHandle> + Send + Sync,
{
#[allow(clippy::too_many_arguments)]
fn new_projection(
table_name: String,
schema: Arc<Schema>,
table: Arc<ExecutorTable<P>>,
projections: Vec<ScanProjection>,
filter_expr: LlkvExpr<'static, FieldId>,
options: ScanStreamOptions<P>,
full_table_scan: bool,
order_by: Vec<OrderByPlan>,
) -> Self {
Self {
table_name,
schema,
stream: SelectStream::Projection {
table,
projections,
filter_expr,
options,
full_table_scan,
order_by,
},
}
}
pub fn new_single_batch(table_name: String, schema: Arc<Schema>, batch: RecordBatch) -> Self {
Self {
table_name,
schema,
stream: SelectStream::Aggregation { batch },
}
}
pub fn from_batch(table_name: String, schema: Arc<Schema>, batch: RecordBatch) -> Self {
Self::new_single_batch(table_name, schema, batch)
}
pub fn table_name(&self) -> &str {
&self.table_name
}
pub fn schema(&self) -> Arc<Schema> {
Arc::clone(&self.schema)
}
pub fn stream(
self,
mut on_batch: impl FnMut(RecordBatch) -> ExecutorResult<()>,
) -> ExecutorResult<()> {
let schema = Arc::clone(&self.schema);
match self.stream {
SelectStream::Projection {
table,
projections,
filter_expr,
options,
full_table_scan,
order_by,
} => {
let total_rows = table.total_rows.load(Ordering::SeqCst);
if total_rows == 0 {
return Ok(());
}
let mut error: Option<Error> = None;
let mut produced = false;
let mut produced_rows: u64 = 0;
let capture_nulls_first = matches!(options.order, Some(spec) if spec.nulls_first);
let needs_post_sort = order_by.len() > 1;
let collect_batches = needs_post_sort || capture_nulls_first;
let include_nulls = options.include_nulls;
let has_row_id_filter = options.row_id_filter.is_some();
let scan_options = options;
let mut buffered_batches: Vec<RecordBatch> = Vec::new();
table
.table
.scan_stream(projections, &filter_expr, scan_options, |batch| {
if error.is_some() {
return;
}
produced = true;
produced_rows = produced_rows.saturating_add(batch.num_rows() as u64);
if collect_batches {
buffered_batches.push(batch);
} else if let Err(err) = on_batch(batch) {
error = Some(err);
}
})?;
if let Some(err) = error {
return Err(err);
}
if !produced {
if total_rows > 0 {
for batch in synthesize_null_scan(Arc::clone(&schema), total_rows)? {
on_batch(batch)?;
}
}
return Ok(());
}
let mut null_batches: Vec<RecordBatch> = Vec::new();
if include_nulls
&& full_table_scan
&& produced_rows < total_rows
&& !has_row_id_filter
{
let missing = total_rows - produced_rows;
if missing > 0 {
null_batches = synthesize_null_scan(Arc::clone(&schema), missing)?;
}
}
if collect_batches {
if needs_post_sort {
if !null_batches.is_empty() {
buffered_batches.extend(null_batches);
}
if !buffered_batches.is_empty() {
let combined =
concat_batches(&schema, &buffered_batches).map_err(|err| {
Error::InvalidArgumentError(format!(
"failed to concatenate result batches for ORDER BY: {}",
err
))
})?;
let sorted_batch =
sort_record_batch_with_order(&schema, &combined, &order_by)?;
on_batch(sorted_batch)?;
}
} else if capture_nulls_first {
for batch in null_batches {
on_batch(batch)?;
}
for batch in buffered_batches {
on_batch(batch)?;
}
}
} else if !null_batches.is_empty() {
for batch in null_batches {
on_batch(batch)?;
}
}
Ok(())
}
SelectStream::Aggregation { batch } => on_batch(batch),
}
}
pub fn collect(self) -> ExecutorResult<Vec<RecordBatch>> {
let mut batches = Vec::new();
self.stream(|batch| {
batches.push(batch);
Ok(())
})?;
Ok(batches)
}
pub fn collect_rows(self) -> ExecutorResult<RowBatch> {
let schema = self.schema();
let mut rows: Vec<Vec<PlanValue>> = Vec::new();
self.stream(|batch| {
for row_idx in 0..batch.num_rows() {
let mut row: Vec<PlanValue> = Vec::with_capacity(batch.num_columns());
for col_idx in 0..batch.num_columns() {
let value = llkv_plan::plan_value_from_array(batch.column(col_idx), row_idx)?;
row.push(value);
}
rows.push(row);
}
Ok(())
})?;
let columns = schema
.fields()
.iter()
.map(|field| field.name().to_string())
.collect();
Ok(RowBatch { columns, rows })
}
pub fn into_rows(self) -> ExecutorResult<Vec<Vec<PlanValue>>> {
Ok(self.collect_rows()?.rows)
}
}
impl<P> fmt::Debug for SelectExecution<P>
where
P: Pager<Blob = EntryHandle> + Send + Sync,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("SelectExecution")
.field("table_name", &self.table_name)
.field("schema", &self.schema)
.finish()
}
}