ora_client/executor/
run.rs

1use std::{panic::AssertUnwindSafe, sync::Arc, time::UNIX_EPOCH};
2
3use eyre::{Context, OptionExt};
4use futures::FutureExt;
5use ora_proto::{
6    common::v1::JobType,
7    server::v1::{
8        executor_message::ExecutorMessageKind, server_message::ServerMessageKind,
9        ExecutorCapabilities, ExecutorConnectionRequest, ExecutorConnectionResponse,
10        ExecutorHeartbeat, ExecutorMessage,
11    },
12};
13use parking_lot::Mutex;
14use tokio_util::sync::CancellationToken;
15use tracing::Instrument;
16use uuid::Uuid;
17use wgroup::WaitGroup;
18
19#[allow(clippy::wildcard_imports)]
20use tonic::codegen::*;
21
22use crate::{
23    executor::{ExecutionContext, ExecutionFailedCb},
24    IndexMap,
25};
26
27use super::{ExecutionHandlerRaw, Executor, ExecutorOptions};
28
29impl<C> Executor<C>
30where
31    C: tonic::client::GrpcService<tonic::body::BoxBody> + Clone,
32    C::Error: Into<StdError>,
33    C::ResponseBody: Body<Data = Bytes> + std::marker::Send + 'static,
34    <C::ResponseBody as Body>::Error: Into<StdError> + std::marker::Send,
35{
36    /// Run the executor until an error occurs.
37    ///
38    /// The error includes any errors that would prevent
39    /// the executor from running, e.g. network errors.
40    ///
41    /// Individual task errors are not included as they are
42    /// part of the job execution.
43    #[tracing::instrument(skip_all, name = "executor_loop", fields(executor_id, executor_name))]
44    pub async fn run(&mut self) -> eyre::Result<()> {
45        let executor_span = tracing::Span::current();
46
47        executor_span.record("executor_name", &self.options.name);
48
49        let (executor_requests, recv) = flume::bounded(0);
50
51        let on_execution_failed = self.on_execution_failed.clone();
52
53        let mut state = ExecutorState {
54            executor_id: None,
55            options: &self.options,
56            handlers: &self.handlers,
57            executor_requests,
58            // Start with a pessimistic heartbeat interval.
59            heartbeat_interval: std::time::Duration::from_secs(1),
60            in_progress_executions: Arc::new(Mutex::new(IndexMap::default())),
61            wg: WaitGroup::new(),
62        };
63
64        let send_chan_guard = state.wg.add_with("send-channel");
65
66        let mut server_messages = self
67            .client
68            // `Receiver::into_stream` exists, however it will cause
69            // this future to be `!Send` and the compiler goes absolutely
70            // bonkers about it.
71            .executor_connection(tonic::Request::new(async_stream::stream!({
72                loop {
73                    tokio::select! {
74                        _ = send_chan_guard.waiting() => {
75                            tracing::debug!("send channel closed, stopping stream");
76                            return;
77                        }
78                        msg = recv.recv_async() => {
79                            if let Ok(msg) = msg {
80                                tracing::trace!(?msg, "sending message");
81                                yield msg;
82                            } else {
83                                tracing::debug!("send channel closed, stopping stream");
84                                return;
85                            }
86                        }
87                    }
88                }
89            })))
90            .await?
91            .into_inner();
92
93        // Initial setup messages.
94        state
95            .executor_requests
96            .send_async(ExecutorConnectionRequest {
97                message: Some(ExecutorMessage {
98                    executor_message_kind: Some(ExecutorMessageKind::Capabilities(
99                        ExecutorCapabilities {
100                            max_concurrent_executions: state
101                                .options
102                                .max_concurrent_executions
103                                .get(),
104                            name: state.options.name.clone(),
105                            supported_job_types: self
106                                .handlers
107                                .iter()
108                                .map(|h| {
109                                    let handler_meta = h.job_type_metadata();
110
111                                    JobType {
112                                        id: handler_meta.id.to_string(),
113                                        name: handler_meta.name.clone(),
114                                        description: handler_meta.description.clone(),
115                                        input_schema_json: handler_meta.input_schema_json.clone(),
116                                        output_schema_json: handler_meta.output_schema_json.clone(),
117                                    }
118                                })
119                                .collect(),
120                        },
121                    )),
122                }),
123            })
124            .await?;
125
126        loop {
127            tokio::select! {
128                _ = tokio::time::sleep(state.heartbeat_interval) => {
129                    if state.executor_requests.send(ExecutorConnectionRequest {
130                        message: Some(ExecutorMessage {
131                            executor_message_kind: Some(ExecutorMessageKind::Heartbeat(
132                                ExecutorHeartbeat {},
133                            )),
134                        }),
135                    }).is_err() {
136                        return Ok(());
137                    }
138                }
139                server_msg = server_messages.message() => {
140                    match server_msg {
141                        Ok(Some(server_msg)) => {
142                            handle_server_response(&mut state, &executor_span, server_msg, on_execution_failed.clone()).await?;
143                        }
144                        Ok(None) => {
145                            tracing::info!("incoming stream closed by the server");
146
147                            if !state.in_progress_executions.lock().is_empty() {
148                                tracing::warn!("cancelling executions in progress");
149
150                                loop {
151                                    let execution_state = {
152                                        let mut in_progress_executions = state.in_progress_executions.lock();
153
154                                        if in_progress_executions.is_empty() {
155                                            break;
156                                        }
157
158                                        let execution_id = in_progress_executions.keys().copied().next();
159
160                                        if let Some(execution_id) = execution_id {
161                                            in_progress_executions.swap_remove(&execution_id)
162                                        } else {
163                                            None
164                                        }
165                                    };
166
167                                    if let Some(mut execution_state) = execution_state {
168                                        execution_state.cancellation_token.cancel();
169
170                                        tokio::select! {
171                                            _ = &mut execution_state.handle => {}
172                                            _ = tokio::time::sleep(state.options.cancellation_grace_period) => {
173                                                execution_state.handle.abort();
174                                            }
175                                        }
176                                    } else {
177                                        break;
178                                    }
179                                }
180                            }
181
182                            return Ok(());
183                        }
184                        Err(error) => {
185                            tracing::warn!(?error, "received error from the server");
186                        }
187                    }
188                }
189            }
190        }
191    }
192}
193
194#[tracing::instrument(name = "handle_server_message", skip_all)]
195async fn handle_server_response(
196    state: &mut ExecutorState<'_>,
197    executor_span: &tracing::Span,
198    response: ExecutorConnectionResponse,
199    on_execution_failed: Option<ExecutionFailedCb>,
200) -> eyre::Result<()> {
201    let Some(message) = response.message else {
202        tracing::warn!("received empty message from the server");
203        return Ok(());
204    };
205
206    let Some(message) = message.server_message_kind else {
207        tracing::warn!("received unknown or missing message kind from the server");
208        return Ok(());
209    };
210
211    match message {
212        ServerMessageKind::Properties(executor_properties) => {
213            executor_span.record("executor_id", &executor_properties.executor_id);
214            state.executor_id = Some(executor_properties.executor_id);
215
216            if let Some(max_hb_interval) = executor_properties.max_heartbeat_interval {
217                if let Ok(max_hb_interval) = std::time::Duration::try_from(max_hb_interval) {
218                    state.heartbeat_interval = max_hb_interval / 2;
219                    tracing::debug!(
220                        heartbeat_interval = ?state.heartbeat_interval,
221                        "using heartbeat interval"
222                    );
223                }
224            }
225
226            tracing::info!("received executor properties");
227        }
228        ServerMessageKind::ExecutionReady(execution_ready) => {
229            spawn_execution(state, execution_ready, on_execution_failed.clone()).await?;
230        }
231        ServerMessageKind::ExecutionCancelled(execution_cancelled) => {
232            let execution_id: Uuid = execution_cancelled
233                .execution_id
234                .parse()
235                .wrap_err("expected execution ID to be UUID")?;
236
237            let execution_state = state
238                .in_progress_executions
239                .lock()
240                .swap_remove(&execution_id);
241
242            if let Some(execution_state) = execution_state {
243                tokio::spawn(
244                    cancel_execution(execution_state, state.options.cancellation_grace_period)
245                        .instrument(tracing::Span::current()),
246                );
247            } else {
248                tracing::warn!("received cancellation for unknown execution");
249            }
250        }
251    }
252
253    Ok(())
254}
255
256#[tracing::instrument(skip_all, fields(
257    execution_id = %execution_state.execution_id,
258))]
259async fn cancel_execution(mut execution_state: ExecutionState, grace_period: std::time::Duration) {
260    tracing::debug!("cancelling execution");
261    execution_state.cancellation_token.cancel();
262
263    tokio::select! {
264        _ = &mut execution_state.handle => {
265            tracing::debug!("execution cancelled");
266        }
267        _ = tokio::time::sleep(grace_period) => {
268            if !execution_state.handle.is_finished() {
269                tracing::warn!("execution did not cancel in time, aborting forcefully");
270                execution_state.handle.abort();
271            }
272        }
273    }
274
275    tracing::debug!("cancelled execution");
276}
277
278#[tracing::instrument(skip_all,
279    fields(
280        execution_id = %execution_ready.execution_id,
281        job_id = %execution_ready.job_id,
282    )
283)]
284async fn spawn_execution(
285    state: &ExecutorState<'_>,
286    execution_ready: ora_proto::server::v1::ExecutionReady,
287    on_execution_failed: Option<ExecutionFailedCb>,
288) -> eyre::Result<()> {
289    let execution_span = tracing::Span::current();
290
291    let executor_requests = state.executor_requests.clone();
292
293    tracing::debug!("received new execution");
294
295    let execution_id: Uuid = execution_ready
296        .execution_id
297        .parse()
298        .wrap_err("expected execution ID to be UUID")?;
299
300    let job_id: Uuid = execution_ready
301        .job_id
302        .parse()
303        .wrap_err("expected job ID to be UUID")?;
304
305    let cancellation_token = CancellationToken::new();
306
307    let ctx = ExecutionContext {
308        execution_id,
309        job_id,
310        target_execution_time: execution_ready
311            .target_execution_time
312            .and_then(|t| t.try_into().ok())
313            .unwrap_or(UNIX_EPOCH),
314        attempt_number: execution_ready.attempt_number,
315        job_type_id: execution_ready.job_type_id,
316        cancellation_token: cancellation_token.clone(),
317    };
318
319    let handler = state
320        .handlers
321        .iter()
322        .find(|h| h.can_execute(&ctx))
323        .ok_or_eyre("no handler found for the execution")?
324        .clone();
325
326    tracing::trace!("found handler for the execution");
327
328    let now = std::time::SystemTime::now();
329
330    if executor_requests
331        .send_async(ExecutorConnectionRequest {
332            message: Some(ExecutorMessage {
333                executor_message_kind: Some(ExecutorMessageKind::ExecutionStarted(
334                    ora_proto::server::v1::ExecutionStarted {
335                        timestamp: Some(now.into()),
336                        execution_id: execution_ready.execution_id,
337                    },
338                )),
339            }),
340        })
341        .await
342        .is_err()
343    {
344        tracing::debug!("not sending execution started message, executor is shutting down");
345        return Ok(());
346    }
347    tracing::trace!("sent execution started message");
348
349    let execution_guard = state.wg.add_with(&format!("execution-{execution_id}"));
350
351    let cancellation_grace_period = state.options.cancellation_grace_period;
352
353    let handle = tokio::spawn({
354        let in_progress_executions = state.in_progress_executions.clone();
355        let in_progress_executions2 = state.in_progress_executions.clone();
356        tracing::debug!("executing handler");
357
358        let execution_id = ctx.execution_id;
359
360        async move {
361            let mut warn_bomb = ExecutionDropWarnBomb::new(tracing::Span::current());
362
363            let handler_fut = async move {
364                match AssertUnwindSafe(handler.execute(ctx.clone(), &execution_ready.input_payload_json))
365                    .catch_unwind()
366                    .await
367                {
368                    Ok(task_result) => match task_result {
369                        Ok(output_json) => {
370                            tracing::debug!("execution succeeded");
371                            let now = std::time::SystemTime::now();
372
373                            if let Err(error) = executor_requests
374                                .send_async(ExecutorConnectionRequest {
375                                    message: Some(ExecutorMessage {
376                                        executor_message_kind: Some(
377                                            ExecutorMessageKind::ExecutionSucceeded(
378                                                ora_proto::server::v1::ExecutionSucceeded {
379                                                    timestamp: Some(now.into()),
380                                                    execution_id: execution_id.to_string(),
381                                                    output_payload_json: output_json,
382                                                },
383                                            ),
384                                        ),
385                                    }),
386                                })
387                                .await
388                            {
389                                tracing::warn!(?error, "failed to send execution result");
390                            }
391                        }
392                        Err(error) => {
393                            tracing::debug!(error, "execution failed");
394
395                            if let Some(on_execution_failed) = &on_execution_failed {
396                                on_execution_failed(ctx, &error);
397                            }
398
399                            let now = std::time::SystemTime::now();
400
401                            if let Err(error) = executor_requests
402                                .send_async(ExecutorConnectionRequest {
403                                    message: Some(ExecutorMessage {
404                                        executor_message_kind: Some(
405                                            ExecutorMessageKind::ExecutionFailed(
406                                                ora_proto::server::v1::ExecutionFailed {
407                                                    timestamp: Some(now.into()),
408                                                    execution_id: execution_id.to_string(),
409                                                    error_message: error,
410                                                },
411                                            ),
412                                        ),
413                                    }),
414                                })
415                                .await
416                            {
417                                tracing::warn!(?error, "failed to send execution result");
418                            }
419                        }
420                    },
421                    Err(panic_out) => {
422                        tracing::warn!("handler panicked");
423                        let now = std::time::SystemTime::now();
424
425                        let error_message = if let Some(error) = panic_out.downcast_ref::<&str>() {
426                            (*error).to_string()
427                        } else if let Some(error) = panic_out.downcast_ref::<String>() {
428                            error.clone()
429                        } else {
430                            "handler panicked".to_string()
431                        };
432
433                        if let Some(on_execution_failed) = &on_execution_failed {
434                            on_execution_failed(ctx, &error_message);
435                        }
436
437                        if let Err(error) = executor_requests
438                            .send_async(ExecutorConnectionRequest {
439                                message: Some(ExecutorMessage {
440                                    executor_message_kind: Some(
441                                        ExecutorMessageKind::ExecutionFailed(
442                                            ora_proto::server::v1::ExecutionFailed {
443                                                timestamp: Some(now.into()),
444                                                execution_id: execution_id.to_string(),
445                                                error_message,
446                                            },
447                                        ),
448                                    ),
449                                }),
450                            })
451                            .await
452                        {
453                            tracing::warn!(?error, "failed to send execution result");
454                        }
455                    }
456                }
457
458                if in_progress_executions
459                    .lock()
460                    .swap_remove(&execution_id)
461                    .is_none()
462                {
463                    tracing::debug!(
464                        "execution was not found in the in-progress list, it must have been cancelled"
465                    );
466                }
467            };
468
469            let mut handler_fut = std::pin::pin!(handler_fut);
470
471            loop {
472                tokio::select! {
473                    _ = execution_guard.waiting() => {
474                        let execution_state = in_progress_executions2.lock().swap_remove(&execution_id);
475
476                        if let Some(execution_state) = execution_state {
477                            tokio::spawn(
478                                cancel_execution(execution_state, cancellation_grace_period)
479                                .instrument(tracing::Span::current()));
480                        }
481
482                        (&mut handler_fut).await;
483                    }
484                    _ = &mut handler_fut => {
485                        break;
486                    }
487                }
488            }
489            warn_bomb.defuse();
490        }
491        .instrument(execution_span)
492    });
493
494    state.in_progress_executions.lock().insert(
495        execution_id,
496        ExecutionState {
497            execution_id,
498            cancellation_token,
499            handle,
500        },
501    );
502
503    Ok(())
504}
505
506struct ExecutorState<'s> {
507    executor_id: Option<String>,
508    options: &'s ExecutorOptions,
509    handlers: &'s [Arc<dyn ExecutionHandlerRaw + Send + Sync>],
510    executor_requests: flume::Sender<ExecutorConnectionRequest>,
511    heartbeat_interval: std::time::Duration,
512    in_progress_executions: Arc<Mutex<IndexMap<Uuid, ExecutionState>>>,
513    wg: WaitGroup,
514}
515
516struct ExecutionState {
517    execution_id: Uuid,
518    cancellation_token: CancellationToken,
519    handle: tokio::task::JoinHandle<()>,
520}
521
522struct ExecutionDropWarnBomb {
523    span: tracing::Span,
524    defused: bool,
525}
526
527impl ExecutionDropWarnBomb {
528    fn new(span: tracing::Span) -> Self {
529        Self {
530            span,
531            defused: false,
532        }
533    }
534
535    fn defuse(&mut self) {
536        self.defused = true;
537    }
538}
539
540impl Drop for ExecutionDropWarnBomb {
541    fn drop(&mut self) {
542        if !self.defused {
543            self.span.in_scope(|| {
544                tracing::warn!("execution in progess was dropped");
545            });
546        }
547    }
548}