Skip to main content

forge_core/job/
context.rs

1use std::sync::{Arc, mpsc};
2use std::time::Duration;
3
4use uuid::Uuid;
5
6use serde::Serialize;
7
8use crate::env::{EnvAccess, EnvProvider, RealEnvProvider};
9use crate::function::{AuthContext, JobDispatch, KvHandle, WorkflowDispatch};
10use crate::http::CircuitBreakerClient;
11
12/// Returns an empty JSON object for initializing job saved data.
13pub fn empty_saved_data() -> serde_json::Value {
14    serde_json::Value::Object(serde_json::Map::new())
15}
16
17/// Context available to job handlers.
18#[non_exhaustive]
19pub struct JobContext {
20    /// Current attempt (1-based).
21    pub job_id: Uuid,
22    pub job_type: String,
23    pub attempt: u32,
24    pub max_attempts: u32,
25    pub auth: AuthContext,
26    /// Persisted across retries; accessible in compensation handlers.
27    saved_data: Arc<tokio::sync::RwLock<serde_json::Value>>,
28    db_pool: sqlx::PgPool,
29    http_client: CircuitBreakerClient,
30    /// `None` means unlimited.
31    http_timeout: Option<Duration>,
32    progress_tx: Option<mpsc::Sender<ProgressUpdate>>,
33    env_provider: Arc<dyn EnvProvider>,
34    kv: Option<Arc<dyn KvHandle>>,
35    /// If absent, `dispatch_job` fails with an internal error rather than bypassing the trait.
36    job_dispatch: Option<Arc<dyn JobDispatch>>,
37    /// Required so dispatched workflows resolve the active version + signature
38    /// instead of resuming as `BlockedSignatureMismatch`.
39    workflow_dispatch: Option<Arc<dyn WorkflowDispatch>>,
40}
41
42/// Progress update message.
43#[derive(Debug, Clone)]
44pub struct ProgressUpdate {
45    pub job_id: Uuid,
46    /// 0–100.
47    pub percentage: u8,
48    pub message: String,
49}
50
51impl JobContext {
52    /// Create a new job context.
53    pub fn new(
54        job_id: Uuid,
55        job_type: String,
56        attempt: u32,
57        max_attempts: u32,
58        db_pool: sqlx::PgPool,
59        http_client: CircuitBreakerClient,
60    ) -> Self {
61        Self {
62            job_id,
63            job_type,
64            attempt,
65            max_attempts,
66            auth: AuthContext::unauthenticated(),
67            saved_data: Arc::new(tokio::sync::RwLock::new(empty_saved_data())),
68            db_pool,
69            http_client,
70            http_timeout: None,
71            progress_tx: None,
72            env_provider: Arc::new(RealEnvProvider::new()),
73            kv: None,
74            job_dispatch: None,
75            workflow_dispatch: None,
76        }
77    }
78
79    /// Attach a KV store handle. Called by the runtime before handing the
80    /// context to the handler.
81    pub fn with_kv(mut self, kv: Arc<dyn KvHandle>) -> Self {
82        self.kv = Some(kv);
83        self
84    }
85
86    /// Attach a job dispatcher so `dispatch_job` routes through the
87    /// `JobDispatch` trait (the only path that resolves registered job
88    /// metadata).
89    pub fn with_job_dispatch(mut self, dispatcher: Arc<dyn JobDispatch>) -> Self {
90        self.job_dispatch = Some(dispatcher);
91        self
92    }
93
94    /// Attach a workflow dispatcher so `start_workflow` routes through the
95    /// `WorkflowDispatch` trait, which writes the active version + signature.
96    /// Without this, dispatched workflows would resume as
97    /// `BlockedSignatureMismatch` on first attempt.
98    pub fn with_workflow_dispatch(mut self, dispatcher: Arc<dyn WorkflowDispatch>) -> Self {
99        self.workflow_dispatch = Some(dispatcher);
100        self
101    }
102
103    /// Access the KV store.
104    pub fn kv(&self) -> crate::error::Result<&dyn KvHandle> {
105        self.kv
106            .as_deref()
107            .ok_or_else(|| crate::error::ForgeError::internal("KV store not available"))
108    }
109
110    /// Create a new job context with persisted saved data.
111    pub fn with_saved(mut self, data: serde_json::Value) -> Self {
112        self.saved_data = Arc::new(tokio::sync::RwLock::new(data));
113        self
114    }
115
116    /// Set authentication context.
117    pub fn with_auth(mut self, auth: AuthContext) -> Self {
118        self.auth = auth;
119        self
120    }
121
122    /// Inject a tenant ID into the auth context claims.
123    ///
124    /// Merges the `tenant_id` claim into the existing auth context so that
125    /// `ctx.auth.tenant_id()` returns the value for the duration of this job.
126    /// Used by the executor when the job record carries a tenant ID.
127    pub fn with_tenant_id(mut self, tenant_id: Uuid) -> Self {
128        let mut claims = self.auth.claims().clone();
129        claims.insert(
130            "tenant_id".to_string(),
131            serde_json::Value::String(tenant_id.to_string()),
132        );
133        self.auth = if self.auth.is_authenticated() {
134            if let Some(user_id) = self.auth.user_id() {
135                AuthContext::authenticated(user_id, self.auth.roles().to_vec(), claims)
136            } else {
137                AuthContext::authenticated_without_uuid(self.auth.roles().to_vec(), claims)
138            }
139        } else {
140            AuthContext::authenticated_without_uuid(Vec::new(), claims)
141        };
142        self
143    }
144
145    /// Set progress channel.
146    pub fn with_progress(mut self, tx: mpsc::Sender<ProgressUpdate>) -> Self {
147        self.progress_tx = Some(tx);
148        self
149    }
150
151    /// Set environment provider.
152    pub fn with_env_provider(mut self, provider: Arc<dyn EnvProvider>) -> Self {
153        self.env_provider = provider;
154        self
155    }
156
157    /// Get database pool.
158    pub fn db(&self) -> crate::function::ForgeDb {
159        crate::function::ForgeDb::from_pool(&self.db_pool)
160    }
161
162    /// Get a `DbConn` for use in shared helper functions.
163    pub fn db_conn(&self) -> crate::function::DbConn<'_> {
164        crate::function::DbConn::Pool(self.db_pool.clone())
165    }
166
167    /// Acquire a connection compatible with sqlx compile-time checked macros.
168    pub async fn conn(&self) -> sqlx::Result<crate::function::ForgeConn<'static>> {
169        Ok(crate::function::ForgeConn::Pool(
170            self.db_pool.acquire().await?,
171        ))
172    }
173
174    /// Get the HTTP client for external requests.
175    pub fn http(&self) -> crate::http::HttpClient {
176        self.http_client.with_timeout(self.http_timeout)
177    }
178
179    /// Get the raw reqwest client, bypassing circuit breaker execution.
180    pub fn raw_http(&self) -> &reqwest::Client {
181        self.http_client.inner()
182    }
183
184    pub fn set_http_timeout(&mut self, timeout: Option<Duration>) {
185        self.http_timeout = timeout;
186    }
187
188    /// Get the raw database pool for bridge handlers that need to construct
189    /// other context types (e.g., CronContext from a cron bridge job).
190    ///
191    /// Not intended for use in application job handlers. Use `db()`, `db_conn()`,
192    /// or `conn()` instead. This exists so bridge handlers (cron-to-job,
193    /// workflow-to-job) can construct sub-contexts that require a pool reference.
194    #[doc(hidden)]
195    pub fn pool(&self) -> &sqlx::PgPool {
196        &self.db_pool
197    }
198
199    /// Get the circuit breaker HTTP client for bridge handlers that need to
200    /// construct sub-contexts (e.g., CronContext from a cron bridge job).
201    ///
202    /// Not intended for use in application job handlers. Use `http()` or
203    /// `raw_http()` instead.
204    #[doc(hidden)]
205    pub fn circuit_breaker_client(&self) -> &CircuitBreakerClient {
206        &self.http_client
207    }
208
209    /// Get the KV handle for bridge handlers that need to propagate it to
210    /// sub-contexts (e.g., CronContext from a cron bridge job).
211    ///
212    /// Not intended for use in application job handlers. Use `kv()` instead.
213    #[doc(hidden)]
214    pub fn kv_handle(&self) -> Option<Arc<dyn KvHandle>> {
215        self.kv.clone()
216    }
217
218    /// Report job progress.
219    pub fn progress(&self, percentage: u8, message: impl Into<String>) -> crate::Result<()> {
220        let update = ProgressUpdate {
221            job_id: self.job_id,
222            percentage: percentage.min(100),
223            message: message.into(),
224        };
225
226        if let Some(ref tx) = self.progress_tx {
227            tx.send(update).map_err(|e| {
228                crate::ForgeError::internal(format!("Failed to send progress: {e}"))
229            })?;
230        }
231
232        Ok(())
233    }
234
235    /// Get all saved job data.
236    ///
237    /// Returns data that was saved during job execution via `save()`.
238    /// This data persists across retries and is accessible in compensation handlers.
239    pub async fn saved(&self) -> serde_json::Value {
240        self.saved_data.read().await.clone()
241    }
242
243    /// Save a key-value pair to persistent job data.
244    ///
245    /// Merges `key` into the saved data object and persists the result to the
246    /// database. Saved data survives retries and is accessible in compensation
247    /// handlers. Use this to store information needed for rollback (e.g.,
248    /// transaction IDs, resource handles, progress markers).
249    ///
250    /// Read saved data back with [`saved()`](Self::saved).
251    ///
252    /// # Example
253    ///
254    /// ```ignore
255    /// ctx.save("charge_id", json!(charge.id)).await?;
256    /// ctx.save("refund_amount", json!(amount)).await?;
257    /// ```
258    pub async fn save(&self, key: &str, value: serde_json::Value) -> crate::Result<()> {
259        let mut guard = self.saved_data.write().await;
260        Self::apply_save(&mut guard, key, value);
261        let persisted = Self::clone_and_drop(guard);
262        if self.job_id.is_nil() {
263            return Ok(());
264        }
265        self.persist_saved_data(persisted).await
266    }
267
268    /// Dispatch a sub-job directly.
269    ///
270    /// Routes through the `JobDispatch` trait so registered job metadata
271    /// (queue/capability, priority, retry policy) is honoured. The dispatch is
272    /// non-transactional: once the parent job returns, the child remains
273    /// enqueued regardless of success. Use the transactional dispatch on
274    /// `MutationContext` for commit-dependent fan-out.
275    pub async fn dispatch_job<T: Serialize>(
276        &self,
277        job_type: &str,
278        args: &T,
279    ) -> crate::Result<Uuid> {
280        let args_json = serde_json::to_value(args)
281            .map_err(|e| crate::ForgeError::Serialization(e.to_string()))?;
282        let dispatcher = self
283            .job_dispatch
284            .as_ref()
285            .ok_or_else(|| crate::ForgeError::internal("Job dispatch not available"))?;
286        dispatcher
287            .dispatch_by_name(
288                job_type,
289                args_json,
290                self.auth.principal_id(),
291                self.auth.tenant_id(),
292            )
293            .await
294    }
295
296    /// Type-safe dispatch: resolves the job name from the type's `ForgeJob`
297    /// impl and serializes the args at the call site.
298    pub async fn dispatch<J: crate::ForgeJob>(&self, args: &J::Args) -> crate::Result<Uuid> {
299        self.dispatch_job(J::info().name, args).await
300    }
301
302    /// Start a workflow directly.
303    ///
304    /// Routes through the `WorkflowDispatch` trait, which writes the active
305    /// version + signature onto the run row and enqueues the
306    /// `$workflow_resume` job. Calling raw SQL here would leave both columns
307    /// blank and the executor would immediately mark the run as
308    /// `BlockedSignatureMismatch`.
309    pub async fn start_workflow<T: Serialize>(
310        &self,
311        workflow_name: &str,
312        args: &T,
313    ) -> crate::Result<Uuid> {
314        let input_json = serde_json::to_value(args)
315            .map_err(|e| crate::ForgeError::Serialization(e.to_string()))?;
316        let dispatcher = self
317            .workflow_dispatch
318            .as_ref()
319            .ok_or_else(|| crate::ForgeError::internal("Workflow dispatch not available"))?;
320        dispatcher
321            .start_by_name(workflow_name, input_json, self.auth.principal_id(), None)
322            .await
323    }
324
325    /// Check if cancellation has been requested for this job.
326    pub async fn is_cancel_requested(&self) -> crate::Result<bool> {
327        let row = sqlx::query_scalar!(
328            r#"
329            SELECT status
330            FROM forge_jobs
331            WHERE id = $1
332            "#,
333            self.job_id
334        )
335        .fetch_optional(&self.db_pool)
336        .await
337        .map_err(crate::ForgeError::Database)?;
338
339        Ok(matches!(
340            row.as_deref(),
341            Some("cancel_requested") | Some("cancelled")
342        ))
343    }
344
345    /// Return an error if cancellation has been requested.
346    pub async fn check_cancelled(&self) -> crate::Result<()> {
347        if self.is_cancel_requested().await? {
348            Err(crate::ForgeError::JobCancelled(
349                "Job cancellation requested".to_string(),
350            ))
351        } else {
352            Ok(())
353        }
354    }
355
356    async fn persist_saved_data(&self, data: serde_json::Value) -> crate::Result<()> {
357        sqlx::query!(
358            r#"
359            UPDATE forge_jobs
360            SET job_context = $2
361            WHERE id = $1
362            "#,
363            self.job_id,
364            data,
365        )
366        .execute(&self.db_pool)
367        .await
368        .map_err(crate::ForgeError::Database)?;
369
370        Ok(())
371    }
372
373    fn apply_save(data: &mut serde_json::Value, key: &str, value: serde_json::Value) {
374        if let Some(map) = data.as_object_mut() {
375            map.insert(key.to_string(), value);
376        } else {
377            let mut map = serde_json::Map::new();
378            map.insert(key.to_string(), value);
379            *data = serde_json::Value::Object(map);
380        }
381    }
382
383    fn clone_and_drop(
384        guard: tokio::sync::RwLockWriteGuard<'_, serde_json::Value>,
385    ) -> serde_json::Value {
386        let cloned = guard.clone();
387        drop(guard);
388        cloned
389    }
390
391    /// Send heartbeat to keep job alive (async).
392    pub async fn heartbeat(&self) -> crate::Result<()> {
393        sqlx::query!(
394            r#"
395            UPDATE forge_jobs
396            SET last_heartbeat = NOW()
397            WHERE id = $1
398            "#,
399            self.job_id,
400        )
401        .execute(&self.db_pool)
402        .await
403        .map_err(crate::ForgeError::Database)?;
404
405        Ok(())
406    }
407
408    /// Check if this is a retry attempt.
409    pub fn is_retry(&self) -> bool {
410        self.attempt > 1
411    }
412
413    /// Check if this is the last attempt.
414    pub fn is_last_attempt(&self) -> bool {
415        self.attempt >= self.max_attempts
416    }
417}
418
419impl EnvAccess for JobContext {
420    fn env_provider(&self) -> &dyn EnvProvider {
421        self.env_provider.as_ref()
422    }
423}
424
425#[cfg(test)]
426#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
427mod tests {
428    use super::*;
429
430    #[tokio::test]
431    async fn test_job_context_creation() {
432        let pool = sqlx::postgres::PgPoolOptions::new()
433            .max_connections(1)
434            .connect_lazy("postgres://localhost/nonexistent")
435            .expect("Failed to create mock pool");
436
437        let job_id = Uuid::new_v4();
438        let ctx = JobContext::new(
439            job_id,
440            "test_job".to_string(),
441            1,
442            3,
443            pool,
444            CircuitBreakerClient::with_defaults(reqwest::Client::new()),
445        );
446
447        assert_eq!(ctx.job_id, job_id);
448        assert_eq!(ctx.job_type, "test_job");
449        assert_eq!(ctx.attempt, 1);
450        assert_eq!(ctx.max_attempts, 3);
451        assert!(!ctx.is_retry());
452        assert!(!ctx.is_last_attempt());
453    }
454
455    #[tokio::test]
456    async fn test_is_retry() {
457        let pool = sqlx::postgres::PgPoolOptions::new()
458            .max_connections(1)
459            .connect_lazy("postgres://localhost/nonexistent")
460            .expect("Failed to create mock pool");
461
462        let ctx = JobContext::new(
463            Uuid::new_v4(),
464            "test".to_string(),
465            2,
466            3,
467            pool,
468            CircuitBreakerClient::with_defaults(reqwest::Client::new()),
469        );
470
471        assert!(ctx.is_retry());
472    }
473
474    #[tokio::test]
475    async fn test_is_last_attempt() {
476        let pool = sqlx::postgres::PgPoolOptions::new()
477            .max_connections(1)
478            .connect_lazy("postgres://localhost/nonexistent")
479            .expect("Failed to create mock pool");
480
481        let ctx = JobContext::new(
482            Uuid::new_v4(),
483            "test".to_string(),
484            3,
485            3,
486            pool,
487            CircuitBreakerClient::with_defaults(reqwest::Client::new()),
488        );
489
490        assert!(ctx.is_last_attempt());
491    }
492
493    #[test]
494    fn test_progress_update() {
495        let update = ProgressUpdate {
496            job_id: Uuid::new_v4(),
497            percentage: 50,
498            message: "Halfway there".to_string(),
499        };
500
501        assert_eq!(update.percentage, 50);
502        assert_eq!(update.message, "Halfway there");
503    }
504
505    #[tokio::test]
506    async fn test_saved_data_in_memory() {
507        let pool = sqlx::postgres::PgPoolOptions::new()
508            .max_connections(1)
509            .connect_lazy("postgres://localhost/nonexistent")
510            .expect("Failed to create mock pool");
511
512        let ctx = JobContext::new(
513            Uuid::nil(),
514            "test_job".to_string(),
515            1,
516            3,
517            pool,
518            CircuitBreakerClient::with_defaults(reqwest::Client::new()),
519        )
520        .with_saved(serde_json::json!({"foo": "bar"}));
521
522        let saved = ctx.saved().await;
523        assert_eq!(saved["foo"], "bar");
524    }
525
526    #[tokio::test]
527    async fn test_save_key_value() {
528        let pool = sqlx::postgres::PgPoolOptions::new()
529            .max_connections(1)
530            .connect_lazy("postgres://localhost/nonexistent")
531            .expect("Failed to create mock pool");
532
533        let ctx = JobContext::new(
534            Uuid::nil(),
535            "test_job".to_string(),
536            1,
537            3,
538            pool,
539            CircuitBreakerClient::with_defaults(reqwest::Client::new()),
540        );
541
542        ctx.save("charge_id", serde_json::json!("ch_123"))
543            .await
544            .unwrap();
545        ctx.save("amount", serde_json::json!(100)).await.unwrap();
546
547        let saved = ctx.saved().await;
548        assert_eq!(saved["charge_id"], "ch_123");
549        assert_eq!(saved["amount"], 100);
550    }
551
552    fn mock_pool() -> sqlx::PgPool {
553        sqlx::postgres::PgPoolOptions::new()
554            .max_connections(1)
555            .connect_lazy("postgres://localhost/nonexistent")
556            .expect("Failed to create mock pool")
557    }
558
559    fn nil_ctx() -> JobContext {
560        JobContext::new(
561            Uuid::nil(),
562            "test_job".to_string(),
563            1,
564            3,
565            mock_pool(),
566            CircuitBreakerClient::with_defaults(reqwest::Client::new()),
567        )
568    }
569
570    #[test]
571    fn empty_saved_data_is_an_empty_object() {
572        let data = empty_saved_data();
573        let obj = data.as_object().expect("empty_saved_data is an object");
574        assert!(obj.is_empty());
575    }
576
577    #[tokio::test]
578    async fn progress_without_channel_is_a_noop() {
579        let ctx = nil_ctx();
580        ctx.progress(42, "boot")
581            .expect("noop progress should not error");
582    }
583
584    #[tokio::test]
585    async fn progress_clamps_percentage_to_100() {
586        let (tx, rx) = mpsc::channel();
587        let ctx = nil_ctx().with_progress(tx);
588        ctx.progress(250, "over").expect("send should succeed");
589        let update = rx.recv().expect("update available");
590        assert_eq!(update.percentage, 100);
591        assert_eq!(update.message, "over");
592        assert_eq!(update.job_id, ctx.job_id);
593    }
594
595    #[tokio::test]
596    async fn progress_returns_job_error_when_receiver_dropped() {
597        let (tx, rx) = mpsc::channel::<ProgressUpdate>();
598        drop(rx);
599        let ctx = nil_ctx().with_progress(tx);
600        let err = ctx
601            .progress(10, "lost")
602            .expect_err("dropped receiver should fail send");
603        match err {
604            crate::ForgeError::Internal { context: msg, .. } => {
605                assert!(msg.contains("Failed to send progress"), "got: {msg}");
606            }
607            other => panic!("expected ForgeError::Internal, got {other:?}"),
608        }
609    }
610
611    #[tokio::test]
612    async fn with_auth_threads_authenticated_principal() {
613        let user = Uuid::new_v4();
614        let ctx = nil_ctx().with_auth(AuthContext::authenticated(
615            user,
616            vec!["admin".to_string()],
617            Default::default(),
618        ));
619        assert_eq!(ctx.auth.user_id(), Some(user));
620        assert!(ctx.auth.has_role("admin"));
621    }
622
623    #[tokio::test]
624    async fn with_env_provider_reaches_through_env_access_trait() {
625        use crate::env::MockEnvProvider;
626        let mut mock = MockEnvProvider::new();
627        mock.set("API_KEY", "sk_test");
628        let ctx = nil_ctx().with_env_provider(Arc::new(mock));
629
630        assert_eq!(ctx.env("API_KEY"), Some("sk_test".to_string()));
631        assert!(ctx.env("MISSING").is_none());
632    }
633
634    #[tokio::test]
635    async fn save_promotes_non_object_value_into_object() {
636        // If saved data is somehow not an object (e.g., legacy nullable column),
637        // save() must replace it with an object containing the new key rather
638        // than silently dropping the write.
639        let ctx = nil_ctx().with_saved(serde_json::Value::Null);
640        ctx.save("charge", serde_json::json!("ch_1"))
641            .await
642            .expect("save coerces non-object data");
643
644        let saved = ctx.saved().await;
645        assert!(saved.is_object(), "saved should be an object after save()");
646        assert_eq!(saved["charge"], "ch_1");
647    }
648
649    #[test]
650    fn progress_update_carries_job_id_percentage_and_message() {
651        let id = Uuid::new_v4();
652        let update = ProgressUpdate {
653            job_id: id,
654            percentage: 75,
655            message: "almost there".to_string(),
656        };
657        assert_eq!(update.job_id, id);
658        assert_eq!(update.percentage, 75);
659        assert_eq!(update.message, "almost there");
660    }
661}