use std::{pin::Pin, sync::Arc, time::SystemTime};
use eyre::Context;
use schemars::Schema;
use tokio_util::sync::CancellationToken;
use wgroup::WaitGroup;
use crate::{
execution::ExecutionId,
job::JobId,
job_type::{JobType, JobTypeId},
proto::executors::v1::execution_service_client::ExecutionServiceClient,
};
mod capabilities;
mod connection;
mod executions;
mod heartbeat;
mod main_loop;
pub type HandlerError = eyre::Report;
pub type HandlerResult<T> = Result<T, HandlerError>;
pub struct ExecutorOptions {
pub name: String,
pub cancellation_grace_period: std::time::Duration,
}
impl Default for ExecutorOptions {
fn default() -> Self {
Self {
name: String::new(),
cancellation_grace_period: std::time::Duration::from_secs(30),
}
}
}
#[must_use = "an executor does nothing until it is started"]
pub struct Executor<C> {
options: ExecutorOptions,
client: ExecutionServiceClient<C>,
on_execution_failed: Option<ExecutionFailedCb>,
queues: Vec<ExecutorJobQueue>,
}
impl<C> Executor<C> {
pub fn new(client: ExecutionServiceClient<C>) -> Self {
Self::new_with_options(ExecutorOptions::default(), client)
}
pub fn new_with_options(options: ExecutorOptions, client: ExecutionServiceClient<C>) -> Self {
Self {
options,
client,
on_execution_failed: None,
queues: Vec::new(),
}
}
pub fn on_execution_failed<F>(&mut self, cb: F)
where
F: Fn(ExecutionContext, &str) + Send + Sync + 'static,
{
self.on_execution_failed = Some(Arc::new(cb));
}
pub fn options(&self) -> &ExecutorOptions {
&self.options
}
pub fn with_name(mut self, name: impl Into<String>) -> Self {
self.options.name = name.into();
self
}
pub fn handler<F, J, Fut>(self, handler: F) -> Self
where
F: Fn(ExecutionContext, J) -> Fut + Send + Sync + 'static,
Fut: Future<Output = HandlerResult<J::Output>> + Send + 'static,
J: JobType,
{
self.handler_with_options(handler, HandlerOptions::default())
}
pub fn handler_with_options<F, J, Fut>(mut self, handler: F, options: HandlerOptions) -> Self
where
F: Fn(ExecutionContext, J) -> Fut + Send + Sync + 'static,
Fut: Future<Output = HandlerResult<J::Output>> + Send + 'static,
J: JobType,
{
let job_type_id = J::job_type_id();
self.queues
.retain(|q| q.job_type_id.as_str() != job_type_id.as_str());
let handler = Arc::new(handler);
let handler: HandlerFn = Arc::new(move |ctx: ExecutionContext, payload_json: String| {
Box::pin({
let handler = handler.clone();
async move {
let payload: J = serde_json::from_str(&payload_json)
.wrap_err("invalid payload for execution")?;
let output = handler(ctx, payload).await?;
let output_json = serde_json::to_string(&output)
.wrap_err("failed to serialize handler output")?;
Ok(output_json)
}
})
});
let input_schema = schemars::schema_for!(J);
let output_schema = schemars::schema_for!(J::Output);
let description = input_schema
.as_object()
.and_then(|o| o.get("description"))
.and_then(|s| s.as_str())
.map(Into::into);
self.queues.push(ExecutorJobQueue {
max_concurrent_jobs: options.max_concurrent,
job_type_id,
input_schema,
output_schema,
description,
handler,
});
self
}
}
#[must_use = "The executor stops when this handle is dropped."]
pub struct ExecutorHandle {
_wg: Option<WaitGroup>,
}
#[derive(Debug, Clone)]
pub struct ExecutionContext {
execution_id: ExecutionId,
job_id: JobId,
target_execution_time: SystemTime,
attempt_number: u64,
job_type_id: JobTypeId,
cancellation_token: CancellationToken,
}
impl ExecutionContext {
#[must_use]
pub fn execution_id(&self) -> ExecutionId {
self.execution_id
}
#[must_use]
pub fn job_id(&self) -> JobId {
self.job_id
}
#[must_use]
pub fn target_execution_time(&self) -> SystemTime {
self.target_execution_time
}
#[must_use]
pub fn attempt_number(&self) -> u64 {
self.attempt_number
}
pub fn job_type_id(&self) -> &JobTypeId {
&self.job_type_id
}
pub async fn cancelled(&self) {
self.cancellation_token.cancelled().await;
}
#[must_use]
pub fn is_cancelled(&self) -> bool {
self.cancellation_token.is_cancelled()
}
}
#[derive(Clone)]
pub struct HandlerOptions {
pub max_concurrent: u64,
}
impl Default for HandlerOptions {
fn default() -> Self {
Self { max_concurrent: 1 }
}
}
type ExecutionFailedCb = Arc<dyn Fn(ExecutionContext, &str) + Send + Sync>;
struct ExecutorJobQueue {
max_concurrent_jobs: u64,
job_type_id: JobTypeId,
input_schema: Schema,
output_schema: Schema,
description: Option<String>,
handler: HandlerFn,
}
type HandlerFn = Arc<
dyn Fn(ExecutionContext, String) -> Pin<Box<dyn Future<Output = HandlerResult<String>> + Send>>
+ Send
+ Sync,
>;