Skip to main content

forge_core/job/
context.rs

1use std::sync::{Arc, mpsc};
2
3use uuid::Uuid;
4
5use crate::env::{EnvAccess, EnvProvider, RealEnvProvider};
6use crate::function::AuthContext;
7
8/// Returns an empty JSON object for initializing job saved data.
9pub fn empty_saved_data() -> serde_json::Value {
10    serde_json::Value::Object(serde_json::Map::new())
11}
12
13/// Context available to job handlers.
14pub struct JobContext {
15    /// Job ID.
16    pub job_id: Uuid,
17    /// Job type/name.
18    pub job_type: String,
19    /// Current attempt number (1-based).
20    pub attempt: u32,
21    /// Maximum attempts allowed.
22    pub max_attempts: u32,
23    /// Authentication context (for queries/mutations).
24    pub auth: AuthContext,
25    /// Persisted job data (survives retries, accessible during compensation).
26    saved_data: Arc<tokio::sync::RwLock<serde_json::Value>>,
27    /// Database pool.
28    db_pool: sqlx::PgPool,
29    /// HTTP client for external calls.
30    http_client: reqwest::Client,
31    /// Progress reporter (sync channel for simplicity).
32    progress_tx: Option<mpsc::Sender<ProgressUpdate>>,
33    /// Environment variable provider.
34    env_provider: Arc<dyn EnvProvider>,
35}
36
37/// Progress update message.
38#[derive(Debug, Clone)]
39pub struct ProgressUpdate {
40    /// Job ID.
41    pub job_id: Uuid,
42    /// Progress percentage (0-100).
43    pub percentage: u8,
44    /// Status message.
45    pub message: String,
46}
47
48impl JobContext {
49    /// Create a new job context.
50    pub fn new(
51        job_id: Uuid,
52        job_type: String,
53        attempt: u32,
54        max_attempts: u32,
55        db_pool: sqlx::PgPool,
56        http_client: reqwest::Client,
57    ) -> Self {
58        Self {
59            job_id,
60            job_type,
61            attempt,
62            max_attempts,
63            auth: AuthContext::unauthenticated(),
64            saved_data: Arc::new(tokio::sync::RwLock::new(empty_saved_data())),
65            db_pool,
66            http_client,
67            progress_tx: None,
68            env_provider: Arc::new(RealEnvProvider::new()),
69        }
70    }
71
72    /// Create a new job context with persisted saved data.
73    pub fn with_saved(mut self, data: serde_json::Value) -> Self {
74        self.saved_data = Arc::new(tokio::sync::RwLock::new(data));
75        self
76    }
77
78    /// Set authentication context.
79    pub fn with_auth(mut self, auth: AuthContext) -> Self {
80        self.auth = auth;
81        self
82    }
83
84    /// Set progress channel.
85    pub fn with_progress(mut self, tx: mpsc::Sender<ProgressUpdate>) -> Self {
86        self.progress_tx = Some(tx);
87        self
88    }
89
90    /// Set environment provider.
91    pub fn with_env_provider(mut self, provider: Arc<dyn EnvProvider>) -> Self {
92        self.env_provider = provider;
93        self
94    }
95
96    /// Get database pool.
97    pub fn db(&self) -> &sqlx::PgPool {
98        &self.db_pool
99    }
100
101    /// Returns a `DbConn` wrapping the pool for shared helper functions.
102    pub fn db_conn(&self) -> crate::function::DbConn<'_> {
103        crate::function::DbConn::Pool(&self.db_pool)
104    }
105
106    /// Get HTTP client.
107    pub fn http(&self) -> &reqwest::Client {
108        &self.http_client
109    }
110
111    /// Report job progress.
112    pub fn progress(&self, percentage: u8, message: impl Into<String>) -> crate::Result<()> {
113        let update = ProgressUpdate {
114            job_id: self.job_id,
115            percentage: percentage.min(100),
116            message: message.into(),
117        };
118
119        if let Some(ref tx) = self.progress_tx {
120            tx.send(update)
121                .map_err(|e| crate::ForgeError::Job(format!("Failed to send progress: {}", e)))?;
122        }
123
124        Ok(())
125    }
126
127    /// Get all saved job data.
128    ///
129    /// Returns data that was saved during job execution via `save()`.
130    /// This data persists across retries and is accessible in compensation handlers.
131    pub async fn saved(&self) -> serde_json::Value {
132        self.saved_data.read().await.clone()
133    }
134
135    /// Replace all saved job data.
136    ///
137    /// Replaces the entire saved data object. For updating individual keys,
138    /// use `save()` instead.
139    pub async fn set_saved(&self, data: serde_json::Value) -> crate::Result<()> {
140        let mut guard = self.saved_data.write().await;
141        *guard = data;
142        let persisted = Self::clone_and_drop(guard);
143        if self.job_id.is_nil() {
144            return Ok(());
145        }
146        self.persist_saved_data(persisted).await
147    }
148
149    /// Save a key-value pair to persistent job data.
150    ///
151    /// Saved data persists across retries and is accessible in compensation handlers.
152    /// Use this to store information needed for rollback (e.g., transaction IDs,
153    /// resource handles, progress markers).
154    ///
155    /// # Example
156    ///
157    /// ```ignore
158    /// ctx.save("charge_id", json!(charge.id)).await?;
159    /// ctx.save("refund_amount", json!(amount)).await?;
160    /// ```
161    pub async fn save(&self, key: &str, value: serde_json::Value) -> crate::Result<()> {
162        let mut guard = self.saved_data.write().await;
163        Self::apply_save(&mut guard, key, value);
164        let persisted = Self::clone_and_drop(guard);
165        if self.job_id.is_nil() {
166            return Ok(());
167        }
168        self.persist_saved_data(persisted).await
169    }
170
171    /// Check if cancellation has been requested for this job.
172    pub async fn is_cancel_requested(&self) -> crate::Result<bool> {
173        let row: Option<(String,)> = sqlx::query_as(
174            r#"
175            SELECT status
176            FROM forge_jobs
177            WHERE id = $1
178            "#,
179        )
180        .bind(self.job_id)
181        .fetch_optional(&self.db_pool)
182        .await
183        .map_err(|e| crate::ForgeError::Database(e.to_string()))?;
184
185        Ok(matches!(
186            row.as_ref().map(|(status,)| status.as_str()),
187            Some("cancel_requested") | Some("cancelled")
188        ))
189    }
190
191    /// Return an error if cancellation has been requested.
192    pub async fn check_cancelled(&self) -> crate::Result<()> {
193        if self.is_cancel_requested().await? {
194            Err(crate::ForgeError::JobCancelled(
195                "Job cancellation requested".to_string(),
196            ))
197        } else {
198            Ok(())
199        }
200    }
201
202    async fn persist_saved_data(&self, data: serde_json::Value) -> crate::Result<()> {
203        sqlx::query(
204            r#"
205            UPDATE forge_jobs
206            SET job_context = $2
207            WHERE id = $1
208            "#,
209        )
210        .bind(self.job_id)
211        .bind(data)
212        .execute(&self.db_pool)
213        .await
214        .map_err(|e| crate::ForgeError::Database(e.to_string()))?;
215
216        Ok(())
217    }
218
219    fn apply_save(data: &mut serde_json::Value, key: &str, value: serde_json::Value) {
220        if let Some(map) = data.as_object_mut() {
221            map.insert(key.to_string(), value);
222        } else {
223            let mut map = serde_json::Map::new();
224            map.insert(key.to_string(), value);
225            *data = serde_json::Value::Object(map);
226        }
227    }
228
229    fn clone_and_drop(
230        guard: tokio::sync::RwLockWriteGuard<'_, serde_json::Value>,
231    ) -> serde_json::Value {
232        let cloned = guard.clone();
233        drop(guard);
234        cloned
235    }
236
237    /// Send heartbeat to keep job alive (async).
238    pub async fn heartbeat(&self) -> crate::Result<()> {
239        sqlx::query(
240            r#"
241            UPDATE forge_jobs
242            SET last_heartbeat = NOW()
243            WHERE id = $1
244            "#,
245        )
246        .bind(self.job_id)
247        .execute(&self.db_pool)
248        .await
249        .map_err(|e| crate::ForgeError::Database(e.to_string()))?;
250
251        Ok(())
252    }
253
254    /// Check if this is a retry attempt.
255    pub fn is_retry(&self) -> bool {
256        self.attempt > 1
257    }
258
259    /// Check if this is the last attempt.
260    pub fn is_last_attempt(&self) -> bool {
261        self.attempt >= self.max_attempts
262    }
263}
264
265impl EnvAccess for JobContext {
266    fn env_provider(&self) -> &dyn EnvProvider {
267        self.env_provider.as_ref()
268    }
269}
270
271#[cfg(test)]
272#[allow(clippy::unwrap_used, clippy::indexing_slicing)]
273mod tests {
274    use super::*;
275
276    #[tokio::test]
277    async fn test_job_context_creation() {
278        let pool = sqlx::postgres::PgPoolOptions::new()
279            .max_connections(1)
280            .connect_lazy("postgres://localhost/nonexistent")
281            .expect("Failed to create mock pool");
282
283        let job_id = Uuid::new_v4();
284        let ctx = JobContext::new(
285            job_id,
286            "test_job".to_string(),
287            1,
288            3,
289            pool,
290            reqwest::Client::new(),
291        );
292
293        assert_eq!(ctx.job_id, job_id);
294        assert_eq!(ctx.job_type, "test_job");
295        assert_eq!(ctx.attempt, 1);
296        assert_eq!(ctx.max_attempts, 3);
297        assert!(!ctx.is_retry());
298        assert!(!ctx.is_last_attempt());
299    }
300
301    #[tokio::test]
302    async fn test_is_retry() {
303        let pool = sqlx::postgres::PgPoolOptions::new()
304            .max_connections(1)
305            .connect_lazy("postgres://localhost/nonexistent")
306            .expect("Failed to create mock pool");
307
308        let ctx = JobContext::new(
309            Uuid::new_v4(),
310            "test".to_string(),
311            2,
312            3,
313            pool,
314            reqwest::Client::new(),
315        );
316
317        assert!(ctx.is_retry());
318    }
319
320    #[tokio::test]
321    async fn test_is_last_attempt() {
322        let pool = sqlx::postgres::PgPoolOptions::new()
323            .max_connections(1)
324            .connect_lazy("postgres://localhost/nonexistent")
325            .expect("Failed to create mock pool");
326
327        let ctx = JobContext::new(
328            Uuid::new_v4(),
329            "test".to_string(),
330            3,
331            3,
332            pool,
333            reqwest::Client::new(),
334        );
335
336        assert!(ctx.is_last_attempt());
337    }
338
339    #[test]
340    fn test_progress_update() {
341        let update = ProgressUpdate {
342            job_id: Uuid::new_v4(),
343            percentage: 50,
344            message: "Halfway there".to_string(),
345        };
346
347        assert_eq!(update.percentage, 50);
348        assert_eq!(update.message, "Halfway there");
349    }
350
351    #[tokio::test]
352    async fn test_saved_data_in_memory() {
353        let pool = sqlx::postgres::PgPoolOptions::new()
354            .max_connections(1)
355            .connect_lazy("postgres://localhost/nonexistent")
356            .expect("Failed to create mock pool");
357
358        let ctx = JobContext::new(
359            Uuid::nil(),
360            "test_job".to_string(),
361            1,
362            3,
363            pool,
364            reqwest::Client::new(),
365        )
366        .with_saved(serde_json::json!({"foo": "bar"}));
367
368        let saved = ctx.saved().await;
369        assert_eq!(saved["foo"], "bar");
370    }
371
372    #[tokio::test]
373    async fn test_save_key_value() {
374        let pool = sqlx::postgres::PgPoolOptions::new()
375            .max_connections(1)
376            .connect_lazy("postgres://localhost/nonexistent")
377            .expect("Failed to create mock pool");
378
379        let ctx = JobContext::new(
380            Uuid::nil(),
381            "test_job".to_string(),
382            1,
383            3,
384            pool,
385            reqwest::Client::new(),
386        );
387
388        ctx.save("charge_id", serde_json::json!("ch_123"))
389            .await
390            .unwrap();
391        ctx.save("amount", serde_json::json!(100)).await.unwrap();
392
393        let saved = ctx.saved().await;
394        assert_eq!(saved["charge_id"], "ch_123");
395        assert_eq!(saved["amount"], 100);
396    }
397}