use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use arrow::array::RecordBatch;
use arrow::compute;
use bytes::Bytes;
use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder;
use parquet::arrow::ProjectionMask;
use tracing::{debug, info};
use super::select::{
ColumnRef, Condition, InputFormat, OutputFormat, ParsedQuery, Record, SelectColumn, SelectError,
};
#[derive(Clone)]
pub struct QueryPlanCache {
cache: Arc<RwLock<HashMap<String, CachedPlan>>>,
max_entries: usize,
}
impl QueryPlanCache {
pub fn new(max_entries: usize) -> Self {
Self {
cache: Arc::new(RwLock::new(HashMap::new())),
max_entries,
}
}
pub fn get(&self, sql: &str) -> Option<CachedPlan> {
self.cache.read().ok()?.get(sql).cloned()
}
pub fn insert(&self, sql: String, plan: CachedPlan) {
if let Ok(mut cache) = self.cache.write() {
if cache.len() >= self.max_entries {
if let Some(oldest_key) = cache.keys().next().cloned() {
cache.remove(&oldest_key);
}
}
cache.insert(sql, plan);
}
}
pub fn clear(&self) {
if let Ok(mut cache) = self.cache.write() {
cache.clear();
}
}
}
#[derive(Clone)]
pub struct CachedPlan {
pub parsed_query: ParsedQuery,
pub projection_indices: Option<Vec<usize>>,
pub pushdown_predicates: Vec<PushdownPredicate>,
}
#[derive(Clone, Debug)]
pub enum PushdownPredicate {
ColumnFilter {
column: String,
op: super::select::CompareOp,
value: super::select::FieldValue,
},
NumericRange {
column: String,
min: Option<f64>,
max: Option<f64>,
},
}
pub struct OptimizedSelectExecutor {
query: ParsedQuery,
input_format: InputFormat,
output_format: OutputFormat,
query_cache: Option<Arc<QueryPlanCache>>,
parallel_threshold: usize, }
impl OptimizedSelectExecutor {
pub fn new(
parsed_query: ParsedQuery,
input_format: InputFormat,
output_format: OutputFormat,
query_cache: Option<Arc<QueryPlanCache>>,
) -> Self {
Self {
query: parsed_query,
input_format,
output_format,
query_cache,
parallel_threshold: 10 * 1024 * 1024, }
}
pub fn execute(&self, data: &[u8], sql: &str) -> Result<Vec<u8>, SelectError> {
info!(
size = data.len(),
format = ?self.input_format,
"Executing optimized S3 Select query"
);
if let Some(cache) = &self.query_cache {
if let Some(cached_plan) = cache.get(sql) {
debug!("Using cached query plan");
let _ = cached_plan;
} else {
let plan = CachedPlan {
parsed_query: self.query.clone(),
projection_indices: None,
pushdown_predicates: Vec::new(),
};
cache.insert(sql.to_string(), plan);
debug!("Cached query plan for future use");
}
}
let use_parallel = data.len() >= self.parallel_threshold
&& matches!(self.input_format, InputFormat::Parquet);
let records = if use_parallel {
self.execute_parallel(data)?
} else {
self.execute_optimized(data)?
};
super::select::SelectExecutor::serialize_output_static(records, &self.output_format)
}
fn execute_optimized(&self, data: &[u8]) -> Result<Vec<Record>, SelectError> {
match &self.input_format {
InputFormat::Parquet => self.execute_parquet_optimized(data),
InputFormat::Csv(config) => {
let executor = super::select::SelectExecutor::new_from_parts(
self.query.clone(),
self.input_format.clone(),
self.output_format.clone(),
);
let records = executor.parse_csv_public(data, config)?;
let filtered = executor.filter_records_public(records)?;
executor.project_columns_public(filtered)
}
InputFormat::Json(config) => {
let executor = super::select::SelectExecutor::new_from_parts(
self.query.clone(),
self.input_format.clone(),
self.output_format.clone(),
);
let records = executor.parse_json_public(data, config)?;
let filtered = executor.filter_records_public(records)?;
executor.project_columns_public(filtered)
}
InputFormat::Avro => {
let executor = super::select::SelectExecutor::new_from_parts(
self.query.clone(),
self.input_format.clone(),
self.output_format.clone(),
);
let records = executor.parse_avro_public(data)?;
let filtered = executor.filter_records_public(records)?;
executor.project_columns_public(filtered)
}
InputFormat::Orc => {
let executor = super::select::SelectExecutor::new_from_parts(
self.query.clone(),
self.input_format.clone(),
self.output_format.clone(),
);
let records = executor.parse_orc_public(data)?;
let filtered = executor.filter_records_public(records)?;
executor.project_columns_public(filtered)
}
InputFormat::Protobuf => {
let executor = super::select::SelectExecutor::new_from_parts(
self.query.clone(),
self.input_format.clone(),
self.output_format.clone(),
);
let records = executor.parse_protobuf_public(data)?;
let filtered = executor.filter_records_public(records)?;
executor.project_columns_public(filtered)
}
InputFormat::MessagePack => {
let executor = super::select::SelectExecutor::new_from_parts(
self.query.clone(),
self.input_format.clone(),
self.output_format.clone(),
);
let records = executor.parse_messagepack_public(data)?;
let filtered = executor.filter_records_public(records)?;
executor.project_columns_public(filtered)
}
}
}
fn execute_parquet_optimized(&self, data: &[u8]) -> Result<Vec<Record>, SelectError> {
debug!("Executing Parquet query with optimizations");
let bytes = Bytes::copy_from_slice(data);
let builder = ParquetRecordBatchReaderBuilder::try_new(bytes)?;
let (projection_mask, column_indices) =
self.build_projection_mask_from_builder(&builder)?;
let reader = builder.with_projection(projection_mask).build()?;
let mut all_records = Vec::new();
let mut rows_processed = 0;
let limit = self.query.limit.unwrap_or(usize::MAX);
for batch_result in reader {
let batch = batch_result?;
let filtered_batch = if let Some(condition) = &self.query.where_clause {
self.apply_arrow_filter(&batch, condition, &column_indices)?
} else {
batch
};
let records = self.batch_to_records(&filtered_batch, &column_indices)?;
for record in records {
if rows_processed >= limit {
return Ok(all_records);
}
all_records.push(record);
rows_processed += 1;
}
if rows_processed >= limit {
break;
}
}
debug!(
records = all_records.len(),
"Parquet query completed with optimizations"
);
Ok(all_records)
}
fn build_projection_mask_from_builder(
&self,
builder: &ParquetRecordBatchReaderBuilder<Bytes>,
) -> Result<(ProjectionMask, HashMap<String, usize>), SelectError> {
let schema = builder.schema();
let parquet_schema = builder.parquet_schema();
let mut column_indices = HashMap::new();
let needed_columns = self.get_needed_columns()?;
if needed_columns.is_empty() || needed_columns.contains(&"*".to_string()) {
for (idx, field) in schema.fields().iter().enumerate() {
column_indices.insert(field.name().clone(), idx);
}
let mask = ProjectionMask::all();
Ok((mask, column_indices))
} else {
let mut indices = Vec::new();
for (idx, field) in schema.fields().iter().enumerate() {
if needed_columns.contains(field.name()) {
indices.push(idx);
column_indices.insert(field.name().clone(), idx);
}
}
let mask = ProjectionMask::leaves(parquet_schema, indices.iter().copied());
debug!(
columns = ?needed_columns,
"Built projection mask for column pruning"
);
Ok((mask, column_indices))
}
}
fn get_needed_columns(&self) -> Result<Vec<String>, SelectError> {
let mut columns = Vec::new();
for col in &self.query.columns {
match col {
SelectColumn::Column(ColumnRef::All) => return Ok(vec!["*".to_string()]),
SelectColumn::Column(ColumnRef::Named(name)) => columns.push(name.clone()),
SelectColumn::Column(ColumnRef::Indexed(_)) => {
return Ok(vec!["*".to_string()]);
}
SelectColumn::Aggregate { column, .. } => {
if let Some(col_name) = column {
columns.push(col_name.clone());
}
}
}
}
if let Some(condition) = &self.query.where_clause {
columns.extend(self.extract_columns_from_condition(condition));
}
Ok(columns)
}
fn extract_columns_from_condition(&self, condition: &Condition) -> Vec<String> {
let mut columns = Vec::new();
match condition {
Condition::Comparison { left, right, .. } => {
if let super::select::Operand::Column(ColumnRef::Named(name)) = left {
columns.push(name.clone());
}
if let super::select::Operand::Column(ColumnRef::Named(name)) = right {
columns.push(name.clone());
}
}
Condition::And(left, right) | Condition::Or(left, right) => {
columns.extend(self.extract_columns_from_condition(left));
columns.extend(self.extract_columns_from_condition(right));
}
Condition::Not(inner) => {
columns.extend(self.extract_columns_from_condition(inner));
}
Condition::IsNull(operand) | Condition::IsNotNull(operand) => {
if let super::select::Operand::Column(ColumnRef::Named(name)) = operand {
columns.push(name.clone());
}
}
Condition::Like { value, .. } => {
if let super::select::Operand::Column(ColumnRef::Named(name)) = value {
columns.push(name.clone());
}
}
}
columns
}
fn apply_arrow_filter(
&self,
batch: &RecordBatch,
condition: &Condition,
column_indices: &HashMap<String, usize>,
) -> Result<RecordBatch, SelectError> {
let filter_array = self.build_filter_array(batch, condition, column_indices)?;
let filtered_batch = compute::filter_record_batch(batch, &filter_array)?;
Ok(filtered_batch)
}
fn build_filter_array(
&self,
batch: &RecordBatch,
condition: &Condition,
_column_indices: &HashMap<String, usize>,
) -> Result<arrow::array::BooleanArray, SelectError> {
use arrow::array::BooleanArray;
let num_rows = batch.num_rows();
let mut filter_values = Vec::with_capacity(num_rows);
for row_idx in 0..num_rows {
let record = self.batch_row_to_record(batch, row_idx)?;
let matches = super::select::evaluate_condition_public(condition, &record)?;
filter_values.push(matches);
}
Ok(BooleanArray::from(filter_values))
}
fn batch_row_to_record(
&self,
batch: &RecordBatch,
row_idx: usize,
) -> Result<Record, SelectError> {
let schema = batch.schema();
let mut record_map = HashMap::new();
for (col_idx, field) in schema.fields().iter().enumerate() {
let column = batch.column(col_idx);
let field_name = field.name().clone();
let value = super::select::SelectExecutor::extract_arrow_value_public(
column.as_ref(),
row_idx,
)?;
record_map.insert(field_name, value);
}
Ok(Record::Map(record_map))
}
fn batch_to_records(
&self,
batch: &RecordBatch,
_column_indices: &HashMap<String, usize>,
) -> Result<Vec<Record>, SelectError> {
let num_rows = batch.num_rows();
let mut records = Vec::with_capacity(num_rows);
for row_idx in 0..num_rows {
records.push(self.batch_row_to_record(batch, row_idx)?);
}
Ok(records)
}
fn execute_parallel(&self, data: &[u8]) -> Result<Vec<Record>, SelectError> {
info!("Executing query in parallel mode");
if matches!(self.input_format, InputFormat::Parquet) {
return self.execute_parquet_parallel(data);
}
self.execute_optimized(data)
}
fn execute_parquet_parallel(&self, data: &[u8]) -> Result<Vec<Record>, SelectError> {
use rayon::prelude::*;
let bytes = Bytes::copy_from_slice(data);
let builder = ParquetRecordBatchReaderBuilder::try_new(bytes)?;
let (projection_mask, column_indices) =
self.build_projection_mask_from_builder(&builder)?;
let reader = builder.with_projection(projection_mask).build()?;
let limit = self.query.limit.unwrap_or(usize::MAX);
let batches: Result<Vec<_>, _> = reader.collect();
let batches = batches?;
let query = &self.query;
let column_indices = &column_indices;
let mut all_records: Vec<Record> = batches
.par_iter()
.map(|batch| {
let filtered_batch = if let Some(condition) = &query.where_clause {
self.apply_arrow_filter(batch, condition, column_indices)
} else {
Ok(batch.clone())
};
match filtered_batch {
Ok(batch) => self.batch_to_records(&batch, column_indices),
Err(e) => Err(e),
}
})
.collect::<Result<Vec<Vec<Record>>, _>>()?
.into_iter()
.flatten()
.collect();
all_records.truncate(limit);
info!(
records = all_records.len(),
"Parallel query execution completed"
);
Ok(all_records)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
#[test]
fn test_query_plan_cache() {
let cache = QueryPlanCache::new(10);
let plan = CachedPlan {
parsed_query: ParsedQuery {
columns: vec![SelectColumn::Column(ColumnRef::All)],
from_alias: None,
where_clause: None,
group_by: None,
order_by: None,
limit: None,
},
projection_indices: None,
pushdown_predicates: Vec::new(),
};
cache.insert("SELECT * FROM s3object".to_string(), plan.clone());
let cached = cache.get("SELECT * FROM s3object");
assert!(cached.is_some());
}
#[test]
fn test_cache_eviction() {
let cache = QueryPlanCache::new(2);
for i in 0..3 {
let plan = CachedPlan {
parsed_query: ParsedQuery {
columns: vec![SelectColumn::Column(ColumnRef::All)],
from_alias: None,
where_clause: None,
group_by: None,
order_by: None,
limit: None,
},
projection_indices: None,
pushdown_predicates: Vec::new(),
};
cache.insert(format!("SELECT {}", i), plan);
}
assert!(cache.cache.read().map(|c| c.len() <= 2).unwrap_or(false));
}
#[test]
fn test_optimized_executor_with_cache() {
let cache = Arc::new(QueryPlanCache::new(10));
let query = ParsedQuery {
columns: vec![SelectColumn::Column(ColumnRef::All)],
from_alias: None,
where_clause: None,
group_by: None,
order_by: None,
limit: None,
};
let executor = OptimizedSelectExecutor::new(
query,
InputFormat::Json(super::super::select::JsonInput {
json_type: super::super::select::JsonType::Lines,
}),
OutputFormat::Json(super::super::select::JsonOutput::default()),
Some(cache.clone()),
);
let data = br#"{"name":"Alice","age":30}
{"name":"Bob","age":25}"#;
let sql = "SELECT * FROM s3object";
let result1 = executor.execute(data, sql);
assert!(result1.is_ok());
let result2 = executor.execute(data, sql);
assert!(result2.is_ok());
assert!(cache.get(sql).is_some());
}
#[test]
fn test_parallel_execution_threshold() {
let query = ParsedQuery {
columns: vec![SelectColumn::Column(ColumnRef::All)],
from_alias: None,
where_clause: None,
group_by: None,
order_by: None,
limit: None,
};
let executor = OptimizedSelectExecutor::new(
query,
InputFormat::Json(super::super::select::JsonInput {
json_type: super::super::select::JsonType::Lines,
}),
OutputFormat::Json(super::super::select::JsonOutput::default()),
None,
);
let small_data = br#"{"name":"Alice"}"#;
assert!(small_data.len() < executor.parallel_threshold);
}
}