1use std::sync::{Arc, mpsc};
2use std::time::Duration;
3
4use uuid::Uuid;
5
6use crate::env::{EnvAccess, EnvProvider, RealEnvProvider};
7use crate::function::AuthContext;
8use crate::http::CircuitBreakerClient;
9
10pub fn empty_saved_data() -> serde_json::Value {
12 serde_json::Value::Object(serde_json::Map::new())
13}
14
15pub struct JobContext {
17 pub job_id: Uuid,
19 pub job_type: String,
21 pub attempt: u32,
23 pub max_attempts: u32,
25 pub auth: AuthContext,
27 saved_data: Arc<tokio::sync::RwLock<serde_json::Value>>,
29 db_pool: sqlx::PgPool,
31 http_client: CircuitBreakerClient,
33 http_timeout: Option<Duration>,
36 progress_tx: Option<mpsc::Sender<ProgressUpdate>>,
38 env_provider: Arc<dyn EnvProvider>,
40}
41
42#[derive(Debug, Clone)]
44pub struct ProgressUpdate {
45 pub job_id: Uuid,
47 pub percentage: u8,
49 pub message: String,
51}
52
53impl JobContext {
54 pub fn new(
56 job_id: Uuid,
57 job_type: String,
58 attempt: u32,
59 max_attempts: u32,
60 db_pool: sqlx::PgPool,
61 http_client: CircuitBreakerClient,
62 ) -> Self {
63 Self {
64 job_id,
65 job_type,
66 attempt,
67 max_attempts,
68 auth: AuthContext::unauthenticated(),
69 saved_data: Arc::new(tokio::sync::RwLock::new(empty_saved_data())),
70 db_pool,
71 http_client,
72 http_timeout: None,
73 progress_tx: None,
74 env_provider: Arc::new(RealEnvProvider::new()),
75 }
76 }
77
78 pub fn with_saved(mut self, data: serde_json::Value) -> Self {
80 self.saved_data = Arc::new(tokio::sync::RwLock::new(data));
81 self
82 }
83
84 pub fn with_auth(mut self, auth: AuthContext) -> Self {
86 self.auth = auth;
87 self
88 }
89
90 pub fn with_progress(mut self, tx: mpsc::Sender<ProgressUpdate>) -> Self {
92 self.progress_tx = Some(tx);
93 self
94 }
95
96 pub fn with_env_provider(mut self, provider: Arc<dyn EnvProvider>) -> Self {
98 self.env_provider = provider;
99 self
100 }
101
102 pub fn db(&self) -> crate::function::ForgeDb {
104 crate::function::ForgeDb::from_pool(&self.db_pool)
105 }
106
107 pub async fn conn(&self) -> sqlx::Result<crate::function::ForgeConn<'static>> {
109 Ok(crate::function::ForgeConn::Pool(
110 self.db_pool.acquire().await?,
111 ))
112 }
113
114 pub fn http(&self) -> crate::http::HttpClient {
116 self.http_client.with_timeout(self.http_timeout)
117 }
118
119 pub fn raw_http(&self) -> &reqwest::Client {
121 self.http_client.inner()
122 }
123
124 pub fn http_with_circuit_breaker(&self) -> crate::http::HttpClient {
125 self.http()
126 }
127
128 pub fn set_http_timeout(&mut self, timeout: Option<Duration>) {
129 self.http_timeout = timeout;
130 }
131
132 pub fn progress(&self, percentage: u8, message: impl Into<String>) -> crate::Result<()> {
134 let update = ProgressUpdate {
135 job_id: self.job_id,
136 percentage: percentage.min(100),
137 message: message.into(),
138 };
139
140 if let Some(ref tx) = self.progress_tx {
141 tx.send(update)
142 .map_err(|e| crate::ForgeError::Job(format!("Failed to send progress: {}", e)))?;
143 }
144
145 Ok(())
146 }
147
148 pub async fn saved(&self) -> serde_json::Value {
153 self.saved_data.read().await.clone()
154 }
155
156 pub async fn set_saved(&self, data: serde_json::Value) -> crate::Result<()> {
161 let mut guard = self.saved_data.write().await;
162 *guard = data;
163 let persisted = Self::clone_and_drop(guard);
164 if self.job_id.is_nil() {
165 return Ok(());
166 }
167 self.persist_saved_data(persisted).await
168 }
169
170 pub async fn save(&self, key: &str, value: serde_json::Value) -> crate::Result<()> {
183 let mut guard = self.saved_data.write().await;
184 Self::apply_save(&mut guard, key, value);
185 let persisted = Self::clone_and_drop(guard);
186 if self.job_id.is_nil() {
187 return Ok(());
188 }
189 self.persist_saved_data(persisted).await
190 }
191
192 pub async fn is_cancel_requested(&self) -> crate::Result<bool> {
194 let row = sqlx::query_scalar!(
195 r#"
196 SELECT status
197 FROM forge_jobs
198 WHERE id = $1
199 "#,
200 self.job_id
201 )
202 .fetch_optional(&self.db_pool)
203 .await
204 .map_err(|e| crate::ForgeError::Database(e.to_string()))?;
205
206 Ok(matches!(
207 row.as_deref(),
208 Some("cancel_requested") | Some("cancelled")
209 ))
210 }
211
212 pub async fn check_cancelled(&self) -> crate::Result<()> {
214 if self.is_cancel_requested().await? {
215 Err(crate::ForgeError::JobCancelled(
216 "Job cancellation requested".to_string(),
217 ))
218 } else {
219 Ok(())
220 }
221 }
222
223 async fn persist_saved_data(&self, data: serde_json::Value) -> crate::Result<()> {
224 sqlx::query(
225 r#"
226 UPDATE forge_jobs
227 SET job_context = $2
228 WHERE id = $1
229 "#,
230 )
231 .bind(self.job_id)
232 .bind(data)
233 .execute(&self.db_pool)
234 .await
235 .map_err(|e| crate::ForgeError::Database(e.to_string()))?;
236
237 Ok(())
238 }
239
240 fn apply_save(data: &mut serde_json::Value, key: &str, value: serde_json::Value) {
241 if let Some(map) = data.as_object_mut() {
242 map.insert(key.to_string(), value);
243 } else {
244 let mut map = serde_json::Map::new();
245 map.insert(key.to_string(), value);
246 *data = serde_json::Value::Object(map);
247 }
248 }
249
250 fn clone_and_drop(
251 guard: tokio::sync::RwLockWriteGuard<'_, serde_json::Value>,
252 ) -> serde_json::Value {
253 let cloned = guard.clone();
254 drop(guard);
255 cloned
256 }
257
258 pub async fn heartbeat(&self) -> crate::Result<()> {
260 sqlx::query(
261 r#"
262 UPDATE forge_jobs
263 SET last_heartbeat = NOW()
264 WHERE id = $1
265 "#,
266 )
267 .bind(self.job_id)
268 .execute(&self.db_pool)
269 .await
270 .map_err(|e| crate::ForgeError::Database(e.to_string()))?;
271
272 Ok(())
273 }
274
275 pub fn is_retry(&self) -> bool {
277 self.attempt > 1
278 }
279
280 pub fn is_last_attempt(&self) -> bool {
282 self.attempt >= self.max_attempts
283 }
284}
285
286impl EnvAccess for JobContext {
287 fn env_provider(&self) -> &dyn EnvProvider {
288 self.env_provider.as_ref()
289 }
290}
291
292#[cfg(test)]
293#[allow(clippy::unwrap_used, clippy::indexing_slicing)]
294mod tests {
295 use super::*;
296
297 #[tokio::test]
298 async fn test_job_context_creation() {
299 let pool = sqlx::postgres::PgPoolOptions::new()
300 .max_connections(1)
301 .connect_lazy("postgres://localhost/nonexistent")
302 .expect("Failed to create mock pool");
303
304 let job_id = Uuid::new_v4();
305 let ctx = JobContext::new(
306 job_id,
307 "test_job".to_string(),
308 1,
309 3,
310 pool,
311 CircuitBreakerClient::with_defaults(reqwest::Client::new()),
312 );
313
314 assert_eq!(ctx.job_id, job_id);
315 assert_eq!(ctx.job_type, "test_job");
316 assert_eq!(ctx.attempt, 1);
317 assert_eq!(ctx.max_attempts, 3);
318 assert!(!ctx.is_retry());
319 assert!(!ctx.is_last_attempt());
320 }
321
322 #[tokio::test]
323 async fn test_is_retry() {
324 let pool = sqlx::postgres::PgPoolOptions::new()
325 .max_connections(1)
326 .connect_lazy("postgres://localhost/nonexistent")
327 .expect("Failed to create mock pool");
328
329 let ctx = JobContext::new(
330 Uuid::new_v4(),
331 "test".to_string(),
332 2,
333 3,
334 pool,
335 CircuitBreakerClient::with_defaults(reqwest::Client::new()),
336 );
337
338 assert!(ctx.is_retry());
339 }
340
341 #[tokio::test]
342 async fn test_is_last_attempt() {
343 let pool = sqlx::postgres::PgPoolOptions::new()
344 .max_connections(1)
345 .connect_lazy("postgres://localhost/nonexistent")
346 .expect("Failed to create mock pool");
347
348 let ctx = JobContext::new(
349 Uuid::new_v4(),
350 "test".to_string(),
351 3,
352 3,
353 pool,
354 CircuitBreakerClient::with_defaults(reqwest::Client::new()),
355 );
356
357 assert!(ctx.is_last_attempt());
358 }
359
360 #[test]
361 fn test_progress_update() {
362 let update = ProgressUpdate {
363 job_id: Uuid::new_v4(),
364 percentage: 50,
365 message: "Halfway there".to_string(),
366 };
367
368 assert_eq!(update.percentage, 50);
369 assert_eq!(update.message, "Halfway there");
370 }
371
372 #[tokio::test]
373 async fn test_saved_data_in_memory() {
374 let pool = sqlx::postgres::PgPoolOptions::new()
375 .max_connections(1)
376 .connect_lazy("postgres://localhost/nonexistent")
377 .expect("Failed to create mock pool");
378
379 let ctx = JobContext::new(
380 Uuid::nil(),
381 "test_job".to_string(),
382 1,
383 3,
384 pool,
385 CircuitBreakerClient::with_defaults(reqwest::Client::new()),
386 )
387 .with_saved(serde_json::json!({"foo": "bar"}));
388
389 let saved = ctx.saved().await;
390 assert_eq!(saved["foo"], "bar");
391 }
392
393 #[tokio::test]
394 async fn test_save_key_value() {
395 let pool = sqlx::postgres::PgPoolOptions::new()
396 .max_connections(1)
397 .connect_lazy("postgres://localhost/nonexistent")
398 .expect("Failed to create mock pool");
399
400 let ctx = JobContext::new(
401 Uuid::nil(),
402 "test_job".to_string(),
403 1,
404 3,
405 pool,
406 CircuitBreakerClient::with_defaults(reqwest::Client::new()),
407 );
408
409 ctx.save("charge_id", serde_json::json!("ch_123"))
410 .await
411 .unwrap();
412 ctx.save("amount", serde_json::json!(100)).await.unwrap();
413
414 let saved = ctx.saved().await;
415 assert_eq!(saved["charge_id"], "ch_123");
416 assert_eq!(saved["amount"], 100);
417 }
418}