Skip to main content

forge_core/job/
context.rs

1use std::sync::{Arc, mpsc};
2
3use uuid::Uuid;
4
5use crate::env::{EnvAccess, EnvProvider, RealEnvProvider};
6use crate::function::AuthContext;
7
8/// Returns an empty JSON object for initializing job saved data.
9pub fn empty_saved_data() -> serde_json::Value {
10    serde_json::Value::Object(serde_json::Map::new())
11}
12
13/// Context available to job handlers.
14pub struct JobContext {
15    /// Job ID.
16    pub job_id: Uuid,
17    /// Job type/name.
18    pub job_type: String,
19    /// Current attempt number (1-based).
20    pub attempt: u32,
21    /// Maximum attempts allowed.
22    pub max_attempts: u32,
23    /// Authentication context (for queries/mutations).
24    pub auth: AuthContext,
25    /// Persisted job data (survives retries, accessible during compensation).
26    saved_data: Arc<tokio::sync::RwLock<serde_json::Value>>,
27    /// Database pool.
28    db_pool: sqlx::PgPool,
29    /// HTTP client for external calls.
30    http_client: reqwest::Client,
31    /// Progress reporter (sync channel for simplicity).
32    progress_tx: Option<mpsc::Sender<ProgressUpdate>>,
33    /// Environment variable provider.
34    env_provider: Arc<dyn EnvProvider>,
35}
36
37/// Progress update message.
38#[derive(Debug, Clone)]
39pub struct ProgressUpdate {
40    /// Job ID.
41    pub job_id: Uuid,
42    /// Progress percentage (0-100).
43    pub percentage: u8,
44    /// Status message.
45    pub message: String,
46}
47
48impl JobContext {
49    /// Create a new job context.
50    pub fn new(
51        job_id: Uuid,
52        job_type: String,
53        attempt: u32,
54        max_attempts: u32,
55        db_pool: sqlx::PgPool,
56        http_client: reqwest::Client,
57    ) -> Self {
58        Self {
59            job_id,
60            job_type,
61            attempt,
62            max_attempts,
63            auth: AuthContext::unauthenticated(),
64            saved_data: Arc::new(tokio::sync::RwLock::new(empty_saved_data())),
65            db_pool,
66            http_client,
67            progress_tx: None,
68            env_provider: Arc::new(RealEnvProvider::new()),
69        }
70    }
71
72    /// Create a new job context with persisted saved data.
73    pub fn with_saved(mut self, data: serde_json::Value) -> Self {
74        self.saved_data = Arc::new(tokio::sync::RwLock::new(data));
75        self
76    }
77
78    /// Set authentication context.
79    pub fn with_auth(mut self, auth: AuthContext) -> Self {
80        self.auth = auth;
81        self
82    }
83
84    /// Set progress channel.
85    pub fn with_progress(mut self, tx: mpsc::Sender<ProgressUpdate>) -> Self {
86        self.progress_tx = Some(tx);
87        self
88    }
89
90    /// Set environment provider.
91    pub fn with_env_provider(mut self, provider: Arc<dyn EnvProvider>) -> Self {
92        self.env_provider = provider;
93        self
94    }
95
96    /// Get database pool.
97    pub fn db(&self) -> &sqlx::PgPool {
98        &self.db_pool
99    }
100
101    /// Returns a `DbConn` wrapping the pool for shared helper functions.
102    pub fn db_conn(&self) -> crate::function::DbConn<'_> {
103        crate::function::DbConn::Pool(&self.db_pool)
104    }
105
106    /// Acquire a connection compatible with sqlx compile-time checked macros.
107    pub async fn conn(&self) -> sqlx::Result<crate::function::ForgeConn<'static>> {
108        Ok(crate::function::ForgeConn::Pool(
109            self.db_pool.acquire().await?,
110        ))
111    }
112
113    /// Get HTTP client.
114    pub fn http(&self) -> &reqwest::Client {
115        &self.http_client
116    }
117
118    /// Report job progress.
119    pub fn progress(&self, percentage: u8, message: impl Into<String>) -> crate::Result<()> {
120        let update = ProgressUpdate {
121            job_id: self.job_id,
122            percentage: percentage.min(100),
123            message: message.into(),
124        };
125
126        if let Some(ref tx) = self.progress_tx {
127            tx.send(update)
128                .map_err(|e| crate::ForgeError::Job(format!("Failed to send progress: {}", e)))?;
129        }
130
131        Ok(())
132    }
133
134    /// Get all saved job data.
135    ///
136    /// Returns data that was saved during job execution via `save()`.
137    /// This data persists across retries and is accessible in compensation handlers.
138    pub async fn saved(&self) -> serde_json::Value {
139        self.saved_data.read().await.clone()
140    }
141
142    /// Replace all saved job data.
143    ///
144    /// Replaces the entire saved data object. For updating individual keys,
145    /// use `save()` instead.
146    pub async fn set_saved(&self, data: serde_json::Value) -> crate::Result<()> {
147        let mut guard = self.saved_data.write().await;
148        *guard = data;
149        let persisted = Self::clone_and_drop(guard);
150        if self.job_id.is_nil() {
151            return Ok(());
152        }
153        self.persist_saved_data(persisted).await
154    }
155
156    /// Save a key-value pair to persistent job data.
157    ///
158    /// Saved data persists across retries and is accessible in compensation handlers.
159    /// Use this to store information needed for rollback (e.g., transaction IDs,
160    /// resource handles, progress markers).
161    ///
162    /// # Example
163    ///
164    /// ```ignore
165    /// ctx.save("charge_id", json!(charge.id)).await?;
166    /// ctx.save("refund_amount", json!(amount)).await?;
167    /// ```
168    pub async fn save(&self, key: &str, value: serde_json::Value) -> crate::Result<()> {
169        let mut guard = self.saved_data.write().await;
170        Self::apply_save(&mut guard, key, value);
171        let persisted = Self::clone_and_drop(guard);
172        if self.job_id.is_nil() {
173            return Ok(());
174        }
175        self.persist_saved_data(persisted).await
176    }
177
178    /// Check if cancellation has been requested for this job.
179    pub async fn is_cancel_requested(&self) -> crate::Result<bool> {
180        let row: Option<(String,)> = sqlx::query_as(
181            r#"
182            SELECT status
183            FROM forge_jobs
184            WHERE id = $1
185            "#,
186        )
187        .bind(self.job_id)
188        .fetch_optional(&self.db_pool)
189        .await
190        .map_err(|e| crate::ForgeError::Database(e.to_string()))?;
191
192        Ok(matches!(
193            row.as_ref().map(|(status,)| status.as_str()),
194            Some("cancel_requested") | Some("cancelled")
195        ))
196    }
197
198    /// Return an error if cancellation has been requested.
199    pub async fn check_cancelled(&self) -> crate::Result<()> {
200        if self.is_cancel_requested().await? {
201            Err(crate::ForgeError::JobCancelled(
202                "Job cancellation requested".to_string(),
203            ))
204        } else {
205            Ok(())
206        }
207    }
208
209    async fn persist_saved_data(&self, data: serde_json::Value) -> crate::Result<()> {
210        sqlx::query(
211            r#"
212            UPDATE forge_jobs
213            SET job_context = $2
214            WHERE id = $1
215            "#,
216        )
217        .bind(self.job_id)
218        .bind(data)
219        .execute(&self.db_pool)
220        .await
221        .map_err(|e| crate::ForgeError::Database(e.to_string()))?;
222
223        Ok(())
224    }
225
226    fn apply_save(data: &mut serde_json::Value, key: &str, value: serde_json::Value) {
227        if let Some(map) = data.as_object_mut() {
228            map.insert(key.to_string(), value);
229        } else {
230            let mut map = serde_json::Map::new();
231            map.insert(key.to_string(), value);
232            *data = serde_json::Value::Object(map);
233        }
234    }
235
236    fn clone_and_drop(
237        guard: tokio::sync::RwLockWriteGuard<'_, serde_json::Value>,
238    ) -> serde_json::Value {
239        let cloned = guard.clone();
240        drop(guard);
241        cloned
242    }
243
244    /// Send heartbeat to keep job alive (async).
245    pub async fn heartbeat(&self) -> crate::Result<()> {
246        sqlx::query(
247            r#"
248            UPDATE forge_jobs
249            SET last_heartbeat = NOW()
250            WHERE id = $1
251            "#,
252        )
253        .bind(self.job_id)
254        .execute(&self.db_pool)
255        .await
256        .map_err(|e| crate::ForgeError::Database(e.to_string()))?;
257
258        Ok(())
259    }
260
261    /// Check if this is a retry attempt.
262    pub fn is_retry(&self) -> bool {
263        self.attempt > 1
264    }
265
266    /// Check if this is the last attempt.
267    pub fn is_last_attempt(&self) -> bool {
268        self.attempt >= self.max_attempts
269    }
270}
271
272impl EnvAccess for JobContext {
273    fn env_provider(&self) -> &dyn EnvProvider {
274        self.env_provider.as_ref()
275    }
276}
277
278#[cfg(test)]
279#[allow(clippy::unwrap_used, clippy::indexing_slicing)]
280mod tests {
281    use super::*;
282
283    #[tokio::test]
284    async fn test_job_context_creation() {
285        let pool = sqlx::postgres::PgPoolOptions::new()
286            .max_connections(1)
287            .connect_lazy("postgres://localhost/nonexistent")
288            .expect("Failed to create mock pool");
289
290        let job_id = Uuid::new_v4();
291        let ctx = JobContext::new(
292            job_id,
293            "test_job".to_string(),
294            1,
295            3,
296            pool,
297            reqwest::Client::new(),
298        );
299
300        assert_eq!(ctx.job_id, job_id);
301        assert_eq!(ctx.job_type, "test_job");
302        assert_eq!(ctx.attempt, 1);
303        assert_eq!(ctx.max_attempts, 3);
304        assert!(!ctx.is_retry());
305        assert!(!ctx.is_last_attempt());
306    }
307
308    #[tokio::test]
309    async fn test_is_retry() {
310        let pool = sqlx::postgres::PgPoolOptions::new()
311            .max_connections(1)
312            .connect_lazy("postgres://localhost/nonexistent")
313            .expect("Failed to create mock pool");
314
315        let ctx = JobContext::new(
316            Uuid::new_v4(),
317            "test".to_string(),
318            2,
319            3,
320            pool,
321            reqwest::Client::new(),
322        );
323
324        assert!(ctx.is_retry());
325    }
326
327    #[tokio::test]
328    async fn test_is_last_attempt() {
329        let pool = sqlx::postgres::PgPoolOptions::new()
330            .max_connections(1)
331            .connect_lazy("postgres://localhost/nonexistent")
332            .expect("Failed to create mock pool");
333
334        let ctx = JobContext::new(
335            Uuid::new_v4(),
336            "test".to_string(),
337            3,
338            3,
339            pool,
340            reqwest::Client::new(),
341        );
342
343        assert!(ctx.is_last_attempt());
344    }
345
346    #[test]
347    fn test_progress_update() {
348        let update = ProgressUpdate {
349            job_id: Uuid::new_v4(),
350            percentage: 50,
351            message: "Halfway there".to_string(),
352        };
353
354        assert_eq!(update.percentage, 50);
355        assert_eq!(update.message, "Halfway there");
356    }
357
358    #[tokio::test]
359    async fn test_saved_data_in_memory() {
360        let pool = sqlx::postgres::PgPoolOptions::new()
361            .max_connections(1)
362            .connect_lazy("postgres://localhost/nonexistent")
363            .expect("Failed to create mock pool");
364
365        let ctx = JobContext::new(
366            Uuid::nil(),
367            "test_job".to_string(),
368            1,
369            3,
370            pool,
371            reqwest::Client::new(),
372        )
373        .with_saved(serde_json::json!({"foo": "bar"}));
374
375        let saved = ctx.saved().await;
376        assert_eq!(saved["foo"], "bar");
377    }
378
379    #[tokio::test]
380    async fn test_save_key_value() {
381        let pool = sqlx::postgres::PgPoolOptions::new()
382            .max_connections(1)
383            .connect_lazy("postgres://localhost/nonexistent")
384            .expect("Failed to create mock pool");
385
386        let ctx = JobContext::new(
387            Uuid::nil(),
388            "test_job".to_string(),
389            1,
390            3,
391            pool,
392            reqwest::Client::new(),
393        );
394
395        ctx.save("charge_id", serde_json::json!("ch_123"))
396            .await
397            .unwrap();
398        ctx.save("amount", serde_json::json!(100)).await.unwrap();
399
400        let saved = ctx.saved().await;
401        assert_eq!(saved["charge_id"], "ch_123");
402        assert_eq!(saved["amount"], 100);
403    }
404}