1use std::{panic::AssertUnwindSafe, sync::Arc, time::UNIX_EPOCH};
2
3use eyre::{Context, OptionExt};
4use futures::FutureExt;
5use ora_proto::{
6 common::v1::JobType,
7 server::v1::{
8 executor_message::ExecutorMessageKind, server_message::ServerMessageKind,
9 ExecutorCapabilities, ExecutorConnectionRequest, ExecutorConnectionResponse,
10 ExecutorHeartbeat, ExecutorMessage,
11 },
12};
13use parking_lot::Mutex;
14use tokio_util::sync::CancellationToken;
15use tracing::Instrument;
16use uuid::Uuid;
17use wgroup::WaitGroup;
18
19#[allow(clippy::wildcard_imports)]
20use tonic::codegen::*;
21
22use crate::{
23 executor::{ExecutionContext, ExecutionFailedCb},
24 IndexMap,
25};
26
27use super::{ExecutionHandlerRaw, Executor, ExecutorOptions};
28
29impl<C> Executor<C>
30where
31 C: tonic::client::GrpcService<tonic::body::BoxBody> + Clone,
32 C::Error: Into<StdError>,
33 C::ResponseBody: Body<Data = Bytes> + std::marker::Send + 'static,
34 <C::ResponseBody as Body>::Error: Into<StdError> + std::marker::Send,
35{
36 #[tracing::instrument(skip_all, name = "executor_loop", fields(executor_id, executor_name))]
44 pub async fn run(&mut self) -> eyre::Result<()> {
45 let executor_span = tracing::Span::current();
46
47 executor_span.record("executor_name", &self.options.name);
48
49 let (executor_requests, recv) = flume::bounded(0);
50
51 let on_execution_failed = self.on_execution_failed.clone();
52
53 let mut state = ExecutorState {
54 executor_id: None,
55 options: &self.options,
56 handlers: &self.handlers,
57 executor_requests,
58 heartbeat_interval: std::time::Duration::from_secs(1),
60 in_progress_executions: Arc::new(Mutex::new(IndexMap::default())),
61 wg: WaitGroup::new(),
62 };
63
64 let send_chan_guard = state.wg.add_with("send-channel");
65
66 let mut server_messages = self
67 .client
68 .executor_connection(tonic::Request::new(async_stream::stream!({
72 loop {
73 tokio::select! {
74 _ = send_chan_guard.waiting() => {
75 tracing::debug!("send channel closed, stopping stream");
76 return;
77 }
78 msg = recv.recv_async() => {
79 if let Ok(msg) = msg {
80 tracing::trace!(?msg, "sending message");
81 yield msg;
82 } else {
83 tracing::debug!("send channel closed, stopping stream");
84 return;
85 }
86 }
87 }
88 }
89 })))
90 .await?
91 .into_inner();
92
93 state
95 .executor_requests
96 .send_async(ExecutorConnectionRequest {
97 message: Some(ExecutorMessage {
98 executor_message_kind: Some(ExecutorMessageKind::Capabilities(
99 ExecutorCapabilities {
100 max_concurrent_executions: state
101 .options
102 .max_concurrent_executions
103 .get(),
104 name: state.options.name.clone(),
105 supported_job_types: self
106 .handlers
107 .iter()
108 .map(|h| {
109 let handler_meta = h.job_type_metadata();
110
111 JobType {
112 id: handler_meta.id.to_string(),
113 name: handler_meta.name.clone(),
114 description: handler_meta.description.clone(),
115 input_schema_json: handler_meta.input_schema_json.clone(),
116 output_schema_json: handler_meta.output_schema_json.clone(),
117 }
118 })
119 .collect(),
120 },
121 )),
122 }),
123 })
124 .await?;
125
126 loop {
127 tokio::select! {
128 _ = tokio::time::sleep(state.heartbeat_interval) => {
129 if state.executor_requests.send(ExecutorConnectionRequest {
130 message: Some(ExecutorMessage {
131 executor_message_kind: Some(ExecutorMessageKind::Heartbeat(
132 ExecutorHeartbeat {},
133 )),
134 }),
135 }).is_err() {
136 return Ok(());
137 }
138 }
139 server_msg = server_messages.message() => {
140 match server_msg {
141 Ok(Some(server_msg)) => {
142 handle_server_response(&mut state, &executor_span, server_msg, on_execution_failed.clone()).await?;
143 }
144 Ok(None) => {
145 tracing::info!("incoming stream closed by the server");
146
147 if !state.in_progress_executions.lock().is_empty() {
148 tracing::warn!("cancelling executions in progress");
149
150 loop {
151 let execution_state = {
152 let mut in_progress_executions = state.in_progress_executions.lock();
153
154 if in_progress_executions.is_empty() {
155 break;
156 }
157
158 let execution_id = in_progress_executions.keys().copied().next();
159
160 if let Some(execution_id) = execution_id {
161 in_progress_executions.swap_remove(&execution_id)
162 } else {
163 None
164 }
165 };
166
167 if let Some(mut execution_state) = execution_state {
168 execution_state.cancellation_token.cancel();
169
170 tokio::select! {
171 _ = &mut execution_state.handle => {}
172 _ = tokio::time::sleep(state.options.cancellation_grace_period) => {
173 execution_state.handle.abort();
174 }
175 }
176 } else {
177 break;
178 }
179 }
180 }
181
182 return Ok(());
183 }
184 Err(error) => {
185 tracing::warn!(?error, "received error from the server");
186 }
187 }
188 }
189 }
190 }
191 }
192}
193
194#[tracing::instrument(name = "handle_server_message", skip_all)]
195async fn handle_server_response(
196 state: &mut ExecutorState<'_>,
197 executor_span: &tracing::Span,
198 response: ExecutorConnectionResponse,
199 on_execution_failed: Option<ExecutionFailedCb>,
200) -> eyre::Result<()> {
201 let Some(message) = response.message else {
202 tracing::warn!("received empty message from the server");
203 return Ok(());
204 };
205
206 let Some(message) = message.server_message_kind else {
207 tracing::warn!("received unknown or missing message kind from the server");
208 return Ok(());
209 };
210
211 match message {
212 ServerMessageKind::Properties(executor_properties) => {
213 executor_span.record("executor_id", &executor_properties.executor_id);
214 state.executor_id = Some(executor_properties.executor_id);
215
216 if let Some(max_hb_interval) = executor_properties.max_heartbeat_interval {
217 if let Ok(max_hb_interval) = std::time::Duration::try_from(max_hb_interval) {
218 state.heartbeat_interval = max_hb_interval / 2;
219 tracing::debug!(
220 heartbeat_interval = ?state.heartbeat_interval,
221 "using heartbeat interval"
222 );
223 }
224 }
225
226 tracing::info!("received executor properties");
227 }
228 ServerMessageKind::ExecutionReady(execution_ready) => {
229 spawn_execution(state, execution_ready, on_execution_failed.clone()).await?;
230 }
231 ServerMessageKind::ExecutionCancelled(execution_cancelled) => {
232 let execution_id: Uuid = execution_cancelled
233 .execution_id
234 .parse()
235 .wrap_err("expected execution ID to be UUID")?;
236
237 let execution_state = state
238 .in_progress_executions
239 .lock()
240 .swap_remove(&execution_id);
241
242 if let Some(execution_state) = execution_state {
243 tokio::spawn(
244 cancel_execution(execution_state, state.options.cancellation_grace_period)
245 .instrument(tracing::Span::current()),
246 );
247 } else {
248 tracing::warn!("received cancellation for unknown execution");
249 }
250 }
251 }
252
253 Ok(())
254}
255
256#[tracing::instrument(skip_all, fields(
257 execution_id = %execution_state.execution_id,
258))]
259async fn cancel_execution(mut execution_state: ExecutionState, grace_period: std::time::Duration) {
260 tracing::debug!("cancelling execution");
261 execution_state.cancellation_token.cancel();
262
263 tokio::select! {
264 _ = &mut execution_state.handle => {
265 tracing::debug!("execution cancelled");
266 }
267 _ = tokio::time::sleep(grace_period) => {
268 if !execution_state.handle.is_finished() {
269 tracing::warn!("execution did not cancel in time, aborting forcefully");
270 execution_state.handle.abort();
271 }
272 }
273 }
274
275 tracing::debug!("cancelled execution");
276}
277
278#[tracing::instrument(skip_all,
279 fields(
280 execution_id = %execution_ready.execution_id,
281 job_id = %execution_ready.job_id,
282 )
283)]
284async fn spawn_execution(
285 state: &ExecutorState<'_>,
286 execution_ready: ora_proto::server::v1::ExecutionReady,
287 on_execution_failed: Option<ExecutionFailedCb>,
288) -> eyre::Result<()> {
289 let execution_span = tracing::Span::current();
290
291 let executor_requests = state.executor_requests.clone();
292
293 tracing::debug!("received new execution");
294
295 let execution_id: Uuid = execution_ready
296 .execution_id
297 .parse()
298 .wrap_err("expected execution ID to be UUID")?;
299
300 let job_id: Uuid = execution_ready
301 .job_id
302 .parse()
303 .wrap_err("expected job ID to be UUID")?;
304
305 let cancellation_token = CancellationToken::new();
306
307 let ctx = ExecutionContext {
308 execution_id,
309 job_id,
310 target_execution_time: execution_ready
311 .target_execution_time
312 .and_then(|t| t.try_into().ok())
313 .unwrap_or(UNIX_EPOCH),
314 attempt_number: execution_ready.attempt_number,
315 job_type_id: execution_ready.job_type_id,
316 cancellation_token: cancellation_token.clone(),
317 };
318
319 let handler = state
320 .handlers
321 .iter()
322 .find(|h| h.can_execute(&ctx))
323 .ok_or_eyre("no handler found for the execution")?
324 .clone();
325
326 tracing::trace!("found handler for the execution");
327
328 let now = std::time::SystemTime::now();
329
330 if executor_requests
331 .send_async(ExecutorConnectionRequest {
332 message: Some(ExecutorMessage {
333 executor_message_kind: Some(ExecutorMessageKind::ExecutionStarted(
334 ora_proto::server::v1::ExecutionStarted {
335 timestamp: Some(now.into()),
336 execution_id: execution_ready.execution_id,
337 },
338 )),
339 }),
340 })
341 .await
342 .is_err()
343 {
344 tracing::debug!("not sending execution started message, executor is shutting down");
345 return Ok(());
346 }
347 tracing::trace!("sent execution started message");
348
349 let execution_guard = state.wg.add_with(&format!("execution-{execution_id}"));
350
351 let cancellation_grace_period = state.options.cancellation_grace_period;
352
353 let handle = tokio::spawn({
354 let in_progress_executions = state.in_progress_executions.clone();
355 let in_progress_executions2 = state.in_progress_executions.clone();
356 tracing::debug!("executing handler");
357
358 let execution_id = ctx.execution_id;
359
360 async move {
361 let mut warn_bomb = ExecutionDropWarnBomb::new(tracing::Span::current());
362
363 let handler_fut = async move {
364 match AssertUnwindSafe(handler.execute(ctx.clone(), &execution_ready.input_payload_json))
365 .catch_unwind()
366 .await
367 {
368 Ok(task_result) => match task_result {
369 Ok(output_json) => {
370 tracing::debug!("execution succeeded");
371 let now = std::time::SystemTime::now();
372
373 if let Err(error) = executor_requests
374 .send_async(ExecutorConnectionRequest {
375 message: Some(ExecutorMessage {
376 executor_message_kind: Some(
377 ExecutorMessageKind::ExecutionSucceeded(
378 ora_proto::server::v1::ExecutionSucceeded {
379 timestamp: Some(now.into()),
380 execution_id: execution_id.to_string(),
381 output_payload_json: output_json,
382 },
383 ),
384 ),
385 }),
386 })
387 .await
388 {
389 tracing::warn!(?error, "failed to send execution result");
390 }
391 }
392 Err(error) => {
393 tracing::debug!(error, "execution failed");
394
395 if let Some(on_execution_failed) = &on_execution_failed {
396 on_execution_failed(ctx, &error);
397 }
398
399 let now = std::time::SystemTime::now();
400
401 if let Err(error) = executor_requests
402 .send_async(ExecutorConnectionRequest {
403 message: Some(ExecutorMessage {
404 executor_message_kind: Some(
405 ExecutorMessageKind::ExecutionFailed(
406 ora_proto::server::v1::ExecutionFailed {
407 timestamp: Some(now.into()),
408 execution_id: execution_id.to_string(),
409 error_message: error,
410 },
411 ),
412 ),
413 }),
414 })
415 .await
416 {
417 tracing::warn!(?error, "failed to send execution result");
418 }
419 }
420 },
421 Err(panic_out) => {
422 tracing::warn!("handler panicked");
423 let now = std::time::SystemTime::now();
424
425 let error_message = if let Some(error) = panic_out.downcast_ref::<&str>() {
426 (*error).to_string()
427 } else if let Some(error) = panic_out.downcast_ref::<String>() {
428 error.clone()
429 } else {
430 "handler panicked".to_string()
431 };
432
433 if let Some(on_execution_failed) = &on_execution_failed {
434 on_execution_failed(ctx, &error_message);
435 }
436
437 if let Err(error) = executor_requests
438 .send_async(ExecutorConnectionRequest {
439 message: Some(ExecutorMessage {
440 executor_message_kind: Some(
441 ExecutorMessageKind::ExecutionFailed(
442 ora_proto::server::v1::ExecutionFailed {
443 timestamp: Some(now.into()),
444 execution_id: execution_id.to_string(),
445 error_message,
446 },
447 ),
448 ),
449 }),
450 })
451 .await
452 {
453 tracing::warn!(?error, "failed to send execution result");
454 }
455 }
456 }
457
458 if in_progress_executions
459 .lock()
460 .swap_remove(&execution_id)
461 .is_none()
462 {
463 tracing::debug!(
464 "execution was not found in the in-progress list, it must have been cancelled"
465 );
466 }
467 };
468
469 let mut handler_fut = std::pin::pin!(handler_fut);
470
471 loop {
472 tokio::select! {
473 _ = execution_guard.waiting() => {
474 let execution_state = in_progress_executions2.lock().swap_remove(&execution_id);
475
476 if let Some(execution_state) = execution_state {
477 tokio::spawn(
478 cancel_execution(execution_state, cancellation_grace_period)
479 .instrument(tracing::Span::current()));
480 }
481
482 (&mut handler_fut).await;
483 }
484 _ = &mut handler_fut => {
485 break;
486 }
487 }
488 }
489 warn_bomb.defuse();
490 }
491 .instrument(execution_span)
492 });
493
494 state.in_progress_executions.lock().insert(
495 execution_id,
496 ExecutionState {
497 execution_id,
498 cancellation_token,
499 handle,
500 },
501 );
502
503 Ok(())
504}
505
506struct ExecutorState<'s> {
507 executor_id: Option<String>,
508 options: &'s ExecutorOptions,
509 handlers: &'s [Arc<dyn ExecutionHandlerRaw + Send + Sync>],
510 executor_requests: flume::Sender<ExecutorConnectionRequest>,
511 heartbeat_interval: std::time::Duration,
512 in_progress_executions: Arc<Mutex<IndexMap<Uuid, ExecutionState>>>,
513 wg: WaitGroup,
514}
515
516struct ExecutionState {
517 execution_id: Uuid,
518 cancellation_token: CancellationToken,
519 handle: tokio::task::JoinHandle<()>,
520}
521
522struct ExecutionDropWarnBomb {
523 span: tracing::Span,
524 defused: bool,
525}
526
527impl ExecutionDropWarnBomb {
528 fn new(span: tracing::Span) -> Self {
529 Self {
530 span,
531 defused: false,
532 }
533 }
534
535 fn defuse(&mut self) {
536 self.defused = true;
537 }
538}
539
540impl Drop for ExecutionDropWarnBomb {
541 fn drop(&mut self) {
542 if !self.defused {
543 self.span.in_scope(|| {
544 tracing::warn!("execution in progess was dropped");
545 });
546 }
547 }
548}