use std::collections::VecDeque;
use std::fmt::{Debug, Display};
use std::sync::Arc;
use crate::context::Context;
use crate::distributed::context_store::ContextStore;
use crate::distributed::{ItemProcessedOutcome, StepCallbacks, StepError, WorkQueue};
use crate::error::FloxideError;
use crate::{Checkpoint, CheckpointStore};
use async_trait::async_trait;
use serde::de::DeserializeOwned;
use serde::Serialize;
use serde_json;
use tracing::{debug, info, span, Level};
pub trait WorkItem:
Debug + Display + Send + Sync + Serialize + DeserializeOwned + Clone + PartialEq
{
fn instance_id(&self) -> String;
fn is_terminal(&self) -> bool;
}
#[async_trait]
pub trait Workflow<C: Context>: Debug + Clone + Send + Sync {
type Input: Send + Sync + Serialize + DeserializeOwned;
type Output: Send + Sync + Serialize + DeserializeOwned;
type WorkItem: WorkItem;
fn name(&self) -> &'static str;
fn start_work_item(&self, input: Self::Input) -> Self::WorkItem;
async fn run<'a>(
&'a self,
ctx: &'a crate::WorkflowCtx<C>,
input: Self::Input,
) -> Result<Self::Output, FloxideError> {
let span = span!(Level::INFO, "workflow_run", workflow = self.name());
let _enter = span.enter();
let mut queue: VecDeque<Self::WorkItem> = VecDeque::new();
queue.push_back(self.start_work_item(input));
while let Some(item) = queue.pop_front() {
debug!(?item, queue_len = queue.len(), "Processing work item");
if let Some(output) = self.process_work_item(ctx, item, &mut queue).await? {
return Ok(output);
}
debug!(queue_len = queue.len(), "Queue state after processing");
}
unreachable!("Workflow did not reach terminal branch");
}
async fn process_work_item<'a>(
&'a self,
ctx: &'a crate::WorkflowCtx<C>,
item: Self::WorkItem,
queue: &mut std::collections::VecDeque<Self::WorkItem>,
) -> Result<Option<Self::Output>, FloxideError>;
async fn run_with_checkpoint<CS: CheckpointStore<C, Self::WorkItem> + Send + Sync>(
&self,
ctx: &crate::WorkflowCtx<C>,
input: Self::Input,
store: &CS,
id: &str,
) -> Result<Self::Output, FloxideError> {
let span = span!(
Level::INFO,
"workflow_run_with_checkpoint",
workflow = self.name(),
run_id = id
);
let _enter = span.enter();
let mut cp: Checkpoint<C, Self::WorkItem> = match store
.load(id)
.await
.map_err(|e| FloxideError::Generic(e.to_string()))?
{
Some(saved) => {
debug!("Loaded existing checkpoint");
saved
}
None => {
debug!("No checkpoint found, starting new");
let mut init_q = VecDeque::new();
init_q.push_back(self.start_work_item(input));
Checkpoint::new(ctx.store.clone(), init_q)
}
};
let mut queue = cp.queue.clone();
if queue.is_empty() {
info!("Workflow already completed (empty queue)");
return Err(FloxideError::AlreadyCompleted);
}
let mut wf_ctx = ctx.clone();
wf_ctx.store = cp.context.clone();
let ctx_ref = &mut wf_ctx;
while let Some(item) = queue.pop_front() {
debug!(?item, queue_len = queue.len(), "Processing work item");
if let Some(output) = self.process_work_item(ctx_ref, item, &mut queue).await? {
return Ok(output);
}
debug!(queue_len = queue.len(), "Queue state after processing");
cp.context = ctx_ref.store.clone();
cp.queue = queue.clone();
store
.save(id, &cp)
.await
.map_err(|e| FloxideError::Generic(e.to_string()))?;
debug!("Checkpoint saved");
}
unreachable!("Workflow did not reach terminal branch");
}
async fn resume<CS: CheckpointStore<C, Self::WorkItem> + Send + Sync>(
&self,
store: &CS,
id: &str,
) -> Result<Self::Output, FloxideError> {
let span = span!(
Level::INFO,
"workflow_resume",
workflow = self.name(),
checkpoint_id = id
);
let _enter = span.enter();
let mut cp = store
.load(id)
.await
.map_err(|e| FloxideError::Generic(e.to_string()))?
.ok_or(FloxideError::NotStarted)?;
debug!("Loaded checkpoint for resume");
let wf_ctx = crate::WorkflowCtx::new(cp.context.clone());
let ctx = &wf_ctx;
let mut queue: VecDeque<Self::WorkItem> = cp.queue.clone();
if queue.is_empty() {
info!("Workflow already completed (empty queue)");
return Err(FloxideError::AlreadyCompleted);
}
if queue.len() == 1
&& queue
.front()
.map(|item| item.is_terminal())
.unwrap_or(false)
{
info!("Workflow already completed (terminal node in queue)");
return Err(FloxideError::AlreadyCompleted);
}
while let Some(item) = queue.pop_front() {
debug!(?item, queue_len = queue.len(), "Processing work item");
if let Some(output) = self.process_work_item(ctx, item, &mut queue).await? {
return Ok(output);
}
cp.context = ctx.store.clone();
cp.queue = queue.clone();
store
.save(id, &cp)
.await
.map_err(|e| FloxideError::Generic(e.to_string()))?;
debug!("Checkpoint saved");
debug!(queue_len = queue.len(), "Queue state after processing");
}
unreachable!("Workflow did not reach terminal branch");
}
async fn start_distributed<CS, Q>(
&self,
ctx: &crate::WorkflowCtx<C>,
input: Self::Input,
context_store: &CS,
queue: &Q,
id: &str,
) -> Result<(), FloxideError>
where
CS: ContextStore<C> + Send + Sync,
Q: WorkQueue<C, Self::WorkItem> + Send + Sync,
C: crate::merge::Merge + Default,
{
let seed_span =
span!(Level::DEBUG, "start_distributed", workflow = self.name(), run_id = %id);
let _enter = seed_span.enter();
debug!(run_id = %id, "start_distributed seeding");
if context_store
.get(id)
.await
.map_err(|e| FloxideError::Generic(e.to_string()))?
.is_none()
{
let item = self.start_work_item(input);
context_store
.set(id, ctx.store.clone())
.await
.map_err(|e| FloxideError::Generic(e.to_string()))?;
queue
.enqueue(id, item)
.await
.map_err(|e| FloxideError::Generic(e.to_string()))?;
}
Ok(())
}
async fn step_distributed<CS, Q>(
&self,
context_store: &CS,
queue: &Q,
worker_id: usize,
callbacks: Arc<dyn StepCallbacks<C, Self>>,
) -> Result<Option<(String, Self::Output)>, StepError<Self::WorkItem>>
where
C: 'static + crate::merge::Merge + Default,
CS: ContextStore<C> + Send + Sync,
Q: crate::distributed::WorkQueue<C, Self::WorkItem> + Send + Sync,
{
let work = queue.dequeue().await.map_err(|e| StepError {
error: FloxideError::Generic(e.to_string()),
run_id: None,
work_item: None,
})?;
let (run_id, item) = match work {
None => return Ok(None),
Some((rid, it)) => (rid, it),
};
let step_span = span!(Level::DEBUG, "step_distributed",
workflow = self.name(), run_id = %run_id, worker = worker_id);
let _enter = step_span.enter();
debug!(worker = worker_id, run_id = %run_id, ?item, "Worker dequeued item");
let on_started_result = callbacks.on_started(run_id.clone(), item.clone()).await;
if let Err(e) = on_started_result {
return Err(StepError {
error: FloxideError::Generic(format!("on_started_state_updates failed: {:?}", e)),
run_id: Some(run_id.clone()),
work_item: Some(item.clone()),
});
}
let ctx_val = context_store.get(&run_id).await.map_err(|e| StepError {
error: FloxideError::Generic(e.to_string()),
run_id: Some(run_id.clone()),
work_item: Some(item.clone()),
})?;
let ctx_val = ctx_val.ok_or_else(|| StepError {
error: FloxideError::NotStarted,
run_id: Some(run_id.clone()),
work_item: Some(item.clone()),
})?;
let wf_ctx = crate::WorkflowCtx::new(ctx_val.clone());
let ctx_ref = &wf_ctx;
let mut local_q = VecDeque::new();
let process_result = self
.process_work_item(ctx_ref, item.clone(), &mut local_q)
.await;
match process_result {
Ok(Some(out)) => {
context_store
.merge(&run_id, wf_ctx.store.clone())
.await
.map_err(|e| StepError {
error: FloxideError::Generic(e.to_string()),
run_id: Some(run_id.clone()),
work_item: Some(item.clone()),
})?;
debug!(worker = worker_id, run_id = %run_id, "Context merged (terminal)");
let output_json = serde_json::to_value(&out).map_err(|e| StepError {
error: FloxideError::Generic(format!("Failed to serialize output: {}", e)),
run_id: Some(run_id.clone()),
work_item: Some(item.clone()),
})?;
let on_item_processed_result = callbacks
.on_item_processed(
run_id.clone(),
item.clone(),
ItemProcessedOutcome::SuccessTerminal(output_json),
)
.await;
if let Err(e) = on_item_processed_result {
return Err(StepError {
error: e,
run_id: Some(run_id.clone()),
work_item: Some(item.clone()),
});
}
return Ok(Some((run_id.clone(), out)));
}
Ok(None) => {
for succ in local_q.iter() {
queue
.enqueue(&run_id, succ.clone())
.await
.map_err(|e| StepError {
error: FloxideError::Generic(e.to_string()),
run_id: Some(run_id.clone()),
work_item: Some(item.clone()),
})?;
}
context_store
.merge(&run_id, wf_ctx.store.clone())
.await
.map_err(|e| StepError {
error: FloxideError::Generic(e.to_string()),
run_id: Some(run_id.clone()),
work_item: Some(item.clone()),
})?;
debug!(worker = worker_id, run_id = %run_id, "Context merged");
let on_item_processed_result = callbacks
.on_item_processed(
run_id.clone(),
item.clone(),
ItemProcessedOutcome::SuccessNonTerminal,
)
.await;
if let Err(e) = on_item_processed_result {
return Err(StepError {
error: e,
run_id: Some(run_id.clone()),
work_item: Some(item.clone()),
});
}
return Ok(None);
}
Err(e) => {
let on_item_processed_result = callbacks
.on_item_processed(
run_id.clone(),
item.clone(),
ItemProcessedOutcome::Error(e.clone()),
)
.await;
if let Err(e) = on_item_processed_result {
return Err(StepError {
error: e,
run_id: Some(run_id.clone()),
work_item: Some(item.clone()),
});
}
Err(StepError {
error: e,
run_id: Some(run_id),
work_item: Some(item),
})
}
}
}
fn to_dot(&self) -> &'static str;
}