1use awa_model::{AwaError, JobArgs, JobRow};
6use awa_worker::context::ProgressState;
7use awa_worker::{JobContext, JobError, JobResult, Worker};
8use sqlx::PgPool;
9use std::any::Any;
10use std::collections::HashMap;
11use std::sync::atomic::AtomicBool;
12use std::sync::Arc;
13
14pub struct TestClient {
18 pool: PgPool,
19}
20
21impl TestClient {
22 pub async fn from_pool(pool: PgPool) -> Self {
24 Self { pool }
25 }
26
27 pub fn pool(&self) -> &PgPool {
29 &self.pool
30 }
31
32 pub async fn migrate(&self) -> Result<(), AwaError> {
34 awa_model::migrations::run(&self.pool).await
35 }
36
37 pub async fn clean(&self) -> Result<(), AwaError> {
39 sqlx::query("DELETE FROM awa.jobs")
40 .execute(&self.pool)
41 .await?;
42 sqlx::query("DELETE FROM awa.queue_meta")
43 .execute(&self.pool)
44 .await?;
45 Ok(())
46 }
47
48 pub async fn insert(&self, args: &impl JobArgs) -> Result<JobRow, AwaError> {
50 awa_model::insert(&self.pool, args).await
51 }
52
53 pub async fn work_one<W: Worker>(&self, worker: &W) -> Result<WorkResult, AwaError> {
58 self.work_one_in_queue(worker, None).await
59 }
60
61 pub async fn work_one_in_queue<W: Worker>(
63 &self,
64 worker: &W,
65 queue: Option<&str>,
66 ) -> Result<WorkResult, AwaError> {
67 let jobs: Vec<JobRow> = sqlx::query_as::<_, JobRow>(
69 r#"
70 WITH claimed AS (
71 SELECT id FROM awa.jobs
72 WHERE state = 'available' AND kind = $1
73 AND ($2::text IS NULL OR queue = $2)
74 ORDER BY run_at ASC, id ASC
75 LIMIT 1
76 FOR UPDATE SKIP LOCKED
77 )
78 UPDATE awa.jobs
79 SET state = 'running',
80 attempt = attempt + 1,
81 run_lease = run_lease + 1,
82 attempted_at = now(),
83 heartbeat_at = now(),
84 deadline_at = now() + interval '5 minutes'
85 FROM claimed
86 WHERE awa.jobs.id = claimed.id
87 RETURNING awa.jobs.*
88 "#,
89 )
90 .bind(worker.kind())
91 .bind(queue)
92 .fetch_all(&self.pool)
93 .await?;
94
95 let job = match jobs.into_iter().next() {
96 Some(job) => job,
97 None => return Ok(WorkResult::NoJob),
98 };
99
100 let cancel = Arc::new(AtomicBool::new(false));
101 let state: Arc<HashMap<std::any::TypeId, Box<dyn Any + Send + Sync>>> =
102 Arc::new(HashMap::new());
103 let progress = Arc::new(std::sync::Mutex::new(ProgressState::new(
104 job.progress.clone(),
105 )));
106 let ctx = JobContext::new(
107 job.clone(),
108 cancel,
109 state,
110 self.pool.clone(),
111 progress.clone(),
112 );
113
114 let result = worker.perform(&job, &ctx).await;
115
116 let progress_snapshot: Option<serde_json::Value> = {
118 let guard = progress.lock().expect("progress lock poisoned");
119 guard.clone_latest()
120 };
121
122 match &result {
124 Ok(JobResult::Completed) => {
125 sqlx::query(
126 "UPDATE awa.jobs SET state = 'completed', finalized_at = now(), progress = NULL WHERE id = $1",
127 )
128 .bind(job.id)
129 .execute(&self.pool)
130 .await?;
131 Ok(WorkResult::Completed(job))
132 }
133 Ok(JobResult::Cancel(reason)) => {
134 sqlx::query(
135 "UPDATE awa.jobs SET state = 'cancelled', finalized_at = now(), progress = $2 WHERE id = $1",
136 )
137 .bind(job.id)
138 .bind(&progress_snapshot)
139 .execute(&self.pool)
140 .await?;
141 Ok(WorkResult::Cancelled(job, reason.clone()))
142 }
143 Ok(JobResult::RetryAfter(_)) | Err(JobError::Retryable(_)) => {
144 sqlx::query(
145 "UPDATE awa.jobs SET state = 'retryable', finalized_at = now(), progress = $2 WHERE id = $1",
146 )
147 .bind(job.id)
148 .bind(&progress_snapshot)
149 .execute(&self.pool)
150 .await?;
151 Ok(WorkResult::Retryable(job))
152 }
153 Ok(JobResult::Snooze(_)) => {
154 sqlx::query(
155 "UPDATE awa.jobs SET state = 'available', attempt = attempt - 1, progress = $2 WHERE id = $1",
156 )
157 .bind(job.id)
158 .bind(&progress_snapshot)
159 .execute(&self.pool)
160 .await?;
161 Ok(WorkResult::Snoozed(job))
162 }
163 Ok(JobResult::WaitForCallback) => {
164 let has_callback: Option<(Option<uuid::Uuid>,)> =
166 sqlx::query_as("SELECT callback_id FROM awa.jobs WHERE id = $1")
167 .bind(job.id)
168 .fetch_optional(&self.pool)
169 .await?;
170 match has_callback {
171 Some((Some(_),)) => {
172 sqlx::query(
173 "UPDATE awa.jobs SET state = 'waiting_external', heartbeat_at = NULL, deadline_at = NULL, progress = $2 WHERE id = $1",
174 )
175 .bind(job.id)
176 .bind(&progress_snapshot)
177 .execute(&self.pool)
178 .await?;
179 let updated = self.get_job(job.id).await?;
180 Ok(WorkResult::WaitingExternal(updated))
181 }
182 _ => {
183 sqlx::query(
184 "UPDATE awa.jobs SET state = 'failed', finalized_at = now() WHERE id = $1",
185 )
186 .bind(job.id)
187 .execute(&self.pool)
188 .await?;
189 Ok(WorkResult::Failed(
190 job,
191 "WaitForCallback returned without calling register_callback"
192 .to_string(),
193 ))
194 }
195 }
196 }
197 Err(JobError::Terminal(msg)) => {
198 sqlx::query(
199 "UPDATE awa.jobs SET state = 'failed', finalized_at = now(), progress = $2 WHERE id = $1",
200 )
201 .bind(job.id)
202 .bind(&progress_snapshot)
203 .execute(&self.pool)
204 .await?;
205 Ok(WorkResult::Failed(job, msg.clone()))
206 }
207 }
208 }
209
210 pub async fn get_job(&self, job_id: i64) -> Result<JobRow, AwaError> {
212 awa_model::admin::get_job(&self.pool, job_id).await
213 }
214}
215
216#[derive(Debug)]
218pub enum WorkResult {
219 NoJob,
221 Completed(JobRow),
223 Retryable(JobRow),
225 Snoozed(JobRow),
227 Cancelled(JobRow, String),
229 Failed(JobRow, String),
231 WaitingExternal(JobRow),
233}
234
235impl WorkResult {
236 pub fn is_completed(&self) -> bool {
237 matches!(self, WorkResult::Completed(_))
238 }
239
240 pub fn is_failed(&self) -> bool {
241 matches!(self, WorkResult::Failed(_, _))
242 }
243
244 pub fn is_no_job(&self) -> bool {
245 matches!(self, WorkResult::NoJob)
246 }
247
248 pub fn is_waiting_external(&self) -> bool {
249 matches!(self, WorkResult::WaitingExternal(_))
250 }
251}