kapot_executor/
execution_loop.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use datafusion::config::ConfigOptions;
19use datafusion::physical_plan::ExecutionPlan;
20
21use kapot_core::serde::protobuf::{
22    scheduler_grpc_client::SchedulerGrpcClient, PollWorkParams, PollWorkResult,
23    TaskDefinition, TaskStatus,
24};
25use datafusion::prelude::SessionConfig;
26use tokio::sync::{OwnedSemaphorePermit, Semaphore};
27
28use crate::cpu_bound_executor::DedicatedExecutor;
29use crate::executor::Executor;
30use crate::{as_task_status, TaskExecutionTimes};
31use kapot_core::error::KapotError;
32use kapot_core::serde::scheduler::{ExecutorSpecification, PartitionId};
33use kapot_core::serde::KapotCodec;
34use datafusion::execution::context::TaskContext;
35use datafusion::functions::datetime::date_part;
36use datafusion::functions::unicode::substr;
37use datafusion::functions_aggregate::covariance::{covar_pop_udaf, covar_samp_udaf};
38use datafusion::functions_aggregate::sum::sum_udaf;
39use datafusion::functions_aggregate::variance::var_samp_udaf;
40use datafusion_proto::logical_plan::AsLogicalPlan;
41use datafusion_proto::physical_plan::AsExecutionPlan;
42use futures::FutureExt;
43use log::{debug, error, info, warn};
44use std::any::Any;
45use std::collections::HashMap;
46use std::convert::TryInto;
47use std::error::Error;
48use std::ops::Deref;
49use std::sync::mpsc::{Receiver, Sender, TryRecvError};
50use std::time::{SystemTime, UNIX_EPOCH};
51use std::{sync::Arc, time::Duration};
52use tonic::transport::Channel;
53
54pub async fn poll_loop<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan>(
55    mut scheduler: SchedulerGrpcClient<Channel>,
56    executor: Arc<Executor>,
57    codec: KapotCodec<T, U>,
58) -> Result<(), KapotError> {
59    let executor_specification: ExecutorSpecification = executor
60        .metadata
61        .specification
62        .as_ref()
63        .unwrap()
64        .clone()
65        .into();
66    let available_task_slots =
67        Arc::new(Semaphore::new(executor_specification.task_slots as usize));
68
69    let (task_status_sender, mut task_status_receiver) =
70        std::sync::mpsc::channel::<TaskStatus>();
71    info!("Starting poll work loop with scheduler");
72
73    let dedicated_executor =
74        DedicatedExecutor::new("task_runner", executor_specification.task_slots as usize);
75
76    loop {
77        // Wait for task slots to be available before asking for new work
78        let permit = available_task_slots.acquire().await.unwrap();
79        // Make the slot available again
80        drop(permit);
81
82        // Keeps track of whether we received task in last iteration
83        // to avoid going in sleep mode between polling
84        let mut active_job = false;
85
86        let task_status: Vec<TaskStatus> =
87            sample_tasks_status(&mut task_status_receiver).await;
88
89        let poll_work_result: anyhow::Result<
90            tonic::Response<PollWorkResult>,
91            tonic::Status,
92        > = scheduler
93            .poll_work(PollWorkParams {
94                metadata: Some(executor.metadata.clone()),
95                num_free_slots: available_task_slots.available_permits() as u32,
96                task_status,
97            })
98            .await;
99
100        match poll_work_result {
101            Ok(result) => {
102                let tasks = result.into_inner().tasks;
103                active_job = !tasks.is_empty();
104
105                for task in tasks {
106                    let task_status_sender = task_status_sender.clone();
107
108                    // Acquire a permit/slot for the task
109                    let permit =
110                        available_task_slots.clone().acquire_owned().await.unwrap();
111
112                    match run_received_task(
113                        executor.clone(),
114                        permit,
115                        task_status_sender,
116                        task,
117                        &codec,
118                        &dedicated_executor,
119                    )
120                    .await
121                    {
122                        Ok(_) => {}
123                        Err(e) => {
124                            warn!("Failed to run task: {:?}", e);
125                        }
126                    }
127                }
128            }
129            Err(error) => {
130                warn!("Executor poll work loop failed. If this continues to happen the Scheduler might be marked as dead. Error: {}", error);
131            }
132        }
133
134        if !active_job {
135            tokio::time::sleep(Duration::from_millis(100)).await;
136        }
137    }
138}
139
140/// Tries to get meaningful description from panic-error.
141pub(crate) fn any_to_string(any: &Box<dyn Any + Send>) -> String {
142    if let Some(s) = any.downcast_ref::<&str>() {
143        (*s).to_string()
144    } else if let Some(s) = any.downcast_ref::<String>() {
145        s.clone()
146    } else if let Some(error) = any.downcast_ref::<Box<dyn Error + Send>>() {
147        error.to_string()
148    } else {
149        "Unknown error occurred".to_string()
150    }
151}
152
153async fn run_received_task<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan>(
154    executor: Arc<Executor>,
155    permit: OwnedSemaphorePermit,
156    task_status_sender: Sender<TaskStatus>,
157    task: TaskDefinition,
158    codec: &KapotCodec<T, U>,
159    dedicated_executor: &DedicatedExecutor,
160) -> Result<(), KapotError> {
161    let task_id = task.task_id;
162    let task_attempt_num = task.task_attempt_num;
163    let job_id = task.job_id;
164    let stage_id = task.stage_id;
165    let stage_attempt_num = task.stage_attempt_num;
166    let task_launch_time = task.launch_time;
167    let partition_id = task.partition_id;
168    let start_exec_time = SystemTime::now()
169        .duration_since(UNIX_EPOCH)
170        .unwrap()
171        .as_millis() as u64;
172    let task_identity = format!(
173        "TID {task_id} {job_id}/{stage_id}.{stage_attempt_num}/{partition_id}.{task_attempt_num}"
174    );
175    info!("Received task {}", task_identity);
176
177    let mut task_props = HashMap::new();
178    for kv_pair in task.props {
179        task_props.insert(kv_pair.key, kv_pair.value);
180    }
181    let mut config = ConfigOptions::new();
182    for (k, v) in task_props {
183        config.set(&k, &v)?;
184    }
185    let session_config = SessionConfig::from(config);
186
187    let mut task_scalar_functions = HashMap::new();
188    let mut task_aggregate_functions = HashMap::new();
189    let mut task_window_functions = HashMap::new();
190    // TODO combine the functions from Executor's functions and TaskDefintion's function resources
191    for scalar_func in executor.scalar_functions.clone() {
192        task_scalar_functions.insert(scalar_func.0.clone(), scalar_func.1);
193    }
194    for agg_func in executor.aggregate_functions.clone() {
195        task_aggregate_functions.insert(agg_func.0, agg_func.1);
196    }
197    // since DataFusion 38 some internal functions were converted to UDAF, so
198    // we have to register them manually
199    task_aggregate_functions.insert("var".to_string(), var_samp_udaf());
200    task_aggregate_functions.insert("covar_samp".to_string(), covar_samp_udaf());
201    task_aggregate_functions.insert("covar_pop".to_string(), covar_pop_udaf());
202    task_aggregate_functions.insert("SUM".to_string(), sum_udaf());
203
204    // TODO which other functions need adding here?
205    task_scalar_functions.insert("date_part".to_string(), date_part());
206    task_scalar_functions.insert("substr".to_string(), substr());
207
208    for window_func in executor.window_functions.clone() {
209        task_window_functions.insert(window_func.0, window_func.1);
210    }
211    let runtime = executor.get_runtime(false);
212    let session_id = task.session_id.clone();
213    let task_context = Arc::new(TaskContext::new(
214        Some(task_identity.clone()),
215        session_id,
216        session_config,
217        task_scalar_functions,
218        task_aggregate_functions,
219        task_window_functions,
220        runtime.clone(),
221    ));
222
223    let plan: Arc<dyn ExecutionPlan> =
224        U::try_decode(task.plan.as_slice()).and_then(|proto| {
225            proto.try_into_physical_plan(
226                task_context.deref(),
227                runtime.deref(),
228                codec.physical_extension_codec(),
229            )
230        })?;
231
232    let query_stage_exec = executor.execution_engine.create_query_stage_exec(
233        job_id.clone(),
234        stage_id as usize,
235        plan,
236        &executor.work_dir,
237    )?;
238    dedicated_executor.spawn(async move {
239        use std::panic::AssertUnwindSafe;
240        let part = PartitionId {
241            job_id: job_id.clone(),
242            stage_id: stage_id as usize,
243            partition_id: partition_id as usize,
244        };
245
246        let execution_result = match AssertUnwindSafe(executor.execute_query_stage(
247            task_id as usize,
248            part.clone(),
249            query_stage_exec.clone(),
250            task_context,
251        ))
252        .catch_unwind()
253        .await
254        {
255            Ok(Ok(r)) => Ok(r),
256            Ok(Err(r)) => Err(r),
257            Err(r) => {
258                error!("Error executing task: {:?}", any_to_string(&r));
259                Err(KapotError::Internal(format!("{:#?}", any_to_string(&r))))
260            }
261        };
262
263        info!("Done with task {}", task_identity);
264        debug!("Statistics: {:?}", execution_result);
265
266        let plan_metrics = query_stage_exec.collect_plan_metrics();
267        let operator_metrics = plan_metrics
268            .into_iter()
269            .map(|m| m.try_into())
270            .collect::<Result<Vec<_>, KapotError>>()
271            .ok();
272
273        let end_exec_time = SystemTime::now()
274            .duration_since(UNIX_EPOCH)
275            .unwrap()
276            .as_millis() as u64;
277
278        let task_execution_times = TaskExecutionTimes {
279            launch_time: task_launch_time,
280            start_exec_time,
281            end_exec_time,
282        };
283
284        let _ = task_status_sender.send(as_task_status(
285            execution_result,
286            executor.metadata.id.clone(),
287            task_id as usize,
288            stage_attempt_num as usize,
289            part,
290            operator_metrics,
291            task_execution_times,
292        ));
293
294        // Release the permit after the work is done
295        drop(permit);
296    });
297
298    Ok(())
299}
300
301async fn sample_tasks_status(
302    task_status_receiver: &mut Receiver<TaskStatus>,
303) -> Vec<TaskStatus> {
304    let mut task_status: Vec<TaskStatus> = vec![];
305
306    loop {
307        match task_status_receiver.try_recv() {
308            anyhow::Result::Ok(status) => {
309                task_status.push(status);
310            }
311            Err(TryRecvError::Empty) => {
312                break;
313            }
314            Err(TryRecvError::Disconnected) => {
315                error!("Task statuses channel disconnected");
316            }
317        }
318    }
319
320    task_status
321}