ora_client/
executor.rs

1//! An implementation of a job executor.
2
3use std::{num::NonZero, sync::Arc, time::SystemTime};
4
5use async_trait::async_trait;
6use eyre::bail;
7use ora_proto::server::v1::executor_service_client::ExecutorServiceClient;
8use tokio_util::sync::CancellationToken;
9use tonic::transport::Channel;
10use uuid::Uuid;
11
12use crate::job_type::{JobType, JobTypeExt, JobTypeMetadata};
13
14mod run;
15
16pub use eyre::Result;
17
18/// Options for configuring an executor.
19#[derive(Debug, Clone)]
20pub struct ExecutorOptions {
21    /// The name of the executor.
22    pub name: String,
23    /// The maximum number of concurrent executions.
24    ///
25    /// Defaults to 1.
26    pub max_concurrent_executions: NonZero<u32>,
27    /// The grace period for job cancellations,
28    /// after which the futures will be dropped.
29    pub cancellation_grace_period: std::time::Duration,
30}
31
32impl Default for ExecutorOptions {
33    fn default() -> Self {
34        Self {
35            name: String::new(),
36            max_concurrent_executions: NonZero::new(1).unwrap(),
37            cancellation_grace_period: std::time::Duration::from_secs(30),
38        }
39    }
40}
41
42type ExecutionFailedCb = Arc<dyn Fn(ExecutionContext, &str) + Send + Sync>;
43
44/// An executor for running jobs.
45pub struct Executor<C = Channel> {
46    options: ExecutorOptions,
47    client: ExecutorServiceClient<C>,
48    handlers: Vec<Arc<dyn ExecutionHandlerRaw + Send + Sync>>,
49    on_execution_failed: Option<ExecutionFailedCb>,
50}
51
52impl<C> Executor<C> {
53    /// Create a new executor.
54    pub fn new(client: ExecutorServiceClient<C>) -> Self {
55        Self::with_options(client, ExecutorOptions::default())
56    }
57
58    /// Create a new executor with the given options.
59    pub fn with_options(client: ExecutorServiceClient<C>, options: ExecutorOptions) -> Self {
60        Self {
61            client,
62            options,
63            handlers: Vec::new(),
64            on_execution_failed: None,
65        }
66    }
67
68    /// Set a callback to be called when an execution fails.
69    ///
70    /// Only one callback can be set at a time,
71    /// the previous one will be replaced.
72    pub fn on_execution_failed(
73        &mut self,
74        callback: impl Fn(ExecutionContext, &str) + Send + Sync + 'static,
75    ) {
76        self.on_execution_failed = Some(Arc::new(callback));
77    }
78
79    /// Get the options of the executor.
80    pub fn options(&self) -> &ExecutorOptions {
81        &self.options
82    }
83
84    /// Add a new handler to the executor.
85    ///
86    /// # Panics
87    ///
88    /// Panics if a handler for the same job type is already registered.
89    pub fn add_handler(&mut self, handler: Arc<dyn ExecutionHandlerRaw + Send + Sync>) {
90        assert!(
91            !self
92                .handlers
93                .iter()
94                .any(|h| h.job_type_metadata().id == handler.job_type_metadata().id),
95            "A handler for job type {} is already registered",
96            handler.job_type_metadata().id
97        );
98
99        self.handlers.push(handler);
100    }
101
102    /// Try to add a new handler to the executor.
103    ///
104    /// If a handler for the same job type is already registered,
105    /// this function will return an error.
106    pub fn try_add_handler(
107        &mut self,
108        handler: Arc<dyn ExecutionHandlerRaw + Send + Sync>,
109    ) -> eyre::Result<()> {
110        if self
111            .handlers
112            .iter()
113            .any(|h| h.job_type_metadata().id == handler.job_type_metadata().id)
114        {
115            bail!(
116                "A handler for job type {} is already registered",
117                handler.job_type_metadata().id
118            );
119        }
120
121        self.handlers.push(handler);
122        Ok(())
123    }
124}
125
126/// The context in which a job is executed.
127#[derive(Debug, Clone)]
128pub struct ExecutionContext {
129    execution_id: Uuid,
130    job_id: Uuid,
131    target_execution_time: SystemTime,
132    attempt_number: u64,
133    job_type_id: String,
134    cancellation_token: CancellationToken,
135}
136
137impl ExecutionContext {
138    /// Get the ID of the current execution.
139    #[must_use]
140    pub fn execution_id(&self) -> Uuid {
141        self.execution_id
142    }
143
144    /// Get the ID of the job.
145    #[must_use]
146    pub fn job_id(&self) -> Uuid {
147        self.job_id
148    }
149
150    /// Get the target execution time of the job.
151    #[must_use]
152    pub fn target_execution_time(&self) -> SystemTime {
153        self.target_execution_time
154    }
155
156    /// Get the attempt number of the job.
157    ///
158    /// The first attempt has number 1.
159    #[must_use]
160    pub fn attempt_number(&self) -> u64 {
161        self.attempt_number
162    }
163
164    /// The job type of the current job.
165    #[must_use]
166    pub fn job_type_id(&self) -> &str {
167        &self.job_type_id
168    }
169
170    /// Wait for the execution to be cancelled.
171    pub async fn cancelled(&self) {
172        self.cancellation_token.cancelled().await;
173    }
174
175    /// Check if the execution has been cancelled.
176    #[must_use]
177    pub fn is_cancelled(&self) -> bool {
178        self.cancellation_token.is_cancelled()
179    }
180}
181
182/// An execution handler for a specific job type.
183#[async_trait]
184pub trait ExecutionHandler<J>
185where
186    J: JobType,
187{
188    /// Execute the given job execution.
189    async fn execute(&self, context: ExecutionContext, input: J) -> eyre::Result<J::Output>;
190
191    /// Return a raw handler to be used by an executor.
192    fn raw_handler(self) -> Arc<dyn ExecutionHandlerRaw + Send + Sync>
193    where
194        Self: Sized + Send + Sync + 'static,
195    {
196        struct H<J, F>(F, std::marker::PhantomData<J>, JobTypeMetadata);
197
198        #[async_trait]
199        impl<J, F> ExecutionHandlerRaw for H<J, F>
200        where
201            J: JobType,
202            F: ExecutionHandler<J> + Send + Sync + 'static,
203        {
204            fn can_execute(&self, context: &ExecutionContext) -> bool {
205                context.job_type_id == J::id()
206            }
207
208            async fn execute(
209                &self,
210                context: ExecutionContext,
211                input_json: &str,
212            ) -> Result<String, String> {
213                let input = serde_json::from_str::<J>(input_json)
214                    .map_err(|e| format!("Failed to parse job input JSON: {e}"))?;
215
216                let result = self
217                    .0
218                    .execute(context, input)
219                    .await
220                    .map_err(|e| format!("{e:?}"))?;
221
222                let output_json = serde_json::to_string(&result)
223                    .map_err(|e| format!("Failed to serialize job output JSON: {e}"))?;
224
225                Ok(output_json)
226            }
227
228            fn job_type_metadata(&self) -> &JobTypeMetadata {
229                &self.2
230            }
231        }
232
233        Arc::new(H(self, std::marker::PhantomData, J::metadata()))
234    }
235}
236
237#[async_trait]
238impl<J, F, Fut> ExecutionHandler<J> for F
239where
240    J: JobType,
241    F: Fn(ExecutionContext, J) -> Fut + Send + Sync + 'static,
242    Fut: std::future::Future<Output = eyre::Result<J::Output>> + Send + 'static,
243{
244    async fn execute(&self, context: ExecutionContext, input: J) -> eyre::Result<J::Output> {
245        self(context, input).await
246    }
247}
248
249/// A handler for executing jobs.
250#[async_trait]
251pub trait ExecutionHandlerRaw {
252    /// Returns whether the handler can execute the
253    /// given job execution.
254    fn can_execute(&self, context: &ExecutionContext) -> bool;
255
256    /// Execute the given job execution.
257    ///
258    /// The Ok variant must be a valid JSON,
259    /// while the Err variant must be an error message of any kind.
260    ///
261    /// Note that while the input and outputs should be JSON,
262    /// this might not be enforced by either the executor or the server.
263    async fn execute(&self, context: ExecutionContext, input_json: &str) -> Result<String, String>;
264
265    /// Get information about the job type this handler can execute.
266    fn job_type_metadata(&self) -> &JobTypeMetadata;
267}
268
269/// A helper blanket trait for types that might implement [`ExecutionHandler`]
270/// for multiple [`JobType`]s.
271pub trait IntoExecutionHandler: Sized + Send + Sync + 'static {
272    /// Convert `self` into a [`RawHandler`] that can be registered
273    /// in workers.
274    fn handler<J>(self) -> Arc<dyn ExecutionHandlerRaw + Send + Sync>
275    where
276        Self: ExecutionHandler<J>,
277        J: JobType,
278    {
279        <Self as ExecutionHandler<J>>::raw_handler(self)
280    }
281}
282
283impl<W> IntoExecutionHandler for W where W: Sized + Send + Sync + 'static {}