1use std::sync::{Arc, mpsc};
2use std::time::Duration;
3
4use uuid::Uuid;
5
6use serde::Serialize;
7
8use crate::env::{EnvAccess, EnvProvider, RealEnvProvider};
9use crate::function::{AuthContext, JobDispatch, KvHandle, WorkflowDispatch};
10use crate::http::CircuitBreakerClient;
11
12pub fn empty_saved_data() -> serde_json::Value {
14 serde_json::Value::Object(serde_json::Map::new())
15}
16
17#[non_exhaustive]
19pub struct JobContext {
20 pub job_id: Uuid,
22 pub job_type: String,
23 pub attempt: u32,
24 pub max_attempts: u32,
25 pub auth: AuthContext,
26 saved_data: Arc<tokio::sync::RwLock<serde_json::Value>>,
28 db_pool: sqlx::PgPool,
29 http_client: CircuitBreakerClient,
30 http_timeout: Option<Duration>,
32 progress_tx: Option<mpsc::Sender<ProgressUpdate>>,
33 env_provider: Arc<dyn EnvProvider>,
34 kv: Option<Arc<dyn KvHandle>>,
35 job_dispatch: Option<Arc<dyn JobDispatch>>,
37 workflow_dispatch: Option<Arc<dyn WorkflowDispatch>>,
40}
41
42#[derive(Debug, Clone)]
44pub struct ProgressUpdate {
45 pub job_id: Uuid,
46 pub percentage: u8,
48 pub message: String,
49}
50
51impl JobContext {
52 pub fn new(
54 job_id: Uuid,
55 job_type: String,
56 attempt: u32,
57 max_attempts: u32,
58 db_pool: sqlx::PgPool,
59 http_client: CircuitBreakerClient,
60 ) -> Self {
61 Self {
62 job_id,
63 job_type,
64 attempt,
65 max_attempts,
66 auth: AuthContext::unauthenticated(),
67 saved_data: Arc::new(tokio::sync::RwLock::new(empty_saved_data())),
68 db_pool,
69 http_client,
70 http_timeout: None,
71 progress_tx: None,
72 env_provider: Arc::new(RealEnvProvider::new()),
73 kv: None,
74 job_dispatch: None,
75 workflow_dispatch: None,
76 }
77 }
78
79 pub fn with_kv(mut self, kv: Arc<dyn KvHandle>) -> Self {
82 self.kv = Some(kv);
83 self
84 }
85
86 pub fn with_job_dispatch(mut self, dispatcher: Arc<dyn JobDispatch>) -> Self {
90 self.job_dispatch = Some(dispatcher);
91 self
92 }
93
94 pub fn with_workflow_dispatch(mut self, dispatcher: Arc<dyn WorkflowDispatch>) -> Self {
99 self.workflow_dispatch = Some(dispatcher);
100 self
101 }
102
103 pub fn kv(&self) -> crate::error::Result<&dyn KvHandle> {
105 self.kv
106 .as_deref()
107 .ok_or_else(|| crate::error::ForgeError::internal("KV store not available"))
108 }
109
110 pub fn with_saved(mut self, data: serde_json::Value) -> Self {
112 self.saved_data = Arc::new(tokio::sync::RwLock::new(data));
113 self
114 }
115
116 pub fn with_auth(mut self, auth: AuthContext) -> Self {
118 self.auth = auth;
119 self
120 }
121
122 pub fn with_tenant_id(mut self, tenant_id: Uuid) -> Self {
128 let mut claims = self.auth.claims().clone();
129 claims.insert(
130 "tenant_id".to_string(),
131 serde_json::Value::String(tenant_id.to_string()),
132 );
133 self.auth = if self.auth.is_authenticated() {
134 if let Some(user_id) = self.auth.user_id() {
135 AuthContext::authenticated(user_id, self.auth.roles().to_vec(), claims)
136 } else {
137 AuthContext::authenticated_without_uuid(self.auth.roles().to_vec(), claims)
138 }
139 } else {
140 AuthContext::authenticated_without_uuid(Vec::new(), claims)
141 };
142 self
143 }
144
145 pub fn with_progress(mut self, tx: mpsc::Sender<ProgressUpdate>) -> Self {
147 self.progress_tx = Some(tx);
148 self
149 }
150
151 pub fn with_env_provider(mut self, provider: Arc<dyn EnvProvider>) -> Self {
153 self.env_provider = provider;
154 self
155 }
156
157 pub fn db(&self) -> crate::function::ForgeDb {
159 crate::function::ForgeDb::from_pool(&self.db_pool)
160 }
161
162 pub fn db_conn(&self) -> crate::function::DbConn<'_> {
164 crate::function::DbConn::Pool(self.db_pool.clone())
165 }
166
167 pub async fn conn(&self) -> sqlx::Result<crate::function::ForgeConn<'static>> {
169 Ok(crate::function::ForgeConn::Pool(
170 self.db_pool.acquire().await?,
171 ))
172 }
173
174 pub fn http(&self) -> crate::http::HttpClient {
176 self.http_client.with_timeout(self.http_timeout)
177 }
178
179 pub fn raw_http(&self) -> &reqwest::Client {
181 self.http_client.inner()
182 }
183
184 pub fn set_http_timeout(&mut self, timeout: Option<Duration>) {
185 self.http_timeout = timeout;
186 }
187
188 #[doc(hidden)]
195 pub fn pool(&self) -> &sqlx::PgPool {
196 &self.db_pool
197 }
198
199 #[doc(hidden)]
205 pub fn circuit_breaker_client(&self) -> &CircuitBreakerClient {
206 &self.http_client
207 }
208
209 #[doc(hidden)]
214 pub fn kv_handle(&self) -> Option<Arc<dyn KvHandle>> {
215 self.kv.clone()
216 }
217
218 pub fn progress(&self, percentage: u8, message: impl Into<String>) -> crate::Result<()> {
220 let update = ProgressUpdate {
221 job_id: self.job_id,
222 percentage: percentage.min(100),
223 message: message.into(),
224 };
225
226 if let Some(ref tx) = self.progress_tx {
227 tx.send(update).map_err(|e| {
228 crate::ForgeError::internal(format!("Failed to send progress: {e}"))
229 })?;
230 }
231
232 Ok(())
233 }
234
235 pub async fn saved(&self) -> serde_json::Value {
240 self.saved_data.read().await.clone()
241 }
242
243 pub async fn save(&self, key: &str, value: serde_json::Value) -> crate::Result<()> {
259 let mut guard = self.saved_data.write().await;
260 Self::apply_save(&mut guard, key, value);
261 let persisted = Self::clone_and_drop(guard);
262 if self.job_id.is_nil() {
263 return Ok(());
264 }
265 self.persist_saved_data(persisted).await
266 }
267
268 pub async fn dispatch_job<T: Serialize>(
276 &self,
277 job_type: &str,
278 args: &T,
279 ) -> crate::Result<Uuid> {
280 let args_json = serde_json::to_value(args)
281 .map_err(|e| crate::ForgeError::Serialization(e.to_string()))?;
282 let dispatcher = self
283 .job_dispatch
284 .as_ref()
285 .ok_or_else(|| crate::ForgeError::internal("Job dispatch not available"))?;
286 dispatcher
287 .dispatch_by_name(
288 job_type,
289 args_json,
290 self.auth.principal_id(),
291 self.auth.tenant_id(),
292 )
293 .await
294 }
295
296 pub async fn dispatch<J: crate::ForgeJob>(&self, args: &J::Args) -> crate::Result<Uuid> {
299 self.dispatch_job(J::info().name, args).await
300 }
301
302 pub async fn start_workflow<T: Serialize>(
310 &self,
311 workflow_name: &str,
312 args: &T,
313 ) -> crate::Result<Uuid> {
314 let input_json = serde_json::to_value(args)
315 .map_err(|e| crate::ForgeError::Serialization(e.to_string()))?;
316 let dispatcher = self
317 .workflow_dispatch
318 .as_ref()
319 .ok_or_else(|| crate::ForgeError::internal("Workflow dispatch not available"))?;
320 dispatcher
321 .start_by_name(workflow_name, input_json, self.auth.principal_id(), None)
322 .await
323 }
324
325 pub async fn is_cancel_requested(&self) -> crate::Result<bool> {
327 let row = sqlx::query_scalar!(
328 r#"
329 SELECT status
330 FROM forge_jobs
331 WHERE id = $1
332 "#,
333 self.job_id
334 )
335 .fetch_optional(&self.db_pool)
336 .await
337 .map_err(crate::ForgeError::Database)?;
338
339 Ok(matches!(
340 row.as_deref(),
341 Some("cancel_requested") | Some("cancelled")
342 ))
343 }
344
345 pub async fn check_cancelled(&self) -> crate::Result<()> {
347 if self.is_cancel_requested().await? {
348 Err(crate::ForgeError::JobCancelled(
349 "Job cancellation requested".to_string(),
350 ))
351 } else {
352 Ok(())
353 }
354 }
355
356 async fn persist_saved_data(&self, data: serde_json::Value) -> crate::Result<()> {
357 sqlx::query!(
358 r#"
359 UPDATE forge_jobs
360 SET job_context = $2
361 WHERE id = $1
362 "#,
363 self.job_id,
364 data,
365 )
366 .execute(&self.db_pool)
367 .await
368 .map_err(crate::ForgeError::Database)?;
369
370 Ok(())
371 }
372
373 fn apply_save(data: &mut serde_json::Value, key: &str, value: serde_json::Value) {
374 if let Some(map) = data.as_object_mut() {
375 map.insert(key.to_string(), value);
376 } else {
377 let mut map = serde_json::Map::new();
378 map.insert(key.to_string(), value);
379 *data = serde_json::Value::Object(map);
380 }
381 }
382
383 fn clone_and_drop(
384 guard: tokio::sync::RwLockWriteGuard<'_, serde_json::Value>,
385 ) -> serde_json::Value {
386 let cloned = guard.clone();
387 drop(guard);
388 cloned
389 }
390
391 pub async fn heartbeat(&self) -> crate::Result<()> {
393 sqlx::query!(
394 r#"
395 UPDATE forge_jobs
396 SET last_heartbeat = NOW()
397 WHERE id = $1
398 "#,
399 self.job_id,
400 )
401 .execute(&self.db_pool)
402 .await
403 .map_err(crate::ForgeError::Database)?;
404
405 Ok(())
406 }
407
408 pub fn is_retry(&self) -> bool {
410 self.attempt > 1
411 }
412
413 pub fn is_last_attempt(&self) -> bool {
415 self.attempt >= self.max_attempts
416 }
417}
418
419impl EnvAccess for JobContext {
420 fn env_provider(&self) -> &dyn EnvProvider {
421 self.env_provider.as_ref()
422 }
423}
424
425#[cfg(test)]
426#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
427mod tests {
428 use super::*;
429
430 #[tokio::test]
431 async fn test_job_context_creation() {
432 let pool = sqlx::postgres::PgPoolOptions::new()
433 .max_connections(1)
434 .connect_lazy("postgres://localhost/nonexistent")
435 .expect("Failed to create mock pool");
436
437 let job_id = Uuid::new_v4();
438 let ctx = JobContext::new(
439 job_id,
440 "test_job".to_string(),
441 1,
442 3,
443 pool,
444 CircuitBreakerClient::with_defaults(reqwest::Client::new()),
445 );
446
447 assert_eq!(ctx.job_id, job_id);
448 assert_eq!(ctx.job_type, "test_job");
449 assert_eq!(ctx.attempt, 1);
450 assert_eq!(ctx.max_attempts, 3);
451 assert!(!ctx.is_retry());
452 assert!(!ctx.is_last_attempt());
453 }
454
455 #[tokio::test]
456 async fn test_is_retry() {
457 let pool = sqlx::postgres::PgPoolOptions::new()
458 .max_connections(1)
459 .connect_lazy("postgres://localhost/nonexistent")
460 .expect("Failed to create mock pool");
461
462 let ctx = JobContext::new(
463 Uuid::new_v4(),
464 "test".to_string(),
465 2,
466 3,
467 pool,
468 CircuitBreakerClient::with_defaults(reqwest::Client::new()),
469 );
470
471 assert!(ctx.is_retry());
472 }
473
474 #[tokio::test]
475 async fn test_is_last_attempt() {
476 let pool = sqlx::postgres::PgPoolOptions::new()
477 .max_connections(1)
478 .connect_lazy("postgres://localhost/nonexistent")
479 .expect("Failed to create mock pool");
480
481 let ctx = JobContext::new(
482 Uuid::new_v4(),
483 "test".to_string(),
484 3,
485 3,
486 pool,
487 CircuitBreakerClient::with_defaults(reqwest::Client::new()),
488 );
489
490 assert!(ctx.is_last_attempt());
491 }
492
493 #[test]
494 fn test_progress_update() {
495 let update = ProgressUpdate {
496 job_id: Uuid::new_v4(),
497 percentage: 50,
498 message: "Halfway there".to_string(),
499 };
500
501 assert_eq!(update.percentage, 50);
502 assert_eq!(update.message, "Halfway there");
503 }
504
505 #[tokio::test]
506 async fn test_saved_data_in_memory() {
507 let pool = sqlx::postgres::PgPoolOptions::new()
508 .max_connections(1)
509 .connect_lazy("postgres://localhost/nonexistent")
510 .expect("Failed to create mock pool");
511
512 let ctx = JobContext::new(
513 Uuid::nil(),
514 "test_job".to_string(),
515 1,
516 3,
517 pool,
518 CircuitBreakerClient::with_defaults(reqwest::Client::new()),
519 )
520 .with_saved(serde_json::json!({"foo": "bar"}));
521
522 let saved = ctx.saved().await;
523 assert_eq!(saved["foo"], "bar");
524 }
525
526 #[tokio::test]
527 async fn test_save_key_value() {
528 let pool = sqlx::postgres::PgPoolOptions::new()
529 .max_connections(1)
530 .connect_lazy("postgres://localhost/nonexistent")
531 .expect("Failed to create mock pool");
532
533 let ctx = JobContext::new(
534 Uuid::nil(),
535 "test_job".to_string(),
536 1,
537 3,
538 pool,
539 CircuitBreakerClient::with_defaults(reqwest::Client::new()),
540 );
541
542 ctx.save("charge_id", serde_json::json!("ch_123"))
543 .await
544 .unwrap();
545 ctx.save("amount", serde_json::json!(100)).await.unwrap();
546
547 let saved = ctx.saved().await;
548 assert_eq!(saved["charge_id"], "ch_123");
549 assert_eq!(saved["amount"], 100);
550 }
551
552 fn mock_pool() -> sqlx::PgPool {
553 sqlx::postgres::PgPoolOptions::new()
554 .max_connections(1)
555 .connect_lazy("postgres://localhost/nonexistent")
556 .expect("Failed to create mock pool")
557 }
558
559 fn nil_ctx() -> JobContext {
560 JobContext::new(
561 Uuid::nil(),
562 "test_job".to_string(),
563 1,
564 3,
565 mock_pool(),
566 CircuitBreakerClient::with_defaults(reqwest::Client::new()),
567 )
568 }
569
570 #[test]
571 fn empty_saved_data_is_an_empty_object() {
572 let data = empty_saved_data();
573 let obj = data.as_object().expect("empty_saved_data is an object");
574 assert!(obj.is_empty());
575 }
576
577 #[tokio::test]
578 async fn progress_without_channel_is_a_noop() {
579 let ctx = nil_ctx();
580 ctx.progress(42, "boot")
581 .expect("noop progress should not error");
582 }
583
584 #[tokio::test]
585 async fn progress_clamps_percentage_to_100() {
586 let (tx, rx) = mpsc::channel();
587 let ctx = nil_ctx().with_progress(tx);
588 ctx.progress(250, "over").expect("send should succeed");
589 let update = rx.recv().expect("update available");
590 assert_eq!(update.percentage, 100);
591 assert_eq!(update.message, "over");
592 assert_eq!(update.job_id, ctx.job_id);
593 }
594
595 #[tokio::test]
596 async fn progress_returns_job_error_when_receiver_dropped() {
597 let (tx, rx) = mpsc::channel::<ProgressUpdate>();
598 drop(rx);
599 let ctx = nil_ctx().with_progress(tx);
600 let err = ctx
601 .progress(10, "lost")
602 .expect_err("dropped receiver should fail send");
603 match err {
604 crate::ForgeError::Internal { context: msg, .. } => {
605 assert!(msg.contains("Failed to send progress"), "got: {msg}");
606 }
607 other => panic!("expected ForgeError::Internal, got {other:?}"),
608 }
609 }
610
611 #[tokio::test]
612 async fn with_auth_threads_authenticated_principal() {
613 let user = Uuid::new_v4();
614 let ctx = nil_ctx().with_auth(AuthContext::authenticated(
615 user,
616 vec!["admin".to_string()],
617 Default::default(),
618 ));
619 assert_eq!(ctx.auth.user_id(), Some(user));
620 assert!(ctx.auth.has_role("admin"));
621 }
622
623 #[tokio::test]
624 async fn with_env_provider_reaches_through_env_access_trait() {
625 use crate::env::MockEnvProvider;
626 let mut mock = MockEnvProvider::new();
627 mock.set("API_KEY", "sk_test");
628 let ctx = nil_ctx().with_env_provider(Arc::new(mock));
629
630 assert_eq!(ctx.env("API_KEY"), Some("sk_test".to_string()));
631 assert!(ctx.env("MISSING").is_none());
632 }
633
634 #[tokio::test]
635 async fn save_promotes_non_object_value_into_object() {
636 let ctx = nil_ctx().with_saved(serde_json::Value::Null);
640 ctx.save("charge", serde_json::json!("ch_1"))
641 .await
642 .expect("save coerces non-object data");
643
644 let saved = ctx.saved().await;
645 assert!(saved.is_object(), "saved should be an object after save()");
646 assert_eq!(saved["charge"], "ch_1");
647 }
648
649 #[test]
650 fn progress_update_carries_job_id_percentage_and_message() {
651 let id = Uuid::new_v4();
652 let update = ProgressUpdate {
653 job_id: id,
654 percentage: 75,
655 message: "almost there".to_string(),
656 };
657 assert_eq!(update.job_id, id);
658 assert_eq!(update.percentage, 75);
659 assert_eq!(update.message, "almost there");
660 }
661}