ora_client/executor/
run.rs

1use std::{panic::AssertUnwindSafe, sync::Arc, time::UNIX_EPOCH};
2
3use eyre::{Context, OptionExt};
4use futures::FutureExt;
5use ora_proto::{
6    common::v1::JobType,
7    server::v1::{
8        executor_message::ExecutorMessageKind, server_message::ServerMessageKind,
9        ExecutorCapabilities, ExecutorConnectionRequest, ExecutorConnectionResponse,
10        ExecutorHeartbeat, ExecutorMessage,
11    },
12};
13use parking_lot::Mutex;
14use tokio_util::sync::CancellationToken;
15use tracing::Instrument;
16use uuid::Uuid;
17use wgroup::WaitGroup;
18
19#[allow(clippy::wildcard_imports)]
20use tonic::codegen::*;
21
22use crate::{executor::ExecutionContext, IndexMap};
23
24use super::{ExecutionHandlerRaw, Executor, ExecutorOptions};
25
26impl<C> Executor<C>
27where
28    C: tonic::client::GrpcService<tonic::body::BoxBody> + Clone,
29    C::Error: Into<StdError>,
30    C::ResponseBody: Body<Data = Bytes> + std::marker::Send + 'static,
31    <C::ResponseBody as Body>::Error: Into<StdError> + std::marker::Send,
32{
33    /// Run the executor until an error occurs.
34    ///
35    /// The error includes any errors that would prevent
36    /// the executor from running, e.g. network errors.
37    ///
38    /// Individual task errors are not included as they are
39    /// part of the job execution.
40    #[tracing::instrument(skip_all, name = "executor_loop", fields(executor_id, executor_name))]
41    pub async fn run(&mut self) -> eyre::Result<()> {
42        let executor_span = tracing::Span::current();
43
44        executor_span.record("executor_name", &self.options.name);
45
46        let (executor_requests, recv) = flume::bounded(0);
47
48        let mut state = ExecutorState {
49            executor_id: None,
50            options: &self.options,
51            handlers: &self.handlers,
52            executor_requests,
53            // Start with a pessimistic heartbeat interval.
54            heartbeat_interval: std::time::Duration::from_secs(1),
55            in_progress_executions: Arc::new(Mutex::new(IndexMap::default())),
56            wg: WaitGroup::new(),
57        };
58
59        let send_chan_guard = state.wg.add_with("send-channel");
60
61        let mut server_messages = self
62            .client
63            // `Receiver::into_stream` exists, however it will cause
64            // this future to be `!Send` and the compiler goes absolutely
65            // bonkers about it.
66            .executor_connection(tonic::Request::new(async_stream::stream!({
67                loop {
68                    tokio::select! {
69                        _ = send_chan_guard.waiting() => {
70                            tracing::debug!("send channel closed, stopping stream");
71                            return;
72                        }
73                        msg = recv.recv_async() => {
74                            if let Ok(msg) = msg {
75                                tracing::trace!(?msg, "sending message");
76                                yield msg;
77                            } else {
78                                tracing::debug!("send channel closed, stopping stream");
79                                return;
80                            }
81                        }
82                    }
83                }
84            })))
85            .await?
86            .into_inner();
87
88        // Initial setup messages.
89        state
90            .executor_requests
91            .send_async(ExecutorConnectionRequest {
92                message: Some(ExecutorMessage {
93                    executor_message_kind: Some(ExecutorMessageKind::Capabilities(
94                        ExecutorCapabilities {
95                            max_concurrent_executions: state
96                                .options
97                                .max_concurrent_executions
98                                .get(),
99                            name: state.options.name.clone(),
100                            supported_job_types: self
101                                .handlers
102                                .iter()
103                                .map(|h| {
104                                    let handler_meta = h.job_type_metadata();
105
106                                    JobType {
107                                        id: handler_meta.id.to_string(),
108                                        name: handler_meta.name.clone(),
109                                        description: handler_meta.description.clone(),
110                                        input_schema_json: handler_meta.input_schema_json.clone(),
111                                        output_schema_json: handler_meta.output_schema_json.clone(),
112                                    }
113                                })
114                                .collect(),
115                        },
116                    )),
117                }),
118            })
119            .await?;
120
121        loop {
122            tokio::select! {
123                _ = tokio::time::sleep(state.heartbeat_interval) => {
124                    if state.executor_requests.send(ExecutorConnectionRequest {
125                        message: Some(ExecutorMessage {
126                            executor_message_kind: Some(ExecutorMessageKind::Heartbeat(
127                                ExecutorHeartbeat {},
128                            )),
129                        }),
130                    }).is_err() {
131                        return Ok(());
132                    }
133                }
134                server_msg = server_messages.message() => {
135                    match server_msg {
136                        Ok(Some(server_msg)) => {
137                            handle_server_response(&mut state, &executor_span, server_msg).await?;
138                        }
139                        Ok(None) => {
140                            tracing::info!("incoming stream closed by the server");
141
142                            if !state.in_progress_executions.lock().is_empty() {
143                                tracing::warn!("cancelling executions in progress");
144
145                                loop {
146                                    let execution_state = {
147                                        let mut in_progress_executions = state.in_progress_executions.lock();
148
149                                        if in_progress_executions.is_empty() {
150                                            break;
151                                        }
152
153                                        let execution_id = in_progress_executions.keys().copied().next();
154
155                                        if let Some(execution_id) = execution_id {
156                                            in_progress_executions.swap_remove(&execution_id)
157                                        } else {
158                                            None
159                                        }
160                                    };
161
162                                    if let Some(mut execution_state) = execution_state {
163                                        execution_state.cancellation_token.cancel();
164
165                                        tokio::select! {
166                                            _ = &mut execution_state.handle => {}
167                                            _ = tokio::time::sleep(state.options.cancellation_grace_period) => {
168                                                execution_state.handle.abort();
169                                            }
170                                        }
171                                    } else {
172                                        break;
173                                    }
174                                }
175                            }
176
177                            return Ok(());
178                        }
179                        Err(error) => {
180                            tracing::warn!(?error, "received error from the server");
181                        }
182                    }
183                }
184            }
185        }
186    }
187}
188
189#[tracing::instrument(name = "handle_server_message", skip_all)]
190async fn handle_server_response(
191    state: &mut ExecutorState<'_>,
192    executor_span: &tracing::Span,
193    response: ExecutorConnectionResponse,
194) -> eyre::Result<()> {
195    let Some(message) = response.message else {
196        tracing::warn!("received empty message from the server");
197        return Ok(());
198    };
199
200    let Some(message) = message.server_message_kind else {
201        tracing::warn!("received unknown or missing message kind from the server");
202        return Ok(());
203    };
204
205    match message {
206        ServerMessageKind::Properties(executor_properties) => {
207            executor_span.record("executor_id", &executor_properties.executor_id);
208            state.executor_id = Some(executor_properties.executor_id);
209
210            if let Some(max_hb_interval) = executor_properties.max_heartbeat_interval {
211                if let Ok(max_hb_interval) = std::time::Duration::try_from(max_hb_interval) {
212                    state.heartbeat_interval = max_hb_interval / 2;
213                    tracing::debug!(
214                        heartbeat_interval = ?state.heartbeat_interval,
215                        "using heartbeat interval"
216                    );
217                }
218            }
219
220            tracing::info!("received executor properties");
221        }
222        ServerMessageKind::ExecutionReady(execution_ready) => {
223            spawn_execution(state, execution_ready).await?;
224        }
225        ServerMessageKind::ExecutionCancelled(execution_cancelled) => {
226            let execution_id: Uuid = execution_cancelled
227                .execution_id
228                .parse()
229                .wrap_err("expected execution ID to be UUID")?;
230
231            let execution_state = state
232                .in_progress_executions
233                .lock()
234                .swap_remove(&execution_id);
235
236            if let Some(execution_state) = execution_state {
237                tokio::spawn(
238                    cancel_execution(execution_state, state.options.cancellation_grace_period)
239                        .instrument(tracing::Span::current()),
240                );
241            } else {
242                tracing::warn!("received cancellation for unknown execution");
243            }
244        }
245    }
246
247    Ok(())
248}
249
250#[tracing::instrument(skip_all, fields(
251    execution_id = %execution_state.execution_id,
252))]
253async fn cancel_execution(mut execution_state: ExecutionState, grace_period: std::time::Duration) {
254    tracing::debug!("cancelling execution");
255    execution_state.cancellation_token.cancel();
256
257    tokio::select! {
258        _ = &mut execution_state.handle => {
259            tracing::debug!("execution cancelled");
260        }
261        _ = tokio::time::sleep(grace_period) => {
262            if !execution_state.handle.is_finished() {
263                tracing::warn!("execution did not cancel in time, aborting forcefully");
264                execution_state.handle.abort();
265            }
266        }
267    }
268
269    tracing::debug!("cancelled execution");
270}
271
272#[tracing::instrument(skip_all,
273    fields(
274        execution_id = %execution_ready.execution_id,
275        job_id = %execution_ready.job_id,
276    )
277)]
278async fn spawn_execution(
279    state: &ExecutorState<'_>,
280    execution_ready: ora_proto::server::v1::ExecutionReady,
281) -> eyre::Result<()> {
282    let execution_span = tracing::Span::current();
283
284    let executor_requests = state.executor_requests.clone();
285
286    tracing::debug!("received new execution");
287
288    let execution_id: Uuid = execution_ready
289        .execution_id
290        .parse()
291        .wrap_err("expected execution ID to be UUID")?;
292
293    let job_id: Uuid = execution_ready
294        .job_id
295        .parse()
296        .wrap_err("expected job ID to be UUID")?;
297
298    let cancellation_token = CancellationToken::new();
299
300    let ctx = ExecutionContext {
301        execution_id,
302        job_id,
303        target_execution_time: execution_ready
304            .target_execution_time
305            .and_then(|t| t.try_into().ok())
306            .unwrap_or(UNIX_EPOCH),
307        attempt_number: execution_ready.attempt_number,
308        job_type_id: execution_ready.job_type_id,
309        cancellation_token: cancellation_token.clone(),
310    };
311
312    let handler = state
313        .handlers
314        .iter()
315        .find(|h| h.can_execute(&ctx))
316        .ok_or_eyre("no handler found for the execution")?
317        .clone();
318
319    tracing::trace!("found handler for the execution");
320
321    let now = std::time::SystemTime::now();
322
323    if executor_requests
324        .send_async(ExecutorConnectionRequest {
325            message: Some(ExecutorMessage {
326                executor_message_kind: Some(ExecutorMessageKind::ExecutionStarted(
327                    ora_proto::server::v1::ExecutionStarted {
328                        timestamp: Some(now.into()),
329                        execution_id: execution_ready.execution_id,
330                    },
331                )),
332            }),
333        })
334        .await
335        .is_err()
336    {
337        tracing::debug!("not sending execution started message, executor is shutting down");
338        return Ok(());
339    }
340    tracing::trace!("sent execution started message");
341
342    let execution_guard = state.wg.add_with(&format!("execution-{execution_id}"));
343
344    let cancellation_grace_period = state.options.cancellation_grace_period;
345
346    let handle = tokio::spawn({
347        let in_progress_executions = state.in_progress_executions.clone();
348        let in_progress_executions2 = state.in_progress_executions.clone();
349        tracing::debug!("executing handler");
350
351        let execution_id = ctx.execution_id;
352
353        async move {
354            let mut warn_bomb = ExecutionDropWarnBomb::new(tracing::Span::current());
355
356            let handler_fut = async move {
357                match AssertUnwindSafe(handler.execute(ctx, &execution_ready.input_payload_json))
358                    .catch_unwind()
359                    .await
360                {
361                    Ok(task_result) => match task_result {
362                        Ok(output_json) => {
363                            tracing::debug!("execution succeeded");
364                            let now = std::time::SystemTime::now();
365
366                            if let Err(error) = executor_requests
367                                .send_async(ExecutorConnectionRequest {
368                                    message: Some(ExecutorMessage {
369                                        executor_message_kind: Some(
370                                            ExecutorMessageKind::ExecutionSucceeded(
371                                                ora_proto::server::v1::ExecutionSucceeded {
372                                                    timestamp: Some(now.into()),
373                                                    execution_id: execution_id.to_string(),
374                                                    output_payload_json: output_json,
375                                                },
376                                            ),
377                                        ),
378                                    }),
379                                })
380                                .await
381                            {
382                                tracing::warn!(?error, "failed to send execution result");
383                            }
384                        }
385                        Err(error) => {
386                            tracing::debug!(error, "execution failed");
387                            let now = std::time::SystemTime::now();
388
389                            if let Err(error) = executor_requests
390                                .send_async(ExecutorConnectionRequest {
391                                    message: Some(ExecutorMessage {
392                                        executor_message_kind: Some(
393                                            ExecutorMessageKind::ExecutionFailed(
394                                                ora_proto::server::v1::ExecutionFailed {
395                                                    timestamp: Some(now.into()),
396                                                    execution_id: execution_id.to_string(),
397                                                    error_message: error,
398                                                },
399                                            ),
400                                        ),
401                                    }),
402                                })
403                                .await
404                            {
405                                tracing::warn!(?error, "failed to send execution result");
406                            }
407                        }
408                    },
409                    Err(panic_out) => {
410                        tracing::warn!("handler panicked");
411                        let now = std::time::SystemTime::now();
412
413                        let error_message = if let Some(error) = panic_out.downcast_ref::<&str>() {
414                            (*error).to_string()
415                        } else if let Some(error) = panic_out.downcast_ref::<String>() {
416                            error.clone()
417                        } else {
418                            "handler panicked".to_string()
419                        };
420
421                        if let Err(error) = executor_requests
422                            .send_async(ExecutorConnectionRequest {
423                                message: Some(ExecutorMessage {
424                                    executor_message_kind: Some(
425                                        ExecutorMessageKind::ExecutionFailed(
426                                            ora_proto::server::v1::ExecutionFailed {
427                                                timestamp: Some(now.into()),
428                                                execution_id: execution_id.to_string(),
429                                                error_message,
430                                            },
431                                        ),
432                                    ),
433                                }),
434                            })
435                            .await
436                        {
437                            tracing::warn!(?error, "failed to send execution result");
438                        }
439                    }
440                }
441
442                if in_progress_executions
443                    .lock()
444                    .swap_remove(&execution_id)
445                    .is_none()
446                {
447                    tracing::debug!(
448                        "execution was not found in the in-progress list, it must have been cancelled"
449                    );
450                }
451            };
452
453            let mut handler_fut = std::pin::pin!(handler_fut);
454
455            loop {
456                tokio::select! {
457                    _ = execution_guard.waiting() => {
458                        let execution_state = in_progress_executions2.lock().swap_remove(&execution_id);
459
460                        if let Some(execution_state) = execution_state {
461                            tokio::spawn(
462                                cancel_execution(execution_state, cancellation_grace_period)
463                                .instrument(tracing::Span::current()));
464                        }
465
466                        (&mut handler_fut).await;
467                    }
468                    _ = &mut handler_fut => {
469                        break;
470                    }
471                }
472            }
473            warn_bomb.defuse();
474        }
475        .instrument(execution_span)
476    });
477
478    state.in_progress_executions.lock().insert(
479        execution_id,
480        ExecutionState {
481            execution_id,
482            cancellation_token,
483            handle,
484        },
485    );
486
487    Ok(())
488}
489
490struct ExecutorState<'s> {
491    executor_id: Option<String>,
492    options: &'s ExecutorOptions,
493    handlers: &'s [Arc<dyn ExecutionHandlerRaw + Send + Sync>],
494    executor_requests: flume::Sender<ExecutorConnectionRequest>,
495    heartbeat_interval: std::time::Duration,
496    in_progress_executions: Arc<Mutex<IndexMap<Uuid, ExecutionState>>>,
497    wg: WaitGroup,
498}
499
500struct ExecutionState {
501    execution_id: Uuid,
502    cancellation_token: CancellationToken,
503    handle: tokio::task::JoinHandle<()>,
504}
505
506struct ExecutionDropWarnBomb {
507    span: tracing::Span,
508    defused: bool,
509}
510
511impl ExecutionDropWarnBomb {
512    fn new(span: tracing::Span) -> Self {
513        Self {
514            span,
515            defused: false,
516        }
517    }
518
519    fn defuse(&mut self) {
520        self.defused = true;
521    }
522}
523
524impl Drop for ExecutionDropWarnBomb {
525    fn drop(&mut self) {
526        if !self.defused {
527            self.span.in_scope(|| {
528                tracing::warn!("execution was dropped during execution");
529            });
530        }
531    }
532}