kapot_scheduler/state/
executor_manager.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::time::Duration;
19
20use kapot_core::error::KapotError;
21use kapot_core::error::Result;
22use kapot_core::serde::protobuf;
23
24use crate::cluster::{BoundTask, ClusterState, ExecutorSlot};
25use crate::config::SchedulerConfig;
26
27use crate::state::execution_graph::RunningTaskInfo;
28use crate::state::task_manager::JobInfoCache;
29use kapot_core::serde::protobuf::executor_grpc_client::ExecutorGrpcClient;
30use kapot_core::serde::protobuf::{
31    executor_status, CancelTasksParams, ExecutorHeartbeat, MultiTaskDefinition,
32    RemoveJobDataParams, StopExecutorParams,
33};
34use kapot_core::serde::scheduler::{ExecutorData, ExecutorMetadata};
35use kapot_core::utils::{create_grpc_client_connection, get_time_before};
36use dashmap::DashMap;
37use log::{debug, error, info, warn};
38use std::collections::{HashMap, HashSet};
39use std::sync::Arc;
40use tonic::transport::Channel;
41
42type ExecutorClients = Arc<DashMap<String, ExecutorGrpcClient<Channel>>>;
43
44#[derive(Clone)]
45pub struct ExecutorManager {
46    cluster_state: Arc<dyn ClusterState>,
47    config: Arc<SchedulerConfig>,
48    clients: ExecutorClients,
49}
50
51impl ExecutorManager {
52    pub(crate) fn new(
53        cluster_state: Arc<dyn ClusterState>,
54        config: Arc<SchedulerConfig>,
55    ) -> Self {
56        Self {
57            cluster_state,
58            config,
59            clients: Default::default(),
60        }
61    }
62
63    pub async fn init(&self) -> Result<()> {
64        self.cluster_state.init().await?;
65
66        Ok(())
67    }
68
69    /// Bind the ready to running tasks from [`active_jobs`] with available executors.
70    ///
71    /// If `executors` is provided, only bind slots from the specified executor IDs
72    pub async fn bind_schedulable_tasks(
73        &self,
74        active_jobs: Arc<HashMap<String, JobInfoCache>>,
75    ) -> Result<Vec<BoundTask>> {
76        if active_jobs.is_empty() {
77            warn!("There's no active jobs for binding tasks");
78            return Ok(vec![]);
79        }
80        let alive_executors = self.get_alive_executors();
81        if alive_executors.is_empty() {
82            warn!("There's no alive executors for binding tasks");
83            return Ok(vec![]);
84        }
85        self.cluster_state
86            .bind_schedulable_tasks(
87                self.config.task_distribution,
88                active_jobs,
89                Some(alive_executors),
90            )
91            .await
92    }
93
94    /// Returned reserved task slots to the pool of available slots. This operation is atomic
95    /// so either the entire pool of reserved task slots it returned or none are.
96    pub async fn unbind_tasks(&self, executor_slots: Vec<ExecutorSlot>) -> Result<()> {
97        self.cluster_state.unbind_tasks(executor_slots).await
98    }
99
100    /// Send rpc to Executors to cancel the running tasks
101    pub async fn cancel_running_tasks(&self, tasks: Vec<RunningTaskInfo>) -> Result<()> {
102        let mut tasks_to_cancel: HashMap<String, Vec<protobuf::RunningTaskInfo>> =
103            Default::default();
104
105        for task_info in tasks {
106            let infos = tasks_to_cancel.entry(task_info.executor_id).or_default();
107            infos.push(protobuf::RunningTaskInfo {
108                task_id: task_info.task_id as u32,
109                job_id: task_info.job_id,
110                stage_id: task_info.stage_id as u32,
111                partition_id: task_info.partition_id as u32,
112            });
113        }
114
115        let executor_manager = self.clone();
116        tokio::spawn(async move {
117            for (executor_id, infos) in tasks_to_cancel {
118                if let Ok(mut client) = executor_manager.get_client(&executor_id).await {
119                    if let Err(e) = client
120                        .cancel_tasks(CancelTasksParams { task_infos: infos })
121                        .await
122                    {
123                        error!(
124                            "Fail to cancel tasks for executor ID {} due to {:?}",
125                            executor_id, e
126                        );
127                    }
128                } else {
129                    error!(
130                        "Failed to get client for executor ID {} to cancel tasks",
131                        executor_id
132                    )
133                }
134            }
135        });
136
137        Ok(())
138    }
139
140    /// Send rpc to Executors to clean up the job data by delayed clean_up_interval seconds
141    pub(crate) fn clean_up_job_data_delayed(
142        &self,
143        job_id: String,
144        clean_up_interval: u64,
145    ) {
146        if clean_up_interval == 0 {
147            info!(
148                "The interval is 0 and the clean up for job data {} will not triggered",
149                job_id
150            );
151            return;
152        }
153
154        let executor_manager = self.clone();
155        tokio::spawn(async move {
156            tokio::time::sleep(Duration::from_secs(clean_up_interval)).await;
157            executor_manager.clean_up_job_data_inner(job_id).await;
158        });
159    }
160
161    /// Send rpc to Executors to clean up the job data in a spawn thread
162    pub fn clean_up_job_data(&self, job_id: String) {
163        let executor_manager = self.clone();
164        tokio::spawn(async move {
165            executor_manager.clean_up_job_data_inner(job_id).await;
166        });
167    }
168
169    /// Send rpc to Executors to clean up the job data
170    async fn clean_up_job_data_inner(&self, job_id: String) {
171        let alive_executors = self.get_alive_executors();
172        for executor in alive_executors {
173            let job_id_clone = job_id.to_owned();
174            if let Ok(mut client) = self.get_client(&executor).await {
175                tokio::spawn(async move {
176                    if let Err(err) = client
177                        .remove_job_data(RemoveJobDataParams {
178                            job_id: job_id_clone,
179                        })
180                        .await
181                    {
182                        warn!(
183                            "Failed to call remove_job_data on Executor {} due to {:?}",
184                            executor, err
185                        )
186                    }
187                });
188            } else {
189                warn!("Failed to get client for Executor {}", executor)
190            }
191        }
192    }
193
194    /// Get a list of all executors along with the timestamp of their last recorded heartbeat
195    pub async fn get_executor_state(&self) -> Result<Vec<(ExecutorMetadata, Duration)>> {
196        let heartbeat_timestamps: Vec<(String, u64)> = self
197            .cluster_state
198            .executor_heartbeats()
199            .into_iter()
200            .map(|(executor_id, heartbeat)| (executor_id, heartbeat.timestamp))
201            .collect();
202
203        let mut state: Vec<(ExecutorMetadata, Duration)> = vec![];
204        for (executor_id, ts) in heartbeat_timestamps {
205            let duration = Duration::from_secs(ts);
206
207            let metadata = self.get_executor_metadata(&executor_id).await?;
208
209            state.push((metadata, duration));
210        }
211
212        Ok(state)
213    }
214
215    /// Get executor metadata for the provided executor ID. Returns an error if the executor does not exist
216    pub async fn get_executor_metadata(
217        &self,
218        executor_id: &str,
219    ) -> Result<ExecutorMetadata> {
220        self.cluster_state.get_executor_metadata(executor_id).await
221    }
222
223    /// It's only used for pull-based task scheduling.
224    ///
225    /// For push-based one, we should use [`register_executor`], instead.
226    pub async fn save_executor_metadata(&self, metadata: ExecutorMetadata) -> Result<()> {
227        self.cluster_state.save_executor_metadata(metadata).await
228    }
229
230    /// Register the executor with the scheduler.
231    ///
232    /// This will save the executor metadata and the executor data to persistent state.
233    ///
234    /// It's only used for push-based task scheduling
235    pub async fn register_executor(
236        &self,
237        metadata: ExecutorMetadata,
238        specification: ExecutorData,
239    ) -> Result<()> {
240        debug!(
241            "registering executor {} with {} task slots",
242            metadata.id, specification.total_task_slots
243        );
244
245        ExecutorManager::test_connectivity(&metadata).await?;
246
247        self.cluster_state
248            .register_executor(metadata, specification)
249            .await?;
250
251        Ok(())
252    }
253
254    /// Remove the executor from the cluster
255    pub async fn remove_executor(
256        &self,
257        executor_id: &str,
258        reason: Option<String>,
259    ) -> Result<()> {
260        info!("Removing executor {}: {:?}", executor_id, reason);
261        self.cluster_state.remove_executor(executor_id).await
262    }
263
264    pub async fn stop_executor(&self, executor_id: &str, stop_reason: String) {
265        let executor_id = executor_id.to_string();
266        match self.get_client(&executor_id).await {
267            Ok(mut client) => {
268                tokio::task::spawn(async move {
269                    match client
270                        .stop_executor(StopExecutorParams {
271                            executor_id: executor_id.to_string(),
272                            reason: stop_reason,
273                            force: true,
274                        })
275                        .await
276                    {
277                        Err(error) => {
278                            warn!("Failed to send stop_executor rpc due to, {}", error);
279                        }
280                        Ok(_value) => {}
281                    }
282                });
283            }
284            Err(_) => {
285                warn!(
286                    "Executor is already dead, failed to connect to Executor {}",
287                    executor_id
288                );
289            }
290        }
291    }
292
293    pub async fn launch_multi_task(
294        &self,
295        executor_id: &str,
296        multi_tasks: Vec<MultiTaskDefinition>,
297        scheduler_id: String,
298    ) -> Result<()> {
299        let mut client = self.get_client(executor_id).await?;
300        client
301            .launch_multi_task(protobuf::LaunchMultiTaskParams {
302                multi_tasks,
303                scheduler_id,
304            })
305            .await
306            .map_err(|e| {
307                KapotError::Internal(format!(
308                    "Failed to connect to executor {}: {:?}",
309                    executor_id, e
310                ))
311            })?;
312
313        Ok(())
314    }
315
316    pub(crate) async fn save_executor_heartbeat(
317        &self,
318        heartbeat: ExecutorHeartbeat,
319    ) -> Result<()> {
320        self.cluster_state
321            .save_executor_heartbeat(heartbeat.clone())
322            .await?;
323
324        Ok(())
325    }
326
327    pub(crate) fn is_dead_executor(&self, executor_id: &str) -> bool {
328        self.cluster_state
329            .get_executor_heartbeat(executor_id)
330            .map_or(true, |heartbeat| {
331                matches!(
332                    heartbeat.status,
333                    Some(kapot_core::serde::generated::kapot::ExecutorStatus {
334                        status: Some(executor_status::Status::Dead(_))
335                    })
336                )
337            })
338    }
339
340    /// Retrieve the set of all executor IDs where the executor has been observed in the last
341    /// `last_seen_ts_threshold` seconds.
342    pub(crate) fn get_alive_executors(&self) -> HashSet<String> {
343        let last_seen_ts_threshold =
344            get_time_before(self.config.executor_timeout_seconds);
345        self.cluster_state
346            .executor_heartbeats()
347            .iter()
348            .filter_map(|(exec, heartbeat)| {
349                let active = matches!(
350                    heartbeat
351                        .status
352                        .as_ref()
353                        .and_then(|status| status.status.as_ref()),
354                    Some(executor_status::Status::Active(_))
355                );
356                let live = heartbeat.timestamp > last_seen_ts_threshold;
357
358                (active && live).then(|| exec.clone())
359            })
360            .collect()
361    }
362
363    /// Return a list of expired executors
364    pub(crate) fn get_expired_executors(&self) -> Vec<ExecutorHeartbeat> {
365        // Threshold for last heartbeat from Active executor before marking dead
366        let last_seen_threshold = get_time_before(self.config.executor_timeout_seconds);
367
368        // Threshold for last heartbeat for Fenced executor before marking dead
369        let termination_wait_threshold =
370            get_time_before(self.config.executor_termination_grace_period);
371
372        self.cluster_state
373            .executor_heartbeats()
374            .iter()
375            .filter_map(|(_exec, heartbeat)| {
376                let terminating = matches!(
377                    heartbeat
378                        .status
379                        .as_ref()
380                        .and_then(|status| status.status.as_ref()),
381                    Some(executor_status::Status::Terminating(_))
382                );
383
384                let grace_period_expired =
385                    heartbeat.timestamp <= termination_wait_threshold;
386
387                let expired = heartbeat.timestamp <= last_seen_threshold;
388
389                ((terminating && grace_period_expired) || expired)
390                    .then(|| heartbeat.clone())
391            })
392            .collect::<Vec<_>>()
393    }
394
395    async fn get_client(&self, executor_id: &str) -> Result<ExecutorGrpcClient<Channel>> {
396        let client = self.clients.get(executor_id).map(|value| value.clone());
397
398        if let Some(client) = client {
399            Ok(client)
400        } else {
401            let executor_metadata = self.get_executor_metadata(executor_id).await?;
402            let executor_url = format!(
403                "http://{}:{}",
404                executor_metadata.host, executor_metadata.grpc_port
405            );
406            let connection = create_grpc_client_connection(executor_url).await?;
407            let client = ExecutorGrpcClient::new(connection);
408
409            {
410                self.clients.insert(executor_id.to_owned(), client.clone());
411            }
412            Ok(client)
413        }
414    }
415
416    #[cfg(not(test))]
417    async fn test_connectivity(metadata: &ExecutorMetadata) -> Result<()> {
418        let executor_url = format!("http://{}:{}", metadata.host, metadata.grpc_port);
419        debug!("Connecting to executor {:?}", executor_url);
420        let _ = protobuf::executor_grpc_client::ExecutorGrpcClient::connect(executor_url)
421            .await
422            .map_err(|e| {
423                KapotError::Internal(format!(
424                    "Failed to register executor at {}:{}, could not connect: {:?}",
425                    metadata.host, metadata.grpc_port, e
426                ))
427            })?;
428        Ok(())
429    }
430
431    #[cfg(test)]
432    async fn test_connectivity(_metadata: &ExecutorMetadata) -> Result<()> {
433        Ok(())
434    }
435}