Skip to main content

forge_core/job/
context.rs

1use std::sync::{Arc, mpsc};
2use std::time::Duration;
3
4use uuid::Uuid;
5
6use crate::env::{EnvAccess, EnvProvider, RealEnvProvider};
7use crate::function::AuthContext;
8use crate::http::CircuitBreakerClient;
9
10/// Returns an empty JSON object for initializing job saved data.
11pub fn empty_saved_data() -> serde_json::Value {
12    serde_json::Value::Object(serde_json::Map::new())
13}
14
15/// Context available to job handlers.
16pub struct JobContext {
17    /// Job ID.
18    pub job_id: Uuid,
19    /// Job type/name.
20    pub job_type: String,
21    /// Current attempt number (1-based).
22    pub attempt: u32,
23    /// Maximum attempts allowed.
24    pub max_attempts: u32,
25    /// Authentication context (for queries/mutations).
26    pub auth: AuthContext,
27    /// Persisted job data (survives retries, accessible during compensation).
28    saved_data: Arc<tokio::sync::RwLock<serde_json::Value>>,
29    /// Database pool.
30    db_pool: sqlx::PgPool,
31    /// HTTP client for external calls.
32    http_client: CircuitBreakerClient,
33    /// Default timeout for outbound HTTP requests made through the
34    /// circuit-breaker client. `None` means unlimited.
35    http_timeout: Option<Duration>,
36    /// Progress reporter (sync channel for simplicity).
37    progress_tx: Option<mpsc::Sender<ProgressUpdate>>,
38    /// Environment variable provider.
39    env_provider: Arc<dyn EnvProvider>,
40}
41
42/// Progress update message.
43#[derive(Debug, Clone)]
44pub struct ProgressUpdate {
45    /// Job ID.
46    pub job_id: Uuid,
47    /// Progress percentage (0-100).
48    pub percentage: u8,
49    /// Status message.
50    pub message: String,
51}
52
53impl JobContext {
54    /// Create a new job context.
55    pub fn new(
56        job_id: Uuid,
57        job_type: String,
58        attempt: u32,
59        max_attempts: u32,
60        db_pool: sqlx::PgPool,
61        http_client: CircuitBreakerClient,
62    ) -> Self {
63        Self {
64            job_id,
65            job_type,
66            attempt,
67            max_attempts,
68            auth: AuthContext::unauthenticated(),
69            saved_data: Arc::new(tokio::sync::RwLock::new(empty_saved_data())),
70            db_pool,
71            http_client,
72            http_timeout: None,
73            progress_tx: None,
74            env_provider: Arc::new(RealEnvProvider::new()),
75        }
76    }
77
78    /// Create a new job context with persisted saved data.
79    pub fn with_saved(mut self, data: serde_json::Value) -> Self {
80        self.saved_data = Arc::new(tokio::sync::RwLock::new(data));
81        self
82    }
83
84    /// Set authentication context.
85    pub fn with_auth(mut self, auth: AuthContext) -> Self {
86        self.auth = auth;
87        self
88    }
89
90    /// Set progress channel.
91    pub fn with_progress(mut self, tx: mpsc::Sender<ProgressUpdate>) -> Self {
92        self.progress_tx = Some(tx);
93        self
94    }
95
96    /// Set environment provider.
97    pub fn with_env_provider(mut self, provider: Arc<dyn EnvProvider>) -> Self {
98        self.env_provider = provider;
99        self
100    }
101
102    /// Get database pool.
103    pub fn db(&self) -> crate::function::ForgeDb {
104        crate::function::ForgeDb::from_pool(&self.db_pool)
105    }
106
107    /// Acquire a connection compatible with sqlx compile-time checked macros.
108    pub async fn conn(&self) -> sqlx::Result<crate::function::ForgeConn<'static>> {
109        Ok(crate::function::ForgeConn::Pool(
110            self.db_pool.acquire().await?,
111        ))
112    }
113
114    /// Get the HTTP client for external requests.
115    pub fn http(&self) -> crate::http::HttpClient {
116        self.http_client.with_timeout(self.http_timeout)
117    }
118
119    /// Get the raw reqwest client, bypassing circuit breaker execution.
120    pub fn raw_http(&self) -> &reqwest::Client {
121        self.http_client.inner()
122    }
123
124    pub fn set_http_timeout(&mut self, timeout: Option<Duration>) {
125        self.http_timeout = timeout;
126    }
127
128    /// Report job progress.
129    pub fn progress(&self, percentage: u8, message: impl Into<String>) -> crate::Result<()> {
130        let update = ProgressUpdate {
131            job_id: self.job_id,
132            percentage: percentage.min(100),
133            message: message.into(),
134        };
135
136        if let Some(ref tx) = self.progress_tx {
137            tx.send(update)
138                .map_err(|e| crate::ForgeError::Job(format!("Failed to send progress: {}", e)))?;
139        }
140
141        Ok(())
142    }
143
144    /// Get all saved job data.
145    ///
146    /// Returns data that was saved during job execution via `save()`.
147    /// This data persists across retries and is accessible in compensation handlers.
148    pub async fn saved(&self) -> serde_json::Value {
149        self.saved_data.read().await.clone()
150    }
151
152    /// Replace all saved job data.
153    ///
154    /// Replaces the entire saved data object. For updating individual keys,
155    /// use `save()` instead.
156    pub async fn set_saved(&self, data: serde_json::Value) -> crate::Result<()> {
157        let mut guard = self.saved_data.write().await;
158        *guard = data;
159        let persisted = Self::clone_and_drop(guard);
160        if self.job_id.is_nil() {
161            return Ok(());
162        }
163        self.persist_saved_data(persisted).await
164    }
165
166    /// Save a key-value pair to persistent job data.
167    ///
168    /// Saved data persists across retries and is accessible in compensation handlers.
169    /// Use this to store information needed for rollback (e.g., transaction IDs,
170    /// resource handles, progress markers).
171    ///
172    /// # Example
173    ///
174    /// ```ignore
175    /// ctx.save("charge_id", json!(charge.id)).await?;
176    /// ctx.save("refund_amount", json!(amount)).await?;
177    /// ```
178    pub async fn save(&self, key: &str, value: serde_json::Value) -> crate::Result<()> {
179        let mut guard = self.saved_data.write().await;
180        Self::apply_save(&mut guard, key, value);
181        let persisted = Self::clone_and_drop(guard);
182        if self.job_id.is_nil() {
183            return Ok(());
184        }
185        self.persist_saved_data(persisted).await
186    }
187
188    /// Check if cancellation has been requested for this job.
189    pub async fn is_cancel_requested(&self) -> crate::Result<bool> {
190        let row = sqlx::query_scalar!(
191            r#"
192            SELECT status
193            FROM forge_jobs
194            WHERE id = $1
195            "#,
196            self.job_id
197        )
198        .fetch_optional(&self.db_pool)
199        .await
200        .map_err(|e| crate::ForgeError::Database(e.to_string()))?;
201
202        Ok(matches!(
203            row.as_deref(),
204            Some("cancel_requested") | Some("cancelled")
205        ))
206    }
207
208    /// Return an error if cancellation has been requested.
209    pub async fn check_cancelled(&self) -> crate::Result<()> {
210        if self.is_cancel_requested().await? {
211            Err(crate::ForgeError::JobCancelled(
212                "Job cancellation requested".to_string(),
213            ))
214        } else {
215            Ok(())
216        }
217    }
218
219    async fn persist_saved_data(&self, data: serde_json::Value) -> crate::Result<()> {
220        sqlx::query!(
221            r#"
222            UPDATE forge_jobs
223            SET job_context = $2
224            WHERE id = $1
225            "#,
226            self.job_id,
227            data,
228        )
229        .execute(&self.db_pool)
230        .await
231        .map_err(|e| crate::ForgeError::Database(e.to_string()))?;
232
233        Ok(())
234    }
235
236    fn apply_save(data: &mut serde_json::Value, key: &str, value: serde_json::Value) {
237        if let Some(map) = data.as_object_mut() {
238            map.insert(key.to_string(), value);
239        } else {
240            let mut map = serde_json::Map::new();
241            map.insert(key.to_string(), value);
242            *data = serde_json::Value::Object(map);
243        }
244    }
245
246    fn clone_and_drop(
247        guard: tokio::sync::RwLockWriteGuard<'_, serde_json::Value>,
248    ) -> serde_json::Value {
249        let cloned = guard.clone();
250        drop(guard);
251        cloned
252    }
253
254    /// Send heartbeat to keep job alive (async).
255    pub async fn heartbeat(&self) -> crate::Result<()> {
256        sqlx::query!(
257            r#"
258            UPDATE forge_jobs
259            SET last_heartbeat = NOW()
260            WHERE id = $1
261            "#,
262            self.job_id,
263        )
264        .execute(&self.db_pool)
265        .await
266        .map_err(|e| crate::ForgeError::Database(e.to_string()))?;
267
268        Ok(())
269    }
270
271    /// Check if this is a retry attempt.
272    pub fn is_retry(&self) -> bool {
273        self.attempt > 1
274    }
275
276    /// Check if this is the last attempt.
277    pub fn is_last_attempt(&self) -> bool {
278        self.attempt >= self.max_attempts
279    }
280}
281
282impl EnvAccess for JobContext {
283    fn env_provider(&self) -> &dyn EnvProvider {
284        self.env_provider.as_ref()
285    }
286}
287
288#[cfg(test)]
289#[allow(clippy::unwrap_used, clippy::indexing_slicing)]
290mod tests {
291    use super::*;
292
293    #[tokio::test]
294    async fn test_job_context_creation() {
295        let pool = sqlx::postgres::PgPoolOptions::new()
296            .max_connections(1)
297            .connect_lazy("postgres://localhost/nonexistent")
298            .expect("Failed to create mock pool");
299
300        let job_id = Uuid::new_v4();
301        let ctx = JobContext::new(
302            job_id,
303            "test_job".to_string(),
304            1,
305            3,
306            pool,
307            CircuitBreakerClient::with_defaults(reqwest::Client::new()),
308        );
309
310        assert_eq!(ctx.job_id, job_id);
311        assert_eq!(ctx.job_type, "test_job");
312        assert_eq!(ctx.attempt, 1);
313        assert_eq!(ctx.max_attempts, 3);
314        assert!(!ctx.is_retry());
315        assert!(!ctx.is_last_attempt());
316    }
317
318    #[tokio::test]
319    async fn test_is_retry() {
320        let pool = sqlx::postgres::PgPoolOptions::new()
321            .max_connections(1)
322            .connect_lazy("postgres://localhost/nonexistent")
323            .expect("Failed to create mock pool");
324
325        let ctx = JobContext::new(
326            Uuid::new_v4(),
327            "test".to_string(),
328            2,
329            3,
330            pool,
331            CircuitBreakerClient::with_defaults(reqwest::Client::new()),
332        );
333
334        assert!(ctx.is_retry());
335    }
336
337    #[tokio::test]
338    async fn test_is_last_attempt() {
339        let pool = sqlx::postgres::PgPoolOptions::new()
340            .max_connections(1)
341            .connect_lazy("postgres://localhost/nonexistent")
342            .expect("Failed to create mock pool");
343
344        let ctx = JobContext::new(
345            Uuid::new_v4(),
346            "test".to_string(),
347            3,
348            3,
349            pool,
350            CircuitBreakerClient::with_defaults(reqwest::Client::new()),
351        );
352
353        assert!(ctx.is_last_attempt());
354    }
355
356    #[test]
357    fn test_progress_update() {
358        let update = ProgressUpdate {
359            job_id: Uuid::new_v4(),
360            percentage: 50,
361            message: "Halfway there".to_string(),
362        };
363
364        assert_eq!(update.percentage, 50);
365        assert_eq!(update.message, "Halfway there");
366    }
367
368    #[tokio::test]
369    async fn test_saved_data_in_memory() {
370        let pool = sqlx::postgres::PgPoolOptions::new()
371            .max_connections(1)
372            .connect_lazy("postgres://localhost/nonexistent")
373            .expect("Failed to create mock pool");
374
375        let ctx = JobContext::new(
376            Uuid::nil(),
377            "test_job".to_string(),
378            1,
379            3,
380            pool,
381            CircuitBreakerClient::with_defaults(reqwest::Client::new()),
382        )
383        .with_saved(serde_json::json!({"foo": "bar"}));
384
385        let saved = ctx.saved().await;
386        assert_eq!(saved["foo"], "bar");
387    }
388
389    #[tokio::test]
390    async fn test_save_key_value() {
391        let pool = sqlx::postgres::PgPoolOptions::new()
392            .max_connections(1)
393            .connect_lazy("postgres://localhost/nonexistent")
394            .expect("Failed to create mock pool");
395
396        let ctx = JobContext::new(
397            Uuid::nil(),
398            "test_job".to_string(),
399            1,
400            3,
401            pool,
402            CircuitBreakerClient::with_defaults(reqwest::Client::new()),
403        );
404
405        ctx.save("charge_id", serde_json::json!("ch_123"))
406            .await
407            .unwrap();
408        ctx.save("amount", serde_json::json!(100)).await.unwrap();
409
410        let saved = ctx.saved().await;
411        assert_eq!(saved["charge_id"], "ch_123");
412        assert_eq!(saved["amount"], 100);
413    }
414}