1use crate::scheduler_server::event::QueryStageSchedulerEvent;
19
20use crate::state::execution_graph::{
21 ExecutionGraph, ExecutionStage, RunningTaskInfo, TaskDescription,
22};
23use crate::state::executor_manager::ExecutorManager;
24
25use kapot_core::error::KapotError;
26use kapot_core::error::Result;
27
28use crate::cluster::JobState;
29use kapot_core::serde::protobuf::{
30 job_status, JobStatus, KeyValuePair, MultiTaskDefinition, TaskDefinition, TaskId,
31 TaskStatus,
32};
33use kapot_core::serde::scheduler::ExecutorMetadata;
34use kapot_core::serde::KapotCodec;
35use dashmap::DashMap;
36
37use datafusion::physical_plan::ExecutionPlan;
38use datafusion_proto::logical_plan::AsLogicalPlan;
39use datafusion_proto::physical_plan::AsExecutionPlan;
40use log::{debug, error, info, warn};
41use rand::distributions::Alphanumeric;
42use rand::{thread_rng, Rng};
43use std::collections::{HashMap, HashSet};
44use std::ops::Deref;
45use std::sync::Arc;
46use std::time::Duration;
47use std::time::{SystemTime, UNIX_EPOCH};
48use tokio::sync::RwLock;
49
50use kapot_core::config::KAPOT_DATA_CACHE_ENABLED;
51use tracing::trace;
52
53type ActiveJobCache = Arc<DashMap<String, JobInfoCache>>;
54
55pub const TASK_MAX_FAILURES: usize = 4;
58pub const STAGE_MAX_FAILURES: usize = 4;
60
61#[async_trait::async_trait]
62pub trait TaskLauncher: Send + Sync + 'static {
63 async fn launch_tasks(
64 &self,
65 executor: &ExecutorMetadata,
66 tasks: Vec<MultiTaskDefinition>,
67 executor_manager: &ExecutorManager,
68 ) -> Result<()>;
69}
70
71struct DefaultTaskLauncher {
72 scheduler_id: String,
73}
74
75impl DefaultTaskLauncher {
76 pub fn new(scheduler_id: String) -> Self {
77 Self { scheduler_id }
78 }
79}
80
81#[async_trait::async_trait]
82impl TaskLauncher for DefaultTaskLauncher {
83 async fn launch_tasks(
84 &self,
85 executor: &ExecutorMetadata,
86 tasks: Vec<MultiTaskDefinition>,
87 executor_manager: &ExecutorManager,
88 ) -> Result<()> {
89 if log::max_level() >= log::Level::Info {
90 let tasks_ids: Vec<String> = tasks
91 .iter()
92 .map(|task| {
93 let task_ids: Vec<u32> = task
94 .task_ids
95 .iter()
96 .map(|task_id| task_id.partition_id)
97 .collect();
98 format!("{}/{}/{:?}", task.job_id, task.stage_id, task_ids)
99 })
100 .collect();
101 info!(
102 "Launching multi task on executor {:?} for {:?}",
103 executor.id, tasks_ids
104 );
105 }
106 executor_manager
107 .launch_multi_task(&executor.id, tasks, self.scheduler_id.clone())
108 .await?;
109 Ok(())
110 }
111}
112
113#[derive(Clone)]
114pub struct TaskManager<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> {
115 state: Arc<dyn JobState>,
116 codec: KapotCodec<T, U>,
117 scheduler_id: String,
118 active_job_cache: ActiveJobCache,
120 launcher: Arc<dyn TaskLauncher>,
121}
122
123#[derive(Clone)]
124pub struct JobInfoCache {
125 pub execution_graph: Arc<RwLock<ExecutionGraph>>,
127 pub status: Option<job_status::Status>,
129 encoded_stage_plans: HashMap<usize, Vec<u8>>,
131}
132
133impl JobInfoCache {
134 pub fn new(graph: ExecutionGraph) -> Self {
135 let status = graph.status().status.clone();
136 Self {
137 execution_graph: Arc::new(RwLock::new(graph)),
138 status,
139 encoded_stage_plans: HashMap::new(),
140 }
141 }
142}
143
144#[derive(Clone)]
145pub struct UpdatedStages {
146 pub resolved_stages: HashSet<usize>,
147 pub successful_stages: HashSet<usize>,
148 pub failed_stages: HashMap<usize, String>,
149 pub rollback_running_stages: HashMap<usize, HashSet<String>>,
150 pub resubmit_successful_stages: HashSet<usize>,
151}
152
153impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> TaskManager<T, U> {
154 pub fn new(
155 state: Arc<dyn JobState>,
156 codec: KapotCodec<T, U>,
157 scheduler_id: String,
158 ) -> Self {
159 Self {
160 state,
161 codec,
162 scheduler_id: scheduler_id.clone(),
163 active_job_cache: Arc::new(DashMap::new()),
164 launcher: Arc::new(DefaultTaskLauncher::new(scheduler_id)),
165 }
166 }
167
168 #[allow(dead_code)]
169 pub(crate) fn with_launcher(
170 state: Arc<dyn JobState>,
171 codec: KapotCodec<T, U>,
172 scheduler_id: String,
173 launcher: Arc<dyn TaskLauncher>,
174 ) -> Self {
175 Self {
176 state,
177 codec,
178 scheduler_id,
179 active_job_cache: Arc::new(DashMap::new()),
180 launcher,
181 }
182 }
183
184 pub fn queue_job(&self, job_id: &str, job_name: &str, queued_at: u64) -> Result<()> {
186 self.state.accept_job(job_id, job_name, queued_at)
187 }
188
189 pub fn pending_job_number(&self) -> usize {
192 self.state.pending_job_number()
193 }
194
195 pub fn running_job_number(&self) -> usize {
197 self.active_job_cache.len()
198 }
199
200 pub async fn submit_job(
204 &self,
205 job_id: &str,
206 job_name: &str,
207 session_id: &str,
208 plan: Arc<dyn ExecutionPlan>,
209 queued_at: u64,
210 ) -> Result<()> {
211 let mut graph = ExecutionGraph::new(
212 &self.scheduler_id,
213 job_id,
214 job_name,
215 session_id,
216 plan,
217 queued_at,
218 )?;
219 info!("Submitting execution graph: {:?}", graph);
220
221 self.state.submit_job(job_id.to_string(), &graph).await?;
222
223 graph.revive();
224 self.active_job_cache
225 .insert(job_id.to_owned(), JobInfoCache::new(graph));
226
227 Ok(())
228 }
229
230 pub fn get_running_job_cache(&self) -> Arc<HashMap<String, JobInfoCache>> {
231 let ret = self
232 .active_job_cache
233 .iter()
234 .filter_map(|pair| {
235 let (job_id, job_info) = pair.pair();
236 if matches!(job_info.status, Some(job_status::Status::Running(_))) {
237 Some((job_id.clone(), job_info.clone()))
238 } else {
239 None
240 }
241 })
242 .collect::<HashMap<_, _>>();
243 Arc::new(ret)
244 }
245
246 pub async fn get_jobs(&self) -> Result<Vec<JobOverview>> {
248 let job_ids = self.state.get_jobs().await?;
249
250 let mut jobs = vec![];
251 for job_id in &job_ids {
252 if let Some(cached) = self.get_active_execution_graph(job_id) {
253 let graph = cached.read().await;
254 jobs.push(graph.deref().into());
255 } else {
256 let graph = self.state
257 .get_execution_graph(job_id)
258 .await?
259 .ok_or_else(|| KapotError::Internal(format!("Error getting job overview, no execution graph found for job {job_id}")))?;
260 jobs.push((&graph).into());
261 }
262 }
263 Ok(jobs)
264 }
265
266 pub async fn get_job_status(&self, job_id: &str) -> Result<Option<JobStatus>> {
269 if let Some(graph) = self.get_active_execution_graph(job_id) {
270 let guard = graph.read().await;
271
272 Ok(Some(guard.status().clone()))
273 } else {
274 self.state.get_job_status(job_id).await
275 }
276 }
277
278 pub(crate) async fn get_job_execution_graph(
281 &self,
282 job_id: &str,
283 ) -> Result<Option<Arc<ExecutionGraph>>> {
284 if let Some(cached) = self.get_active_execution_graph(job_id) {
285 let guard = cached.read().await;
286
287 Ok(Some(Arc::new(guard.deref().clone())))
288 } else {
289 let graph = self.state.get_execution_graph(job_id).await?;
290
291 Ok(graph.map(Arc::new))
292 }
293 }
294
295 pub(crate) async fn update_task_statuses(
299 &self,
300 executor: &ExecutorMetadata,
301 task_status: Vec<TaskStatus>,
302 ) -> Result<Vec<QueryStageSchedulerEvent>> {
303 let mut job_updates: HashMap<String, Vec<TaskStatus>> = HashMap::new();
304 for status in task_status {
305 trace!("Task Update\n{:?}", status);
306 let job_id = status.job_id.clone();
307 let job_task_statuses = job_updates.entry(job_id).or_default();
308 job_task_statuses.push(status);
309 }
310
311 let mut events: Vec<QueryStageSchedulerEvent> = vec![];
312 for (job_id, statuses) in job_updates {
313 let num_tasks = statuses.len();
314 debug!("Updating {} tasks in job {}", num_tasks, job_id);
315
316 let job_events = if let Some(cached) =
318 self.get_active_execution_graph(&job_id)
319 {
320 let mut graph = cached.write().await;
321 graph.update_task_status(
322 executor,
323 statuses,
324 TASK_MAX_FAILURES,
325 STAGE_MAX_FAILURES,
326 )?
327 } else {
328 error!("Fail to find job {} in the active cache and it may not be curated by this scheduler", job_id);
330 vec![]
331 };
332
333 for event in job_events {
334 events.push(event);
335 }
336 }
337
338 Ok(events)
339 }
340
341 pub(crate) async fn succeed_job(&self, job_id: &str) -> Result<()> {
344 debug!("Moving job {} from Active to Success", job_id);
345
346 if let Some(graph) = self.remove_active_execution_graph(job_id) {
347 let graph = graph.read().await.clone();
348 if graph.is_successful() {
349 self.state.save_job(job_id, &graph).await?;
350 } else {
351 error!("Job {} has not finished and cannot be completed", job_id);
352 return Ok(());
353 }
354 } else {
355 warn!("Fail to find job {} in the cache", job_id);
356 }
357
358 Ok(())
359 }
360
361 pub(crate) async fn cancel_job(
363 &self,
364 job_id: &str,
365 ) -> Result<(Vec<RunningTaskInfo>, usize)> {
366 self.abort_job(job_id, "Cancelled".to_owned()).await
367 }
368
369 pub(crate) async fn abort_job(
371 &self,
372 job_id: &str,
373 failure_reason: String,
374 ) -> Result<(Vec<RunningTaskInfo>, usize)> {
375 let (tasks_to_cancel, pending_tasks) = if let Some(graph) =
376 self.remove_active_execution_graph(job_id)
377 {
378 let mut guard = graph.write().await;
379
380 let pending_tasks = guard.available_tasks();
381 let running_tasks = guard.running_tasks();
382
383 info!(
384 "Cancelling {} running tasks for job {}",
385 running_tasks.len(),
386 job_id
387 );
388
389 guard.fail_job(failure_reason);
390
391 self.state.save_job(job_id, &guard).await?;
392
393 (running_tasks, pending_tasks)
394 } else {
395 warn!("Fail to find job {} in the cache, unable to cancel tasks for job, fail the job state only.", job_id);
397 (vec![], 0)
398 };
399
400 Ok((tasks_to_cancel, pending_tasks))
401 }
402
403 pub async fn fail_unscheduled_job(
406 &self,
407 job_id: &str,
408 failure_reason: String,
409 ) -> Result<()> {
410 self.state
411 .fail_unscheduled_job(job_id, failure_reason)
412 .await
413 }
414
415 pub async fn update_job(&self, job_id: &str) -> Result<usize> {
416 debug!("Update active job {job_id}");
417 if let Some(graph) = self.get_active_execution_graph(job_id) {
418 let mut graph = graph.write().await;
419
420 let curr_available_tasks = graph.available_tasks();
421
422 graph.revive();
423
424 println!("Saving job with status {:?}", graph.status());
425
426 self.state.save_job(job_id, &graph).await?;
427
428 let new_tasks = graph.available_tasks() - curr_available_tasks;
429
430 Ok(new_tasks)
431 } else {
432 warn!("Fail to find job {} in the cache", job_id);
433
434 Ok(0)
435 }
436 }
437
438 pub async fn executor_lost(&self, executor_id: &str) -> Result<Vec<RunningTaskInfo>> {
440 let mut running_tasks_to_cancel: Vec<RunningTaskInfo> = vec![];
442 let updated_graphs: DashMap<String, ExecutionGraph> = DashMap::new();
444 {
445 for pairs in self.active_job_cache.iter() {
446 let (job_id, job_info) = pairs.pair();
447 let mut graph = job_info.execution_graph.write().await;
448 let reset = graph.reset_stages_on_lost_executor(executor_id)?;
449 if !reset.0.is_empty() {
450 updated_graphs.insert(job_id.to_owned(), graph.clone());
451 running_tasks_to_cancel.extend(reset.1);
452 }
453 }
454 }
455
456 Ok(running_tasks_to_cancel)
457 }
458
459 pub async fn get_available_task_count(&self, job_id: &str) -> Result<usize> {
462 if let Some(graph) = self.get_active_execution_graph(job_id) {
463 let available_tasks = graph.read().await.available_tasks();
464 Ok(available_tasks)
465 } else {
466 warn!("Fail to find job {} in the cache", job_id);
467 Ok(0)
468 }
469 }
470
471 #[allow(dead_code)]
472 pub fn prepare_task_definition(
473 &self,
474 task: TaskDescription,
475 ) -> Result<TaskDefinition> {
476 debug!("Preparing task definition for {:?}", task);
477
478 let job_id = task.partition.job_id.clone();
479 let stage_id = task.partition.stage_id;
480
481 if let Some(mut job_info) = self.active_job_cache.get_mut(&job_id) {
482 let plan = if let Some(plan) = job_info.encoded_stage_plans.get(&stage_id) {
483 plan.clone()
484 } else {
485 let mut plan_buf: Vec<u8> = vec![];
486 let plan_proto = U::try_from_physical_plan(
487 task.plan,
488 self.codec.physical_extension_codec(),
489 )?;
490 plan_proto.try_encode(&mut plan_buf)?;
491
492 job_info
493 .encoded_stage_plans
494 .insert(stage_id, plan_buf.clone());
495
496 plan_buf
497 };
498
499 let mut props = vec![];
500 if task.data_cache {
501 props.push(KeyValuePair {
502 key: KAPOT_DATA_CACHE_ENABLED.to_string(),
503 value: "true".to_string(),
504 });
505 }
506
507 let task_definition = TaskDefinition {
508 task_id: task.task_id as u32,
509 task_attempt_num: task.task_attempt as u32,
510 job_id,
511 stage_id: stage_id as u32,
512 stage_attempt_num: task.stage_attempt_num as u32,
513 partition_id: task.partition.partition_id as u32,
514 plan,
515 session_id: task.session_id,
516 launch_time: SystemTime::now()
517 .duration_since(UNIX_EPOCH)
518 .unwrap()
519 .as_millis() as u64,
520 props,
521 };
522 Ok(task_definition)
523 } else {
524 Err(KapotError::General(format!(
525 "Cannot prepare task definition for job {job_id} which is not in active cache"
526 )))
527 }
528 }
529
530 pub(crate) async fn launch_multi_task(
532 &self,
533 executor: &ExecutorMetadata,
534 tasks: Vec<Vec<TaskDescription>>,
535 executor_manager: &ExecutorManager,
536 ) -> Result<()> {
537 let mut multi_tasks = vec![];
538 for stage_tasks in tasks {
539 match self.prepare_multi_task_definition(stage_tasks) {
540 Ok(stage_tasks) => multi_tasks.extend(stage_tasks),
541 Err(e) => error!("Fail to prepare task definition: {:?}", e),
542 }
543 }
544
545 if !multi_tasks.is_empty() {
546 self.launcher
547 .launch_tasks(executor, multi_tasks, executor_manager)
548 .await
549 } else {
550 Ok(())
551 }
552 }
553
554 #[allow(dead_code)]
555 fn prepare_multi_task_definition(
557 &self,
558 tasks: Vec<TaskDescription>,
559 ) -> Result<Vec<MultiTaskDefinition>> {
560 if let Some(task) = tasks.first() {
561 let session_id = task.session_id.clone();
562 let job_id = task.partition.job_id.clone();
563 let stage_id = task.partition.stage_id;
564 let stage_attempt_num = task.stage_attempt_num;
565
566 if log::max_level() >= log::Level::Debug {
567 let task_ids: Vec<usize> = tasks
568 .iter()
569 .map(|task| task.partition.partition_id)
570 .collect();
571 debug!("Preparing multi task definition for tasks {:?} belonging to job stage {}/{}", task_ids, job_id, stage_id);
572 trace!("With task details {:?}", tasks);
573 }
574
575 if let Some(mut job_info) = self.active_job_cache.get_mut(&job_id) {
576 let plan = if let Some(plan) = job_info.encoded_stage_plans.get(&stage_id)
577 {
578 plan.clone()
579 } else {
580 let mut plan_buf: Vec<u8> = vec![];
581 let plan_proto = U::try_from_physical_plan(
582 task.plan.clone(),
583 self.codec.physical_extension_codec(),
584 )?;
585 plan_proto.try_encode(&mut plan_buf)?;
586
587 job_info
588 .encoded_stage_plans
589 .insert(stage_id, plan_buf.clone());
590
591 plan_buf
592 };
593
594 let launch_time = SystemTime::now()
595 .duration_since(UNIX_EPOCH)
596 .unwrap()
597 .as_millis() as u64;
598
599 let (tasks_with_data_cache, tasks_without_data_cache): (Vec<_>, Vec<_>) =
600 tasks.into_iter().partition(|task| task.data_cache);
601
602 let mut multi_tasks = vec![];
603 if !tasks_with_data_cache.is_empty() {
604 let task_ids = tasks_with_data_cache
605 .into_iter()
606 .map(|task| TaskId {
607 task_id: task.task_id as u32,
608 task_attempt_num: task.task_attempt as u32,
609 partition_id: task.partition.partition_id as u32,
610 })
611 .collect();
612 multi_tasks.push(MultiTaskDefinition {
613 task_ids,
614 job_id: job_id.clone(),
615 stage_id: stage_id as u32,
616 stage_attempt_num: stage_attempt_num as u32,
617 plan: plan.clone(),
618 session_id: session_id.clone(),
619 launch_time,
620 props: vec![KeyValuePair {
621 key: KAPOT_DATA_CACHE_ENABLED.to_string(),
622 value: "true".to_string(),
623 }],
624 });
625 }
626 if !tasks_without_data_cache.is_empty() {
627 let task_ids = tasks_without_data_cache
628 .into_iter()
629 .map(|task| TaskId {
630 task_id: task.task_id as u32,
631 task_attempt_num: task.task_attempt as u32,
632 partition_id: task.partition.partition_id as u32,
633 })
634 .collect();
635 multi_tasks.push(MultiTaskDefinition {
636 task_ids,
637 job_id,
638 stage_id: stage_id as u32,
639 stage_attempt_num: stage_attempt_num as u32,
640 plan,
641 session_id,
642 launch_time,
643 props: vec![],
644 });
645 }
646
647 Ok(multi_tasks)
648 } else {
649 Err(KapotError::General(format!("Cannot prepare multi task definition for job {job_id} which is not in active cache")))
650 }
651 } else {
652 Err(KapotError::General(
653 "Cannot prepare multi task definition for an empty vec".to_string(),
654 ))
655 }
656 }
657
658 pub(crate) fn get_active_execution_graph(
660 &self,
661 job_id: &str,
662 ) -> Option<Arc<RwLock<ExecutionGraph>>> {
663 self.active_job_cache
664 .get(job_id)
665 .as_deref()
666 .map(|cached| cached.execution_graph.clone())
667 }
668
669 pub(crate) fn remove_active_execution_graph(
671 &self,
672 job_id: &str,
673 ) -> Option<Arc<RwLock<ExecutionGraph>>> {
674 self.active_job_cache
675 .remove(job_id)
676 .map(|value| value.1.execution_graph)
677 }
678
679 pub fn generate_job_id(&self) -> String {
681 let mut rng = thread_rng();
682 std::iter::repeat(())
683 .map(|()| rng.sample(Alphanumeric))
684 .map(char::from)
685 .take(7)
686 .collect()
687 }
688
689 pub(crate) fn clean_up_job_delayed(&self, job_id: String, clean_up_interval: u64) {
691 if clean_up_interval == 0 {
692 info!("The interval is 0 and the clean up for the failed job state {} will not triggered", job_id);
693 return;
694 }
695
696 let state = self.state.clone();
697 tokio::spawn(async move {
698 tokio::time::sleep(Duration::from_secs(clean_up_interval)).await;
699 if let Err(err) = state.remove_job(&job_id).await {
700 error!("Failed to delete job {job_id}: {err:?}");
701 }
702 });
703 }
704}
705
706pub struct JobOverview {
707 pub job_id: String,
708 pub job_name: String,
709 pub status: JobStatus,
710 pub start_time: u64,
711 pub end_time: u64,
712 pub num_stages: usize,
713 pub completed_stages: usize,
714}
715
716impl From<&ExecutionGraph> for JobOverview {
717 fn from(value: &ExecutionGraph) -> Self {
718 let mut completed_stages = 0;
719 for stage in value.stages().values() {
720 if let ExecutionStage::Successful(_) = stage {
721 completed_stages += 1;
722 }
723 }
724
725 Self {
726 job_id: value.job_id().to_string(),
727 job_name: value.job_name().to_string(),
728 status: value.status().clone(),
729 start_time: value.start_time(),
730 end_time: value.end_time(),
731 num_stages: value.stage_count(),
732 completed_stages,
733 }
734 }
735}