use std::collections::HashMap;
use std::future::Future;
use std::sync::Arc;
use std::time::{Duration, SystemTime};
use taquba::{EnqueueOptions, JobRecord, PermanentFailure, Queue, Worker, WorkerError};
use tokio::sync::Mutex;
use tracing::{debug, instrument, warn};
use crate::error::{Error, Result};
use crate::runner::{Step, StepError, StepErrorKind, StepOutcome, StepRunner};
use crate::terminal::{RunOutcome, TerminalHook, TerminalStatus};
pub const HEADER_RUN_ID: &str = "workflow.run_id";
pub const HEADER_STEP: &str = "workflow.step";
pub const RESERVED_HEADER_PREFIX: &str = "workflow.";
const DEDUP_PREFIX: &str = "run:";
#[derive(Debug, Default)]
struct StepEnqueueOpts {
run_at: Option<SystemTime>,
priority: Option<u32>,
max_attempts: Option<u32>,
}
#[derive(Debug, Clone, Default)]
pub struct RunSpec {
pub run_id: Option<String>,
pub input: Vec<u8>,
pub headers: HashMap<String, String>,
pub priority: Option<u32>,
pub max_attempts_per_step: Option<u32>,
}
#[derive(Debug, Clone)]
pub struct RunHandle {
pub run_id: String,
pub first_job_id: String,
}
#[derive(Debug, Clone)]
pub struct RunStatus {
pub run_id: String,
pub state: RunState,
pub current_step: u32,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum RunState {
Pending,
Running,
}
pub struct WorkflowRuntimeBuilder<R, H> {
queue: Arc<Queue>,
queue_name: String,
runner: R,
terminal_hook: H,
max_concurrent_steps: usize,
poll_interval: Duration,
}
impl<R: StepRunner, H: TerminalHook> WorkflowRuntimeBuilder<R, H> {
pub fn queue_name(mut self, name: impl Into<String>) -> Self {
self.queue_name = name.into();
self
}
pub fn max_concurrent_steps(mut self, n: usize) -> Self {
assert!(n > 0, "max_concurrent_steps must be at least 1");
self.max_concurrent_steps = n;
self
}
pub fn poll_interval(mut self, interval: Duration) -> Self {
self.poll_interval = interval;
self
}
pub fn build(self) -> WorkflowRuntime<R, H> {
let inner = RuntimeInner {
queue: self.queue,
queue_name: self.queue_name,
runner: self.runner,
terminal_hook: self.terminal_hook,
max_concurrent_steps: self.max_concurrent_steps,
poll_interval: self.poll_interval,
registry: Mutex::new(HashMap::new()),
};
WorkflowRuntime {
inner: Arc::new(inner),
}
}
}
pub struct WorkflowRuntime<R, H> {
inner: Arc<RuntimeInner<R, H>>,
}
impl<R, H> Clone for WorkflowRuntime<R, H> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
}
}
}
struct RuntimeInner<R, H> {
queue: Arc<Queue>,
queue_name: String,
runner: R,
terminal_hook: H,
max_concurrent_steps: usize,
poll_interval: Duration,
registry: Mutex<HashMap<String, RunStatus>>,
}
impl<R: StepRunner, H: TerminalHook> WorkflowRuntime<R, H> {
pub fn builder(queue: Arc<Queue>, runner: R, terminal_hook: H) -> WorkflowRuntimeBuilder<R, H> {
WorkflowRuntimeBuilder {
queue,
queue_name: "workflow-steps".to_string(),
runner,
terminal_hook,
max_concurrent_steps: 16,
poll_interval: Duration::from_millis(250),
}
}
#[instrument(skip(self, spec), fields(run_id))]
pub async fn submit(&self, spec: RunSpec) -> Result<RunHandle> {
let run_id = spec.run_id.unwrap_or_else(|| ulid::Ulid::new().to_string());
tracing::Span::current().record("run_id", run_id.as_str());
for k in spec.headers.keys() {
if k.starts_with(RESERVED_HEADER_PREFIX) {
return Err(Error::ReservedHeaderInSubmit(k.clone()));
}
}
let mut registry = self.inner.registry.lock().await;
if registry.contains_key(&run_id) {
return Err(Error::DuplicateRun(run_id));
}
let job_id = self
.inner
.enqueue_step(
&run_id,
0,
spec.input,
&spec.headers,
StepEnqueueOpts {
priority: spec.priority,
max_attempts: spec.max_attempts_per_step,
..Default::default()
},
)
.await?;
registry.insert(
run_id.clone(),
RunStatus {
run_id: run_id.clone(),
state: RunState::Pending,
current_step: 0,
},
);
drop(registry);
debug!(run_id = %run_id, job_id = %job_id, "run submitted");
Ok(RunHandle {
run_id,
first_job_id: job_id,
})
}
pub async fn status(&self, run_id: &str) -> Option<RunStatus> {
self.inner.registry.lock().await.get(run_id).cloned()
}
pub async fn run<F>(&self, shutdown: F) -> Result<()>
where
F: Future<Output = ()>,
R: 'static,
H: 'static,
{
let worker = Arc::new(StepWorker {
inner: self.inner.clone(),
});
taquba::run_worker_concurrent(
&self.inner.queue,
&self.inner.queue_name,
worker,
self.inner.max_concurrent_steps,
self.inner.poll_interval,
shutdown,
)
.await?;
Ok(())
}
}
struct StepWorker<R, H> {
inner: Arc<RuntimeInner<R, H>>,
}
impl<R: StepRunner + 'static, H: TerminalHook + 'static> Worker for StepWorker<R, H> {
async fn process(&self, job: &JobRecord) -> std::result::Result<(), WorkerError> {
self.inner.process_step(job).await
}
}
impl<R: StepRunner, H: TerminalHook> RuntimeInner<R, H> {
async fn enqueue_step(
&self,
run_id: &str,
step_number: u32,
payload: Vec<u8>,
user_headers: &HashMap<String, String>,
opts: StepEnqueueOpts,
) -> Result<String> {
let mut headers = user_headers.clone();
headers.insert(HEADER_RUN_ID.to_string(), run_id.to_string());
headers.insert(HEADER_STEP.to_string(), step_number.to_string());
let enqueue_opts = EnqueueOptions {
headers,
run_at: opts.run_at,
priority: opts.priority,
max_attempts: opts.max_attempts,
dedup_key: Some(format!("{DEDUP_PREFIX}{run_id}:{step_number}")),
};
Ok(self
.queue
.enqueue_with(&self.queue_name, payload, enqueue_opts)
.await?)
}
fn split_headers(headers: &HashMap<String, String>) -> HashMap<String, String> {
headers
.iter()
.filter(|(k, _)| !k.starts_with(RESERVED_HEADER_PREFIX))
.map(|(k, v)| (k.clone(), v.clone()))
.collect()
}
fn parse_step_headers(job: &JobRecord) -> std::result::Result<(String, u32), Error> {
let run_id = job
.headers
.get(HEADER_RUN_ID)
.ok_or(Error::MissingHeader(HEADER_RUN_ID))?
.to_string();
let step_str = job
.headers
.get(HEADER_STEP)
.ok_or(Error::MissingHeader(HEADER_STEP))?;
let step_number: u32 = step_str.parse().map_err(|_| Error::InvalidStepHeader {
header: HEADER_STEP,
value: step_str.clone(),
})?;
Ok((run_id, step_number))
}
async fn fire_terminal_hook(&self, outcome: RunOutcome) {
self.terminal_hook.on_termination(&outcome).await;
}
async fn registry_remove(&self, run_id: &str) {
self.registry.lock().await.remove(run_id);
}
async fn registry_set(&self, status: RunStatus) {
self.registry
.lock()
.await
.insert(status.run_id.clone(), status);
}
async fn process_step(&self, job: &JobRecord) -> std::result::Result<(), WorkerError> {
let (run_id, step_number) = match Self::parse_step_headers(job) {
Ok(v) => v,
Err(e) => {
warn!(job_id = %job.id, error = %e, "workflow step has malformed headers");
if e.is_permanent() {
return Err(PermanentFailure::new(e.to_string()).into());
}
return Err(e.to_string().into());
}
};
let user_headers = Self::split_headers(&job.headers);
self.registry_set(RunStatus {
run_id: run_id.clone(),
state: RunState::Running,
current_step: step_number,
})
.await;
let step = Step {
run_id: run_id.clone(),
step_number,
payload: job.payload.clone(),
headers: user_headers.clone(),
job_id: job.id.clone(),
attempts: job.attempts,
};
let inherit_opts = || StepEnqueueOpts {
run_at: None,
priority: Some(job.priority),
max_attempts: Some(job.max_attempts),
};
match self.runner.run_step(&step).await {
Ok(StepOutcome::Continue { payload }) => {
self.advance(
&run_id,
step_number + 1,
payload,
&user_headers,
inherit_opts(),
)
.await
}
Ok(StepOutcome::ContinueAfter { payload, delay }) => {
let opts = StepEnqueueOpts {
run_at: Some(SystemTime::now() + delay),
..inherit_opts()
};
self.advance(&run_id, step_number + 1, payload, &user_headers, opts)
.await
}
Ok(StepOutcome::Succeed { result }) => {
self.fire_terminal_hook(RunOutcome {
run_id: run_id.clone(),
status: TerminalStatus::Succeeded,
result: Some(result),
error: None,
headers: user_headers,
final_step: step_number,
})
.await;
self.registry_remove(&run_id).await;
Ok(())
}
Ok(StepOutcome::Fail { reason }) => {
self.fire_terminal_hook(RunOutcome {
run_id: run_id.clone(),
status: TerminalStatus::Failed,
result: None,
error: Some(reason),
headers: user_headers,
final_step: step_number,
})
.await;
self.registry_remove(&run_id).await;
Ok(())
}
Err(StepError {
message,
kind: StepErrorKind::Permanent,
}) => {
self.fire_terminal_hook(RunOutcome {
run_id: run_id.clone(),
status: TerminalStatus::Failed,
result: None,
error: Some(message.clone()),
headers: user_headers,
final_step: step_number,
})
.await;
self.registry_remove(&run_id).await;
Err(PermanentFailure::new(message).into())
}
Err(StepError {
message,
kind: StepErrorKind::Transient,
}) => {
if job.attempts >= job.max_attempts {
self.fire_terminal_hook(RunOutcome {
run_id: run_id.clone(),
status: TerminalStatus::Failed,
result: None,
error: Some(message.clone()),
headers: user_headers,
final_step: step_number,
})
.await;
self.registry_remove(&run_id).await;
}
Err(message.into())
}
}
}
async fn advance(
&self,
run_id: &str,
next_step: u32,
payload: Vec<u8>,
user_headers: &HashMap<String, String>,
opts: StepEnqueueOpts,
) -> std::result::Result<(), WorkerError> {
match self
.enqueue_step(run_id, next_step, payload, user_headers, opts)
.await
{
Ok(_) => {
self.registry_set(RunStatus {
run_id: run_id.to_string(),
state: RunState::Pending,
current_step: next_step,
})
.await;
Ok(())
}
Err(e) => Err(e.to_string().into()),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::terminal::NoopTerminalHook;
use std::sync::Mutex as StdMutex;
use std::sync::atomic::{AtomicU32, Ordering};
use taquba::object_store::memory::InMemory;
use taquba::{OpenOptions, QueueConfig};
use tokio::sync::oneshot;
struct ChannelHook {
tx: tokio::sync::mpsc::UnboundedSender<RunOutcome>,
}
impl TerminalHook for ChannelHook {
async fn on_termination(&self, outcome: &RunOutcome) {
let _ = self.tx.send(outcome.clone());
}
}
struct ScriptedRunner {
script: Arc<StdMutex<Vec<StepOutcome>>>,
}
impl ScriptedRunner {
fn new(steps: Vec<StepOutcome>) -> Self {
Self {
script: Arc::new(StdMutex::new(steps)),
}
}
}
impl StepRunner for ScriptedRunner {
async fn run_step(&self, _step: &Step) -> std::result::Result<StepOutcome, StepError> {
let next = self.script.lock().unwrap().remove(0);
Ok(next)
}
}
async fn fresh_queue() -> Arc<Queue> {
Arc::new(
Queue::open(Arc::new(InMemory::new()), "test")
.await
.unwrap(),
)
}
async fn fresh_queue_fast_retry() -> Arc<Queue> {
let opts = OpenOptions {
default_queue_config: QueueConfig {
retry_backoff_base: Duration::ZERO,
..QueueConfig::default()
},
reaper_interval: Duration::from_millis(50),
scheduler_interval: Duration::from_millis(50),
..OpenOptions::default()
};
Arc::new(
Queue::open_with_options(Arc::new(InMemory::new()), "test", opts)
.await
.unwrap(),
)
}
fn spawn_runtime<R, H>(runtime: WorkflowRuntime<R, H>) -> oneshot::Sender<()>
where
R: StepRunner + 'static,
H: TerminalHook + 'static,
{
let (tx, rx) = oneshot::channel::<()>();
tokio::spawn(async move {
let _ = runtime
.run(async move {
let _ = rx.await;
})
.await;
});
tx
}
#[tokio::test]
async fn single_step_succeeds_and_fires_hook() {
let queue = fresh_queue().await;
let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
let runtime = WorkflowRuntime::builder(
queue,
ScriptedRunner::new(vec![StepOutcome::Succeed {
result: b"done".to_vec(),
}]),
ChannelHook { tx },
)
.build();
let shutdown = spawn_runtime(runtime.clone());
let handle = runtime
.submit(RunSpec {
input: b"in".to_vec(),
..Default::default()
})
.await
.unwrap();
let outcome = tokio::time::timeout(Duration::from_secs(2), rx.recv())
.await
.unwrap()
.unwrap();
assert_eq!(outcome.run_id, handle.run_id);
assert_eq!(outcome.status, TerminalStatus::Succeeded);
assert_eq!(outcome.result.as_deref(), Some(b"done".as_slice()));
assert_eq!(outcome.final_step, 0);
assert!(runtime.status(&handle.run_id).await.is_none());
let _ = shutdown.send(());
}
#[tokio::test]
async fn multi_step_run_advances_through_continue() {
let queue = fresh_queue().await;
let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
let runtime = WorkflowRuntime::builder(
queue,
ScriptedRunner::new(vec![
StepOutcome::Continue {
payload: b"step1".to_vec(),
},
StepOutcome::Continue {
payload: b"step2".to_vec(),
},
StepOutcome::Succeed {
result: b"final".to_vec(),
},
]),
ChannelHook { tx },
)
.build();
let shutdown = spawn_runtime(runtime.clone());
let handle = runtime
.submit(RunSpec {
input: b"start".to_vec(),
..Default::default()
})
.await
.unwrap();
let outcome = tokio::time::timeout(Duration::from_secs(2), rx.recv())
.await
.unwrap()
.unwrap();
assert_eq!(outcome.run_id, handle.run_id);
assert_eq!(outcome.final_step, 2);
assert_eq!(outcome.status, TerminalStatus::Succeeded);
assert_eq!(outcome.result.as_deref(), Some(b"final".as_slice()));
let _ = shutdown.send(());
}
#[tokio::test]
async fn permanent_failure_dead_letters_and_fires_hook() {
struct FailingRunner;
impl StepRunner for FailingRunner {
async fn run_step(&self, _step: &Step) -> std::result::Result<StepOutcome, StepError> {
Err(StepError::permanent("nope"))
}
}
let queue = fresh_queue().await;
let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
let runtime =
WorkflowRuntime::builder(queue.clone(), FailingRunner, ChannelHook { tx }).build();
let shutdown = spawn_runtime(runtime.clone());
let handle = runtime
.submit(RunSpec {
input: b"x".to_vec(),
..Default::default()
})
.await
.unwrap();
let outcome = tokio::time::timeout(Duration::from_secs(2), rx.recv())
.await
.unwrap()
.unwrap();
assert_eq!(outcome.run_id, handle.run_id);
assert_eq!(outcome.status, TerminalStatus::Failed);
assert_eq!(outcome.error.as_deref(), Some("nope"));
assert!(runtime.status(&handle.run_id).await.is_none());
let stats = queue.stats("workflow-steps").await.unwrap();
assert_eq!(stats.dead, 1, "permanent error should dead-letter");
let _ = shutdown.send(());
}
#[tokio::test]
async fn fail_outcome_terminates_run_without_dead_letter() {
struct VerdictRunner;
impl StepRunner for VerdictRunner {
async fn run_step(&self, _step: &Step) -> std::result::Result<StepOutcome, StepError> {
Ok(StepOutcome::Fail {
reason: "agent declined the task".to_string(),
})
}
}
let queue = fresh_queue().await;
let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
let runtime =
WorkflowRuntime::builder(queue.clone(), VerdictRunner, ChannelHook { tx }).build();
let shutdown = spawn_runtime(runtime.clone());
let handle = runtime
.submit(RunSpec {
input: b"x".to_vec(),
..Default::default()
})
.await
.unwrap();
let outcome = tokio::time::timeout(Duration::from_secs(2), rx.recv())
.await
.expect("hook fired in time")
.expect("hook channel open");
assert_eq!(outcome.run_id, handle.run_id);
assert_eq!(outcome.status, TerminalStatus::Failed);
assert_eq!(outcome.error.as_deref(), Some("agent declined the task"));
assert!(runtime.status(&handle.run_id).await.is_none());
let stats = queue.stats("workflow-steps").await.unwrap();
assert_eq!(stats.dead, 0, "Fail verdict must not dead-letter");
let _ = shutdown.send(());
}
#[tokio::test]
async fn duplicate_submit_in_process_is_rejected() {
struct PauseRunner;
impl StepRunner for PauseRunner {
async fn run_step(&self, _step: &Step) -> std::result::Result<StepOutcome, StepError> {
std::future::pending().await
}
}
let queue = fresh_queue().await;
let runtime = WorkflowRuntime::builder(queue, PauseRunner, NoopTerminalHook).build();
let shutdown = spawn_runtime(runtime.clone());
let handle = runtime
.submit(RunSpec {
run_id: Some("fixed-id".to_string()),
input: b"x".to_vec(),
..Default::default()
})
.await
.unwrap();
for _ in 0..40 {
if runtime.status(&handle.run_id).await.is_some() {
break;
}
tokio::time::sleep(Duration::from_millis(25)).await;
}
assert!(runtime.status(&handle.run_id).await.is_some());
let err = runtime
.submit(RunSpec {
run_id: Some("fixed-id".to_string()),
input: b"y".to_vec(),
..Default::default()
})
.await
.unwrap_err();
assert!(matches!(err, Error::DuplicateRun(id) if id == "fixed-id"));
let _ = shutdown.send(());
}
#[tokio::test]
async fn reserved_header_on_submit_is_rejected() {
let queue = fresh_queue().await;
let runtime: WorkflowRuntime<ScriptedRunner, NoopTerminalHook> =
WorkflowRuntime::builder(queue, ScriptedRunner::new(vec![]), NoopTerminalHook).build();
let mut headers = HashMap::new();
headers.insert("workflow.run_id".to_string(), "evil".to_string());
let err = runtime
.submit(RunSpec {
input: b"x".to_vec(),
headers,
..Default::default()
})
.await
.unwrap_err();
assert!(
matches!(&err, Error::ReservedHeaderInSubmit(k) if k == "workflow.run_id"),
"got: {err:?}"
);
}
#[tokio::test]
async fn user_headers_thread_through_to_terminal_hook() {
let queue = fresh_queue().await;
let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
let runtime = WorkflowRuntime::builder(
queue,
ScriptedRunner::new(vec![
StepOutcome::Continue { payload: vec![] },
StepOutcome::Succeed { result: vec![] },
]),
ChannelHook { tx },
)
.build();
let shutdown = spawn_runtime(runtime.clone());
let mut headers = HashMap::new();
headers.insert("trace_id".to_string(), "abc-123".to_string());
headers.insert("tenant".to_string(), "acme".to_string());
runtime
.submit(RunSpec {
input: b"x".to_vec(),
headers,
..Default::default()
})
.await
.unwrap();
let outcome = tokio::time::timeout(Duration::from_secs(2), rx.recv())
.await
.unwrap()
.unwrap();
assert_eq!(outcome.headers.get("trace_id").unwrap(), "abc-123");
assert_eq!(outcome.headers.get("tenant").unwrap(), "acme");
assert!(!outcome.headers.contains_key(HEADER_RUN_ID));
assert!(!outcome.headers.contains_key(HEADER_STEP));
let _ = shutdown.send(());
}
#[tokio::test]
async fn restart_resumes_at_next_step() {
struct GatedRunner {
gate: tokio::sync::Mutex<Option<oneshot::Receiver<Vec<u8>>>>,
}
impl StepRunner for GatedRunner {
async fn run_step(&self, step: &Step) -> std::result::Result<StepOutcome, StepError> {
match step.step_number {
0 => {
let rx = self.gate.lock().await.take().expect("gate consumed twice");
let payload = rx.await.expect("gate sender dropped");
Ok(StepOutcome::Continue { payload })
}
_ => std::future::pending().await,
}
}
}
struct CompleteOnStep1;
impl StepRunner for CompleteOnStep1 {
async fn run_step(&self, step: &Step) -> std::result::Result<StepOutcome, StepError> {
assert_eq!(step.step_number, 1, "runtime B should only ever see step 1");
assert_eq!(step.payload.as_slice(), b"step1-payload");
Ok(StepOutcome::Succeed {
result: b"resumed".to_vec(),
})
}
}
let queue = fresh_queue().await;
let (gate_tx, gate_rx) = oneshot::channel::<Vec<u8>>();
let runtime_a = WorkflowRuntime::builder(
queue.clone(),
GatedRunner {
gate: tokio::sync::Mutex::new(Some(gate_rx)),
},
NoopTerminalHook,
)
.max_concurrent_steps(1)
.build();
let (shutdown_a_tx, shutdown_a_rx) = oneshot::channel::<()>();
let worker_a = {
let runtime_a = runtime_a.clone();
tokio::spawn(async move {
let _ = runtime_a
.run(async move {
let _ = shutdown_a_rx.await;
})
.await;
})
};
let handle = runtime_a
.submit(RunSpec {
input: b"input".to_vec(),
..Default::default()
})
.await
.unwrap();
for _ in 0..80 {
if let Some(s) = runtime_a.status(&handle.run_id).await {
if s.state == RunState::Running && s.current_step == 0 {
break;
}
}
tokio::time::sleep(Duration::from_millis(25)).await;
}
let s = runtime_a.status(&handle.run_id).await.expect("status");
assert_eq!(s.state, RunState::Running);
assert_eq!(s.current_step, 0);
let _ = shutdown_a_tx.send(());
let _ = gate_tx.send(b"step1-payload".to_vec());
worker_a.await.expect("runtime A drained cleanly");
let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
let runtime_b =
WorkflowRuntime::builder(queue, CompleteOnStep1, ChannelHook { tx }).build();
let shutdown_b = spawn_runtime(runtime_b.clone());
let outcome = tokio::time::timeout(Duration::from_secs(2), rx.recv())
.await
.expect("hook fired in time")
.expect("hook channel open");
assert_eq!(outcome.run_id, handle.run_id);
assert_eq!(outcome.status, TerminalStatus::Succeeded);
assert_eq!(outcome.result.as_deref(), Some(b"resumed".as_slice()));
assert_eq!(outcome.final_step, 1);
let _ = shutdown_b.send(());
}
async fn assert_transient_retries_until_max(max_attempts: u32) {
struct AlwaysTransient {
calls: Arc<AtomicU32>,
}
impl StepRunner for AlwaysTransient {
async fn run_step(&self, _step: &Step) -> std::result::Result<StepOutcome, StepError> {
self.calls.fetch_add(1, Ordering::SeqCst);
Err(StepError::transient("flaky"))
}
}
let queue = fresh_queue_fast_retry().await;
let calls = Arc::new(AtomicU32::new(0));
let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
let runtime = WorkflowRuntime::builder(
queue,
AlwaysTransient {
calls: calls.clone(),
},
ChannelHook { tx },
)
.build();
let shutdown = spawn_runtime(runtime.clone());
runtime
.submit(RunSpec {
input: b"x".to_vec(),
max_attempts_per_step: Some(max_attempts),
..Default::default()
})
.await
.unwrap();
let outcome = tokio::time::timeout(Duration::from_secs(3), rx.recv())
.await
.expect("hook fired in time")
.expect("hook channel open");
assert_eq!(outcome.status, TerminalStatus::Failed);
assert_eq!(outcome.error.as_deref(), Some("flaky"));
assert_eq!(
calls.load(Ordering::SeqCst),
max_attempts,
"runner called once per attempt up to max_attempts"
);
tokio::time::sleep(Duration::from_millis(50)).await;
assert!(rx.try_recv().is_err(), "hook fired more than once");
let _ = shutdown.send(());
}
#[tokio::test]
async fn transient_fires_once_on_single_attempt() {
assert_transient_retries_until_max(1).await;
}
#[tokio::test]
async fn transient_retries_up_to_max_attempts() {
assert_transient_retries_until_max(3).await;
}
}