1use sqlx::postgres::PgPoolOptions;
4use sqlx::PgPool;
5use std::collections::HashMap;
6use std::time::Duration;
7
8pub fn database_url() -> String {
10 std::env::var("DATABASE_URL")
11 .unwrap_or_else(|_| "postgres://postgres:test@localhost:15432/awa_test".to_string())
12}
13
14pub fn database_url_with_app_name(app_name: &str) -> String {
16 let mut url = database_url();
17 let sep = if url.contains('?') { '&' } else { '?' };
18 url.push(sep);
19 url.push_str("application_name=");
20 url.push_str(app_name);
21 url
22}
23
24pub async fn pool(max_connections: u32) -> PgPool {
26 PgPoolOptions::new()
27 .max_connections(max_connections)
28 .connect(&database_url())
29 .await
30 .expect("Failed to connect to database")
31}
32
33pub async fn pool_with_url(url: &str, max_connections: u32) -> PgPool {
35 PgPoolOptions::new()
36 .max_connections(max_connections)
37 .connect(url)
38 .await
39 .expect("Failed to connect to database")
40}
41
42pub async fn setup(max_connections: u32) -> PgPool {
44 let pool = pool(max_connections).await;
45 awa_model::migrations::run(&pool)
46 .await
47 .expect("Failed to run migrations");
48 pool
49}
50
51pub async fn clean_queue(pool: &PgPool, queue: &str) {
58 sqlx::query("DELETE FROM awa.jobs WHERE queue = $1")
59 .bind(queue)
60 .execute(pool)
61 .await
62 .expect("Failed to clean queue jobs");
63 sqlx::query("DELETE FROM awa.queue_meta WHERE queue = $1")
64 .bind(queue)
65 .execute(pool)
66 .await
67 .expect("Failed to clean queue meta");
68 sqlx::query("DELETE FROM awa.queue_state_counts WHERE queue = $1")
69 .bind(queue)
70 .execute(pool)
71 .await
72 .expect("Failed to clean queue state counts");
73}
74
75pub async fn queue_state_counts(pool: &PgPool, queue: &str) -> HashMap<String, i64> {
77 let rows: Vec<(String, i64)> = sqlx::query_as(
78 r#"
79 SELECT state::text, count(*)::bigint
80 FROM awa.jobs
81 WHERE queue = $1
82 GROUP BY state
83 "#,
84 )
85 .bind(queue)
86 .fetch_all(pool)
87 .await
88 .expect("Failed to query state counts");
89
90 rows.into_iter().collect()
91}
92
93pub fn state_count(counts: &HashMap<String, i64>, state: &str) -> i64 {
95 counts.get(state).copied().unwrap_or(0)
96}
97
98pub async fn wait_for_counts(
100 pool: &PgPool,
101 queue: &str,
102 predicate: impl Fn(&HashMap<String, i64>) -> bool,
103 timeout: Duration,
104) -> HashMap<String, i64> {
105 let start = std::time::Instant::now();
106 loop {
107 let counts = queue_state_counts(pool, queue).await;
108 if predicate(&counts) {
109 return counts;
110 }
111 assert!(
112 start.elapsed() < timeout,
113 "Timed out waiting for queue {queue} counts; last counts: {counts:?}"
114 );
115 tokio::time::sleep(Duration::from_millis(50)).await;
116 }
117}