pub mod cron;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use axum::{
Json, Router,
extract::{Path, State},
http::StatusCode,
response::IntoResponse,
routing::{get, post},
};
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use tokio::sync::RwLock;
use tokio_util::sync::CancellationToken;
pub use cron::{
ConcurrencyPolicy, CreateCronJobRequest, CronJob, CronJobResponse, CronJobStatus, CronState,
cron_jobs_router, cron_jobs_router_with_state, start_cron_scheduler, validate_cron_expression,
};
pub type WorkflowState = HashMap<String, Value>;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum RunStatus {
Queued,
Running,
Completed,
Failed,
Cancelled,
}
#[derive(Debug, Clone)]
pub struct BackgroundRun {
pub run_id: String,
pub workflow_id: String,
pub status: RunStatus,
pub input: WorkflowState,
pub result: Option<Value>,
pub error: Option<String>,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
pub timeout: Option<Duration>,
pub max_retries: u32,
pub retry_count: u32,
pub cancel_token: CancellationToken,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SubmitRunRequest {
pub workflow_id: String,
pub input: WorkflowState,
#[serde(default)]
pub timeout_secs: Option<u64>,
#[serde(default)]
pub max_retries: Option<u32>,
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct SubmitRunResponse {
pub run_id: String,
pub status: RunStatus,
pub created_at: String,
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct RunStatusResponse {
pub run_id: String,
pub status: RunStatus,
pub created_at: String,
pub updated_at: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub result: Option<Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub error: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub retry_count: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub retries_remaining: Option<u32>,
}
#[derive(Debug, Clone, Default)]
pub struct RunStore {
runs: Arc<RwLock<HashMap<String, BackgroundRun>>>,
}
impl RunStore {
pub fn new() -> Self {
Self { runs: Arc::new(RwLock::new(HashMap::new())) }
}
pub async fn insert(&self, run: BackgroundRun) {
self.runs.write().await.insert(run.run_id.clone(), run);
}
pub async fn get(&self, run_id: &str) -> Option<BackgroundRun> {
self.runs.read().await.get(run_id).cloned()
}
pub async fn update_status(&self, run_id: &str, status: RunStatus) {
if let Some(run) = self.runs.write().await.get_mut(run_id) {
run.status = status;
run.updated_at = Utc::now();
}
}
pub async fn set_completed(&self, run_id: &str, result: Value) {
if let Some(run) = self.runs.write().await.get_mut(run_id) {
run.status = RunStatus::Completed;
run.result = Some(result);
run.updated_at = Utc::now();
}
}
pub async fn set_failed(&self, run_id: &str, error: String) {
if let Some(run) = self.runs.write().await.get_mut(run_id) {
run.status = RunStatus::Failed;
run.error = Some(error);
run.updated_at = Utc::now();
}
}
pub async fn retry(&self, run_id: &str) -> bool {
if let Some(run) = self.runs.write().await.get_mut(run_id) {
if run.retry_count < run.max_retries {
run.retry_count += 1;
run.status = RunStatus::Queued;
run.error = None;
run.updated_at = Utc::now();
return true;
}
}
false
}
}
#[derive(Debug, Clone)]
pub struct BackgroundRunner {
store: RunStore,
}
impl BackgroundRunner {
pub fn new(store: RunStore) -> Self {
Self { store }
}
pub fn execute(&self, run_id: String) {
let store = self.store.clone();
tokio::spawn(async move {
let run = match store.get(&run_id).await {
Some(r) => r,
None => return,
};
let cancel_token = run.cancel_token.clone();
let timeout_duration = run.timeout;
store.update_status(&run_id, RunStatus::Running).await;
let result = Self::run_with_timeout(timeout_duration, &cancel_token).await;
match result {
RunOutcome::Completed(value) => {
store.set_completed(&run_id, value).await;
}
RunOutcome::Failed(error) => {
if store.retry(&run_id).await {
let store_clone = store.clone();
let run_id_clone = run_id.clone();
tokio::spawn(async move {
let runner = BackgroundRunner::new(store_clone);
runner.execute(run_id_clone);
});
} else {
store.set_failed(&run_id, error).await;
}
}
RunOutcome::Cancelled => {
store.update_status(&run_id, RunStatus::Cancelled).await;
}
RunOutcome::TimedOut => {
store.set_failed(&run_id, "run timed out".to_string()).await;
}
}
});
}
async fn run_with_timeout(
timeout_duration: Option<Duration>,
cancel_token: &CancellationToken,
) -> RunOutcome {
let work = async {
if cancel_token.is_cancelled() {
return RunOutcome::Cancelled;
}
RunOutcome::Completed(Value::Object(serde_json::Map::new()))
};
match timeout_duration {
Some(duration) => {
tokio::select! {
_ = cancel_token.cancelled() => RunOutcome::Cancelled,
result = tokio::time::timeout(duration, work) => {
match result {
Ok(outcome) => outcome,
Err(_) => RunOutcome::TimedOut,
}
}
}
}
None => {
tokio::select! {
_ = cancel_token.cancelled() => RunOutcome::Cancelled,
outcome = work => outcome,
}
}
}
}
}
#[derive(Debug)]
#[allow(dead_code)]
enum RunOutcome {
Completed(Value),
Failed(String),
Cancelled,
TimedOut,
}
#[derive(Debug, Clone)]
pub struct BackgroundState {
pub store: RunStore,
pub runner: BackgroundRunner,
}
impl BackgroundState {
pub fn new() -> Self {
let store = RunStore::new();
let runner = BackgroundRunner::new(store.clone());
Self { store, runner }
}
}
impl Default for BackgroundState {
fn default() -> Self {
Self::new()
}
}
async fn submit_run(
State(state): State<BackgroundState>,
Json(request): Json<SubmitRunRequest>,
) -> impl IntoResponse {
let run_id = uuid::Uuid::new_v4().to_string();
let now = Utc::now();
let run = BackgroundRun {
run_id: run_id.clone(),
workflow_id: request.workflow_id,
status: RunStatus::Queued,
input: request.input,
result: None,
error: None,
created_at: now,
updated_at: now,
timeout: request.timeout_secs.map(Duration::from_secs),
max_retries: request.max_retries.unwrap_or(0),
retry_count: 0,
cancel_token: CancellationToken::new(),
};
state.store.insert(run).await;
state.runner.execute(run_id.clone());
let response =
SubmitRunResponse { run_id, status: RunStatus::Queued, created_at: now.to_rfc3339() };
(StatusCode::CREATED, Json(response))
}
async fn get_run_status(
State(state): State<BackgroundState>,
Path(run_id): Path<String>,
) -> impl IntoResponse {
match state.store.get(&run_id).await {
Some(run) => {
let retries_remaining = if run.max_retries > 0 {
Some(run.max_retries.saturating_sub(run.retry_count))
} else {
None
};
let retry_count = if run.max_retries > 0 { Some(run.retry_count) } else { None };
let response = RunStatusResponse {
run_id: run.run_id,
status: run.status,
created_at: run.created_at.to_rfc3339(),
updated_at: run.updated_at.to_rfc3339(),
result: run.result,
error: run.error,
retry_count,
retries_remaining,
};
(StatusCode::OK, Json(response)).into_response()
}
None => (StatusCode::NOT_FOUND, Json(serde_json::json!({ "error": "run not found" })))
.into_response(),
}
}
async fn cancel_run(
State(state): State<BackgroundState>,
Path(run_id): Path<String>,
) -> impl IntoResponse {
match state.store.get(&run_id).await {
Some(run) => {
match run.status {
RunStatus::Completed | RunStatus::Failed | RunStatus::Cancelled => {
let response = RunStatusResponse {
run_id: run.run_id,
status: run.status,
created_at: run.created_at.to_rfc3339(),
updated_at: run.updated_at.to_rfc3339(),
result: run.result,
error: run.error,
retry_count: if run.max_retries > 0 { Some(run.retry_count) } else { None },
retries_remaining: if run.max_retries > 0 {
Some(run.max_retries.saturating_sub(run.retry_count))
} else {
None
},
};
(StatusCode::OK, Json(response)).into_response()
}
RunStatus::Queued | RunStatus::Running => {
run.cancel_token.cancel();
state.store.update_status(&run_id, RunStatus::Cancelled).await;
let updated = state.store.get(&run_id).await.unwrap();
let response = RunStatusResponse {
run_id: updated.run_id,
status: updated.status,
created_at: updated.created_at.to_rfc3339(),
updated_at: updated.updated_at.to_rfc3339(),
result: updated.result,
error: updated.error,
retry_count: if updated.max_retries > 0 {
Some(updated.retry_count)
} else {
None
},
retries_remaining: if updated.max_retries > 0 {
Some(updated.max_retries.saturating_sub(updated.retry_count))
} else {
None
},
};
(StatusCode::OK, Json(response)).into_response()
}
}
}
None => (StatusCode::NOT_FOUND, Json(serde_json::json!({ "error": "run not found" })))
.into_response(),
}
}
pub fn background_runs_router() -> Router {
let state = BackgroundState::new();
background_runs_router_with_state(state)
}
pub fn background_runs_router_with_state(state: BackgroundState) -> Router {
Router::new()
.route("/runs", post(submit_run))
.route("/runs/{run_id}", get(get_run_status).delete(cancel_run))
.with_state(state)
}