use std::future::Future;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::Duration;
use anyhow::{Context, Result, anyhow};
use ironflow::{
DeadLetter, DeadLetterQuery, OutboxStore, PgStore, RetryPolicy, RuntimeConfig, WorkflowBuilder,
WorkflowRuntime, WorkflowService, WorkflowServiceConfig,
};
use sqlx::PgPool;
use tokio::task::JoinHandle;
use super::db;
pub fn init_test_tracing() {
let _ = tracing_subscriber::fmt()
.with_env_filter("ironflow=debug")
.try_init();
}
pub fn assert_event_types(events: &[serde_json::Value], expected_types: &[&str]) {
assert_eq!(
events.len(),
expected_types.len(),
"event count mismatch: expected {}, got {}",
expected_types.len(),
events.len()
);
for (i, expected_type) in expected_types.iter().enumerate() {
assert_eq!(
events[i]["type"], *expected_type,
"event {i} type mismatch: expected {expected_type}, got {}",
events[i]["type"]
);
}
}
pub const TEST_MAX_ATTEMPTS: u32 = 3;
pub const DEFAULT_POLL_INTERVAL: Duration = Duration::from_millis(100);
pub const DEFAULT_TEST_TIMEOUT: Duration = Duration::from_secs(10);
pub const TEST_LOCK_DURATION: Duration = Duration::from_secs(30);
#[derive(Default)]
pub struct ConcurrencyTracker {
current: AtomicUsize,
max_seen: AtomicUsize,
}
impl ConcurrencyTracker {
pub fn new() -> Arc<Self> {
Arc::default()
}
pub fn enter(&self) {
let count = self.current.fetch_add(1, Ordering::SeqCst) + 1;
self.max_seen.fetch_max(count, Ordering::SeqCst);
}
pub fn exit(&self) {
self.current.fetch_sub(1, Ordering::SeqCst);
}
pub fn max_concurrent(&self) -> usize {
self.max_seen.load(Ordering::SeqCst)
}
}
pub fn test_runtime_config() -> RuntimeConfig {
RuntimeConfig {
effect_poll_interval: Duration::from_millis(50),
timer_poll_interval: Duration::from_millis(100),
shutdown_timeout: Duration::from_secs(5),
retry_policy: RetryPolicy {
max_attempts: TEST_MAX_ATTEMPTS,
base_delay: Duration::from_millis(50),
max_delay: Duration::from_millis(200),
},
..Default::default()
}
}
pub async fn wait_until<F, Fut, T>(timeout: Duration, interval: Duration, check: F) -> Result<T>
where
F: Fn() -> Fut,
Fut: Future<Output = Result<Option<T>>>,
{
let deadline = tokio::time::Instant::now() + timeout;
loop {
if let Some(result) = check().await? {
return Ok(result);
}
if tokio::time::Instant::now() > deadline {
return Err(anyhow!("timeout waiting for condition"));
}
tokio::time::sleep(interval).await;
}
}
pub struct TestApp {
pub store: PgStore,
pub service: Arc<WorkflowService<PgStore>>,
pool: PgPool,
max_attempts: u32,
shutdown: Option<tokio::sync::oneshot::Sender<()>>,
handle: Option<JoinHandle<anyhow::Result<()>>>,
}
pub struct TestAppBuilder<'a> {
pool: &'a PgPool,
builder: WorkflowBuilder<PgStore>,
runtime_config: RuntimeConfig,
}
impl<'a> TestAppBuilder<'a> {
pub fn new(pool: &'a PgPool) -> Self {
let store = PgStore::new(pool.clone());
Self {
pool,
builder: WorkflowRuntime::builder(store, WorkflowServiceConfig::default()),
runtime_config: test_runtime_config(),
}
}
pub fn register<H: ironflow::EffectHandler + Send + Sync + 'static>(
mut self,
handler: H,
) -> Self
where
H::Workflow: ironflow::Workflow + Send + Sync + 'static,
<H::Workflow as ironflow::Workflow>::State: Default + Send + Sync + serde::Serialize,
<H::Workflow as ironflow::Workflow>::Input:
serde::de::DeserializeOwned + ironflow::HasWorkflowId + Send + Sync,
<H::Workflow as ironflow::Workflow>::Event: serde::Serialize + Send + Sync,
<H::Workflow as ironflow::Workflow>::Effect:
serde::Serialize + serde::de::DeserializeOwned + Send + Sync,
H::Error: std::fmt::Display + Send + Sync,
{
self.builder = self.builder.register(handler);
self
}
pub fn config(mut self, config: RuntimeConfig) -> Self {
self.runtime_config = config;
self
}
pub async fn build_and_run(self) -> Result<TestApp> {
let max_attempts = self.runtime_config.retry_policy.max_attempts;
let engine = self
.builder
.config(self.runtime_config)
.build_engine()
.map_err(anyhow::Error::from)?;
let store = PgStore::new(self.pool.clone());
let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<()>();
let runtime = engine.runtime;
let handle = tokio::spawn(async move {
runtime
.run(async move {
let _ = shutdown_rx.await;
})
.await
.map_err(anyhow::Error::from)
});
Ok(TestApp {
store,
service: engine.service,
pool: self.pool.clone(),
max_attempts,
shutdown: Some(shutdown_tx),
handle: Some(handle),
})
}
}
impl TestApp {
pub fn builder(pool: &PgPool) -> TestAppBuilder<'_> {
TestAppBuilder::new(pool)
}
#[allow(dead_code)]
pub async fn shutdown(mut self) -> Result<()> {
if let Some(tx) = self.shutdown.take() {
let _ = tx.send(());
}
if let Some(handle) = self.handle.take() {
handle.await??;
}
Ok(())
}
pub async fn wait_for_events(
&self,
workflow_type: &str,
workflow_id: &str,
expected: usize,
timeout: Duration,
) -> Result<Vec<serde_json::Value>> {
let pool = &self.pool;
wait_until(timeout, DEFAULT_POLL_INTERVAL, || async {
let events = db::fetch_events(pool, workflow_type, workflow_id).await?;
if events.len() >= expected {
Ok(Some(events))
} else {
Ok(None)
}
})
.await
.with_context(|| format!("waiting for {expected} events on {workflow_type}/{workflow_id}"))
}
pub async fn count_pending_timers(
&self,
workflow_type: &str,
workflow_id: &str,
key: Option<&str>,
) -> Result<i64> {
db::count_timers(&self.pool, workflow_type, workflow_id, true, key)
.await
.map_err(Into::into)
}
pub async fn fetch_dead_letters(&self, query: DeadLetterQuery) -> Result<Vec<DeadLetter>> {
let dead_letters = self
.store
.fetch_dead_letters(&query, self.max_attempts)
.await?;
Ok(dead_letters)
}
pub async fn wait_for_dead_letter(
&self,
query: DeadLetterQuery,
timeout: Duration,
) -> Result<Vec<DeadLetter>> {
wait_until(timeout, DEFAULT_POLL_INTERVAL, || async {
let dead_letters = self.fetch_dead_letters(query.clone()).await?;
if dead_letters.is_empty() {
Ok(None)
} else {
Ok(Some(dead_letters))
}
})
.await
.context("waiting for dead letter")
}
pub async fn wait_for_effect_processed(
&self,
workflow_id: &str,
timeout: Duration,
) -> Result<()> {
let pool = &self.pool;
wait_until(timeout, DEFAULT_POLL_INTERVAL, || async {
if db::is_effect_processed(pool, workflow_id).await? {
Ok(Some(()))
} else {
Ok(None)
}
})
.await
.with_context(|| format!("waiting for effect to be processed: {workflow_id}"))
}
}
impl Drop for TestApp {
fn drop(&mut self) {
if let Some(tx) = self.shutdown.take() {
let _ = tx.send(());
}
}
}