Skip to main content

forge_runtime/function/
router.rs

1use std::collections::BTreeMap;
2use std::sync::Arc;
3use std::time::Duration;
4
5use chrono::Utc;
6use forge_core::{
7    AuthContext, CircuitBreakerClient, ForgeError, FunctionInfo, FunctionKind, JobDispatch,
8    MutationContext, OutboxBuffer, PendingJob, PendingWorkflow, QueryContext, RequestMetadata,
9    Result, WorkflowDispatch,
10    job::JobStatus,
11    rate_limit::{RateLimitConfig, RateLimitKey},
12    workflow::WorkflowStatus,
13};
14use serde_json::Value;
15use tracing::Instrument;
16
17use super::cache::QueryCache;
18use super::registry::{BoxedMutationFn, FunctionEntry, FunctionRegistry};
19use crate::db::Database;
20use crate::rate_limit::HybridRateLimiter;
21
22/// Shared auth enforcement: checks public flag, authentication, and role.
23fn require_auth(is_public: bool, required_role: Option<&str>, auth: &AuthContext) -> Result<()> {
24    if is_public {
25        return Ok(());
26    }
27    if !auth.is_authenticated() {
28        return Err(ForgeError::Unauthorized("Authentication required".into()));
29    }
30    if let Some(role) = required_role
31        && !auth.has_role(role)
32    {
33        return Err(ForgeError::Forbidden(format!("Role '{role}' required")));
34    }
35    Ok(())
36}
37
38/// Result of routing a function call.
39pub enum RouteResult {
40    /// Query execution result.
41    Query(Value),
42    /// Mutation execution result.
43    Mutation(Value),
44    /// Job dispatch result (returns job_id).
45    Job(Value),
46    /// Workflow dispatch result (returns workflow_id).
47    Workflow(Value),
48}
49
50/// Routes function calls to the appropriate handler.
51pub struct FunctionRouter {
52    registry: Arc<FunctionRegistry>,
53    db: Database,
54    http_client: CircuitBreakerClient,
55    job_dispatcher: Option<Arc<dyn JobDispatch>>,
56    workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
57    rate_limiter: HybridRateLimiter,
58    query_cache: QueryCache,
59    token_issuer: Option<Arc<dyn forge_core::TokenIssuer>>,
60    token_ttl: forge_core::AuthTokenTtl,
61}
62
63impl FunctionRouter {
64    /// Create a new function router.
65    pub fn new(registry: Arc<FunctionRegistry>, db: Database) -> Self {
66        let rate_limiter = HybridRateLimiter::new(db.primary().clone());
67        Self {
68            registry,
69            db,
70            http_client: CircuitBreakerClient::with_defaults(reqwest::Client::new()),
71            job_dispatcher: None,
72            workflow_dispatcher: None,
73            rate_limiter,
74            query_cache: QueryCache::new(),
75            token_issuer: None,
76            token_ttl: forge_core::AuthTokenTtl::default(),
77        }
78    }
79
80    /// Create a new function router with a custom HTTP client.
81    pub fn with_http_client(
82        registry: Arc<FunctionRegistry>,
83        db: Database,
84        http_client: CircuitBreakerClient,
85    ) -> Self {
86        let rate_limiter = HybridRateLimiter::new(db.primary().clone());
87        Self {
88            registry,
89            db,
90            http_client,
91            job_dispatcher: None,
92            workflow_dispatcher: None,
93            rate_limiter,
94            query_cache: QueryCache::new(),
95            token_issuer: None,
96            token_ttl: forge_core::AuthTokenTtl::default(),
97        }
98    }
99
100    /// Set the token issuer for this router (enables `ctx.issue_token()` in mutations).
101    pub fn with_token_issuer(mut self, issuer: Arc<dyn forge_core::TokenIssuer>) -> Self {
102        self.token_issuer = Some(issuer);
103        self
104    }
105
106    /// Set the token TTL config for this router (configures `ctx.issue_token_pair()` durations).
107    pub fn with_token_ttl(mut self, ttl: forge_core::AuthTokenTtl) -> Self {
108        self.token_ttl = ttl;
109        self
110    }
111
112    /// Set the token TTL config (mutable reference version).
113    pub fn set_token_ttl(&mut self, ttl: forge_core::AuthTokenTtl) {
114        self.token_ttl = ttl;
115    }
116
117    /// Set the job dispatcher for this router.
118    pub fn with_job_dispatcher(mut self, dispatcher: Arc<dyn JobDispatch>) -> Self {
119        self.job_dispatcher = Some(dispatcher);
120        self
121    }
122
123    /// Set the workflow dispatcher for this router.
124    pub fn with_workflow_dispatcher(mut self, dispatcher: Arc<dyn WorkflowDispatch>) -> Self {
125        self.workflow_dispatcher = Some(dispatcher);
126        self
127    }
128
129    pub async fn route(
130        &self,
131        function_name: &str,
132        args: Value,
133        auth: AuthContext,
134        request: RequestMetadata,
135    ) -> Result<RouteResult> {
136        if let Some(entry) = self.registry.get(function_name) {
137            self.check_auth(entry.info(), &auth)?;
138            self.check_rate_limit(entry.info(), function_name, &auth, &request)
139                .await?;
140
141            return match entry {
142                FunctionEntry::Query { handler, info, .. } => {
143                    let pool = if info.consistent {
144                        self.db.primary().clone()
145                    } else {
146                        self.db.read_pool().clone()
147                    };
148
149                    let auth_scope = Self::auth_cache_scope(&auth);
150                    if let Some(ttl) = info.cache_ttl {
151                        if let Some(cached) =
152                            self.query_cache
153                                .get(function_name, &args, auth_scope.as_deref())
154                        {
155                            return Ok(RouteResult::Query(Value::clone(&cached)));
156                        }
157
158                        let ctx = QueryContext::new(pool, auth, request);
159                        let result = handler(&ctx, args.clone()).await?;
160
161                        self.query_cache.set(
162                            function_name,
163                            &args,
164                            auth_scope.as_deref(),
165                            result.clone(),
166                            Duration::from_secs(ttl),
167                        );
168
169                        Ok(RouteResult::Query(result))
170                    } else {
171                        let ctx = QueryContext::new(pool, auth, request);
172                        let result = handler(&ctx, args).await?;
173                        Ok(RouteResult::Query(result))
174                    }
175                }
176                FunctionEntry::Mutation { handler, info } => {
177                    if info.transactional {
178                        self.execute_transactional(info, handler, args, auth, request)
179                            .await
180                    } else {
181                        // Use primary for mutations
182                        let mut ctx = MutationContext::with_dispatch(
183                            self.db.primary().clone(),
184                            auth,
185                            request,
186                            self.http_client.clone(),
187                            self.job_dispatcher.clone(),
188                            self.workflow_dispatcher.clone(),
189                        );
190                        if let Some(ref issuer) = self.token_issuer {
191                            ctx.set_token_issuer(issuer.clone());
192                        }
193                        ctx.set_token_ttl(self.token_ttl.clone());
194                        ctx.set_http_timeout(info.http_timeout.map(Duration::from_secs));
195                        let result = handler(&ctx, args).await?;
196                        Ok(RouteResult::Mutation(result))
197                    }
198                }
199            };
200        }
201
202        if let Some(ref job_dispatcher) = self.job_dispatcher
203            && let Some(job_info) = job_dispatcher.get_info(function_name)
204        {
205            self.check_job_auth(&job_info, &auth)?;
206            match job_dispatcher
207                .dispatch_by_name(function_name, args.clone(), auth.principal_id())
208                .await
209            {
210                Ok(job_id) => {
211                    return Ok(RouteResult::Job(serde_json::json!({ "job_id": job_id })));
212                }
213                Err(ForgeError::NotFound(_)) => {}
214                Err(e) => return Err(e),
215            }
216        }
217
218        if let Some(ref workflow_dispatcher) = self.workflow_dispatcher
219            && let Some(workflow_info) = workflow_dispatcher.get_info(function_name)
220        {
221            self.check_workflow_auth(&workflow_info, &auth)?;
222            match workflow_dispatcher
223                .start_by_name(function_name, args.clone(), auth.principal_id())
224                .await
225            {
226                Ok(workflow_id) => {
227                    return Ok(RouteResult::Workflow(
228                        serde_json::json!({ "workflow_id": workflow_id }),
229                    ));
230                }
231                Err(ForgeError::NotFound(_)) => {}
232                Err(e) => return Err(e),
233            }
234        }
235
236        Err(ForgeError::NotFound(format!(
237            "Function '{}' not found",
238            function_name
239        )))
240    }
241
242    fn check_auth(&self, info: &FunctionInfo, auth: &AuthContext) -> Result<()> {
243        require_auth(info.is_public, info.required_role, auth)
244    }
245
246    fn check_job_auth(&self, info: &forge_core::job::JobInfo, auth: &AuthContext) -> Result<()> {
247        require_auth(info.is_public, info.required_role, auth)
248    }
249
250    fn check_workflow_auth(
251        &self,
252        info: &forge_core::workflow::WorkflowInfo,
253        auth: &AuthContext,
254    ) -> Result<()> {
255        require_auth(info.is_public, info.required_role, auth)
256    }
257
258    /// Check rate limit for a function call.
259    async fn check_rate_limit(
260        &self,
261        info: &FunctionInfo,
262        function_name: &str,
263        auth: &AuthContext,
264        request: &RequestMetadata,
265    ) -> Result<()> {
266        // Skip if no rate limit configured
267        let (requests, per_secs) = match (info.rate_limit_requests, info.rate_limit_per_secs) {
268            (Some(r), Some(p)) => (r, p),
269            _ => return Ok(()),
270        };
271
272        // Build rate limit config
273        let key_str = info.rate_limit_key.unwrap_or("user");
274        let key_type: RateLimitKey = match key_str.parse() {
275            Ok(k) => k,
276            Err(_) => {
277                tracing::error!(
278                    function = %function_name,
279                    key = %key_str,
280                    "Invalid rate limit key, falling back to 'user'"
281                );
282                RateLimitKey::default()
283            }
284        };
285
286        let config =
287            RateLimitConfig::new(requests, Duration::from_secs(per_secs)).with_key(key_type);
288
289        // Build bucket key
290        let bucket_key = self
291            .rate_limiter
292            .build_key(key_type, function_name, auth, request);
293
294        // Enforce rate limit
295        self.rate_limiter.enforce(&bucket_key, &config).await?;
296
297        Ok(())
298    }
299
300    fn auth_cache_scope(auth: &AuthContext) -> Option<String> {
301        if !auth.is_authenticated() {
302            return Some("anon".to_string());
303        }
304
305        // Include role + claims fingerprint to avoid cross-scope cache bleed.
306        let mut roles = auth.roles().to_vec();
307        roles.sort();
308        roles.dedup();
309
310        let mut claims = BTreeMap::new();
311        for (k, v) in auth.claims() {
312            claims.insert(k.clone(), v.clone());
313        }
314
315        use std::collections::hash_map::DefaultHasher;
316        use std::hash::{Hash, Hasher};
317
318        let mut hasher = DefaultHasher::new();
319        roles.hash(&mut hasher);
320        serde_json::to_string(&claims)
321            .unwrap_or_default()
322            .hash(&mut hasher);
323
324        let principal = auth
325            .principal_id()
326            .unwrap_or_else(|| "authenticated".to_string());
327
328        Some(format!(
329            "subject:{principal}:scope:{:016x}",
330            hasher.finish()
331        ))
332    }
333
334    /// Get the function kind by name.
335    pub fn get_function_kind(&self, function_name: &str) -> Option<FunctionKind> {
336        self.registry.get(function_name).map(|e| e.kind())
337    }
338
339    /// Check if a function exists.
340    pub fn has_function(&self, function_name: &str) -> bool {
341        self.registry.get(function_name).is_some()
342    }
343
344    async fn execute_transactional(
345        &self,
346        info: &FunctionInfo,
347        handler: &BoxedMutationFn,
348        args: Value,
349        auth: AuthContext,
350        request: RequestMetadata,
351    ) -> Result<RouteResult> {
352        let span = tracing::info_span!("db.transaction", db.system = "postgresql",);
353
354        async {
355            let primary = self.db.primary();
356            let tx = primary
357                .begin()
358                .await
359                .map_err(|e| ForgeError::Database(e.to_string()))?;
360
361            let job_dispatcher = self.job_dispatcher.clone();
362            let job_lookup: forge_core::JobInfoLookup =
363                Arc::new(move |name: &str| job_dispatcher.as_ref().and_then(|d| d.get_info(name)));
364
365            let (mut ctx, tx_handle, outbox) = MutationContext::with_transaction(
366                primary.clone(),
367                tx,
368                auth,
369                request,
370                self.http_client.clone(),
371                job_lookup,
372            );
373            if let Some(ref issuer) = self.token_issuer {
374                ctx.set_token_issuer(issuer.clone());
375            }
376            ctx.set_token_ttl(self.token_ttl.clone());
377            ctx.set_http_timeout(info.http_timeout.map(Duration::from_secs));
378
379            match handler(&ctx, args).await {
380                Ok(value) => {
381                    let buffer = {
382                        let guard = outbox.lock().unwrap_or_else(|poisoned| {
383                            tracing::error!("Outbox mutex was poisoned, recovering");
384                            poisoned.into_inner()
385                        });
386                        OutboxBuffer {
387                            jobs: guard.jobs.clone(),
388                            workflows: guard.workflows.clone(),
389                        }
390                    };
391
392                    let mut tx = Arc::try_unwrap(tx_handle)
393                        .map_err(|_| ForgeError::Internal("Transaction still in use".into()))?
394                        .into_inner();
395
396                    for job in &buffer.jobs {
397                        Self::insert_job(&mut tx, job).await?;
398                    }
399
400                    for workflow in &buffer.workflows {
401                        if self
402                            .workflow_dispatcher
403                            .as_ref()
404                            .and_then(|d| d.get_info(&workflow.workflow_name))
405                            .is_none()
406                        {
407                            return Err(ForgeError::NotFound(format!(
408                                "Workflow '{}' not found",
409                                workflow.workflow_name
410                            )));
411                        }
412                        Self::insert_workflow(&mut tx, workflow).await?;
413                    }
414
415                    tx.commit()
416                        .await
417                        .map_err(|e| ForgeError::Database(e.to_string()))?;
418
419                    Ok(RouteResult::Mutation(value))
420                }
421                Err(e) => Err(e),
422            }
423        }
424        .instrument(span)
425        .await
426    }
427
428    async fn insert_job(
429        tx: &mut sqlx::Transaction<'_, sqlx::Postgres>,
430        job: &PendingJob,
431    ) -> Result<()> {
432        let now = Utc::now();
433        sqlx::query!(
434            r#"
435            INSERT INTO forge_jobs (
436                id, job_type, input, job_context, status, priority, attempts, max_attempts,
437                worker_capability, owner_subject, scheduled_at, created_at
438            ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
439            "#,
440            job.id,
441            &job.job_type,
442            job.args as _,
443            job.context as _,
444            JobStatus::Pending.as_str(),
445            job.priority,
446            0i32,
447            job.max_attempts,
448            job.worker_capability.as_deref(),
449            job.owner_subject as _,
450            now,
451            now,
452        )
453        .execute(&mut **tx)
454        .await
455        .map_err(|e| ForgeError::Database(e.to_string()))?;
456
457        Ok(())
458    }
459
460    async fn insert_workflow(
461        tx: &mut sqlx::Transaction<'_, sqlx::Postgres>,
462        workflow: &PendingWorkflow,
463    ) -> Result<()> {
464        let now = Utc::now();
465        sqlx::query!(
466            r#"
467            INSERT INTO forge_workflow_runs (
468                id, workflow_name, owner_subject, input, status, current_step,
469                step_results, started_at, trace_id
470            ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
471            "#,
472            workflow.id,
473            &workflow.workflow_name,
474            workflow.owner_subject as _,
475            workflow.input as _,
476            WorkflowStatus::Created.as_str(),
477            Option::<String>::None,
478            serde_json::json!({}) as _,
479            now,
480            workflow.id.to_string(),
481        )
482        .execute(&mut **tx)
483        .await
484        .map_err(|e| ForgeError::Database(e.to_string()))?;
485
486        Ok(())
487    }
488}
489
490#[cfg(test)]
491mod tests {
492    use super::*;
493    use std::collections::HashMap;
494
495    #[test]
496    fn test_check_auth_public() {
497        let info = FunctionInfo {
498            name: "test",
499            description: None,
500            kind: FunctionKind::Query,
501            required_role: None,
502            is_public: true,
503            cache_ttl: None,
504            timeout: None,
505            http_timeout: None,
506            rate_limit_requests: None,
507            rate_limit_per_secs: None,
508            rate_limit_key: None,
509            log_level: None,
510            table_dependencies: &[],
511            selected_columns: &[],
512            transactional: false,
513            consistent: false,
514            max_upload_size_bytes: None,
515        };
516
517        let _auth = AuthContext::unauthenticated();
518
519        // Can't test check_auth directly without a router instance,
520        // but we can test the logic
521        assert!(info.is_public);
522    }
523
524    #[test]
525    fn test_auth_cache_scope_changes_with_claims() {
526        let user_id = uuid::Uuid::new_v4();
527        let auth_a = AuthContext::authenticated(
528            user_id,
529            vec!["user".to_string()],
530            HashMap::from([
531                (
532                    "sub".to_string(),
533                    serde_json::Value::String(user_id.to_string()),
534                ),
535                (
536                    "tenant_id".to_string(),
537                    serde_json::Value::String("tenant-a".to_string()),
538                ),
539            ]),
540        );
541        let auth_b = AuthContext::authenticated(
542            user_id,
543            vec!["user".to_string()],
544            HashMap::from([
545                (
546                    "sub".to_string(),
547                    serde_json::Value::String(user_id.to_string()),
548                ),
549                (
550                    "tenant_id".to_string(),
551                    serde_json::Value::String("tenant-b".to_string()),
552                ),
553            ]),
554        );
555
556        let scope_a = FunctionRouter::auth_cache_scope(&auth_a);
557        let scope_b = FunctionRouter::auth_cache_scope(&auth_b);
558        assert_ne!(scope_a, scope_b);
559    }
560}