use std::future::Future;
use std::ops::ControlFlow;
use bytes::Bytes;
use sayiir_core::codec::LoopDecision;
use sayiir_core::error::{BoxError, CodecError, WorkflowError};
use sayiir_core::snapshot::{ExecutionPosition, WorkflowSnapshot};
use sayiir_core::workflow::{MaxIterationsPolicy, WorkflowContinuation};
use sayiir_persistence::SnapshotStore;
use crate::error::RuntimeError;
pub(crate) fn decode_loop_envelope(bytes: &[u8]) -> Result<(LoopDecision, Bytes), BoxError> {
let &first = bytes.first().ok_or("empty loop envelope")?;
match first {
0 | 1 => sayiir_core::codec::decode_loop_envelope(bytes),
_ => {
let v: serde_json::Value = serde_json::from_slice(bytes)?;
let tag = v
.get("_loop")
.and_then(serde_json::Value::as_str)
.ok_or("missing or invalid '_loop' tag in LoopResult JSON")?;
let decision: LoopDecision = tag
.parse()
.map_err(|_| format!("unknown loop decision tag: '{tag}'"))?;
let inner = v
.get("value")
.ok_or("missing 'value' field in LoopResult JSON")?;
Ok((decision, Bytes::from(serde_json::to_vec(inner)?)))
}
}
}
pub(crate) struct LoopExit(pub Bytes);
pub(crate) struct LoopNext(pub Bytes);
pub(crate) struct LoopConfig<'a> {
pub id: &'a str,
pub body: &'a WorkflowContinuation,
pub max_iterations: u32,
pub on_max: MaxIterationsPolicy,
pub start_iteration: u32,
}
pub(crate) fn resolve_loop_iteration(
output: &Bytes,
iteration: u32,
cfg: &LoopConfig<'_>,
) -> Result<ControlFlow<LoopExit, LoopNext>, RuntimeError> {
let (decision, inner) = decode_loop_envelope(output).map_err(|e| CodecError::DecodeFailed {
task_id: cfg.id.to_string(),
expected_type: "LoopEnvelope",
source: e,
})?;
match decision {
LoopDecision::Done => Ok(ControlFlow::Break(LoopExit(inner))),
LoopDecision::Again => {
if iteration + 1 >= cfg.max_iterations {
match cfg.on_max {
MaxIterationsPolicy::Fail => Err(WorkflowError::MaxIterationsExceeded {
loop_id: cfg.id.to_string(),
max_iterations: cfg.max_iterations,
}
.into()),
MaxIterationsPolicy::ExitWithLast => Ok(ControlFlow::Break(LoopExit(inner))),
}
} else {
Ok(ControlFlow::Continue(LoopNext(inner)))
}
}
}
}
#[allow(unused_variables)]
pub(crate) trait LoopHooks: Send {
fn clear_body_tasks(&mut self, body: &WorkflowContinuation) {}
fn on_loop_exit(
&mut self,
loop_id: &str,
output: &Bytes,
) -> impl Future<Output = Result<(), RuntimeError>> + Send {
async { Ok(()) }
}
fn on_iteration_progress(
&mut self,
loop_id: &str,
next_iteration: u32,
body: &WorkflowContinuation,
) -> impl Future<Output = Result<(), RuntimeError>> + Send {
async { Ok(()) }
}
}
pub(crate) struct NoHooks;
impl LoopHooks for NoHooks {}
pub(crate) struct CheckpointingLoopHooks<'a, B> {
pub snapshot: &'a mut WorkflowSnapshot,
pub backend: &'a B,
pub track_position: bool,
}
impl<B: SnapshotStore> LoopHooks for CheckpointingLoopHooks<'_, B> {
fn clear_body_tasks(&mut self, body: &WorkflowContinuation) {
let body_ser = body.to_serializable();
for tid in &body_ser.task_ids() {
self.snapshot.remove_task_result(tid);
}
}
async fn on_loop_exit(&mut self, loop_id: &str, output: &Bytes) -> Result<(), RuntimeError> {
self.snapshot.clear_loop_iteration(loop_id);
self.snapshot
.mark_task_completed(loop_id.to_string(), output.clone());
self.backend.save_snapshot(self.snapshot).await?;
Ok(())
}
async fn on_iteration_progress(
&mut self,
loop_id: &str,
next_iteration: u32,
body: &WorkflowContinuation,
) -> Result<(), RuntimeError> {
self.snapshot.set_loop_iteration(loop_id, next_iteration);
if self.track_position {
self.snapshot.update_position(ExecutionPosition::InLoop {
loop_id: loop_id.to_string(),
iteration: next_iteration,
next_task_id: Some(body.first_task_id().to_string()),
});
}
self.backend.save_snapshot(self.snapshot).await?;
Ok(())
}
}
#[tracing::instrument(
name = "loop",
skip_all,
fields(loop_id = %cfg.id),
)]
pub(crate) async fn run_loop_async<F, Fut, H>(
cfg: &LoopConfig<'_>,
initial_input: Bytes,
execute_body: F,
hooks: &mut H,
) -> Result<Bytes, RuntimeError>
where
F: Fn(Bytes) -> Fut,
Fut: Future<Output = Result<Bytes, RuntimeError>>,
H: LoopHooks,
{
tracing::debug!("starting loop execution");
let mut loop_input = initial_input;
for iteration in cfg.start_iteration..cfg.max_iterations {
let output = execute_body(loop_input.clone()).await?;
hooks.clear_body_tasks(cfg.body);
match resolve_loop_iteration(&output, iteration, cfg)? {
ControlFlow::Break(LoopExit(inner)) => {
hooks.on_loop_exit(cfg.id, &inner).await?;
return Ok(inner);
}
ControlFlow::Continue(LoopNext(inner)) => {
hooks
.on_iteration_progress(cfg.id, iteration + 1, cfg.body)
.await?;
loop_input = inner;
}
}
}
Err(WorkflowError::MaxIterationsExceeded {
loop_id: cfg.id.to_string(),
max_iterations: cfg.max_iterations,
}
.into())
}