Skip to main content

forge_core/job/
context.rs

1use std::sync::{Arc, mpsc};
2
3use uuid::Uuid;
4
5use crate::env::{EnvAccess, EnvProvider, RealEnvProvider};
6use crate::function::AuthContext;
7
8/// Returns an empty JSON object for initializing job saved data.
9pub fn empty_saved_data() -> serde_json::Value {
10    serde_json::Value::Object(serde_json::Map::new())
11}
12
13/// Context available to job handlers.
14pub struct JobContext {
15    /// Job ID.
16    pub job_id: Uuid,
17    /// Job type/name.
18    pub job_type: String,
19    /// Current attempt number (1-based).
20    pub attempt: u32,
21    /// Maximum attempts allowed.
22    pub max_attempts: u32,
23    /// Authentication context (for queries/mutations).
24    pub auth: AuthContext,
25    /// Persisted job data (survives retries, accessible during compensation).
26    saved_data: Arc<tokio::sync::RwLock<serde_json::Value>>,
27    /// Database pool.
28    db_pool: sqlx::PgPool,
29    /// HTTP client for external calls.
30    http_client: reqwest::Client,
31    /// Progress reporter (sync channel for simplicity).
32    progress_tx: Option<mpsc::Sender<ProgressUpdate>>,
33    /// Environment variable provider.
34    env_provider: Arc<dyn EnvProvider>,
35}
36
37/// Progress update message.
38#[derive(Debug, Clone)]
39pub struct ProgressUpdate {
40    /// Job ID.
41    pub job_id: Uuid,
42    /// Progress percentage (0-100).
43    pub percentage: u8,
44    /// Status message.
45    pub message: String,
46}
47
48impl JobContext {
49    /// Create a new job context.
50    pub fn new(
51        job_id: Uuid,
52        job_type: String,
53        attempt: u32,
54        max_attempts: u32,
55        db_pool: sqlx::PgPool,
56        http_client: reqwest::Client,
57    ) -> Self {
58        Self {
59            job_id,
60            job_type,
61            attempt,
62            max_attempts,
63            auth: AuthContext::unauthenticated(),
64            saved_data: Arc::new(tokio::sync::RwLock::new(empty_saved_data())),
65            db_pool,
66            http_client,
67            progress_tx: None,
68            env_provider: Arc::new(RealEnvProvider::new()),
69        }
70    }
71
72    /// Create a new job context with persisted saved data.
73    pub fn with_saved(mut self, data: serde_json::Value) -> Self {
74        self.saved_data = Arc::new(tokio::sync::RwLock::new(data));
75        self
76    }
77
78    /// Set authentication context.
79    pub fn with_auth(mut self, auth: AuthContext) -> Self {
80        self.auth = auth;
81        self
82    }
83
84    /// Set progress channel.
85    pub fn with_progress(mut self, tx: mpsc::Sender<ProgressUpdate>) -> Self {
86        self.progress_tx = Some(tx);
87        self
88    }
89
90    /// Set environment provider.
91    pub fn with_env_provider(mut self, provider: Arc<dyn EnvProvider>) -> Self {
92        self.env_provider = provider;
93        self
94    }
95
96    /// Get database pool.
97    pub fn db(&self) -> &sqlx::PgPool {
98        &self.db_pool
99    }
100
101    /// Get HTTP client.
102    pub fn http(&self) -> &reqwest::Client {
103        &self.http_client
104    }
105
106    /// Report job progress.
107    pub fn progress(&self, percentage: u8, message: impl Into<String>) -> crate::Result<()> {
108        let update = ProgressUpdate {
109            job_id: self.job_id,
110            percentage: percentage.min(100),
111            message: message.into(),
112        };
113
114        if let Some(ref tx) = self.progress_tx {
115            tx.send(update)
116                .map_err(|e| crate::ForgeError::Job(format!("Failed to send progress: {}", e)))?;
117        }
118
119        Ok(())
120    }
121
122    /// Get all saved job data.
123    ///
124    /// Returns data that was saved during job execution via `save()`.
125    /// This data persists across retries and is accessible in compensation handlers.
126    pub async fn saved(&self) -> serde_json::Value {
127        self.saved_data.read().await.clone()
128    }
129
130    /// Replace all saved job data.
131    ///
132    /// Replaces the entire saved data object. For updating individual keys,
133    /// use `save()` instead.
134    pub async fn set_saved(&self, data: serde_json::Value) -> crate::Result<()> {
135        let mut guard = self.saved_data.write().await;
136        *guard = data;
137        let persisted = Self::clone_and_drop(guard);
138        if self.job_id.is_nil() {
139            return Ok(());
140        }
141        self.persist_saved_data(persisted).await
142    }
143
144    /// Save a key-value pair to persistent job data.
145    ///
146    /// Saved data persists across retries and is accessible in compensation handlers.
147    /// Use this to store information needed for rollback (e.g., transaction IDs,
148    /// resource handles, progress markers).
149    ///
150    /// # Example
151    ///
152    /// ```ignore
153    /// ctx.save("charge_id", json!(charge.id)).await?;
154    /// ctx.save("refund_amount", json!(amount)).await?;
155    /// ```
156    pub async fn save(&self, key: &str, value: serde_json::Value) -> crate::Result<()> {
157        let mut guard = self.saved_data.write().await;
158        Self::apply_save(&mut guard, key, value);
159        let persisted = Self::clone_and_drop(guard);
160        if self.job_id.is_nil() {
161            return Ok(());
162        }
163        self.persist_saved_data(persisted).await
164    }
165
166    /// Check if cancellation has been requested for this job.
167    pub async fn is_cancel_requested(&self) -> crate::Result<bool> {
168        let row: Option<(String,)> = sqlx::query_as(
169            r#"
170            SELECT status
171            FROM forge_jobs
172            WHERE id = $1
173            "#,
174        )
175        .bind(self.job_id)
176        .fetch_optional(&self.db_pool)
177        .await
178        .map_err(|e| crate::ForgeError::Database(e.to_string()))?;
179
180        Ok(matches!(
181            row.as_ref().map(|(status,)| status.as_str()),
182            Some("cancel_requested") | Some("cancelled")
183        ))
184    }
185
186    /// Return an error if cancellation has been requested.
187    pub async fn check_cancelled(&self) -> crate::Result<()> {
188        if self.is_cancel_requested().await? {
189            Err(crate::ForgeError::JobCancelled(
190                "Job cancellation requested".to_string(),
191            ))
192        } else {
193            Ok(())
194        }
195    }
196
197    async fn persist_saved_data(&self, data: serde_json::Value) -> crate::Result<()> {
198        sqlx::query(
199            r#"
200            UPDATE forge_jobs
201            SET job_context = $2
202            WHERE id = $1
203            "#,
204        )
205        .bind(self.job_id)
206        .bind(data)
207        .execute(&self.db_pool)
208        .await
209        .map_err(|e| crate::ForgeError::Database(e.to_string()))?;
210
211        Ok(())
212    }
213
214    fn apply_save(data: &mut serde_json::Value, key: &str, value: serde_json::Value) {
215        if let Some(map) = data.as_object_mut() {
216            map.insert(key.to_string(), value);
217        } else {
218            let mut map = serde_json::Map::new();
219            map.insert(key.to_string(), value);
220            *data = serde_json::Value::Object(map);
221        }
222    }
223
224    fn clone_and_drop(
225        guard: tokio::sync::RwLockWriteGuard<'_, serde_json::Value>,
226    ) -> serde_json::Value {
227        let cloned = guard.clone();
228        drop(guard);
229        cloned
230    }
231
232    /// Send heartbeat to keep job alive (async).
233    pub async fn heartbeat(&self) -> crate::Result<()> {
234        sqlx::query(
235            r#"
236            UPDATE forge_jobs
237            SET last_heartbeat = NOW()
238            WHERE id = $1
239            "#,
240        )
241        .bind(self.job_id)
242        .execute(&self.db_pool)
243        .await
244        .map_err(|e| crate::ForgeError::Database(e.to_string()))?;
245
246        Ok(())
247    }
248
249    /// Check if this is a retry attempt.
250    pub fn is_retry(&self) -> bool {
251        self.attempt > 1
252    }
253
254    /// Check if this is the last attempt.
255    pub fn is_last_attempt(&self) -> bool {
256        self.attempt >= self.max_attempts
257    }
258}
259
260impl EnvAccess for JobContext {
261    fn env_provider(&self) -> &dyn EnvProvider {
262        self.env_provider.as_ref()
263    }
264}
265
266#[cfg(test)]
267mod tests {
268    use super::*;
269
270    #[tokio::test]
271    async fn test_job_context_creation() {
272        let pool = sqlx::postgres::PgPoolOptions::new()
273            .max_connections(1)
274            .connect_lazy("postgres://localhost/nonexistent")
275            .expect("Failed to create mock pool");
276
277        let job_id = Uuid::new_v4();
278        let ctx = JobContext::new(
279            job_id,
280            "test_job".to_string(),
281            1,
282            3,
283            pool,
284            reqwest::Client::new(),
285        );
286
287        assert_eq!(ctx.job_id, job_id);
288        assert_eq!(ctx.job_type, "test_job");
289        assert_eq!(ctx.attempt, 1);
290        assert_eq!(ctx.max_attempts, 3);
291        assert!(!ctx.is_retry());
292        assert!(!ctx.is_last_attempt());
293    }
294
295    #[tokio::test]
296    async fn test_is_retry() {
297        let pool = sqlx::postgres::PgPoolOptions::new()
298            .max_connections(1)
299            .connect_lazy("postgres://localhost/nonexistent")
300            .expect("Failed to create mock pool");
301
302        let ctx = JobContext::new(
303            Uuid::new_v4(),
304            "test".to_string(),
305            2,
306            3,
307            pool,
308            reqwest::Client::new(),
309        );
310
311        assert!(ctx.is_retry());
312    }
313
314    #[tokio::test]
315    async fn test_is_last_attempt() {
316        let pool = sqlx::postgres::PgPoolOptions::new()
317            .max_connections(1)
318            .connect_lazy("postgres://localhost/nonexistent")
319            .expect("Failed to create mock pool");
320
321        let ctx = JobContext::new(
322            Uuid::new_v4(),
323            "test".to_string(),
324            3,
325            3,
326            pool,
327            reqwest::Client::new(),
328        );
329
330        assert!(ctx.is_last_attempt());
331    }
332
333    #[test]
334    fn test_progress_update() {
335        let update = ProgressUpdate {
336            job_id: Uuid::new_v4(),
337            percentage: 50,
338            message: "Halfway there".to_string(),
339        };
340
341        assert_eq!(update.percentage, 50);
342        assert_eq!(update.message, "Halfway there");
343    }
344
345    #[tokio::test]
346    async fn test_saved_data_in_memory() {
347        let pool = sqlx::postgres::PgPoolOptions::new()
348            .max_connections(1)
349            .connect_lazy("postgres://localhost/nonexistent")
350            .expect("Failed to create mock pool");
351
352        let ctx = JobContext::new(
353            Uuid::nil(),
354            "test_job".to_string(),
355            1,
356            3,
357            pool,
358            reqwest::Client::new(),
359        )
360        .with_saved(serde_json::json!({"foo": "bar"}));
361
362        let saved = ctx.saved().await;
363        assert_eq!(saved["foo"], "bar");
364    }
365
366    #[tokio::test]
367    async fn test_save_key_value() {
368        let pool = sqlx::postgres::PgPoolOptions::new()
369            .max_connections(1)
370            .connect_lazy("postgres://localhost/nonexistent")
371            .expect("Failed to create mock pool");
372
373        let ctx = JobContext::new(
374            Uuid::nil(),
375            "test_job".to_string(),
376            1,
377            3,
378            pool,
379            reqwest::Client::new(),
380        );
381
382        ctx.save("charge_id", serde_json::json!("ch_123"))
383            .await
384            .unwrap();
385        ctx.save("amount", serde_json::json!(100)).await.unwrap();
386
387        let saved = ctx.saved().await;
388        assert_eq!(saved["charge_id"], "ch_123");
389        assert_eq!(saved["amount"], 100);
390    }
391}