use std::{
collections::{HashMap, HashSet},
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
use parking_lot::Mutex;
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use arrow::array::RecordBatch;
use arrow::datatypes::Schema;
use datafusion_common::DataFusionError;
use datafusion_execution::{SendableRecordBatchStream, TaskContext};
use datafusion_physical_plan::{
ExecutionPlan, RecordBatchStream, display::DisplayableExecutionPlan,
execution_plan::reset_plan_states, stream::RecordBatchStreamAdapter,
};
use futures::{Stream, StreamExt, TryStreamExt, future::join_all};
use log::{debug, error, warn};
use tokio::{sync::mpsc::Sender, task::AbortHandle};
use crate::{
DistError, DistResult, JobId,
cluster::{DistCluster, NodeId, NodeStatus},
config::DistConfig,
event::{Event, EventHandler, local_jobs, send_event_with_timeout, start_event_handler},
executor::{DefaultExecutor, DistExecutor, logging_executor_metrics},
heartbeat::Heartbeater,
network::{DistNetwork, ScheduledTasks, StageInfo},
planner::{
DefaultPlanner, DisplayableStagePlans, DistPlanner, StageId, TaskId,
check_initial_stage_plans, resolve_stage_plan,
},
scheduler::{DefaultScheduler, DisplayableTaskDistribution, DistScheduler},
util::{ReceiverStreamBuilder, timestamp_ms},
};
#[derive(Debug, Clone)]
pub struct DistRuntime {
pub node_id: NodeId,
pub status: Arc<Mutex<NodeStatus>>,
pub task_ctx: Arc<TaskContext>,
pub config: Arc<DistConfig>,
pub cluster: Arc<dyn DistCluster>,
pub network: Arc<dyn DistNetwork>,
pub planner: Arc<dyn DistPlanner>,
pub scheduler: Arc<dyn DistScheduler>,
pub executor: Arc<dyn DistExecutor>,
pub heartbeater: Arc<Heartbeater>,
pub stages: Arc<Mutex<HashMap<StageId, StageState>>>,
pub event_sender: Sender<Event>,
}
impl DistRuntime {
pub fn new(
task_ctx: Arc<TaskContext>,
config: Arc<DistConfig>,
cluster: Arc<dyn DistCluster>,
network: Arc<dyn DistNetwork>,
) -> Self {
let node_id = network.local_node();
let status = Arc::new(Mutex::new(NodeStatus::Available));
let stages = Arc::new(Mutex::new(HashMap::new()));
let heartbeater = Heartbeater {
node_id: node_id.clone(),
cluster: cluster.clone(),
stages: stages.clone(),
heartbeat_interval: config.heartbeat_interval,
status: status.clone(),
};
let (sender, receiver) = tokio::sync::mpsc::channel::<Event>(config.event_queue_size);
let event_handler = EventHandler {
config: config.clone(),
cluster: cluster.clone(),
network: network.clone(),
local_stages: stages.clone(),
sender: sender.clone(),
receiver,
};
start_event_handler(event_handler);
Self {
node_id: network.local_node(),
status,
task_ctx,
config,
cluster,
network,
planner: Arc::new(DefaultPlanner),
scheduler: Arc::new(DefaultScheduler::new()),
executor: Arc::new(DefaultExecutor::new()),
heartbeater: Arc::new(heartbeater),
stages,
event_sender: sender,
}
}
pub fn with_planner(self, planner: Arc<dyn DistPlanner>) -> Self {
Self { planner, ..self }
}
pub fn with_scheduler(self, scheduler: Arc<dyn DistScheduler>) -> Self {
Self { scheduler, ..self }
}
pub fn with_executor(self, executor: Arc<dyn DistExecutor>) -> Self {
Self { executor, ..self }
}
pub async fn start(&self) {
self.heartbeater.start();
start_job_cleaner(self.stages.clone(), self.config.clone());
}
pub async fn shutdown(&self) {
*self.status.lock() = NodeStatus::Terminating;
debug!("Set node status to Terminating, no new tasks will be assigned");
self.heartbeater.send_heartbeat().await;
}
pub async fn submit(
&self,
job_id: impl Into<JobId>,
plan: Arc<dyn ExecutionPlan>,
job_meta: Arc<HashMap<String, String>>,
) -> DistResult<HashMap<TaskId, NodeId>> {
let job_id = job_id.into();
debug!(
"Submitting job {job_id} with meta {job_meta:?} and physical plan: \n{}",
DisplayableExecutionPlan::new(plan.as_ref()).indent(true)
);
let mut stage_plans = self.planner.plan_stages(job_id.clone(), plan)?;
debug!(
"job {job_id} initial stage plans:\n{}",
DisplayableStagePlans(&stage_plans)
);
check_initial_stage_plans(job_id.clone(), &stage_plans)?;
let node_states = self.cluster.alive_nodes().await?;
debug!(
"alive nodes: {}",
node_states
.keys()
.map(|n| n.to_string())
.collect::<Vec<_>>()
.join(", ")
);
let task_distribution = self
.scheduler
.schedule(&self.node_id, &node_states, &stage_plans)
.await?;
debug!(
"job {job_id} task distribution: {}",
DisplayableTaskDistribution(&task_distribution)
);
let stage0_task_distribution: HashMap<TaskId, NodeId> = task_distribution
.iter()
.filter(|(task_id, _)| task_id.stage == 0)
.map(|(task_id, node_id)| (task_id.clone(), node_id.clone()))
.collect();
if stage0_task_distribution.is_empty() {
return Err(DistError::internal(format!(
"Not found stage0 task distribution in {task_distribution:?} for job {job_id}"
)));
}
for (_, stage_plan) in stage_plans.iter_mut() {
*stage_plan = resolve_stage_plan(stage_plan.clone(), &task_distribution, self.clone())?;
}
debug!(
"job {job_id} final stage plans:\n{}",
DisplayableStagePlans(&stage_plans)
);
let mut node_stages = HashMap::new();
let mut node_tasks = HashMap::new();
for (task_id, node_id) in task_distribution.iter() {
node_stages
.entry(node_id.clone())
.or_insert_with(HashSet::new)
.insert(task_id.stage_id());
node_tasks
.entry(node_id.clone())
.or_insert_with(Vec::new)
.push(task_id.clone());
}
let mut handles = Vec::with_capacity(node_stages.len());
for (node_id, stage_ids) in node_stages {
let node_stage_plans = stage_ids
.iter()
.map(|stage_id| {
(
stage_id.clone(),
stage_plans
.get(stage_id)
.cloned()
.expect("stage id should be valid"),
)
})
.collect::<HashMap<_, _>>();
let tasks = node_tasks.get(&node_id).cloned().unwrap_or_default();
let scheduled_tasks = ScheduledTasks::new(
node_stage_plans,
tasks,
Arc::new(task_distribution.clone()),
job_meta.clone(),
);
if node_id == self.node_id {
self.receive_tasks(scheduled_tasks).await?;
} else {
debug!(
"Sending job {job_id} tasks [{}] to {node_id}",
scheduled_tasks
.task_ids
.iter()
.map(|t| format!("{}/{}", t.stage, t.partition))
.collect::<Vec<String>>()
.join(", ")
);
let network = self.network.clone();
let handle = tokio::spawn(async move {
network.send_tasks(node_id.clone(), scheduled_tasks).await?;
Ok::<_, DistError>(())
});
handles.push(handle);
}
}
for handle in handles {
handle.await??;
}
logging_executor_metrics(self.executor.handle());
Ok(stage0_task_distribution)
}
pub async fn execute_local(&self, task_id: TaskId) -> DistResult<SendableRecordBatchStream> {
let stage_id = task_id.stage_id();
let mut guard = self.stages.lock();
let stage_state = guard
.get_mut(&stage_id)
.ok_or_else(|| DistError::internal(format!("Stage {stage_id} not found")))?;
let (task_set_id, plan) = stage_state.get_plan(task_id.partition as usize)?;
let schema = plan.schema();
let mut receiver_stream_builder = ReceiverStreamBuilder::new(2);
let tx = receiver_stream_builder.tx();
let partition = task_id.partition as usize;
let task_ctx = self.task_ctx.clone();
let driver_task = async move {
let mut df_stream = plan.execute(partition, task_ctx)?;
while let Some(batch) = df_stream.next().await {
let batch = batch.map_err(DistError::from);
match tx.send(batch).await {
Ok(()) => {}
Err(e) => {
warn!("Dist driver task failed to send batch to channel: {e}");
return Ok(());
}
}
}
Ok(()) as DistResult<()>
};
let abort_handle = receiver_stream_builder.spawn_on(driver_task, self.executor.handle());
stage_state.start_task(task_id.partition as usize, task_set_id, abort_handle)?;
drop(guard);
let stream = receiver_stream_builder.build();
let stream = Box::pin(RecordBatchStreamAdapter::new(
schema,
stream.map_err(DataFusionError::from),
));
let task_stream = TaskStream::new(
task_id,
task_set_id,
self.stages.clone(),
self.event_sender.clone(),
stream,
);
Ok(Box::pin(task_stream))
}
pub async fn execute_remote(
&self,
node_id: NodeId,
task_id: TaskId,
) -> DistResult<SendableRecordBatchStream> {
if node_id == self.node_id {
return Err(DistError::internal(format!(
"remote node id {node_id} is actually self"
)));
}
debug!("Executing remote task {task_id} on node {node_id}");
self.network.execute_task(node_id, task_id).await
}
pub async fn receive_tasks(&self, scheduled_tasks: ScheduledTasks) -> DistResult<()> {
if matches!(*self.status.lock(), NodeStatus::Terminating) {
return Err(DistError::internal(
"Local node is in Terminating status, cannot receive tasks",
));
}
debug!(
"Received job {} tasks: [{}] and plans of stages: [{}]",
scheduled_tasks.job_id()?,
scheduled_tasks
.task_ids
.iter()
.map(|t| format!("{}/{}", t.stage, t.partition))
.collect::<Vec<String>>()
.join(", "),
scheduled_tasks
.stage_plans
.keys()
.map(|k| k.stage.to_string())
.collect::<Vec<String>>()
.join(", ")
);
let stage_states = StageState::from_scheduled_tasks(scheduled_tasks)?;
let stage_ids = stage_states.keys().cloned().collect::<Vec<StageId>>();
{
let mut guard = self.stages.lock();
guard.extend(stage_states);
drop(guard);
}
let stage0_ids = stage_ids
.iter()
.filter(|id| id.stage == 0)
.cloned()
.collect::<Vec<StageId>>();
if !stage0_ids.is_empty() {
send_event_with_timeout(&self.event_sender, Event::ReceivedStage0Tasks(stage0_ids))
.await?;
}
Ok(())
}
pub fn cleanup_local_jobs(&self, job_ids: Vec<JobId>) {
debug!(
"Cleaning up local Jobs [{}]",
job_ids
.iter()
.map(|id| id.to_string())
.collect::<Vec<_>>()
.join(", "),
);
let job_ids: HashSet<JobId> = job_ids.into_iter().collect();
if job_ids.is_empty() {
return;
}
cleanup_stages(&mut self.stages.lock(), |stage_id| {
job_ids.contains(&stage_id.job_id)
});
}
pub fn get_local_jobs(&self, job_ids: Option<&Vec<JobId>>) -> HashMap<StageId, StageInfo> {
local_jobs(&self.stages, job_ids)
}
pub async fn get_all_jobs(&self) -> DistResult<HashMap<StageId, StageInfo>> {
let mut combined_status = local_jobs(&self.stages, None);
let node_states = self.cluster.alive_nodes().await?;
let mut futures = Vec::new();
for node_id in node_states.keys() {
if *node_id != self.node_id {
let network = self.network.clone();
let node_id = node_id.clone();
futures.push(async move { network.get_jobs(node_id, None).await });
}
}
for remote_status in join_all(futures).await {
let remote_status = remote_status?;
for (stage_id, remote_stage_info) in remote_status {
combined_status
.entry(stage_id)
.and_modify(|existing| {
existing.merge(&remote_stage_info);
})
.or_insert(remote_stage_info);
}
}
Ok(combined_status)
}
}
#[derive(Debug)]
pub struct StageState {
pub stage_id: StageId,
pub created_at_ms: i64,
pub stage_plan: Arc<dyn ExecutionPlan>,
pub assigned_partitions: HashSet<usize>,
pub task_sets: Vec<TaskSet>,
pub job_task_distribution: Arc<HashMap<TaskId, NodeId>>,
pub job_meta: Arc<HashMap<String, String>>,
}
impl StageState {
pub fn from_scheduled_tasks(
scheduled_tasks: ScheduledTasks,
) -> DistResult<HashMap<StageId, StageState>> {
let mut stage_tasks: HashMap<StageId, HashSet<TaskId>> = HashMap::new();
for task_id in scheduled_tasks.task_ids {
let stage_id = task_id.stage_id();
stage_tasks.entry(stage_id).or_default().insert(task_id);
}
let mut stage_states = HashMap::new();
for (stage_id, assigned_task_ids) in stage_tasks {
let stage_state = StageState {
stage_id: stage_id.clone(),
created_at_ms: timestamp_ms(),
stage_plan: scheduled_tasks
.stage_plans
.get(&stage_id)
.ok_or_else(|| {
DistError::internal(format!("Not found plan of stage {stage_id}"))
})?
.clone(),
assigned_partitions: assigned_task_ids
.iter()
.map(|task_id| task_id.partition as usize)
.collect(),
task_sets: Vec::new(),
job_task_distribution: scheduled_tasks.job_task_distribution.clone(),
job_meta: scheduled_tasks.job_meta.clone(),
};
stage_states.insert(stage_id, stage_state);
}
Ok(stage_states)
}
pub fn num_running_tasks(&self) -> usize {
self.task_sets
.iter()
.map(|task_set| task_set.running_partitions.len())
.sum()
}
pub fn num_pending_tasks(&self) -> usize {
let executed_partitions: HashSet<usize> = self
.task_sets
.iter()
.flat_map(|task_set| {
let mut executed: HashSet<usize> =
task_set.running_partitions.keys().copied().collect();
executed.extend(task_set.dropped_partitions.keys());
executed
})
.collect();
self.assigned_partitions
.difference(&executed_partitions)
.count()
}
pub fn all_assigned_partitions_completed(&self) -> bool {
self.num_running_tasks() == 0 && self.num_pending_tasks() == 0
}
pub fn get_plan(&mut self, partition: usize) -> DistResult<(Uuid, Arc<dyn ExecutionPlan>)> {
if !self.assigned_partitions.contains(&partition) {
let task_id = self.stage_id.task_id(partition as u32);
return Err(DistError::internal(format!(
"Task {task_id} not found in this node"
)));
}
for task_set in self.task_sets.iter_mut() {
if task_set.never_executed(&partition) {
return Ok((task_set.id, task_set.shared_plan.clone()));
}
}
let task_set_id = Uuid::new_v4();
let new_task_set = TaskSet {
id: task_set_id,
shared_plan: reset_plan_states(self.stage_plan.clone())
.map_err(|e| DistError::internal(format!("Failed to reset plan state: {e}")))?,
running_partitions: HashMap::new(),
dropped_partitions: HashMap::new(),
};
let shared_plan = new_task_set.shared_plan.clone();
self.task_sets.push(new_task_set);
Ok((task_set_id, shared_plan))
}
pub fn start_task(
&mut self,
partition: usize,
task_set_id: Uuid,
abort_handle: AbortHandle,
) -> DistResult<()> {
let task_set = self
.task_sets
.iter_mut()
.find(|task_set| task_set.id == task_set_id)
.ok_or_else(|| DistError::internal(format!("Task set {task_set_id} not found")))?;
task_set.running_partitions.insert(partition, abort_handle);
Ok(())
}
pub fn complete_task(&mut self, task_id: TaskId, task_set_id: Uuid, task_metrics: TaskMetrics) {
if let Some(task_set) = self
.task_sets
.iter_mut()
.find(|task_set| task_set.id == task_set_id)
{
task_set
.running_partitions
.remove(&(task_id.partition as usize));
task_set
.dropped_partitions
.insert(task_id.partition as usize, task_metrics);
}
}
pub fn never_executed(&self) -> bool {
self.task_sets
.iter()
.all(|set| set.running_partitions.is_empty() && set.dropped_partitions.is_empty())
}
pub fn abort_running_tasks(&mut self) {
for task_set in &mut self.task_sets {
task_set.abort_running_partitions();
}
}
}
#[derive(Debug)]
pub struct TaskSet {
pub id: Uuid,
pub shared_plan: Arc<dyn ExecutionPlan>,
pub running_partitions: HashMap<usize, AbortHandle>,
pub dropped_partitions: HashMap<usize, TaskMetrics>,
}
impl TaskSet {
pub fn never_executed(&self, partition: &usize) -> bool {
!self.running_partitions.contains_key(partition)
&& !self.dropped_partitions.contains_key(partition)
}
pub fn abort_running_partitions(&mut self) {
for (_, abort_handle) in self.running_partitions.drain() {
abort_handle.abort();
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TaskMetrics {
pub output_rows: usize,
pub output_bytes: usize,
pub completed: bool,
}
pub struct TaskStream {
pub task_id: TaskId,
pub task_set_id: Uuid,
pub stages: Arc<Mutex<HashMap<StageId, StageState>>>,
pub event_sender: Sender<Event>,
pub stream: SendableRecordBatchStream,
pub output_rows: usize,
pub output_bytes: usize,
pub completed: bool,
}
impl TaskStream {
pub fn new(
task_id: TaskId,
task_set_id: Uuid,
stages: Arc<Mutex<HashMap<StageId, StageState>>>,
event_sender: Sender<Event>,
stream: SendableRecordBatchStream,
) -> Self {
Self {
task_id,
task_set_id,
stages,
event_sender,
stream,
output_rows: 0,
output_bytes: 0,
completed: false,
}
}
}
impl Stream for TaskStream {
type Item = Result<RecordBatch, DataFusionError>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match self.stream.as_mut().poll_next(cx) {
Poll::Ready(Some(Ok(batch))) => {
self.output_rows += batch.num_rows();
self.output_bytes += batch.get_array_memory_size();
Poll::Ready(Some(Ok(batch)))
}
Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
Poll::Ready(None) => {
self.completed = true;
Poll::Ready(None)
}
Poll::Pending => Poll::Pending,
}
}
}
impl RecordBatchStream for TaskStream {
fn schema(&self) -> Arc<Schema> {
self.stream.schema()
}
}
impl Drop for TaskStream {
fn drop(&mut self) {
let task_id = self.task_id.clone();
let task_set_id = self.task_set_id;
let task_metrics = TaskMetrics {
output_bytes: self.output_bytes,
output_rows: self.output_rows,
completed: self.completed,
};
debug!("Task {task_id} dropped with metrics: {task_metrics:?}");
let should_send = {
let mut guard = self.stages.lock();
if let Some(stage_state) = guard.get_mut(&task_id.stage_id()) {
stage_state.complete_task(task_id.clone(), task_set_id, task_metrics);
stage_state.stage_id.stage == 0 && stage_state.all_assigned_partitions_completed()
} else {
false
}
};
if should_send {
let event = Event::CheckJobCompleted(task_id.job_id.clone());
if let Err(e) = self.event_sender.try_send(event) {
error!(
"Failed to send CheckJobCompleted event after task {task_id} stream dropped: {e}"
);
}
}
}
}
fn start_job_cleaner(stages: Arc<Mutex<HashMap<StageId, StageState>>>, config: Arc<DistConfig>) {
tokio::spawn(async move {
loop {
tokio::time::sleep(config.job_ttl_check_interval).await;
let mut guard = stages.lock();
let mut to_cleanup = Vec::new();
for (stage_id, stage_state) in guard.iter() {
let age_ms = timestamp_ms() - stage_state.created_at_ms;
if age_ms >= config.job_ttl.as_millis() as i64 {
to_cleanup.push(stage_id.clone());
}
}
if !to_cleanup.is_empty() {
debug!(
"Stages [{}] lifetime exceed job ttl {}s, cleaning up.",
to_cleanup
.iter()
.map(|id| id.to_string())
.collect::<Vec<_>>()
.join(", "),
config.job_ttl.as_secs()
);
cleanup_stages(&mut guard, |stage_id| to_cleanup.contains(stage_id));
}
drop(guard);
}
});
}
pub(crate) fn cleanup_stages(
stages: &mut HashMap<StageId, StageState>,
mut should_cleanup: impl FnMut(&StageId) -> bool,
) {
stages.retain(|stage_id, stage_state| {
if should_cleanup(stage_id) {
stage_state.abort_running_tasks();
false
} else {
true
}
});
}