1use std::{panic::AssertUnwindSafe, sync::Arc, time::UNIX_EPOCH};
2
3use eyre::{Context, OptionExt};
4use futures::FutureExt;
5use ora_proto::{
6 common::v1::JobType,
7 server::v1::{
8 executor_message::ExecutorMessageKind, server_message::ServerMessageKind,
9 ExecutorCapabilities, ExecutorConnectionRequest, ExecutorConnectionResponse,
10 ExecutorHeartbeat, ExecutorMessage,
11 },
12};
13use parking_lot::Mutex;
14use tokio_util::sync::CancellationToken;
15use tracing::Instrument;
16use uuid::Uuid;
17use wgroup::WaitGroup;
18
19#[allow(clippy::wildcard_imports)]
20use tonic::codegen::*;
21
22use crate::{executor::ExecutionContext, IndexMap};
23
24use super::{ExecutionHandlerRaw, Executor, ExecutorOptions};
25
26impl<C> Executor<C>
27where
28 C: tonic::client::GrpcService<tonic::body::BoxBody> + Clone,
29 C::Error: Into<StdError>,
30 C::ResponseBody: Body<Data = Bytes> + std::marker::Send + 'static,
31 <C::ResponseBody as Body>::Error: Into<StdError> + std::marker::Send,
32{
33 #[tracing::instrument(skip_all, name = "executor_loop", fields(executor_id, executor_name))]
41 pub async fn run(&mut self) -> eyre::Result<()> {
42 let executor_span = tracing::Span::current();
43
44 executor_span.record("executor_name", &self.options.name);
45
46 let (executor_requests, recv) = flume::bounded(0);
47
48 let mut state = ExecutorState {
49 executor_id: None,
50 options: &self.options,
51 handlers: &self.handlers,
52 executor_requests,
53 heartbeat_interval: std::time::Duration::from_secs(1),
55 in_progress_executions: Arc::new(Mutex::new(IndexMap::default())),
56 wg: WaitGroup::new(),
57 };
58
59 let send_chan_guard = state.wg.add_with("send-channel");
60
61 let mut server_messages = self
62 .client
63 .executor_connection(tonic::Request::new(async_stream::stream!({
67 loop {
68 tokio::select! {
69 _ = send_chan_guard.waiting() => {
70 tracing::debug!("send channel closed, stopping stream");
71 return;
72 }
73 msg = recv.recv_async() => {
74 if let Ok(msg) = msg {
75 yield msg;
76 } else {
77 tracing::debug!("send channel closed, stopping stream");
78 return;
79 }
80 }
81 }
82 }
83 })))
84 .await?
85 .into_inner();
86
87 state
89 .executor_requests
90 .send_async(ExecutorConnectionRequest {
91 message: Some(ExecutorMessage {
92 executor_message_kind: Some(ExecutorMessageKind::Capabilities(
93 ExecutorCapabilities {
94 max_concurrent_executions: state
95 .options
96 .max_concurrent_executions
97 .get(),
98 name: state.options.name.clone(),
99 supported_job_types: self
100 .handlers
101 .iter()
102 .map(|h| {
103 let handler_meta = h.job_type_metadata();
104
105 JobType {
106 id: handler_meta.id.to_string(),
107 name: handler_meta.name.clone(),
108 description: handler_meta.description.clone(),
109 input_schema_json: handler_meta.input_schema_json.clone(),
110 output_schema_json: handler_meta.output_schema_json.clone(),
111 }
112 })
113 .collect(),
114 },
115 )),
116 }),
117 })
118 .await?;
119
120 loop {
121 tokio::select! {
122 _ = tokio::time::sleep(state.heartbeat_interval) => {
123 if state.executor_requests.send(ExecutorConnectionRequest {
124 message: Some(ExecutorMessage {
125 executor_message_kind: Some(ExecutorMessageKind::Heartbeat(
126 ExecutorHeartbeat {},
127 )),
128 }),
129 }).is_err() {
130 return Ok(());
131 }
132 }
133 server_msg = server_messages.message() => {
134 match server_msg {
135 Ok(Some(server_msg)) => {
136 handle_server_response(&mut state, &executor_span, server_msg).await?;
137 }
138 Ok(None) => {
139 tracing::info!("incoming stream closed by the server");
140
141 if !state.in_progress_executions.lock().is_empty() {
142 tracing::warn!("cancelling executions in progress");
143
144 loop {
145 let execution_state = {
146 let mut in_progress_executions = state.in_progress_executions.lock();
147
148 if in_progress_executions.is_empty() {
149 break;
150 }
151
152 let execution_id = in_progress_executions.keys().copied().next();
153
154 if let Some(execution_id) = execution_id {
155 in_progress_executions.swap_remove(&execution_id)
156 } else {
157 None
158 }
159 };
160
161 if let Some(mut execution_state) = execution_state {
162 execution_state.cancellation_token.cancel();
163
164 tokio::select! {
165 _ = &mut execution_state.handle => {}
166 _ = tokio::time::sleep(state.options.cancellation_grace_period) => {
167 execution_state.handle.abort();
168 }
169 }
170 } else {
171 break;
172 }
173 }
174 }
175
176 return Ok(());
177 }
178 Err(error) => {
179 tracing::warn!(?error, "received error from the server");
180 }
181 }
182 }
183 }
184 }
185 }
186}
187
188#[tracing::instrument(name = "handle_server_message", skip_all)]
189async fn handle_server_response(
190 state: &mut ExecutorState<'_>,
191 executor_span: &tracing::Span,
192 response: ExecutorConnectionResponse,
193) -> eyre::Result<()> {
194 let Some(message) = response.message else {
195 tracing::warn!("received empty message from the server");
196 return Ok(());
197 };
198
199 let Some(message) = message.server_message_kind else {
200 tracing::warn!("received unknown or missing message kind from the server");
201 return Ok(());
202 };
203
204 match message {
205 ServerMessageKind::Properties(executor_properties) => {
206 executor_span.record("executor_id", &executor_properties.executor_id);
207 state.executor_id = Some(executor_properties.executor_id);
208
209 if let Some(max_hb_interval) = executor_properties.max_heartbeat_interval {
210 if let Ok(max_hb_interval) = std::time::Duration::try_from(max_hb_interval) {
211 state.heartbeat_interval = max_hb_interval / 2;
212 tracing::debug!(
213 heartbeat_interval = ?state.heartbeat_interval,
214 "using heartbeat interval"
215 );
216 }
217 }
218
219 tracing::info!("received executor properties");
220 }
221 ServerMessageKind::ExecutionReady(execution_ready) => {
222 spawn_execution(state, execution_ready).await?;
223 }
224 ServerMessageKind::ExecutionCancelled(execution_cancelled) => {
225 let execution_id: Uuid = execution_cancelled
226 .execution_id
227 .parse()
228 .wrap_err("expected execution ID to be UUID")?;
229
230 let execution_state = state
231 .in_progress_executions
232 .lock()
233 .swap_remove(&execution_id);
234
235 if let Some(execution_state) = execution_state {
236 tokio::spawn(
237 cancel_execution(execution_state, state.options.cancellation_grace_period)
238 .instrument(tracing::Span::current()),
239 );
240 } else {
241 tracing::warn!("received cancellation for unknown execution");
242 }
243 }
244 }
245
246 Ok(())
247}
248
249#[tracing::instrument(skip_all, fields(
250 execution_id = %execution_state.execution_id,
251))]
252async fn cancel_execution(mut execution_state: ExecutionState, grace_period: std::time::Duration) {
253 tracing::debug!("cancelling execution");
254 execution_state.cancellation_token.cancel();
255
256 tokio::select! {
257 _ = &mut execution_state.handle => {
258 tracing::debug!("execution cancelled");
259 }
260 _ = tokio::time::sleep(grace_period) => {
261 if !execution_state.handle.is_finished() {
262 tracing::warn!("execution did not cancel in time, aborting forcefully");
263 execution_state.handle.abort();
264 }
265 }
266 }
267
268 tracing::debug!("cancelled execution");
269}
270
271#[tracing::instrument(skip_all,
272 fields(
273 execution_id = %execution_ready.execution_id,
274 job_id = %execution_ready.job_id,
275 )
276)]
277async fn spawn_execution(
278 state: &ExecutorState<'_>,
279 execution_ready: ora_proto::server::v1::ExecutionReady,
280) -> eyre::Result<()> {
281 let execution_span = tracing::Span::current();
282
283 let executor_requests = state.executor_requests.clone();
284
285 tracing::debug!("received new execution");
286
287 let execution_id: Uuid = execution_ready
288 .execution_id
289 .parse()
290 .wrap_err("expected execution ID to be UUID")?;
291
292 let job_id: Uuid = execution_ready
293 .job_id
294 .parse()
295 .wrap_err("expected job ID to be UUID")?;
296
297 let cancellation_token = CancellationToken::new();
298
299 let ctx = ExecutionContext {
300 execution_id,
301 job_id,
302 target_execution_time: execution_ready
303 .target_execution_time
304 .and_then(|t| t.try_into().ok())
305 .unwrap_or(UNIX_EPOCH),
306 attempt_number: execution_ready.attempt_number,
307 job_type_id: execution_ready.job_type_id,
308 cancellation_token: cancellation_token.clone(),
309 };
310
311 let handler = state
312 .handlers
313 .iter()
314 .find(|h| h.can_execute(&ctx))
315 .ok_or_eyre("no handler found for the execution")?
316 .clone();
317
318 tracing::trace!("found handler for the execution");
319
320 let now = std::time::SystemTime::now();
321
322 if executor_requests
323 .send_async(ExecutorConnectionRequest {
324 message: Some(ExecutorMessage {
325 executor_message_kind: Some(ExecutorMessageKind::ExecutionStarted(
326 ora_proto::server::v1::ExecutionStarted {
327 timestamp: Some(now.into()),
328 execution_id: execution_ready.execution_id,
329 },
330 )),
331 }),
332 })
333 .await
334 .is_err()
335 {
336 tracing::debug!("not sending execution started message, executor is shutting down");
337 return Ok(());
338 }
339 tracing::trace!("sent execution started message");
340
341 let execution_guard = state.wg.add_with(&format!("execution-{execution_id}"));
342
343 let cancellation_grace_period = state.options.cancellation_grace_period;
344
345 let handle = tokio::spawn({
346 let in_progress_executions = state.in_progress_executions.clone();
347 let in_progress_executions2 = state.in_progress_executions.clone();
348 tracing::debug!("executing handler");
349
350 let execution_id = ctx.execution_id;
351
352 async move {
353 let mut warn_bomb = ExecutionDropWarnBomb::new(tracing::Span::current());
354
355 let handler_fut = async move {
356 match AssertUnwindSafe(handler.execute(ctx, &execution_ready.input_payload_json))
357 .catch_unwind()
358 .await
359 {
360 Ok(task_result) => match task_result {
361 Ok(output_json) => {
362 tracing::debug!("execution succeeded");
363 let now = std::time::SystemTime::now();
364
365 if let Err(error) = executor_requests
366 .send_async(ExecutorConnectionRequest {
367 message: Some(ExecutorMessage {
368 executor_message_kind: Some(
369 ExecutorMessageKind::ExecutionSucceeded(
370 ora_proto::server::v1::ExecutionSucceeded {
371 timestamp: Some(now.into()),
372 execution_id: execution_id.to_string(),
373 output_payload_json: output_json,
374 },
375 ),
376 ),
377 }),
378 })
379 .await
380 {
381 tracing::warn!(?error, "failed to send execution result");
382 }
383 }
384 Err(error) => {
385 tracing::debug!(error, "execution failed");
386 let now = std::time::SystemTime::now();
387
388 if let Err(error) = executor_requests
389 .send_async(ExecutorConnectionRequest {
390 message: Some(ExecutorMessage {
391 executor_message_kind: Some(
392 ExecutorMessageKind::ExecutionFailed(
393 ora_proto::server::v1::ExecutionFailed {
394 timestamp: Some(now.into()),
395 execution_id: execution_id.to_string(),
396 error_message: error,
397 },
398 ),
399 ),
400 }),
401 })
402 .await
403 {
404 tracing::warn!(?error, "failed to send execution result");
405 }
406 }
407 },
408 Err(panic_out) => {
409 tracing::warn!("handler panicked");
410 let now = std::time::SystemTime::now();
411
412 let error_message = if let Some(error) = panic_out.downcast_ref::<&str>() {
413 (*error).to_string()
414 } else if let Some(error) = panic_out.downcast_ref::<String>() {
415 error.clone()
416 } else {
417 "handler panicked".to_string()
418 };
419
420 if let Err(error) = executor_requests
421 .send_async(ExecutorConnectionRequest {
422 message: Some(ExecutorMessage {
423 executor_message_kind: Some(
424 ExecutorMessageKind::ExecutionFailed(
425 ora_proto::server::v1::ExecutionFailed {
426 timestamp: Some(now.into()),
427 execution_id: execution_id.to_string(),
428 error_message,
429 },
430 ),
431 ),
432 }),
433 })
434 .await
435 {
436 tracing::warn!(?error, "failed to send execution result");
437 }
438 }
439 }
440
441 if in_progress_executions
442 .lock()
443 .swap_remove(&execution_id)
444 .is_none()
445 {
446 tracing::debug!(
447 "execution was not found in the in-progress list, it must have been cancelled"
448 );
449 }
450 };
451
452 let mut handler_fut = std::pin::pin!(handler_fut);
453
454 loop {
455 tokio::select! {
456 _ = execution_guard.waiting() => {
457 let execution_state = in_progress_executions2.lock().swap_remove(&execution_id);
458
459 if let Some(execution_state) = execution_state {
460 tokio::spawn(
461 cancel_execution(execution_state, cancellation_grace_period)
462 .instrument(tracing::Span::current()));
463 }
464
465 (&mut handler_fut).await;
466 }
467 _ = &mut handler_fut => {
468 break;
469 }
470 }
471 }
472 warn_bomb.defuse();
473 }
474 .instrument(execution_span)
475 });
476
477 state.in_progress_executions.lock().insert(
478 execution_id,
479 ExecutionState {
480 execution_id,
481 cancellation_token,
482 handle,
483 },
484 );
485
486 Ok(())
487}
488
489struct ExecutorState<'s> {
490 executor_id: Option<String>,
491 options: &'s ExecutorOptions,
492 handlers: &'s [Arc<dyn ExecutionHandlerRaw + Send + Sync>],
493 executor_requests: flume::Sender<ExecutorConnectionRequest>,
494 heartbeat_interval: std::time::Duration,
495 in_progress_executions: Arc<Mutex<IndexMap<Uuid, ExecutionState>>>,
496 wg: WaitGroup,
497}
498
499struct ExecutionState {
500 execution_id: Uuid,
501 cancellation_token: CancellationToken,
502 handle: tokio::task::JoinHandle<()>,
503}
504
505struct ExecutionDropWarnBomb {
506 span: tracing::Span,
507 defused: bool,
508}
509
510impl ExecutionDropWarnBomb {
511 fn new(span: tracing::Span) -> Self {
512 Self {
513 span,
514 defused: false,
515 }
516 }
517
518 fn defuse(&mut self) {
519 self.defused = true;
520 }
521}
522
523impl Drop for ExecutionDropWarnBomb {
524 fn drop(&mut self) {
525 if !self.defused {
526 self.span.in_scope(|| {
527 tracing::warn!("execution was dropped during execution");
528 });
529 }
530 }
531}