1use datafusion::config::ConfigOptions;
19use datafusion::physical_plan::ExecutionPlan;
20
21use kapot_core::serde::protobuf::{
22 scheduler_grpc_client::SchedulerGrpcClient, PollWorkParams, PollWorkResult,
23 TaskDefinition, TaskStatus,
24};
25use datafusion::prelude::SessionConfig;
26use tokio::sync::{OwnedSemaphorePermit, Semaphore};
27
28use crate::cpu_bound_executor::DedicatedExecutor;
29use crate::executor::Executor;
30use crate::{as_task_status, TaskExecutionTimes};
31use kapot_core::error::KapotError;
32use kapot_core::serde::scheduler::{ExecutorSpecification, PartitionId};
33use kapot_core::serde::KapotCodec;
34use datafusion::execution::context::TaskContext;
35use datafusion::functions::datetime::date_part;
36use datafusion::functions::unicode::substr;
37use datafusion::functions_aggregate::covariance::{covar_pop_udaf, covar_samp_udaf};
38use datafusion::functions_aggregate::sum::sum_udaf;
39use datafusion::functions_aggregate::variance::var_samp_udaf;
40use datafusion_proto::logical_plan::AsLogicalPlan;
41use datafusion_proto::physical_plan::AsExecutionPlan;
42use futures::FutureExt;
43use log::{debug, error, info, warn};
44use std::any::Any;
45use std::collections::HashMap;
46use std::convert::TryInto;
47use std::error::Error;
48use std::ops::Deref;
49use std::sync::mpsc::{Receiver, Sender, TryRecvError};
50use std::time::{SystemTime, UNIX_EPOCH};
51use std::{sync::Arc, time::Duration};
52use tonic::transport::Channel;
53
54pub async fn poll_loop<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan>(
55 mut scheduler: SchedulerGrpcClient<Channel>,
56 executor: Arc<Executor>,
57 codec: KapotCodec<T, U>,
58) -> Result<(), KapotError> {
59 let executor_specification: ExecutorSpecification = executor
60 .metadata
61 .specification
62 .as_ref()
63 .unwrap()
64 .clone()
65 .into();
66 let available_task_slots =
67 Arc::new(Semaphore::new(executor_specification.task_slots as usize));
68
69 let (task_status_sender, mut task_status_receiver) =
70 std::sync::mpsc::channel::<TaskStatus>();
71 info!("Starting poll work loop with scheduler");
72
73 let dedicated_executor =
74 DedicatedExecutor::new("task_runner", executor_specification.task_slots as usize);
75
76 loop {
77 let permit = available_task_slots.acquire().await.unwrap();
79 drop(permit);
81
82 let mut active_job = false;
85
86 let task_status: Vec<TaskStatus> =
87 sample_tasks_status(&mut task_status_receiver).await;
88
89 let poll_work_result: anyhow::Result<
90 tonic::Response<PollWorkResult>,
91 tonic::Status,
92 > = scheduler
93 .poll_work(PollWorkParams {
94 metadata: Some(executor.metadata.clone()),
95 num_free_slots: available_task_slots.available_permits() as u32,
96 task_status,
97 })
98 .await;
99
100 match poll_work_result {
101 Ok(result) => {
102 let tasks = result.into_inner().tasks;
103 active_job = !tasks.is_empty();
104
105 for task in tasks {
106 let task_status_sender = task_status_sender.clone();
107
108 let permit =
110 available_task_slots.clone().acquire_owned().await.unwrap();
111
112 match run_received_task(
113 executor.clone(),
114 permit,
115 task_status_sender,
116 task,
117 &codec,
118 &dedicated_executor,
119 )
120 .await
121 {
122 Ok(_) => {}
123 Err(e) => {
124 warn!("Failed to run task: {:?}", e);
125 }
126 }
127 }
128 }
129 Err(error) => {
130 warn!("Executor poll work loop failed. If this continues to happen the Scheduler might be marked as dead. Error: {}", error);
131 }
132 }
133
134 if !active_job {
135 tokio::time::sleep(Duration::from_millis(100)).await;
136 }
137 }
138}
139
140pub(crate) fn any_to_string(any: &Box<dyn Any + Send>) -> String {
142 if let Some(s) = any.downcast_ref::<&str>() {
143 (*s).to_string()
144 } else if let Some(s) = any.downcast_ref::<String>() {
145 s.clone()
146 } else if let Some(error) = any.downcast_ref::<Box<dyn Error + Send>>() {
147 error.to_string()
148 } else {
149 "Unknown error occurred".to_string()
150 }
151}
152
153async fn run_received_task<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan>(
154 executor: Arc<Executor>,
155 permit: OwnedSemaphorePermit,
156 task_status_sender: Sender<TaskStatus>,
157 task: TaskDefinition,
158 codec: &KapotCodec<T, U>,
159 dedicated_executor: &DedicatedExecutor,
160) -> Result<(), KapotError> {
161 let task_id = task.task_id;
162 let task_attempt_num = task.task_attempt_num;
163 let job_id = task.job_id;
164 let stage_id = task.stage_id;
165 let stage_attempt_num = task.stage_attempt_num;
166 let task_launch_time = task.launch_time;
167 let partition_id = task.partition_id;
168 let start_exec_time = SystemTime::now()
169 .duration_since(UNIX_EPOCH)
170 .unwrap()
171 .as_millis() as u64;
172 let task_identity = format!(
173 "TID {task_id} {job_id}/{stage_id}.{stage_attempt_num}/{partition_id}.{task_attempt_num}"
174 );
175 info!("Received task {}", task_identity);
176
177 let mut task_props = HashMap::new();
178 for kv_pair in task.props {
179 task_props.insert(kv_pair.key, kv_pair.value);
180 }
181 let mut config = ConfigOptions::new();
182 for (k, v) in task_props {
183 config.set(&k, &v)?;
184 }
185 let session_config = SessionConfig::from(config);
186
187 let mut task_scalar_functions = HashMap::new();
188 let mut task_aggregate_functions = HashMap::new();
189 let mut task_window_functions = HashMap::new();
190 for scalar_func in executor.scalar_functions.clone() {
192 task_scalar_functions.insert(scalar_func.0.clone(), scalar_func.1);
193 }
194 for agg_func in executor.aggregate_functions.clone() {
195 task_aggregate_functions.insert(agg_func.0, agg_func.1);
196 }
197 task_aggregate_functions.insert("var".to_string(), var_samp_udaf());
200 task_aggregate_functions.insert("covar_samp".to_string(), covar_samp_udaf());
201 task_aggregate_functions.insert("covar_pop".to_string(), covar_pop_udaf());
202 task_aggregate_functions.insert("SUM".to_string(), sum_udaf());
203
204 task_scalar_functions.insert("date_part".to_string(), date_part());
206 task_scalar_functions.insert("substr".to_string(), substr());
207
208 for window_func in executor.window_functions.clone() {
209 task_window_functions.insert(window_func.0, window_func.1);
210 }
211 let runtime = executor.get_runtime(false);
212 let session_id = task.session_id.clone();
213 let task_context = Arc::new(TaskContext::new(
214 Some(task_identity.clone()),
215 session_id,
216 session_config,
217 task_scalar_functions,
218 task_aggregate_functions,
219 task_window_functions,
220 runtime.clone(),
221 ));
222
223 let plan: Arc<dyn ExecutionPlan> =
224 U::try_decode(task.plan.as_slice()).and_then(|proto| {
225 proto.try_into_physical_plan(
226 task_context.deref(),
227 runtime.deref(),
228 codec.physical_extension_codec(),
229 )
230 })?;
231
232 let query_stage_exec = executor.execution_engine.create_query_stage_exec(
233 job_id.clone(),
234 stage_id as usize,
235 plan,
236 &executor.work_dir,
237 )?;
238 dedicated_executor.spawn(async move {
239 use std::panic::AssertUnwindSafe;
240 let part = PartitionId {
241 job_id: job_id.clone(),
242 stage_id: stage_id as usize,
243 partition_id: partition_id as usize,
244 };
245
246 let execution_result = match AssertUnwindSafe(executor.execute_query_stage(
247 task_id as usize,
248 part.clone(),
249 query_stage_exec.clone(),
250 task_context,
251 ))
252 .catch_unwind()
253 .await
254 {
255 Ok(Ok(r)) => Ok(r),
256 Ok(Err(r)) => Err(r),
257 Err(r) => {
258 error!("Error executing task: {:?}", any_to_string(&r));
259 Err(KapotError::Internal(format!("{:#?}", any_to_string(&r))))
260 }
261 };
262
263 info!("Done with task {}", task_identity);
264 debug!("Statistics: {:?}", execution_result);
265
266 let plan_metrics = query_stage_exec.collect_plan_metrics();
267 let operator_metrics = plan_metrics
268 .into_iter()
269 .map(|m| m.try_into())
270 .collect::<Result<Vec<_>, KapotError>>()
271 .ok();
272
273 let end_exec_time = SystemTime::now()
274 .duration_since(UNIX_EPOCH)
275 .unwrap()
276 .as_millis() as u64;
277
278 let task_execution_times = TaskExecutionTimes {
279 launch_time: task_launch_time,
280 start_exec_time,
281 end_exec_time,
282 };
283
284 let _ = task_status_sender.send(as_task_status(
285 execution_result,
286 executor.metadata.id.clone(),
287 task_id as usize,
288 stage_attempt_num as usize,
289 part,
290 operator_metrics,
291 task_execution_times,
292 ));
293
294 drop(permit);
296 });
297
298 Ok(())
299}
300
301async fn sample_tasks_status(
302 task_status_receiver: &mut Receiver<TaskStatus>,
303) -> Vec<TaskStatus> {
304 let mut task_status: Vec<TaskStatus> = vec![];
305
306 loop {
307 match task_status_receiver.try_recv() {
308 anyhow::Result::Ok(status) => {
309 task_status.push(status);
310 }
311 Err(TryRecvError::Empty) => {
312 break;
313 }
314 Err(TryRecvError::Disconnected) => {
315 error!("Task statuses channel disconnected");
316 }
317 }
318 }
319
320 task_status
321}