use async_trait::async_trait;
use deltaflow::{
HasEntityId, NoopRecorder, Pipeline, RunnerBuilder, SqliteTaskStore, Step, StepError, TaskStore,
};
use serde::{Deserialize, Serialize};
use sqlx::SqlitePool;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::Duration;
#[derive(Clone, Debug, Serialize, Deserialize)]
struct SlowInput {
id: String,
}
impl HasEntityId for SlowInput {
fn entity_id(&self) -> String {
self.id.clone()
}
}
struct SlowStep {
concurrent: Arc<AtomicUsize>,
max_observed: Arc<AtomicUsize>,
duration: Duration,
}
#[async_trait]
impl Step for SlowStep {
type Input = SlowInput;
type Output = ();
fn name(&self) -> &'static str {
"slow_step"
}
async fn execute(&self, _input: Self::Input) -> Result<Self::Output, StepError> {
let current = self.concurrent.fetch_add(1, Ordering::SeqCst) + 1;
self.max_observed.fetch_max(current, Ordering::SeqCst);
tokio::time::sleep(self.duration).await;
self.concurrent.fetch_sub(1, Ordering::SeqCst);
Ok(())
}
}
#[tokio::test]
async fn test_pipeline_concurrency_limit_enforced() {
let pool = SqlitePool::connect(":memory:").await.unwrap();
let store = SqliteTaskStore::new(pool.clone());
store.run_migrations().await.unwrap();
let concurrent = Arc::new(AtomicUsize::new(0));
let max_observed = Arc::new(AtomicUsize::new(0));
let pipeline = Pipeline::new("slow_pipeline")
.start_with(SlowStep {
concurrent: concurrent.clone(),
max_observed: max_observed.clone(),
duration: Duration::from_millis(100),
})
.with_recorder(NoopRecorder)
.build();
let runner = RunnerBuilder::new(store)
.max_concurrent(10) .pipeline_with_concurrency(pipeline, 1) .poll_interval(Duration::from_millis(10))
.build();
for i in 0..10 {
runner
.submit(
"slow_pipeline",
SlowInput {
id: format!("task_{}", i),
},
)
.await
.unwrap();
}
tokio::select! {
_ = runner.run() => {}
_ = tokio::time::sleep(Duration::from_secs(2)) => {}
}
let max = max_observed.load(Ordering::SeqCst);
assert_eq!(
max, 1,
"Pipeline concurrency limit violated: observed {} concurrent executions, expected max 1",
max
);
}
#[tokio::test]
async fn test_database_running_count_respects_concurrency() {
let pool = SqlitePool::connect(":memory:").await.unwrap();
let store = SqliteTaskStore::new(pool.clone());
store.run_migrations().await.unwrap();
let concurrent = Arc::new(AtomicUsize::new(0));
let max_observed = Arc::new(AtomicUsize::new(0));
let pipeline = Pipeline::new("limited_pipeline")
.start_with(SlowStep {
concurrent: concurrent.clone(),
max_observed: max_observed.clone(),
duration: Duration::from_millis(200),
})
.with_recorder(NoopRecorder)
.build();
let runner = RunnerBuilder::new(store)
.max_concurrent(5)
.pipeline_with_concurrency(pipeline, 1)
.poll_interval(Duration::from_millis(10))
.build();
for i in 0..20 {
runner
.submit(
"limited_pipeline",
SlowInput {
id: format!("task_{}", i),
},
)
.await
.unwrap();
}
let runner_handle = tokio::spawn(async move {
tokio::select! {
_ = runner.run() => {}
_ = tokio::time::sleep(Duration::from_secs(3)) => {}
}
});
let mut max_running_in_db = 0i64;
for _ in 0..20 {
tokio::time::sleep(Duration::from_millis(100)).await;
let count: i64 = sqlx::query_scalar(
"SELECT COUNT(*) FROM delta_tasks WHERE pipeline = 'limited_pipeline' AND status = 'running'"
)
.fetch_one(&pool)
.await
.unwrap();
if count > max_running_in_db {
max_running_in_db = count;
}
}
runner_handle.abort();
assert!(
max_running_in_db <= 1,
"Database shows {} tasks as 'running' for pipeline with concurrency limit 1",
max_running_in_db
);
}
#[tokio::test]
async fn test_orphan_tasks_recovered_on_startup() {
let pool = SqlitePool::connect(":memory:").await.unwrap();
let store = SqliteTaskStore::new(pool.clone());
store.run_migrations().await.unwrap();
sqlx::query(
r#"
INSERT INTO delta_tasks (pipeline, input, status, started_at)
VALUES ('orphan_pipeline', '{"id": "orphan_1"}', 'running', datetime('now', '-5 minutes'))
"#,
)
.execute(&pool)
.await
.unwrap();
store
.enqueue("orphan_pipeline", serde_json::json!({"id": "pending_1"}))
.await
.unwrap();
let running_count: i64 =
sqlx::query_scalar("SELECT COUNT(*) FROM delta_tasks WHERE status = 'running'")
.fetch_one(&pool)
.await
.unwrap();
assert_eq!(running_count, 1, "Setup: should have 1 running task");
let pending_count: i64 =
sqlx::query_scalar("SELECT COUNT(*) FROM delta_tasks WHERE status = 'pending'")
.fetch_one(&pool)
.await
.unwrap();
assert_eq!(pending_count, 1, "Setup: should have 1 pending task");
let recovered = store.recover_orphans().await.unwrap();
assert_eq!(recovered, 1, "Should recover 1 orphaned task");
let claimed = store.claim(10).await.unwrap();
assert_eq!(
claimed.len(),
2,
"Should claim both orphan and pending tasks, but only got {}",
claimed.len()
);
}