1use std::sync::{Arc, mpsc};
2
3use uuid::Uuid;
4
5use crate::env::{EnvAccess, EnvProvider, RealEnvProvider};
6use crate::function::AuthContext;
7
8pub fn empty_saved_data() -> serde_json::Value {
10 serde_json::Value::Object(serde_json::Map::new())
11}
12
13pub struct JobContext {
15 pub job_id: Uuid,
17 pub job_type: String,
19 pub attempt: u32,
21 pub max_attempts: u32,
23 pub auth: AuthContext,
25 saved_data: Arc<tokio::sync::RwLock<serde_json::Value>>,
27 db_pool: sqlx::PgPool,
29 http_client: reqwest::Client,
31 progress_tx: Option<mpsc::Sender<ProgressUpdate>>,
33 env_provider: Arc<dyn EnvProvider>,
35}
36
37#[derive(Debug, Clone)]
39pub struct ProgressUpdate {
40 pub job_id: Uuid,
42 pub percentage: u8,
44 pub message: String,
46}
47
48impl JobContext {
49 pub fn new(
51 job_id: Uuid,
52 job_type: String,
53 attempt: u32,
54 max_attempts: u32,
55 db_pool: sqlx::PgPool,
56 http_client: reqwest::Client,
57 ) -> Self {
58 Self {
59 job_id,
60 job_type,
61 attempt,
62 max_attempts,
63 auth: AuthContext::unauthenticated(),
64 saved_data: Arc::new(tokio::sync::RwLock::new(empty_saved_data())),
65 db_pool,
66 http_client,
67 progress_tx: None,
68 env_provider: Arc::new(RealEnvProvider::new()),
69 }
70 }
71
72 pub fn with_saved(mut self, data: serde_json::Value) -> Self {
74 self.saved_data = Arc::new(tokio::sync::RwLock::new(data));
75 self
76 }
77
78 pub fn with_auth(mut self, auth: AuthContext) -> Self {
80 self.auth = auth;
81 self
82 }
83
84 pub fn with_progress(mut self, tx: mpsc::Sender<ProgressUpdate>) -> Self {
86 self.progress_tx = Some(tx);
87 self
88 }
89
90 pub fn with_env_provider(mut self, provider: Arc<dyn EnvProvider>) -> Self {
92 self.env_provider = provider;
93 self
94 }
95
96 pub fn db(&self) -> &sqlx::PgPool {
98 &self.db_pool
99 }
100
101 pub fn db_conn(&self) -> crate::function::DbConn<'_> {
103 crate::function::DbConn::Pool(&self.db_pool)
104 }
105
106 pub fn http(&self) -> &reqwest::Client {
108 &self.http_client
109 }
110
111 pub fn progress(&self, percentage: u8, message: impl Into<String>) -> crate::Result<()> {
113 let update = ProgressUpdate {
114 job_id: self.job_id,
115 percentage: percentage.min(100),
116 message: message.into(),
117 };
118
119 if let Some(ref tx) = self.progress_tx {
120 tx.send(update)
121 .map_err(|e| crate::ForgeError::Job(format!("Failed to send progress: {}", e)))?;
122 }
123
124 Ok(())
125 }
126
127 pub async fn saved(&self) -> serde_json::Value {
132 self.saved_data.read().await.clone()
133 }
134
135 pub async fn set_saved(&self, data: serde_json::Value) -> crate::Result<()> {
140 let mut guard = self.saved_data.write().await;
141 *guard = data;
142 let persisted = Self::clone_and_drop(guard);
143 if self.job_id.is_nil() {
144 return Ok(());
145 }
146 self.persist_saved_data(persisted).await
147 }
148
149 pub async fn save(&self, key: &str, value: serde_json::Value) -> crate::Result<()> {
162 let mut guard = self.saved_data.write().await;
163 Self::apply_save(&mut guard, key, value);
164 let persisted = Self::clone_and_drop(guard);
165 if self.job_id.is_nil() {
166 return Ok(());
167 }
168 self.persist_saved_data(persisted).await
169 }
170
171 pub async fn is_cancel_requested(&self) -> crate::Result<bool> {
173 let row: Option<(String,)> = sqlx::query_as(
174 r#"
175 SELECT status
176 FROM forge_jobs
177 WHERE id = $1
178 "#,
179 )
180 .bind(self.job_id)
181 .fetch_optional(&self.db_pool)
182 .await
183 .map_err(|e| crate::ForgeError::Database(e.to_string()))?;
184
185 Ok(matches!(
186 row.as_ref().map(|(status,)| status.as_str()),
187 Some("cancel_requested") | Some("cancelled")
188 ))
189 }
190
191 pub async fn check_cancelled(&self) -> crate::Result<()> {
193 if self.is_cancel_requested().await? {
194 Err(crate::ForgeError::JobCancelled(
195 "Job cancellation requested".to_string(),
196 ))
197 } else {
198 Ok(())
199 }
200 }
201
202 async fn persist_saved_data(&self, data: serde_json::Value) -> crate::Result<()> {
203 sqlx::query(
204 r#"
205 UPDATE forge_jobs
206 SET job_context = $2
207 WHERE id = $1
208 "#,
209 )
210 .bind(self.job_id)
211 .bind(data)
212 .execute(&self.db_pool)
213 .await
214 .map_err(|e| crate::ForgeError::Database(e.to_string()))?;
215
216 Ok(())
217 }
218
219 fn apply_save(data: &mut serde_json::Value, key: &str, value: serde_json::Value) {
220 if let Some(map) = data.as_object_mut() {
221 map.insert(key.to_string(), value);
222 } else {
223 let mut map = serde_json::Map::new();
224 map.insert(key.to_string(), value);
225 *data = serde_json::Value::Object(map);
226 }
227 }
228
229 fn clone_and_drop(
230 guard: tokio::sync::RwLockWriteGuard<'_, serde_json::Value>,
231 ) -> serde_json::Value {
232 let cloned = guard.clone();
233 drop(guard);
234 cloned
235 }
236
237 pub async fn heartbeat(&self) -> crate::Result<()> {
239 sqlx::query(
240 r#"
241 UPDATE forge_jobs
242 SET last_heartbeat = NOW()
243 WHERE id = $1
244 "#,
245 )
246 .bind(self.job_id)
247 .execute(&self.db_pool)
248 .await
249 .map_err(|e| crate::ForgeError::Database(e.to_string()))?;
250
251 Ok(())
252 }
253
254 pub fn is_retry(&self) -> bool {
256 self.attempt > 1
257 }
258
259 pub fn is_last_attempt(&self) -> bool {
261 self.attempt >= self.max_attempts
262 }
263}
264
265impl EnvAccess for JobContext {
266 fn env_provider(&self) -> &dyn EnvProvider {
267 self.env_provider.as_ref()
268 }
269}
270
271#[cfg(test)]
272#[allow(clippy::unwrap_used, clippy::indexing_slicing)]
273mod tests {
274 use super::*;
275
276 #[tokio::test]
277 async fn test_job_context_creation() {
278 let pool = sqlx::postgres::PgPoolOptions::new()
279 .max_connections(1)
280 .connect_lazy("postgres://localhost/nonexistent")
281 .expect("Failed to create mock pool");
282
283 let job_id = Uuid::new_v4();
284 let ctx = JobContext::new(
285 job_id,
286 "test_job".to_string(),
287 1,
288 3,
289 pool,
290 reqwest::Client::new(),
291 );
292
293 assert_eq!(ctx.job_id, job_id);
294 assert_eq!(ctx.job_type, "test_job");
295 assert_eq!(ctx.attempt, 1);
296 assert_eq!(ctx.max_attempts, 3);
297 assert!(!ctx.is_retry());
298 assert!(!ctx.is_last_attempt());
299 }
300
301 #[tokio::test]
302 async fn test_is_retry() {
303 let pool = sqlx::postgres::PgPoolOptions::new()
304 .max_connections(1)
305 .connect_lazy("postgres://localhost/nonexistent")
306 .expect("Failed to create mock pool");
307
308 let ctx = JobContext::new(
309 Uuid::new_v4(),
310 "test".to_string(),
311 2,
312 3,
313 pool,
314 reqwest::Client::new(),
315 );
316
317 assert!(ctx.is_retry());
318 }
319
320 #[tokio::test]
321 async fn test_is_last_attempt() {
322 let pool = sqlx::postgres::PgPoolOptions::new()
323 .max_connections(1)
324 .connect_lazy("postgres://localhost/nonexistent")
325 .expect("Failed to create mock pool");
326
327 let ctx = JobContext::new(
328 Uuid::new_v4(),
329 "test".to_string(),
330 3,
331 3,
332 pool,
333 reqwest::Client::new(),
334 );
335
336 assert!(ctx.is_last_attempt());
337 }
338
339 #[test]
340 fn test_progress_update() {
341 let update = ProgressUpdate {
342 job_id: Uuid::new_v4(),
343 percentage: 50,
344 message: "Halfway there".to_string(),
345 };
346
347 assert_eq!(update.percentage, 50);
348 assert_eq!(update.message, "Halfway there");
349 }
350
351 #[tokio::test]
352 async fn test_saved_data_in_memory() {
353 let pool = sqlx::postgres::PgPoolOptions::new()
354 .max_connections(1)
355 .connect_lazy("postgres://localhost/nonexistent")
356 .expect("Failed to create mock pool");
357
358 let ctx = JobContext::new(
359 Uuid::nil(),
360 "test_job".to_string(),
361 1,
362 3,
363 pool,
364 reqwest::Client::new(),
365 )
366 .with_saved(serde_json::json!({"foo": "bar"}));
367
368 let saved = ctx.saved().await;
369 assert_eq!(saved["foo"], "bar");
370 }
371
372 #[tokio::test]
373 async fn test_save_key_value() {
374 let pool = sqlx::postgres::PgPoolOptions::new()
375 .max_connections(1)
376 .connect_lazy("postgres://localhost/nonexistent")
377 .expect("Failed to create mock pool");
378
379 let ctx = JobContext::new(
380 Uuid::nil(),
381 "test_job".to_string(),
382 1,
383 3,
384 pool,
385 reqwest::Client::new(),
386 );
387
388 ctx.save("charge_id", serde_json::json!("ch_123"))
389 .await
390 .unwrap();
391 ctx.save("amount", serde_json::json!(100)).await.unwrap();
392
393 let saved = ctx.saved().await;
394 assert_eq!(saved["charge_id"], "ch_123");
395 assert_eq!(saved["amount"], 100);
396 }
397}