1use std::sync::{Arc, mpsc};
2use std::time::Duration;
3
4use uuid::Uuid;
5
6use crate::env::{EnvAccess, EnvProvider, RealEnvProvider};
7use crate::function::AuthContext;
8use crate::http::CircuitBreakerClient;
9
10pub fn empty_saved_data() -> serde_json::Value {
12 serde_json::Value::Object(serde_json::Map::new())
13}
14
15pub struct JobContext {
17 pub job_id: Uuid,
19 pub job_type: String,
21 pub attempt: u32,
23 pub max_attempts: u32,
25 pub auth: AuthContext,
27 saved_data: Arc<tokio::sync::RwLock<serde_json::Value>>,
29 db_pool: sqlx::PgPool,
31 http_client: CircuitBreakerClient,
33 http_timeout: Option<Duration>,
36 progress_tx: Option<mpsc::Sender<ProgressUpdate>>,
38 env_provider: Arc<dyn EnvProvider>,
40}
41
42#[derive(Debug, Clone)]
44pub struct ProgressUpdate {
45 pub job_id: Uuid,
47 pub percentage: u8,
49 pub message: String,
51}
52
53impl JobContext {
54 pub fn new(
56 job_id: Uuid,
57 job_type: String,
58 attempt: u32,
59 max_attempts: u32,
60 db_pool: sqlx::PgPool,
61 http_client: CircuitBreakerClient,
62 ) -> Self {
63 Self {
64 job_id,
65 job_type,
66 attempt,
67 max_attempts,
68 auth: AuthContext::unauthenticated(),
69 saved_data: Arc::new(tokio::sync::RwLock::new(empty_saved_data())),
70 db_pool,
71 http_client,
72 http_timeout: None,
73 progress_tx: None,
74 env_provider: Arc::new(RealEnvProvider::new()),
75 }
76 }
77
78 pub fn with_saved(mut self, data: serde_json::Value) -> Self {
80 self.saved_data = Arc::new(tokio::sync::RwLock::new(data));
81 self
82 }
83
84 pub fn with_auth(mut self, auth: AuthContext) -> Self {
86 self.auth = auth;
87 self
88 }
89
90 pub fn with_progress(mut self, tx: mpsc::Sender<ProgressUpdate>) -> Self {
92 self.progress_tx = Some(tx);
93 self
94 }
95
96 pub fn with_env_provider(mut self, provider: Arc<dyn EnvProvider>) -> Self {
98 self.env_provider = provider;
99 self
100 }
101
102 pub fn db(&self) -> crate::function::ForgeDb {
104 crate::function::ForgeDb::from_pool(&self.db_pool)
105 }
106
107 pub async fn conn(&self) -> sqlx::Result<crate::function::ForgeConn<'static>> {
109 Ok(crate::function::ForgeConn::Pool(
110 self.db_pool.acquire().await?,
111 ))
112 }
113
114 pub fn http(&self) -> crate::http::HttpClient {
116 self.http_client.with_timeout(self.http_timeout)
117 }
118
119 pub fn raw_http(&self) -> &reqwest::Client {
121 self.http_client.inner()
122 }
123
124 pub fn set_http_timeout(&mut self, timeout: Option<Duration>) {
125 self.http_timeout = timeout;
126 }
127
128 pub fn progress(&self, percentage: u8, message: impl Into<String>) -> crate::Result<()> {
130 let update = ProgressUpdate {
131 job_id: self.job_id,
132 percentage: percentage.min(100),
133 message: message.into(),
134 };
135
136 if let Some(ref tx) = self.progress_tx {
137 tx.send(update)
138 .map_err(|e| crate::ForgeError::Job(format!("Failed to send progress: {}", e)))?;
139 }
140
141 Ok(())
142 }
143
144 pub async fn saved(&self) -> serde_json::Value {
149 self.saved_data.read().await.clone()
150 }
151
152 pub async fn set_saved(&self, data: serde_json::Value) -> crate::Result<()> {
157 let mut guard = self.saved_data.write().await;
158 *guard = data;
159 let persisted = Self::clone_and_drop(guard);
160 if self.job_id.is_nil() {
161 return Ok(());
162 }
163 self.persist_saved_data(persisted).await
164 }
165
166 pub async fn save(&self, key: &str, value: serde_json::Value) -> crate::Result<()> {
179 let mut guard = self.saved_data.write().await;
180 Self::apply_save(&mut guard, key, value);
181 let persisted = Self::clone_and_drop(guard);
182 if self.job_id.is_nil() {
183 return Ok(());
184 }
185 self.persist_saved_data(persisted).await
186 }
187
188 pub async fn is_cancel_requested(&self) -> crate::Result<bool> {
190 let row = sqlx::query_scalar!(
191 r#"
192 SELECT status
193 FROM forge_jobs
194 WHERE id = $1
195 "#,
196 self.job_id
197 )
198 .fetch_optional(&self.db_pool)
199 .await
200 .map_err(|e| crate::ForgeError::Database(e.to_string()))?;
201
202 Ok(matches!(
203 row.as_deref(),
204 Some("cancel_requested") | Some("cancelled")
205 ))
206 }
207
208 pub async fn check_cancelled(&self) -> crate::Result<()> {
210 if self.is_cancel_requested().await? {
211 Err(crate::ForgeError::JobCancelled(
212 "Job cancellation requested".to_string(),
213 ))
214 } else {
215 Ok(())
216 }
217 }
218
219 async fn persist_saved_data(&self, data: serde_json::Value) -> crate::Result<()> {
220 sqlx::query!(
221 r#"
222 UPDATE forge_jobs
223 SET job_context = $2
224 WHERE id = $1
225 "#,
226 self.job_id,
227 data,
228 )
229 .execute(&self.db_pool)
230 .await
231 .map_err(|e| crate::ForgeError::Database(e.to_string()))?;
232
233 Ok(())
234 }
235
236 fn apply_save(data: &mut serde_json::Value, key: &str, value: serde_json::Value) {
237 if let Some(map) = data.as_object_mut() {
238 map.insert(key.to_string(), value);
239 } else {
240 let mut map = serde_json::Map::new();
241 map.insert(key.to_string(), value);
242 *data = serde_json::Value::Object(map);
243 }
244 }
245
246 fn clone_and_drop(
247 guard: tokio::sync::RwLockWriteGuard<'_, serde_json::Value>,
248 ) -> serde_json::Value {
249 let cloned = guard.clone();
250 drop(guard);
251 cloned
252 }
253
254 pub async fn heartbeat(&self) -> crate::Result<()> {
256 sqlx::query!(
257 r#"
258 UPDATE forge_jobs
259 SET last_heartbeat = NOW()
260 WHERE id = $1
261 "#,
262 self.job_id,
263 )
264 .execute(&self.db_pool)
265 .await
266 .map_err(|e| crate::ForgeError::Database(e.to_string()))?;
267
268 Ok(())
269 }
270
271 pub fn is_retry(&self) -> bool {
273 self.attempt > 1
274 }
275
276 pub fn is_last_attempt(&self) -> bool {
278 self.attempt >= self.max_attempts
279 }
280}
281
282impl EnvAccess for JobContext {
283 fn env_provider(&self) -> &dyn EnvProvider {
284 self.env_provider.as_ref()
285 }
286}
287
288#[cfg(test)]
289#[allow(clippy::unwrap_used, clippy::indexing_slicing)]
290mod tests {
291 use super::*;
292
293 #[tokio::test]
294 async fn test_job_context_creation() {
295 let pool = sqlx::postgres::PgPoolOptions::new()
296 .max_connections(1)
297 .connect_lazy("postgres://localhost/nonexistent")
298 .expect("Failed to create mock pool");
299
300 let job_id = Uuid::new_v4();
301 let ctx = JobContext::new(
302 job_id,
303 "test_job".to_string(),
304 1,
305 3,
306 pool,
307 CircuitBreakerClient::with_defaults(reqwest::Client::new()),
308 );
309
310 assert_eq!(ctx.job_id, job_id);
311 assert_eq!(ctx.job_type, "test_job");
312 assert_eq!(ctx.attempt, 1);
313 assert_eq!(ctx.max_attempts, 3);
314 assert!(!ctx.is_retry());
315 assert!(!ctx.is_last_attempt());
316 }
317
318 #[tokio::test]
319 async fn test_is_retry() {
320 let pool = sqlx::postgres::PgPoolOptions::new()
321 .max_connections(1)
322 .connect_lazy("postgres://localhost/nonexistent")
323 .expect("Failed to create mock pool");
324
325 let ctx = JobContext::new(
326 Uuid::new_v4(),
327 "test".to_string(),
328 2,
329 3,
330 pool,
331 CircuitBreakerClient::with_defaults(reqwest::Client::new()),
332 );
333
334 assert!(ctx.is_retry());
335 }
336
337 #[tokio::test]
338 async fn test_is_last_attempt() {
339 let pool = sqlx::postgres::PgPoolOptions::new()
340 .max_connections(1)
341 .connect_lazy("postgres://localhost/nonexistent")
342 .expect("Failed to create mock pool");
343
344 let ctx = JobContext::new(
345 Uuid::new_v4(),
346 "test".to_string(),
347 3,
348 3,
349 pool,
350 CircuitBreakerClient::with_defaults(reqwest::Client::new()),
351 );
352
353 assert!(ctx.is_last_attempt());
354 }
355
356 #[test]
357 fn test_progress_update() {
358 let update = ProgressUpdate {
359 job_id: Uuid::new_v4(),
360 percentage: 50,
361 message: "Halfway there".to_string(),
362 };
363
364 assert_eq!(update.percentage, 50);
365 assert_eq!(update.message, "Halfway there");
366 }
367
368 #[tokio::test]
369 async fn test_saved_data_in_memory() {
370 let pool = sqlx::postgres::PgPoolOptions::new()
371 .max_connections(1)
372 .connect_lazy("postgres://localhost/nonexistent")
373 .expect("Failed to create mock pool");
374
375 let ctx = JobContext::new(
376 Uuid::nil(),
377 "test_job".to_string(),
378 1,
379 3,
380 pool,
381 CircuitBreakerClient::with_defaults(reqwest::Client::new()),
382 )
383 .with_saved(serde_json::json!({"foo": "bar"}));
384
385 let saved = ctx.saved().await;
386 assert_eq!(saved["foo"], "bar");
387 }
388
389 #[tokio::test]
390 async fn test_save_key_value() {
391 let pool = sqlx::postgres::PgPoolOptions::new()
392 .max_connections(1)
393 .connect_lazy("postgres://localhost/nonexistent")
394 .expect("Failed to create mock pool");
395
396 let ctx = JobContext::new(
397 Uuid::nil(),
398 "test_job".to_string(),
399 1,
400 3,
401 pool,
402 CircuitBreakerClient::with_defaults(reqwest::Client::new()),
403 );
404
405 ctx.save("charge_id", serde_json::json!("ch_123"))
406 .await
407 .unwrap();
408 ctx.save("amount", serde_json::json!(100)).await.unwrap();
409
410 let saved = ctx.saved().await;
411 assert_eq!(saved["charge_id"], "ch_123");
412 assert_eq!(saved["amount"], 100);
413 }
414}