Skip to main content

forge_core/job/
context.rs

1use std::sync::{Arc, mpsc};
2use std::time::Duration;
3
4use uuid::Uuid;
5
6use crate::env::{EnvAccess, EnvProvider, RealEnvProvider};
7use crate::function::AuthContext;
8use crate::http::CircuitBreakerClient;
9
10/// Returns an empty JSON object for initializing job saved data.
11pub fn empty_saved_data() -> serde_json::Value {
12    serde_json::Value::Object(serde_json::Map::new())
13}
14
15/// Context available to job handlers.
16pub struct JobContext {
17    /// Job ID.
18    pub job_id: Uuid,
19    /// Job type/name.
20    pub job_type: String,
21    /// Current attempt number (1-based).
22    pub attempt: u32,
23    /// Maximum attempts allowed.
24    pub max_attempts: u32,
25    /// Authentication context (for queries/mutations).
26    pub auth: AuthContext,
27    /// Persisted job data (survives retries, accessible during compensation).
28    saved_data: Arc<tokio::sync::RwLock<serde_json::Value>>,
29    /// Database pool.
30    db_pool: sqlx::PgPool,
31    /// HTTP client for external calls.
32    http_client: CircuitBreakerClient,
33    /// Default timeout for outbound HTTP requests made through the
34    /// circuit-breaker client. `None` means unlimited.
35    http_timeout: Option<Duration>,
36    /// Progress reporter (sync channel for simplicity).
37    progress_tx: Option<mpsc::Sender<ProgressUpdate>>,
38    /// Environment variable provider.
39    env_provider: Arc<dyn EnvProvider>,
40}
41
42/// Progress update message.
43#[derive(Debug, Clone)]
44pub struct ProgressUpdate {
45    /// Job ID.
46    pub job_id: Uuid,
47    /// Progress percentage (0-100).
48    pub percentage: u8,
49    /// Status message.
50    pub message: String,
51}
52
53impl JobContext {
54    /// Create a new job context.
55    pub fn new(
56        job_id: Uuid,
57        job_type: String,
58        attempt: u32,
59        max_attempts: u32,
60        db_pool: sqlx::PgPool,
61        http_client: CircuitBreakerClient,
62    ) -> Self {
63        Self {
64            job_id,
65            job_type,
66            attempt,
67            max_attempts,
68            auth: AuthContext::unauthenticated(),
69            saved_data: Arc::new(tokio::sync::RwLock::new(empty_saved_data())),
70            db_pool,
71            http_client,
72            http_timeout: None,
73            progress_tx: None,
74            env_provider: Arc::new(RealEnvProvider::new()),
75        }
76    }
77
78    /// Create a new job context with persisted saved data.
79    pub fn with_saved(mut self, data: serde_json::Value) -> Self {
80        self.saved_data = Arc::new(tokio::sync::RwLock::new(data));
81        self
82    }
83
84    /// Set authentication context.
85    pub fn with_auth(mut self, auth: AuthContext) -> Self {
86        self.auth = auth;
87        self
88    }
89
90    /// Set progress channel.
91    pub fn with_progress(mut self, tx: mpsc::Sender<ProgressUpdate>) -> Self {
92        self.progress_tx = Some(tx);
93        self
94    }
95
96    /// Set environment provider.
97    pub fn with_env_provider(mut self, provider: Arc<dyn EnvProvider>) -> Self {
98        self.env_provider = provider;
99        self
100    }
101
102    /// Get database pool.
103    pub fn db(&self) -> crate::function::ForgeDb {
104        crate::function::ForgeDb::from_pool(&self.db_pool)
105    }
106
107    /// Acquire a connection compatible with sqlx compile-time checked macros.
108    pub async fn conn(&self) -> sqlx::Result<crate::function::ForgeConn<'static>> {
109        Ok(crate::function::ForgeConn::Pool(
110            self.db_pool.acquire().await?,
111        ))
112    }
113
114    /// Get the HTTP client for external requests.
115    pub fn http(&self) -> crate::http::HttpClient {
116        self.http_client.with_timeout(self.http_timeout)
117    }
118
119    /// Get the raw reqwest client, bypassing circuit breaker execution.
120    pub fn raw_http(&self) -> &reqwest::Client {
121        self.http_client.inner()
122    }
123
124    pub fn http_with_circuit_breaker(&self) -> crate::http::HttpClient {
125        self.http()
126    }
127
128    pub fn set_http_timeout(&mut self, timeout: Option<Duration>) {
129        self.http_timeout = timeout;
130    }
131
132    /// Report job progress.
133    pub fn progress(&self, percentage: u8, message: impl Into<String>) -> crate::Result<()> {
134        let update = ProgressUpdate {
135            job_id: self.job_id,
136            percentage: percentage.min(100),
137            message: message.into(),
138        };
139
140        if let Some(ref tx) = self.progress_tx {
141            tx.send(update)
142                .map_err(|e| crate::ForgeError::Job(format!("Failed to send progress: {}", e)))?;
143        }
144
145        Ok(())
146    }
147
148    /// Get all saved job data.
149    ///
150    /// Returns data that was saved during job execution via `save()`.
151    /// This data persists across retries and is accessible in compensation handlers.
152    pub async fn saved(&self) -> serde_json::Value {
153        self.saved_data.read().await.clone()
154    }
155
156    /// Replace all saved job data.
157    ///
158    /// Replaces the entire saved data object. For updating individual keys,
159    /// use `save()` instead.
160    pub async fn set_saved(&self, data: serde_json::Value) -> crate::Result<()> {
161        let mut guard = self.saved_data.write().await;
162        *guard = data;
163        let persisted = Self::clone_and_drop(guard);
164        if self.job_id.is_nil() {
165            return Ok(());
166        }
167        self.persist_saved_data(persisted).await
168    }
169
170    /// Save a key-value pair to persistent job data.
171    ///
172    /// Saved data persists across retries and is accessible in compensation handlers.
173    /// Use this to store information needed for rollback (e.g., transaction IDs,
174    /// resource handles, progress markers).
175    ///
176    /// # Example
177    ///
178    /// ```ignore
179    /// ctx.save("charge_id", json!(charge.id)).await?;
180    /// ctx.save("refund_amount", json!(amount)).await?;
181    /// ```
182    pub async fn save(&self, key: &str, value: serde_json::Value) -> crate::Result<()> {
183        let mut guard = self.saved_data.write().await;
184        Self::apply_save(&mut guard, key, value);
185        let persisted = Self::clone_and_drop(guard);
186        if self.job_id.is_nil() {
187            return Ok(());
188        }
189        self.persist_saved_data(persisted).await
190    }
191
192    /// Check if cancellation has been requested for this job.
193    pub async fn is_cancel_requested(&self) -> crate::Result<bool> {
194        let row = sqlx::query_scalar!(
195            r#"
196            SELECT status
197            FROM forge_jobs
198            WHERE id = $1
199            "#,
200            self.job_id
201        )
202        .fetch_optional(&self.db_pool)
203        .await
204        .map_err(|e| crate::ForgeError::Database(e.to_string()))?;
205
206        Ok(matches!(
207            row.as_deref(),
208            Some("cancel_requested") | Some("cancelled")
209        ))
210    }
211
212    /// Return an error if cancellation has been requested.
213    pub async fn check_cancelled(&self) -> crate::Result<()> {
214        if self.is_cancel_requested().await? {
215            Err(crate::ForgeError::JobCancelled(
216                "Job cancellation requested".to_string(),
217            ))
218        } else {
219            Ok(())
220        }
221    }
222
223    async fn persist_saved_data(&self, data: serde_json::Value) -> crate::Result<()> {
224        sqlx::query(
225            r#"
226            UPDATE forge_jobs
227            SET job_context = $2
228            WHERE id = $1
229            "#,
230        )
231        .bind(self.job_id)
232        .bind(data)
233        .execute(&self.db_pool)
234        .await
235        .map_err(|e| crate::ForgeError::Database(e.to_string()))?;
236
237        Ok(())
238    }
239
240    fn apply_save(data: &mut serde_json::Value, key: &str, value: serde_json::Value) {
241        if let Some(map) = data.as_object_mut() {
242            map.insert(key.to_string(), value);
243        } else {
244            let mut map = serde_json::Map::new();
245            map.insert(key.to_string(), value);
246            *data = serde_json::Value::Object(map);
247        }
248    }
249
250    fn clone_and_drop(
251        guard: tokio::sync::RwLockWriteGuard<'_, serde_json::Value>,
252    ) -> serde_json::Value {
253        let cloned = guard.clone();
254        drop(guard);
255        cloned
256    }
257
258    /// Send heartbeat to keep job alive (async).
259    pub async fn heartbeat(&self) -> crate::Result<()> {
260        sqlx::query(
261            r#"
262            UPDATE forge_jobs
263            SET last_heartbeat = NOW()
264            WHERE id = $1
265            "#,
266        )
267        .bind(self.job_id)
268        .execute(&self.db_pool)
269        .await
270        .map_err(|e| crate::ForgeError::Database(e.to_string()))?;
271
272        Ok(())
273    }
274
275    /// Check if this is a retry attempt.
276    pub fn is_retry(&self) -> bool {
277        self.attempt > 1
278    }
279
280    /// Check if this is the last attempt.
281    pub fn is_last_attempt(&self) -> bool {
282        self.attempt >= self.max_attempts
283    }
284}
285
286impl EnvAccess for JobContext {
287    fn env_provider(&self) -> &dyn EnvProvider {
288        self.env_provider.as_ref()
289    }
290}
291
292#[cfg(test)]
293#[allow(clippy::unwrap_used, clippy::indexing_slicing)]
294mod tests {
295    use super::*;
296
297    #[tokio::test]
298    async fn test_job_context_creation() {
299        let pool = sqlx::postgres::PgPoolOptions::new()
300            .max_connections(1)
301            .connect_lazy("postgres://localhost/nonexistent")
302            .expect("Failed to create mock pool");
303
304        let job_id = Uuid::new_v4();
305        let ctx = JobContext::new(
306            job_id,
307            "test_job".to_string(),
308            1,
309            3,
310            pool,
311            CircuitBreakerClient::with_defaults(reqwest::Client::new()),
312        );
313
314        assert_eq!(ctx.job_id, job_id);
315        assert_eq!(ctx.job_type, "test_job");
316        assert_eq!(ctx.attempt, 1);
317        assert_eq!(ctx.max_attempts, 3);
318        assert!(!ctx.is_retry());
319        assert!(!ctx.is_last_attempt());
320    }
321
322    #[tokio::test]
323    async fn test_is_retry() {
324        let pool = sqlx::postgres::PgPoolOptions::new()
325            .max_connections(1)
326            .connect_lazy("postgres://localhost/nonexistent")
327            .expect("Failed to create mock pool");
328
329        let ctx = JobContext::new(
330            Uuid::new_v4(),
331            "test".to_string(),
332            2,
333            3,
334            pool,
335            CircuitBreakerClient::with_defaults(reqwest::Client::new()),
336        );
337
338        assert!(ctx.is_retry());
339    }
340
341    #[tokio::test]
342    async fn test_is_last_attempt() {
343        let pool = sqlx::postgres::PgPoolOptions::new()
344            .max_connections(1)
345            .connect_lazy("postgres://localhost/nonexistent")
346            .expect("Failed to create mock pool");
347
348        let ctx = JobContext::new(
349            Uuid::new_v4(),
350            "test".to_string(),
351            3,
352            3,
353            pool,
354            CircuitBreakerClient::with_defaults(reqwest::Client::new()),
355        );
356
357        assert!(ctx.is_last_attempt());
358    }
359
360    #[test]
361    fn test_progress_update() {
362        let update = ProgressUpdate {
363            job_id: Uuid::new_v4(),
364            percentage: 50,
365            message: "Halfway there".to_string(),
366        };
367
368        assert_eq!(update.percentage, 50);
369        assert_eq!(update.message, "Halfway there");
370    }
371
372    #[tokio::test]
373    async fn test_saved_data_in_memory() {
374        let pool = sqlx::postgres::PgPoolOptions::new()
375            .max_connections(1)
376            .connect_lazy("postgres://localhost/nonexistent")
377            .expect("Failed to create mock pool");
378
379        let ctx = JobContext::new(
380            Uuid::nil(),
381            "test_job".to_string(),
382            1,
383            3,
384            pool,
385            CircuitBreakerClient::with_defaults(reqwest::Client::new()),
386        )
387        .with_saved(serde_json::json!({"foo": "bar"}));
388
389        let saved = ctx.saved().await;
390        assert_eq!(saved["foo"], "bar");
391    }
392
393    #[tokio::test]
394    async fn test_save_key_value() {
395        let pool = sqlx::postgres::PgPoolOptions::new()
396            .max_connections(1)
397            .connect_lazy("postgres://localhost/nonexistent")
398            .expect("Failed to create mock pool");
399
400        let ctx = JobContext::new(
401            Uuid::nil(),
402            "test_job".to_string(),
403            1,
404            3,
405            pool,
406            CircuitBreakerClient::with_defaults(reqwest::Client::new()),
407        );
408
409        ctx.save("charge_id", serde_json::json!("ch_123"))
410            .await
411            .unwrap();
412        ctx.save("amount", serde_json::json!(100)).await.unwrap();
413
414        let saved = ctx.saved().await;
415        assert_eq!(saved["charge_id"], "ch_123");
416        assert_eq!(saved["amount"], 100);
417    }
418}