ora_client/
executor.rs

1//! An implementation of a job executor.
2
3use std::{num::NonZero, sync::Arc, time::SystemTime};
4
5use async_trait::async_trait;
6use eyre::bail;
7use ora_proto::server::v1::executor_service_client::ExecutorServiceClient;
8use tokio_util::sync::CancellationToken;
9use tonic::transport::Channel;
10use uuid::Uuid;
11
12use crate::job_type::{JobType, JobTypeExt, JobTypeMetadata};
13
14mod run;
15
16pub use eyre::Result;
17
18/// Options for configuring an executor.
19#[derive(Debug, Clone)]
20pub struct ExecutorOptions {
21    /// The name of the executor.
22    pub name: String,
23    /// The maximum number of concurrent executions.
24    ///
25    /// Defaults to 1.
26    pub max_concurrent_executions: NonZero<u32>,
27    /// The grace period for job cancellations,
28    /// after which the futures will be dropped.
29    pub cancellation_grace_period: std::time::Duration,
30}
31
32impl Default for ExecutorOptions {
33    fn default() -> Self {
34        Self {
35            name: String::new(),
36            max_concurrent_executions: NonZero::new(1).unwrap(),
37            cancellation_grace_period: std::time::Duration::from_secs(30),
38        }
39    }
40}
41
42/// An executor for running jobs.
43pub struct Executor<C = Channel> {
44    options: ExecutorOptions,
45    client: ExecutorServiceClient<C>,
46    handlers: Vec<Arc<dyn ExecutionHandlerRaw + Send + Sync>>,
47}
48
49impl<C> Executor<C> {
50    /// Create a new executor.
51    pub fn new(client: ExecutorServiceClient<C>) -> Self {
52        Self::with_options(client, ExecutorOptions::default())
53    }
54
55    /// Create a new executor with the given options.
56    pub fn with_options(client: ExecutorServiceClient<C>, options: ExecutorOptions) -> Self {
57        Self {
58            client,
59            options,
60            handlers: Vec::new(),
61        }
62    }
63
64    /// Get the options of the executor.
65    pub fn options(&self) -> &ExecutorOptions {
66        &self.options
67    }
68
69    /// Add a new handler to the executor.
70    ///
71    /// # Panics
72    ///
73    /// Panics if a handler for the same job type is already registered.
74    pub fn add_handler(&mut self, handler: Arc<dyn ExecutionHandlerRaw + Send + Sync>) {
75        assert!(
76            !self
77                .handlers
78                .iter()
79                .any(|h| h.job_type_metadata().id == handler.job_type_metadata().id),
80            "A handler for job type {} is already registered",
81            handler.job_type_metadata().id
82        );
83
84        self.handlers.push(handler);
85    }
86
87    /// Try to add a new handler to the executor.
88    ///
89    /// If a handler for the same job type is already registered,
90    /// this function will return an error.
91    pub fn try_add_handler(
92        &mut self,
93        handler: Arc<dyn ExecutionHandlerRaw + Send + Sync>,
94    ) -> eyre::Result<()> {
95        if self
96            .handlers
97            .iter()
98            .any(|h| h.job_type_metadata().id == handler.job_type_metadata().id)
99        {
100            bail!(
101                "A handler for job type {} is already registered",
102                handler.job_type_metadata().id
103            );
104        }
105
106        self.handlers.push(handler);
107        Ok(())
108    }
109}
110
111/// The context in which a job is executed.
112#[derive(Debug, Clone)]
113pub struct ExecutionContext {
114    execution_id: Uuid,
115    job_id: Uuid,
116    target_execution_time: SystemTime,
117    attempt_number: u64,
118    job_type_id: String,
119    cancellation_token: CancellationToken,
120}
121
122impl ExecutionContext {
123    /// Get the ID of the current execution.
124    #[must_use]
125    pub fn execution_id(&self) -> Uuid {
126        self.execution_id
127    }
128
129    /// Get the ID of the job.
130    #[must_use]
131    pub fn job_id(&self) -> Uuid {
132        self.job_id
133    }
134
135    /// Get the target execution time of the job.
136    #[must_use]
137    pub fn target_execution_time(&self) -> SystemTime {
138        self.target_execution_time
139    }
140
141    /// Get the attempt number of the job.
142    ///
143    /// The first attempt has number 1.
144    #[must_use]
145    pub fn attempt_number(&self) -> u64 {
146        self.attempt_number
147    }
148
149    /// The job type of the current job.
150    #[must_use]
151    pub fn job_type_id(&self) -> &str {
152        &self.job_type_id
153    }
154
155    /// Wait for the execution to be cancelled.
156    pub async fn cancelled(&self) {
157        self.cancellation_token.cancelled().await;
158    }
159
160    /// Check if the execution has been cancelled.
161    #[must_use]
162    pub fn is_cancelled(&self) -> bool {
163        self.cancellation_token.is_cancelled()
164    }
165}
166
167/// An execution handler for a specific job type.
168#[async_trait]
169pub trait ExecutionHandler<J>
170where
171    J: JobType,
172{
173    /// Execute the given job execution.
174    async fn execute(&self, context: ExecutionContext, input: J) -> eyre::Result<J::Output>;
175
176    /// Return a raw handler to be used by an executor.
177    fn raw_handler(self) -> Arc<dyn ExecutionHandlerRaw + Send + Sync>
178    where
179        Self: Sized + Send + Sync + 'static,
180    {
181        struct H<J, F>(F, std::marker::PhantomData<J>, JobTypeMetadata);
182
183        #[async_trait]
184        impl<J, F> ExecutionHandlerRaw for H<J, F>
185        where
186            J: JobType,
187            F: ExecutionHandler<J> + Send + Sync + 'static,
188        {
189            fn can_execute(&self, context: &ExecutionContext) -> bool {
190                context.job_type_id == J::id()
191            }
192
193            async fn execute(
194                &self,
195                context: ExecutionContext,
196                input_json: &str,
197            ) -> Result<String, String> {
198                let input = serde_json::from_str::<J>(input_json)
199                    .map_err(|e| format!("Failed to parse job input JSON: {e}"))?;
200
201                let result = self
202                    .0
203                    .execute(context, input)
204                    .await
205                    .map_err(|e| format!("{e:?}"))?;
206
207                let output_json = serde_json::to_string(&result)
208                    .map_err(|e| format!("Failed to serialize job output JSON: {e}"))?;
209
210                Ok(output_json)
211            }
212
213            fn job_type_metadata(&self) -> &JobTypeMetadata {
214                &self.2
215            }
216        }
217
218        Arc::new(H(self, std::marker::PhantomData, J::metadata()))
219    }
220}
221
222#[async_trait]
223impl<J, F, Fut> ExecutionHandler<J> for F
224where
225    J: JobType,
226    F: Fn(ExecutionContext, J) -> Fut + Send + Sync + 'static,
227    Fut: std::future::Future<Output = eyre::Result<J::Output>> + Send + 'static,
228{
229    async fn execute(&self, context: ExecutionContext, input: J) -> eyre::Result<J::Output> {
230        self(context, input).await
231    }
232}
233
234/// A handler for executing jobs.
235#[async_trait]
236pub trait ExecutionHandlerRaw {
237    /// Returns whether the handler can execute the
238    /// given job execution.
239    fn can_execute(&self, context: &ExecutionContext) -> bool;
240
241    /// Execute the given job execution.
242    ///
243    /// The Ok variant must be a valid JSON,
244    /// while the Err variant must be an error message of any kind.
245    ///
246    /// Note that while the input and outputs should be JSON,
247    /// this might not be enforced by either the executor or the server.
248    async fn execute(&self, context: ExecutionContext, input_json: &str) -> Result<String, String>;
249
250    /// Get information about the job type this handler can execute.
251    fn job_type_metadata(&self) -> &JobTypeMetadata;
252}
253
254/// A helper blanket trait for types that might implement [`ExecutionHandler`]
255/// for multiple [`JobType`]s.
256pub trait IntoExecutionHandler: Sized + Send + Sync + 'static {
257    /// Convert `self` into a [`RawHandler`] that can be registered
258    /// in workers.
259    fn handler<J>(self) -> Arc<dyn ExecutionHandlerRaw + Send + Sync>
260    where
261        Self: ExecutionHandler<J>,
262        J: JobType,
263    {
264        <Self as ExecutionHandler<J>>::raw_handler(self)
265    }
266}
267
268impl<W> IntoExecutionHandler for W where W: Sized + Send + Sync + 'static {}