Skip to main content

forge_runtime/function/
router.rs

1use std::collections::BTreeMap;
2use std::sync::Arc;
3use std::time::Duration;
4
5use chrono::Utc;
6use forge_core::{
7    AuthContext, CircuitBreakerClient, ForgeError, FunctionInfo, FunctionKind, JobDispatch,
8    MutationContext, OutboxBuffer, PendingJob, PendingWorkflow, QueryContext, RequestMetadata,
9    Result, WorkflowDispatch,
10    job::JobStatus,
11    rate_limit::{RateLimitConfig, RateLimitKey},
12    workflow::WorkflowStatus,
13};
14use serde_json::Value;
15use tracing::Instrument;
16
17use super::cache::QueryCache;
18use super::registry::{BoxedMutationFn, FunctionEntry, FunctionRegistry};
19use crate::db::Database;
20use crate::rate_limit::HybridRateLimiter;
21
22/// Result of routing a function call.
23pub enum RouteResult {
24    /// Query execution result.
25    Query(Value),
26    /// Mutation execution result.
27    Mutation(Value),
28    /// Job dispatch result (returns job_id).
29    Job(Value),
30    /// Workflow dispatch result (returns workflow_id).
31    Workflow(Value),
32}
33
34/// Routes function calls to the appropriate handler.
35pub struct FunctionRouter {
36    registry: Arc<FunctionRegistry>,
37    db: Database,
38    http_client: CircuitBreakerClient,
39    job_dispatcher: Option<Arc<dyn JobDispatch>>,
40    workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
41    rate_limiter: HybridRateLimiter,
42    query_cache: QueryCache,
43    token_issuer: Option<Arc<dyn forge_core::TokenIssuer>>,
44}
45
46impl FunctionRouter {
47    /// Create a new function router.
48    pub fn new(registry: Arc<FunctionRegistry>, db: Database) -> Self {
49        let rate_limiter = HybridRateLimiter::new(db.primary().clone());
50        Self {
51            registry,
52            db,
53            http_client: CircuitBreakerClient::with_defaults(reqwest::Client::new()),
54            job_dispatcher: None,
55            workflow_dispatcher: None,
56            rate_limiter,
57            query_cache: QueryCache::new(),
58            token_issuer: None,
59        }
60    }
61
62    /// Create a new function router with a custom HTTP client.
63    pub fn with_http_client(
64        registry: Arc<FunctionRegistry>,
65        db: Database,
66        http_client: CircuitBreakerClient,
67    ) -> Self {
68        let rate_limiter = HybridRateLimiter::new(db.primary().clone());
69        Self {
70            registry,
71            db,
72            http_client,
73            job_dispatcher: None,
74            workflow_dispatcher: None,
75            rate_limiter,
76            query_cache: QueryCache::new(),
77            token_issuer: None,
78        }
79    }
80
81    /// Set the token issuer for this router (enables `ctx.issue_token()` in mutations).
82    pub fn with_token_issuer(mut self, issuer: Arc<dyn forge_core::TokenIssuer>) -> Self {
83        self.token_issuer = Some(issuer);
84        self
85    }
86
87    /// Set the job dispatcher for this router.
88    pub fn with_job_dispatcher(mut self, dispatcher: Arc<dyn JobDispatch>) -> Self {
89        self.job_dispatcher = Some(dispatcher);
90        self
91    }
92
93    /// Set the workflow dispatcher for this router.
94    pub fn with_workflow_dispatcher(mut self, dispatcher: Arc<dyn WorkflowDispatch>) -> Self {
95        self.workflow_dispatcher = Some(dispatcher);
96        self
97    }
98
99    pub async fn route(
100        &self,
101        function_name: &str,
102        args: Value,
103        auth: AuthContext,
104        request: RequestMetadata,
105    ) -> Result<RouteResult> {
106        if let Some(entry) = self.registry.get(function_name) {
107            self.check_auth(entry.info(), &auth)?;
108            self.check_rate_limit(entry.info(), function_name, &auth, &request)
109                .await?;
110            // Skip scope enforcement for functions with no input args.
111            // The JWT still carries identity, accessible via ctx.require_user_id().
112            let enforce = !entry.info().is_public && entry.info().has_input_args;
113            Self::check_identity_args(function_name, &args, &auth, enforce)?;
114
115            return match entry {
116                FunctionEntry::Query { handler, info, .. } => {
117                    let pool = if info.consistent {
118                        self.db.primary().clone()
119                    } else {
120                        self.db.read_pool().clone()
121                    };
122
123                    let auth_scope = Self::auth_cache_scope(&auth);
124                    if let Some(ttl) = info.cache_ttl {
125                        if let Some(cached) =
126                            self.query_cache
127                                .get(function_name, &args, auth_scope.as_deref())
128                        {
129                            return Ok(RouteResult::Query(Value::clone(&cached)));
130                        }
131
132                        let ctx = QueryContext::new(pool, auth, request);
133                        let result = handler(&ctx, args.clone()).await?;
134
135                        self.query_cache.set(
136                            function_name,
137                            &args,
138                            auth_scope.as_deref(),
139                            result.clone(),
140                            Duration::from_secs(ttl),
141                        );
142
143                        Ok(RouteResult::Query(result))
144                    } else {
145                        let ctx = QueryContext::new(pool, auth, request);
146                        let result = handler(&ctx, args).await?;
147                        Ok(RouteResult::Query(result))
148                    }
149                }
150                FunctionEntry::Mutation { handler, info } => {
151                    if info.transactional {
152                        self.execute_transactional(handler, args, auth, request)
153                            .await
154                    } else {
155                        // Use primary for mutations
156                        let mut ctx = MutationContext::with_dispatch(
157                            self.db.primary().clone(),
158                            auth,
159                            request,
160                            self.http_client.clone(),
161                            self.job_dispatcher.clone(),
162                            self.workflow_dispatcher.clone(),
163                        );
164                        if let Some(ref issuer) = self.token_issuer {
165                            ctx.set_token_issuer(issuer.clone());
166                        }
167                        let result = handler(&ctx, args).await?;
168                        Ok(RouteResult::Mutation(result))
169                    }
170                }
171            };
172        }
173
174        if let Some(ref job_dispatcher) = self.job_dispatcher
175            && let Some(job_info) = job_dispatcher.get_info(function_name)
176        {
177            self.check_job_auth(&job_info, &auth)?;
178            Self::check_identity_args(function_name, &args, &auth, !job_info.is_public)?;
179            match job_dispatcher
180                .dispatch_by_name(function_name, args.clone(), auth.principal_id())
181                .await
182            {
183                Ok(job_id) => {
184                    return Ok(RouteResult::Job(serde_json::json!({ "job_id": job_id })));
185                }
186                Err(ForgeError::NotFound(_)) => {}
187                Err(e) => return Err(e),
188            }
189        }
190
191        if let Some(ref workflow_dispatcher) = self.workflow_dispatcher
192            && let Some(workflow_info) = workflow_dispatcher.get_info(function_name)
193        {
194            self.check_workflow_auth(&workflow_info, &auth)?;
195            Self::check_identity_args(function_name, &args, &auth, !workflow_info.is_public)?;
196            match workflow_dispatcher
197                .start_by_name(function_name, args.clone(), auth.principal_id())
198                .await
199            {
200                Ok(workflow_id) => {
201                    return Ok(RouteResult::Workflow(
202                        serde_json::json!({ "workflow_id": workflow_id }),
203                    ));
204                }
205                Err(ForgeError::NotFound(_)) => {}
206                Err(e) => return Err(e),
207            }
208        }
209
210        Err(ForgeError::NotFound(format!(
211            "Function '{}' not found",
212            function_name
213        )))
214    }
215
216    fn check_auth(&self, info: &FunctionInfo, auth: &AuthContext) -> Result<()> {
217        if info.is_public {
218            return Ok(());
219        }
220
221        if !auth.is_authenticated() {
222            return Err(ForgeError::Unauthorized("Authentication required".into()));
223        }
224
225        if let Some(role) = info.required_role
226            && !auth.has_role(role)
227        {
228            return Err(ForgeError::Forbidden(format!("Role '{}' required", role)));
229        }
230
231        Ok(())
232    }
233
234    fn check_job_auth(&self, info: &forge_core::job::JobInfo, auth: &AuthContext) -> Result<()> {
235        if info.is_public {
236            return Ok(());
237        }
238
239        if !auth.is_authenticated() {
240            return Err(ForgeError::Unauthorized("Authentication required".into()));
241        }
242
243        if let Some(role) = info.required_role
244            && !auth.has_role(role)
245        {
246            return Err(ForgeError::Forbidden(format!("Role '{}' required", role)));
247        }
248
249        Ok(())
250    }
251
252    fn check_workflow_auth(
253        &self,
254        info: &forge_core::workflow::WorkflowInfo,
255        auth: &AuthContext,
256    ) -> Result<()> {
257        if info.is_public {
258            return Ok(());
259        }
260
261        if !auth.is_authenticated() {
262            return Err(ForgeError::Unauthorized("Authentication required".into()));
263        }
264
265        if let Some(role) = info.required_role
266            && !auth.has_role(role)
267        {
268            return Err(ForgeError::Forbidden(format!("Role '{}' required", role)));
269        }
270
271        Ok(())
272    }
273
274    /// Check rate limit for a function call.
275    async fn check_rate_limit(
276        &self,
277        info: &FunctionInfo,
278        function_name: &str,
279        auth: &AuthContext,
280        request: &RequestMetadata,
281    ) -> Result<()> {
282        // Skip if no rate limit configured
283        let (requests, per_secs) = match (info.rate_limit_requests, info.rate_limit_per_secs) {
284            (Some(r), Some(p)) => (r, p),
285            _ => return Ok(()),
286        };
287
288        // Build rate limit config
289        let key_str = info.rate_limit_key.unwrap_or("user");
290        let key_type: RateLimitKey = match key_str.parse() {
291            Ok(k) => k,
292            Err(_) => {
293                tracing::warn!(
294                    function = %function_name,
295                    key = %key_str,
296                    "Invalid rate limit key, falling back to 'user'"
297                );
298                RateLimitKey::default()
299            }
300        };
301
302        let config =
303            RateLimitConfig::new(requests, Duration::from_secs(per_secs)).with_key(key_type);
304
305        // Build bucket key
306        let bucket_key = self
307            .rate_limiter
308            .build_key(key_type, function_name, auth, request);
309
310        // Enforce rate limit
311        self.rate_limiter.enforce(&bucket_key, &config).await?;
312
313        Ok(())
314    }
315
316    fn auth_cache_scope(auth: &AuthContext) -> Option<String> {
317        if !auth.is_authenticated() {
318            return Some("anon".to_string());
319        }
320
321        // Include role + claims fingerprint to avoid cross-scope cache bleed.
322        let mut roles = auth.roles().to_vec();
323        roles.sort();
324        roles.dedup();
325
326        let mut claims = BTreeMap::new();
327        for (k, v) in auth.claims() {
328            claims.insert(k.clone(), v.clone());
329        }
330
331        use std::collections::hash_map::DefaultHasher;
332        use std::hash::{Hash, Hasher};
333
334        let mut hasher = DefaultHasher::new();
335        roles.hash(&mut hasher);
336        serde_json::to_string(&claims)
337            .unwrap_or_default()
338            .hash(&mut hasher);
339
340        let principal = auth
341            .principal_id()
342            .unwrap_or_else(|| "authenticated".to_string());
343
344        Some(format!(
345            "subject:{principal}:scope:{:016x}",
346            hasher.finish()
347        ))
348    }
349
350    fn check_identity_args(
351        function_name: &str,
352        args: &Value,
353        auth: &AuthContext,
354        enforce_scope: bool,
355    ) -> Result<()> {
356        if auth.is_admin() {
357            return Ok(());
358        }
359
360        let Some(obj) = args.as_object() else {
361            if enforce_scope && auth.is_authenticated() {
362                return Err(ForgeError::Forbidden(format!(
363                    "Function '{function_name}' must include identity or tenant scope arguments"
364                )));
365            }
366            return Ok(());
367        };
368
369        let mut principal_values: Vec<String> = Vec::new();
370        if let Some(user_id) = auth.user_id().map(|id| id.to_string()) {
371            principal_values.push(user_id);
372        }
373        if let Some(subject) = auth.principal_id()
374            && !principal_values.iter().any(|v| v == &subject)
375        {
376            principal_values.push(subject);
377        }
378
379        let mut has_scope_key = false;
380
381        for key in [
382            "user_id",
383            "userId",
384            "owner_id",
385            "ownerId",
386            "owner_subject",
387            "ownerSubject",
388            "subject",
389            "sub",
390            "principal_id",
391            "principalId",
392        ] {
393            let Some(value) = obj.get(key) else {
394                continue;
395            };
396            has_scope_key = true;
397
398            if !auth.is_authenticated() {
399                return Err(ForgeError::Unauthorized(format!(
400                    "Function '{function_name}' requires authentication for identity-scoped argument '{key}'"
401                )));
402            }
403
404            let Value::String(actual) = value else {
405                return Err(ForgeError::InvalidArgument(format!(
406                    "Function '{function_name}' argument '{key}' must be a non-empty string"
407                )));
408            };
409
410            if actual.trim().is_empty() || !principal_values.iter().any(|v| v == actual) {
411                return Err(ForgeError::Forbidden(format!(
412                    "Function '{function_name}' argument '{key}' does not match authenticated principal"
413                )));
414            }
415        }
416
417        for key in ["tenant_id", "tenantId"] {
418            let Some(value) = obj.get(key) else {
419                continue;
420            };
421            has_scope_key = true;
422
423            if !auth.is_authenticated() {
424                return Err(ForgeError::Unauthorized(format!(
425                    "Function '{function_name}' requires authentication for tenant-scoped argument '{key}'"
426                )));
427            }
428
429            let expected = auth
430                .claim("tenant_id")
431                .and_then(|v| v.as_str())
432                .ok_or_else(|| {
433                    ForgeError::Forbidden(format!(
434                        "Function '{function_name}' argument '{key}' is not allowed for this principal"
435                    ))
436                })?;
437
438            let Value::String(actual) = value else {
439                return Err(ForgeError::InvalidArgument(format!(
440                    "Function '{function_name}' argument '{key}' must be a non-empty string"
441                )));
442            };
443
444            if actual.trim().is_empty() || actual != expected {
445                return Err(ForgeError::Forbidden(format!(
446                    "Function '{function_name}' argument '{key}' does not match authenticated tenant"
447                )));
448            }
449        }
450
451        if enforce_scope && auth.is_authenticated() && !has_scope_key {
452            return Err(ForgeError::Forbidden(format!(
453                "Function '{function_name}' must include identity or tenant scope arguments"
454            )));
455        }
456
457        Ok(())
458    }
459
460    /// Get the function kind by name.
461    pub fn get_function_kind(&self, function_name: &str) -> Option<FunctionKind> {
462        self.registry.get(function_name).map(|e| e.kind())
463    }
464
465    /// Check if a function exists.
466    pub fn has_function(&self, function_name: &str) -> bool {
467        self.registry.get(function_name).is_some()
468    }
469
470    async fn execute_transactional(
471        &self,
472        handler: &BoxedMutationFn,
473        args: Value,
474        auth: AuthContext,
475        request: RequestMetadata,
476    ) -> Result<RouteResult> {
477        let span = tracing::info_span!("db.transaction", db.system = "postgresql",);
478
479        async {
480            let primary = self.db.primary();
481            let tx = primary
482                .begin()
483                .await
484                .map_err(|e| ForgeError::Database(e.to_string()))?;
485
486            let job_dispatcher = self.job_dispatcher.clone();
487            let job_lookup: forge_core::JobInfoLookup =
488                Arc::new(move |name: &str| job_dispatcher.as_ref().and_then(|d| d.get_info(name)));
489
490            let (mut ctx, tx_handle, outbox) = MutationContext::with_transaction(
491                primary.clone(),
492                tx,
493                auth,
494                request,
495                self.http_client.clone(),
496                job_lookup,
497            );
498            if let Some(ref issuer) = self.token_issuer {
499                ctx.set_token_issuer(issuer.clone());
500            }
501
502            match handler(&ctx, args).await {
503                Ok(value) => {
504                    let buffer = {
505                        let guard = outbox.lock().expect("outbox mutex poisoned");
506                        OutboxBuffer {
507                            jobs: guard.jobs.clone(),
508                            workflows: guard.workflows.clone(),
509                        }
510                    };
511
512                    let mut tx = Arc::try_unwrap(tx_handle)
513                        .map_err(|_| ForgeError::Internal("Transaction still in use".into()))?
514                        .into_inner();
515
516                    for job in &buffer.jobs {
517                        Self::insert_job(&mut tx, job).await?;
518                    }
519
520                    for workflow in &buffer.workflows {
521                        let version = self
522                            .workflow_dispatcher
523                            .as_ref()
524                            .and_then(|d| d.get_info(&workflow.workflow_name))
525                            .map(|info| info.version)
526                            .ok_or_else(|| {
527                                ForgeError::NotFound(format!(
528                                    "Workflow '{}' not found",
529                                    workflow.workflow_name
530                                ))
531                            })?;
532                        Self::insert_workflow(&mut tx, workflow, version).await?;
533                    }
534
535                    tx.commit()
536                        .await
537                        .map_err(|e| ForgeError::Database(e.to_string()))?;
538
539                    Ok(RouteResult::Mutation(value))
540                }
541                Err(e) => Err(e),
542            }
543        }
544        .instrument(span)
545        .await
546    }
547
548    async fn insert_job(
549        tx: &mut sqlx::Transaction<'_, sqlx::Postgres>,
550        job: &PendingJob,
551    ) -> Result<()> {
552        let now = Utc::now();
553        sqlx::query(
554            r#"
555            INSERT INTO forge_jobs (
556                id, job_type, input, job_context, status, priority, attempts, max_attempts,
557                worker_capability, owner_subject, scheduled_at, created_at
558            ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
559            "#,
560        )
561        .bind(job.id)
562        .bind(&job.job_type)
563        .bind(&job.args)
564        .bind(&job.context)
565        .bind(JobStatus::Pending.as_str())
566        .bind(job.priority)
567        .bind(0i32)
568        .bind(job.max_attempts)
569        .bind(&job.worker_capability)
570        .bind(&job.owner_subject)
571        .bind(now)
572        .bind(now)
573        .execute(&mut **tx)
574        .await
575        .map_err(|e| ForgeError::Database(e.to_string()))?;
576
577        Ok(())
578    }
579
580    async fn insert_workflow(
581        tx: &mut sqlx::Transaction<'_, sqlx::Postgres>,
582        workflow: &PendingWorkflow,
583        version: u32,
584    ) -> Result<()> {
585        let now = Utc::now();
586        sqlx::query(
587            r#"
588            INSERT INTO forge_workflow_runs (
589                id, workflow_name, version, owner_subject, input, status, current_step,
590                step_results, started_at, trace_id
591            ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
592            "#,
593        )
594        .bind(workflow.id)
595        .bind(&workflow.workflow_name)
596        .bind(version as i32)
597        .bind(&workflow.owner_subject)
598        .bind(&workflow.input)
599        .bind(WorkflowStatus::Created.as_str())
600        .bind(Option::<String>::None)
601        .bind(serde_json::json!({}))
602        .bind(now)
603        .bind(workflow.id.to_string())
604        .execute(&mut **tx)
605        .await
606        .map_err(|e| ForgeError::Database(e.to_string()))?;
607
608        Ok(())
609    }
610}
611
612#[cfg(test)]
613mod tests {
614    use super::*;
615    use std::collections::HashMap;
616
617    #[test]
618    fn test_check_auth_public() {
619        let info = FunctionInfo {
620            name: "test",
621            description: None,
622            kind: FunctionKind::Query,
623            required_role: None,
624            is_public: true,
625            cache_ttl: None,
626            timeout: None,
627            rate_limit_requests: None,
628            rate_limit_per_secs: None,
629            rate_limit_key: None,
630            log_level: None,
631            table_dependencies: &[],
632            selected_columns: &[],
633            transactional: false,
634            consistent: false,
635            has_input_args: false,
636        };
637
638        let _auth = AuthContext::unauthenticated();
639
640        // Can't test check_auth directly without a router instance,
641        // but we can test the logic
642        assert!(info.is_public);
643    }
644
645    #[test]
646    fn test_identity_args_reject_cross_user_value() {
647        let user_id = uuid::Uuid::new_v4();
648        let auth = AuthContext::authenticated(
649            user_id,
650            vec!["user".to_string()],
651            HashMap::from([(
652                "sub".to_string(),
653                serde_json::Value::String(user_id.to_string()),
654            )]),
655        );
656        let args = serde_json::json!({
657            "user_id": uuid::Uuid::new_v4().to_string()
658        });
659
660        let result = FunctionRouter::check_identity_args("list_orders", &args, &auth, true);
661        assert!(matches!(result, Err(ForgeError::Forbidden(_))));
662    }
663
664    #[test]
665    fn test_identity_args_allow_matching_subject() {
666        let sub = "user_123";
667        let auth = AuthContext::authenticated_without_uuid(
668            vec!["user".to_string()],
669            HashMap::from([(
670                "sub".to_string(),
671                serde_json::Value::String(sub.to_string()),
672            )]),
673        );
674        let args = serde_json::json!({
675            "subject": sub
676        });
677
678        let result = FunctionRouter::check_identity_args("list_orders", &args, &auth, true);
679        assert!(result.is_ok());
680    }
681
682    #[test]
683    fn test_identity_args_require_auth_for_identity_keys() {
684        let auth = AuthContext::unauthenticated();
685        let args = serde_json::json!({
686            "user_id": uuid::Uuid::new_v4().to_string()
687        });
688
689        let result = FunctionRouter::check_identity_args("list_orders", &args, &auth, true);
690        assert!(matches!(result, Err(ForgeError::Unauthorized(_))));
691    }
692
693    #[test]
694    fn test_identity_args_require_scope_for_non_public_calls() {
695        let user_id = uuid::Uuid::new_v4();
696        let auth = AuthContext::authenticated(
697            user_id,
698            vec!["user".to_string()],
699            HashMap::from([(
700                "sub".to_string(),
701                serde_json::Value::String(user_id.to_string()),
702            )]),
703        );
704
705        let result =
706            FunctionRouter::check_identity_args("list_orders", &serde_json::json!({}), &auth, true);
707        assert!(matches!(result, Err(ForgeError::Forbidden(_))));
708    }
709
710    #[test]
711    fn test_identity_args_skip_scope_for_no_input_functions() {
712        let user_id = uuid::Uuid::new_v4();
713        let auth = AuthContext::authenticated(
714            user_id,
715            vec!["user".to_string()],
716            HashMap::from([(
717                "sub".to_string(),
718                serde_json::Value::String(user_id.to_string()),
719            )]),
720        );
721
722        // enforce_scope=false simulates has_input_args=false
723        let result = FunctionRouter::check_identity_args(
724            "list_todos",
725            &serde_json::Value::Null,
726            &auth,
727            false,
728        );
729        assert!(result.is_ok());
730    }
731
732    #[test]
733    fn test_auth_cache_scope_changes_with_claims() {
734        let user_id = uuid::Uuid::new_v4();
735        let auth_a = AuthContext::authenticated(
736            user_id,
737            vec!["user".to_string()],
738            HashMap::from([
739                (
740                    "sub".to_string(),
741                    serde_json::Value::String(user_id.to_string()),
742                ),
743                (
744                    "tenant_id".to_string(),
745                    serde_json::Value::String("tenant-a".to_string()),
746                ),
747            ]),
748        );
749        let auth_b = AuthContext::authenticated(
750            user_id,
751            vec!["user".to_string()],
752            HashMap::from([
753                (
754                    "sub".to_string(),
755                    serde_json::Value::String(user_id.to_string()),
756                ),
757                (
758                    "tenant_id".to_string(),
759                    serde_json::Value::String("tenant-b".to_string()),
760                ),
761            ]),
762        );
763
764        let scope_a = FunctionRouter::auth_cache_scope(&auth_a);
765        let scope_b = FunctionRouter::auth_cache_scope(&auth_b);
766        assert_ne!(scope_a, scope_b);
767    }
768}