kapot_scheduler/scheduler_server/
mod.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::sync::Arc;
19use std::time::{Duration, SystemTime, UNIX_EPOCH};
20
21use kapot_core::error::Result;
22use kapot_core::event_loop::{EventLoop, EventSender};
23use kapot_core::serde::protobuf::TaskStatus;
24use kapot_core::serde::KapotCodec;
25
26use datafusion::execution::context::SessionState;
27use datafusion::logical_expr::LogicalPlan;
28use datafusion::prelude::{SessionConfig, SessionContext};
29use datafusion_proto::logical_plan::AsLogicalPlan;
30use datafusion_proto::physical_plan::AsExecutionPlan;
31
32use crate::cluster::KapotCluster;
33use crate::config::SchedulerConfig;
34use crate::metrics::SchedulerMetricsCollector;
35use kapot_core::serde::scheduler::{ExecutorData, ExecutorMetadata};
36use log::{error, warn};
37
38use crate::scheduler_server::event::QueryStageSchedulerEvent;
39use crate::scheduler_server::query_stage_scheduler::QueryStageScheduler;
40
41use crate::state::executor_manager::ExecutorManager;
42
43use crate::state::task_manager::TaskLauncher;
44use crate::state::SchedulerState;
45
46// include the generated protobuf source as a submodule
47#[allow(clippy::all)]
48pub mod externalscaler {
49    include!(concat!(env!("OUT_DIR"), "/externalscaler.rs"));
50}
51
52pub mod event;
53mod external_scaler;
54mod grpc;
55pub(crate) mod query_stage_scheduler;
56
57pub(crate) type SessionBuilder = fn(SessionConfig) -> SessionState;
58
59#[derive(Clone)]
60pub struct SchedulerServer<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> {
61    pub scheduler_name: String,
62    pub start_time: u128,
63    pub state: Arc<SchedulerState<T, U>>,
64    pub(crate) query_stage_event_loop: EventLoop<QueryStageSchedulerEvent>,
65    query_stage_scheduler: Arc<QueryStageScheduler<T, U>>,
66    config: Arc<SchedulerConfig>,
67}
68
69impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> SchedulerServer<T, U> {
70    pub fn new(
71        scheduler_name: String,
72        cluster: KapotCluster,
73        codec: KapotCodec<T, U>,
74        config: Arc<SchedulerConfig>,
75        metrics_collector: Arc<dyn SchedulerMetricsCollector>,
76    ) -> Self {
77        let state = Arc::new(SchedulerState::new(
78            cluster,
79            codec,
80            scheduler_name.clone(),
81            config.clone(),
82        ));
83        let query_stage_scheduler = Arc::new(QueryStageScheduler::new(
84            state.clone(),
85            metrics_collector,
86            config.clone(),
87        ));
88        let query_stage_event_loop = EventLoop::new(
89            "query_stage".to_owned(),
90            config.event_loop_buffer_size as usize,
91            query_stage_scheduler.clone(),
92        );
93
94        Self {
95            scheduler_name,
96            start_time: timestamp_millis() as u128,
97            state,
98            query_stage_event_loop,
99            query_stage_scheduler,
100            config,
101        }
102    }
103
104    #[allow(dead_code)]
105    pub fn new_with_task_launcher(
106        scheduler_name: String,
107        cluster: KapotCluster,
108        codec: KapotCodec<T, U>,
109        config: Arc<SchedulerConfig>,
110        metrics_collector: Arc<dyn SchedulerMetricsCollector>,
111        task_launcher: Arc<dyn TaskLauncher>,
112    ) -> Self {
113        let state = Arc::new(SchedulerState::new_with_task_launcher(
114            cluster,
115            codec,
116            scheduler_name.clone(),
117            config.clone(),
118            task_launcher,
119        ));
120        let query_stage_scheduler = Arc::new(QueryStageScheduler::new(
121            state.clone(),
122            metrics_collector,
123            config.clone(),
124        ));
125        let query_stage_event_loop = EventLoop::new(
126            "query_stage".to_owned(),
127            config.event_loop_buffer_size as usize,
128            query_stage_scheduler.clone(),
129        );
130
131        Self {
132            scheduler_name,
133            start_time: timestamp_millis() as u128,
134            state,
135            query_stage_event_loop,
136            query_stage_scheduler,
137            config,
138        }
139    }
140
141    pub async fn init(&mut self) -> Result<()> {
142        self.state.init().await?;
143        self.query_stage_event_loop.start()?;
144        self.expire_dead_executors()?;
145
146        Ok(())
147    }
148
149    pub fn pending_job_number(&self) -> usize {
150        self.state.task_manager.pending_job_number()
151    }
152
153    pub fn running_job_number(&self) -> usize {
154        self.state.task_manager.running_job_number()
155    }
156
157    pub(crate) fn metrics_collector(&self) -> &dyn SchedulerMetricsCollector {
158        self.query_stage_scheduler.metrics_collector()
159    }
160
161    pub(crate) async fn submit_job(
162        &self,
163        job_id: &str,
164        job_name: &str,
165        ctx: Arc<SessionContext>,
166        plan: &LogicalPlan,
167    ) -> Result<()> {
168        self.query_stage_event_loop
169            .get_sender()?
170            .post_event(QueryStageSchedulerEvent::JobQueued {
171                job_id: job_id.to_owned(),
172                job_name: job_name.to_owned(),
173                session_ctx: ctx,
174                plan: Box::new(plan.clone()),
175                queued_at: timestamp_millis(),
176            })
177            .await
178    }
179
180    /// It just send task status update event to the channel,
181    /// and will not guarantee the event processing completed after return
182    pub(crate) async fn update_task_status(
183        &self,
184        executor_id: &str,
185        tasks_status: Vec<TaskStatus>,
186    ) -> Result<()> {
187        // We might receive buggy task updates from dead executors.
188        if self.state.config.is_push_staged_scheduling()
189            && self.state.executor_manager.is_dead_executor(executor_id)
190        {
191            let error_msg = format!(
192                "Receive buggy tasks status from dead Executor {executor_id}, task status update ignored."
193            );
194            warn!("{}", error_msg);
195            return Ok(());
196        }
197        self.query_stage_event_loop
198            .get_sender()?
199            .post_event(QueryStageSchedulerEvent::TaskUpdating(
200                executor_id.to_owned(),
201                tasks_status,
202            ))
203            .await
204    }
205
206    pub(crate) async fn revive_offers(&self) -> Result<()> {
207        self.query_stage_event_loop
208            .get_sender()?
209            .post_event(QueryStageSchedulerEvent::ReviveOffers)
210            .await
211    }
212
213    /// Spawn an async task which periodically check the active executors' status and
214    /// expire the dead executors
215    fn expire_dead_executors(&self) -> Result<()> {
216        let state = self.state.clone();
217        let event_sender = self.query_stage_event_loop.get_sender()?;
218        tokio::task::spawn(async move {
219            loop {
220                let expired_executors = state.executor_manager.get_expired_executors();
221                for expired in expired_executors {
222                    let executor_id = expired.executor_id.clone();
223
224                    let sender_clone = event_sender.clone();
225
226                    let terminating = matches!(
227                        expired
228                            .status
229                            .as_ref()
230                            .and_then(|status| status.status.as_ref()),
231                        Some(kapot_core::serde::protobuf::executor_status::Status::Terminating(_))
232                    );
233
234                    let stop_reason = if terminating {
235                        format!(
236                        "TERMINATING executor {executor_id} heartbeat timed out after {}s", state.config.executor_termination_grace_period,
237                    )
238                    } else {
239                        format!(
240                            "ACTIVE executor {executor_id} heartbeat timed out after {}s",
241                            state.config.executor_timeout_seconds,
242                        )
243                    };
244
245                    warn!("{stop_reason}");
246
247                    // If executor is expired, remove it immediately
248                    Self::remove_executor(
249                        state.executor_manager.clone(),
250                        sender_clone,
251                        &executor_id,
252                        Some(stop_reason.clone()),
253                        0,
254                    );
255
256                    // If executor is not already terminating then stop it. If it is terminating then it should already be shutting
257                    // down and we do not need to do anything here.
258                    if !terminating {
259                        state
260                            .executor_manager
261                            .stop_executor(&executor_id, stop_reason)
262                            .await;
263                    }
264                }
265                tokio::time::sleep(Duration::from_secs(
266                    state.config.expire_dead_executor_interval_seconds,
267                ))
268                .await;
269            }
270        });
271        Ok(())
272    }
273
274    pub(crate) fn remove_executor(
275        executor_manager: ExecutorManager,
276        event_sender: EventSender<QueryStageSchedulerEvent>,
277        executor_id: &str,
278        reason: Option<String>,
279        wait_secs: u64,
280    ) {
281        let executor_id = executor_id.to_owned();
282        tokio::spawn(async move {
283            // Wait for `wait_secs` before removing executor
284            tokio::time::sleep(Duration::from_secs(wait_secs)).await;
285
286            // Update the executor manager immediately here
287            if let Err(e) = executor_manager
288                .remove_executor(&executor_id, reason.clone())
289                .await
290            {
291                error!("error removing executor {executor_id}: {e:?}");
292            }
293
294            if let Err(e) = event_sender
295                .post_event(QueryStageSchedulerEvent::ExecutorLost(executor_id, reason))
296                .await
297            {
298                error!("error sending ExecutorLost event: {e:?}");
299            }
300        });
301    }
302
303    async fn do_register_executor(&self, metadata: ExecutorMetadata) -> Result<()> {
304        let executor_data = ExecutorData {
305            executor_id: metadata.id.clone(),
306            total_task_slots: metadata.specification.task_slots,
307            available_task_slots: metadata.specification.task_slots,
308        };
309
310        // Save the executor to state
311        self.state
312            .executor_manager
313            .register_executor(metadata, executor_data)
314            .await?;
315
316        // If we are using push-based scheduling then reserve this executors slots and send
317        // them for scheduling tasks.
318        if self.state.config.is_push_staged_scheduling() {
319            self.revive_offers().await?;
320        }
321
322        Ok(())
323    }
324}
325
326pub fn timestamp_secs() -> u64 {
327    SystemTime::now()
328        .duration_since(UNIX_EPOCH)
329        .expect("Time went backwards")
330        .as_secs()
331}
332
333pub fn timestamp_millis() -> u64 {
334    SystemTime::now()
335        .duration_since(UNIX_EPOCH)
336        .expect("Time went backwards")
337        .as_millis() as u64
338}
339
340#[cfg(all(test, feature = "sled"))]
341mod test {
342    use std::sync::Arc;
343
344    use datafusion::arrow::datatypes::{DataType, Field, Schema};
345    use datafusion::functions_aggregate::sum::sum;
346    use datafusion::logical_expr::{col, LogicalPlan};
347
348    use datafusion::test_util::scan_empty_with_partitions;
349    use datafusion_proto::protobuf::LogicalPlanNode;
350    use datafusion_proto::protobuf::PhysicalPlanNode;
351
352    use kapot_core::config::{
353        KapotConfig, TaskSchedulingPolicy, KAPOT_DEFAULT_SHUFFLE_PARTITIONS,
354    };
355    use kapot_core::error::Result;
356
357    use crate::config::SchedulerConfig;
358
359    use kapot_core::serde::protobuf::{
360        failed_task, job_status, task_status, ExecutionError, FailedTask, JobStatus,
361        MultiTaskDefinition, ShuffleWritePartition, SuccessfulJob, SuccessfulTask,
362        TaskId, TaskStatus,
363    };
364    use kapot_core::serde::scheduler::{
365        ExecutorData, ExecutorMetadata, ExecutorSpecification,
366    };
367    use kapot_core::serde::KapotCodec;
368
369    use crate::scheduler_server::{timestamp_millis, SchedulerServer};
370
371    use crate::test_utils::{
372        assert_completed_event, assert_failed_event, assert_no_submitted_event,
373        assert_submitted_event, test_cluster_context, ExplodingTableProvider,
374        SchedulerTest, TaskRunnerFn, TestMetricsCollector,
375    };
376
377    #[tokio::test]
378    async fn test_pull_scheduling() -> Result<()> {
379        let plan = test_plan();
380        let task_slots = 4;
381
382        let scheduler = test_scheduler(TaskSchedulingPolicy::PullStaged).await?;
383
384        let executors = test_executors(task_slots);
385        for (executor_metadata, executor_data) in executors {
386            scheduler
387                .state
388                .executor_manager
389                .register_executor(executor_metadata, executor_data)
390                .await?;
391        }
392
393        let config = test_session(task_slots);
394
395        let ctx = scheduler
396            .state
397            .session_manager
398            .create_session(&config)
399            .await?;
400
401        let job_id = "job";
402
403        // Enqueue job
404        scheduler
405            .state
406            .task_manager
407            .queue_job(job_id, "", timestamp_millis())?;
408
409        // Submit job
410        scheduler
411            .state
412            .submit_job(job_id, "", ctx, &plan, 0)
413            .await
414            .expect("submitting plan");
415
416        // Refresh the ExecutionGraph
417        while let Some(graph) = scheduler
418            .state
419            .task_manager
420            .get_active_execution_graph(job_id)
421        {
422            let task = {
423                let mut graph = graph.write().await;
424                graph.pop_next_task("executor-1")?
425            };
426            if let Some(task) = task {
427                let mut partitions: Vec<ShuffleWritePartition> = vec![];
428
429                let num_partitions = task.get_output_partition_number();
430
431                for partition_id in 0..num_partitions {
432                    partitions.push(ShuffleWritePartition {
433                        partition_id: partition_id as u64,
434                        path: "some/path".to_string(),
435                        num_batches: 1,
436                        num_rows: 1,
437                        num_bytes: 1,
438                    })
439                }
440
441                // Complete the task
442                let task_status = TaskStatus {
443                    task_id: task.task_id as u32,
444                    job_id: task.partition.job_id.clone(),
445                    stage_id: task.partition.stage_id as u32,
446                    stage_attempt_num: task.stage_attempt_num as u32,
447                    partition_id: task.partition.partition_id as u32,
448                    launch_time: 0,
449                    start_exec_time: 0,
450                    end_exec_time: 0,
451                    metrics: vec![],
452                    status: Some(task_status::Status::Successful(SuccessfulTask {
453                        executor_id: "executor-1".to_owned(),
454                        partitions,
455                    })),
456                };
457
458                scheduler
459                    .state
460                    .update_task_statuses("executor-1", vec![task_status])
461                    .await?;
462            } else {
463                break;
464            }
465        }
466
467        let final_graph = scheduler
468            .state
469            .task_manager
470            .get_active_execution_graph(job_id)
471            .expect("Fail to find graph in the cache");
472
473        let final_graph = final_graph.read().await;
474        assert!(final_graph.is_successful());
475        assert_eq!(final_graph.output_locations().len(), 4);
476
477        for output_location in final_graph.output_locations() {
478            assert_eq!(output_location.path, "some/path".to_owned());
479            assert_eq!(output_location.executor_meta.host, "localhost1".to_owned())
480        }
481
482        Ok(())
483    }
484
485    #[tokio::test]
486    async fn test_push_scheduling() -> Result<()> {
487        let plan = test_plan();
488
489        let metrics_collector = Arc::new(TestMetricsCollector::default());
490
491        let mut test = SchedulerTest::new(
492            SchedulerConfig::default()
493                .with_scheduler_policy(TaskSchedulingPolicy::PushStaged),
494            metrics_collector.clone(),
495            4,
496            1,
497            None,
498        )
499        .await?;
500
501        let status = test.run("job", "", &plan).await.expect("running plan");
502
503        match status.status {
504            Some(job_status::Status::Successful(SuccessfulJob {
505                partition_location,
506                ..
507            })) => {
508                assert_eq!(partition_location.len(), 4);
509            }
510            other => {
511                panic!("Expected success status but found {:?}", other);
512            }
513        }
514
515        assert_submitted_event("job", &metrics_collector);
516        assert_completed_event("job", &metrics_collector);
517
518        Ok(())
519    }
520
521    // Simulate a task failure and ensure the job status is updated correctly
522    #[tokio::test]
523    async fn test_job_failure() -> Result<()> {
524        let plan = test_plan();
525
526        let runner = Arc::new(TaskRunnerFn::new(
527            |_executor_id: String, task: MultiTaskDefinition| {
528                let mut statuses = vec![];
529
530                for TaskId {
531                    task_id,
532                    partition_id,
533                    ..
534                } in task.task_ids
535                {
536                    let timestamp = timestamp_millis();
537                    statuses.push(TaskStatus {
538                        task_id,
539                        job_id: task.job_id.clone(),
540                        stage_id: task.stage_id,
541                        stage_attempt_num: task.stage_attempt_num,
542                        partition_id,
543                        launch_time: timestamp,
544                        start_exec_time: timestamp,
545                        end_exec_time: timestamp,
546                        metrics: vec![],
547                        status: Some(task_status::Status::Failed(FailedTask {
548                            error: "ERROR".to_string(),
549                            retryable: false,
550                            count_to_failures: false,
551                            failed_reason: Some(
552                                failed_task::FailedReason::ExecutionError(
553                                    ExecutionError {},
554                                ),
555                            ),
556                        })),
557                    });
558                }
559
560                statuses
561            },
562        ));
563
564        let metrics_collector = Arc::new(TestMetricsCollector::default());
565
566        let mut test = SchedulerTest::new(
567            SchedulerConfig::default()
568                .with_scheduler_policy(TaskSchedulingPolicy::PushStaged),
569            metrics_collector.clone(),
570            4,
571            1,
572            Some(runner),
573        )
574        .await?;
575
576        let status = test.run("job", "", &plan).await.expect("running plan");
577
578        assert!(
579            matches!(
580                status,
581                JobStatus {
582                    status: Some(job_status::Status::Failed(_)),
583                    ..
584                }
585            ),
586            "{}",
587            "Expected job status to be failed but it was {status:?}"
588        );
589
590        assert_submitted_event("job", &metrics_collector);
591        assert_failed_event("job", &metrics_collector);
592
593        Ok(())
594    }
595
596    // If the physical planning fails, the job should be marked as failed.
597    // Here we simulate a planning failure using ExplodingTableProvider to test this.
598    #[tokio::test]
599    async fn test_planning_failure() -> Result<()> {
600        let metrics_collector = Arc::new(TestMetricsCollector::default());
601        let mut test = SchedulerTest::new(
602            SchedulerConfig::default()
603                .with_scheduler_policy(TaskSchedulingPolicy::PushStaged),
604            metrics_collector.clone(),
605            4,
606            1,
607            None,
608        )
609        .await?;
610
611        let ctx = test.ctx().await?;
612
613        ctx.register_table("explode", Arc::new(ExplodingTableProvider))?;
614
615        let plan = ctx
616            .sql("SELECT * FROM explode")
617            .await?
618            .into_optimized_plan()?;
619
620        // This should fail when we try and create the physical plan
621        let status = test.run("job", "", &plan).await?;
622
623        assert!(
624            matches!(
625                status,
626                JobStatus {
627                    status: Some(job_status::Status::Failed(_)),
628                    ..
629                }
630            ),
631            "{}",
632            "Expected job status to be failed but it was {status:?}"
633        );
634
635        assert_no_submitted_event("job", &metrics_collector);
636        assert_failed_event("job", &metrics_collector);
637
638        Ok(())
639    }
640
641    async fn test_scheduler(
642        scheduling_policy: TaskSchedulingPolicy,
643    ) -> Result<SchedulerServer<LogicalPlanNode, PhysicalPlanNode>> {
644        let cluster = test_cluster_context();
645
646        let config = SchedulerConfig::default().with_scheduler_policy(scheduling_policy);
647        let mut scheduler: SchedulerServer<LogicalPlanNode, PhysicalPlanNode> =
648            SchedulerServer::new(
649                "localhost:50050".to_owned(),
650                cluster,
651                KapotCodec::default(),
652                Arc::new(config),
653                Arc::new(TestMetricsCollector::default()),
654            );
655        scheduler.init().await?;
656
657        Ok(scheduler)
658    }
659
660    fn test_executors(num_partitions: usize) -> Vec<(ExecutorMetadata, ExecutorData)> {
661        let task_slots = (num_partitions as u32 + 1) / 2;
662
663        vec![
664            (
665                ExecutorMetadata {
666                    id: "executor-1".to_string(),
667                    host: "localhost1".to_string(),
668                    port: 8080,
669                    grpc_port: 9090,
670                    specification: ExecutorSpecification { task_slots },
671                },
672                ExecutorData {
673                    executor_id: "executor-1".to_owned(),
674                    total_task_slots: task_slots,
675                    available_task_slots: task_slots,
676                },
677            ),
678            (
679                ExecutorMetadata {
680                    id: "executor-2".to_string(),
681                    host: "localhost2".to_string(),
682                    port: 8080,
683                    grpc_port: 9090,
684                    specification: ExecutorSpecification {
685                        task_slots: num_partitions as u32 - task_slots,
686                    },
687                },
688                ExecutorData {
689                    executor_id: "executor-2".to_owned(),
690                    total_task_slots: num_partitions as u32 - task_slots,
691                    available_task_slots: num_partitions as u32 - task_slots,
692                },
693            ),
694        ]
695    }
696
697    fn test_plan() -> LogicalPlan {
698        let schema = Schema::new(vec![
699            Field::new("id", DataType::Utf8, false),
700            Field::new("gmv", DataType::UInt64, false),
701        ]);
702
703        // partitions need to be > 1 for the datafusion's optimizer to insert a repartition node
704        // behavior changed with: https://github.com/apache/datafusion/pull/11875
705        scan_empty_with_partitions(None, &schema, Some(vec![0, 1]), 2)
706            .unwrap()
707            .aggregate(vec![col("id")], vec![sum(col("gmv"))])
708            .unwrap()
709            .build()
710            .unwrap()
711    }
712
713    fn test_session(partitions: usize) -> KapotConfig {
714        KapotConfig::builder()
715            .set(
716                KAPOT_DEFAULT_SHUFFLE_PARTITIONS,
717                format!("{partitions}").as_str(),
718            )
719            .build()
720            .expect("creating kapotConfig")
721    }
722}