Skip to main content

forge_core/function/
context.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use uuid::Uuid;
5
6use super::dispatch::{JobDispatch, WorkflowDispatch};
7use crate::env::{EnvAccess, EnvProvider, RealEnvProvider};
8
9/// Authentication context available to all functions.
10#[derive(Debug, Clone)]
11pub struct AuthContext {
12    /// The authenticated user ID (if any).
13    user_id: Option<Uuid>,
14    /// User roles.
15    roles: Vec<String>,
16    /// Custom claims from JWT.
17    claims: HashMap<String, serde_json::Value>,
18    /// Whether the request is authenticated.
19    authenticated: bool,
20}
21
22impl AuthContext {
23    /// Create an unauthenticated context.
24    pub fn unauthenticated() -> Self {
25        Self {
26            user_id: None,
27            roles: Vec::new(),
28            claims: HashMap::new(),
29            authenticated: false,
30        }
31    }
32
33    /// Create an authenticated context with a UUID user ID.
34    pub fn authenticated(
35        user_id: Uuid,
36        roles: Vec<String>,
37        claims: HashMap<String, serde_json::Value>,
38    ) -> Self {
39        Self {
40            user_id: Some(user_id),
41            roles,
42            claims,
43            authenticated: true,
44        }
45    }
46
47    /// Create an authenticated context without requiring a UUID user ID.
48    ///
49    /// Use this for auth providers that don't use UUID subjects (e.g., Firebase,
50    /// Clerk). The raw subject string is available via `subject()` method
51    /// from the "sub" claim.
52    pub fn authenticated_without_uuid(
53        roles: Vec<String>,
54        claims: HashMap<String, serde_json::Value>,
55    ) -> Self {
56        Self {
57            user_id: None,
58            roles,
59            claims,
60            authenticated: true,
61        }
62    }
63
64    /// Check if the user is authenticated.
65    pub fn is_authenticated(&self) -> bool {
66        self.authenticated
67    }
68
69    /// Get the user ID if authenticated.
70    pub fn user_id(&self) -> Option<Uuid> {
71        self.user_id
72    }
73
74    /// Get the user ID, returning an error if not authenticated.
75    pub fn require_user_id(&self) -> crate::error::Result<Uuid> {
76        self.user_id
77            .ok_or_else(|| crate::error::ForgeError::Unauthorized("Authentication required".into()))
78    }
79
80    /// Check if the user has a specific role.
81    pub fn has_role(&self, role: &str) -> bool {
82        self.roles.iter().any(|r| r == role)
83    }
84
85    /// Require a specific role, returning an error if not present.
86    pub fn require_role(&self, role: &str) -> crate::error::Result<()> {
87        if self.has_role(role) {
88            Ok(())
89        } else {
90            Err(crate::error::ForgeError::Forbidden(format!(
91                "Required role '{}' not present",
92                role
93            )))
94        }
95    }
96
97    /// Get a custom claim value.
98    pub fn claim(&self, key: &str) -> Option<&serde_json::Value> {
99        self.claims.get(key)
100    }
101
102    /// Get all roles.
103    pub fn roles(&self) -> &[String] {
104        &self.roles
105    }
106
107    /// Get the raw subject claim.
108    ///
109    /// This works with any provider's subject format (UUID, email, custom ID).
110    /// For providers like Firebase or Clerk that don't use UUIDs, use this
111    /// instead of `user_id()`.
112    pub fn subject(&self) -> Option<&str> {
113        self.claims.get("sub").and_then(|v| v.as_str())
114    }
115
116    /// Like `require_user_id()` but returns the raw subject string for non-UUID providers.
117    pub fn require_subject(&self) -> crate::error::Result<&str> {
118        if !self.authenticated {
119            return Err(crate::error::ForgeError::Unauthorized(
120                "Authentication required".to_string(),
121            ));
122        }
123        self.subject().ok_or_else(|| {
124            crate::error::ForgeError::Unauthorized("No subject claim in token".to_string())
125        })
126    }
127}
128
129/// Request metadata available to all functions.
130#[derive(Debug, Clone)]
131pub struct RequestMetadata {
132    /// Unique request ID for tracing.
133    pub request_id: Uuid,
134    /// Trace ID for distributed tracing.
135    pub trace_id: String,
136    /// Client IP address.
137    pub client_ip: Option<String>,
138    /// User agent string.
139    pub user_agent: Option<String>,
140    /// Request timestamp.
141    pub timestamp: chrono::DateTime<chrono::Utc>,
142}
143
144impl RequestMetadata {
145    /// Create new request metadata.
146    pub fn new() -> Self {
147        Self {
148            request_id: Uuid::new_v4(),
149            trace_id: Uuid::new_v4().to_string(),
150            client_ip: None,
151            user_agent: None,
152            timestamp: chrono::Utc::now(),
153        }
154    }
155
156    /// Create with a specific trace ID.
157    pub fn with_trace_id(trace_id: String) -> Self {
158        Self {
159            request_id: Uuid::new_v4(),
160            trace_id,
161            client_ip: None,
162            user_agent: None,
163            timestamp: chrono::Utc::now(),
164        }
165    }
166}
167
168impl Default for RequestMetadata {
169    fn default() -> Self {
170        Self::new()
171    }
172}
173
174/// Context for query functions (read-only database access).
175pub struct QueryContext {
176    /// Authentication context.
177    pub auth: AuthContext,
178    /// Request metadata.
179    pub request: RequestMetadata,
180    /// Database pool for read operations.
181    db_pool: sqlx::PgPool,
182    /// Environment variable provider.
183    env_provider: Arc<dyn EnvProvider>,
184}
185
186impl QueryContext {
187    /// Create a new query context.
188    pub fn new(db_pool: sqlx::PgPool, auth: AuthContext, request: RequestMetadata) -> Self {
189        Self {
190            auth,
191            request,
192            db_pool,
193            env_provider: Arc::new(RealEnvProvider::new()),
194        }
195    }
196
197    /// Create a query context with a custom environment provider.
198    pub fn with_env(
199        db_pool: sqlx::PgPool,
200        auth: AuthContext,
201        request: RequestMetadata,
202        env_provider: Arc<dyn EnvProvider>,
203    ) -> Self {
204        Self {
205            auth,
206            request,
207            db_pool,
208            env_provider,
209        }
210    }
211
212    /// Get a reference to the database pool.
213    pub fn db(&self) -> &sqlx::PgPool {
214        &self.db_pool
215    }
216
217    /// Get the authenticated user ID or return an error.
218    pub fn require_user_id(&self) -> crate::error::Result<Uuid> {
219        self.auth.require_user_id()
220    }
221
222    /// Like `require_user_id()` but for non-UUID auth providers.
223    pub fn require_subject(&self) -> crate::error::Result<&str> {
224        self.auth.require_subject()
225    }
226}
227
228impl EnvAccess for QueryContext {
229    fn env_provider(&self) -> &dyn EnvProvider {
230        self.env_provider.as_ref()
231    }
232}
233
234/// Context for mutation functions (transactional database access).
235pub struct MutationContext {
236    /// Authentication context.
237    pub auth: AuthContext,
238    /// Request metadata.
239    pub request: RequestMetadata,
240    /// Database pool for transactional operations.
241    db_pool: sqlx::PgPool,
242    /// HTTP client for external requests.
243    http_client: reqwest::Client,
244    /// Optional job dispatcher for dispatching background jobs.
245    job_dispatch: Option<Arc<dyn JobDispatch>>,
246    /// Optional workflow dispatcher for starting workflows.
247    workflow_dispatch: Option<Arc<dyn WorkflowDispatch>>,
248    /// Environment variable provider.
249    env_provider: Arc<dyn EnvProvider>,
250}
251
252impl MutationContext {
253    /// Create a new mutation context.
254    pub fn new(db_pool: sqlx::PgPool, auth: AuthContext, request: RequestMetadata) -> Self {
255        Self {
256            auth,
257            request,
258            db_pool,
259            http_client: reqwest::Client::new(),
260            job_dispatch: None,
261            workflow_dispatch: None,
262            env_provider: Arc::new(RealEnvProvider::new()),
263        }
264    }
265
266    /// Create a mutation context with dispatch capabilities.
267    pub fn with_dispatch(
268        db_pool: sqlx::PgPool,
269        auth: AuthContext,
270        request: RequestMetadata,
271        http_client: reqwest::Client,
272        job_dispatch: Option<Arc<dyn JobDispatch>>,
273        workflow_dispatch: Option<Arc<dyn WorkflowDispatch>>,
274    ) -> Self {
275        Self {
276            auth,
277            request,
278            db_pool,
279            http_client,
280            job_dispatch,
281            workflow_dispatch,
282            env_provider: Arc::new(RealEnvProvider::new()),
283        }
284    }
285
286    /// Create a mutation context with a custom environment provider.
287    pub fn with_env(
288        db_pool: sqlx::PgPool,
289        auth: AuthContext,
290        request: RequestMetadata,
291        http_client: reqwest::Client,
292        job_dispatch: Option<Arc<dyn JobDispatch>>,
293        workflow_dispatch: Option<Arc<dyn WorkflowDispatch>>,
294        env_provider: Arc<dyn EnvProvider>,
295    ) -> Self {
296        Self {
297            auth,
298            request,
299            db_pool,
300            http_client,
301            job_dispatch,
302            workflow_dispatch,
303            env_provider,
304        }
305    }
306
307    /// Get a reference to the database pool.
308    pub fn db(&self) -> &sqlx::PgPool {
309        &self.db_pool
310    }
311
312    /// Get a reference to the HTTP client.
313    pub fn http(&self) -> &reqwest::Client {
314        &self.http_client
315    }
316
317    /// Get the authenticated user ID or return an error.
318    pub fn require_user_id(&self) -> crate::error::Result<Uuid> {
319        self.auth.require_user_id()
320    }
321
322    /// Like `require_user_id()` but for non-UUID auth providers.
323    pub fn require_subject(&self) -> crate::error::Result<&str> {
324        self.auth.require_subject()
325    }
326
327    /// Dispatch a background job.
328    ///
329    /// # Arguments
330    /// * `job_type` - The registered name of the job type
331    /// * `args` - The arguments for the job (will be serialized to JSON)
332    ///
333    /// # Returns
334    /// The UUID of the dispatched job, or an error if dispatch is not available.
335    pub async fn dispatch_job<T: serde::Serialize>(
336        &self,
337        job_type: &str,
338        args: T,
339    ) -> crate::error::Result<Uuid> {
340        let dispatcher = self.job_dispatch.as_ref().ok_or_else(|| {
341            crate::error::ForgeError::Internal("Job dispatch not available".into())
342        })?;
343        let args_json = serde_json::to_value(args)?;
344        dispatcher.dispatch_by_name(job_type, args_json).await
345    }
346
347    /// Start a workflow.
348    ///
349    /// # Arguments
350    /// * `workflow_name` - The registered name of the workflow
351    /// * `input` - The input for the workflow (will be serialized to JSON)
352    ///
353    /// # Returns
354    /// The UUID of the started workflow run, or an error if dispatch is not available.
355    pub async fn start_workflow<T: serde::Serialize>(
356        &self,
357        workflow_name: &str,
358        input: T,
359    ) -> crate::error::Result<Uuid> {
360        let dispatcher = self.workflow_dispatch.as_ref().ok_or_else(|| {
361            crate::error::ForgeError::Internal("Workflow dispatch not available".into())
362        })?;
363        let input_json = serde_json::to_value(input)?;
364        dispatcher.start_by_name(workflow_name, input_json).await
365    }
366}
367
368impl EnvAccess for MutationContext {
369    fn env_provider(&self) -> &dyn EnvProvider {
370        self.env_provider.as_ref()
371    }
372}
373
374#[cfg(test)]
375mod tests {
376    use super::*;
377
378    #[test]
379    fn test_auth_context_unauthenticated() {
380        let ctx = AuthContext::unauthenticated();
381        assert!(!ctx.is_authenticated());
382        assert!(ctx.user_id().is_none());
383        assert!(ctx.require_user_id().is_err());
384    }
385
386    #[test]
387    fn test_auth_context_authenticated() {
388        let user_id = Uuid::new_v4();
389        let ctx = AuthContext::authenticated(
390            user_id,
391            vec!["admin".to_string(), "user".to_string()],
392            HashMap::new(),
393        );
394
395        assert!(ctx.is_authenticated());
396        assert_eq!(ctx.user_id(), Some(user_id));
397        assert!(ctx.require_user_id().is_ok());
398        assert!(ctx.has_role("admin"));
399        assert!(ctx.has_role("user"));
400        assert!(!ctx.has_role("superadmin"));
401        assert!(ctx.require_role("admin").is_ok());
402        assert!(ctx.require_role("superadmin").is_err());
403    }
404
405    #[test]
406    fn test_auth_context_with_claims() {
407        let mut claims = HashMap::new();
408        claims.insert("org_id".to_string(), serde_json::json!("org-123"));
409
410        let ctx = AuthContext::authenticated(Uuid::new_v4(), vec![], claims);
411
412        assert_eq!(ctx.claim("org_id"), Some(&serde_json::json!("org-123")));
413        assert!(ctx.claim("nonexistent").is_none());
414    }
415
416    #[test]
417    fn test_request_metadata() {
418        let meta = RequestMetadata::new();
419        assert!(!meta.trace_id.is_empty());
420        assert!(meta.client_ip.is_none());
421
422        let meta2 = RequestMetadata::with_trace_id("trace-123".to_string());
423        assert_eq!(meta2.trace_id, "trace-123");
424    }
425}