Skip to main content

forge_core/function/
context.rs

1//! Execution contexts for queries and mutations.
2//!
3//! Every function receives a context providing access to:
4//!
5//! - Database connection (pool or transaction)
6//! - Authentication state (user ID, roles, claims)
7//! - Request metadata (request ID, trace ID, client IP)
8//! - Environment variables
9//! - Job/workflow dispatch (mutations only)
10//!
11//! # QueryContext vs MutationContext
12//!
13//! | Feature | QueryContext | MutationContext |
14//! |---------|--------------|-----------------|
15//! | Database | Pool (read-only) | Transaction or pool |
16//! | Dispatch jobs | No | Yes |
17//! | Start workflows | No | Yes |
18//! | HTTP client | No | Yes (circuit breaker) |
19//!
20//! # Transactional Mutations
21//!
22//! When `transactional = true` (default), mutations run in a transaction.
23//! Jobs and workflows dispatched during the mutation are buffered and only
24//! inserted after the transaction commits successfully.
25//!
26//! ```text
27//! BEGIN
28//!   ├── ctx.db().execute(...)
29//!   ├── ctx.dispatch_job("send_email", ...)  // buffered
30//!   └── return Ok(result)
31//! COMMIT
32//!   └── INSERT INTO forge_jobs (buffered jobs)
33//! ```
34
35use std::collections::HashMap;
36use std::sync::{Arc, Mutex};
37
38use sqlx::postgres::{PgArguments, PgQueryResult, PgRow};
39use sqlx::{FromRow, Postgres, Transaction};
40use tokio::sync::Mutex as AsyncMutex;
41use uuid::Uuid;
42
43use super::dispatch::{JobDispatch, WorkflowDispatch};
44use crate::env::{EnvAccess, EnvProvider, RealEnvProvider};
45use crate::http::CircuitBreakerClient;
46use crate::job::JobInfo;
47
48/// Abstracts over pool and transaction connections so handlers can work with either.
49pub enum DbConn<'a> {
50    Pool(&'a sqlx::PgPool),
51    Transaction(Arc<AsyncMutex<Transaction<'static, Postgres>>>),
52}
53
54impl DbConn<'_> {
55    pub async fn fetch_one<'q, O>(
56        &self,
57        query: sqlx::query::QueryAs<'q, Postgres, O, PgArguments>,
58    ) -> sqlx::Result<O>
59    where
60        O: Send + Unpin + for<'r> FromRow<'r, PgRow>,
61    {
62        match self {
63            DbConn::Pool(pool) => query.fetch_one(*pool).await,
64            DbConn::Transaction(tx) => query.fetch_one(&mut **tx.lock().await).await,
65        }
66    }
67
68    pub async fn fetch_optional<'q, O>(
69        &self,
70        query: sqlx::query::QueryAs<'q, Postgres, O, PgArguments>,
71    ) -> sqlx::Result<Option<O>>
72    where
73        O: Send + Unpin + for<'r> FromRow<'r, PgRow>,
74    {
75        match self {
76            DbConn::Pool(pool) => query.fetch_optional(*pool).await,
77            DbConn::Transaction(tx) => query.fetch_optional(&mut **tx.lock().await).await,
78        }
79    }
80
81    pub async fn fetch_all<'q, O>(
82        &self,
83        query: sqlx::query::QueryAs<'q, Postgres, O, PgArguments>,
84    ) -> sqlx::Result<Vec<O>>
85    where
86        O: Send + Unpin + for<'r> FromRow<'r, PgRow>,
87    {
88        match self {
89            DbConn::Pool(pool) => query.fetch_all(*pool).await,
90            DbConn::Transaction(tx) => query.fetch_all(&mut **tx.lock().await).await,
91        }
92    }
93
94    pub async fn execute<'q>(
95        &self,
96        query: sqlx::query::Query<'q, Postgres, PgArguments>,
97    ) -> sqlx::Result<PgQueryResult> {
98        match self {
99            DbConn::Pool(pool) => query.execute(*pool).await,
100            DbConn::Transaction(tx) => query.execute(&mut **tx.lock().await).await,
101        }
102    }
103}
104
105#[derive(Debug, Clone)]
106pub struct PendingJob {
107    pub id: Uuid,
108    pub job_type: String,
109    pub args: serde_json::Value,
110    pub context: serde_json::Value,
111    pub owner_subject: Option<String>,
112    pub priority: i32,
113    pub max_attempts: i32,
114    pub worker_capability: Option<String>,
115}
116
117#[derive(Debug, Clone)]
118pub struct PendingWorkflow {
119    pub id: Uuid,
120    pub workflow_name: String,
121    pub input: serde_json::Value,
122    pub owner_subject: Option<String>,
123}
124
125#[derive(Default)]
126pub struct OutboxBuffer {
127    pub jobs: Vec<PendingJob>,
128    pub workflows: Vec<PendingWorkflow>,
129}
130
131/// Authentication context available to all functions.
132#[derive(Debug, Clone)]
133pub struct AuthContext {
134    /// The authenticated user ID (if any).
135    user_id: Option<Uuid>,
136    /// User roles.
137    roles: Vec<String>,
138    /// Custom claims from JWT.
139    claims: HashMap<String, serde_json::Value>,
140    /// Whether the request is authenticated.
141    authenticated: bool,
142}
143
144impl AuthContext {
145    /// Create an unauthenticated context.
146    pub fn unauthenticated() -> Self {
147        Self {
148            user_id: None,
149            roles: Vec::new(),
150            claims: HashMap::new(),
151            authenticated: false,
152        }
153    }
154
155    /// Create an authenticated context with a UUID user ID.
156    pub fn authenticated(
157        user_id: Uuid,
158        roles: Vec<String>,
159        claims: HashMap<String, serde_json::Value>,
160    ) -> Self {
161        Self {
162            user_id: Some(user_id),
163            roles,
164            claims,
165            authenticated: true,
166        }
167    }
168
169    /// Create an authenticated context without requiring a UUID user ID.
170    ///
171    /// Use this for auth providers that don't use UUID subjects (e.g., Firebase,
172    /// Clerk). The raw subject string is available via `subject()` method
173    /// from the "sub" claim.
174    pub fn authenticated_without_uuid(
175        roles: Vec<String>,
176        claims: HashMap<String, serde_json::Value>,
177    ) -> Self {
178        Self {
179            user_id: None,
180            roles,
181            claims,
182            authenticated: true,
183        }
184    }
185
186    /// Check if the user is authenticated.
187    pub fn is_authenticated(&self) -> bool {
188        self.authenticated
189    }
190
191    /// Get the user ID if authenticated.
192    pub fn user_id(&self) -> Option<Uuid> {
193        self.user_id
194    }
195
196    /// Get the user ID, returning an error if not authenticated.
197    pub fn require_user_id(&self) -> crate::error::Result<Uuid> {
198        self.user_id
199            .ok_or_else(|| crate::error::ForgeError::Unauthorized("Authentication required".into()))
200    }
201
202    /// Check if the user has a specific role.
203    pub fn has_role(&self, role: &str) -> bool {
204        self.roles.iter().any(|r| r == role)
205    }
206
207    /// Require a specific role, returning an error if not present.
208    pub fn require_role(&self, role: &str) -> crate::error::Result<()> {
209        if self.has_role(role) {
210            Ok(())
211        } else {
212            Err(crate::error::ForgeError::Forbidden(format!(
213                "Required role '{}' not present",
214                role
215            )))
216        }
217    }
218
219    /// Get a custom claim value.
220    pub fn claim(&self, key: &str) -> Option<&serde_json::Value> {
221        self.claims.get(key)
222    }
223
224    /// Get all custom claims.
225    pub fn claims(&self) -> &HashMap<String, serde_json::Value> {
226        &self.claims
227    }
228
229    /// Get all roles.
230    pub fn roles(&self) -> &[String] {
231        &self.roles
232    }
233
234    /// Get the raw subject claim.
235    ///
236    /// This works with any provider's subject format (UUID, email, custom ID).
237    /// For providers like Firebase or Clerk that don't use UUIDs, use this
238    /// instead of `user_id()`.
239    pub fn subject(&self) -> Option<&str> {
240        self.claims.get("sub").and_then(|v| v.as_str())
241    }
242
243    /// Like `require_user_id()` but returns the raw subject string for non-UUID providers.
244    pub fn require_subject(&self) -> crate::error::Result<&str> {
245        if !self.authenticated {
246            return Err(crate::error::ForgeError::Unauthorized(
247                "Authentication required".to_string(),
248            ));
249        }
250        self.subject().ok_or_else(|| {
251            crate::error::ForgeError::Unauthorized("No subject claim in token".to_string())
252        })
253    }
254
255    /// Get a stable principal identifier for access control and cache scoping.
256    ///
257    /// Prefers the raw JWT `sub` claim and falls back to UUID user_id.
258    pub fn principal_id(&self) -> Option<String> {
259        self.subject()
260            .map(ToString::to_string)
261            .or_else(|| self.user_id.map(|id| id.to_string()))
262    }
263
264    /// Check whether this principal should be treated as privileged admin.
265    pub fn is_admin(&self) -> bool {
266        self.roles.iter().any(|r| r.eq_ignore_ascii_case("admin"))
267    }
268}
269
270/// Request metadata available to all functions.
271#[derive(Debug, Clone)]
272pub struct RequestMetadata {
273    /// Unique request ID for tracing.
274    pub request_id: Uuid,
275    /// Trace ID for distributed tracing.
276    pub trace_id: String,
277    /// Client IP address.
278    pub client_ip: Option<String>,
279    /// User agent string.
280    pub user_agent: Option<String>,
281    /// Request timestamp.
282    pub timestamp: chrono::DateTime<chrono::Utc>,
283}
284
285impl RequestMetadata {
286    /// Create new request metadata.
287    pub fn new() -> Self {
288        Self {
289            request_id: Uuid::new_v4(),
290            trace_id: Uuid::new_v4().to_string(),
291            client_ip: None,
292            user_agent: None,
293            timestamp: chrono::Utc::now(),
294        }
295    }
296
297    /// Create with a specific trace ID.
298    pub fn with_trace_id(trace_id: String) -> Self {
299        Self {
300            request_id: Uuid::new_v4(),
301            trace_id,
302            client_ip: None,
303            user_agent: None,
304            timestamp: chrono::Utc::now(),
305        }
306    }
307}
308
309impl Default for RequestMetadata {
310    fn default() -> Self {
311        Self::new()
312    }
313}
314
315/// Context for query functions (read-only database access).
316pub struct QueryContext {
317    /// Authentication context.
318    pub auth: AuthContext,
319    /// Request metadata.
320    pub request: RequestMetadata,
321    /// Database pool for read operations.
322    db_pool: sqlx::PgPool,
323    /// Environment variable provider.
324    env_provider: Arc<dyn EnvProvider>,
325}
326
327impl QueryContext {
328    /// Create a new query context.
329    pub fn new(db_pool: sqlx::PgPool, auth: AuthContext, request: RequestMetadata) -> Self {
330        Self {
331            auth,
332            request,
333            db_pool,
334            env_provider: Arc::new(RealEnvProvider::new()),
335        }
336    }
337
338    /// Create a query context with a custom environment provider.
339    pub fn with_env(
340        db_pool: sqlx::PgPool,
341        auth: AuthContext,
342        request: RequestMetadata,
343        env_provider: Arc<dyn EnvProvider>,
344    ) -> Self {
345        Self {
346            auth,
347            request,
348            db_pool,
349            env_provider,
350        }
351    }
352
353    pub fn db(&self) -> &sqlx::PgPool {
354        &self.db_pool
355    }
356
357    pub fn require_user_id(&self) -> crate::error::Result<Uuid> {
358        self.auth.require_user_id()
359    }
360
361    /// Like `require_user_id()` but for non-UUID auth providers.
362    pub fn require_subject(&self) -> crate::error::Result<&str> {
363        self.auth.require_subject()
364    }
365}
366
367impl EnvAccess for QueryContext {
368    fn env_provider(&self) -> &dyn EnvProvider {
369        self.env_provider.as_ref()
370    }
371}
372
373/// Callback type for looking up job info by name.
374pub type JobInfoLookup = Arc<dyn Fn(&str) -> Option<JobInfo> + Send + Sync>;
375
376/// Context for mutation functions (transactional database access).
377pub struct MutationContext {
378    /// Authentication context.
379    pub auth: AuthContext,
380    /// Request metadata.
381    pub request: RequestMetadata,
382    /// Database pool for transactional operations.
383    db_pool: sqlx::PgPool,
384    /// HTTP client with circuit breaker for external requests.
385    http_client: CircuitBreakerClient,
386    /// Optional job dispatcher for dispatching background jobs.
387    job_dispatch: Option<Arc<dyn JobDispatch>>,
388    /// Optional workflow dispatcher for starting workflows.
389    workflow_dispatch: Option<Arc<dyn WorkflowDispatch>>,
390    /// Environment variable provider.
391    env_provider: Arc<dyn EnvProvider>,
392    /// Transaction handle for transactional mutations.
393    tx: Option<Arc<AsyncMutex<Transaction<'static, Postgres>>>>,
394    /// Outbox buffer for jobs/workflows dispatched during transaction.
395    outbox: Option<Arc<Mutex<OutboxBuffer>>>,
396    /// Job info lookup for transactional dispatch.
397    job_info_lookup: Option<JobInfoLookup>,
398}
399
400impl MutationContext {
401    /// Create a new mutation context.
402    pub fn new(db_pool: sqlx::PgPool, auth: AuthContext, request: RequestMetadata) -> Self {
403        Self {
404            auth,
405            request,
406            db_pool,
407            http_client: CircuitBreakerClient::with_defaults(reqwest::Client::new()),
408            job_dispatch: None,
409            workflow_dispatch: None,
410            env_provider: Arc::new(RealEnvProvider::new()),
411            tx: None,
412            outbox: None,
413            job_info_lookup: None,
414        }
415    }
416
417    /// Create a mutation context with dispatch capabilities.
418    pub fn with_dispatch(
419        db_pool: sqlx::PgPool,
420        auth: AuthContext,
421        request: RequestMetadata,
422        http_client: CircuitBreakerClient,
423        job_dispatch: Option<Arc<dyn JobDispatch>>,
424        workflow_dispatch: Option<Arc<dyn WorkflowDispatch>>,
425    ) -> Self {
426        Self {
427            auth,
428            request,
429            db_pool,
430            http_client,
431            job_dispatch,
432            workflow_dispatch,
433            env_provider: Arc::new(RealEnvProvider::new()),
434            tx: None,
435            outbox: None,
436            job_info_lookup: None,
437        }
438    }
439
440    /// Create a mutation context with a custom environment provider.
441    pub fn with_env(
442        db_pool: sqlx::PgPool,
443        auth: AuthContext,
444        request: RequestMetadata,
445        http_client: CircuitBreakerClient,
446        job_dispatch: Option<Arc<dyn JobDispatch>>,
447        workflow_dispatch: Option<Arc<dyn WorkflowDispatch>>,
448        env_provider: Arc<dyn EnvProvider>,
449    ) -> Self {
450        Self {
451            auth,
452            request,
453            db_pool,
454            http_client,
455            job_dispatch,
456            workflow_dispatch,
457            env_provider,
458            tx: None,
459            outbox: None,
460            job_info_lookup: None,
461        }
462    }
463
464    /// Returns handles to transaction and outbox for the caller to commit/flush.
465    #[allow(clippy::type_complexity)]
466    pub fn with_transaction(
467        db_pool: sqlx::PgPool,
468        tx: Transaction<'static, Postgres>,
469        auth: AuthContext,
470        request: RequestMetadata,
471        http_client: CircuitBreakerClient,
472        job_info_lookup: JobInfoLookup,
473    ) -> (
474        Self,
475        Arc<AsyncMutex<Transaction<'static, Postgres>>>,
476        Arc<Mutex<OutboxBuffer>>,
477    ) {
478        let tx_handle = Arc::new(AsyncMutex::new(tx));
479        let outbox = Arc::new(Mutex::new(OutboxBuffer::default()));
480
481        let ctx = Self {
482            auth,
483            request,
484            db_pool,
485            http_client,
486            job_dispatch: None,
487            workflow_dispatch: None,
488            env_provider: Arc::new(RealEnvProvider::new()),
489            tx: Some(tx_handle.clone()),
490            outbox: Some(outbox.clone()),
491            job_info_lookup: Some(job_info_lookup),
492        };
493
494        (ctx, tx_handle, outbox)
495    }
496
497    pub fn is_transactional(&self) -> bool {
498        self.tx.is_some()
499    }
500
501    pub fn db(&self) -> DbConn<'_> {
502        match &self.tx {
503            Some(tx) => DbConn::Transaction(tx.clone()),
504            None => DbConn::Pool(&self.db_pool),
505        }
506    }
507
508    /// Direct pool access for operations that cannot run inside a transaction.
509    pub fn pool(&self) -> &sqlx::PgPool {
510        &self.db_pool
511    }
512
513    /// Get the HTTP client for external requests.
514    ///
515    /// The client includes circuit breaker protection that tracks failure rates
516    /// per host. After repeated failures, requests fail fast to prevent cascade
517    /// failures when downstream services are unhealthy.
518    pub fn http(&self) -> &reqwest::Client {
519        self.http_client.inner()
520    }
521
522    /// Get the circuit breaker client directly for advanced usage.
523    pub fn http_with_circuit_breaker(&self) -> &CircuitBreakerClient {
524        &self.http_client
525    }
526
527    pub fn require_user_id(&self) -> crate::error::Result<Uuid> {
528        self.auth.require_user_id()
529    }
530
531    pub fn require_subject(&self) -> crate::error::Result<&str> {
532        self.auth.require_subject()
533    }
534
535    /// In transactional mode, buffers for atomic commit; otherwise dispatches immediately.
536    pub async fn dispatch_job<T: serde::Serialize>(
537        &self,
538        job_type: &str,
539        args: T,
540    ) -> crate::error::Result<Uuid> {
541        let args_json = serde_json::to_value(args)?;
542
543        // Transactional mode: buffer the job for atomic commit
544        if let (Some(outbox), Some(job_info_lookup)) = (&self.outbox, &self.job_info_lookup) {
545            let job_info = job_info_lookup(job_type).ok_or_else(|| {
546                crate::error::ForgeError::NotFound(format!("Job type '{}' not found", job_type))
547            })?;
548
549            let pending = PendingJob {
550                id: Uuid::new_v4(),
551                job_type: job_type.to_string(),
552                args: args_json,
553                context: serde_json::json!({}),
554                owner_subject: self.auth.principal_id(),
555                priority: job_info.priority.as_i32(),
556                max_attempts: job_info.retry.max_attempts as i32,
557                worker_capability: job_info.worker_capability.map(|s| s.to_string()),
558            };
559
560            let job_id = pending.id;
561            outbox
562                .lock()
563                .expect("outbox lock poisoned")
564                .jobs
565                .push(pending);
566            return Ok(job_id);
567        }
568
569        // Non-transactional mode: dispatch immediately
570        let dispatcher = self.job_dispatch.as_ref().ok_or_else(|| {
571            crate::error::ForgeError::Internal("Job dispatch not available".into())
572        })?;
573        dispatcher
574            .dispatch_by_name(job_type, args_json, self.auth.principal_id())
575            .await
576    }
577
578    /// Dispatch a job with initial context.
579    pub async fn dispatch_job_with_context<T: serde::Serialize>(
580        &self,
581        job_type: &str,
582        args: T,
583        context: serde_json::Value,
584    ) -> crate::error::Result<Uuid> {
585        let args_json = serde_json::to_value(args)?;
586
587        if let (Some(outbox), Some(job_info_lookup)) = (&self.outbox, &self.job_info_lookup) {
588            let job_info = job_info_lookup(job_type).ok_or_else(|| {
589                crate::error::ForgeError::NotFound(format!("Job type '{}' not found", job_type))
590            })?;
591
592            let pending = PendingJob {
593                id: Uuid::new_v4(),
594                job_type: job_type.to_string(),
595                args: args_json,
596                context,
597                owner_subject: self.auth.principal_id(),
598                priority: job_info.priority.as_i32(),
599                max_attempts: job_info.retry.max_attempts as i32,
600                worker_capability: job_info.worker_capability.map(|s| s.to_string()),
601            };
602
603            let job_id = pending.id;
604            outbox
605                .lock()
606                .expect("outbox lock poisoned")
607                .jobs
608                .push(pending);
609            return Ok(job_id);
610        }
611
612        let dispatcher = self.job_dispatch.as_ref().ok_or_else(|| {
613            crate::error::ForgeError::Internal("Job dispatch not available".into())
614        })?;
615        dispatcher
616            .dispatch_by_name(job_type, args_json, self.auth.principal_id())
617            .await
618    }
619
620    /// Request cancellation for a job.
621    pub async fn cancel_job(
622        &self,
623        job_id: Uuid,
624        reason: Option<String>,
625    ) -> crate::error::Result<bool> {
626        let dispatcher = self.job_dispatch.as_ref().ok_or_else(|| {
627            crate::error::ForgeError::Internal("Job dispatch not available".into())
628        })?;
629        dispatcher.cancel(job_id, reason).await
630    }
631
632    /// In transactional mode, buffers for atomic commit; otherwise starts immediately.
633    pub async fn start_workflow<T: serde::Serialize>(
634        &self,
635        workflow_name: &str,
636        input: T,
637    ) -> crate::error::Result<Uuid> {
638        let input_json = serde_json::to_value(input)?;
639
640        // Transactional mode: buffer the workflow for atomic commit
641        if let Some(outbox) = &self.outbox {
642            let pending = PendingWorkflow {
643                id: Uuid::new_v4(),
644                workflow_name: workflow_name.to_string(),
645                input: input_json,
646                owner_subject: self.auth.principal_id(),
647            };
648
649            let workflow_id = pending.id;
650            outbox
651                .lock()
652                .expect("outbox lock poisoned")
653                .workflows
654                .push(pending);
655            return Ok(workflow_id);
656        }
657
658        // Non-transactional mode: start immediately
659        let dispatcher = self.workflow_dispatch.as_ref().ok_or_else(|| {
660            crate::error::ForgeError::Internal("Workflow dispatch not available".into())
661        })?;
662        dispatcher
663            .start_by_name(workflow_name, input_json, self.auth.principal_id())
664            .await
665    }
666}
667
668impl EnvAccess for MutationContext {
669    fn env_provider(&self) -> &dyn EnvProvider {
670        self.env_provider.as_ref()
671    }
672}
673
674#[cfg(test)]
675#[allow(clippy::unwrap_used, clippy::indexing_slicing)]
676mod tests {
677    use super::*;
678
679    #[test]
680    fn test_auth_context_unauthenticated() {
681        let ctx = AuthContext::unauthenticated();
682        assert!(!ctx.is_authenticated());
683        assert!(ctx.user_id().is_none());
684        assert!(ctx.require_user_id().is_err());
685    }
686
687    #[test]
688    fn test_auth_context_authenticated() {
689        let user_id = Uuid::new_v4();
690        let ctx = AuthContext::authenticated(
691            user_id,
692            vec!["admin".to_string(), "user".to_string()],
693            HashMap::new(),
694        );
695
696        assert!(ctx.is_authenticated());
697        assert_eq!(ctx.user_id(), Some(user_id));
698        assert!(ctx.require_user_id().is_ok());
699        assert!(ctx.has_role("admin"));
700        assert!(ctx.has_role("user"));
701        assert!(!ctx.has_role("superadmin"));
702        assert!(ctx.require_role("admin").is_ok());
703        assert!(ctx.require_role("superadmin").is_err());
704    }
705
706    #[test]
707    fn test_auth_context_with_claims() {
708        let mut claims = HashMap::new();
709        claims.insert("org_id".to_string(), serde_json::json!("org-123"));
710
711        let ctx = AuthContext::authenticated(Uuid::new_v4(), vec![], claims);
712
713        assert_eq!(ctx.claim("org_id"), Some(&serde_json::json!("org-123")));
714        assert!(ctx.claim("nonexistent").is_none());
715    }
716
717    #[test]
718    fn test_request_metadata() {
719        let meta = RequestMetadata::new();
720        assert!(!meta.trace_id.is_empty());
721        assert!(meta.client_ip.is_none());
722
723        let meta2 = RequestMetadata::with_trace_id("trace-123".to_string());
724        assert_eq!(meta2.trace_id, "trace-123");
725    }
726}