1use std::sync::Arc;
19use std::time::{Duration, SystemTime, UNIX_EPOCH};
20
21use kapot_core::error::Result;
22use kapot_core::event_loop::{EventLoop, EventSender};
23use kapot_core::serde::protobuf::TaskStatus;
24use kapot_core::serde::KapotCodec;
25
26use datafusion::execution::context::SessionState;
27use datafusion::logical_expr::LogicalPlan;
28use datafusion::prelude::{SessionConfig, SessionContext};
29use datafusion_proto::logical_plan::AsLogicalPlan;
30use datafusion_proto::physical_plan::AsExecutionPlan;
31
32use crate::cluster::KapotCluster;
33use crate::config::SchedulerConfig;
34use crate::metrics::SchedulerMetricsCollector;
35use kapot_core::serde::scheduler::{ExecutorData, ExecutorMetadata};
36use log::{error, warn};
37
38use crate::scheduler_server::event::QueryStageSchedulerEvent;
39use crate::scheduler_server::query_stage_scheduler::QueryStageScheduler;
40
41use crate::state::executor_manager::ExecutorManager;
42
43use crate::state::task_manager::TaskLauncher;
44use crate::state::SchedulerState;
45
46#[allow(clippy::all)]
48pub mod externalscaler {
49 include!(concat!(env!("OUT_DIR"), "/externalscaler.rs"));
50}
51
52pub mod event;
53mod external_scaler;
54mod grpc;
55pub(crate) mod query_stage_scheduler;
56
57pub(crate) type SessionBuilder = fn(SessionConfig) -> SessionState;
58
59#[derive(Clone)]
60pub struct SchedulerServer<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> {
61 pub scheduler_name: String,
62 pub start_time: u128,
63 pub state: Arc<SchedulerState<T, U>>,
64 pub(crate) query_stage_event_loop: EventLoop<QueryStageSchedulerEvent>,
65 query_stage_scheduler: Arc<QueryStageScheduler<T, U>>,
66 config: Arc<SchedulerConfig>,
67}
68
69impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> SchedulerServer<T, U> {
70 pub fn new(
71 scheduler_name: String,
72 cluster: KapotCluster,
73 codec: KapotCodec<T, U>,
74 config: Arc<SchedulerConfig>,
75 metrics_collector: Arc<dyn SchedulerMetricsCollector>,
76 ) -> Self {
77 let state = Arc::new(SchedulerState::new(
78 cluster,
79 codec,
80 scheduler_name.clone(),
81 config.clone(),
82 ));
83 let query_stage_scheduler = Arc::new(QueryStageScheduler::new(
84 state.clone(),
85 metrics_collector,
86 config.clone(),
87 ));
88 let query_stage_event_loop = EventLoop::new(
89 "query_stage".to_owned(),
90 config.event_loop_buffer_size as usize,
91 query_stage_scheduler.clone(),
92 );
93
94 Self {
95 scheduler_name,
96 start_time: timestamp_millis() as u128,
97 state,
98 query_stage_event_loop,
99 query_stage_scheduler,
100 config,
101 }
102 }
103
104 #[allow(dead_code)]
105 pub fn new_with_task_launcher(
106 scheduler_name: String,
107 cluster: KapotCluster,
108 codec: KapotCodec<T, U>,
109 config: Arc<SchedulerConfig>,
110 metrics_collector: Arc<dyn SchedulerMetricsCollector>,
111 task_launcher: Arc<dyn TaskLauncher>,
112 ) -> Self {
113 let state = Arc::new(SchedulerState::new_with_task_launcher(
114 cluster,
115 codec,
116 scheduler_name.clone(),
117 config.clone(),
118 task_launcher,
119 ));
120 let query_stage_scheduler = Arc::new(QueryStageScheduler::new(
121 state.clone(),
122 metrics_collector,
123 config.clone(),
124 ));
125 let query_stage_event_loop = EventLoop::new(
126 "query_stage".to_owned(),
127 config.event_loop_buffer_size as usize,
128 query_stage_scheduler.clone(),
129 );
130
131 Self {
132 scheduler_name,
133 start_time: timestamp_millis() as u128,
134 state,
135 query_stage_event_loop,
136 query_stage_scheduler,
137 config,
138 }
139 }
140
141 pub async fn init(&mut self) -> Result<()> {
142 self.state.init().await?;
143 self.query_stage_event_loop.start()?;
144 self.expire_dead_executors()?;
145
146 Ok(())
147 }
148
149 pub fn pending_job_number(&self) -> usize {
150 self.state.task_manager.pending_job_number()
151 }
152
153 pub fn running_job_number(&self) -> usize {
154 self.state.task_manager.running_job_number()
155 }
156
157 pub(crate) fn metrics_collector(&self) -> &dyn SchedulerMetricsCollector {
158 self.query_stage_scheduler.metrics_collector()
159 }
160
161 pub(crate) async fn submit_job(
162 &self,
163 job_id: &str,
164 job_name: &str,
165 ctx: Arc<SessionContext>,
166 plan: &LogicalPlan,
167 ) -> Result<()> {
168 self.query_stage_event_loop
169 .get_sender()?
170 .post_event(QueryStageSchedulerEvent::JobQueued {
171 job_id: job_id.to_owned(),
172 job_name: job_name.to_owned(),
173 session_ctx: ctx,
174 plan: Box::new(plan.clone()),
175 queued_at: timestamp_millis(),
176 })
177 .await
178 }
179
180 pub(crate) async fn update_task_status(
183 &self,
184 executor_id: &str,
185 tasks_status: Vec<TaskStatus>,
186 ) -> Result<()> {
187 if self.state.config.is_push_staged_scheduling()
189 && self.state.executor_manager.is_dead_executor(executor_id)
190 {
191 let error_msg = format!(
192 "Receive buggy tasks status from dead Executor {executor_id}, task status update ignored."
193 );
194 warn!("{}", error_msg);
195 return Ok(());
196 }
197 self.query_stage_event_loop
198 .get_sender()?
199 .post_event(QueryStageSchedulerEvent::TaskUpdating(
200 executor_id.to_owned(),
201 tasks_status,
202 ))
203 .await
204 }
205
206 pub(crate) async fn revive_offers(&self) -> Result<()> {
207 self.query_stage_event_loop
208 .get_sender()?
209 .post_event(QueryStageSchedulerEvent::ReviveOffers)
210 .await
211 }
212
213 fn expire_dead_executors(&self) -> Result<()> {
216 let state = self.state.clone();
217 let event_sender = self.query_stage_event_loop.get_sender()?;
218 tokio::task::spawn(async move {
219 loop {
220 let expired_executors = state.executor_manager.get_expired_executors();
221 for expired in expired_executors {
222 let executor_id = expired.executor_id.clone();
223
224 let sender_clone = event_sender.clone();
225
226 let terminating = matches!(
227 expired
228 .status
229 .as_ref()
230 .and_then(|status| status.status.as_ref()),
231 Some(kapot_core::serde::protobuf::executor_status::Status::Terminating(_))
232 );
233
234 let stop_reason = if terminating {
235 format!(
236 "TERMINATING executor {executor_id} heartbeat timed out after {}s", state.config.executor_termination_grace_period,
237 )
238 } else {
239 format!(
240 "ACTIVE executor {executor_id} heartbeat timed out after {}s",
241 state.config.executor_timeout_seconds,
242 )
243 };
244
245 warn!("{stop_reason}");
246
247 Self::remove_executor(
249 state.executor_manager.clone(),
250 sender_clone,
251 &executor_id,
252 Some(stop_reason.clone()),
253 0,
254 );
255
256 if !terminating {
259 state
260 .executor_manager
261 .stop_executor(&executor_id, stop_reason)
262 .await;
263 }
264 }
265 tokio::time::sleep(Duration::from_secs(
266 state.config.expire_dead_executor_interval_seconds,
267 ))
268 .await;
269 }
270 });
271 Ok(())
272 }
273
274 pub(crate) fn remove_executor(
275 executor_manager: ExecutorManager,
276 event_sender: EventSender<QueryStageSchedulerEvent>,
277 executor_id: &str,
278 reason: Option<String>,
279 wait_secs: u64,
280 ) {
281 let executor_id = executor_id.to_owned();
282 tokio::spawn(async move {
283 tokio::time::sleep(Duration::from_secs(wait_secs)).await;
285
286 if let Err(e) = executor_manager
288 .remove_executor(&executor_id, reason.clone())
289 .await
290 {
291 error!("error removing executor {executor_id}: {e:?}");
292 }
293
294 if let Err(e) = event_sender
295 .post_event(QueryStageSchedulerEvent::ExecutorLost(executor_id, reason))
296 .await
297 {
298 error!("error sending ExecutorLost event: {e:?}");
299 }
300 });
301 }
302
303 async fn do_register_executor(&self, metadata: ExecutorMetadata) -> Result<()> {
304 let executor_data = ExecutorData {
305 executor_id: metadata.id.clone(),
306 total_task_slots: metadata.specification.task_slots,
307 available_task_slots: metadata.specification.task_slots,
308 };
309
310 self.state
312 .executor_manager
313 .register_executor(metadata, executor_data)
314 .await?;
315
316 if self.state.config.is_push_staged_scheduling() {
319 self.revive_offers().await?;
320 }
321
322 Ok(())
323 }
324}
325
326pub fn timestamp_secs() -> u64 {
327 SystemTime::now()
328 .duration_since(UNIX_EPOCH)
329 .expect("Time went backwards")
330 .as_secs()
331}
332
333pub fn timestamp_millis() -> u64 {
334 SystemTime::now()
335 .duration_since(UNIX_EPOCH)
336 .expect("Time went backwards")
337 .as_millis() as u64
338}
339
340#[cfg(all(test, feature = "sled"))]
341mod test {
342 use std::sync::Arc;
343
344 use datafusion::arrow::datatypes::{DataType, Field, Schema};
345 use datafusion::functions_aggregate::sum::sum;
346 use datafusion::logical_expr::{col, LogicalPlan};
347
348 use datafusion::test_util::scan_empty_with_partitions;
349 use datafusion_proto::protobuf::LogicalPlanNode;
350 use datafusion_proto::protobuf::PhysicalPlanNode;
351
352 use kapot_core::config::{
353 KapotConfig, TaskSchedulingPolicy, KAPOT_DEFAULT_SHUFFLE_PARTITIONS,
354 };
355 use kapot_core::error::Result;
356
357 use crate::config::SchedulerConfig;
358
359 use kapot_core::serde::protobuf::{
360 failed_task, job_status, task_status, ExecutionError, FailedTask, JobStatus,
361 MultiTaskDefinition, ShuffleWritePartition, SuccessfulJob, SuccessfulTask,
362 TaskId, TaskStatus,
363 };
364 use kapot_core::serde::scheduler::{
365 ExecutorData, ExecutorMetadata, ExecutorSpecification,
366 };
367 use kapot_core::serde::KapotCodec;
368
369 use crate::scheduler_server::{timestamp_millis, SchedulerServer};
370
371 use crate::test_utils::{
372 assert_completed_event, assert_failed_event, assert_no_submitted_event,
373 assert_submitted_event, test_cluster_context, ExplodingTableProvider,
374 SchedulerTest, TaskRunnerFn, TestMetricsCollector,
375 };
376
377 #[tokio::test]
378 async fn test_pull_scheduling() -> Result<()> {
379 let plan = test_plan();
380 let task_slots = 4;
381
382 let scheduler = test_scheduler(TaskSchedulingPolicy::PullStaged).await?;
383
384 let executors = test_executors(task_slots);
385 for (executor_metadata, executor_data) in executors {
386 scheduler
387 .state
388 .executor_manager
389 .register_executor(executor_metadata, executor_data)
390 .await?;
391 }
392
393 let config = test_session(task_slots);
394
395 let ctx = scheduler
396 .state
397 .session_manager
398 .create_session(&config)
399 .await?;
400
401 let job_id = "job";
402
403 scheduler
405 .state
406 .task_manager
407 .queue_job(job_id, "", timestamp_millis())?;
408
409 scheduler
411 .state
412 .submit_job(job_id, "", ctx, &plan, 0)
413 .await
414 .expect("submitting plan");
415
416 while let Some(graph) = scheduler
418 .state
419 .task_manager
420 .get_active_execution_graph(job_id)
421 {
422 let task = {
423 let mut graph = graph.write().await;
424 graph.pop_next_task("executor-1")?
425 };
426 if let Some(task) = task {
427 let mut partitions: Vec<ShuffleWritePartition> = vec![];
428
429 let num_partitions = task.get_output_partition_number();
430
431 for partition_id in 0..num_partitions {
432 partitions.push(ShuffleWritePartition {
433 partition_id: partition_id as u64,
434 path: "some/path".to_string(),
435 num_batches: 1,
436 num_rows: 1,
437 num_bytes: 1,
438 })
439 }
440
441 let task_status = TaskStatus {
443 task_id: task.task_id as u32,
444 job_id: task.partition.job_id.clone(),
445 stage_id: task.partition.stage_id as u32,
446 stage_attempt_num: task.stage_attempt_num as u32,
447 partition_id: task.partition.partition_id as u32,
448 launch_time: 0,
449 start_exec_time: 0,
450 end_exec_time: 0,
451 metrics: vec![],
452 status: Some(task_status::Status::Successful(SuccessfulTask {
453 executor_id: "executor-1".to_owned(),
454 partitions,
455 })),
456 };
457
458 scheduler
459 .state
460 .update_task_statuses("executor-1", vec![task_status])
461 .await?;
462 } else {
463 break;
464 }
465 }
466
467 let final_graph = scheduler
468 .state
469 .task_manager
470 .get_active_execution_graph(job_id)
471 .expect("Fail to find graph in the cache");
472
473 let final_graph = final_graph.read().await;
474 assert!(final_graph.is_successful());
475 assert_eq!(final_graph.output_locations().len(), 4);
476
477 for output_location in final_graph.output_locations() {
478 assert_eq!(output_location.path, "some/path".to_owned());
479 assert_eq!(output_location.executor_meta.host, "localhost1".to_owned())
480 }
481
482 Ok(())
483 }
484
485 #[tokio::test]
486 async fn test_push_scheduling() -> Result<()> {
487 let plan = test_plan();
488
489 let metrics_collector = Arc::new(TestMetricsCollector::default());
490
491 let mut test = SchedulerTest::new(
492 SchedulerConfig::default()
493 .with_scheduler_policy(TaskSchedulingPolicy::PushStaged),
494 metrics_collector.clone(),
495 4,
496 1,
497 None,
498 )
499 .await?;
500
501 let status = test.run("job", "", &plan).await.expect("running plan");
502
503 match status.status {
504 Some(job_status::Status::Successful(SuccessfulJob {
505 partition_location,
506 ..
507 })) => {
508 assert_eq!(partition_location.len(), 4);
509 }
510 other => {
511 panic!("Expected success status but found {:?}", other);
512 }
513 }
514
515 assert_submitted_event("job", &metrics_collector);
516 assert_completed_event("job", &metrics_collector);
517
518 Ok(())
519 }
520
521 #[tokio::test]
523 async fn test_job_failure() -> Result<()> {
524 let plan = test_plan();
525
526 let runner = Arc::new(TaskRunnerFn::new(
527 |_executor_id: String, task: MultiTaskDefinition| {
528 let mut statuses = vec![];
529
530 for TaskId {
531 task_id,
532 partition_id,
533 ..
534 } in task.task_ids
535 {
536 let timestamp = timestamp_millis();
537 statuses.push(TaskStatus {
538 task_id,
539 job_id: task.job_id.clone(),
540 stage_id: task.stage_id,
541 stage_attempt_num: task.stage_attempt_num,
542 partition_id,
543 launch_time: timestamp,
544 start_exec_time: timestamp,
545 end_exec_time: timestamp,
546 metrics: vec![],
547 status: Some(task_status::Status::Failed(FailedTask {
548 error: "ERROR".to_string(),
549 retryable: false,
550 count_to_failures: false,
551 failed_reason: Some(
552 failed_task::FailedReason::ExecutionError(
553 ExecutionError {},
554 ),
555 ),
556 })),
557 });
558 }
559
560 statuses
561 },
562 ));
563
564 let metrics_collector = Arc::new(TestMetricsCollector::default());
565
566 let mut test = SchedulerTest::new(
567 SchedulerConfig::default()
568 .with_scheduler_policy(TaskSchedulingPolicy::PushStaged),
569 metrics_collector.clone(),
570 4,
571 1,
572 Some(runner),
573 )
574 .await?;
575
576 let status = test.run("job", "", &plan).await.expect("running plan");
577
578 assert!(
579 matches!(
580 status,
581 JobStatus {
582 status: Some(job_status::Status::Failed(_)),
583 ..
584 }
585 ),
586 "{}",
587 "Expected job status to be failed but it was {status:?}"
588 );
589
590 assert_submitted_event("job", &metrics_collector);
591 assert_failed_event("job", &metrics_collector);
592
593 Ok(())
594 }
595
596 #[tokio::test]
599 async fn test_planning_failure() -> Result<()> {
600 let metrics_collector = Arc::new(TestMetricsCollector::default());
601 let mut test = SchedulerTest::new(
602 SchedulerConfig::default()
603 .with_scheduler_policy(TaskSchedulingPolicy::PushStaged),
604 metrics_collector.clone(),
605 4,
606 1,
607 None,
608 )
609 .await?;
610
611 let ctx = test.ctx().await?;
612
613 ctx.register_table("explode", Arc::new(ExplodingTableProvider))?;
614
615 let plan = ctx
616 .sql("SELECT * FROM explode")
617 .await?
618 .into_optimized_plan()?;
619
620 let status = test.run("job", "", &plan).await?;
622
623 assert!(
624 matches!(
625 status,
626 JobStatus {
627 status: Some(job_status::Status::Failed(_)),
628 ..
629 }
630 ),
631 "{}",
632 "Expected job status to be failed but it was {status:?}"
633 );
634
635 assert_no_submitted_event("job", &metrics_collector);
636 assert_failed_event("job", &metrics_collector);
637
638 Ok(())
639 }
640
641 async fn test_scheduler(
642 scheduling_policy: TaskSchedulingPolicy,
643 ) -> Result<SchedulerServer<LogicalPlanNode, PhysicalPlanNode>> {
644 let cluster = test_cluster_context();
645
646 let config = SchedulerConfig::default().with_scheduler_policy(scheduling_policy);
647 let mut scheduler: SchedulerServer<LogicalPlanNode, PhysicalPlanNode> =
648 SchedulerServer::new(
649 "localhost:50050".to_owned(),
650 cluster,
651 KapotCodec::default(),
652 Arc::new(config),
653 Arc::new(TestMetricsCollector::default()),
654 );
655 scheduler.init().await?;
656
657 Ok(scheduler)
658 }
659
660 fn test_executors(num_partitions: usize) -> Vec<(ExecutorMetadata, ExecutorData)> {
661 let task_slots = (num_partitions as u32 + 1) / 2;
662
663 vec![
664 (
665 ExecutorMetadata {
666 id: "executor-1".to_string(),
667 host: "localhost1".to_string(),
668 port: 8080,
669 grpc_port: 9090,
670 specification: ExecutorSpecification { task_slots },
671 },
672 ExecutorData {
673 executor_id: "executor-1".to_owned(),
674 total_task_slots: task_slots,
675 available_task_slots: task_slots,
676 },
677 ),
678 (
679 ExecutorMetadata {
680 id: "executor-2".to_string(),
681 host: "localhost2".to_string(),
682 port: 8080,
683 grpc_port: 9090,
684 specification: ExecutorSpecification {
685 task_slots: num_partitions as u32 - task_slots,
686 },
687 },
688 ExecutorData {
689 executor_id: "executor-2".to_owned(),
690 total_task_slots: num_partitions as u32 - task_slots,
691 available_task_slots: num_partitions as u32 - task_slots,
692 },
693 ),
694 ]
695 }
696
697 fn test_plan() -> LogicalPlan {
698 let schema = Schema::new(vec![
699 Field::new("id", DataType::Utf8, false),
700 Field::new("gmv", DataType::UInt64, false),
701 ]);
702
703 scan_empty_with_partitions(None, &schema, Some(vec![0, 1]), 2)
706 .unwrap()
707 .aggregate(vec![col("id")], vec![sum(col("gmv"))])
708 .unwrap()
709 .build()
710 .unwrap()
711 }
712
713 fn test_session(partitions: usize) -> KapotConfig {
714 KapotConfig::builder()
715 .set(
716 KAPOT_DEFAULT_SHUFFLE_PARTITIONS,
717 format!("{partitions}").as_str(),
718 )
719 .build()
720 .expect("creating kapotConfig")
721 }
722}