Skip to main content

forge_core/function/
context.rs

1//! Execution contexts for queries and mutations.
2//!
3//! Every function receives a context providing access to:
4//!
5//! - Database connection (pool or transaction)
6//! - Authentication state (user ID, roles, claims)
7//! - Request metadata (request ID, trace ID, client IP)
8//! - Environment variables
9//! - Job/workflow dispatch (mutations only)
10//!
11//! # QueryContext vs MutationContext
12//!
13//! | Feature | QueryContext | MutationContext |
14//! |---------|--------------|-----------------|
15//! | Database | Pool (read-only) | Transaction or pool |
16//! | Dispatch jobs | No | Yes |
17//! | Start workflows | No | Yes |
18//! | HTTP client | No | Yes (circuit breaker) |
19//!
20//! # Transactional Mutations
21//!
22//! When `transactional = true` (default), mutations run in a transaction.
23//! Jobs and workflows dispatched during the mutation are buffered and only
24//! inserted after the transaction commits successfully.
25//!
26//! ```text
27//! BEGIN
28//!   ├── ctx.db().execute(...)
29//!   ├── ctx.dispatch_job("send_email", ...)  // buffered
30//!   └── return Ok(result)
31//! COMMIT
32//!   └── INSERT INTO forge_jobs (buffered jobs)
33//! ```
34
35use std::collections::HashMap;
36use std::sync::{Arc, Mutex};
37
38use sqlx::postgres::{PgArguments, PgQueryResult, PgRow};
39use sqlx::{FromRow, Postgres, Transaction};
40use tokio::sync::Mutex as AsyncMutex;
41use uuid::Uuid;
42
43use super::dispatch::{JobDispatch, WorkflowDispatch};
44use crate::env::{EnvAccess, EnvProvider, RealEnvProvider};
45use crate::http::CircuitBreakerClient;
46use crate::job::JobInfo;
47
48/// Abstracts over pool and transaction connections so handlers can work with either.
49pub enum DbConn<'a> {
50    Pool(&'a sqlx::PgPool),
51    Transaction(Arc<AsyncMutex<Transaction<'static, Postgres>>>),
52}
53
54impl DbConn<'_> {
55    pub async fn fetch_one<'q, O>(
56        &self,
57        query: sqlx::query::QueryAs<'q, Postgres, O, PgArguments>,
58    ) -> sqlx::Result<O>
59    where
60        O: Send + Unpin + for<'r> FromRow<'r, PgRow>,
61    {
62        match self {
63            DbConn::Pool(pool) => query.fetch_one(*pool).await,
64            DbConn::Transaction(tx) => query.fetch_one(&mut **tx.lock().await).await,
65        }
66    }
67
68    pub async fn fetch_optional<'q, O>(
69        &self,
70        query: sqlx::query::QueryAs<'q, Postgres, O, PgArguments>,
71    ) -> sqlx::Result<Option<O>>
72    where
73        O: Send + Unpin + for<'r> FromRow<'r, PgRow>,
74    {
75        match self {
76            DbConn::Pool(pool) => query.fetch_optional(*pool).await,
77            DbConn::Transaction(tx) => query.fetch_optional(&mut **tx.lock().await).await,
78        }
79    }
80
81    pub async fn fetch_all<'q, O>(
82        &self,
83        query: sqlx::query::QueryAs<'q, Postgres, O, PgArguments>,
84    ) -> sqlx::Result<Vec<O>>
85    where
86        O: Send + Unpin + for<'r> FromRow<'r, PgRow>,
87    {
88        match self {
89            DbConn::Pool(pool) => query.fetch_all(*pool).await,
90            DbConn::Transaction(tx) => query.fetch_all(&mut **tx.lock().await).await,
91        }
92    }
93
94    pub async fn execute<'q>(
95        &self,
96        query: sqlx::query::Query<'q, Postgres, PgArguments>,
97    ) -> sqlx::Result<PgQueryResult> {
98        match self {
99            DbConn::Pool(pool) => query.execute(*pool).await,
100            DbConn::Transaction(tx) => query.execute(&mut **tx.lock().await).await,
101        }
102    }
103}
104
105#[derive(Debug, Clone)]
106pub struct PendingJob {
107    pub id: Uuid,
108    pub job_type: String,
109    pub args: serde_json::Value,
110    pub context: serde_json::Value,
111    pub owner_subject: Option<String>,
112    pub priority: i32,
113    pub max_attempts: i32,
114    pub worker_capability: Option<String>,
115}
116
117#[derive(Debug, Clone)]
118pub struct PendingWorkflow {
119    pub id: Uuid,
120    pub workflow_name: String,
121    pub input: serde_json::Value,
122    pub owner_subject: Option<String>,
123}
124
125#[derive(Default)]
126pub struct OutboxBuffer {
127    pub jobs: Vec<PendingJob>,
128    pub workflows: Vec<PendingWorkflow>,
129}
130
131/// Authentication context available to all functions.
132#[derive(Debug, Clone)]
133pub struct AuthContext {
134    /// The authenticated user ID (if any).
135    user_id: Option<Uuid>,
136    /// User roles.
137    roles: Vec<String>,
138    /// Custom claims from JWT.
139    claims: HashMap<String, serde_json::Value>,
140    /// Whether the request is authenticated.
141    authenticated: bool,
142}
143
144impl AuthContext {
145    /// Create an unauthenticated context.
146    pub fn unauthenticated() -> Self {
147        Self {
148            user_id: None,
149            roles: Vec::new(),
150            claims: HashMap::new(),
151            authenticated: false,
152        }
153    }
154
155    /// Create an authenticated context with a UUID user ID.
156    pub fn authenticated(
157        user_id: Uuid,
158        roles: Vec<String>,
159        claims: HashMap<String, serde_json::Value>,
160    ) -> Self {
161        Self {
162            user_id: Some(user_id),
163            roles,
164            claims,
165            authenticated: true,
166        }
167    }
168
169    /// Create an authenticated context without requiring a UUID user ID.
170    ///
171    /// Use this for auth providers that don't use UUID subjects (e.g., Firebase,
172    /// Clerk). The raw subject string is available via `subject()` method
173    /// from the "sub" claim.
174    pub fn authenticated_without_uuid(
175        roles: Vec<String>,
176        claims: HashMap<String, serde_json::Value>,
177    ) -> Self {
178        Self {
179            user_id: None,
180            roles,
181            claims,
182            authenticated: true,
183        }
184    }
185
186    /// Check if the user is authenticated.
187    pub fn is_authenticated(&self) -> bool {
188        self.authenticated
189    }
190
191    /// Get the user ID if authenticated.
192    pub fn user_id(&self) -> Option<Uuid> {
193        self.user_id
194    }
195
196    /// Get the user ID, returning an error if not authenticated.
197    pub fn require_user_id(&self) -> crate::error::Result<Uuid> {
198        self.user_id
199            .ok_or_else(|| crate::error::ForgeError::Unauthorized("Authentication required".into()))
200    }
201
202    /// Check if the user has a specific role.
203    pub fn has_role(&self, role: &str) -> bool {
204        self.roles.iter().any(|r| r == role)
205    }
206
207    /// Require a specific role, returning an error if not present.
208    pub fn require_role(&self, role: &str) -> crate::error::Result<()> {
209        if self.has_role(role) {
210            Ok(())
211        } else {
212            Err(crate::error::ForgeError::Forbidden(format!(
213                "Required role '{}' not present",
214                role
215            )))
216        }
217    }
218
219    /// Get a custom claim value.
220    pub fn claim(&self, key: &str) -> Option<&serde_json::Value> {
221        self.claims.get(key)
222    }
223
224    /// Get all custom claims.
225    pub fn claims(&self) -> &HashMap<String, serde_json::Value> {
226        &self.claims
227    }
228
229    /// Get all roles.
230    pub fn roles(&self) -> &[String] {
231        &self.roles
232    }
233
234    /// Get the raw subject claim.
235    ///
236    /// This works with any provider's subject format (UUID, email, custom ID).
237    /// For providers like Firebase or Clerk that don't use UUIDs, use this
238    /// instead of `user_id()`.
239    pub fn subject(&self) -> Option<&str> {
240        self.claims.get("sub").and_then(|v| v.as_str())
241    }
242
243    /// Like `require_user_id()` but returns the raw subject string for non-UUID providers.
244    pub fn require_subject(&self) -> crate::error::Result<&str> {
245        if !self.authenticated {
246            return Err(crate::error::ForgeError::Unauthorized(
247                "Authentication required".to_string(),
248            ));
249        }
250        self.subject().ok_or_else(|| {
251            crate::error::ForgeError::Unauthorized("No subject claim in token".to_string())
252        })
253    }
254
255    /// Get a stable principal identifier for access control and cache scoping.
256    ///
257    /// Prefers the raw JWT `sub` claim and falls back to UUID user_id.
258    pub fn principal_id(&self) -> Option<String> {
259        self.subject()
260            .map(ToString::to_string)
261            .or_else(|| self.user_id.map(|id| id.to_string()))
262    }
263
264    /// Check whether this principal should be treated as privileged admin.
265    pub fn is_admin(&self) -> bool {
266        self.roles.iter().any(|r| r == "admin")
267    }
268}
269
270/// Request metadata available to all functions.
271#[derive(Debug, Clone)]
272pub struct RequestMetadata {
273    /// Unique request ID for tracing.
274    pub request_id: Uuid,
275    /// Trace ID for distributed tracing.
276    pub trace_id: String,
277    /// Client IP address.
278    pub client_ip: Option<String>,
279    /// User agent string.
280    pub user_agent: Option<String>,
281    /// Request timestamp.
282    pub timestamp: chrono::DateTime<chrono::Utc>,
283}
284
285impl RequestMetadata {
286    /// Create new request metadata.
287    pub fn new() -> Self {
288        Self {
289            request_id: Uuid::new_v4(),
290            trace_id: Uuid::new_v4().to_string(),
291            client_ip: None,
292            user_agent: None,
293            timestamp: chrono::Utc::now(),
294        }
295    }
296
297    /// Create with a specific trace ID.
298    pub fn with_trace_id(trace_id: String) -> Self {
299        Self {
300            request_id: Uuid::new_v4(),
301            trace_id,
302            client_ip: None,
303            user_agent: None,
304            timestamp: chrono::Utc::now(),
305        }
306    }
307}
308
309impl Default for RequestMetadata {
310    fn default() -> Self {
311        Self::new()
312    }
313}
314
315/// Context for query functions (read-only database access).
316pub struct QueryContext {
317    /// Authentication context.
318    pub auth: AuthContext,
319    /// Request metadata.
320    pub request: RequestMetadata,
321    /// Database pool for read operations.
322    db_pool: sqlx::PgPool,
323    /// Environment variable provider.
324    env_provider: Arc<dyn EnvProvider>,
325}
326
327impl QueryContext {
328    /// Create a new query context.
329    pub fn new(db_pool: sqlx::PgPool, auth: AuthContext, request: RequestMetadata) -> Self {
330        Self {
331            auth,
332            request,
333            db_pool,
334            env_provider: Arc::new(RealEnvProvider::new()),
335        }
336    }
337
338    /// Create a query context with a custom environment provider.
339    pub fn with_env(
340        db_pool: sqlx::PgPool,
341        auth: AuthContext,
342        request: RequestMetadata,
343        env_provider: Arc<dyn EnvProvider>,
344    ) -> Self {
345        Self {
346            auth,
347            request,
348            db_pool,
349            env_provider,
350        }
351    }
352
353    pub fn db(&self) -> &sqlx::PgPool {
354        &self.db_pool
355    }
356
357    /// Returns a `DbConn` wrapping the pool, allowing shared helper functions
358    /// that accept `DbConn` to work with both query and mutation contexts.
359    pub fn db_conn(&self) -> DbConn<'_> {
360        DbConn::Pool(&self.db_pool)
361    }
362
363    pub fn require_user_id(&self) -> crate::error::Result<Uuid> {
364        self.auth.require_user_id()
365    }
366
367    /// Like `require_user_id()` but for non-UUID auth providers.
368    pub fn require_subject(&self) -> crate::error::Result<&str> {
369        self.auth.require_subject()
370    }
371}
372
373impl EnvAccess for QueryContext {
374    fn env_provider(&self) -> &dyn EnvProvider {
375        self.env_provider.as_ref()
376    }
377}
378
379/// Callback type for looking up job info by name.
380pub type JobInfoLookup = Arc<dyn Fn(&str) -> Option<JobInfo> + Send + Sync>;
381
382/// Context for mutation functions (transactional database access).
383pub struct MutationContext {
384    /// Authentication context.
385    pub auth: AuthContext,
386    /// Request metadata.
387    pub request: RequestMetadata,
388    /// Database pool for transactional operations.
389    db_pool: sqlx::PgPool,
390    /// HTTP client with circuit breaker for external requests.
391    http_client: CircuitBreakerClient,
392    /// Optional job dispatcher for dispatching background jobs.
393    job_dispatch: Option<Arc<dyn JobDispatch>>,
394    /// Optional workflow dispatcher for starting workflows.
395    workflow_dispatch: Option<Arc<dyn WorkflowDispatch>>,
396    /// Environment variable provider.
397    env_provider: Arc<dyn EnvProvider>,
398    /// Transaction handle for transactional mutations.
399    tx: Option<Arc<AsyncMutex<Transaction<'static, Postgres>>>>,
400    /// Outbox buffer for jobs/workflows dispatched during transaction.
401    outbox: Option<Arc<Mutex<OutboxBuffer>>>,
402    /// Job info lookup for transactional dispatch.
403    job_info_lookup: Option<JobInfoLookup>,
404}
405
406impl MutationContext {
407    /// Create a new mutation context.
408    pub fn new(db_pool: sqlx::PgPool, auth: AuthContext, request: RequestMetadata) -> Self {
409        Self {
410            auth,
411            request,
412            db_pool,
413            http_client: CircuitBreakerClient::with_defaults(reqwest::Client::new()),
414            job_dispatch: None,
415            workflow_dispatch: None,
416            env_provider: Arc::new(RealEnvProvider::new()),
417            tx: None,
418            outbox: None,
419            job_info_lookup: None,
420        }
421    }
422
423    /// Create a mutation context with dispatch capabilities.
424    pub fn with_dispatch(
425        db_pool: sqlx::PgPool,
426        auth: AuthContext,
427        request: RequestMetadata,
428        http_client: CircuitBreakerClient,
429        job_dispatch: Option<Arc<dyn JobDispatch>>,
430        workflow_dispatch: Option<Arc<dyn WorkflowDispatch>>,
431    ) -> Self {
432        Self {
433            auth,
434            request,
435            db_pool,
436            http_client,
437            job_dispatch,
438            workflow_dispatch,
439            env_provider: Arc::new(RealEnvProvider::new()),
440            tx: None,
441            outbox: None,
442            job_info_lookup: None,
443        }
444    }
445
446    /// Create a mutation context with a custom environment provider.
447    pub fn with_env(
448        db_pool: sqlx::PgPool,
449        auth: AuthContext,
450        request: RequestMetadata,
451        http_client: CircuitBreakerClient,
452        job_dispatch: Option<Arc<dyn JobDispatch>>,
453        workflow_dispatch: Option<Arc<dyn WorkflowDispatch>>,
454        env_provider: Arc<dyn EnvProvider>,
455    ) -> Self {
456        Self {
457            auth,
458            request,
459            db_pool,
460            http_client,
461            job_dispatch,
462            workflow_dispatch,
463            env_provider,
464            tx: None,
465            outbox: None,
466            job_info_lookup: None,
467        }
468    }
469
470    /// Returns handles to transaction and outbox for the caller to commit/flush.
471    #[allow(clippy::type_complexity)]
472    pub fn with_transaction(
473        db_pool: sqlx::PgPool,
474        tx: Transaction<'static, Postgres>,
475        auth: AuthContext,
476        request: RequestMetadata,
477        http_client: CircuitBreakerClient,
478        job_info_lookup: JobInfoLookup,
479    ) -> (
480        Self,
481        Arc<AsyncMutex<Transaction<'static, Postgres>>>,
482        Arc<Mutex<OutboxBuffer>>,
483    ) {
484        let tx_handle = Arc::new(AsyncMutex::new(tx));
485        let outbox = Arc::new(Mutex::new(OutboxBuffer::default()));
486
487        let ctx = Self {
488            auth,
489            request,
490            db_pool,
491            http_client,
492            job_dispatch: None,
493            workflow_dispatch: None,
494            env_provider: Arc::new(RealEnvProvider::new()),
495            tx: Some(tx_handle.clone()),
496            outbox: Some(outbox.clone()),
497            job_info_lookup: Some(job_info_lookup),
498        };
499
500        (ctx, tx_handle, outbox)
501    }
502
503    pub fn is_transactional(&self) -> bool {
504        self.tx.is_some()
505    }
506
507    pub fn db(&self) -> DbConn<'_> {
508        match &self.tx {
509            Some(tx) => DbConn::Transaction(tx.clone()),
510            None => DbConn::Pool(&self.db_pool),
511        }
512    }
513
514    /// Direct pool access for operations that cannot run inside a transaction.
515    pub fn pool(&self) -> &sqlx::PgPool {
516        &self.db_pool
517    }
518
519    /// Get the HTTP client for external requests.
520    ///
521    /// The client includes circuit breaker protection that tracks failure rates
522    /// per host. After repeated failures, requests fail fast to prevent cascade
523    /// failures when downstream services are unhealthy.
524    pub fn http(&self) -> &reqwest::Client {
525        self.http_client.inner()
526    }
527
528    /// Get the circuit breaker client directly for advanced usage.
529    pub fn http_with_circuit_breaker(&self) -> &CircuitBreakerClient {
530        &self.http_client
531    }
532
533    pub fn require_user_id(&self) -> crate::error::Result<Uuid> {
534        self.auth.require_user_id()
535    }
536
537    pub fn require_subject(&self) -> crate::error::Result<&str> {
538        self.auth.require_subject()
539    }
540
541    /// In transactional mode, buffers for atomic commit; otherwise dispatches immediately.
542    pub async fn dispatch_job<T: serde::Serialize>(
543        &self,
544        job_type: &str,
545        args: T,
546    ) -> crate::error::Result<Uuid> {
547        let args_json = serde_json::to_value(args)?;
548
549        // Transactional mode: buffer the job for atomic commit
550        if let (Some(outbox), Some(job_info_lookup)) = (&self.outbox, &self.job_info_lookup) {
551            let job_info = job_info_lookup(job_type).ok_or_else(|| {
552                crate::error::ForgeError::NotFound(format!("Job type '{}' not found", job_type))
553            })?;
554
555            let pending = PendingJob {
556                id: Uuid::new_v4(),
557                job_type: job_type.to_string(),
558                args: args_json,
559                context: serde_json::json!({}),
560                owner_subject: self.auth.principal_id(),
561                priority: job_info.priority.as_i32(),
562                max_attempts: job_info.retry.max_attempts as i32,
563                worker_capability: job_info.worker_capability.map(|s| s.to_string()),
564            };
565
566            let job_id = pending.id;
567            outbox
568                .lock()
569                .expect("outbox lock poisoned")
570                .jobs
571                .push(pending);
572            return Ok(job_id);
573        }
574
575        // Non-transactional mode: dispatch immediately
576        let dispatcher = self.job_dispatch.as_ref().ok_or_else(|| {
577            crate::error::ForgeError::Internal("Job dispatch not available".into())
578        })?;
579        dispatcher
580            .dispatch_by_name(job_type, args_json, self.auth.principal_id())
581            .await
582    }
583
584    /// Dispatch a job with initial context.
585    pub async fn dispatch_job_with_context<T: serde::Serialize>(
586        &self,
587        job_type: &str,
588        args: T,
589        context: serde_json::Value,
590    ) -> crate::error::Result<Uuid> {
591        let args_json = serde_json::to_value(args)?;
592
593        if let (Some(outbox), Some(job_info_lookup)) = (&self.outbox, &self.job_info_lookup) {
594            let job_info = job_info_lookup(job_type).ok_or_else(|| {
595                crate::error::ForgeError::NotFound(format!("Job type '{}' not found", job_type))
596            })?;
597
598            let pending = PendingJob {
599                id: Uuid::new_v4(),
600                job_type: job_type.to_string(),
601                args: args_json,
602                context,
603                owner_subject: self.auth.principal_id(),
604                priority: job_info.priority.as_i32(),
605                max_attempts: job_info.retry.max_attempts as i32,
606                worker_capability: job_info.worker_capability.map(|s| s.to_string()),
607            };
608
609            let job_id = pending.id;
610            outbox
611                .lock()
612                .expect("outbox lock poisoned")
613                .jobs
614                .push(pending);
615            return Ok(job_id);
616        }
617
618        let dispatcher = self.job_dispatch.as_ref().ok_or_else(|| {
619            crate::error::ForgeError::Internal("Job dispatch not available".into())
620        })?;
621        dispatcher
622            .dispatch_by_name(job_type, args_json, self.auth.principal_id())
623            .await
624    }
625
626    /// Request cancellation for a job.
627    pub async fn cancel_job(
628        &self,
629        job_id: Uuid,
630        reason: Option<String>,
631    ) -> crate::error::Result<bool> {
632        let dispatcher = self.job_dispatch.as_ref().ok_or_else(|| {
633            crate::error::ForgeError::Internal("Job dispatch not available".into())
634        })?;
635        dispatcher.cancel(job_id, reason).await
636    }
637
638    /// In transactional mode, buffers for atomic commit; otherwise starts immediately.
639    pub async fn start_workflow<T: serde::Serialize>(
640        &self,
641        workflow_name: &str,
642        input: T,
643    ) -> crate::error::Result<Uuid> {
644        let input_json = serde_json::to_value(input)?;
645
646        // Transactional mode: buffer the workflow for atomic commit
647        if let Some(outbox) = &self.outbox {
648            let pending = PendingWorkflow {
649                id: Uuid::new_v4(),
650                workflow_name: workflow_name.to_string(),
651                input: input_json,
652                owner_subject: self.auth.principal_id(),
653            };
654
655            let workflow_id = pending.id;
656            outbox
657                .lock()
658                .expect("outbox lock poisoned")
659                .workflows
660                .push(pending);
661            return Ok(workflow_id);
662        }
663
664        // Non-transactional mode: start immediately
665        let dispatcher = self.workflow_dispatch.as_ref().ok_or_else(|| {
666            crate::error::ForgeError::Internal("Workflow dispatch not available".into())
667        })?;
668        dispatcher
669            .start_by_name(workflow_name, input_json, self.auth.principal_id())
670            .await
671    }
672}
673
674impl EnvAccess for MutationContext {
675    fn env_provider(&self) -> &dyn EnvProvider {
676        self.env_provider.as_ref()
677    }
678}
679
680#[cfg(test)]
681#[allow(clippy::unwrap_used, clippy::indexing_slicing)]
682mod tests {
683    use super::*;
684
685    #[test]
686    fn test_auth_context_unauthenticated() {
687        let ctx = AuthContext::unauthenticated();
688        assert!(!ctx.is_authenticated());
689        assert!(ctx.user_id().is_none());
690        assert!(ctx.require_user_id().is_err());
691    }
692
693    #[test]
694    fn test_auth_context_authenticated() {
695        let user_id = Uuid::new_v4();
696        let ctx = AuthContext::authenticated(
697            user_id,
698            vec!["admin".to_string(), "user".to_string()],
699            HashMap::new(),
700        );
701
702        assert!(ctx.is_authenticated());
703        assert_eq!(ctx.user_id(), Some(user_id));
704        assert!(ctx.require_user_id().is_ok());
705        assert!(ctx.has_role("admin"));
706        assert!(ctx.has_role("user"));
707        assert!(!ctx.has_role("superadmin"));
708        assert!(ctx.require_role("admin").is_ok());
709        assert!(ctx.require_role("superadmin").is_err());
710    }
711
712    #[test]
713    fn test_auth_context_with_claims() {
714        let mut claims = HashMap::new();
715        claims.insert("org_id".to_string(), serde_json::json!("org-123"));
716
717        let ctx = AuthContext::authenticated(Uuid::new_v4(), vec![], claims);
718
719        assert_eq!(ctx.claim("org_id"), Some(&serde_json::json!("org-123")));
720        assert!(ctx.claim("nonexistent").is_none());
721    }
722
723    #[test]
724    fn test_request_metadata() {
725        let meta = RequestMetadata::new();
726        assert!(!meta.trace_id.is_empty());
727        assert!(meta.client_ip.is_none());
728
729        let meta2 = RequestMetadata::with_trace_id("trace-123".to_string());
730        assert_eq!(meta2.trace_id, "trace-123");
731    }
732}