Skip to main content

forge_core/job/
context.rs

1use std::sync::{Arc, mpsc};
2use std::time::Duration;
3
4use uuid::Uuid;
5
6use crate::env::{EnvAccess, EnvProvider, RealEnvProvider};
7use crate::function::AuthContext;
8use crate::http::CircuitBreakerClient;
9
10/// Returns an empty JSON object for initializing job saved data.
11pub fn empty_saved_data() -> serde_json::Value {
12    serde_json::Value::Object(serde_json::Map::new())
13}
14
15/// Context available to job handlers.
16pub struct JobContext {
17    /// Job ID.
18    pub job_id: Uuid,
19    /// Job type/name.
20    pub job_type: String,
21    /// Current attempt number (1-based).
22    pub attempt: u32,
23    /// Maximum attempts allowed.
24    pub max_attempts: u32,
25    /// Authentication context (for queries/mutations).
26    pub auth: AuthContext,
27    /// Persisted job data (survives retries, accessible during compensation).
28    saved_data: Arc<tokio::sync::RwLock<serde_json::Value>>,
29    /// Database pool.
30    db_pool: sqlx::PgPool,
31    /// HTTP client for external calls.
32    http_client: CircuitBreakerClient,
33    /// Default timeout for outbound HTTP requests made through the
34    /// circuit-breaker client. `None` means unlimited.
35    http_timeout: Option<Duration>,
36    /// Progress reporter (sync channel for simplicity).
37    progress_tx: Option<mpsc::Sender<ProgressUpdate>>,
38    /// Environment variable provider.
39    env_provider: Arc<dyn EnvProvider>,
40}
41
42/// Progress update message.
43#[derive(Debug, Clone)]
44pub struct ProgressUpdate {
45    /// Job ID.
46    pub job_id: Uuid,
47    /// Progress percentage (0-100).
48    pub percentage: u8,
49    /// Status message.
50    pub message: String,
51}
52
53impl JobContext {
54    /// Create a new job context.
55    pub fn new(
56        job_id: Uuid,
57        job_type: String,
58        attempt: u32,
59        max_attempts: u32,
60        db_pool: sqlx::PgPool,
61        http_client: CircuitBreakerClient,
62    ) -> Self {
63        Self {
64            job_id,
65            job_type,
66            attempt,
67            max_attempts,
68            auth: AuthContext::unauthenticated(),
69            saved_data: Arc::new(tokio::sync::RwLock::new(empty_saved_data())),
70            db_pool,
71            http_client,
72            http_timeout: None,
73            progress_tx: None,
74            env_provider: Arc::new(RealEnvProvider::new()),
75        }
76    }
77
78    /// Create a new job context with persisted saved data.
79    pub fn with_saved(mut self, data: serde_json::Value) -> Self {
80        self.saved_data = Arc::new(tokio::sync::RwLock::new(data));
81        self
82    }
83
84    /// Set authentication context.
85    pub fn with_auth(mut self, auth: AuthContext) -> Self {
86        self.auth = auth;
87        self
88    }
89
90    /// Set progress channel.
91    pub fn with_progress(mut self, tx: mpsc::Sender<ProgressUpdate>) -> Self {
92        self.progress_tx = Some(tx);
93        self
94    }
95
96    /// Set environment provider.
97    pub fn with_env_provider(mut self, provider: Arc<dyn EnvProvider>) -> Self {
98        self.env_provider = provider;
99        self
100    }
101
102    /// Get database pool.
103    pub fn db(&self) -> crate::function::ForgeDb {
104        crate::function::ForgeDb::from_pool(&self.db_pool)
105    }
106
107    /// Get a `DbConn` for use in shared helper functions.
108    pub fn db_conn(&self) -> crate::function::DbConn<'_> {
109        crate::function::DbConn::Pool(self.db_pool.clone())
110    }
111
112    /// Acquire a connection compatible with sqlx compile-time checked macros.
113    pub async fn conn(&self) -> sqlx::Result<crate::function::ForgeConn<'static>> {
114        Ok(crate::function::ForgeConn::Pool(
115            self.db_pool.acquire().await?,
116        ))
117    }
118
119    /// Get the HTTP client for external requests.
120    pub fn http(&self) -> crate::http::HttpClient {
121        self.http_client.with_timeout(self.http_timeout)
122    }
123
124    /// Get the raw reqwest client, bypassing circuit breaker execution.
125    pub fn raw_http(&self) -> &reqwest::Client {
126        self.http_client.inner()
127    }
128
129    pub fn set_http_timeout(&mut self, timeout: Option<Duration>) {
130        self.http_timeout = timeout;
131    }
132
133    /// Report job progress.
134    pub fn progress(&self, percentage: u8, message: impl Into<String>) -> crate::Result<()> {
135        let update = ProgressUpdate {
136            job_id: self.job_id,
137            percentage: percentage.min(100),
138            message: message.into(),
139        };
140
141        if let Some(ref tx) = self.progress_tx {
142            tx.send(update)
143                .map_err(|e| crate::ForgeError::Job(format!("Failed to send progress: {}", e)))?;
144        }
145
146        Ok(())
147    }
148
149    /// Get all saved job data.
150    ///
151    /// Returns data that was saved during job execution via `save()`.
152    /// This data persists across retries and is accessible in compensation handlers.
153    pub async fn saved(&self) -> serde_json::Value {
154        self.saved_data.read().await.clone()
155    }
156
157    /// Replace all saved job data.
158    ///
159    /// Replaces the entire saved data object. For updating individual keys,
160    /// use `save()` instead.
161    pub async fn set_saved(&self, data: serde_json::Value) -> crate::Result<()> {
162        let mut guard = self.saved_data.write().await;
163        *guard = data;
164        let persisted = Self::clone_and_drop(guard);
165        if self.job_id.is_nil() {
166            return Ok(());
167        }
168        self.persist_saved_data(persisted).await
169    }
170
171    /// Save a key-value pair to persistent job data.
172    ///
173    /// Saved data persists across retries and is accessible in compensation handlers.
174    /// Use this to store information needed for rollback (e.g., transaction IDs,
175    /// resource handles, progress markers).
176    ///
177    /// # Example
178    ///
179    /// ```ignore
180    /// ctx.save("charge_id", json!(charge.id)).await?;
181    /// ctx.save("refund_amount", json!(amount)).await?;
182    /// ```
183    pub async fn save(&self, key: &str, value: serde_json::Value) -> crate::Result<()> {
184        let mut guard = self.saved_data.write().await;
185        Self::apply_save(&mut guard, key, value);
186        let persisted = Self::clone_and_drop(guard);
187        if self.job_id.is_nil() {
188            return Ok(());
189        }
190        self.persist_saved_data(persisted).await
191    }
192
193    /// Check if cancellation has been requested for this job.
194    pub async fn is_cancel_requested(&self) -> crate::Result<bool> {
195        let row = sqlx::query_scalar!(
196            r#"
197            SELECT status
198            FROM forge_jobs
199            WHERE id = $1
200            "#,
201            self.job_id
202        )
203        .fetch_optional(&self.db_pool)
204        .await
205        .map_err(|e| crate::ForgeError::Database(e.to_string()))?;
206
207        Ok(matches!(
208            row.as_deref(),
209            Some("cancel_requested") | Some("cancelled")
210        ))
211    }
212
213    /// Return an error if cancellation has been requested.
214    pub async fn check_cancelled(&self) -> crate::Result<()> {
215        if self.is_cancel_requested().await? {
216            Err(crate::ForgeError::JobCancelled(
217                "Job cancellation requested".to_string(),
218            ))
219        } else {
220            Ok(())
221        }
222    }
223
224    async fn persist_saved_data(&self, data: serde_json::Value) -> crate::Result<()> {
225        sqlx::query!(
226            r#"
227            UPDATE forge_jobs
228            SET job_context = $2
229            WHERE id = $1
230            "#,
231            self.job_id,
232            data,
233        )
234        .execute(&self.db_pool)
235        .await
236        .map_err(|e| crate::ForgeError::Database(e.to_string()))?;
237
238        Ok(())
239    }
240
241    fn apply_save(data: &mut serde_json::Value, key: &str, value: serde_json::Value) {
242        if let Some(map) = data.as_object_mut() {
243            map.insert(key.to_string(), value);
244        } else {
245            let mut map = serde_json::Map::new();
246            map.insert(key.to_string(), value);
247            *data = serde_json::Value::Object(map);
248        }
249    }
250
251    fn clone_and_drop(
252        guard: tokio::sync::RwLockWriteGuard<'_, serde_json::Value>,
253    ) -> serde_json::Value {
254        let cloned = guard.clone();
255        drop(guard);
256        cloned
257    }
258
259    /// Send heartbeat to keep job alive (async).
260    pub async fn heartbeat(&self) -> crate::Result<()> {
261        sqlx::query!(
262            r#"
263            UPDATE forge_jobs
264            SET last_heartbeat = NOW()
265            WHERE id = $1
266            "#,
267            self.job_id,
268        )
269        .execute(&self.db_pool)
270        .await
271        .map_err(|e| crate::ForgeError::Database(e.to_string()))?;
272
273        Ok(())
274    }
275
276    /// Check if this is a retry attempt.
277    pub fn is_retry(&self) -> bool {
278        self.attempt > 1
279    }
280
281    /// Check if this is the last attempt.
282    pub fn is_last_attempt(&self) -> bool {
283        self.attempt >= self.max_attempts
284    }
285}
286
287impl EnvAccess for JobContext {
288    fn env_provider(&self) -> &dyn EnvProvider {
289        self.env_provider.as_ref()
290    }
291}
292
293#[cfg(test)]
294#[allow(clippy::unwrap_used, clippy::indexing_slicing)]
295mod tests {
296    use super::*;
297
298    #[tokio::test]
299    async fn test_job_context_creation() {
300        let pool = sqlx::postgres::PgPoolOptions::new()
301            .max_connections(1)
302            .connect_lazy("postgres://localhost/nonexistent")
303            .expect("Failed to create mock pool");
304
305        let job_id = Uuid::new_v4();
306        let ctx = JobContext::new(
307            job_id,
308            "test_job".to_string(),
309            1,
310            3,
311            pool,
312            CircuitBreakerClient::with_defaults(reqwest::Client::new()),
313        );
314
315        assert_eq!(ctx.job_id, job_id);
316        assert_eq!(ctx.job_type, "test_job");
317        assert_eq!(ctx.attempt, 1);
318        assert_eq!(ctx.max_attempts, 3);
319        assert!(!ctx.is_retry());
320        assert!(!ctx.is_last_attempt());
321    }
322
323    #[tokio::test]
324    async fn test_is_retry() {
325        let pool = sqlx::postgres::PgPoolOptions::new()
326            .max_connections(1)
327            .connect_lazy("postgres://localhost/nonexistent")
328            .expect("Failed to create mock pool");
329
330        let ctx = JobContext::new(
331            Uuid::new_v4(),
332            "test".to_string(),
333            2,
334            3,
335            pool,
336            CircuitBreakerClient::with_defaults(reqwest::Client::new()),
337        );
338
339        assert!(ctx.is_retry());
340    }
341
342    #[tokio::test]
343    async fn test_is_last_attempt() {
344        let pool = sqlx::postgres::PgPoolOptions::new()
345            .max_connections(1)
346            .connect_lazy("postgres://localhost/nonexistent")
347            .expect("Failed to create mock pool");
348
349        let ctx = JobContext::new(
350            Uuid::new_v4(),
351            "test".to_string(),
352            3,
353            3,
354            pool,
355            CircuitBreakerClient::with_defaults(reqwest::Client::new()),
356        );
357
358        assert!(ctx.is_last_attempt());
359    }
360
361    #[test]
362    fn test_progress_update() {
363        let update = ProgressUpdate {
364            job_id: Uuid::new_v4(),
365            percentage: 50,
366            message: "Halfway there".to_string(),
367        };
368
369        assert_eq!(update.percentage, 50);
370        assert_eq!(update.message, "Halfway there");
371    }
372
373    #[tokio::test]
374    async fn test_saved_data_in_memory() {
375        let pool = sqlx::postgres::PgPoolOptions::new()
376            .max_connections(1)
377            .connect_lazy("postgres://localhost/nonexistent")
378            .expect("Failed to create mock pool");
379
380        let ctx = JobContext::new(
381            Uuid::nil(),
382            "test_job".to_string(),
383            1,
384            3,
385            pool,
386            CircuitBreakerClient::with_defaults(reqwest::Client::new()),
387        )
388        .with_saved(serde_json::json!({"foo": "bar"}));
389
390        let saved = ctx.saved().await;
391        assert_eq!(saved["foo"], "bar");
392    }
393
394    #[tokio::test]
395    async fn test_save_key_value() {
396        let pool = sqlx::postgres::PgPoolOptions::new()
397            .max_connections(1)
398            .connect_lazy("postgres://localhost/nonexistent")
399            .expect("Failed to create mock pool");
400
401        let ctx = JobContext::new(
402            Uuid::nil(),
403            "test_job".to_string(),
404            1,
405            3,
406            pool,
407            CircuitBreakerClient::with_defaults(reqwest::Client::new()),
408        );
409
410        ctx.save("charge_id", serde_json::json!("ch_123"))
411            .await
412            .unwrap();
413        ctx.save("amount", serde_json::json!(100)).await.unwrap();
414
415        let saved = ctx.saved().await;
416        assert_eq!(saved["charge_id"], "ch_123");
417        assert_eq!(saved["amount"], 100);
418    }
419}