Skip to main content

forge_runtime/function/
router.rs

1use std::collections::BTreeMap;
2use std::sync::Arc;
3use std::time::Duration;
4
5use chrono::Utc;
6use forge_core::{
7    AuthContext, CircuitBreakerClient, ForgeError, FunctionInfo, FunctionKind, JobDispatch,
8    MutationContext, OutboxBuffer, PendingJob, PendingWorkflow, QueryContext, RequestMetadata,
9    Result, WorkflowDispatch,
10    job::JobStatus,
11    rate_limit::{RateLimitConfig, RateLimitKey},
12    workflow::WorkflowStatus,
13};
14use serde_json::Value;
15use tracing::Instrument;
16
17use super::cache::QueryCache;
18use super::registry::{BoxedMutationFn, FunctionEntry, FunctionRegistry};
19use crate::db::Database;
20use crate::rate_limit::HybridRateLimiter;
21
22/// Result of routing a function call.
23pub enum RouteResult {
24    /// Query execution result.
25    Query(Value),
26    /// Mutation execution result.
27    Mutation(Value),
28    /// Job dispatch result (returns job_id).
29    Job(Value),
30    /// Workflow dispatch result (returns workflow_id).
31    Workflow(Value),
32}
33
34/// Routes function calls to the appropriate handler.
35pub struct FunctionRouter {
36    registry: Arc<FunctionRegistry>,
37    db: Database,
38    http_client: CircuitBreakerClient,
39    job_dispatcher: Option<Arc<dyn JobDispatch>>,
40    workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
41    rate_limiter: HybridRateLimiter,
42    query_cache: QueryCache,
43}
44
45impl FunctionRouter {
46    /// Create a new function router.
47    pub fn new(registry: Arc<FunctionRegistry>, db: Database) -> Self {
48        let rate_limiter = HybridRateLimiter::new(db.primary().clone());
49        Self {
50            registry,
51            db,
52            http_client: CircuitBreakerClient::with_defaults(reqwest::Client::new()),
53            job_dispatcher: None,
54            workflow_dispatcher: None,
55            rate_limiter,
56            query_cache: QueryCache::new(),
57        }
58    }
59
60    /// Create a new function router with a custom HTTP client.
61    pub fn with_http_client(
62        registry: Arc<FunctionRegistry>,
63        db: Database,
64        http_client: CircuitBreakerClient,
65    ) -> Self {
66        let rate_limiter = HybridRateLimiter::new(db.primary().clone());
67        Self {
68            registry,
69            db,
70            http_client,
71            job_dispatcher: None,
72            workflow_dispatcher: None,
73            rate_limiter,
74            query_cache: QueryCache::new(),
75        }
76    }
77
78    /// Set the job dispatcher for this router.
79    pub fn with_job_dispatcher(mut self, dispatcher: Arc<dyn JobDispatch>) -> Self {
80        self.job_dispatcher = Some(dispatcher);
81        self
82    }
83
84    /// Set the workflow dispatcher for this router.
85    pub fn with_workflow_dispatcher(mut self, dispatcher: Arc<dyn WorkflowDispatch>) -> Self {
86        self.workflow_dispatcher = Some(dispatcher);
87        self
88    }
89
90    pub async fn route(
91        &self,
92        function_name: &str,
93        args: Value,
94        auth: AuthContext,
95        request: RequestMetadata,
96    ) -> Result<RouteResult> {
97        if let Some(entry) = self.registry.get(function_name) {
98            self.check_auth(entry.info(), &auth)?;
99            self.check_rate_limit(entry.info(), function_name, &auth, &request)
100                .await?;
101            Self::check_identity_args(function_name, &args, &auth, !entry.info().is_public)?;
102
103            return match entry {
104                FunctionEntry::Query { handler, info, .. } => {
105                    let pool = if info.consistent {
106                        self.db.primary().clone()
107                    } else {
108                        self.db.read_pool().clone()
109                    };
110
111                    let auth_scope = Self::auth_cache_scope(&auth);
112                    if let Some(ttl) = info.cache_ttl {
113                        if let Some(cached) =
114                            self.query_cache
115                                .get(function_name, &args, auth_scope.as_deref())
116                        {
117                            return Ok(RouteResult::Query(Value::clone(&cached)));
118                        }
119
120                        let ctx = QueryContext::new(pool, auth, request);
121                        let result = handler(&ctx, args.clone()).await?;
122
123                        self.query_cache.set(
124                            function_name,
125                            &args,
126                            auth_scope.as_deref(),
127                            result.clone(),
128                            Duration::from_secs(ttl),
129                        );
130
131                        Ok(RouteResult::Query(result))
132                    } else {
133                        let ctx = QueryContext::new(pool, auth, request);
134                        let result = handler(&ctx, args).await?;
135                        Ok(RouteResult::Query(result))
136                    }
137                }
138                FunctionEntry::Mutation { handler, info } => {
139                    if info.transactional {
140                        self.execute_transactional(handler, args, auth, request)
141                            .await
142                    } else {
143                        // Use primary for mutations
144                        let ctx = MutationContext::with_dispatch(
145                            self.db.primary().clone(),
146                            auth,
147                            request,
148                            self.http_client.clone(),
149                            self.job_dispatcher.clone(),
150                            self.workflow_dispatcher.clone(),
151                        );
152                        let result = handler(&ctx, args).await?;
153                        Ok(RouteResult::Mutation(result))
154                    }
155                }
156            };
157        }
158
159        if let Some(ref job_dispatcher) = self.job_dispatcher
160            && let Some(job_info) = job_dispatcher.get_info(function_name)
161        {
162            self.check_job_auth(&job_info, &auth)?;
163            Self::check_identity_args(function_name, &args, &auth, !job_info.is_public)?;
164            match job_dispatcher
165                .dispatch_by_name(function_name, args.clone(), auth.principal_id())
166                .await
167            {
168                Ok(job_id) => {
169                    return Ok(RouteResult::Job(serde_json::json!({ "job_id": job_id })));
170                }
171                Err(ForgeError::NotFound(_)) => {}
172                Err(e) => return Err(e),
173            }
174        }
175
176        if let Some(ref workflow_dispatcher) = self.workflow_dispatcher
177            && let Some(workflow_info) = workflow_dispatcher.get_info(function_name)
178        {
179            self.check_workflow_auth(&workflow_info, &auth)?;
180            Self::check_identity_args(function_name, &args, &auth, !workflow_info.is_public)?;
181            match workflow_dispatcher
182                .start_by_name(function_name, args.clone(), auth.principal_id())
183                .await
184            {
185                Ok(workflow_id) => {
186                    return Ok(RouteResult::Workflow(
187                        serde_json::json!({ "workflow_id": workflow_id }),
188                    ));
189                }
190                Err(ForgeError::NotFound(_)) => {}
191                Err(e) => return Err(e),
192            }
193        }
194
195        Err(ForgeError::NotFound(format!(
196            "Function '{}' not found",
197            function_name
198        )))
199    }
200
201    fn check_auth(&self, info: &FunctionInfo, auth: &AuthContext) -> Result<()> {
202        if info.is_public {
203            return Ok(());
204        }
205
206        if !auth.is_authenticated() {
207            return Err(ForgeError::Unauthorized("Authentication required".into()));
208        }
209
210        if let Some(role) = info.required_role
211            && !auth.has_role(role)
212        {
213            return Err(ForgeError::Forbidden(format!("Role '{}' required", role)));
214        }
215
216        Ok(())
217    }
218
219    fn check_job_auth(&self, info: &forge_core::job::JobInfo, auth: &AuthContext) -> Result<()> {
220        if info.is_public {
221            return Ok(());
222        }
223
224        if !auth.is_authenticated() {
225            return Err(ForgeError::Unauthorized("Authentication required".into()));
226        }
227
228        if let Some(role) = info.required_role
229            && !auth.has_role(role)
230        {
231            return Err(ForgeError::Forbidden(format!("Role '{}' required", role)));
232        }
233
234        Ok(())
235    }
236
237    fn check_workflow_auth(
238        &self,
239        info: &forge_core::workflow::WorkflowInfo,
240        auth: &AuthContext,
241    ) -> Result<()> {
242        if info.is_public {
243            return Ok(());
244        }
245
246        if !auth.is_authenticated() {
247            return Err(ForgeError::Unauthorized("Authentication required".into()));
248        }
249
250        if let Some(role) = info.required_role
251            && !auth.has_role(role)
252        {
253            return Err(ForgeError::Forbidden(format!("Role '{}' required", role)));
254        }
255
256        Ok(())
257    }
258
259    /// Check rate limit for a function call.
260    async fn check_rate_limit(
261        &self,
262        info: &FunctionInfo,
263        function_name: &str,
264        auth: &AuthContext,
265        request: &RequestMetadata,
266    ) -> Result<()> {
267        // Skip if no rate limit configured
268        let (requests, per_secs) = match (info.rate_limit_requests, info.rate_limit_per_secs) {
269            (Some(r), Some(p)) => (r, p),
270            _ => return Ok(()),
271        };
272
273        // Build rate limit config
274        let key_str = info.rate_limit_key.unwrap_or("user");
275        let key_type: RateLimitKey = match key_str.parse() {
276            Ok(k) => k,
277            Err(_) => {
278                tracing::warn!(
279                    function = %function_name,
280                    key = %key_str,
281                    "Invalid rate limit key, falling back to 'user'"
282                );
283                RateLimitKey::default()
284            }
285        };
286
287        let config =
288            RateLimitConfig::new(requests, Duration::from_secs(per_secs)).with_key(key_type);
289
290        // Build bucket key
291        let bucket_key = self
292            .rate_limiter
293            .build_key(key_type, function_name, auth, request);
294
295        // Enforce rate limit
296        self.rate_limiter.enforce(&bucket_key, &config).await?;
297
298        Ok(())
299    }
300
301    fn auth_cache_scope(auth: &AuthContext) -> Option<String> {
302        if !auth.is_authenticated() {
303            return Some("anon".to_string());
304        }
305
306        // Include role + claims fingerprint to avoid cross-scope cache bleed.
307        let mut roles = auth.roles().to_vec();
308        roles.sort();
309        roles.dedup();
310
311        let mut claims = BTreeMap::new();
312        for (k, v) in auth.claims() {
313            claims.insert(k.clone(), v.clone());
314        }
315
316        use std::collections::hash_map::DefaultHasher;
317        use std::hash::{Hash, Hasher};
318
319        let mut hasher = DefaultHasher::new();
320        roles.hash(&mut hasher);
321        serde_json::to_string(&claims)
322            .unwrap_or_default()
323            .hash(&mut hasher);
324
325        let principal = auth
326            .principal_id()
327            .unwrap_or_else(|| "authenticated".to_string());
328
329        Some(format!(
330            "subject:{principal}:scope:{:016x}",
331            hasher.finish()
332        ))
333    }
334
335    fn check_identity_args(
336        function_name: &str,
337        args: &Value,
338        auth: &AuthContext,
339        enforce_scope: bool,
340    ) -> Result<()> {
341        if auth.is_admin() {
342            return Ok(());
343        }
344
345        let Some(obj) = args.as_object() else {
346            if enforce_scope && auth.is_authenticated() {
347                return Err(ForgeError::Forbidden(format!(
348                    "Function '{function_name}' must include identity or tenant scope arguments"
349                )));
350            }
351            return Ok(());
352        };
353
354        let mut principal_values: Vec<String> = Vec::new();
355        if let Some(user_id) = auth.user_id().map(|id| id.to_string()) {
356            principal_values.push(user_id);
357        }
358        if let Some(subject) = auth.principal_id()
359            && !principal_values.iter().any(|v| v == &subject)
360        {
361            principal_values.push(subject);
362        }
363
364        let mut has_scope_key = false;
365
366        for key in [
367            "user_id",
368            "userId",
369            "owner_id",
370            "ownerId",
371            "owner_subject",
372            "ownerSubject",
373            "subject",
374            "sub",
375            "principal_id",
376            "principalId",
377        ] {
378            let Some(value) = obj.get(key) else {
379                continue;
380            };
381            has_scope_key = true;
382
383            if !auth.is_authenticated() {
384                return Err(ForgeError::Unauthorized(format!(
385                    "Function '{function_name}' requires authentication for identity-scoped argument '{key}'"
386                )));
387            }
388
389            let Value::String(actual) = value else {
390                return Err(ForgeError::InvalidArgument(format!(
391                    "Function '{function_name}' argument '{key}' must be a non-empty string"
392                )));
393            };
394
395            if actual.trim().is_empty() || !principal_values.iter().any(|v| v == actual) {
396                return Err(ForgeError::Forbidden(format!(
397                    "Function '{function_name}' argument '{key}' does not match authenticated principal"
398                )));
399            }
400        }
401
402        for key in ["tenant_id", "tenantId"] {
403            let Some(value) = obj.get(key) else {
404                continue;
405            };
406            has_scope_key = true;
407
408            if !auth.is_authenticated() {
409                return Err(ForgeError::Unauthorized(format!(
410                    "Function '{function_name}' requires authentication for tenant-scoped argument '{key}'"
411                )));
412            }
413
414            let expected = auth
415                .claim("tenant_id")
416                .and_then(|v| v.as_str())
417                .ok_or_else(|| {
418                    ForgeError::Forbidden(format!(
419                        "Function '{function_name}' argument '{key}' is not allowed for this principal"
420                    ))
421                })?;
422
423            let Value::String(actual) = value else {
424                return Err(ForgeError::InvalidArgument(format!(
425                    "Function '{function_name}' argument '{key}' must be a non-empty string"
426                )));
427            };
428
429            if actual.trim().is_empty() || actual != expected {
430                return Err(ForgeError::Forbidden(format!(
431                    "Function '{function_name}' argument '{key}' does not match authenticated tenant"
432                )));
433            }
434        }
435
436        if enforce_scope && auth.is_authenticated() && !has_scope_key {
437            return Err(ForgeError::Forbidden(format!(
438                "Function '{function_name}' must include identity or tenant scope arguments"
439            )));
440        }
441
442        Ok(())
443    }
444
445    /// Get the function kind by name.
446    pub fn get_function_kind(&self, function_name: &str) -> Option<FunctionKind> {
447        self.registry.get(function_name).map(|e| e.kind())
448    }
449
450    /// Check if a function exists.
451    pub fn has_function(&self, function_name: &str) -> bool {
452        self.registry.get(function_name).is_some()
453    }
454
455    async fn execute_transactional(
456        &self,
457        handler: &BoxedMutationFn,
458        args: Value,
459        auth: AuthContext,
460        request: RequestMetadata,
461    ) -> Result<RouteResult> {
462        let span = tracing::info_span!("db.transaction", db.system = "postgresql",);
463
464        async {
465            let primary = self.db.primary();
466            let tx = primary
467                .begin()
468                .await
469                .map_err(|e| ForgeError::Database(e.to_string()))?;
470
471            let job_dispatcher = self.job_dispatcher.clone();
472            let job_lookup: forge_core::JobInfoLookup =
473                Arc::new(move |name: &str| job_dispatcher.as_ref().and_then(|d| d.get_info(name)));
474
475            let (ctx, tx_handle, outbox) = MutationContext::with_transaction(
476                primary.clone(),
477                tx,
478                auth,
479                request,
480                self.http_client.clone(),
481                job_lookup,
482            );
483
484            match handler(&ctx, args).await {
485                Ok(value) => {
486                    let buffer = {
487                        let guard = outbox.lock().expect("outbox mutex poisoned");
488                        OutboxBuffer {
489                            jobs: guard.jobs.clone(),
490                            workflows: guard.workflows.clone(),
491                        }
492                    };
493
494                    let mut tx = Arc::try_unwrap(tx_handle)
495                        .map_err(|_| ForgeError::Internal("Transaction still in use".into()))?
496                        .into_inner();
497
498                    for job in &buffer.jobs {
499                        Self::insert_job(&mut tx, job).await?;
500                    }
501
502                    for workflow in &buffer.workflows {
503                        let version = self
504                            .workflow_dispatcher
505                            .as_ref()
506                            .and_then(|d| d.get_info(&workflow.workflow_name))
507                            .map(|info| info.version)
508                            .ok_or_else(|| {
509                                ForgeError::NotFound(format!(
510                                    "Workflow '{}' not found",
511                                    workflow.workflow_name
512                                ))
513                            })?;
514                        Self::insert_workflow(&mut tx, workflow, version).await?;
515                    }
516
517                    tx.commit()
518                        .await
519                        .map_err(|e| ForgeError::Database(e.to_string()))?;
520
521                    Ok(RouteResult::Mutation(value))
522                }
523                Err(e) => Err(e),
524            }
525        }
526        .instrument(span)
527        .await
528    }
529
530    async fn insert_job(
531        tx: &mut sqlx::Transaction<'_, sqlx::Postgres>,
532        job: &PendingJob,
533    ) -> Result<()> {
534        let now = Utc::now();
535        sqlx::query(
536            r#"
537            INSERT INTO forge_jobs (
538                id, job_type, input, job_context, status, priority, attempts, max_attempts,
539                worker_capability, owner_subject, scheduled_at, created_at
540            ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
541            "#,
542        )
543        .bind(job.id)
544        .bind(&job.job_type)
545        .bind(&job.args)
546        .bind(&job.context)
547        .bind(JobStatus::Pending.as_str())
548        .bind(job.priority)
549        .bind(0i32)
550        .bind(job.max_attempts)
551        .bind(&job.worker_capability)
552        .bind(&job.owner_subject)
553        .bind(now)
554        .bind(now)
555        .execute(&mut **tx)
556        .await
557        .map_err(|e| ForgeError::Database(e.to_string()))?;
558
559        Ok(())
560    }
561
562    async fn insert_workflow(
563        tx: &mut sqlx::Transaction<'_, sqlx::Postgres>,
564        workflow: &PendingWorkflow,
565        version: u32,
566    ) -> Result<()> {
567        let now = Utc::now();
568        sqlx::query(
569            r#"
570            INSERT INTO forge_workflow_runs (
571                id, workflow_name, version, owner_subject, input, status, current_step,
572                step_results, started_at, trace_id
573            ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
574            "#,
575        )
576        .bind(workflow.id)
577        .bind(&workflow.workflow_name)
578        .bind(version as i32)
579        .bind(&workflow.owner_subject)
580        .bind(&workflow.input)
581        .bind(WorkflowStatus::Created.as_str())
582        .bind(Option::<String>::None)
583        .bind(serde_json::json!({}))
584        .bind(now)
585        .bind(workflow.id.to_string())
586        .execute(&mut **tx)
587        .await
588        .map_err(|e| ForgeError::Database(e.to_string()))?;
589
590        Ok(())
591    }
592}
593
594#[cfg(test)]
595mod tests {
596    use super::*;
597    use std::collections::HashMap;
598
599    #[test]
600    fn test_check_auth_public() {
601        let info = FunctionInfo {
602            name: "test",
603            description: None,
604            kind: FunctionKind::Query,
605            required_role: None,
606            is_public: true,
607            cache_ttl: None,
608            timeout: None,
609            rate_limit_requests: None,
610            rate_limit_per_secs: None,
611            rate_limit_key: None,
612            log_level: None,
613            table_dependencies: &[],
614            selected_columns: &[],
615            transactional: false,
616            consistent: false,
617        };
618
619        let _auth = AuthContext::unauthenticated();
620
621        // Can't test check_auth directly without a router instance,
622        // but we can test the logic
623        assert!(info.is_public);
624    }
625
626    #[test]
627    fn test_identity_args_reject_cross_user_value() {
628        let user_id = uuid::Uuid::new_v4();
629        let auth = AuthContext::authenticated(
630            user_id,
631            vec!["user".to_string()],
632            HashMap::from([(
633                "sub".to_string(),
634                serde_json::Value::String(user_id.to_string()),
635            )]),
636        );
637        let args = serde_json::json!({
638            "user_id": uuid::Uuid::new_v4().to_string()
639        });
640
641        let result = FunctionRouter::check_identity_args("list_orders", &args, &auth, true);
642        assert!(matches!(result, Err(ForgeError::Forbidden(_))));
643    }
644
645    #[test]
646    fn test_identity_args_allow_matching_subject() {
647        let sub = "user_123";
648        let auth = AuthContext::authenticated_without_uuid(
649            vec!["user".to_string()],
650            HashMap::from([(
651                "sub".to_string(),
652                serde_json::Value::String(sub.to_string()),
653            )]),
654        );
655        let args = serde_json::json!({
656            "subject": sub
657        });
658
659        let result = FunctionRouter::check_identity_args("list_orders", &args, &auth, true);
660        assert!(result.is_ok());
661    }
662
663    #[test]
664    fn test_identity_args_require_auth_for_identity_keys() {
665        let auth = AuthContext::unauthenticated();
666        let args = serde_json::json!({
667            "user_id": uuid::Uuid::new_v4().to_string()
668        });
669
670        let result = FunctionRouter::check_identity_args("list_orders", &args, &auth, true);
671        assert!(matches!(result, Err(ForgeError::Unauthorized(_))));
672    }
673
674    #[test]
675    fn test_identity_args_require_scope_for_non_public_calls() {
676        let user_id = uuid::Uuid::new_v4();
677        let auth = AuthContext::authenticated(
678            user_id,
679            vec!["user".to_string()],
680            HashMap::from([(
681                "sub".to_string(),
682                serde_json::Value::String(user_id.to_string()),
683            )]),
684        );
685
686        let result =
687            FunctionRouter::check_identity_args("list_orders", &serde_json::json!({}), &auth, true);
688        assert!(matches!(result, Err(ForgeError::Forbidden(_))));
689    }
690
691    #[test]
692    fn test_auth_cache_scope_changes_with_claims() {
693        let user_id = uuid::Uuid::new_v4();
694        let auth_a = AuthContext::authenticated(
695            user_id,
696            vec!["user".to_string()],
697            HashMap::from([
698                (
699                    "sub".to_string(),
700                    serde_json::Value::String(user_id.to_string()),
701                ),
702                (
703                    "tenant_id".to_string(),
704                    serde_json::Value::String("tenant-a".to_string()),
705                ),
706            ]),
707        );
708        let auth_b = AuthContext::authenticated(
709            user_id,
710            vec!["user".to_string()],
711            HashMap::from([
712                (
713                    "sub".to_string(),
714                    serde_json::Value::String(user_id.to_string()),
715                ),
716                (
717                    "tenant_id".to_string(),
718                    serde_json::Value::String("tenant-b".to_string()),
719                ),
720            ]),
721        );
722
723        let scope_a = FunctionRouter::auth_cache_scope(&auth_a);
724        let scope_b = FunctionRouter::auth_cache_scope(&auth_b);
725        assert_ne!(scope_a, scope_b);
726    }
727}