ora_client/executor/
run.rs

1use std::{panic::AssertUnwindSafe, sync::Arc, time::UNIX_EPOCH};
2
3use eyre::{Context, OptionExt};
4use futures::FutureExt;
5use ora_proto::{
6    common::v1::JobType,
7    server::v1::{
8        executor_message::ExecutorMessageKind, server_message::ServerMessageKind,
9        ExecutorCapabilities, ExecutorConnectionRequest, ExecutorConnectionResponse,
10        ExecutorHeartbeat, ExecutorMessage,
11    },
12};
13use parking_lot::Mutex;
14use tokio_util::sync::CancellationToken;
15use tracing::Instrument;
16use uuid::Uuid;
17use wgroup::WaitGroup;
18
19#[allow(clippy::wildcard_imports)]
20use tonic::codegen::*;
21
22use crate::{executor::ExecutionContext, IndexMap};
23
24use super::{ExecutionHandlerRaw, Executor, ExecutorOptions};
25
26impl<C> Executor<C>
27where
28    C: tonic::client::GrpcService<tonic::body::BoxBody> + Clone,
29    C::Error: Into<StdError>,
30    C::ResponseBody: Body<Data = Bytes> + std::marker::Send + 'static,
31    <C::ResponseBody as Body>::Error: Into<StdError> + std::marker::Send,
32{
33    /// Run the executor until an error occurs.
34    ///
35    /// The error includes any errors that would prevent
36    /// the executor from running, e.g. network errors.
37    ///
38    /// Individual task errors are not included as they are
39    /// part of the job execution.
40    #[tracing::instrument(skip_all, name = "executor_loop", fields(executor_id, executor_name))]
41    pub async fn run(&mut self) -> eyre::Result<()> {
42        let executor_span = tracing::Span::current();
43
44        executor_span.record("executor_name", &self.options.name);
45
46        let (executor_requests, recv) = flume::bounded(0);
47
48        let mut state = ExecutorState {
49            executor_id: None,
50            options: &self.options,
51            handlers: &self.handlers,
52            executor_requests,
53            // Start with a pessimistic heartbeat interval.
54            heartbeat_interval: std::time::Duration::from_secs(1),
55            in_progress_executions: Arc::new(Mutex::new(IndexMap::default())),
56            wg: WaitGroup::new(),
57        };
58
59        let send_chan_guard = state.wg.add_with("send-channel");
60
61        let mut server_messages = self
62            .client
63            // `Receiver::into_stream` exists, however it will cause
64            // this future to be `!Send` and the compiler goes absolutely
65            // bonkers about it.
66            .executor_connection(tonic::Request::new(async_stream::stream!({
67                loop {
68                    tokio::select! {
69                        _ = send_chan_guard.waiting() => {
70                            tracing::debug!("send channel closed, stopping stream");
71                            return;
72                        }
73                        msg = recv.recv_async() => {
74                            if let Ok(msg) = msg {
75                                yield msg;
76                            } else {
77                                tracing::debug!("send channel closed, stopping stream");
78                                return;
79                            }
80                        }
81                    }
82                }
83            })))
84            .await?
85            .into_inner();
86
87        // Initial setup messages.
88        state
89            .executor_requests
90            .send_async(ExecutorConnectionRequest {
91                message: Some(ExecutorMessage {
92                    executor_message_kind: Some(ExecutorMessageKind::Capabilities(
93                        ExecutorCapabilities {
94                            max_concurrent_executions: state
95                                .options
96                                .max_concurrent_executions
97                                .get(),
98                            name: state.options.name.clone(),
99                            supported_job_types: self
100                                .handlers
101                                .iter()
102                                .map(|h| {
103                                    let handler_meta = h.job_type_metadata();
104
105                                    JobType {
106                                        id: handler_meta.id.to_string(),
107                                        name: handler_meta.name.clone(),
108                                        description: handler_meta.description.clone(),
109                                        input_schema_json: handler_meta.input_schema_json.clone(),
110                                        output_schema_json: handler_meta.output_schema_json.clone(),
111                                    }
112                                })
113                                .collect(),
114                        },
115                    )),
116                }),
117            })
118            .await?;
119
120        loop {
121            tokio::select! {
122                _ = tokio::time::sleep(state.heartbeat_interval) => {
123                    if state.executor_requests.send(ExecutorConnectionRequest {
124                        message: Some(ExecutorMessage {
125                            executor_message_kind: Some(ExecutorMessageKind::Heartbeat(
126                                ExecutorHeartbeat {},
127                            )),
128                        }),
129                    }).is_err() {
130                        return Ok(());
131                    }
132                }
133                server_msg = server_messages.message() => {
134                    match server_msg {
135                        Ok(Some(server_msg)) => {
136                            handle_server_response(&mut state, &executor_span, server_msg).await?;
137                        }
138                        Ok(None) => {
139                            tracing::info!("incoming stream closed by the server");
140
141                            if !state.in_progress_executions.lock().is_empty() {
142                                tracing::warn!("cancelling executions in progress");
143
144                                loop {
145                                    let execution_state = {
146                                        let mut in_progress_executions = state.in_progress_executions.lock();
147
148                                        if in_progress_executions.is_empty() {
149                                            break;
150                                        }
151
152                                        let execution_id = in_progress_executions.keys().copied().next();
153
154                                        if let Some(execution_id) = execution_id {
155                                            in_progress_executions.swap_remove(&execution_id)
156                                        } else {
157                                            None
158                                        }
159                                    };
160
161                                    if let Some(mut execution_state) = execution_state {
162                                        execution_state.cancellation_token.cancel();
163
164                                        tokio::select! {
165                                            _ = &mut execution_state.handle => {}
166                                            _ = tokio::time::sleep(state.options.cancellation_grace_period) => {
167                                                execution_state.handle.abort();
168                                            }
169                                        }
170                                    } else {
171                                        break;
172                                    }
173                                }
174                            }
175
176                            return Ok(());
177                        }
178                        Err(error) => {
179                            tracing::warn!(?error, "received error from the server");
180                        }
181                    }
182                }
183            }
184        }
185    }
186}
187
188#[tracing::instrument(name = "handle_server_message", skip_all)]
189async fn handle_server_response(
190    state: &mut ExecutorState<'_>,
191    executor_span: &tracing::Span,
192    response: ExecutorConnectionResponse,
193) -> eyre::Result<()> {
194    let Some(message) = response.message else {
195        tracing::warn!("received empty message from the server");
196        return Ok(());
197    };
198
199    let Some(message) = message.server_message_kind else {
200        tracing::warn!("received unknown or missing message kind from the server");
201        return Ok(());
202    };
203
204    match message {
205        ServerMessageKind::Properties(executor_properties) => {
206            executor_span.record("executor_id", &executor_properties.executor_id);
207            state.executor_id = Some(executor_properties.executor_id);
208
209            if let Some(max_hb_interval) = executor_properties.max_heartbeat_interval {
210                if let Ok(max_hb_interval) = std::time::Duration::try_from(max_hb_interval) {
211                    state.heartbeat_interval = max_hb_interval / 2;
212                    tracing::debug!(
213                        heartbeat_interval = ?state.heartbeat_interval,
214                        "using heartbeat interval"
215                    );
216                }
217            }
218
219            tracing::info!("received executor properties");
220        }
221        ServerMessageKind::ExecutionReady(execution_ready) => {
222            spawn_execution(state, execution_ready).await?;
223        }
224        ServerMessageKind::ExecutionCancelled(execution_cancelled) => {
225            let execution_id: Uuid = execution_cancelled
226                .execution_id
227                .parse()
228                .wrap_err("expected execution ID to be UUID")?;
229
230            let execution_state = state
231                .in_progress_executions
232                .lock()
233                .swap_remove(&execution_id);
234
235            if let Some(execution_state) = execution_state {
236                tokio::spawn(
237                    cancel_execution(execution_state, state.options.cancellation_grace_period)
238                        .instrument(tracing::Span::current()),
239                );
240            } else {
241                tracing::warn!("received cancellation for unknown execution");
242            }
243        }
244    }
245
246    Ok(())
247}
248
249#[tracing::instrument(skip_all, fields(
250    execution_id = %execution_state.execution_id,
251))]
252async fn cancel_execution(mut execution_state: ExecutionState, grace_period: std::time::Duration) {
253    tracing::debug!("cancelling execution");
254    execution_state.cancellation_token.cancel();
255
256    tokio::select! {
257        _ = &mut execution_state.handle => {
258            tracing::debug!("execution cancelled");
259        }
260        _ = tokio::time::sleep(grace_period) => {
261            if !execution_state.handle.is_finished() {
262                tracing::warn!("execution did not cancel in time, aborting forcefully");
263                execution_state.handle.abort();
264            }
265        }
266    }
267
268    tracing::debug!("cancelled execution");
269}
270
271#[tracing::instrument(skip_all,
272    fields(
273        execution_id = %execution_ready.execution_id,
274        job_id = %execution_ready.job_id,
275    )
276)]
277async fn spawn_execution(
278    state: &ExecutorState<'_>,
279    execution_ready: ora_proto::server::v1::ExecutionReady,
280) -> eyre::Result<()> {
281    let execution_span = tracing::Span::current();
282
283    let executor_requests = state.executor_requests.clone();
284
285    tracing::debug!("received new execution");
286
287    let execution_id: Uuid = execution_ready
288        .execution_id
289        .parse()
290        .wrap_err("expected execution ID to be UUID")?;
291
292    let job_id: Uuid = execution_ready
293        .job_id
294        .parse()
295        .wrap_err("expected job ID to be UUID")?;
296
297    let cancellation_token = CancellationToken::new();
298
299    let ctx = ExecutionContext {
300        execution_id,
301        job_id,
302        target_execution_time: execution_ready
303            .target_execution_time
304            .and_then(|t| t.try_into().ok())
305            .unwrap_or(UNIX_EPOCH),
306        attempt_number: execution_ready.attempt_number,
307        job_type_id: execution_ready.job_type_id,
308        cancellation_token: cancellation_token.clone(),
309    };
310
311    let handler = state
312        .handlers
313        .iter()
314        .find(|h| h.can_execute(&ctx))
315        .ok_or_eyre("no handler found for the execution")?
316        .clone();
317
318    tracing::trace!("found handler for the execution");
319
320    let now = std::time::SystemTime::now();
321
322    if executor_requests
323        .send_async(ExecutorConnectionRequest {
324            message: Some(ExecutorMessage {
325                executor_message_kind: Some(ExecutorMessageKind::ExecutionStarted(
326                    ora_proto::server::v1::ExecutionStarted {
327                        timestamp: Some(now.into()),
328                        execution_id: execution_ready.execution_id,
329                    },
330                )),
331            }),
332        })
333        .await
334        .is_err()
335    {
336        tracing::debug!("not sending execution started message, executor is shutting down");
337        return Ok(());
338    }
339    tracing::trace!("sent execution started message");
340
341    let execution_guard = state.wg.add_with(&format!("execution-{execution_id}"));
342
343    let cancellation_grace_period = state.options.cancellation_grace_period;
344
345    let handle = tokio::spawn({
346        let in_progress_executions = state.in_progress_executions.clone();
347        let in_progress_executions2 = state.in_progress_executions.clone();
348        tracing::debug!("executing handler");
349
350        let execution_id = ctx.execution_id;
351
352        async move {
353            let mut warn_bomb = ExecutionDropWarnBomb::new(tracing::Span::current());
354
355            let handler_fut = async move {
356                match AssertUnwindSafe(handler.execute(ctx, &execution_ready.input_payload_json))
357                    .catch_unwind()
358                    .await
359                {
360                    Ok(task_result) => match task_result {
361                        Ok(output_json) => {
362                            tracing::debug!("execution succeeded");
363                            let now = std::time::SystemTime::now();
364
365                            if let Err(error) = executor_requests
366                                .send_async(ExecutorConnectionRequest {
367                                    message: Some(ExecutorMessage {
368                                        executor_message_kind: Some(
369                                            ExecutorMessageKind::ExecutionSucceeded(
370                                                ora_proto::server::v1::ExecutionSucceeded {
371                                                    timestamp: Some(now.into()),
372                                                    execution_id: execution_id.to_string(),
373                                                    output_payload_json: output_json,
374                                                },
375                                            ),
376                                        ),
377                                    }),
378                                })
379                                .await
380                            {
381                                tracing::warn!(?error, "failed to send execution result");
382                            }
383                        }
384                        Err(error) => {
385                            tracing::debug!(error, "execution failed");
386                            let now = std::time::SystemTime::now();
387
388                            if let Err(error) = executor_requests
389                                .send_async(ExecutorConnectionRequest {
390                                    message: Some(ExecutorMessage {
391                                        executor_message_kind: Some(
392                                            ExecutorMessageKind::ExecutionFailed(
393                                                ora_proto::server::v1::ExecutionFailed {
394                                                    timestamp: Some(now.into()),
395                                                    execution_id: execution_id.to_string(),
396                                                    error_message: error,
397                                                },
398                                            ),
399                                        ),
400                                    }),
401                                })
402                                .await
403                            {
404                                tracing::warn!(?error, "failed to send execution result");
405                            }
406                        }
407                    },
408                    Err(panic_out) => {
409                        tracing::warn!("handler panicked");
410                        let now = std::time::SystemTime::now();
411
412                        let error_message = if let Some(error) = panic_out.downcast_ref::<&str>() {
413                            (*error).to_string()
414                        } else if let Some(error) = panic_out.downcast_ref::<String>() {
415                            error.clone()
416                        } else {
417                            "handler panicked".to_string()
418                        };
419
420                        if let Err(error) = executor_requests
421                            .send_async(ExecutorConnectionRequest {
422                                message: Some(ExecutorMessage {
423                                    executor_message_kind: Some(
424                                        ExecutorMessageKind::ExecutionFailed(
425                                            ora_proto::server::v1::ExecutionFailed {
426                                                timestamp: Some(now.into()),
427                                                execution_id: execution_id.to_string(),
428                                                error_message,
429                                            },
430                                        ),
431                                    ),
432                                }),
433                            })
434                            .await
435                        {
436                            tracing::warn!(?error, "failed to send execution result");
437                        }
438                    }
439                }
440
441                if in_progress_executions
442                    .lock()
443                    .swap_remove(&execution_id)
444                    .is_none()
445                {
446                    tracing::debug!(
447                        "execution was not found in the in-progress list, it must have been cancelled"
448                    );
449                }
450            };
451
452            let mut handler_fut = std::pin::pin!(handler_fut);
453
454            loop {
455                tokio::select! {
456                    _ = execution_guard.waiting() => {
457                        let execution_state = in_progress_executions2.lock().swap_remove(&execution_id);
458
459                        if let Some(execution_state) = execution_state {
460                            tokio::spawn(
461                                cancel_execution(execution_state, cancellation_grace_period)
462                                .instrument(tracing::Span::current()));
463                        }
464
465                        (&mut handler_fut).await;
466                    }
467                    _ = &mut handler_fut => {
468                        break;
469                    }
470                }
471            }
472            warn_bomb.defuse();
473        }
474        .instrument(execution_span)
475    });
476
477    state.in_progress_executions.lock().insert(
478        execution_id,
479        ExecutionState {
480            execution_id,
481            cancellation_token,
482            handle,
483        },
484    );
485
486    Ok(())
487}
488
489struct ExecutorState<'s> {
490    executor_id: Option<String>,
491    options: &'s ExecutorOptions,
492    handlers: &'s [Arc<dyn ExecutionHandlerRaw + Send + Sync>],
493    executor_requests: flume::Sender<ExecutorConnectionRequest>,
494    heartbeat_interval: std::time::Duration,
495    in_progress_executions: Arc<Mutex<IndexMap<Uuid, ExecutionState>>>,
496    wg: WaitGroup,
497}
498
499struct ExecutionState {
500    execution_id: Uuid,
501    cancellation_token: CancellationToken,
502    handle: tokio::task::JoinHandle<()>,
503}
504
505struct ExecutionDropWarnBomb {
506    span: tracing::Span,
507    defused: bool,
508}
509
510impl ExecutionDropWarnBomb {
511    fn new(span: tracing::Span) -> Self {
512        Self {
513            span,
514            defused: false,
515        }
516    }
517
518    fn defuse(&mut self) {
519        self.defused = true;
520    }
521}
522
523impl Drop for ExecutionDropWarnBomb {
524    fn drop(&mut self) {
525        if !self.defused {
526            self.span.in_scope(|| {
527                tracing::warn!("execution was dropped during execution");
528            });
529        }
530    }
531}