use crate::error::{DbxError, DbxResult};
use crate::grid::manager::GridManager;
use crate::grid::protocol::{GridMessage, QueryMessage};
use crate::grid::quic::QuicChannel;
use crate::sql::executor::fragment_splitter::FragmentSplitter;
use crate::sql::executor::local_executor::LocalExecutor;
use crate::sql::planner::types::PhysicalPlan;
use crate::storage::metadata::MetadataRegistry;
use arrow::array::RecordBatch;
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::{SystemTime, UNIX_EPOCH};
use tokio::sync::mpsc;
use tracing::{error, info, warn};
type NodeScoreMap = HashMap<SocketAddr, usize>;
pub struct DistributedExecutor {
quic_channel: Arc<QuicChannel>,
grid_manager: Arc<GridManager>,
local_executor: Arc<LocalExecutor>,
peer_addrs: Vec<SocketAddr>,
metadata_registry: Arc<MetadataRegistry>,
}
impl DistributedExecutor {
pub fn new(
quic_channel: Arc<QuicChannel>,
grid_manager: Arc<GridManager>,
local_executor: Arc<LocalExecutor>,
peer_addrs: Vec<SocketAddr>,
metadata_registry: Arc<MetadataRegistry>,
) -> Self {
Self {
quic_channel,
grid_manager,
local_executor,
peer_addrs,
metadata_registry,
}
}
pub async fn execute(&self, plan: PhysicalPlan) -> DbxResult<Vec<RecordBatch>> {
if self.peer_addrs.is_empty() {
return self.local_executor.execute_collect(&plan);
}
let dag = FragmentSplitter::split(plan)?;
let coord_plan = match dag.coordinator_plan {
None => {
info!("No distributed split found — executing locally");
let worker_plan = dag
.stages
.into_iter()
.next()
.unwrap()
.plans
.into_iter()
.next()
.unwrap();
return self.local_executor.execute_collect(&worker_plan);
}
Some(p) => p,
};
let execution_id = {
let nanos = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.subsec_nanos();
let secs = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
format!("exec-{}-{}", secs, nanos)
};
let mut channels = crate::sql::executor::local_executor::DistributedChannels::default();
let mut root_op = self
.local_executor
.build_operator_distributed(&coord_plan, &mut channels)?;
let query_streams = self.grid_manager.get_query_streams();
for (e_id, tx) in channels.exchanges {
query_streams.insert((execution_id.clone(), e_id), tx);
}
let coord_task = tokio::task::spawn_blocking(move || {
let mut results = Vec::new();
while let Some(batch) = root_op.next()? {
if batch.num_rows() > 0 {
results.push(batch);
}
}
Ok::<_, DbxError>(results)
});
let coordinator_addr = self.quic_channel.local_addr.to_string();
for stage in dag.stages {
let stage_id = stage.stage_id;
let mut pending_workers = self.peer_addrs.clone();
let mut node_scores: NodeScoreMap = HashMap::new();
fn collect_table_scans(p: &PhysicalPlan) -> Vec<&PhysicalPlan> {
let mut scans = Vec::new();
match p {
PhysicalPlan::TableScan { .. } => scans.push(p),
PhysicalPlan::Projection { input, .. }
| PhysicalPlan::SortMerge { input, .. }
| PhysicalPlan::Limit { input, .. }
| PhysicalPlan::HashAggregate { input, .. } => {
scans.extend(collect_table_scans(input))
}
PhysicalPlan::HashJoin { left, right, .. } => {
scans.extend(collect_table_scans(left));
scans.extend(collect_table_scans(right));
}
_ => {}
}
scans
}
for p in &stage.plans {
let scans = collect_table_scans(p);
for scan in scans {
if let PhysicalPlan::TableScan {
table, ros_files, ..
} = scan
&& let Some(table_meta) = self.metadata_registry.tables.get(table)
{
for part_ref in table_meta.partitions.iter() {
let part = part_ref.value();
if (ros_files.is_empty() || ros_files.contains(&part.file_path))
&& let Some(ref addr_str) = part.node_addr
&& let Ok(addr) = addr_str.parse::<std::net::SocketAddr>()
{
*node_scores.entry(addr).or_insert(0) += part.row_count.max(1);
}
}
}
}
}
pending_workers.sort_by(|a, b| {
let score_a = node_scores.get(a).copied().unwrap_or(0);
let score_b = node_scores.get(b).copied().unwrap_or(0);
score_b.cmp(&score_a)
});
let mut plans_bytes = Vec::new();
for p in stage.plans {
let bytes =
bincode::serialize(&p).map_err(|e| DbxError::Serialization(e.to_string()))?;
plans_bytes.push(bytes);
}
let max_retries = 3;
let timeout_secs = std::env::var("DBX_WORKER_TIMEOUT_SECS")
.unwrap_or_else(|_| "30".to_string())
.parse::<u64>()
.unwrap_or(30);
let timeout_duration = std::time::Duration::from_secs(timeout_secs);
for retry_count in 0..=max_retries {
if pending_workers.is_empty() {
break;
}
info!(
"Dispatching Stage {} to {} workers (exec_id: {}, retry: {})",
stage_id,
pending_workers.len(),
execution_id,
retry_count
);
let stage_barriers = self.grid_manager.get_stage_barriers();
let mut awaiters = Vec::new();
for peer in &pending_workers {
let (tx, mut rx) = mpsc::channel(1);
stage_barriers.insert((execution_id.clone(), stage_id, *peer), tx);
let msg = GridMessage::Query(QueryMessage::ExecuteFragment {
execution_id: execution_id.clone(),
stage_id,
plans_bytes: plans_bytes.clone(),
coordinator_addr: coordinator_addr.clone(),
});
if let Err(e) = self.quic_channel.send_message(*peer, msg).await {
warn!("Failed to send Stage {} to {}: {:?}", stage_id, peer, e);
}
let peer_addr = *peer;
awaiters.push(async move {
let res = tokio::time::timeout(timeout_duration, rx.recv()).await;
(peer_addr, res.is_ok())
});
}
let results = futures::future::join_all(awaiters).await;
let mut retry_peers = Vec::new();
for (peer_addr, success) in results {
stage_barriers.remove(&(execution_id.clone(), stage_id, peer_addr));
if !success {
warn!("Worker {} timed out on Stage {}", peer_addr, stage_id);
retry_peers.push(peer_addr);
}
}
pending_workers = retry_peers;
if !pending_workers.is_empty() && retry_count == max_retries {
error!(
"Max retries exceeded for Stage {} on workers: {:?}",
stage_id, pending_workers
);
return Err(DbxError::Network(format!(
"Stage {} timed out after {} retries",
stage_id, max_retries
)));
}
}
}
let final_results = coord_task
.await
.map_err(|e| DbxError::Network(format!("Coordinator thread panic: {:?}", e)))??;
let keys_to_remove: Vec<_> = query_streams
.iter()
.filter(|k| k.key().0 == execution_id)
.map(|k| k.key().clone())
.collect();
for k in keys_to_remove {
query_streams.remove(&k);
}
info!(
"Distributed execution {} finished successfully",
execution_id
);
Ok(final_results)
}
}