use crate::error::{DbxError, DbxResult};
use crate::sql::executor::operators::{
FilterOperator, HashAggregateOperator, HashJoinOperator, LimitOperator, PhysicalOperator,
ProjectionOperator, SortOperator, TableScanOperator,
};
use crate::sql::executor::operators::{GridExchangeOperator, GridShuffleWriterOperator};
use crate::sql::planner::types::*;
use arrow::array::{RecordBatch, StringArray};
use arrow::datatypes::{DataType, Field, Schema};
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use tokio::sync::mpsc;
type ShuffleReceiver = tokio::sync::mpsc::Receiver<DbxResult<Option<Vec<u8>>>>;
type ShuffleMap = std::collections::HashMap<usize, Vec<(std::net::SocketAddr, ShuffleReceiver)>>;
#[derive(Default)]
pub struct DistributedChannels {
pub exchanges: HashMap<usize, mpsc::Sender<DbxResult<Option<Vec<u8>>>>>,
pub shuffles: ShuffleMap,
}
pub struct LocalExecutor {
table_store: Arc<RwLock<HashMap<String, Vec<RecordBatch>>>>,
table_schemas: Arc<RwLock<HashMap<String, Arc<Schema>>>>,
}
impl LocalExecutor {
pub fn new(
table_store: Arc<RwLock<HashMap<String, Vec<RecordBatch>>>>,
table_schemas: Arc<RwLock<HashMap<String, Arc<Schema>>>>,
) -> Self {
Self {
table_store,
table_schemas,
}
}
pub fn execute_collect(&self, plan: &PhysicalPlan) -> DbxResult<Vec<RecordBatch>> {
let mut operator = self.build_operator(plan)?;
let mut results = Vec::new();
while let Some(batch) = operator.next()? {
if batch.num_rows() > 0 {
results.push(batch);
}
}
Ok(results)
}
pub fn execute_collect_distributed(
&self,
plan: &PhysicalPlan,
channels: &mut DistributedChannels,
) -> DbxResult<Vec<RecordBatch>> {
let mut operator = self.build_operator_distributed(plan, channels)?;
let mut results = Vec::new();
while let Some(batch) = operator.next()? {
if batch.num_rows() > 0 {
results.push(batch);
}
}
Ok(results)
}
pub fn build_operator(&self, plan: &PhysicalPlan) -> DbxResult<Box<dyn PhysicalOperator>> {
self.build_operator_internal(plan, &mut None)
}
pub fn build_operator_distributed(
&self,
plan: &PhysicalPlan,
channels: &mut DistributedChannels,
) -> DbxResult<Box<dyn PhysicalOperator>> {
self.build_operator_internal(plan, &mut Some(channels))
}
fn build_operator_internal(
&self,
plan: &PhysicalPlan,
channels: &mut Option<&mut DistributedChannels>,
) -> DbxResult<Box<dyn PhysicalOperator>> {
match plan {
PhysicalPlan::TableScan {
table,
projection,
filter,
ros_files,
} => {
let store = self.table_store.read().unwrap();
let wos_batches = store.get(table).cloned().unwrap_or_default();
let schemas = self.table_schemas.read().unwrap();
let schema = schemas
.get(table)
.cloned()
.ok_or_else(|| DbxError::TableNotFound(table.clone()))?;
drop(schemas);
drop(store);
let mut op = TableScanOperator::new(table.clone(), schema, projection.clone());
if ros_files.is_empty() {
op.set_data(wos_batches); } else {
op.start_tier_scan(wos_batches, ros_files.clone());
}
if let Some(f) = filter {
Ok(Box::new(FilterOperator::new(Box::new(op), f.clone())))
} else {
Ok(Box::new(op))
}
}
PhysicalPlan::HashAggregate {
input,
group_by,
aggregates,
mode,
} => {
let input_op = self.build_operator_internal(input, channels)?;
let input_schema = input_op.schema().clone();
let mut output_fields = Vec::new();
for &col_idx in group_by.iter() {
if col_idx < input_schema.fields().len() {
output_fields.push(input_schema.field(col_idx).clone());
}
}
for agg in aggregates {
let name = agg
.alias
.clone()
.unwrap_or_else(|| format!("agg_{}", agg.input));
let dtype = match agg.function {
AggregateFunction::Count => DataType::Int64,
AggregateFunction::Sum
| AggregateFunction::Avg
| AggregateFunction::Min
| AggregateFunction::Max => DataType::Float64,
};
output_fields.push(Field::new(&name, dtype, true));
}
let output_schema = Arc::new(Schema::new(output_fields));
Ok(Box::new(HashAggregateOperator::new(
input_op,
output_schema,
group_by.clone(),
aggregates.clone(),
*mode,
)))
}
PhysicalPlan::Projection {
input,
exprs,
aliases,
} => {
let input_op = self.build_operator_internal(input, channels)?;
let input_schema = input_op.schema().clone();
let output_fields: Vec<Field> = exprs
.iter()
.zip(aliases.iter())
.map(|(expr, alias)| {
let dtype = expr.get_type(&input_schema);
let name = alias.clone().unwrap_or_else(|| "col".to_string());
Field::new(&name, dtype, true)
})
.collect();
let output_schema = Arc::new(Schema::new(output_fields));
Ok(Box::new(ProjectionOperator::new(
input_op,
output_schema,
exprs.clone(),
)))
}
PhysicalPlan::Limit {
input,
count,
offset,
} => {
let input_op = self.build_operator_internal(input, channels)?;
Ok(Box::new(LimitOperator::new(input_op, *count, *offset)))
}
PhysicalPlan::SortMerge { input, order_by } => {
let input_op = self.build_operator_internal(input, channels)?;
Ok(Box::new(SortOperator::new(input_op, order_by.clone())))
}
PhysicalPlan::HashJoin {
left,
right,
on,
join_type,
} => {
let left_op = self.build_operator_internal(left, channels)?;
let right_op = self.build_operator_internal(right, channels)?;
let left_schema = left_op.schema().clone();
let right_schema = right_op.schema().clone();
let mut all_fields = left_schema.fields().to_vec();
all_fields.extend(right_schema.fields().to_vec());
let join_schema = Arc::new(Schema::new(all_fields));
Ok(Box::new(HashJoinOperator::new(
left_op,
right_op,
join_schema,
on.clone(),
*join_type,
)))
}
PhysicalPlan::GridExchange {
exchange_id,
schema_hint,
} => {
if let Some(ch) = channels.as_mut() {
let (tx, rx) = mpsc::channel(64);
ch.exchanges.insert(*exchange_id, tx);
let mut fields = Vec::with_capacity(*schema_hint);
for i in 0..*schema_hint {
fields.push(Field::new(format!("col_{}", i), DataType::Float64, true));
}
let schema = Arc::new(Schema::new(fields));
Ok(Box::new(GridExchangeOperator::new(schema, rx)))
} else {
Err(DbxError::SqlExecution {
message: "GridExchange encountered but no DistributedChannels provided"
.to_string(),
context: "LocalExecutor::build_operator_internal".to_string(),
})
}
}
PhysicalPlan::ShuffleWriter {
input,
hash_params,
target_nodes,
exchange_id,
salting,
} => {
if let Some(ch) = channels.as_mut() {
let input_op = self.build_operator_internal(input, &mut Some(*ch))?;
let mut senders = Vec::new();
let mut receivers = Vec::new();
for target_addr in target_nodes {
let (tx, rx) = mpsc::channel(64);
senders.push(tx);
receivers.push((*target_addr, rx));
}
ch.shuffles.insert(*exchange_id, receivers);
Ok(Box::new(GridShuffleWriterOperator::new(
input_op,
hash_params.clone(),
*exchange_id,
salting.clone(),
senders,
)))
} else {
Err(DbxError::SqlExecution {
message: "ShuffleWriter encountered but no DistributedChannels provided"
.to_string(),
context: "LocalExecutor::build_operator_internal".to_string(),
})
}
}
other => Err(DbxError::SqlNotSupported {
feature: format!(
"LocalExecutor: {:?} plan type",
std::mem::discriminant(other)
),
hint: "DML/DDL은 StorageEngine을 통해 실행하세요".to_string(),
}),
}
}
}
pub fn make_dummy_table(rows: Vec<(i32, String, i64)>) -> (Arc<Schema>, Vec<RecordBatch>) {
use arrow::array::{Int32Array, Int64Array};
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("name", DataType::Utf8, false),
Field::new("value", DataType::Int64, false),
]));
let ids: Vec<i32> = rows.iter().map(|(id, _, _)| *id).collect();
let names: Vec<&str> = rows.iter().map(|(_, name, _)| name.as_str()).collect();
let values: Vec<i64> = rows.iter().map(|(_, _, val)| *val).collect();
let batch = RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(Int32Array::from(ids)),
Arc::new(StringArray::from(names)),
Arc::new(Int64Array::from(values)),
],
)
.unwrap();
(schema, vec![batch])
}