1use std::sync::{Arc, mpsc};
2use std::time::Duration;
3
4use uuid::Uuid;
5
6use crate::env::{EnvAccess, EnvProvider, RealEnvProvider};
7use crate::function::AuthContext;
8use crate::http::CircuitBreakerClient;
9
10pub fn empty_saved_data() -> serde_json::Value {
12 serde_json::Value::Object(serde_json::Map::new())
13}
14
15pub struct JobContext {
17 pub job_id: Uuid,
19 pub job_type: String,
21 pub attempt: u32,
23 pub max_attempts: u32,
25 pub auth: AuthContext,
27 saved_data: Arc<tokio::sync::RwLock<serde_json::Value>>,
29 db_pool: sqlx::PgPool,
31 http_client: CircuitBreakerClient,
33 http_timeout: Option<Duration>,
36 progress_tx: Option<mpsc::Sender<ProgressUpdate>>,
38 env_provider: Arc<dyn EnvProvider>,
40}
41
42#[derive(Debug, Clone)]
44pub struct ProgressUpdate {
45 pub job_id: Uuid,
47 pub percentage: u8,
49 pub message: String,
51}
52
53impl JobContext {
54 pub fn new(
56 job_id: Uuid,
57 job_type: String,
58 attempt: u32,
59 max_attempts: u32,
60 db_pool: sqlx::PgPool,
61 http_client: CircuitBreakerClient,
62 ) -> Self {
63 Self {
64 job_id,
65 job_type,
66 attempt,
67 max_attempts,
68 auth: AuthContext::unauthenticated(),
69 saved_data: Arc::new(tokio::sync::RwLock::new(empty_saved_data())),
70 db_pool,
71 http_client,
72 http_timeout: None,
73 progress_tx: None,
74 env_provider: Arc::new(RealEnvProvider::new()),
75 }
76 }
77
78 pub fn with_saved(mut self, data: serde_json::Value) -> Self {
80 self.saved_data = Arc::new(tokio::sync::RwLock::new(data));
81 self
82 }
83
84 pub fn with_auth(mut self, auth: AuthContext) -> Self {
86 self.auth = auth;
87 self
88 }
89
90 pub fn with_progress(mut self, tx: mpsc::Sender<ProgressUpdate>) -> Self {
92 self.progress_tx = Some(tx);
93 self
94 }
95
96 pub fn with_env_provider(mut self, provider: Arc<dyn EnvProvider>) -> Self {
98 self.env_provider = provider;
99 self
100 }
101
102 pub fn db(&self) -> crate::function::ForgeDb {
104 crate::function::ForgeDb::from_pool(&self.db_pool)
105 }
106
107 pub fn db_conn(&self) -> crate::function::DbConn<'_> {
109 crate::function::DbConn::Pool(self.db_pool.clone())
110 }
111
112 pub async fn conn(&self) -> sqlx::Result<crate::function::ForgeConn<'static>> {
114 Ok(crate::function::ForgeConn::Pool(
115 self.db_pool.acquire().await?,
116 ))
117 }
118
119 pub fn http(&self) -> crate::http::HttpClient {
121 self.http_client.with_timeout(self.http_timeout)
122 }
123
124 pub fn raw_http(&self) -> &reqwest::Client {
126 self.http_client.inner()
127 }
128
129 pub fn set_http_timeout(&mut self, timeout: Option<Duration>) {
130 self.http_timeout = timeout;
131 }
132
133 pub fn progress(&self, percentage: u8, message: impl Into<String>) -> crate::Result<()> {
135 let update = ProgressUpdate {
136 job_id: self.job_id,
137 percentage: percentage.min(100),
138 message: message.into(),
139 };
140
141 if let Some(ref tx) = self.progress_tx {
142 tx.send(update)
143 .map_err(|e| crate::ForgeError::Job(format!("Failed to send progress: {}", e)))?;
144 }
145
146 Ok(())
147 }
148
149 pub async fn saved(&self) -> serde_json::Value {
154 self.saved_data.read().await.clone()
155 }
156
157 pub async fn set_saved(&self, data: serde_json::Value) -> crate::Result<()> {
162 let mut guard = self.saved_data.write().await;
163 *guard = data;
164 let persisted = Self::clone_and_drop(guard);
165 if self.job_id.is_nil() {
166 return Ok(());
167 }
168 self.persist_saved_data(persisted).await
169 }
170
171 pub async fn save(&self, key: &str, value: serde_json::Value) -> crate::Result<()> {
184 let mut guard = self.saved_data.write().await;
185 Self::apply_save(&mut guard, key, value);
186 let persisted = Self::clone_and_drop(guard);
187 if self.job_id.is_nil() {
188 return Ok(());
189 }
190 self.persist_saved_data(persisted).await
191 }
192
193 pub async fn is_cancel_requested(&self) -> crate::Result<bool> {
195 let row = sqlx::query_scalar!(
196 r#"
197 SELECT status
198 FROM forge_jobs
199 WHERE id = $1
200 "#,
201 self.job_id
202 )
203 .fetch_optional(&self.db_pool)
204 .await
205 .map_err(|e| crate::ForgeError::Database(e.to_string()))?;
206
207 Ok(matches!(
208 row.as_deref(),
209 Some("cancel_requested") | Some("cancelled")
210 ))
211 }
212
213 pub async fn check_cancelled(&self) -> crate::Result<()> {
215 if self.is_cancel_requested().await? {
216 Err(crate::ForgeError::JobCancelled(
217 "Job cancellation requested".to_string(),
218 ))
219 } else {
220 Ok(())
221 }
222 }
223
224 async fn persist_saved_data(&self, data: serde_json::Value) -> crate::Result<()> {
225 sqlx::query!(
226 r#"
227 UPDATE forge_jobs
228 SET job_context = $2
229 WHERE id = $1
230 "#,
231 self.job_id,
232 data,
233 )
234 .execute(&self.db_pool)
235 .await
236 .map_err(|e| crate::ForgeError::Database(e.to_string()))?;
237
238 Ok(())
239 }
240
241 fn apply_save(data: &mut serde_json::Value, key: &str, value: serde_json::Value) {
242 if let Some(map) = data.as_object_mut() {
243 map.insert(key.to_string(), value);
244 } else {
245 let mut map = serde_json::Map::new();
246 map.insert(key.to_string(), value);
247 *data = serde_json::Value::Object(map);
248 }
249 }
250
251 fn clone_and_drop(
252 guard: tokio::sync::RwLockWriteGuard<'_, serde_json::Value>,
253 ) -> serde_json::Value {
254 let cloned = guard.clone();
255 drop(guard);
256 cloned
257 }
258
259 pub async fn heartbeat(&self) -> crate::Result<()> {
261 sqlx::query!(
262 r#"
263 UPDATE forge_jobs
264 SET last_heartbeat = NOW()
265 WHERE id = $1
266 "#,
267 self.job_id,
268 )
269 .execute(&self.db_pool)
270 .await
271 .map_err(|e| crate::ForgeError::Database(e.to_string()))?;
272
273 Ok(())
274 }
275
276 pub fn is_retry(&self) -> bool {
278 self.attempt > 1
279 }
280
281 pub fn is_last_attempt(&self) -> bool {
283 self.attempt >= self.max_attempts
284 }
285}
286
287impl EnvAccess for JobContext {
288 fn env_provider(&self) -> &dyn EnvProvider {
289 self.env_provider.as_ref()
290 }
291}
292
293#[cfg(test)]
294#[allow(clippy::unwrap_used, clippy::indexing_slicing)]
295mod tests {
296 use super::*;
297
298 #[tokio::test]
299 async fn test_job_context_creation() {
300 let pool = sqlx::postgres::PgPoolOptions::new()
301 .max_connections(1)
302 .connect_lazy("postgres://localhost/nonexistent")
303 .expect("Failed to create mock pool");
304
305 let job_id = Uuid::new_v4();
306 let ctx = JobContext::new(
307 job_id,
308 "test_job".to_string(),
309 1,
310 3,
311 pool,
312 CircuitBreakerClient::with_defaults(reqwest::Client::new()),
313 );
314
315 assert_eq!(ctx.job_id, job_id);
316 assert_eq!(ctx.job_type, "test_job");
317 assert_eq!(ctx.attempt, 1);
318 assert_eq!(ctx.max_attempts, 3);
319 assert!(!ctx.is_retry());
320 assert!(!ctx.is_last_attempt());
321 }
322
323 #[tokio::test]
324 async fn test_is_retry() {
325 let pool = sqlx::postgres::PgPoolOptions::new()
326 .max_connections(1)
327 .connect_lazy("postgres://localhost/nonexistent")
328 .expect("Failed to create mock pool");
329
330 let ctx = JobContext::new(
331 Uuid::new_v4(),
332 "test".to_string(),
333 2,
334 3,
335 pool,
336 CircuitBreakerClient::with_defaults(reqwest::Client::new()),
337 );
338
339 assert!(ctx.is_retry());
340 }
341
342 #[tokio::test]
343 async fn test_is_last_attempt() {
344 let pool = sqlx::postgres::PgPoolOptions::new()
345 .max_connections(1)
346 .connect_lazy("postgres://localhost/nonexistent")
347 .expect("Failed to create mock pool");
348
349 let ctx = JobContext::new(
350 Uuid::new_v4(),
351 "test".to_string(),
352 3,
353 3,
354 pool,
355 CircuitBreakerClient::with_defaults(reqwest::Client::new()),
356 );
357
358 assert!(ctx.is_last_attempt());
359 }
360
361 #[test]
362 fn test_progress_update() {
363 let update = ProgressUpdate {
364 job_id: Uuid::new_v4(),
365 percentage: 50,
366 message: "Halfway there".to_string(),
367 };
368
369 assert_eq!(update.percentage, 50);
370 assert_eq!(update.message, "Halfway there");
371 }
372
373 #[tokio::test]
374 async fn test_saved_data_in_memory() {
375 let pool = sqlx::postgres::PgPoolOptions::new()
376 .max_connections(1)
377 .connect_lazy("postgres://localhost/nonexistent")
378 .expect("Failed to create mock pool");
379
380 let ctx = JobContext::new(
381 Uuid::nil(),
382 "test_job".to_string(),
383 1,
384 3,
385 pool,
386 CircuitBreakerClient::with_defaults(reqwest::Client::new()),
387 )
388 .with_saved(serde_json::json!({"foo": "bar"}));
389
390 let saved = ctx.saved().await;
391 assert_eq!(saved["foo"], "bar");
392 }
393
394 #[tokio::test]
395 async fn test_save_key_value() {
396 let pool = sqlx::postgres::PgPoolOptions::new()
397 .max_connections(1)
398 .connect_lazy("postgres://localhost/nonexistent")
399 .expect("Failed to create mock pool");
400
401 let ctx = JobContext::new(
402 Uuid::nil(),
403 "test_job".to_string(),
404 1,
405 3,
406 pool,
407 CircuitBreakerClient::with_defaults(reqwest::Client::new()),
408 );
409
410 ctx.save("charge_id", serde_json::json!("ch_123"))
411 .await
412 .unwrap();
413 ctx.save("amount", serde_json::json!(100)).await.unwrap();
414
415 let saved = ctx.saved().await;
416 assert_eq!(saved["charge_id"], "ch_123");
417 assert_eq!(saved["amount"], 100);
418 }
419}