#![cfg(feature = "postgres-tests")]
use std::collections::HashSet;
use std::sync::{Arc, Mutex};
use sea_orm::{Database, DatabaseConnection};
use sea_orm_migration::MigratorTrait;
use ferro_queue::{claim, delete_job, enqueue, CreateJobsTable};
struct TestMigrator;
#[async_trait::async_trait]
impl MigratorTrait for TestMigrator {
fn migrations() -> Vec<Box<dyn sea_orm_migration::MigrationTrait>> {
vec![Box::new(CreateJobsTable)]
}
}
async fn fresh_pg_db() -> Option<DatabaseConnection> {
let url = std::env::var("DATABASE_URL").ok()?;
let conn = Database::connect(&url).await.expect("connect to postgres");
let _ = TestMigrator::down(&conn, None).await;
TestMigrator::up(&conn, None).await.expect("migrate");
Some(conn)
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn two_workers_claim_each_job_exactly_once_postgres() {
if std::env::var("DATABASE_URL").is_err() {
eprintln!("DATABASE_URL not set — skipping postgres race test");
return;
}
let conn_setup = fresh_pg_db().await.expect("DATABASE_URL checked above");
const N: usize = 20;
let now = chrono::Utc::now();
for _ in 0..N {
enqueue(
&conn_setup,
"default",
"TestJobPg",
"{}",
3,
None,
None,
now,
)
.await
.expect("enqueue failed");
}
let db_url = std::env::var("DATABASE_URL").unwrap();
let conn1 = Database::connect(&db_url)
.await
.expect("connect conn1 to postgres");
let conn2 = Database::connect(&db_url)
.await
.expect("connect conn2 to postgres");
async fn drain(
conn: sea_orm::DatabaseConnection,
worker_id: &'static str,
out: Arc<Mutex<Vec<i64>>>,
) {
loop {
match claim(&conn, "default", worker_id).await {
Ok(Some(row)) => {
out.lock().unwrap().push(row.id);
let _ = delete_job(&conn, row.id).await;
}
Ok(None) => break,
Err(e) => panic!("claim error: {e:?}"),
}
}
}
let c1: Arc<Mutex<Vec<i64>>> = Arc::new(Mutex::new(Vec::new()));
let c2: Arc<Mutex<Vec<i64>>> = Arc::new(Mutex::new(Vec::new()));
let (h1, h2) = (
tokio::spawn(drain(conn1, "w1", c1.clone())),
tokio::spawn(drain(conn2, "w2", c2.clone())),
);
let _ = tokio::join!(h1, h2);
let mut all: Vec<i64> = c1.lock().unwrap().clone();
all.extend(c2.lock().unwrap().iter().cloned());
let unique: HashSet<i64> = all.iter().cloned().collect();
assert_eq!(
unique.len(),
all.len(),
"a job was claimed more than once (total claimed: {}, unique: {})",
all.len(),
unique.len()
);
assert_eq!(
unique.len(),
N,
"not all jobs were claimed exactly once (expected {N}, got {})",
unique.len()
);
}