1use std::time::Duration;
19
20use kapot_core::error::KapotError;
21use kapot_core::error::Result;
22use kapot_core::serde::protobuf;
23
24use crate::cluster::{BoundTask, ClusterState, ExecutorSlot};
25use crate::config::SchedulerConfig;
26
27use crate::state::execution_graph::RunningTaskInfo;
28use crate::state::task_manager::JobInfoCache;
29use kapot_core::serde::protobuf::executor_grpc_client::ExecutorGrpcClient;
30use kapot_core::serde::protobuf::{
31 executor_status, CancelTasksParams, ExecutorHeartbeat, MultiTaskDefinition,
32 RemoveJobDataParams, StopExecutorParams,
33};
34use kapot_core::serde::scheduler::{ExecutorData, ExecutorMetadata};
35use kapot_core::utils::{create_grpc_client_connection, get_time_before};
36use dashmap::DashMap;
37use log::{debug, error, info, warn};
38use std::collections::{HashMap, HashSet};
39use std::sync::Arc;
40use tonic::transport::Channel;
41
42type ExecutorClients = Arc<DashMap<String, ExecutorGrpcClient<Channel>>>;
43
44#[derive(Clone)]
45pub struct ExecutorManager {
46 cluster_state: Arc<dyn ClusterState>,
47 config: Arc<SchedulerConfig>,
48 clients: ExecutorClients,
49}
50
51impl ExecutorManager {
52 pub(crate) fn new(
53 cluster_state: Arc<dyn ClusterState>,
54 config: Arc<SchedulerConfig>,
55 ) -> Self {
56 Self {
57 cluster_state,
58 config,
59 clients: Default::default(),
60 }
61 }
62
63 pub async fn init(&self) -> Result<()> {
64 self.cluster_state.init().await?;
65
66 Ok(())
67 }
68
69 pub async fn bind_schedulable_tasks(
73 &self,
74 active_jobs: Arc<HashMap<String, JobInfoCache>>,
75 ) -> Result<Vec<BoundTask>> {
76 if active_jobs.is_empty() {
77 warn!("There's no active jobs for binding tasks");
78 return Ok(vec![]);
79 }
80 let alive_executors = self.get_alive_executors();
81 if alive_executors.is_empty() {
82 warn!("There's no alive executors for binding tasks");
83 return Ok(vec![]);
84 }
85 self.cluster_state
86 .bind_schedulable_tasks(
87 self.config.task_distribution,
88 active_jobs,
89 Some(alive_executors),
90 )
91 .await
92 }
93
94 pub async fn unbind_tasks(&self, executor_slots: Vec<ExecutorSlot>) -> Result<()> {
97 self.cluster_state.unbind_tasks(executor_slots).await
98 }
99
100 pub async fn cancel_running_tasks(&self, tasks: Vec<RunningTaskInfo>) -> Result<()> {
102 let mut tasks_to_cancel: HashMap<String, Vec<protobuf::RunningTaskInfo>> =
103 Default::default();
104
105 for task_info in tasks {
106 let infos = tasks_to_cancel.entry(task_info.executor_id).or_default();
107 infos.push(protobuf::RunningTaskInfo {
108 task_id: task_info.task_id as u32,
109 job_id: task_info.job_id,
110 stage_id: task_info.stage_id as u32,
111 partition_id: task_info.partition_id as u32,
112 });
113 }
114
115 let executor_manager = self.clone();
116 tokio::spawn(async move {
117 for (executor_id, infos) in tasks_to_cancel {
118 if let Ok(mut client) = executor_manager.get_client(&executor_id).await {
119 if let Err(e) = client
120 .cancel_tasks(CancelTasksParams { task_infos: infos })
121 .await
122 {
123 error!(
124 "Fail to cancel tasks for executor ID {} due to {:?}",
125 executor_id, e
126 );
127 }
128 } else {
129 error!(
130 "Failed to get client for executor ID {} to cancel tasks",
131 executor_id
132 )
133 }
134 }
135 });
136
137 Ok(())
138 }
139
140 pub(crate) fn clean_up_job_data_delayed(
142 &self,
143 job_id: String,
144 clean_up_interval: u64,
145 ) {
146 if clean_up_interval == 0 {
147 info!(
148 "The interval is 0 and the clean up for job data {} will not triggered",
149 job_id
150 );
151 return;
152 }
153
154 let executor_manager = self.clone();
155 tokio::spawn(async move {
156 tokio::time::sleep(Duration::from_secs(clean_up_interval)).await;
157 executor_manager.clean_up_job_data_inner(job_id).await;
158 });
159 }
160
161 pub fn clean_up_job_data(&self, job_id: String) {
163 let executor_manager = self.clone();
164 tokio::spawn(async move {
165 executor_manager.clean_up_job_data_inner(job_id).await;
166 });
167 }
168
169 async fn clean_up_job_data_inner(&self, job_id: String) {
171 let alive_executors = self.get_alive_executors();
172 for executor in alive_executors {
173 let job_id_clone = job_id.to_owned();
174 if let Ok(mut client) = self.get_client(&executor).await {
175 tokio::spawn(async move {
176 if let Err(err) = client
177 .remove_job_data(RemoveJobDataParams {
178 job_id: job_id_clone,
179 })
180 .await
181 {
182 warn!(
183 "Failed to call remove_job_data on Executor {} due to {:?}",
184 executor, err
185 )
186 }
187 });
188 } else {
189 warn!("Failed to get client for Executor {}", executor)
190 }
191 }
192 }
193
194 pub async fn get_executor_state(&self) -> Result<Vec<(ExecutorMetadata, Duration)>> {
196 let heartbeat_timestamps: Vec<(String, u64)> = self
197 .cluster_state
198 .executor_heartbeats()
199 .into_iter()
200 .map(|(executor_id, heartbeat)| (executor_id, heartbeat.timestamp))
201 .collect();
202
203 let mut state: Vec<(ExecutorMetadata, Duration)> = vec![];
204 for (executor_id, ts) in heartbeat_timestamps {
205 let duration = Duration::from_secs(ts);
206
207 let metadata = self.get_executor_metadata(&executor_id).await?;
208
209 state.push((metadata, duration));
210 }
211
212 Ok(state)
213 }
214
215 pub async fn get_executor_metadata(
217 &self,
218 executor_id: &str,
219 ) -> Result<ExecutorMetadata> {
220 self.cluster_state.get_executor_metadata(executor_id).await
221 }
222
223 pub async fn save_executor_metadata(&self, metadata: ExecutorMetadata) -> Result<()> {
227 self.cluster_state.save_executor_metadata(metadata).await
228 }
229
230 pub async fn register_executor(
236 &self,
237 metadata: ExecutorMetadata,
238 specification: ExecutorData,
239 ) -> Result<()> {
240 debug!(
241 "registering executor {} with {} task slots",
242 metadata.id, specification.total_task_slots
243 );
244
245 ExecutorManager::test_connectivity(&metadata).await?;
246
247 self.cluster_state
248 .register_executor(metadata, specification)
249 .await?;
250
251 Ok(())
252 }
253
254 pub async fn remove_executor(
256 &self,
257 executor_id: &str,
258 reason: Option<String>,
259 ) -> Result<()> {
260 info!("Removing executor {}: {:?}", executor_id, reason);
261 self.cluster_state.remove_executor(executor_id).await
262 }
263
264 pub async fn stop_executor(&self, executor_id: &str, stop_reason: String) {
265 let executor_id = executor_id.to_string();
266 match self.get_client(&executor_id).await {
267 Ok(mut client) => {
268 tokio::task::spawn(async move {
269 match client
270 .stop_executor(StopExecutorParams {
271 executor_id: executor_id.to_string(),
272 reason: stop_reason,
273 force: true,
274 })
275 .await
276 {
277 Err(error) => {
278 warn!("Failed to send stop_executor rpc due to, {}", error);
279 }
280 Ok(_value) => {}
281 }
282 });
283 }
284 Err(_) => {
285 warn!(
286 "Executor is already dead, failed to connect to Executor {}",
287 executor_id
288 );
289 }
290 }
291 }
292
293 pub async fn launch_multi_task(
294 &self,
295 executor_id: &str,
296 multi_tasks: Vec<MultiTaskDefinition>,
297 scheduler_id: String,
298 ) -> Result<()> {
299 let mut client = self.get_client(executor_id).await?;
300 client
301 .launch_multi_task(protobuf::LaunchMultiTaskParams {
302 multi_tasks,
303 scheduler_id,
304 })
305 .await
306 .map_err(|e| {
307 KapotError::Internal(format!(
308 "Failed to connect to executor {}: {:?}",
309 executor_id, e
310 ))
311 })?;
312
313 Ok(())
314 }
315
316 pub(crate) async fn save_executor_heartbeat(
317 &self,
318 heartbeat: ExecutorHeartbeat,
319 ) -> Result<()> {
320 self.cluster_state
321 .save_executor_heartbeat(heartbeat.clone())
322 .await?;
323
324 Ok(())
325 }
326
327 pub(crate) fn is_dead_executor(&self, executor_id: &str) -> bool {
328 self.cluster_state
329 .get_executor_heartbeat(executor_id)
330 .map_or(true, |heartbeat| {
331 matches!(
332 heartbeat.status,
333 Some(kapot_core::serde::generated::kapot::ExecutorStatus {
334 status: Some(executor_status::Status::Dead(_))
335 })
336 )
337 })
338 }
339
340 pub(crate) fn get_alive_executors(&self) -> HashSet<String> {
343 let last_seen_ts_threshold =
344 get_time_before(self.config.executor_timeout_seconds);
345 self.cluster_state
346 .executor_heartbeats()
347 .iter()
348 .filter_map(|(exec, heartbeat)| {
349 let active = matches!(
350 heartbeat
351 .status
352 .as_ref()
353 .and_then(|status| status.status.as_ref()),
354 Some(executor_status::Status::Active(_))
355 );
356 let live = heartbeat.timestamp > last_seen_ts_threshold;
357
358 (active && live).then(|| exec.clone())
359 })
360 .collect()
361 }
362
363 pub(crate) fn get_expired_executors(&self) -> Vec<ExecutorHeartbeat> {
365 let last_seen_threshold = get_time_before(self.config.executor_timeout_seconds);
367
368 let termination_wait_threshold =
370 get_time_before(self.config.executor_termination_grace_period);
371
372 self.cluster_state
373 .executor_heartbeats()
374 .iter()
375 .filter_map(|(_exec, heartbeat)| {
376 let terminating = matches!(
377 heartbeat
378 .status
379 .as_ref()
380 .and_then(|status| status.status.as_ref()),
381 Some(executor_status::Status::Terminating(_))
382 );
383
384 let grace_period_expired =
385 heartbeat.timestamp <= termination_wait_threshold;
386
387 let expired = heartbeat.timestamp <= last_seen_threshold;
388
389 ((terminating && grace_period_expired) || expired)
390 .then(|| heartbeat.clone())
391 })
392 .collect::<Vec<_>>()
393 }
394
395 async fn get_client(&self, executor_id: &str) -> Result<ExecutorGrpcClient<Channel>> {
396 let client = self.clients.get(executor_id).map(|value| value.clone());
397
398 if let Some(client) = client {
399 Ok(client)
400 } else {
401 let executor_metadata = self.get_executor_metadata(executor_id).await?;
402 let executor_url = format!(
403 "http://{}:{}",
404 executor_metadata.host, executor_metadata.grpc_port
405 );
406 let connection = create_grpc_client_connection(executor_url).await?;
407 let client = ExecutorGrpcClient::new(connection);
408
409 {
410 self.clients.insert(executor_id.to_owned(), client.clone());
411 }
412 Ok(client)
413 }
414 }
415
416 #[cfg(not(test))]
417 async fn test_connectivity(metadata: &ExecutorMetadata) -> Result<()> {
418 let executor_url = format!("http://{}:{}", metadata.host, metadata.grpc_port);
419 debug!("Connecting to executor {:?}", executor_url);
420 let _ = protobuf::executor_grpc_client::ExecutorGrpcClient::connect(executor_url)
421 .await
422 .map_err(|e| {
423 KapotError::Internal(format!(
424 "Failed to register executor at {}:{}, could not connect: {:?}",
425 metadata.host, metadata.grpc_port, e
426 ))
427 })?;
428 Ok(())
429 }
430
431 #[cfg(test)]
432 async fn test_connectivity(_metadata: &ExecutorMetadata) -> Result<()> {
433 Ok(())
434 }
435}