use std::{num::NonZero, sync::Arc, time::SystemTime};
use async_trait::async_trait;
use eyre::bail;
use ora_proto::server::v1::executor_service_client::ExecutorServiceClient;
use tokio_util::sync::CancellationToken;
use tonic::transport::Channel;
use uuid::Uuid;
use crate::job_type::{JobType, JobTypeExt, JobTypeMetadata};
mod run;
pub use eyre::Result;
#[derive(Debug, Clone)]
pub struct ExecutorOptions {
pub name: String,
pub max_concurrent_executions: NonZero<u32>,
pub cancellation_grace_period: std::time::Duration,
}
impl Default for ExecutorOptions {
fn default() -> Self {
Self {
name: String::new(),
max_concurrent_executions: NonZero::new(1).unwrap(),
cancellation_grace_period: std::time::Duration::from_secs(30),
}
}
}
pub struct Executor {
options: ExecutorOptions,
client: ExecutorServiceClient<Channel>,
handlers: Vec<Arc<dyn ExecutionHandlerRaw + Send + Sync>>,
}
impl Executor {
pub fn new(client: ExecutorServiceClient<Channel>) -> Self {
Self::with_options(client, ExecutorOptions::default())
}
pub fn with_options(client: ExecutorServiceClient<Channel>, options: ExecutorOptions) -> Self {
Self {
client,
options,
handlers: Vec::new(),
}
}
pub fn add_handler(&mut self, handler: Arc<dyn ExecutionHandlerRaw + Send + Sync>) {
assert!(
!self
.handlers
.iter()
.any(|h| h.job_type_metadata().id == handler.job_type_metadata().id),
"A handler for job type {} is already registered",
handler.job_type_metadata().id
);
self.handlers.push(handler);
}
pub fn try_add_handler(
&mut self,
handler: Arc<dyn ExecutionHandlerRaw + Send + Sync>,
) -> eyre::Result<()> {
if self
.handlers
.iter()
.any(|h| h.job_type_metadata().id == handler.job_type_metadata().id)
{
bail!(
"A handler for job type {} is already registered",
handler.job_type_metadata().id
);
}
self.handlers.push(handler);
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct ExecutionContext {
execution_id: Uuid,
job_id: Uuid,
target_execution_time: SystemTime,
attempt_number: u64,
job_type_id: String,
cancellation_token: CancellationToken,
}
impl ExecutionContext {
#[must_use]
pub fn execution_id(&self) -> Uuid {
self.execution_id
}
#[must_use]
pub fn job_id(&self) -> Uuid {
self.job_id
}
#[must_use]
pub fn target_execution_time(&self) -> SystemTime {
self.target_execution_time
}
#[must_use]
pub fn attempt_number(&self) -> u64 {
self.attempt_number
}
#[must_use]
pub fn job_type_id(&self) -> &str {
&self.job_type_id
}
pub async fn cancelled(&self) {
self.cancellation_token.cancelled().await;
}
#[must_use]
pub fn is_cancelled(&self) -> bool {
self.cancellation_token.is_cancelled()
}
}
#[async_trait]
pub trait ExecutionHandler<J>
where
J: JobType,
{
async fn execute(&self, context: ExecutionContext, input: J) -> eyre::Result<J::Output>;
fn raw_handler(self) -> Arc<dyn ExecutionHandlerRaw + Send + Sync>
where
Self: Sized + Send + Sync + 'static,
{
struct H<J, F>(F, std::marker::PhantomData<J>, JobTypeMetadata);
#[async_trait]
impl<J, F> ExecutionHandlerRaw for H<J, F>
where
J: JobType,
F: ExecutionHandler<J> + Send + Sync + 'static,
{
fn can_execute(&self, context: &ExecutionContext) -> bool {
context.job_type_id == J::id()
}
async fn execute(
&self,
context: ExecutionContext,
input_json: &str,
) -> Result<String, String> {
let input = serde_json::from_str::<J>(input_json)
.map_err(|e| format!("Failed to parse job input JSON: {e}"))?;
let result = self
.0
.execute(context, input)
.await
.map_err(|e| e.to_string())?;
let output_json = serde_json::to_string(&result)
.map_err(|e| format!("Failed to serialize job output JSON: {e}"))?;
Ok(output_json)
}
fn job_type_metadata(&self) -> &JobTypeMetadata {
&self.2
}
}
Arc::new(H(self, std::marker::PhantomData, J::metadata()))
}
}
#[async_trait]
impl<J, F, Fut> ExecutionHandler<J> for F
where
J: JobType,
F: Fn(ExecutionContext, J) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = eyre::Result<J::Output>> + Send + 'static,
{
async fn execute(&self, context: ExecutionContext, input: J) -> eyre::Result<J::Output> {
self(context, input).await
}
}
#[async_trait]
pub trait ExecutionHandlerRaw {
fn can_execute(&self, context: &ExecutionContext) -> bool;
async fn execute(&self, context: ExecutionContext, input_json: &str) -> Result<String, String>;
fn job_type_metadata(&self) -> &JobTypeMetadata;
}
pub trait IntoExecutionHandler: Sized + Send + Sync + 'static {
fn handler<J>(self) -> Arc<dyn ExecutionHandlerRaw + Send + Sync>
where
Self: ExecutionHandler<J>,
J: JobType,
{
<Self as ExecutionHandler<J>>::raw_handler(self)
}
}
impl<W> IntoExecutionHandler for W where W: Sized + Send + Sync + 'static {}