use std::sync::Arc;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use ballista_core::error::Result;
use ballista_core::event_loop::{EventLoop, EventSender};
use ballista_core::serde::protobuf::{StopExecutorParams, TaskStatus};
use ballista_core::serde::{AsExecutionPlan, BallistaCodec};
use ballista_core::utils::default_session_builder;
use datafusion::execution::context::SessionState;
use datafusion::logical_expr::LogicalPlan;
use datafusion::prelude::{SessionConfig, SessionContext};
use datafusion_proto::logical_plan::AsLogicalPlan;
use crate::config::SchedulerConfig;
use crate::metrics::SchedulerMetricsCollector;
use log::{error, warn};
use crate::scheduler_server::event::QueryStageSchedulerEvent;
use crate::scheduler_server::query_stage_scheduler::QueryStageScheduler;
use crate::state::backend::StateBackendClient;
use crate::state::executor_manager::{
ExecutorManager, ExecutorReservation, DEFAULT_EXECUTOR_TIMEOUT_SECONDS,
};
use crate::state::task_manager::TaskLauncher;
use crate::state::SchedulerState;
#[allow(clippy::all)]
pub mod externalscaler {
include!(concat!(env!("OUT_DIR"), "/externalscaler.rs"));
}
pub mod event;
mod external_scaler;
mod grpc;
mod query_stage_scheduler;
pub(crate) type SessionBuilder = fn(SessionConfig) -> SessionState;
#[derive(Clone)]
pub struct SchedulerServer<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> {
pub scheduler_name: String,
pub start_time: u128,
pub(crate) state: Arc<SchedulerState<T, U>>,
pub(crate) query_stage_event_loop: EventLoop<QueryStageSchedulerEvent>,
query_stage_scheduler: Arc<QueryStageScheduler<T, U>>,
}
impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> SchedulerServer<T, U> {
pub fn new(
scheduler_name: String,
config_backend: Arc<dyn StateBackendClient>,
codec: BallistaCodec<T, U>,
config: SchedulerConfig,
metrics_collector: Arc<dyn SchedulerMetricsCollector>,
) -> Self {
let state = Arc::new(SchedulerState::new(
config_backend,
default_session_builder,
codec,
scheduler_name.clone(),
config.clone(),
));
let query_stage_scheduler =
Arc::new(QueryStageScheduler::new(state.clone(), metrics_collector));
let query_stage_event_loop = EventLoop::new(
"query_stage".to_owned(),
config.event_loop_buffer_size as usize,
query_stage_scheduler.clone(),
);
Self {
scheduler_name,
start_time: timestamp_millis() as u128,
state,
query_stage_event_loop,
query_stage_scheduler,
}
}
#[allow(dead_code)]
pub(crate) fn with_task_launcher(
scheduler_name: String,
config_backend: Arc<dyn StateBackendClient>,
codec: BallistaCodec<T, U>,
config: SchedulerConfig,
metrics_collector: Arc<dyn SchedulerMetricsCollector>,
task_launcher: Arc<dyn TaskLauncher>,
) -> Self {
let state = Arc::new(SchedulerState::with_task_launcher(
config_backend,
default_session_builder,
codec,
scheduler_name.clone(),
config.clone(),
task_launcher,
));
let query_stage_scheduler =
Arc::new(QueryStageScheduler::new(state.clone(), metrics_collector));
let query_stage_event_loop = EventLoop::new(
"query_stage".to_owned(),
config.event_loop_buffer_size as usize,
query_stage_scheduler.clone(),
);
Self {
scheduler_name,
start_time: timestamp_millis() as u128,
state,
query_stage_event_loop,
query_stage_scheduler,
}
}
pub async fn init(&mut self) -> Result<()> {
self.state.init().await?;
self.query_stage_event_loop.start()?;
self.expire_dead_executors()?;
Ok(())
}
pub(crate) fn pending_tasks(&self) -> usize {
self.query_stage_scheduler.pending_tasks()
}
pub(crate) fn metrics_collector(&self) -> &dyn SchedulerMetricsCollector {
self.query_stage_scheduler.metrics_collector()
}
pub(crate) async fn submit_job(
&self,
job_id: &str,
job_name: &str,
ctx: Arc<SessionContext>,
plan: &LogicalPlan,
) -> Result<()> {
self.query_stage_event_loop
.get_sender()?
.post_event(QueryStageSchedulerEvent::JobQueued {
job_id: job_id.to_owned(),
job_name: job_name.to_owned(),
session_ctx: ctx,
plan: Box::new(plan.clone()),
queued_at: timestamp_millis(),
})
.await
}
pub(crate) async fn update_task_status(
&self,
executor_id: &str,
tasks_status: Vec<TaskStatus>,
) -> Result<()> {
if self.state.executor_manager.is_dead_executor(executor_id) {
let error_msg = format!(
"Receive buggy tasks status from dead Executor {}, task status update ignored.",
executor_id
);
warn!("{}", error_msg);
return Ok(());
}
self.query_stage_event_loop
.get_sender()?
.post_event(QueryStageSchedulerEvent::TaskUpdating(
executor_id.to_owned(),
tasks_status,
))
.await
}
pub(crate) async fn offer_reservation(
&self,
reservations: Vec<ExecutorReservation>,
) -> Result<()> {
self.query_stage_event_loop
.get_sender()?
.post_event(QueryStageSchedulerEvent::ReservationOffering(reservations))
.await
}
fn expire_dead_executors(&self) -> Result<()> {
let state = self.state.clone();
let event_sender = self.query_stage_event_loop.get_sender()?;
tokio::task::spawn(async move {
loop {
let expired_executors = state.executor_manager.get_expired_executors();
for expired in expired_executors {
let executor_id = expired.executor_id.clone();
let executor_manager = state.executor_manager.clone();
let stop_reason = format!(
"Executor {} heartbeat timed out after {}s",
executor_id.clone(),
DEFAULT_EXECUTOR_TIMEOUT_SECONDS
);
warn!("{}", stop_reason.clone());
let sender_clone = event_sender.clone();
Self::remove_executor(
executor_manager,
sender_clone,
&executor_id,
Some(stop_reason.clone()),
)
.await
.unwrap_or_else(|e| {
let msg = format!(
"Error to remove Executor in Scheduler due to {:?}",
e
);
error!("{}", msg);
});
match state.executor_manager.get_client(&executor_id).await {
Ok(mut client) => {
tokio::task::spawn(async move {
match client
.stop_executor(StopExecutorParams {
executor_id,
reason: stop_reason,
force: true,
})
.await
{
Err(error) => {
warn!(
"Failed to send stop_executor rpc due to, {}",
error
);
}
Ok(_value) => {}
}
});
}
Err(_) => {
warn!("Executor is already dead, failed to connect to Executor {}", executor_id);
}
}
}
tokio::time::sleep(Duration::from_secs(DEFAULT_EXECUTOR_TIMEOUT_SECONDS))
.await;
}
});
Ok(())
}
pub(crate) async fn remove_executor(
executor_manager: ExecutorManager,
event_sender: EventSender<QueryStageSchedulerEvent>,
executor_id: &str,
reason: Option<String>,
) -> Result<()> {
executor_manager
.remove_executor(executor_id, reason.clone())
.await?;
event_sender
.post_event(QueryStageSchedulerEvent::ExecutorLost(
executor_id.to_owned(),
reason,
))
.await?;
Ok(())
}
}
pub fn timestamp_secs() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("Time went backwards")
.as_secs()
}
pub fn timestamp_millis() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("Time went backwards")
.as_millis() as u64
}
#[cfg(all(test, feature = "sled"))]
mod test {
use std::sync::Arc;
use datafusion::arrow::datatypes::{DataType, Field, Schema};
use datafusion::logical_expr::{col, sum, LogicalPlan};
use datafusion::test_util::scan_empty;
use datafusion_proto::protobuf::LogicalPlanNode;
use ballista_core::config::{
BallistaConfig, TaskSchedulingPolicy, BALLISTA_DEFAULT_SHUFFLE_PARTITIONS,
};
use ballista_core::error::Result;
use crate::config::SchedulerConfig;
use ballista_core::serde::protobuf::{
failed_task, job_status, task_status, ExecutionError, FailedTask, JobStatus,
MultiTaskDefinition, PhysicalPlanNode, ShuffleWritePartition, SuccessfulJob,
SuccessfulTask, TaskId, TaskStatus,
};
use ballista_core::serde::scheduler::{
ExecutorData, ExecutorMetadata, ExecutorSpecification,
};
use ballista_core::serde::BallistaCodec;
use crate::scheduler_server::{timestamp_millis, SchedulerServer};
use crate::state::backend::standalone::StandaloneClient;
use crate::test_utils::{
assert_completed_event, assert_failed_event, assert_no_submitted_event,
assert_submitted_event, ExplodingTableProvider, SchedulerTest, TaskRunnerFn,
TestMetricsCollector,
};
#[tokio::test]
async fn test_pull_scheduling() -> Result<()> {
let plan = test_plan();
let task_slots = 4;
let scheduler = test_scheduler(TaskSchedulingPolicy::PullStaged).await?;
let executors = test_executors(task_slots);
for (executor_metadata, executor_data) in executors {
scheduler
.state
.executor_manager
.register_executor(executor_metadata, executor_data, false)
.await?;
}
let config = test_session(task_slots);
let ctx = scheduler
.state
.session_manager
.create_session(&config)
.await?;
let job_id = "job";
scheduler
.state
.submit_job(job_id, "", ctx, &plan, 0)
.await
.expect("submitting plan");
while let Some(graph) = scheduler
.state
.task_manager
.get_active_execution_graph(job_id)
.await
{
let task = {
let mut graph = graph.write().await;
graph.pop_next_task("executor-1")?
};
if let Some(task) = task {
let mut partitions: Vec<ShuffleWritePartition> = vec![];
let num_partitions = task
.output_partitioning
.map(|p| p.partition_count())
.unwrap_or(1);
for partition_id in 0..num_partitions {
partitions.push(ShuffleWritePartition {
partition_id: partition_id as u64,
path: "some/path".to_string(),
num_batches: 1,
num_rows: 1,
num_bytes: 1,
})
}
let task_status = TaskStatus {
task_id: task.task_id as u32,
job_id: task.partition.job_id.clone(),
stage_id: task.partition.stage_id as u32,
stage_attempt_num: task.stage_attempt_num as u32,
partition_id: task.partition.partition_id as u32,
launch_time: 0,
start_exec_time: 0,
end_exec_time: 0,
metrics: vec![],
status: Some(task_status::Status::Successful(SuccessfulTask {
executor_id: "executor-1".to_owned(),
partitions,
})),
};
scheduler
.state
.update_task_statuses("executor-1", vec![task_status])
.await?;
} else {
break;
}
}
let final_graph = scheduler
.state
.task_manager
.get_active_execution_graph(job_id)
.await
.expect("Fail to find graph in the cache");
let final_graph = final_graph.read().await;
assert!(final_graph.is_successful());
assert_eq!(final_graph.output_locations().len(), 4);
for output_location in final_graph.output_locations() {
assert_eq!(output_location.path, "some/path".to_owned());
assert_eq!(output_location.executor_meta.host, "localhost1".to_owned())
}
Ok(())
}
#[tokio::test]
async fn test_push_scheduling() -> Result<()> {
let plan = test_plan();
let metrics_collector = Arc::new(TestMetricsCollector::default());
let mut test = SchedulerTest::new(
SchedulerConfig::default()
.with_scheduler_policy(TaskSchedulingPolicy::PushStaged),
metrics_collector.clone(),
4,
1,
None,
)
.await?;
let status = test.run("job", "", &plan).await.expect("running plan");
match status.status {
Some(job_status::Status::Successful(SuccessfulJob {
partition_location,
})) => {
assert_eq!(partition_location.len(), 4);
}
other => {
panic!("Expected success status but found {:?}", other);
}
}
assert_submitted_event("job", &metrics_collector);
assert_completed_event("job", &metrics_collector);
Ok(())
}
#[tokio::test]
async fn test_job_failure() -> Result<()> {
let plan = test_plan();
let runner = Arc::new(TaskRunnerFn::new(
|_executor_id: String, task: MultiTaskDefinition| {
let mut statuses = vec![];
for TaskId {
task_id,
partition_id,
..
} in task.task_ids
{
let timestamp = timestamp_millis();
statuses.push(TaskStatus {
task_id,
job_id: task.job_id.clone(),
stage_id: task.stage_id,
stage_attempt_num: task.stage_attempt_num,
partition_id,
launch_time: timestamp,
start_exec_time: timestamp,
end_exec_time: timestamp,
metrics: vec![],
status: Some(task_status::Status::Failed(FailedTask {
error: "ERROR".to_string(),
retryable: false,
count_to_failures: false,
failed_reason: Some(
failed_task::FailedReason::ExecutionError(
ExecutionError {},
),
),
})),
});
}
statuses
},
));
let metrics_collector = Arc::new(TestMetricsCollector::default());
let mut test = SchedulerTest::new(
SchedulerConfig::default()
.with_scheduler_policy(TaskSchedulingPolicy::PushStaged),
metrics_collector.clone(),
4,
1,
Some(runner),
)
.await?;
let status = test.run("job", "", &plan).await.expect("running plan");
assert!(
matches!(
status,
JobStatus {
status: Some(job_status::Status::Failed(_))
}
),
"Expected job status to be failed but it was {:?}",
status
);
assert_submitted_event("job", &metrics_collector);
assert_failed_event("job", &metrics_collector);
Ok(())
}
#[tokio::test]
async fn test_planning_failure() -> Result<()> {
let metrics_collector = Arc::new(TestMetricsCollector::default());
let mut test = SchedulerTest::new(
SchedulerConfig::default()
.with_scheduler_policy(TaskSchedulingPolicy::PushStaged),
metrics_collector.clone(),
4,
1,
None,
)
.await?;
let ctx = test.ctx().await?;
ctx.register_table("explode", Arc::new(ExplodingTableProvider))?;
let plan = ctx.sql("SELECT * FROM explode").await?.to_logical_plan()?;
let status = test.run("job", "", &plan).await?;
assert!(
matches!(
status,
JobStatus {
status: Some(job_status::Status::Failed(_))
}
),
"Expected job status to be failed but it was {:?}",
status
);
assert_no_submitted_event("job", &metrics_collector);
assert_failed_event("job", &metrics_collector);
Ok(())
}
async fn test_scheduler(
scheduling_policy: TaskSchedulingPolicy,
) -> Result<SchedulerServer<LogicalPlanNode, PhysicalPlanNode>> {
let state_storage = Arc::new(StandaloneClient::try_new_temporary()?);
let mut scheduler: SchedulerServer<LogicalPlanNode, PhysicalPlanNode> =
SchedulerServer::new(
"localhost:50050".to_owned(),
state_storage.clone(),
BallistaCodec::default(),
SchedulerConfig::default().with_scheduler_policy(scheduling_policy),
Arc::new(TestMetricsCollector::default()),
);
scheduler.init().await?;
Ok(scheduler)
}
fn test_executors(num_partitions: usize) -> Vec<(ExecutorMetadata, ExecutorData)> {
let task_slots = (num_partitions as u32 + 1) / 2;
vec![
(
ExecutorMetadata {
id: "executor-1".to_string(),
host: "localhost1".to_string(),
port: 8080,
grpc_port: 9090,
specification: ExecutorSpecification { task_slots },
},
ExecutorData {
executor_id: "executor-1".to_owned(),
total_task_slots: task_slots,
available_task_slots: task_slots,
},
),
(
ExecutorMetadata {
id: "executor-2".to_string(),
host: "localhost2".to_string(),
port: 8080,
grpc_port: 9090,
specification: ExecutorSpecification {
task_slots: num_partitions as u32 - task_slots,
},
},
ExecutorData {
executor_id: "executor-2".to_owned(),
total_task_slots: num_partitions as u32 - task_slots,
available_task_slots: num_partitions as u32 - task_slots,
},
),
]
}
fn test_plan() -> LogicalPlan {
let schema = Schema::new(vec![
Field::new("id", DataType::Utf8, false),
Field::new("gmv", DataType::UInt64, false),
]);
scan_empty(None, &schema, Some(vec![0, 1]))
.unwrap()
.aggregate(vec![col("id")], vec![sum(col("gmv"))])
.unwrap()
.build()
.unwrap()
}
fn test_session(partitions: usize) -> BallistaConfig {
BallistaConfig::builder()
.set(
BALLISTA_DEFAULT_SHUFFLE_PARTITIONS,
format!("{}", partitions).as_str(),
)
.build()
.expect("creating BallistaConfig")
}
}