use cloacina::database::universal_types::UniversalUuid;
use cloacina::executor::WorkflowExecutor;
use cloacina::runner::DefaultRunner;
use cloacina::*;
use serde_json::Value;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::Duration;
use tokio::time;
use crate::fixtures::get_or_init_fixture;
#[task(id = "deferred_flag_task", dependencies = [])]
async fn deferred_flag_task(
context: &mut Context<Value>,
handle: &mut TaskHandle,
) -> Result<(), TaskError> {
let poll_count = Arc::new(AtomicUsize::new(0));
let pc = poll_count.clone();
handle
.defer_until(
move || {
let pc = pc.clone();
async move {
let n = pc.fetch_add(1, Ordering::SeqCst);
n >= 2
}
},
Duration::from_millis(10),
)
.await
.map_err(|e| TaskError::ExecutionFailed {
message: format!("defer_until failed: {e}"),
task_id: "deferred_flag_task".into(),
timestamp: chrono::Utc::now(),
})?;
context.insert(
"deferred_result",
Value::String("resumed_after_defer".into()),
)?;
context.insert(
"poll_count",
Value::Number(serde_json::Number::from(
poll_count.load(Ordering::SeqCst) as u64
)),
)?;
Ok(())
}
#[task(id = "after_deferred_task", dependencies = ["deferred_flag_task"])]
async fn after_deferred_task(context: &mut Context<Value>) -> Result<(), TaskError> {
if let Some(val) = context.get("deferred_result") {
context.insert("chain_result", Value::String(format!("chained: {}", val)))?;
}
Ok(())
}
#[task(id = "slow_deferred_task", dependencies = [])]
async fn slow_deferred_task(
context: &mut Context<Value>,
handle: &mut TaskHandle,
) -> Result<(), TaskError> {
let poll_count = Arc::new(AtomicUsize::new(0));
let pc = poll_count.clone();
handle
.defer_until(
move || {
let pc = pc.clone();
async move {
let n = pc.fetch_add(1, Ordering::SeqCst);
n >= 4
}
},
Duration::from_millis(200),
)
.await
.map_err(|e| TaskError::ExecutionFailed {
message: format!("defer_until failed: {e}"),
task_id: "slow_deferred_task".into(),
timestamp: chrono::Utc::now(),
})?;
context.insert("slow_deferred_result", Value::String("completed".into()))?;
Ok(())
}
use async_trait::async_trait;
#[derive(Debug)]
struct SimpleTask {
id: String,
dependencies: Vec<TaskNamespace>,
}
impl SimpleTask {
fn new(id: &str, deps: Vec<&str>) -> Self {
Self {
id: id.to_string(),
dependencies: deps
.into_iter()
.map(|s| TaskNamespace::from_string(s).unwrap())
.collect(),
}
}
fn with_workflow(id: &str, deps: Vec<&str>, workflow_name: &str) -> Self {
Self {
id: id.to_string(),
dependencies: deps
.into_iter()
.map(|dep| TaskNamespace::new("public", "embedded", workflow_name, dep))
.collect(),
}
}
}
#[async_trait]
impl Task for SimpleTask {
async fn execute(&self, context: Context<Value>) -> Result<Context<Value>, TaskError> {
Ok(context)
}
fn id(&self) -> &str {
&self.id
}
fn dependencies(&self) -> &[TaskNamespace] {
&self.dependencies
}
}
#[tokio::test]
async fn test_defer_until_full_workflow() {
let fixture = get_or_init_fixture().await;
let mut fixture = fixture.lock().unwrap_or_else(|e| e.into_inner());
fixture.reset_database().await;
fixture.initialize().await;
let database_url = fixture.get_database_url();
let database = fixture.get_database();
let workflow = Workflow::builder("defer_pipeline")
.description("Workflow with deferred task")
.add_task(Arc::new(SimpleTask::new("deferred_flag_task", vec![])))
.unwrap()
.build()
.unwrap();
let runtime = cloacina::Runtime::empty();
let namespace = TaskNamespace::new(
workflow.tenant(),
workflow.package(),
workflow.name(),
"deferred_flag_task",
);
runtime.register_task(namespace, || {
Arc::new(deferred_flag_task_task()) as Arc<dyn cloacina::Task>
});
runtime.register_workflow("defer_pipeline".to_string(), {
let wf = workflow.clone();
move || wf.clone()
});
let schema = fixture.get_schema();
let runner = DefaultRunner::builder()
.database_url(&database_url)
.schema(&schema)
.runtime(runtime)
.build()
.await
.unwrap();
let input_context = Context::new();
let execution = runner
.execute_async("defer_pipeline", input_context)
.await
.unwrap();
let exec_id = execution.execution_id;
let dal = cloacina::dal::DAL::new(database.clone());
crate::fixtures::poll_until(
Duration::from_secs(10),
Duration::from_millis(100),
"deferred task should complete",
|| {
let dal = dal.clone();
async move {
let tasks = dal
.task_execution()
.get_all_tasks_for_workflow(UniversalUuid(exec_id))
.await
.unwrap_or_default();
tasks.len() == 1 && tasks[0].status == "Completed"
}
},
)
.await;
let task_executions = dal
.task_execution()
.get_all_tasks_for_workflow(UniversalUuid(exec_id))
.await
.unwrap();
assert_eq!(task_executions.len(), 1, "Expected 1 task execution");
let task = &task_executions[0];
assert_eq!(task.status, "Completed", "Deferred task should complete");
assert!(
task.sub_status.is_none(),
"sub_status should be None after completion, got: {:?}",
task.sub_status
);
runner.shutdown().await.unwrap();
}
#[tokio::test]
async fn test_defer_until_with_downstream_dependency() {
let fixture = get_or_init_fixture().await;
let mut fixture = fixture.lock().unwrap_or_else(|e| e.into_inner());
fixture.reset_database().await;
fixture.initialize().await;
let database_url = fixture.get_database_url();
let database = fixture.get_database();
let workflow = Workflow::builder("defer_chain_pipeline")
.description("Workflow with deferred task and downstream dependency")
.add_task(Arc::new(SimpleTask::new("deferred_flag_task", vec![])))
.unwrap()
.add_task(Arc::new(SimpleTask::with_workflow(
"after_deferred_task",
vec!["deferred_flag_task"],
"defer_chain_pipeline",
)))
.unwrap()
.build()
.unwrap();
let runtime = cloacina::Runtime::empty();
let ns1 = TaskNamespace::new(
workflow.tenant(),
workflow.package(),
workflow.name(),
"deferred_flag_task",
);
runtime.register_task(ns1, || {
Arc::new(deferred_flag_task_task()) as Arc<dyn cloacina::Task>
});
let ns2 = TaskNamespace::new(
workflow.tenant(),
workflow.package(),
workflow.name(),
"after_deferred_task",
);
runtime.register_task(ns2, || {
Arc::new(after_deferred_task_task()) as Arc<dyn cloacina::Task>
});
runtime.register_workflow("defer_chain_pipeline".to_string(), {
let wf = workflow.clone();
move || wf.clone()
});
let schema = fixture.get_schema();
let runner = DefaultRunner::builder()
.database_url(&database_url)
.schema(&schema)
.runtime(runtime)
.build()
.await
.unwrap();
let input_context = Context::new();
let execution = runner
.execute_async("defer_chain_pipeline", input_context)
.await
.unwrap();
let exec_id = execution.execution_id;
let dal = cloacina::dal::DAL::new(database.clone());
crate::fixtures::poll_until(
Duration::from_secs(10),
Duration::from_millis(100),
"both deferred and downstream tasks should complete",
|| {
let dal = dal.clone();
async move {
let tasks = dal
.task_execution()
.get_all_tasks_for_workflow(UniversalUuid(exec_id))
.await
.unwrap_or_default();
tasks.len() == 2 && tasks.iter().all(|t| t.status == "Completed")
}
},
)
.await;
let task_executions = dal
.task_execution()
.get_all_tasks_for_workflow(UniversalUuid(exec_id))
.await
.unwrap();
assert_eq!(task_executions.len(), 2, "Expected 2 task executions");
for task in &task_executions {
assert_eq!(
task.status, "Completed",
"Task '{}' should be Completed, got '{}'",
task.task_name, task.status
);
}
runner.shutdown().await.unwrap();
}
#[tokio::test]
async fn test_sub_status_transitions_during_deferral() {
let fixture = get_or_init_fixture().await;
let mut fixture = fixture.lock().unwrap_or_else(|e| e.into_inner());
fixture.reset_database().await;
fixture.initialize().await;
let database_url = fixture.get_database_url();
let database = fixture.get_database();
let workflow = Workflow::builder("sub_status_pipeline")
.description("Workflow for observing sub_status transitions")
.add_task(Arc::new(SimpleTask::new("slow_deferred_task", vec![])))
.unwrap()
.build()
.unwrap();
let runtime = cloacina::Runtime::empty();
let namespace = TaskNamespace::new(
workflow.tenant(),
workflow.package(),
workflow.name(),
"slow_deferred_task",
);
runtime.register_task(namespace, || {
Arc::new(slow_deferred_task_task()) as Arc<dyn cloacina::Task>
});
runtime.register_workflow("sub_status_pipeline".to_string(), {
let wf = workflow.clone();
move || wf.clone()
});
let schema = fixture.get_schema();
let runner = DefaultRunner::builder()
.database_url(&database_url)
.schema(&schema)
.runtime(runtime)
.build()
.await
.unwrap();
let input_context = Context::new();
let execution = runner
.execute_async("sub_status_pipeline", input_context)
.await
.unwrap();
let exec_id = execution.execution_id;
let dal = cloacina::dal::DAL::new(database.clone());
let mut saw_deferred = false;
for _ in 0..30 {
time::sleep(Duration::from_millis(100)).await;
let tasks = dal
.task_execution()
.get_all_tasks_for_workflow(UniversalUuid(exec_id))
.await
.unwrap();
if let Some(task) = tasks.first() {
if task.sub_status.as_deref() == Some("Deferred") {
saw_deferred = true;
break;
}
}
}
assert!(
saw_deferred,
"Should have observed sub_status='Deferred' during deferral"
);
crate::fixtures::poll_until(
Duration::from_secs(10),
Duration::from_millis(100),
"slow deferred task should complete",
|| {
let dal = dal.clone();
async move {
let tasks = dal
.task_execution()
.get_all_tasks_for_workflow(UniversalUuid(exec_id))
.await
.unwrap_or_default();
tasks.len() == 1 && tasks[0].status == "Completed"
}
},
)
.await;
let tasks = dal
.task_execution()
.get_all_tasks_for_workflow(UniversalUuid(exec_id))
.await
.unwrap();
assert_eq!(tasks.len(), 1);
let task = &tasks[0];
assert_eq!(task.status, "Completed");
assert!(
task.sub_status.is_none(),
"sub_status should be None after completion, got: {:?}",
task.sub_status
);
runner.shutdown().await.unwrap();
}