kapot_executor/
executor_server.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use kapot_core::KAPOT_VERSION;
19use std::collections::HashMap;
20use std::convert::TryInto;
21use std::path::{Path, PathBuf};
22use std::sync::atomic::{AtomicBool, Ordering};
23use std::sync::Arc;
24use std::time::{Duration, SystemTime, UNIX_EPOCH};
25use tokio::sync::mpsc;
26
27use log::{debug, error, info, warn};
28use tonic::transport::Channel;
29use tonic::{Request, Response, Status};
30
31use kapot_core::config::KAPOT_DATA_CACHE_ENABLED;
32use kapot_core::error::KapotError;
33use kapot_core::serde::protobuf::{
34    executor_grpc_server::{ExecutorGrpc, ExecutorGrpcServer},
35    executor_metric, executor_status,
36    scheduler_grpc_client::SchedulerGrpcClient,
37    CancelTasksParams, CancelTasksResult, ExecutorMetric, ExecutorStatus,
38    HeartBeatParams, LaunchMultiTaskParams, LaunchMultiTaskResult, LaunchTaskParams,
39    LaunchTaskResult, RegisterExecutorParams, RemoveJobDataParams, RemoveJobDataResult,
40    StopExecutorParams, StopExecutorResult, TaskStatus, UpdateTaskStatusParams,
41};
42use kapot_core::serde::scheduler::from_proto::{
43    get_task_definition, get_task_definition_vec,
44};
45use kapot_core::serde::scheduler::PartitionId;
46use kapot_core::serde::scheduler::TaskDefinition;
47use kapot_core::serde::KapotCodec;
48use kapot_core::utils::{create_grpc_client_connection, create_grpc_server};
49use dashmap::DashMap;
50use datafusion::config::ConfigOptions;
51use datafusion::execution::TaskContext;
52use datafusion::prelude::SessionConfig;
53use datafusion_proto::{logical_plan::AsLogicalPlan, physical_plan::AsExecutionPlan};
54use tokio::sync::mpsc::error::TryRecvError;
55use tokio::task::JoinHandle;
56
57use crate::cpu_bound_executor::DedicatedExecutor;
58use crate::executor::Executor;
59use crate::executor_process::ExecutorProcessConfig;
60use crate::shutdown::ShutdownNotifier;
61use crate::{as_task_status, TaskExecutionTimes};
62
63type ServerHandle = JoinHandle<Result<(), KapotError>>;
64type SchedulerClients = Arc<DashMap<String, SchedulerGrpcClient<Channel>>>;
65
66/// Wrap TaskDefinition with its curator scheduler id for task update to its specific curator scheduler later
67#[derive(Debug)]
68struct CuratorTaskDefinition {
69    scheduler_id: String,
70    task: TaskDefinition,
71}
72
73/// Wrap TaskStatus with its curator scheduler id for task update to its specific curator scheduler later
74#[derive(Debug)]
75struct CuratorTaskStatus {
76    scheduler_id: String,
77    task_status: TaskStatus,
78}
79
80pub async fn startup<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan>(
81    mut scheduler: SchedulerGrpcClient<Channel>,
82    config: Arc<ExecutorProcessConfig>,
83    executor: Arc<Executor>,
84    codec: KapotCodec<T, U>,
85    stop_send: mpsc::Sender<bool>,
86    shutdown_noti: &ShutdownNotifier,
87) -> Result<ServerHandle, KapotError> {
88    let channel_buf_size = executor.concurrent_tasks * 50;
89    let (tx_task, rx_task) = mpsc::channel::<CuratorTaskDefinition>(channel_buf_size);
90    let (tx_task_status, rx_task_status) =
91        mpsc::channel::<CuratorTaskStatus>(channel_buf_size);
92
93    let executor_server = ExecutorServer::new(
94        scheduler.clone(),
95        executor.clone(),
96        ExecutorEnv {
97            tx_task,
98            tx_task_status,
99            tx_stop: stop_send,
100        },
101        codec,
102        config.grpc_max_encoding_message_size as usize,
103        config.grpc_max_decoding_message_size as usize,
104    );
105
106    // 1. Start executor grpc service
107    let server = {
108        let executor_meta = executor.metadata.clone();
109        let addr = format!("{}:{}", config.bind_host, executor_meta.grpc_port);
110        let addr = addr.parse().unwrap();
111
112        info!(
113            "kapot v{} Rust Executor Grpc Server listening on {:?}",
114            KAPOT_VERSION, addr
115        );
116        let server = ExecutorGrpcServer::new(executor_server.clone())
117            .max_encoding_message_size(config.grpc_max_encoding_message_size as usize)
118            .max_decoding_message_size(config.grpc_max_decoding_message_size as usize);
119        let mut grpc_shutdown = shutdown_noti.subscribe_for_shutdown();
120        tokio::spawn(async move {
121            let shutdown_signal = grpc_shutdown.recv();
122            let grpc_server_future = create_grpc_server()
123                .add_service(server)
124                .serve_with_shutdown(addr, shutdown_signal);
125            grpc_server_future.await.map_err(|e| {
126                error!("Tonic error, Could not start Executor Grpc Server.");
127                KapotError::TonicError(e)
128            })
129        })
130    };
131
132    // 2. Do executor registration
133    // TODO the executor registration should happen only after the executor grpc server started.
134    let executor_server = Arc::new(executor_server);
135    match register_executor(&mut scheduler, executor.clone()).await {
136        Ok(_) => {
137            info!("Executor registration succeed");
138        }
139        Err(error) => {
140            error!("Executor registration failed due to: {}", error);
141            // abort the Executor Grpc Future
142            server.abort();
143            return Err(error);
144        }
145    };
146
147    // 3. Start Heartbeater loop
148    {
149        let heartbeater = Heartbeater::new(executor_server.clone());
150        heartbeater.start(shutdown_noti, config.executor_heartbeat_interval_seconds);
151    }
152
153    // 4. Start TaskRunnerPool loop
154    {
155        let task_runner_pool = TaskRunnerPool::new(executor_server.clone());
156        task_runner_pool.start(rx_task, rx_task_status, shutdown_noti);
157    }
158
159    Ok(server)
160}
161
162#[allow(clippy::clone_on_copy)]
163async fn register_executor(
164    scheduler: &mut SchedulerGrpcClient<Channel>,
165    executor: Arc<Executor>,
166) -> Result<(), KapotError> {
167    let result = scheduler
168        .register_executor(RegisterExecutorParams {
169            metadata: Some(executor.metadata.clone()),
170        })
171        .await?;
172    if result.into_inner().success {
173        Ok(())
174    } else {
175        Err(KapotError::General(
176            "Executor registration failed!!!".to_owned(),
177        ))
178    }
179}
180
181#[derive(Clone)]
182pub struct ExecutorServer<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> {
183    _start_time: u128,
184    executor: Arc<Executor>,
185    executor_env: ExecutorEnv,
186    codec: KapotCodec<T, U>,
187    scheduler_to_register: SchedulerGrpcClient<Channel>,
188    schedulers: SchedulerClients,
189    grpc_max_encoding_message_size: usize,
190    grpc_max_decoding_message_size: usize,
191}
192
193#[derive(Clone)]
194struct ExecutorEnv {
195    /// Receive `TaskDefinition` from rpc then send to CPU bound tasks pool `dedicated_executor`.
196    tx_task: mpsc::Sender<CuratorTaskDefinition>,
197    /// Receive `TaskStatus` from CPU bound tasks pool `dedicated_executor` then use rpc send back to scheduler.
198    tx_task_status: mpsc::Sender<CuratorTaskStatus>,
199    /// Receive stop executor request from rpc.
200    tx_stop: mpsc::Sender<bool>,
201}
202
203unsafe impl Sync for ExecutorEnv {}
204
205/// Global flag indicating whether the executor is terminating. This should be
206/// set to `true` when the executor receives a shutdown signal
207pub static TERMINATING: AtomicBool = AtomicBool::new(false);
208
209impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> ExecutorServer<T, U> {
210    fn new(
211        scheduler_to_register: SchedulerGrpcClient<Channel>,
212        executor: Arc<Executor>,
213        executor_env: ExecutorEnv,
214        codec: KapotCodec<T, U>,
215        grpc_max_encoding_message_size: usize,
216        grpc_max_decoding_message_size: usize,
217    ) -> Self {
218        Self {
219            _start_time: SystemTime::now()
220                .duration_since(UNIX_EPOCH)
221                .unwrap()
222                .as_millis(),
223            executor,
224            executor_env,
225            codec,
226            scheduler_to_register,
227            schedulers: Default::default(),
228            grpc_max_encoding_message_size,
229            grpc_max_decoding_message_size,
230        }
231    }
232
233    async fn get_scheduler_client(
234        &self,
235        scheduler_id: &str,
236    ) -> Result<SchedulerGrpcClient<Channel>, KapotError> {
237        let scheduler = self.schedulers.get(scheduler_id).map(|value| value.clone());
238        // If channel does not exist, create a new one
239        if let Some(scheduler) = scheduler {
240            Ok(scheduler)
241        } else {
242            let scheduler_url = format!("http://{scheduler_id}");
243            let connection = create_grpc_client_connection(scheduler_url).await?;
244            let scheduler = SchedulerGrpcClient::new(connection)
245                .max_encoding_message_size(self.grpc_max_encoding_message_size)
246                .max_decoding_message_size(self.grpc_max_decoding_message_size);
247
248            {
249                self.schedulers
250                    .insert(scheduler_id.to_owned(), scheduler.clone());
251            }
252
253            Ok(scheduler)
254        }
255    }
256
257    /// 1. First Heartbeat to its registration scheduler, if successful then return; else go next.
258    /// 2. Heartbeat to schedulers which has launching tasks to this executor until one succeeds
259    async fn heartbeat(&self) {
260        let status = if TERMINATING.load(Ordering::Acquire) {
261            executor_status::Status::Terminating(String::default())
262        } else {
263            executor_status::Status::Active(String::default())
264        };
265
266        let heartbeat_params = HeartBeatParams {
267            executor_id: self.executor.metadata.id.clone(),
268            metrics: self.get_executor_metrics(),
269            status: Some(ExecutorStatus {
270                status: Some(status),
271            }),
272            metadata: Some(self.executor.metadata.clone()),
273        };
274        let mut scheduler = self.scheduler_to_register.clone();
275        match scheduler
276            .heart_beat_from_executor(heartbeat_params.clone())
277            .await
278        {
279            Ok(_) => {
280                return;
281            }
282            Err(e) => {
283                warn!(
284                    "Fail to update heartbeat to its registration scheduler due to {:?}",
285                    e
286                );
287            }
288        };
289
290        for mut item in self.schedulers.iter_mut() {
291            let scheduler_id = item.key().clone();
292            let scheduler = item.value_mut();
293
294            match scheduler
295                .heart_beat_from_executor(heartbeat_params.clone())
296                .await
297            {
298                Ok(_) => {
299                    break;
300                }
301                Err(e) => {
302                    warn!(
303                        "Fail to update heartbeat to scheduler {} due to {:?}",
304                        scheduler_id, e
305                    );
306                }
307            }
308        }
309    }
310
311    /// This method should not return Err. If task fails, a failure task status should be sent
312    /// to the channel to notify the scheduler.
313    async fn run_task(&self, task_identity: String, curator_task: CuratorTaskDefinition) {
314        let start_exec_time = SystemTime::now()
315            .duration_since(UNIX_EPOCH)
316            .unwrap()
317            .as_millis() as u64;
318        info!("Start to run task {}", task_identity);
319        let task = curator_task.task;
320
321        let task_id = task.task_id;
322        let job_id = task.job_id;
323        let stage_id = task.stage_id;
324        let stage_attempt_num = task.stage_attempt_num;
325        let partition_id = task.partition_id;
326        let plan = task.plan;
327
328        let part = PartitionId {
329            job_id: job_id.clone(),
330            stage_id,
331            partition_id,
332        };
333
334        let query_stage_exec = self
335            .executor
336            .execution_engine
337            .create_query_stage_exec(
338                job_id.clone(),
339                stage_id,
340                plan,
341                &self.executor.work_dir,
342            )
343            .unwrap();
344
345        let task_context = {
346            let task_props = task.props;
347            let data_cache = task_props
348                .get(KAPOT_DATA_CACHE_ENABLED)
349                .map(|data_cache| data_cache.parse().unwrap_or(false))
350                .unwrap_or(false);
351            let mut config = ConfigOptions::new();
352            for (k, v) in task_props.iter() {
353                if let Err(e) = config.set(k, v) {
354                    debug!("Fail to set session config for ({},{}): {:?}", k, v, e);
355                }
356            }
357            let session_config = SessionConfig::from(config);
358
359            let function_registry = task.function_registry;
360            if data_cache {
361                info!("Data cache will be enabled for {}", task_identity);
362            }
363            let runtime = self.executor.get_runtime(data_cache);
364
365            Arc::new(TaskContext::new(
366                Some(task_identity.clone()),
367                task.session_id,
368                session_config,
369                function_registry.scalar_functions.clone(),
370                function_registry.aggregate_functions.clone(),
371                function_registry.window_functions.clone(),
372                runtime,
373            ))
374        };
375
376        info!("Start to execute shuffle write for task {}", task_identity);
377
378        let execution_result = self
379            .executor
380            .execute_query_stage(
381                task_id,
382                part.clone(),
383                query_stage_exec.clone(),
384                task_context,
385            )
386            .await;
387        info!("Done with task {}", task_identity);
388        debug!("Statistics: {:?}", execution_result);
389
390        let plan_metrics = query_stage_exec.collect_plan_metrics();
391        let operator_metrics = match plan_metrics
392            .into_iter()
393            .map(|m| m.try_into())
394            .collect::<Result<Vec<_>, KapotError>>()
395        {
396            Ok(metrics) => Some(metrics),
397            Err(_) => None,
398        };
399        let executor_id = &self.executor.metadata.id;
400
401        let end_exec_time = SystemTime::now()
402            .duration_since(UNIX_EPOCH)
403            .unwrap()
404            .as_millis() as u64;
405        let task_execution_times = TaskExecutionTimes {
406            launch_time: task.launch_time,
407            start_exec_time,
408            end_exec_time,
409        };
410
411        let task_status = as_task_status(
412            execution_result,
413            executor_id.clone(),
414            task_id,
415            stage_attempt_num,
416            part,
417            operator_metrics,
418            task_execution_times,
419        );
420
421        let scheduler_id = curator_task.scheduler_id;
422        let task_status_sender = self.executor_env.tx_task_status.clone();
423        task_status_sender
424            .send(CuratorTaskStatus {
425                scheduler_id,
426                task_status,
427            })
428            .await
429            .unwrap();
430    }
431
432    // TODO populate with real metrics
433    fn get_executor_metrics(&self) -> Vec<ExecutorMetric> {
434        let available_memory = ExecutorMetric {
435            metric: Some(executor_metric::Metric::AvailableMemory(u64::MAX)),
436        };
437        let executor_metrics = vec![available_memory];
438        executor_metrics
439    }
440}
441
442/// Heartbeater will run forever until a shutdown notification received.
443struct Heartbeater<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> {
444    executor_server: Arc<ExecutorServer<T, U>>,
445}
446
447impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> Heartbeater<T, U> {
448    fn new(executor_server: Arc<ExecutorServer<T, U>>) -> Self {
449        Self { executor_server }
450    }
451
452    fn start(
453        &self,
454        shutdown_noti: &ShutdownNotifier,
455        executor_heartbeat_interval_seconds: u64,
456    ) {
457        let executor_server = self.executor_server.clone();
458        let mut heartbeat_shutdown = shutdown_noti.subscribe_for_shutdown();
459        let heartbeat_complete = shutdown_noti.shutdown_complete_tx.clone();
460        tokio::spawn(async move {
461            info!("Starting heartbeater to send heartbeat the scheduler periodically");
462            // As long as the shutdown notification has not been received
463            while !heartbeat_shutdown.is_shutdown() {
464                executor_server.heartbeat().await;
465                tokio::select! {
466                    _ = tokio::time::sleep(Duration::from_secs(executor_heartbeat_interval_seconds)) => {},
467                    _ = heartbeat_shutdown.recv() => {
468                        info!("Stop heartbeater");
469                        drop(heartbeat_complete);
470                        return;
471                    }
472                };
473            }
474        });
475    }
476}
477
478/// There are two loop(future) running separately in tokio runtime.
479/// First is for sending back task status to scheduler
480/// Second is for receiving task from scheduler and run.
481/// The two loops will run forever until a shutdown notification received.
482struct TaskRunnerPool<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> {
483    executor_server: Arc<ExecutorServer<T, U>>,
484}
485
486impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> TaskRunnerPool<T, U> {
487    fn new(executor_server: Arc<ExecutorServer<T, U>>) -> Self {
488        Self { executor_server }
489    }
490
491    fn start(
492        &self,
493        mut rx_task: mpsc::Receiver<CuratorTaskDefinition>,
494        mut rx_task_status: mpsc::Receiver<CuratorTaskStatus>,
495        shutdown_noti: &ShutdownNotifier,
496    ) {
497        //1. loop for task status reporting
498        let executor_server = self.executor_server.clone();
499        let mut tasks_status_shutdown = shutdown_noti.subscribe_for_shutdown();
500        let tasks_status_complete = shutdown_noti.shutdown_complete_tx.clone();
501        tokio::spawn(async move {
502            info!("Starting the task status reporter");
503            // As long as the shutdown notification has not been received
504            while !tasks_status_shutdown.is_shutdown() {
505                let mut curator_task_status_map: HashMap<String, Vec<TaskStatus>> =
506                    HashMap::new();
507                // First try to fetch task status from the channel in *blocking* mode
508                let maybe_task_status: Option<CuratorTaskStatus> = tokio::select! {
509                     task_status = rx_task_status.recv() => task_status,
510                    _ = tasks_status_shutdown.recv() => {
511                        info!("Stop task status reporting loop");
512                        drop(tasks_status_complete);
513                        return;
514                    }
515                };
516
517                let mut fetched_task_num = 0usize;
518                if let Some(task_status) = maybe_task_status {
519                    let task_status_vec = curator_task_status_map
520                        .entry(task_status.scheduler_id)
521                        .or_default();
522                    task_status_vec.push(task_status.task_status);
523                    fetched_task_num += 1;
524                } else {
525                    info!("Channel is closed and will exit the task status report loop.");
526                    drop(tasks_status_complete);
527                    return;
528                }
529
530                // Then try to fetch by non-blocking mode to fetch as much finished tasks as possible
531                loop {
532                    match rx_task_status.try_recv() {
533                        Ok(task_status) => {
534                            let task_status_vec = curator_task_status_map
535                                .entry(task_status.scheduler_id)
536                                .or_default();
537                            task_status_vec.push(task_status.task_status);
538                            fetched_task_num += 1;
539                        }
540                        Err(TryRecvError::Empty) => {
541                            info!("Fetched {} tasks status to report", fetched_task_num);
542                            break;
543                        }
544                        Err(TryRecvError::Disconnected) => {
545                            info!("Channel is closed and will exit the task status report loop");
546                            drop(tasks_status_complete);
547                            return;
548                        }
549                    }
550                }
551
552                for (scheduler_id, tasks_status) in curator_task_status_map.into_iter() {
553                    match executor_server.get_scheduler_client(&scheduler_id).await {
554                        Ok(mut scheduler) => {
555                            if let Err(e) = scheduler
556                                .update_task_status(UpdateTaskStatusParams {
557                                    executor_id: executor_server
558                                        .executor
559                                        .metadata
560                                        .id
561                                        .clone(),
562                                    task_status: tasks_status.clone(),
563                                })
564                                .await
565                            {
566                                error!(
567                                    "Fail to update tasks {:?} due to {:?}",
568                                    tasks_status, e
569                                );
570                            }
571                        }
572                        Err(e) => {
573                            error!(
574                                "Fail to connect to scheduler {} due to {:?}",
575                                scheduler_id, e
576                            );
577                        }
578                    }
579                }
580            }
581        });
582
583        //2. loop for task fetching and running
584        let executor_server = self.executor_server.clone();
585        let mut task_runner_shutdown = shutdown_noti.subscribe_for_shutdown();
586        let task_runner_complete = shutdown_noti.shutdown_complete_tx.clone();
587        tokio::spawn(async move {
588            info!("Starting the task runner pool");
589
590            // Use a dedicated executor for CPU bound tasks so that the main tokio
591            // executor can still answer requests even when under load
592            let dedicated_executor = DedicatedExecutor::new(
593                "task_runner",
594                executor_server.executor.concurrent_tasks,
595            );
596
597            // As long as the shutdown notification has not been received
598            while !task_runner_shutdown.is_shutdown() {
599                let maybe_task: Option<CuratorTaskDefinition> = tokio::select! {
600                     task = rx_task.recv() => task,
601                    _ = task_runner_shutdown.recv() => {
602                        info!("Stop the task runner pool");
603                        drop(task_runner_complete);
604                        return;
605                    }
606                };
607                if let Some(curator_task) = maybe_task {
608                    let task_identity = format!(
609                        "TID {} {}/{}.{}/{}.{}",
610                        &curator_task.task.task_id,
611                        &curator_task.task.job_id,
612                        &curator_task.task.stage_id,
613                        &curator_task.task.stage_attempt_num,
614                        &curator_task.task.partition_id,
615                        &curator_task.task.task_attempt_num,
616                    );
617                    info!("Received task {:?}", &task_identity);
618
619                    let server = executor_server.clone();
620                    dedicated_executor.spawn(async move {
621                        server.run_task(task_identity.clone(), curator_task).await;
622                    });
623                } else {
624                    info!("Channel is closed and will exit the task receive loop");
625                    drop(task_runner_complete);
626                    return;
627                }
628            }
629        });
630    }
631}
632
633#[tonic::async_trait]
634impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> ExecutorGrpc
635    for ExecutorServer<T, U>
636{
637    async fn launch_task(
638        &self,
639        request: Request<LaunchTaskParams>,
640    ) -> Result<Response<LaunchTaskResult>, Status> {
641        let LaunchTaskParams {
642            tasks,
643            scheduler_id,
644        } = request.into_inner();
645        let task_sender = self.executor_env.tx_task.clone();
646        for task in tasks {
647            task_sender
648                .send(CuratorTaskDefinition {
649                    scheduler_id: scheduler_id.clone(),
650                    task: get_task_definition(
651                        task,
652                        self.executor.get_runtime(false),
653                        self.executor.scalar_functions.clone(),
654                        self.executor.aggregate_functions.clone(),
655                        self.executor.window_functions.clone(),
656                        self.codec.clone(),
657                    )
658                    .map_err(|e| Status::invalid_argument(format!("{e}")))?,
659                })
660                .await
661                .unwrap();
662        }
663        Ok(Response::new(LaunchTaskResult { success: true }))
664    }
665
666    /// by this interface, it can reduce the deserialization cost for multiple tasks
667    /// belong to the same job stage running on the same one executor
668    async fn launch_multi_task(
669        &self,
670        request: Request<LaunchMultiTaskParams>,
671    ) -> Result<Response<LaunchMultiTaskResult>, Status> {
672        let LaunchMultiTaskParams {
673            multi_tasks,
674            scheduler_id,
675        } = request.into_inner();
676        let task_sender = self.executor_env.tx_task.clone();
677        for multi_task in multi_tasks {
678            let multi_task: Vec<TaskDefinition> = get_task_definition_vec(
679                multi_task,
680                self.executor.get_runtime(false),
681                self.executor.scalar_functions.clone(),
682                self.executor.aggregate_functions.clone(),
683                self.executor.window_functions.clone(),
684                self.codec.clone(),
685            )
686            .map_err(|e| Status::invalid_argument(format!("{e}")))?;
687            for task in multi_task {
688                task_sender
689                    .send(CuratorTaskDefinition {
690                        scheduler_id: scheduler_id.clone(),
691                        task,
692                    })
693                    .await
694                    .unwrap();
695            }
696        }
697        Ok(Response::new(LaunchMultiTaskResult { success: true }))
698    }
699
700    async fn stop_executor(
701        &self,
702        request: Request<StopExecutorParams>,
703    ) -> Result<Response<StopExecutorResult>, Status> {
704        let stop_request = request.into_inner();
705        if stop_request.executor_id != self.executor.metadata.id {
706            warn!(
707                "The executor id {} in request is different from {}. The stop request will be ignored",
708                stop_request.executor_id, self.executor.metadata.id
709            );
710            return Ok(Response::new(StopExecutorResult {}));
711        }
712        let stop_reason = stop_request.reason;
713        let force = stop_request.force;
714        info!(
715            "Receive stop executor request, reason: {:?}, force {:?}",
716            stop_reason, force
717        );
718        let stop_sender = self.executor_env.tx_stop.clone();
719        stop_sender.send(force).await.unwrap();
720        Ok(Response::new(StopExecutorResult {}))
721    }
722
723    async fn cancel_tasks(
724        &self,
725        request: Request<CancelTasksParams>,
726    ) -> Result<Response<CancelTasksResult>, Status> {
727        let task_infos = request.into_inner().task_infos;
728        info!("Cancelling tasks for {:?}", task_infos);
729
730        let mut cancelled = true;
731
732        for task in task_infos {
733            if let Err(e) = self
734                .executor
735                .cancel_task(
736                    task.task_id as usize,
737                    task.job_id,
738                    task.stage_id as usize,
739                    task.partition_id as usize,
740                )
741                .await
742            {
743                error!("Error cancelling task: {:?}", e);
744                cancelled = false;
745            }
746        }
747
748        Ok(Response::new(CancelTasksResult { cancelled }))
749    }
750
751    async fn remove_job_data(
752        &self,
753        request: Request<RemoveJobDataParams>,
754    ) -> Result<Response<RemoveJobDataResult>, Status> {
755        let job_id = request.into_inner().job_id;
756
757        let work_dir = PathBuf::from(&self.executor.work_dir);
758        let mut path = work_dir.clone();
759        path.push(&job_id);
760
761        // Verify it's an existing directory
762        if !path.is_dir() {
763            return if !path.exists() {
764                Ok(Response::new(RemoveJobDataResult {}))
765            } else {
766                Err(Status::invalid_argument(format!(
767                    "Path {path:?} is not for a directory!!!"
768                )))
769            };
770        }
771
772        if !is_subdirectory(path.as_path(), work_dir.as_path()) {
773            return Err(Status::invalid_argument(format!(
774                "Path {path:?} is not a subdirectory of {work_dir:?}!!!"
775            )));
776        }
777
778        info!("Remove data for job {:?}", job_id);
779
780        std::fs::remove_dir_all(&path)?;
781
782        Ok(Response::new(RemoveJobDataResult {}))
783    }
784}
785
786// Check whether the path is the subdirectory of the base directory
787fn is_subdirectory(path: &Path, base_path: &Path) -> bool {
788    if let (Ok(path), Ok(base_path)) = (path.canonicalize(), base_path.canonicalize()) {
789        if let Some(parent_path) = path.parent() {
790            parent_path.starts_with(base_path)
791        } else {
792            false
793        }
794    } else {
795        false
796    }
797}
798
799#[cfg(test)]
800mod test {
801    use crate::executor_server::is_subdirectory;
802    use std::fs;
803    use std::path::{Path, PathBuf};
804    use tempfile::TempDir;
805
806    #[tokio::test]
807    async fn test_is_subdirectory() {
808        let base_dir = TempDir::new().unwrap().into_path();
809
810        // Normal correct one
811        {
812            let job_path = prepare_testing_job_directory(&base_dir, "job_a");
813            assert!(is_subdirectory(&job_path, base_dir.as_path()));
814        }
815
816        // Empty job id
817        {
818            let job_path = prepare_testing_job_directory(&base_dir, "");
819            assert!(!is_subdirectory(&job_path, base_dir.as_path()));
820
821            let job_path = prepare_testing_job_directory(&base_dir, ".");
822            assert!(!is_subdirectory(&job_path, base_dir.as_path()));
823        }
824
825        // Malicious job id
826        {
827            let job_path = prepare_testing_job_directory(&base_dir, "..");
828            assert!(!is_subdirectory(&job_path, base_dir.as_path()));
829        }
830    }
831
832    fn prepare_testing_job_directory(base_dir: &Path, job_id: &str) -> PathBuf {
833        let mut path = base_dir.to_path_buf();
834        path.push(job_id);
835        if !path.exists() {
836            fs::create_dir(&path).unwrap();
837        }
838        path
839    }
840}