1use std::{panic::AssertUnwindSafe, sync::Arc, time::UNIX_EPOCH};
2
3use eyre::{Context, OptionExt};
4use futures::FutureExt;
5use ora_proto::{
6 common::v1::JobType,
7 server::v1::{
8 executor_message::ExecutorMessageKind, server_message::ServerMessageKind,
9 ExecutorCapabilities, ExecutorConnectionRequest, ExecutorConnectionResponse,
10 ExecutorHeartbeat, ExecutorMessage,
11 },
12};
13use parking_lot::Mutex;
14use tokio_util::sync::CancellationToken;
15use tracing::Instrument;
16use uuid::Uuid;
17use wgroup::WaitGroup;
18
19#[allow(clippy::wildcard_imports)]
20use tonic::codegen::*;
21
22use crate::{executor::ExecutionContext, IndexMap};
23
24use super::{ExecutionHandlerRaw, Executor, ExecutorOptions};
25
26impl<C> Executor<C>
27where
28 C: tonic::client::GrpcService<tonic::body::BoxBody> + Clone,
29 C::Error: Into<StdError>,
30 C::ResponseBody: Body<Data = Bytes> + std::marker::Send + 'static,
31 <C::ResponseBody as Body>::Error: Into<StdError> + std::marker::Send,
32{
33 #[tracing::instrument(skip_all, name = "executor_loop", fields(executor_id, executor_name))]
41 pub async fn run(&mut self) -> eyre::Result<()> {
42 let executor_span = tracing::Span::current();
43
44 executor_span.record("executor_name", &self.options.name);
45
46 let (executor_requests, recv) = flume::bounded(0);
47
48 let mut state = ExecutorState {
49 executor_id: None,
50 options: &self.options,
51 handlers: &self.handlers,
52 executor_requests,
53 heartbeat_interval: std::time::Duration::from_secs(1),
55 in_progress_executions: Arc::new(Mutex::new(IndexMap::default())),
56 wg: WaitGroup::new(),
57 };
58
59 let send_chan_guard = state.wg.add_with("send-channel");
60
61 let mut server_messages = self
62 .client
63 .executor_connection(tonic::Request::new(async_stream::stream!({
67 loop {
68 tokio::select! {
69 _ = send_chan_guard.waiting() => {
70 tracing::debug!("send channel closed, stopping stream");
71 return;
72 }
73 msg = recv.recv_async() => {
74 if let Ok(msg) = msg {
75 tracing::trace!(?msg, "sending message");
76 yield msg;
77 } else {
78 tracing::debug!("send channel closed, stopping stream");
79 return;
80 }
81 }
82 }
83 }
84 })))
85 .await?
86 .into_inner();
87
88 state
90 .executor_requests
91 .send_async(ExecutorConnectionRequest {
92 message: Some(ExecutorMessage {
93 executor_message_kind: Some(ExecutorMessageKind::Capabilities(
94 ExecutorCapabilities {
95 max_concurrent_executions: state
96 .options
97 .max_concurrent_executions
98 .get(),
99 name: state.options.name.clone(),
100 supported_job_types: self
101 .handlers
102 .iter()
103 .map(|h| {
104 let handler_meta = h.job_type_metadata();
105
106 JobType {
107 id: handler_meta.id.to_string(),
108 name: handler_meta.name.clone(),
109 description: handler_meta.description.clone(),
110 input_schema_json: handler_meta.input_schema_json.clone(),
111 output_schema_json: handler_meta.output_schema_json.clone(),
112 }
113 })
114 .collect(),
115 },
116 )),
117 }),
118 })
119 .await?;
120
121 loop {
122 tokio::select! {
123 _ = tokio::time::sleep(state.heartbeat_interval) => {
124 if state.executor_requests.send(ExecutorConnectionRequest {
125 message: Some(ExecutorMessage {
126 executor_message_kind: Some(ExecutorMessageKind::Heartbeat(
127 ExecutorHeartbeat {},
128 )),
129 }),
130 }).is_err() {
131 return Ok(());
132 }
133 }
134 server_msg = server_messages.message() => {
135 match server_msg {
136 Ok(Some(server_msg)) => {
137 handle_server_response(&mut state, &executor_span, server_msg).await?;
138 }
139 Ok(None) => {
140 tracing::info!("incoming stream closed by the server");
141
142 if !state.in_progress_executions.lock().is_empty() {
143 tracing::warn!("cancelling executions in progress");
144
145 loop {
146 let execution_state = {
147 let mut in_progress_executions = state.in_progress_executions.lock();
148
149 if in_progress_executions.is_empty() {
150 break;
151 }
152
153 let execution_id = in_progress_executions.keys().copied().next();
154
155 if let Some(execution_id) = execution_id {
156 in_progress_executions.swap_remove(&execution_id)
157 } else {
158 None
159 }
160 };
161
162 if let Some(mut execution_state) = execution_state {
163 execution_state.cancellation_token.cancel();
164
165 tokio::select! {
166 _ = &mut execution_state.handle => {}
167 _ = tokio::time::sleep(state.options.cancellation_grace_period) => {
168 execution_state.handle.abort();
169 }
170 }
171 } else {
172 break;
173 }
174 }
175 }
176
177 return Ok(());
178 }
179 Err(error) => {
180 tracing::warn!(?error, "received error from the server");
181 }
182 }
183 }
184 }
185 }
186 }
187}
188
189#[tracing::instrument(name = "handle_server_message", skip_all)]
190async fn handle_server_response(
191 state: &mut ExecutorState<'_>,
192 executor_span: &tracing::Span,
193 response: ExecutorConnectionResponse,
194) -> eyre::Result<()> {
195 let Some(message) = response.message else {
196 tracing::warn!("received empty message from the server");
197 return Ok(());
198 };
199
200 let Some(message) = message.server_message_kind else {
201 tracing::warn!("received unknown or missing message kind from the server");
202 return Ok(());
203 };
204
205 match message {
206 ServerMessageKind::Properties(executor_properties) => {
207 executor_span.record("executor_id", &executor_properties.executor_id);
208 state.executor_id = Some(executor_properties.executor_id);
209
210 if let Some(max_hb_interval) = executor_properties.max_heartbeat_interval {
211 if let Ok(max_hb_interval) = std::time::Duration::try_from(max_hb_interval) {
212 state.heartbeat_interval = max_hb_interval / 2;
213 tracing::debug!(
214 heartbeat_interval = ?state.heartbeat_interval,
215 "using heartbeat interval"
216 );
217 }
218 }
219
220 tracing::info!("received executor properties");
221 }
222 ServerMessageKind::ExecutionReady(execution_ready) => {
223 spawn_execution(state, execution_ready).await?;
224 }
225 ServerMessageKind::ExecutionCancelled(execution_cancelled) => {
226 let execution_id: Uuid = execution_cancelled
227 .execution_id
228 .parse()
229 .wrap_err("expected execution ID to be UUID")?;
230
231 let execution_state = state
232 .in_progress_executions
233 .lock()
234 .swap_remove(&execution_id);
235
236 if let Some(execution_state) = execution_state {
237 tokio::spawn(
238 cancel_execution(execution_state, state.options.cancellation_grace_period)
239 .instrument(tracing::Span::current()),
240 );
241 } else {
242 tracing::warn!("received cancellation for unknown execution");
243 }
244 }
245 }
246
247 Ok(())
248}
249
250#[tracing::instrument(skip_all, fields(
251 execution_id = %execution_state.execution_id,
252))]
253async fn cancel_execution(mut execution_state: ExecutionState, grace_period: std::time::Duration) {
254 tracing::debug!("cancelling execution");
255 execution_state.cancellation_token.cancel();
256
257 tokio::select! {
258 _ = &mut execution_state.handle => {
259 tracing::debug!("execution cancelled");
260 }
261 _ = tokio::time::sleep(grace_period) => {
262 if !execution_state.handle.is_finished() {
263 tracing::warn!("execution did not cancel in time, aborting forcefully");
264 execution_state.handle.abort();
265 }
266 }
267 }
268
269 tracing::debug!("cancelled execution");
270}
271
272#[tracing::instrument(skip_all,
273 fields(
274 execution_id = %execution_ready.execution_id,
275 job_id = %execution_ready.job_id,
276 )
277)]
278async fn spawn_execution(
279 state: &ExecutorState<'_>,
280 execution_ready: ora_proto::server::v1::ExecutionReady,
281) -> eyre::Result<()> {
282 let execution_span = tracing::Span::current();
283
284 let executor_requests = state.executor_requests.clone();
285
286 tracing::debug!("received new execution");
287
288 let execution_id: Uuid = execution_ready
289 .execution_id
290 .parse()
291 .wrap_err("expected execution ID to be UUID")?;
292
293 let job_id: Uuid = execution_ready
294 .job_id
295 .parse()
296 .wrap_err("expected job ID to be UUID")?;
297
298 let cancellation_token = CancellationToken::new();
299
300 let ctx = ExecutionContext {
301 execution_id,
302 job_id,
303 target_execution_time: execution_ready
304 .target_execution_time
305 .and_then(|t| t.try_into().ok())
306 .unwrap_or(UNIX_EPOCH),
307 attempt_number: execution_ready.attempt_number,
308 job_type_id: execution_ready.job_type_id,
309 cancellation_token: cancellation_token.clone(),
310 };
311
312 let handler = state
313 .handlers
314 .iter()
315 .find(|h| h.can_execute(&ctx))
316 .ok_or_eyre("no handler found for the execution")?
317 .clone();
318
319 tracing::trace!("found handler for the execution");
320
321 let now = std::time::SystemTime::now();
322
323 if executor_requests
324 .send_async(ExecutorConnectionRequest {
325 message: Some(ExecutorMessage {
326 executor_message_kind: Some(ExecutorMessageKind::ExecutionStarted(
327 ora_proto::server::v1::ExecutionStarted {
328 timestamp: Some(now.into()),
329 execution_id: execution_ready.execution_id,
330 },
331 )),
332 }),
333 })
334 .await
335 .is_err()
336 {
337 tracing::debug!("not sending execution started message, executor is shutting down");
338 return Ok(());
339 }
340 tracing::trace!("sent execution started message");
341
342 let execution_guard = state.wg.add_with(&format!("execution-{execution_id}"));
343
344 let cancellation_grace_period = state.options.cancellation_grace_period;
345
346 let handle = tokio::spawn({
347 let in_progress_executions = state.in_progress_executions.clone();
348 let in_progress_executions2 = state.in_progress_executions.clone();
349 tracing::debug!("executing handler");
350
351 let execution_id = ctx.execution_id;
352
353 async move {
354 let mut warn_bomb = ExecutionDropWarnBomb::new(tracing::Span::current());
355
356 let handler_fut = async move {
357 match AssertUnwindSafe(handler.execute(ctx, &execution_ready.input_payload_json))
358 .catch_unwind()
359 .await
360 {
361 Ok(task_result) => match task_result {
362 Ok(output_json) => {
363 tracing::debug!("execution succeeded");
364 let now = std::time::SystemTime::now();
365
366 if let Err(error) = executor_requests
367 .send_async(ExecutorConnectionRequest {
368 message: Some(ExecutorMessage {
369 executor_message_kind: Some(
370 ExecutorMessageKind::ExecutionSucceeded(
371 ora_proto::server::v1::ExecutionSucceeded {
372 timestamp: Some(now.into()),
373 execution_id: execution_id.to_string(),
374 output_payload_json: output_json,
375 },
376 ),
377 ),
378 }),
379 })
380 .await
381 {
382 tracing::warn!(?error, "failed to send execution result");
383 }
384 }
385 Err(error) => {
386 tracing::debug!(error, "execution failed");
387 let now = std::time::SystemTime::now();
388
389 if let Err(error) = executor_requests
390 .send_async(ExecutorConnectionRequest {
391 message: Some(ExecutorMessage {
392 executor_message_kind: Some(
393 ExecutorMessageKind::ExecutionFailed(
394 ora_proto::server::v1::ExecutionFailed {
395 timestamp: Some(now.into()),
396 execution_id: execution_id.to_string(),
397 error_message: error,
398 },
399 ),
400 ),
401 }),
402 })
403 .await
404 {
405 tracing::warn!(?error, "failed to send execution result");
406 }
407 }
408 },
409 Err(panic_out) => {
410 tracing::warn!("handler panicked");
411 let now = std::time::SystemTime::now();
412
413 let error_message = if let Some(error) = panic_out.downcast_ref::<&str>() {
414 (*error).to_string()
415 } else if let Some(error) = panic_out.downcast_ref::<String>() {
416 error.clone()
417 } else {
418 "handler panicked".to_string()
419 };
420
421 if let Err(error) = executor_requests
422 .send_async(ExecutorConnectionRequest {
423 message: Some(ExecutorMessage {
424 executor_message_kind: Some(
425 ExecutorMessageKind::ExecutionFailed(
426 ora_proto::server::v1::ExecutionFailed {
427 timestamp: Some(now.into()),
428 execution_id: execution_id.to_string(),
429 error_message,
430 },
431 ),
432 ),
433 }),
434 })
435 .await
436 {
437 tracing::warn!(?error, "failed to send execution result");
438 }
439 }
440 }
441
442 if in_progress_executions
443 .lock()
444 .swap_remove(&execution_id)
445 .is_none()
446 {
447 tracing::debug!(
448 "execution was not found in the in-progress list, it must have been cancelled"
449 );
450 }
451 };
452
453 let mut handler_fut = std::pin::pin!(handler_fut);
454
455 loop {
456 tokio::select! {
457 _ = execution_guard.waiting() => {
458 let execution_state = in_progress_executions2.lock().swap_remove(&execution_id);
459
460 if let Some(execution_state) = execution_state {
461 tokio::spawn(
462 cancel_execution(execution_state, cancellation_grace_period)
463 .instrument(tracing::Span::current()));
464 }
465
466 (&mut handler_fut).await;
467 }
468 _ = &mut handler_fut => {
469 break;
470 }
471 }
472 }
473 warn_bomb.defuse();
474 }
475 .instrument(execution_span)
476 });
477
478 state.in_progress_executions.lock().insert(
479 execution_id,
480 ExecutionState {
481 execution_id,
482 cancellation_token,
483 handle,
484 },
485 );
486
487 Ok(())
488}
489
490struct ExecutorState<'s> {
491 executor_id: Option<String>,
492 options: &'s ExecutorOptions,
493 handlers: &'s [Arc<dyn ExecutionHandlerRaw + Send + Sync>],
494 executor_requests: flume::Sender<ExecutorConnectionRequest>,
495 heartbeat_interval: std::time::Duration,
496 in_progress_executions: Arc<Mutex<IndexMap<Uuid, ExecutionState>>>,
497 wg: WaitGroup,
498}
499
500struct ExecutionState {
501 execution_id: Uuid,
502 cancellation_token: CancellationToken,
503 handle: tokio::task::JoinHandle<()>,
504}
505
506struct ExecutionDropWarnBomb {
507 span: tracing::Span,
508 defused: bool,
509}
510
511impl ExecutionDropWarnBomb {
512 fn new(span: tracing::Span) -> Self {
513 Self {
514 span,
515 defused: false,
516 }
517 }
518
519 fn defuse(&mut self) {
520 self.defused = true;
521 }
522}
523
524impl Drop for ExecutionDropWarnBomb {
525 fn drop(&mut self) {
526 if !self.defused {
527 self.span.in_scope(|| {
528 tracing::warn!("execution was dropped during execution");
529 });
530 }
531 }
532}