Skip to main content

forge_core/function/
context.rs

1use std::collections::HashMap;
2use std::sync::{Arc, Mutex};
3
4use sqlx::postgres::{PgArguments, PgQueryResult, PgRow};
5use sqlx::{FromRow, Postgres, Transaction};
6use tokio::sync::Mutex as AsyncMutex;
7use uuid::Uuid;
8
9use super::dispatch::{JobDispatch, WorkflowDispatch};
10use crate::env::{EnvAccess, EnvProvider, RealEnvProvider};
11use crate::http::CircuitBreakerClient;
12use crate::job::JobInfo;
13
14/// Abstracts over pool and transaction connections so handlers can work with either.
15pub enum DbConn<'a> {
16    Pool(&'a sqlx::PgPool),
17    Transaction(Arc<AsyncMutex<Transaction<'static, Postgres>>>),
18}
19
20impl DbConn<'_> {
21    pub async fn fetch_one<'q, O>(
22        &self,
23        query: sqlx::query::QueryAs<'q, Postgres, O, PgArguments>,
24    ) -> sqlx::Result<O>
25    where
26        O: Send + Unpin + for<'r> FromRow<'r, PgRow>,
27    {
28        match self {
29            DbConn::Pool(pool) => query.fetch_one(*pool).await,
30            DbConn::Transaction(tx) => query.fetch_one(&mut **tx.lock().await).await,
31        }
32    }
33
34    pub async fn fetch_optional<'q, O>(
35        &self,
36        query: sqlx::query::QueryAs<'q, Postgres, O, PgArguments>,
37    ) -> sqlx::Result<Option<O>>
38    where
39        O: Send + Unpin + for<'r> FromRow<'r, PgRow>,
40    {
41        match self {
42            DbConn::Pool(pool) => query.fetch_optional(*pool).await,
43            DbConn::Transaction(tx) => query.fetch_optional(&mut **tx.lock().await).await,
44        }
45    }
46
47    pub async fn fetch_all<'q, O>(
48        &self,
49        query: sqlx::query::QueryAs<'q, Postgres, O, PgArguments>,
50    ) -> sqlx::Result<Vec<O>>
51    where
52        O: Send + Unpin + for<'r> FromRow<'r, PgRow>,
53    {
54        match self {
55            DbConn::Pool(pool) => query.fetch_all(*pool).await,
56            DbConn::Transaction(tx) => query.fetch_all(&mut **tx.lock().await).await,
57        }
58    }
59
60    pub async fn execute<'q>(
61        &self,
62        query: sqlx::query::Query<'q, Postgres, PgArguments>,
63    ) -> sqlx::Result<PgQueryResult> {
64        match self {
65            DbConn::Pool(pool) => query.execute(*pool).await,
66            DbConn::Transaction(tx) => query.execute(&mut **tx.lock().await).await,
67        }
68    }
69}
70
71#[derive(Debug, Clone)]
72pub struct PendingJob {
73    pub id: Uuid,
74    pub job_type: String,
75    pub args: serde_json::Value,
76    pub context: serde_json::Value,
77    pub priority: i32,
78    pub max_attempts: i32,
79    pub worker_capability: Option<String>,
80}
81
82#[derive(Debug, Clone)]
83pub struct PendingWorkflow {
84    pub id: Uuid,
85    pub workflow_name: String,
86    pub input: serde_json::Value,
87}
88
89#[derive(Default)]
90pub struct OutboxBuffer {
91    pub jobs: Vec<PendingJob>,
92    pub workflows: Vec<PendingWorkflow>,
93}
94
95/// Authentication context available to all functions.
96#[derive(Debug, Clone)]
97pub struct AuthContext {
98    /// The authenticated user ID (if any).
99    user_id: Option<Uuid>,
100    /// User roles.
101    roles: Vec<String>,
102    /// Custom claims from JWT.
103    claims: HashMap<String, serde_json::Value>,
104    /// Whether the request is authenticated.
105    authenticated: bool,
106}
107
108impl AuthContext {
109    /// Create an unauthenticated context.
110    pub fn unauthenticated() -> Self {
111        Self {
112            user_id: None,
113            roles: Vec::new(),
114            claims: HashMap::new(),
115            authenticated: false,
116        }
117    }
118
119    /// Create an authenticated context with a UUID user ID.
120    pub fn authenticated(
121        user_id: Uuid,
122        roles: Vec<String>,
123        claims: HashMap<String, serde_json::Value>,
124    ) -> Self {
125        Self {
126            user_id: Some(user_id),
127            roles,
128            claims,
129            authenticated: true,
130        }
131    }
132
133    /// Create an authenticated context without requiring a UUID user ID.
134    ///
135    /// Use this for auth providers that don't use UUID subjects (e.g., Firebase,
136    /// Clerk). The raw subject string is available via `subject()` method
137    /// from the "sub" claim.
138    pub fn authenticated_without_uuid(
139        roles: Vec<String>,
140        claims: HashMap<String, serde_json::Value>,
141    ) -> Self {
142        Self {
143            user_id: None,
144            roles,
145            claims,
146            authenticated: true,
147        }
148    }
149
150    /// Check if the user is authenticated.
151    pub fn is_authenticated(&self) -> bool {
152        self.authenticated
153    }
154
155    /// Get the user ID if authenticated.
156    pub fn user_id(&self) -> Option<Uuid> {
157        self.user_id
158    }
159
160    /// Get the user ID, returning an error if not authenticated.
161    pub fn require_user_id(&self) -> crate::error::Result<Uuid> {
162        self.user_id
163            .ok_or_else(|| crate::error::ForgeError::Unauthorized("Authentication required".into()))
164    }
165
166    /// Check if the user has a specific role.
167    pub fn has_role(&self, role: &str) -> bool {
168        self.roles.iter().any(|r| r == role)
169    }
170
171    /// Require a specific role, returning an error if not present.
172    pub fn require_role(&self, role: &str) -> crate::error::Result<()> {
173        if self.has_role(role) {
174            Ok(())
175        } else {
176            Err(crate::error::ForgeError::Forbidden(format!(
177                "Required role '{}' not present",
178                role
179            )))
180        }
181    }
182
183    /// Get a custom claim value.
184    pub fn claim(&self, key: &str) -> Option<&serde_json::Value> {
185        self.claims.get(key)
186    }
187
188    /// Get all roles.
189    pub fn roles(&self) -> &[String] {
190        &self.roles
191    }
192
193    /// Get the raw subject claim.
194    ///
195    /// This works with any provider's subject format (UUID, email, custom ID).
196    /// For providers like Firebase or Clerk that don't use UUIDs, use this
197    /// instead of `user_id()`.
198    pub fn subject(&self) -> Option<&str> {
199        self.claims.get("sub").and_then(|v| v.as_str())
200    }
201
202    /// Like `require_user_id()` but returns the raw subject string for non-UUID providers.
203    pub fn require_subject(&self) -> crate::error::Result<&str> {
204        if !self.authenticated {
205            return Err(crate::error::ForgeError::Unauthorized(
206                "Authentication required".to_string(),
207            ));
208        }
209        self.subject().ok_or_else(|| {
210            crate::error::ForgeError::Unauthorized("No subject claim in token".to_string())
211        })
212    }
213}
214
215/// Request metadata available to all functions.
216#[derive(Debug, Clone)]
217pub struct RequestMetadata {
218    /// Unique request ID for tracing.
219    pub request_id: Uuid,
220    /// Trace ID for distributed tracing.
221    pub trace_id: String,
222    /// Client IP address.
223    pub client_ip: Option<String>,
224    /// User agent string.
225    pub user_agent: Option<String>,
226    /// Request timestamp.
227    pub timestamp: chrono::DateTime<chrono::Utc>,
228}
229
230impl RequestMetadata {
231    /// Create new request metadata.
232    pub fn new() -> Self {
233        Self {
234            request_id: Uuid::new_v4(),
235            trace_id: Uuid::new_v4().to_string(),
236            client_ip: None,
237            user_agent: None,
238            timestamp: chrono::Utc::now(),
239        }
240    }
241
242    /// Create with a specific trace ID.
243    pub fn with_trace_id(trace_id: String) -> Self {
244        Self {
245            request_id: Uuid::new_v4(),
246            trace_id,
247            client_ip: None,
248            user_agent: None,
249            timestamp: chrono::Utc::now(),
250        }
251    }
252}
253
254impl Default for RequestMetadata {
255    fn default() -> Self {
256        Self::new()
257    }
258}
259
260/// Context for query functions (read-only database access).
261pub struct QueryContext {
262    /// Authentication context.
263    pub auth: AuthContext,
264    /// Request metadata.
265    pub request: RequestMetadata,
266    /// Database pool for read operations.
267    db_pool: sqlx::PgPool,
268    /// Environment variable provider.
269    env_provider: Arc<dyn EnvProvider>,
270}
271
272impl QueryContext {
273    /// Create a new query context.
274    pub fn new(db_pool: sqlx::PgPool, auth: AuthContext, request: RequestMetadata) -> Self {
275        Self {
276            auth,
277            request,
278            db_pool,
279            env_provider: Arc::new(RealEnvProvider::new()),
280        }
281    }
282
283    /// Create a query context with a custom environment provider.
284    pub fn with_env(
285        db_pool: sqlx::PgPool,
286        auth: AuthContext,
287        request: RequestMetadata,
288        env_provider: Arc<dyn EnvProvider>,
289    ) -> Self {
290        Self {
291            auth,
292            request,
293            db_pool,
294            env_provider,
295        }
296    }
297
298    /// Get a reference to the database pool.
299    pub fn db(&self) -> &sqlx::PgPool {
300        &self.db_pool
301    }
302
303    /// Get the authenticated user ID or return an error.
304    pub fn require_user_id(&self) -> crate::error::Result<Uuid> {
305        self.auth.require_user_id()
306    }
307
308    /// Like `require_user_id()` but for non-UUID auth providers.
309    pub fn require_subject(&self) -> crate::error::Result<&str> {
310        self.auth.require_subject()
311    }
312}
313
314impl EnvAccess for QueryContext {
315    fn env_provider(&self) -> &dyn EnvProvider {
316        self.env_provider.as_ref()
317    }
318}
319
320/// Callback type for looking up job info by name.
321pub type JobInfoLookup = Arc<dyn Fn(&str) -> Option<JobInfo> + Send + Sync>;
322
323/// Context for mutation functions (transactional database access).
324pub struct MutationContext {
325    /// Authentication context.
326    pub auth: AuthContext,
327    /// Request metadata.
328    pub request: RequestMetadata,
329    /// Database pool for transactional operations.
330    db_pool: sqlx::PgPool,
331    /// HTTP client with circuit breaker for external requests.
332    http_client: CircuitBreakerClient,
333    /// Optional job dispatcher for dispatching background jobs.
334    job_dispatch: Option<Arc<dyn JobDispatch>>,
335    /// Optional workflow dispatcher for starting workflows.
336    workflow_dispatch: Option<Arc<dyn WorkflowDispatch>>,
337    /// Environment variable provider.
338    env_provider: Arc<dyn EnvProvider>,
339    /// Transaction handle for transactional mutations.
340    tx: Option<Arc<AsyncMutex<Transaction<'static, Postgres>>>>,
341    /// Outbox buffer for jobs/workflows dispatched during transaction.
342    outbox: Option<Arc<Mutex<OutboxBuffer>>>,
343    /// Job info lookup for transactional dispatch.
344    job_info_lookup: Option<JobInfoLookup>,
345}
346
347impl MutationContext {
348    /// Create a new mutation context.
349    pub fn new(db_pool: sqlx::PgPool, auth: AuthContext, request: RequestMetadata) -> Self {
350        Self {
351            auth,
352            request,
353            db_pool,
354            http_client: CircuitBreakerClient::with_defaults(reqwest::Client::new()),
355            job_dispatch: None,
356            workflow_dispatch: None,
357            env_provider: Arc::new(RealEnvProvider::new()),
358            tx: None,
359            outbox: None,
360            job_info_lookup: None,
361        }
362    }
363
364    /// Create a mutation context with dispatch capabilities.
365    pub fn with_dispatch(
366        db_pool: sqlx::PgPool,
367        auth: AuthContext,
368        request: RequestMetadata,
369        http_client: CircuitBreakerClient,
370        job_dispatch: Option<Arc<dyn JobDispatch>>,
371        workflow_dispatch: Option<Arc<dyn WorkflowDispatch>>,
372    ) -> Self {
373        Self {
374            auth,
375            request,
376            db_pool,
377            http_client,
378            job_dispatch,
379            workflow_dispatch,
380            env_provider: Arc::new(RealEnvProvider::new()),
381            tx: None,
382            outbox: None,
383            job_info_lookup: None,
384        }
385    }
386
387    /// Create a mutation context with a custom environment provider.
388    pub fn with_env(
389        db_pool: sqlx::PgPool,
390        auth: AuthContext,
391        request: RequestMetadata,
392        http_client: CircuitBreakerClient,
393        job_dispatch: Option<Arc<dyn JobDispatch>>,
394        workflow_dispatch: Option<Arc<dyn WorkflowDispatch>>,
395        env_provider: Arc<dyn EnvProvider>,
396    ) -> Self {
397        Self {
398            auth,
399            request,
400            db_pool,
401            http_client,
402            job_dispatch,
403            workflow_dispatch,
404            env_provider,
405            tx: None,
406            outbox: None,
407            job_info_lookup: None,
408        }
409    }
410
411    /// Returns handles to transaction and outbox for the caller to commit/flush.
412    #[allow(clippy::type_complexity)]
413    pub fn with_transaction(
414        db_pool: sqlx::PgPool,
415        tx: Transaction<'static, Postgres>,
416        auth: AuthContext,
417        request: RequestMetadata,
418        http_client: CircuitBreakerClient,
419        job_info_lookup: JobInfoLookup,
420    ) -> (
421        Self,
422        Arc<AsyncMutex<Transaction<'static, Postgres>>>,
423        Arc<Mutex<OutboxBuffer>>,
424    ) {
425        let tx_handle = Arc::new(AsyncMutex::new(tx));
426        let outbox = Arc::new(Mutex::new(OutboxBuffer::default()));
427
428        let ctx = Self {
429            auth,
430            request,
431            db_pool,
432            http_client,
433            job_dispatch: None,
434            workflow_dispatch: None,
435            env_provider: Arc::new(RealEnvProvider::new()),
436            tx: Some(tx_handle.clone()),
437            outbox: Some(outbox.clone()),
438            job_info_lookup: Some(job_info_lookup),
439        };
440
441        (ctx, tx_handle, outbox)
442    }
443
444    pub fn is_transactional(&self) -> bool {
445        self.tx.is_some()
446    }
447
448    pub fn db(&self) -> DbConn<'_> {
449        match &self.tx {
450            Some(tx) => DbConn::Transaction(tx.clone()),
451            None => DbConn::Pool(&self.db_pool),
452        }
453    }
454
455    /// Direct pool access for operations that cannot run inside a transaction.
456    pub fn pool(&self) -> &sqlx::PgPool {
457        &self.db_pool
458    }
459
460    /// Get the HTTP client for external requests.
461    ///
462    /// The client includes circuit breaker protection that tracks failure rates
463    /// per host. After repeated failures, requests fail fast to prevent cascade
464    /// failures when downstream services are unhealthy.
465    pub fn http(&self) -> &reqwest::Client {
466        self.http_client.inner()
467    }
468
469    /// Get the circuit breaker client directly for advanced usage.
470    pub fn http_with_circuit_breaker(&self) -> &CircuitBreakerClient {
471        &self.http_client
472    }
473
474    pub fn require_user_id(&self) -> crate::error::Result<Uuid> {
475        self.auth.require_user_id()
476    }
477
478    pub fn require_subject(&self) -> crate::error::Result<&str> {
479        self.auth.require_subject()
480    }
481
482    /// In transactional mode, buffers for atomic commit; otherwise dispatches immediately.
483    pub async fn dispatch_job<T: serde::Serialize>(
484        &self,
485        job_type: &str,
486        args: T,
487    ) -> crate::error::Result<Uuid> {
488        let args_json = serde_json::to_value(args)?;
489
490        // Transactional mode: buffer the job for atomic commit
491        if let (Some(outbox), Some(job_info_lookup)) = (&self.outbox, &self.job_info_lookup) {
492            let job_info = job_info_lookup(job_type).ok_or_else(|| {
493                crate::error::ForgeError::NotFound(format!("Job type '{}' not found", job_type))
494            })?;
495
496            let pending = PendingJob {
497                id: Uuid::new_v4(),
498                job_type: job_type.to_string(),
499                args: args_json,
500                context: serde_json::json!({}),
501                priority: job_info.priority.as_i32(),
502                max_attempts: job_info.retry.max_attempts as i32,
503                worker_capability: job_info.worker_capability.map(|s| s.to_string()),
504            };
505
506            let job_id = pending.id;
507            outbox.lock().unwrap().jobs.push(pending);
508            return Ok(job_id);
509        }
510
511        // Non-transactional mode: dispatch immediately
512        let dispatcher = self.job_dispatch.as_ref().ok_or_else(|| {
513            crate::error::ForgeError::Internal("Job dispatch not available".into())
514        })?;
515        dispatcher.dispatch_by_name(job_type, args_json).await
516    }
517
518    /// Dispatch a job with initial context.
519    pub async fn dispatch_job_with_context<T: serde::Serialize>(
520        &self,
521        job_type: &str,
522        args: T,
523        context: serde_json::Value,
524    ) -> crate::error::Result<Uuid> {
525        let args_json = serde_json::to_value(args)?;
526
527        if let (Some(outbox), Some(job_info_lookup)) = (&self.outbox, &self.job_info_lookup) {
528            let job_info = job_info_lookup(job_type).ok_or_else(|| {
529                crate::error::ForgeError::NotFound(format!("Job type '{}' not found", job_type))
530            })?;
531
532            let pending = PendingJob {
533                id: Uuid::new_v4(),
534                job_type: job_type.to_string(),
535                args: args_json,
536                context,
537                priority: job_info.priority.as_i32(),
538                max_attempts: job_info.retry.max_attempts as i32,
539                worker_capability: job_info.worker_capability.map(|s| s.to_string()),
540            };
541
542            let job_id = pending.id;
543            outbox.lock().unwrap().jobs.push(pending);
544            return Ok(job_id);
545        }
546
547        let dispatcher = self.job_dispatch.as_ref().ok_or_else(|| {
548            crate::error::ForgeError::Internal("Job dispatch not available".into())
549        })?;
550        dispatcher.dispatch_by_name(job_type, args_json).await
551    }
552
553    /// Request cancellation for a job.
554    pub async fn cancel_job(
555        &self,
556        job_id: Uuid,
557        reason: Option<String>,
558    ) -> crate::error::Result<bool> {
559        let dispatcher = self.job_dispatch.as_ref().ok_or_else(|| {
560            crate::error::ForgeError::Internal("Job dispatch not available".into())
561        })?;
562        dispatcher.cancel(job_id, reason).await
563    }
564
565    /// In transactional mode, buffers for atomic commit; otherwise starts immediately.
566    pub async fn start_workflow<T: serde::Serialize>(
567        &self,
568        workflow_name: &str,
569        input: T,
570    ) -> crate::error::Result<Uuid> {
571        let input_json = serde_json::to_value(input)?;
572
573        // Transactional mode: buffer the workflow for atomic commit
574        if let Some(outbox) = &self.outbox {
575            let pending = PendingWorkflow {
576                id: Uuid::new_v4(),
577                workflow_name: workflow_name.to_string(),
578                input: input_json,
579            };
580
581            let workflow_id = pending.id;
582            outbox.lock().unwrap().workflows.push(pending);
583            return Ok(workflow_id);
584        }
585
586        // Non-transactional mode: start immediately
587        let dispatcher = self.workflow_dispatch.as_ref().ok_or_else(|| {
588            crate::error::ForgeError::Internal("Workflow dispatch not available".into())
589        })?;
590        dispatcher.start_by_name(workflow_name, input_json).await
591    }
592}
593
594impl EnvAccess for MutationContext {
595    fn env_provider(&self) -> &dyn EnvProvider {
596        self.env_provider.as_ref()
597    }
598}
599
600#[cfg(test)]
601mod tests {
602    use super::*;
603
604    #[test]
605    fn test_auth_context_unauthenticated() {
606        let ctx = AuthContext::unauthenticated();
607        assert!(!ctx.is_authenticated());
608        assert!(ctx.user_id().is_none());
609        assert!(ctx.require_user_id().is_err());
610    }
611
612    #[test]
613    fn test_auth_context_authenticated() {
614        let user_id = Uuid::new_v4();
615        let ctx = AuthContext::authenticated(
616            user_id,
617            vec!["admin".to_string(), "user".to_string()],
618            HashMap::new(),
619        );
620
621        assert!(ctx.is_authenticated());
622        assert_eq!(ctx.user_id(), Some(user_id));
623        assert!(ctx.require_user_id().is_ok());
624        assert!(ctx.has_role("admin"));
625        assert!(ctx.has_role("user"));
626        assert!(!ctx.has_role("superadmin"));
627        assert!(ctx.require_role("admin").is_ok());
628        assert!(ctx.require_role("superadmin").is_err());
629    }
630
631    #[test]
632    fn test_auth_context_with_claims() {
633        let mut claims = HashMap::new();
634        claims.insert("org_id".to_string(), serde_json::json!("org-123"));
635
636        let ctx = AuthContext::authenticated(Uuid::new_v4(), vec![], claims);
637
638        assert_eq!(ctx.claim("org_id"), Some(&serde_json::json!("org-123")));
639        assert!(ctx.claim("nonexistent").is_none());
640    }
641
642    #[test]
643    fn test_request_metadata() {
644        let meta = RequestMetadata::new();
645        assert!(!meta.trace_id.is_empty());
646        assert!(meta.client_ip.is_none());
647
648        let meta2 = RequestMetadata::with_trace_id("trace-123".to_string());
649        assert_eq!(meta2.trace_id, "trace-123");
650    }
651}