Skip to main content

forge_core/job/
context.rs

1use std::sync::{Arc, mpsc};
2
3use uuid::Uuid;
4
5use crate::env::{EnvAccess, EnvProvider, RealEnvProvider};
6use crate::function::AuthContext;
7
8/// Returns an empty JSON object for initializing job saved data.
9pub fn empty_saved_data() -> serde_json::Value {
10    serde_json::Value::Object(serde_json::Map::new())
11}
12
13/// Context available to job handlers.
14pub struct JobContext {
15    /// Job ID.
16    pub job_id: Uuid,
17    /// Job type/name.
18    pub job_type: String,
19    /// Current attempt number (1-based).
20    pub attempt: u32,
21    /// Maximum attempts allowed.
22    pub max_attempts: u32,
23    /// Authentication context (for queries/mutations).
24    pub auth: AuthContext,
25    /// Persisted job data (survives retries, accessible during compensation).
26    saved_data: Arc<tokio::sync::RwLock<serde_json::Value>>,
27    /// Database pool.
28    db_pool: sqlx::PgPool,
29    /// HTTP client for external calls.
30    http_client: reqwest::Client,
31    /// Progress reporter (sync channel for simplicity).
32    progress_tx: Option<mpsc::Sender<ProgressUpdate>>,
33    /// Environment variable provider.
34    env_provider: Arc<dyn EnvProvider>,
35}
36
37/// Progress update message.
38#[derive(Debug, Clone)]
39pub struct ProgressUpdate {
40    /// Job ID.
41    pub job_id: Uuid,
42    /// Progress percentage (0-100).
43    pub percentage: u8,
44    /// Status message.
45    pub message: String,
46}
47
48impl JobContext {
49    /// Create a new job context.
50    pub fn new(
51        job_id: Uuid,
52        job_type: String,
53        attempt: u32,
54        max_attempts: u32,
55        db_pool: sqlx::PgPool,
56        http_client: reqwest::Client,
57    ) -> Self {
58        Self {
59            job_id,
60            job_type,
61            attempt,
62            max_attempts,
63            auth: AuthContext::unauthenticated(),
64            saved_data: Arc::new(tokio::sync::RwLock::new(empty_saved_data())),
65            db_pool,
66            http_client,
67            progress_tx: None,
68            env_provider: Arc::new(RealEnvProvider::new()),
69        }
70    }
71
72    /// Create a new job context with persisted saved data.
73    pub fn with_saved(mut self, data: serde_json::Value) -> Self {
74        self.saved_data = Arc::new(tokio::sync::RwLock::new(data));
75        self
76    }
77
78    /// Set authentication context.
79    pub fn with_auth(mut self, auth: AuthContext) -> Self {
80        self.auth = auth;
81        self
82    }
83
84    /// Set progress channel.
85    pub fn with_progress(mut self, tx: mpsc::Sender<ProgressUpdate>) -> Self {
86        self.progress_tx = Some(tx);
87        self
88    }
89
90    /// Set environment provider.
91    pub fn with_env_provider(mut self, provider: Arc<dyn EnvProvider>) -> Self {
92        self.env_provider = provider;
93        self
94    }
95
96    /// Get database pool.
97    pub fn db(&self) -> crate::function::ForgeDb {
98        crate::function::ForgeDb::from_pool(&self.db_pool)
99    }
100
101    /// Acquire a connection compatible with sqlx compile-time checked macros.
102    pub async fn conn(&self) -> sqlx::Result<crate::function::ForgeConn<'static>> {
103        Ok(crate::function::ForgeConn::Pool(
104            self.db_pool.acquire().await?,
105        ))
106    }
107
108    /// Get HTTP client.
109    pub fn http(&self) -> &reqwest::Client {
110        &self.http_client
111    }
112
113    /// Report job progress.
114    pub fn progress(&self, percentage: u8, message: impl Into<String>) -> crate::Result<()> {
115        let update = ProgressUpdate {
116            job_id: self.job_id,
117            percentage: percentage.min(100),
118            message: message.into(),
119        };
120
121        if let Some(ref tx) = self.progress_tx {
122            tx.send(update)
123                .map_err(|e| crate::ForgeError::Job(format!("Failed to send progress: {}", e)))?;
124        }
125
126        Ok(())
127    }
128
129    /// Get all saved job data.
130    ///
131    /// Returns data that was saved during job execution via `save()`.
132    /// This data persists across retries and is accessible in compensation handlers.
133    pub async fn saved(&self) -> serde_json::Value {
134        self.saved_data.read().await.clone()
135    }
136
137    /// Replace all saved job data.
138    ///
139    /// Replaces the entire saved data object. For updating individual keys,
140    /// use `save()` instead.
141    pub async fn set_saved(&self, data: serde_json::Value) -> crate::Result<()> {
142        let mut guard = self.saved_data.write().await;
143        *guard = data;
144        let persisted = Self::clone_and_drop(guard);
145        if self.job_id.is_nil() {
146            return Ok(());
147        }
148        self.persist_saved_data(persisted).await
149    }
150
151    /// Save a key-value pair to persistent job data.
152    ///
153    /// Saved data persists across retries and is accessible in compensation handlers.
154    /// Use this to store information needed for rollback (e.g., transaction IDs,
155    /// resource handles, progress markers).
156    ///
157    /// # Example
158    ///
159    /// ```ignore
160    /// ctx.save("charge_id", json!(charge.id)).await?;
161    /// ctx.save("refund_amount", json!(amount)).await?;
162    /// ```
163    pub async fn save(&self, key: &str, value: serde_json::Value) -> crate::Result<()> {
164        let mut guard = self.saved_data.write().await;
165        Self::apply_save(&mut guard, key, value);
166        let persisted = Self::clone_and_drop(guard);
167        if self.job_id.is_nil() {
168            return Ok(());
169        }
170        self.persist_saved_data(persisted).await
171    }
172
173    /// Check if cancellation has been requested for this job.
174    pub async fn is_cancel_requested(&self) -> crate::Result<bool> {
175        let row: Option<(String,)> = sqlx::query_as(
176            r#"
177            SELECT status
178            FROM forge_jobs
179            WHERE id = $1
180            "#,
181        )
182        .bind(self.job_id)
183        .fetch_optional(&self.db_pool)
184        .await
185        .map_err(|e| crate::ForgeError::Database(e.to_string()))?;
186
187        Ok(matches!(
188            row.as_ref().map(|(status,)| status.as_str()),
189            Some("cancel_requested") | Some("cancelled")
190        ))
191    }
192
193    /// Return an error if cancellation has been requested.
194    pub async fn check_cancelled(&self) -> crate::Result<()> {
195        if self.is_cancel_requested().await? {
196            Err(crate::ForgeError::JobCancelled(
197                "Job cancellation requested".to_string(),
198            ))
199        } else {
200            Ok(())
201        }
202    }
203
204    async fn persist_saved_data(&self, data: serde_json::Value) -> crate::Result<()> {
205        sqlx::query(
206            r#"
207            UPDATE forge_jobs
208            SET job_context = $2
209            WHERE id = $1
210            "#,
211        )
212        .bind(self.job_id)
213        .bind(data)
214        .execute(&self.db_pool)
215        .await
216        .map_err(|e| crate::ForgeError::Database(e.to_string()))?;
217
218        Ok(())
219    }
220
221    fn apply_save(data: &mut serde_json::Value, key: &str, value: serde_json::Value) {
222        if let Some(map) = data.as_object_mut() {
223            map.insert(key.to_string(), value);
224        } else {
225            let mut map = serde_json::Map::new();
226            map.insert(key.to_string(), value);
227            *data = serde_json::Value::Object(map);
228        }
229    }
230
231    fn clone_and_drop(
232        guard: tokio::sync::RwLockWriteGuard<'_, serde_json::Value>,
233    ) -> serde_json::Value {
234        let cloned = guard.clone();
235        drop(guard);
236        cloned
237    }
238
239    /// Send heartbeat to keep job alive (async).
240    pub async fn heartbeat(&self) -> crate::Result<()> {
241        sqlx::query(
242            r#"
243            UPDATE forge_jobs
244            SET last_heartbeat = NOW()
245            WHERE id = $1
246            "#,
247        )
248        .bind(self.job_id)
249        .execute(&self.db_pool)
250        .await
251        .map_err(|e| crate::ForgeError::Database(e.to_string()))?;
252
253        Ok(())
254    }
255
256    /// Check if this is a retry attempt.
257    pub fn is_retry(&self) -> bool {
258        self.attempt > 1
259    }
260
261    /// Check if this is the last attempt.
262    pub fn is_last_attempt(&self) -> bool {
263        self.attempt >= self.max_attempts
264    }
265}
266
267impl EnvAccess for JobContext {
268    fn env_provider(&self) -> &dyn EnvProvider {
269        self.env_provider.as_ref()
270    }
271}
272
273#[cfg(test)]
274#[allow(clippy::unwrap_used, clippy::indexing_slicing)]
275mod tests {
276    use super::*;
277
278    #[tokio::test]
279    async fn test_job_context_creation() {
280        let pool = sqlx::postgres::PgPoolOptions::new()
281            .max_connections(1)
282            .connect_lazy("postgres://localhost/nonexistent")
283            .expect("Failed to create mock pool");
284
285        let job_id = Uuid::new_v4();
286        let ctx = JobContext::new(
287            job_id,
288            "test_job".to_string(),
289            1,
290            3,
291            pool,
292            reqwest::Client::new(),
293        );
294
295        assert_eq!(ctx.job_id, job_id);
296        assert_eq!(ctx.job_type, "test_job");
297        assert_eq!(ctx.attempt, 1);
298        assert_eq!(ctx.max_attempts, 3);
299        assert!(!ctx.is_retry());
300        assert!(!ctx.is_last_attempt());
301    }
302
303    #[tokio::test]
304    async fn test_is_retry() {
305        let pool = sqlx::postgres::PgPoolOptions::new()
306            .max_connections(1)
307            .connect_lazy("postgres://localhost/nonexistent")
308            .expect("Failed to create mock pool");
309
310        let ctx = JobContext::new(
311            Uuid::new_v4(),
312            "test".to_string(),
313            2,
314            3,
315            pool,
316            reqwest::Client::new(),
317        );
318
319        assert!(ctx.is_retry());
320    }
321
322    #[tokio::test]
323    async fn test_is_last_attempt() {
324        let pool = sqlx::postgres::PgPoolOptions::new()
325            .max_connections(1)
326            .connect_lazy("postgres://localhost/nonexistent")
327            .expect("Failed to create mock pool");
328
329        let ctx = JobContext::new(
330            Uuid::new_v4(),
331            "test".to_string(),
332            3,
333            3,
334            pool,
335            reqwest::Client::new(),
336        );
337
338        assert!(ctx.is_last_attempt());
339    }
340
341    #[test]
342    fn test_progress_update() {
343        let update = ProgressUpdate {
344            job_id: Uuid::new_v4(),
345            percentage: 50,
346            message: "Halfway there".to_string(),
347        };
348
349        assert_eq!(update.percentage, 50);
350        assert_eq!(update.message, "Halfway there");
351    }
352
353    #[tokio::test]
354    async fn test_saved_data_in_memory() {
355        let pool = sqlx::postgres::PgPoolOptions::new()
356            .max_connections(1)
357            .connect_lazy("postgres://localhost/nonexistent")
358            .expect("Failed to create mock pool");
359
360        let ctx = JobContext::new(
361            Uuid::nil(),
362            "test_job".to_string(),
363            1,
364            3,
365            pool,
366            reqwest::Client::new(),
367        )
368        .with_saved(serde_json::json!({"foo": "bar"}));
369
370        let saved = ctx.saved().await;
371        assert_eq!(saved["foo"], "bar");
372    }
373
374    #[tokio::test]
375    async fn test_save_key_value() {
376        let pool = sqlx::postgres::PgPoolOptions::new()
377            .max_connections(1)
378            .connect_lazy("postgres://localhost/nonexistent")
379            .expect("Failed to create mock pool");
380
381        let ctx = JobContext::new(
382            Uuid::nil(),
383            "test_job".to_string(),
384            1,
385            3,
386            pool,
387            reqwest::Client::new(),
388        );
389
390        ctx.save("charge_id", serde_json::json!("ch_123"))
391            .await
392            .unwrap();
393        ctx.save("amount", serde_json::json!(100)).await.unwrap();
394
395        let saved = ctx.saved().await;
396        assert_eq!(saved["charge_id"], "ch_123");
397        assert_eq!(saved["amount"], 100);
398    }
399}