1use std::sync::{Arc, mpsc};
2
3use uuid::Uuid;
4
5use crate::env::{EnvAccess, EnvProvider, RealEnvProvider};
6use crate::function::AuthContext;
7
8pub fn empty_saved_data() -> serde_json::Value {
10 serde_json::Value::Object(serde_json::Map::new())
11}
12
13pub struct JobContext {
15 pub job_id: Uuid,
17 pub job_type: String,
19 pub attempt: u32,
21 pub max_attempts: u32,
23 pub auth: AuthContext,
25 saved_data: Arc<tokio::sync::RwLock<serde_json::Value>>,
27 db_pool: sqlx::PgPool,
29 http_client: reqwest::Client,
31 progress_tx: Option<mpsc::Sender<ProgressUpdate>>,
33 env_provider: Arc<dyn EnvProvider>,
35}
36
37#[derive(Debug, Clone)]
39pub struct ProgressUpdate {
40 pub job_id: Uuid,
42 pub percentage: u8,
44 pub message: String,
46}
47
48impl JobContext {
49 pub fn new(
51 job_id: Uuid,
52 job_type: String,
53 attempt: u32,
54 max_attempts: u32,
55 db_pool: sqlx::PgPool,
56 http_client: reqwest::Client,
57 ) -> Self {
58 Self {
59 job_id,
60 job_type,
61 attempt,
62 max_attempts,
63 auth: AuthContext::unauthenticated(),
64 saved_data: Arc::new(tokio::sync::RwLock::new(empty_saved_data())),
65 db_pool,
66 http_client,
67 progress_tx: None,
68 env_provider: Arc::new(RealEnvProvider::new()),
69 }
70 }
71
72 pub fn with_saved(mut self, data: serde_json::Value) -> Self {
74 self.saved_data = Arc::new(tokio::sync::RwLock::new(data));
75 self
76 }
77
78 pub fn with_auth(mut self, auth: AuthContext) -> Self {
80 self.auth = auth;
81 self
82 }
83
84 pub fn with_progress(mut self, tx: mpsc::Sender<ProgressUpdate>) -> Self {
86 self.progress_tx = Some(tx);
87 self
88 }
89
90 pub fn with_env_provider(mut self, provider: Arc<dyn EnvProvider>) -> Self {
92 self.env_provider = provider;
93 self
94 }
95
96 pub fn db(&self) -> &sqlx::PgPool {
98 &self.db_pool
99 }
100
101 pub fn db_conn(&self) -> crate::function::DbConn<'_> {
103 crate::function::DbConn::Pool(&self.db_pool)
104 }
105
106 pub async fn conn(&self) -> sqlx::Result<crate::function::ForgeConn<'static>> {
108 Ok(crate::function::ForgeConn::Pool(
109 self.db_pool.acquire().await?,
110 ))
111 }
112
113 pub fn http(&self) -> &reqwest::Client {
115 &self.http_client
116 }
117
118 pub fn progress(&self, percentage: u8, message: impl Into<String>) -> crate::Result<()> {
120 let update = ProgressUpdate {
121 job_id: self.job_id,
122 percentage: percentage.min(100),
123 message: message.into(),
124 };
125
126 if let Some(ref tx) = self.progress_tx {
127 tx.send(update)
128 .map_err(|e| crate::ForgeError::Job(format!("Failed to send progress: {}", e)))?;
129 }
130
131 Ok(())
132 }
133
134 pub async fn saved(&self) -> serde_json::Value {
139 self.saved_data.read().await.clone()
140 }
141
142 pub async fn set_saved(&self, data: serde_json::Value) -> crate::Result<()> {
147 let mut guard = self.saved_data.write().await;
148 *guard = data;
149 let persisted = Self::clone_and_drop(guard);
150 if self.job_id.is_nil() {
151 return Ok(());
152 }
153 self.persist_saved_data(persisted).await
154 }
155
156 pub async fn save(&self, key: &str, value: serde_json::Value) -> crate::Result<()> {
169 let mut guard = self.saved_data.write().await;
170 Self::apply_save(&mut guard, key, value);
171 let persisted = Self::clone_and_drop(guard);
172 if self.job_id.is_nil() {
173 return Ok(());
174 }
175 self.persist_saved_data(persisted).await
176 }
177
178 pub async fn is_cancel_requested(&self) -> crate::Result<bool> {
180 let row: Option<(String,)> = sqlx::query_as(
181 r#"
182 SELECT status
183 FROM forge_jobs
184 WHERE id = $1
185 "#,
186 )
187 .bind(self.job_id)
188 .fetch_optional(&self.db_pool)
189 .await
190 .map_err(|e| crate::ForgeError::Database(e.to_string()))?;
191
192 Ok(matches!(
193 row.as_ref().map(|(status,)| status.as_str()),
194 Some("cancel_requested") | Some("cancelled")
195 ))
196 }
197
198 pub async fn check_cancelled(&self) -> crate::Result<()> {
200 if self.is_cancel_requested().await? {
201 Err(crate::ForgeError::JobCancelled(
202 "Job cancellation requested".to_string(),
203 ))
204 } else {
205 Ok(())
206 }
207 }
208
209 async fn persist_saved_data(&self, data: serde_json::Value) -> crate::Result<()> {
210 sqlx::query(
211 r#"
212 UPDATE forge_jobs
213 SET job_context = $2
214 WHERE id = $1
215 "#,
216 )
217 .bind(self.job_id)
218 .bind(data)
219 .execute(&self.db_pool)
220 .await
221 .map_err(|e| crate::ForgeError::Database(e.to_string()))?;
222
223 Ok(())
224 }
225
226 fn apply_save(data: &mut serde_json::Value, key: &str, value: serde_json::Value) {
227 if let Some(map) = data.as_object_mut() {
228 map.insert(key.to_string(), value);
229 } else {
230 let mut map = serde_json::Map::new();
231 map.insert(key.to_string(), value);
232 *data = serde_json::Value::Object(map);
233 }
234 }
235
236 fn clone_and_drop(
237 guard: tokio::sync::RwLockWriteGuard<'_, serde_json::Value>,
238 ) -> serde_json::Value {
239 let cloned = guard.clone();
240 drop(guard);
241 cloned
242 }
243
244 pub async fn heartbeat(&self) -> crate::Result<()> {
246 sqlx::query(
247 r#"
248 UPDATE forge_jobs
249 SET last_heartbeat = NOW()
250 WHERE id = $1
251 "#,
252 )
253 .bind(self.job_id)
254 .execute(&self.db_pool)
255 .await
256 .map_err(|e| crate::ForgeError::Database(e.to_string()))?;
257
258 Ok(())
259 }
260
261 pub fn is_retry(&self) -> bool {
263 self.attempt > 1
264 }
265
266 pub fn is_last_attempt(&self) -> bool {
268 self.attempt >= self.max_attempts
269 }
270}
271
272impl EnvAccess for JobContext {
273 fn env_provider(&self) -> &dyn EnvProvider {
274 self.env_provider.as_ref()
275 }
276}
277
278#[cfg(test)]
279#[allow(clippy::unwrap_used, clippy::indexing_slicing)]
280mod tests {
281 use super::*;
282
283 #[tokio::test]
284 async fn test_job_context_creation() {
285 let pool = sqlx::postgres::PgPoolOptions::new()
286 .max_connections(1)
287 .connect_lazy("postgres://localhost/nonexistent")
288 .expect("Failed to create mock pool");
289
290 let job_id = Uuid::new_v4();
291 let ctx = JobContext::new(
292 job_id,
293 "test_job".to_string(),
294 1,
295 3,
296 pool,
297 reqwest::Client::new(),
298 );
299
300 assert_eq!(ctx.job_id, job_id);
301 assert_eq!(ctx.job_type, "test_job");
302 assert_eq!(ctx.attempt, 1);
303 assert_eq!(ctx.max_attempts, 3);
304 assert!(!ctx.is_retry());
305 assert!(!ctx.is_last_attempt());
306 }
307
308 #[tokio::test]
309 async fn test_is_retry() {
310 let pool = sqlx::postgres::PgPoolOptions::new()
311 .max_connections(1)
312 .connect_lazy("postgres://localhost/nonexistent")
313 .expect("Failed to create mock pool");
314
315 let ctx = JobContext::new(
316 Uuid::new_v4(),
317 "test".to_string(),
318 2,
319 3,
320 pool,
321 reqwest::Client::new(),
322 );
323
324 assert!(ctx.is_retry());
325 }
326
327 #[tokio::test]
328 async fn test_is_last_attempt() {
329 let pool = sqlx::postgres::PgPoolOptions::new()
330 .max_connections(1)
331 .connect_lazy("postgres://localhost/nonexistent")
332 .expect("Failed to create mock pool");
333
334 let ctx = JobContext::new(
335 Uuid::new_v4(),
336 "test".to_string(),
337 3,
338 3,
339 pool,
340 reqwest::Client::new(),
341 );
342
343 assert!(ctx.is_last_attempt());
344 }
345
346 #[test]
347 fn test_progress_update() {
348 let update = ProgressUpdate {
349 job_id: Uuid::new_v4(),
350 percentage: 50,
351 message: "Halfway there".to_string(),
352 };
353
354 assert_eq!(update.percentage, 50);
355 assert_eq!(update.message, "Halfway there");
356 }
357
358 #[tokio::test]
359 async fn test_saved_data_in_memory() {
360 let pool = sqlx::postgres::PgPoolOptions::new()
361 .max_connections(1)
362 .connect_lazy("postgres://localhost/nonexistent")
363 .expect("Failed to create mock pool");
364
365 let ctx = JobContext::new(
366 Uuid::nil(),
367 "test_job".to_string(),
368 1,
369 3,
370 pool,
371 reqwest::Client::new(),
372 )
373 .with_saved(serde_json::json!({"foo": "bar"}));
374
375 let saved = ctx.saved().await;
376 assert_eq!(saved["foo"], "bar");
377 }
378
379 #[tokio::test]
380 async fn test_save_key_value() {
381 let pool = sqlx::postgres::PgPoolOptions::new()
382 .max_connections(1)
383 .connect_lazy("postgres://localhost/nonexistent")
384 .expect("Failed to create mock pool");
385
386 let ctx = JobContext::new(
387 Uuid::nil(),
388 "test_job".to_string(),
389 1,
390 3,
391 pool,
392 reqwest::Client::new(),
393 );
394
395 ctx.save("charge_id", serde_json::json!("ch_123"))
396 .await
397 .unwrap();
398 ctx.save("amount", serde_json::json!(100)).await.unwrap();
399
400 let saved = ctx.saved().await;
401 assert_eq!(saved["charge_id"], "ch_123");
402 assert_eq!(saved["amount"], 100);
403 }
404}