use std::collections::HashMap;
use std::future::Future;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::time::Duration;
use redb::Database;
use serde::Serialize;
use serde::de::DeserializeOwned;
use tokio::sync::watch;
use tokio::task::JoinSet;
use tracing::Instrument as _;
use tracing::{error, info, info_span, warn};
use super::builder::{EngineBuilder, NoStore};
use super::execution::{
claim_suspended_step, poll_timers, spawn_workflow_task, validate_key_component,
};
use super::invocation::{Invocation, InvocationBuilder};
use super::{Senders, WorkflowFn, WorkflowState};
use crate::context::{Context, STEPS, StepData, SuspendPoint, serialize_step};
use crate::error::{EngineError, StateError, SubscribeError};
use crate::metadata::{self, MetadataStatus, WORKFLOW_META, WorkflowMetadata};
use crate::retry::RetryPolicy;
use crate::stream::StatusStream;
pub struct Engine {
pub(super) db: Arc<Database>,
pub(super) workflows: HashMap<String, WorkflowFn>,
pub(super) running: Arc<AtomicBool>,
pub(super) tasks: Arc<tokio::sync::Mutex<JoinSet<()>>>,
pub(super) timer_serial: Arc<AtomicU64>,
pub(super) default_retry: Option<RetryPolicy>,
pub(super) resume_on_start: bool,
pub(super) senders: Senders,
}
impl Engine {
#[must_use]
pub fn builder() -> EngineBuilder<NoStore> {
EngineBuilder {
store: NoStore,
default_retry: None,
resume_on_start: true,
}
}
pub fn register<F, Fut>(&mut self, name: impl Into<String>, workflow: F)
where
F: Fn(Context) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<(), EngineError>> + Send + 'static,
{
assert!(
!self.running.load(Ordering::Acquire),
"cannot register workflows after the engine has started"
);
let name = name.into();
assert!(
!name.contains('/'),
"workflow name must not contain '/': '{name}'"
);
info!(name, "registered workflow");
self.workflows
.insert(name, Arc::new(move |ctx| Box::pin(workflow(ctx))));
}
pub async fn start(&mut self) -> Result<(), EngineError> {
assert!(
!self.running.load(Ordering::Acquire),
"engine has already been started"
);
self.running.store(true, Ordering::Release);
let db = Arc::clone(&self.db);
let running = Arc::clone(&self.running);
let workflows = self.workflows.clone();
let timer_serial = Arc::clone(&self.timer_serial);
let poller_tasks = Arc::clone(&self.tasks);
let default_retry = self.default_retry.clone();
let poller_senders = Arc::clone(&self.senders);
let mut tasks = self.tasks.lock().await;
tasks.spawn(
async move {
info!("timer poller started");
while running.load(Ordering::Acquire) {
tokio::time::sleep(Duration::from_secs(1)).await;
if let Err(e) = poll_timers(
&db,
&workflows,
&timer_serial,
&poller_tasks,
default_retry.as_ref(),
&poller_senders,
)
.await
{
error!(error = %e, "timer poll failed");
}
}
info!("timer poller stopped");
}
.instrument(info_span!("timer_poller")),
);
drop(tasks);
if self.resume_on_start {
self.resume_running_instances().await?;
}
info!("engine started");
Ok(())
}
async fn resume_running_instances(&self) -> Result<(), EngineError> {
let workflow_names: Vec<String> = self.workflows.keys().cloned().collect();
for workflow_name in &workflow_names {
let instances = metadata::list_metadata(&self.db, workflow_name)?;
for (instance_id, meta) in instances {
if *meta.status() == MetadataStatus::Running {
info!(
workflow = %workflow_name,
instance_id = %instance_id,
"auto-resuming workflow instance"
);
match self
.spawn_workflow(workflow_name, instance_id.clone(), None)
.await
{
Ok(_) => {}
Err(e) => {
warn!(
workflow = %workflow_name,
instance_id = %instance_id,
error = %e,
"failed to auto-resume workflow instance, skipping"
);
}
}
}
}
}
Ok(())
}
#[must_use]
pub fn invoke(&self, workflow_name: impl Into<String>) -> InvocationBuilder<'_> {
InvocationBuilder {
engine: self,
workflow_name: workflow_name.into(),
input_payload: Ok(None),
}
}
pub async fn resume(
&self,
workflow_name: impl Into<String>,
instance_id: impl Into<String>,
) -> Result<Invocation, EngineError> {
let workflow_name = workflow_name.into();
let instance_id = instance_id.into();
validate_key_component(&workflow_name, "workflow_name")?;
validate_key_component(&instance_id, "instance_id")?;
self.spawn_workflow(&workflow_name, instance_id, None).await
}
pub async fn signal<T>(
&self,
workflow_name: &str,
instance_id: &str,
point: &SuspendPoint<T>,
payload: T,
) -> Result<Invocation, EngineError>
where
T: Serialize + DeserializeOwned + Send,
{
validate_key_component(workflow_name, "workflow_name")?;
validate_key_component(instance_id, "instance_id")?;
let step_key = point.key();
let data: StepData<T> = StepData::Completed {
result: payload,
status: None,
};
let step_bytes = serialize_step(&data, step_key)?;
claim_suspended_step(&self.db, workflow_name, instance_id, step_key, &step_bytes)?;
info!(
workflow = workflow_name,
instance = instance_id,
step = step_key,
"signal delivered"
);
let mut tasks = self.tasks.lock().await;
if !self.running.load(Ordering::Acquire) {
return Err(EngineError::NotStarted);
}
let workflow = self
.workflows
.get(workflow_name)
.ok_or_else(|| EngineError::WorkflowNotFound(workflow_name.to_string()))?;
let tx = self.get_or_create_sender(instance_id);
let rx = tx.subscribe();
spawn_workflow_task(
&mut tasks,
workflow,
&self.db,
workflow_name,
instance_id,
&self.timer_serial,
self.default_retry.clone(),
&tx,
&self.senders,
);
Ok(Invocation {
instance_id: instance_id.to_string(),
status: rx,
})
}
pub fn subscribe(
&self,
workflow_name: &str,
instance_id: &str,
) -> Result<StatusStream, SubscribeError> {
let senders = self
.senders
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
if let Some(tx) = senders.get(instance_id) {
return Ok(StatusStream::live(tx.subscribe()));
}
drop(senders);
let meta = metadata::read_metadata(&self.db, workflow_name, instance_id)
.map_err(|e| SubscribeError::Storage(e.to_string().into()))?;
match meta {
None => Err(SubscribeError::NotFound {
workflow_name: workflow_name.to_string(),
instance_id: instance_id.to_string(),
}),
Some(meta) => match meta.status() {
MetadataStatus::Suspended { key, status } => {
Ok(StatusStream::snapshot(WorkflowState::Suspended {
key: key.clone(),
status: status.clone(),
}))
}
MetadataStatus::Completed(msg) => Ok(StatusStream::snapshot(
WorkflowState::Completed(msg.clone()),
)),
MetadataStatus::Failed(msg) => {
Ok(StatusStream::snapshot(WorkflowState::Failed(msg.clone())))
}
MetadataStatus::Running => Err(SubscribeError::StaleRunning {
workflow_name: workflow_name.to_string(),
instance_id: instance_id.to_string(),
}),
},
}
}
pub fn state(
&self,
workflow_name: &str,
instance_id: &str,
) -> Result<WorkflowState, StateError> {
let senders = self
.senders
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
if let Some(tx) = senders.get(instance_id) {
return Ok(tx.borrow().clone());
}
drop(senders);
let meta = metadata::read_metadata(&self.db, workflow_name, instance_id)
.map_err(|e| StateError::Storage(e.to_string().into()))?;
match meta {
None => Err(StateError::NotFound {
workflow_name: workflow_name.to_string(),
instance_id: instance_id.to_string(),
}),
Some(meta) => match meta.status() {
MetadataStatus::Running => Ok(WorkflowState::Started),
MetadataStatus::Suspended { key, status } => Ok(WorkflowState::Suspended {
key: key.clone(),
status: status.clone(),
}),
MetadataStatus::Completed(msg) => Ok(WorkflowState::Completed(msg.clone())),
MetadataStatus::Failed(msg) => Ok(WorkflowState::Failed(msg.clone())),
},
}
}
pub fn get_metadata(
&self,
workflow_name: &str,
instance_id: &str,
) -> Result<Option<WorkflowMetadata>, EngineError> {
metadata::read_metadata(&self.db, workflow_name, instance_id)
}
pub fn list_instances(
&self,
workflow_name: &str,
) -> Result<Vec<(String, WorkflowMetadata)>, EngineError> {
metadata::list_metadata(&self.db, workflow_name)
}
#[expect(clippy::unused_async)]
pub async fn stop(&self) {
self.running.store(false, Ordering::Release);
info!("engine stopped");
}
pub async fn wait_all(&self) {
self.running.store(false, Ordering::Release);
info!("waiting for all workflows to complete");
let mut tasks = self.tasks.lock().await;
while let Some(result) = tasks.join_next().await {
if let Err(e) = result {
if e.is_panic() {
error!("workflow task panicked: {e}");
}
}
}
info!("all workflows completed");
}
pub(super) async fn spawn_workflow(
&self,
workflow_name: &str,
instance_id: String,
input_bytes: Option<Vec<u8>>,
) -> Result<Invocation, EngineError> {
validate_key_component(workflow_name, "workflow_name")?;
let mut tasks = self.tasks.lock().await;
if !self.running.load(Ordering::Acquire) {
return Err(EngineError::NotStarted);
}
let workflow = self
.workflows
.get(workflow_name)
.ok_or_else(|| EngineError::WorkflowNotFound(workflow_name.to_string()))?;
let meta = WorkflowMetadata::new(MetadataStatus::Running);
let meta_key = format!("{workflow_name}/{instance_id}");
let meta_bytes = postcard::to_allocvec(&meta).map_err(|e| EngineError::Serialization {
key: meta_key.clone(),
source: Box::new(e),
})?;
let write_txn = self.db.begin_write()?;
{
let mut meta_table = write_txn.open_table(WORKFLOW_META)?;
meta_table.insert(meta_key.as_str(), meta_bytes.as_slice())?;
if let Some(ref input) = input_bytes {
let mut steps_table = write_txn.open_table(STEPS)?;
let input_key = format!("{workflow_name}/{instance_id}/_input");
steps_table.insert(input_key.as_str(), input.as_slice())?;
}
}
write_txn.commit()?;
let tx = self.get_or_create_sender(&instance_id);
let rx = tx.subscribe();
spawn_workflow_task(
&mut tasks,
workflow,
&self.db,
workflow_name,
&instance_id,
&self.timer_serial,
self.default_retry.clone(),
&tx,
&self.senders,
);
Ok(Invocation {
instance_id,
status: rx,
})
}
fn get_or_create_sender(&self, instance_id: &str) -> watch::Sender<WorkflowState> {
let mut senders = self
.senders
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
if let Some(tx) = senders.get(instance_id) {
tx.send_if_modified(|state| {
*state = WorkflowState::Started;
false
});
return tx.clone();
}
let (tx, _) = watch::channel(WorkflowState::Started);
senders.insert(instance_id.to_string(), tx.clone());
tx
}
}