Skip to main content

forge_runtime/function/
router.rs

1use std::collections::BTreeMap;
2use std::sync::Arc;
3use std::time::Duration;
4
5use chrono::Utc;
6use forge_core::{
7    AuthContext, CircuitBreakerClient, ForgeError, FunctionInfo, FunctionKind, JobDispatch,
8    MutationContext, OutboxBuffer, PendingJob, PendingWorkflow, QueryContext, RequestMetadata,
9    Result, WorkflowDispatch,
10    job::JobStatus,
11    rate_limit::{RateLimitConfig, RateLimitKey},
12    workflow::WorkflowStatus,
13};
14use serde_json::Value;
15use tracing::Instrument;
16
17use super::cache::QueryCache;
18use super::registry::{BoxedMutationFn, FunctionEntry, FunctionRegistry};
19use crate::db::Database;
20use crate::rate_limit::HybridRateLimiter;
21
22/// Result of routing a function call.
23pub enum RouteResult {
24    /// Query execution result.
25    Query(Value),
26    /// Mutation execution result.
27    Mutation(Value),
28    /// Job dispatch result (returns job_id).
29    Job(Value),
30    /// Workflow dispatch result (returns workflow_id).
31    Workflow(Value),
32}
33
34/// Routes function calls to the appropriate handler.
35pub struct FunctionRouter {
36    registry: Arc<FunctionRegistry>,
37    db: Database,
38    http_client: CircuitBreakerClient,
39    job_dispatcher: Option<Arc<dyn JobDispatch>>,
40    workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
41    rate_limiter: HybridRateLimiter,
42    query_cache: QueryCache,
43    token_issuer: Option<Arc<dyn forge_core::TokenIssuer>>,
44}
45
46impl FunctionRouter {
47    /// Create a new function router.
48    pub fn new(registry: Arc<FunctionRegistry>, db: Database) -> Self {
49        let rate_limiter = HybridRateLimiter::new(db.primary().clone());
50        Self {
51            registry,
52            db,
53            http_client: CircuitBreakerClient::with_defaults(reqwest::Client::new()),
54            job_dispatcher: None,
55            workflow_dispatcher: None,
56            rate_limiter,
57            query_cache: QueryCache::new(),
58            token_issuer: None,
59        }
60    }
61
62    /// Create a new function router with a custom HTTP client.
63    pub fn with_http_client(
64        registry: Arc<FunctionRegistry>,
65        db: Database,
66        http_client: CircuitBreakerClient,
67    ) -> Self {
68        let rate_limiter = HybridRateLimiter::new(db.primary().clone());
69        Self {
70            registry,
71            db,
72            http_client,
73            job_dispatcher: None,
74            workflow_dispatcher: None,
75            rate_limiter,
76            query_cache: QueryCache::new(),
77            token_issuer: None,
78        }
79    }
80
81    /// Set the token issuer for this router (enables `ctx.issue_token()` in mutations).
82    pub fn with_token_issuer(mut self, issuer: Arc<dyn forge_core::TokenIssuer>) -> Self {
83        self.token_issuer = Some(issuer);
84        self
85    }
86
87    /// Set the job dispatcher for this router.
88    pub fn with_job_dispatcher(mut self, dispatcher: Arc<dyn JobDispatch>) -> Self {
89        self.job_dispatcher = Some(dispatcher);
90        self
91    }
92
93    /// Set the workflow dispatcher for this router.
94    pub fn with_workflow_dispatcher(mut self, dispatcher: Arc<dyn WorkflowDispatch>) -> Self {
95        self.workflow_dispatcher = Some(dispatcher);
96        self
97    }
98
99    pub async fn route(
100        &self,
101        function_name: &str,
102        args: Value,
103        auth: AuthContext,
104        request: RequestMetadata,
105    ) -> Result<RouteResult> {
106        if let Some(entry) = self.registry.get(function_name) {
107            self.check_auth(entry.info(), &auth)?;
108            self.check_rate_limit(entry.info(), function_name, &auth, &request)
109                .await?;
110            // Skip scope enforcement for functions with no input args.
111            // The JWT still carries identity, accessible via ctx.require_user_id().
112            let enforce = !entry.info().is_public && entry.info().has_input_args;
113            auth.check_identity_args(function_name, &args, enforce)?;
114
115            return match entry {
116                FunctionEntry::Query { handler, info, .. } => {
117                    let pool = if info.consistent {
118                        self.db.primary().clone()
119                    } else {
120                        self.db.read_pool().clone()
121                    };
122
123                    let auth_scope = Self::auth_cache_scope(&auth);
124                    if let Some(ttl) = info.cache_ttl {
125                        if let Some(cached) =
126                            self.query_cache
127                                .get(function_name, &args, auth_scope.as_deref())
128                        {
129                            return Ok(RouteResult::Query(Value::clone(&cached)));
130                        }
131
132                        let ctx = QueryContext::new(pool, auth, request);
133                        let result = handler(&ctx, args.clone()).await?;
134
135                        self.query_cache.set(
136                            function_name,
137                            &args,
138                            auth_scope.as_deref(),
139                            result.clone(),
140                            Duration::from_secs(ttl),
141                        );
142
143                        Ok(RouteResult::Query(result))
144                    } else {
145                        let ctx = QueryContext::new(pool, auth, request);
146                        let result = handler(&ctx, args).await?;
147                        Ok(RouteResult::Query(result))
148                    }
149                }
150                FunctionEntry::Mutation { handler, info } => {
151                    if info.transactional {
152                        self.execute_transactional(handler, args, auth, request)
153                            .await
154                    } else {
155                        // Use primary for mutations
156                        let mut ctx = MutationContext::with_dispatch(
157                            self.db.primary().clone(),
158                            auth,
159                            request,
160                            self.http_client.clone(),
161                            self.job_dispatcher.clone(),
162                            self.workflow_dispatcher.clone(),
163                        );
164                        if let Some(ref issuer) = self.token_issuer {
165                            ctx.set_token_issuer(issuer.clone());
166                        }
167                        let result = handler(&ctx, args).await?;
168                        Ok(RouteResult::Mutation(result))
169                    }
170                }
171            };
172        }
173
174        if let Some(ref job_dispatcher) = self.job_dispatcher
175            && let Some(job_info) = job_dispatcher.get_info(function_name)
176        {
177            self.check_job_auth(&job_info, &auth)?;
178            auth.check_identity_args(function_name, &args, !job_info.is_public)?;
179            match job_dispatcher
180                .dispatch_by_name(function_name, args.clone(), auth.principal_id())
181                .await
182            {
183                Ok(job_id) => {
184                    return Ok(RouteResult::Job(serde_json::json!({ "job_id": job_id })));
185                }
186                Err(ForgeError::NotFound(_)) => {}
187                Err(e) => return Err(e),
188            }
189        }
190
191        if let Some(ref workflow_dispatcher) = self.workflow_dispatcher
192            && let Some(workflow_info) = workflow_dispatcher.get_info(function_name)
193        {
194            self.check_workflow_auth(&workflow_info, &auth)?;
195            auth.check_identity_args(function_name, &args, !workflow_info.is_public)?;
196            match workflow_dispatcher
197                .start_by_name(function_name, args.clone(), auth.principal_id())
198                .await
199            {
200                Ok(workflow_id) => {
201                    return Ok(RouteResult::Workflow(
202                        serde_json::json!({ "workflow_id": workflow_id }),
203                    ));
204                }
205                Err(ForgeError::NotFound(_)) => {}
206                Err(e) => return Err(e),
207            }
208        }
209
210        Err(ForgeError::NotFound(format!(
211            "Function '{}' not found",
212            function_name
213        )))
214    }
215
216    fn check_auth(&self, info: &FunctionInfo, auth: &AuthContext) -> Result<()> {
217        if info.is_public {
218            return Ok(());
219        }
220
221        if !auth.is_authenticated() {
222            return Err(ForgeError::Unauthorized("Authentication required".into()));
223        }
224
225        if let Some(role) = info.required_role
226            && !auth.has_role(role)
227        {
228            return Err(ForgeError::Forbidden(format!("Role '{}' required", role)));
229        }
230
231        Ok(())
232    }
233
234    fn check_job_auth(&self, info: &forge_core::job::JobInfo, auth: &AuthContext) -> Result<()> {
235        if info.is_public {
236            return Ok(());
237        }
238
239        if !auth.is_authenticated() {
240            return Err(ForgeError::Unauthorized("Authentication required".into()));
241        }
242
243        if let Some(role) = info.required_role
244            && !auth.has_role(role)
245        {
246            return Err(ForgeError::Forbidden(format!("Role '{}' required", role)));
247        }
248
249        Ok(())
250    }
251
252    fn check_workflow_auth(
253        &self,
254        info: &forge_core::workflow::WorkflowInfo,
255        auth: &AuthContext,
256    ) -> Result<()> {
257        if info.is_public {
258            return Ok(());
259        }
260
261        if !auth.is_authenticated() {
262            return Err(ForgeError::Unauthorized("Authentication required".into()));
263        }
264
265        if let Some(role) = info.required_role
266            && !auth.has_role(role)
267        {
268            return Err(ForgeError::Forbidden(format!("Role '{}' required", role)));
269        }
270
271        Ok(())
272    }
273
274    /// Check rate limit for a function call.
275    async fn check_rate_limit(
276        &self,
277        info: &FunctionInfo,
278        function_name: &str,
279        auth: &AuthContext,
280        request: &RequestMetadata,
281    ) -> Result<()> {
282        // Skip if no rate limit configured
283        let (requests, per_secs) = match (info.rate_limit_requests, info.rate_limit_per_secs) {
284            (Some(r), Some(p)) => (r, p),
285            _ => return Ok(()),
286        };
287
288        // Build rate limit config
289        let key_str = info.rate_limit_key.unwrap_or("user");
290        let key_type: RateLimitKey = match key_str.parse() {
291            Ok(k) => k,
292            Err(_) => {
293                tracing::warn!(
294                    function = %function_name,
295                    key = %key_str,
296                    "Invalid rate limit key, falling back to 'user'"
297                );
298                RateLimitKey::default()
299            }
300        };
301
302        let config =
303            RateLimitConfig::new(requests, Duration::from_secs(per_secs)).with_key(key_type);
304
305        // Build bucket key
306        let bucket_key = self
307            .rate_limiter
308            .build_key(key_type, function_name, auth, request);
309
310        // Enforce rate limit
311        self.rate_limiter.enforce(&bucket_key, &config).await?;
312
313        Ok(())
314    }
315
316    fn auth_cache_scope(auth: &AuthContext) -> Option<String> {
317        if !auth.is_authenticated() {
318            return Some("anon".to_string());
319        }
320
321        // Include role + claims fingerprint to avoid cross-scope cache bleed.
322        let mut roles = auth.roles().to_vec();
323        roles.sort();
324        roles.dedup();
325
326        let mut claims = BTreeMap::new();
327        for (k, v) in auth.claims() {
328            claims.insert(k.clone(), v.clone());
329        }
330
331        use std::collections::hash_map::DefaultHasher;
332        use std::hash::{Hash, Hasher};
333
334        let mut hasher = DefaultHasher::new();
335        roles.hash(&mut hasher);
336        serde_json::to_string(&claims)
337            .unwrap_or_default()
338            .hash(&mut hasher);
339
340        let principal = auth
341            .principal_id()
342            .unwrap_or_else(|| "authenticated".to_string());
343
344        Some(format!(
345            "subject:{principal}:scope:{:016x}",
346            hasher.finish()
347        ))
348    }
349
350    /// Get the function kind by name.
351    pub fn get_function_kind(&self, function_name: &str) -> Option<FunctionKind> {
352        self.registry.get(function_name).map(|e| e.kind())
353    }
354
355    /// Check if a function exists.
356    pub fn has_function(&self, function_name: &str) -> bool {
357        self.registry.get(function_name).is_some()
358    }
359
360    async fn execute_transactional(
361        &self,
362        handler: &BoxedMutationFn,
363        args: Value,
364        auth: AuthContext,
365        request: RequestMetadata,
366    ) -> Result<RouteResult> {
367        let span = tracing::info_span!("db.transaction", db.system = "postgresql",);
368
369        async {
370            let primary = self.db.primary();
371            let tx = primary
372                .begin()
373                .await
374                .map_err(|e| ForgeError::Database(e.to_string()))?;
375
376            let job_dispatcher = self.job_dispatcher.clone();
377            let job_lookup: forge_core::JobInfoLookup =
378                Arc::new(move |name: &str| job_dispatcher.as_ref().and_then(|d| d.get_info(name)));
379
380            let (mut ctx, tx_handle, outbox) = MutationContext::with_transaction(
381                primary.clone(),
382                tx,
383                auth,
384                request,
385                self.http_client.clone(),
386                job_lookup,
387            );
388            if let Some(ref issuer) = self.token_issuer {
389                ctx.set_token_issuer(issuer.clone());
390            }
391
392            match handler(&ctx, args).await {
393                Ok(value) => {
394                    let buffer = {
395                        let guard = outbox.lock().expect("outbox mutex poisoned");
396                        OutboxBuffer {
397                            jobs: guard.jobs.clone(),
398                            workflows: guard.workflows.clone(),
399                        }
400                    };
401
402                    let mut tx = Arc::try_unwrap(tx_handle)
403                        .map_err(|_| ForgeError::Internal("Transaction still in use".into()))?
404                        .into_inner();
405
406                    for job in &buffer.jobs {
407                        Self::insert_job(&mut tx, job).await?;
408                    }
409
410                    for workflow in &buffer.workflows {
411                        let version = self
412                            .workflow_dispatcher
413                            .as_ref()
414                            .and_then(|d| d.get_info(&workflow.workflow_name))
415                            .map(|info| info.version)
416                            .ok_or_else(|| {
417                                ForgeError::NotFound(format!(
418                                    "Workflow '{}' not found",
419                                    workflow.workflow_name
420                                ))
421                            })?;
422                        Self::insert_workflow(&mut tx, workflow, version).await?;
423                    }
424
425                    tx.commit()
426                        .await
427                        .map_err(|e| ForgeError::Database(e.to_string()))?;
428
429                    Ok(RouteResult::Mutation(value))
430                }
431                Err(e) => Err(e),
432            }
433        }
434        .instrument(span)
435        .await
436    }
437
438    async fn insert_job(
439        tx: &mut sqlx::Transaction<'_, sqlx::Postgres>,
440        job: &PendingJob,
441    ) -> Result<()> {
442        let now = Utc::now();
443        sqlx::query(
444            r#"
445            INSERT INTO forge_jobs (
446                id, job_type, input, job_context, status, priority, attempts, max_attempts,
447                worker_capability, owner_subject, scheduled_at, created_at
448            ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
449            "#,
450        )
451        .bind(job.id)
452        .bind(&job.job_type)
453        .bind(&job.args)
454        .bind(&job.context)
455        .bind(JobStatus::Pending.as_str())
456        .bind(job.priority)
457        .bind(0i32)
458        .bind(job.max_attempts)
459        .bind(&job.worker_capability)
460        .bind(&job.owner_subject)
461        .bind(now)
462        .bind(now)
463        .execute(&mut **tx)
464        .await
465        .map_err(|e| ForgeError::Database(e.to_string()))?;
466
467        Ok(())
468    }
469
470    async fn insert_workflow(
471        tx: &mut sqlx::Transaction<'_, sqlx::Postgres>,
472        workflow: &PendingWorkflow,
473        version: u32,
474    ) -> Result<()> {
475        let now = Utc::now();
476        sqlx::query(
477            r#"
478            INSERT INTO forge_workflow_runs (
479                id, workflow_name, version, owner_subject, input, status, current_step,
480                step_results, started_at, trace_id
481            ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
482            "#,
483        )
484        .bind(workflow.id)
485        .bind(&workflow.workflow_name)
486        .bind(version as i32)
487        .bind(&workflow.owner_subject)
488        .bind(&workflow.input)
489        .bind(WorkflowStatus::Created.as_str())
490        .bind(Option::<String>::None)
491        .bind(serde_json::json!({}))
492        .bind(now)
493        .bind(workflow.id.to_string())
494        .execute(&mut **tx)
495        .await
496        .map_err(|e| ForgeError::Database(e.to_string()))?;
497
498        Ok(())
499    }
500}
501
502#[cfg(test)]
503mod tests {
504    use super::*;
505    use std::collections::HashMap;
506
507    #[test]
508    fn test_check_auth_public() {
509        let info = FunctionInfo {
510            name: "test",
511            description: None,
512            kind: FunctionKind::Query,
513            required_role: None,
514            is_public: true,
515            cache_ttl: None,
516            timeout: None,
517            rate_limit_requests: None,
518            rate_limit_per_secs: None,
519            rate_limit_key: None,
520            log_level: None,
521            table_dependencies: &[],
522            selected_columns: &[],
523            transactional: false,
524            consistent: false,
525            has_input_args: false,
526        };
527
528        let _auth = AuthContext::unauthenticated();
529
530        // Can't test check_auth directly without a router instance,
531        // but we can test the logic
532        assert!(info.is_public);
533    }
534
535    #[test]
536    fn test_identity_args_reject_cross_user_value() {
537        let user_id = uuid::Uuid::new_v4();
538        let auth = AuthContext::authenticated(
539            user_id,
540            vec!["user".to_string()],
541            HashMap::from([(
542                "sub".to_string(),
543                serde_json::Value::String(user_id.to_string()),
544            )]),
545        );
546        let args = serde_json::json!({
547            "user_id": uuid::Uuid::new_v4().to_string()
548        });
549
550        let result = auth.check_identity_args("list_orders", &args, true);
551        assert!(matches!(result, Err(ForgeError::Forbidden(_))));
552    }
553
554    #[test]
555    fn test_identity_args_allow_matching_subject() {
556        let sub = "user_123";
557        let auth = AuthContext::authenticated_without_uuid(
558            vec!["user".to_string()],
559            HashMap::from([(
560                "sub".to_string(),
561                serde_json::Value::String(sub.to_string()),
562            )]),
563        );
564        let args = serde_json::json!({
565            "subject": sub
566        });
567
568        let result = auth.check_identity_args("list_orders", &args, true);
569        assert!(result.is_ok());
570    }
571
572    #[test]
573    fn test_identity_args_require_auth_for_identity_keys() {
574        let auth = AuthContext::unauthenticated();
575        let args = serde_json::json!({
576            "user_id": uuid::Uuid::new_v4().to_string()
577        });
578
579        let result = auth.check_identity_args("list_orders", &args, true);
580        assert!(matches!(result, Err(ForgeError::Unauthorized(_))));
581    }
582
583    #[test]
584    fn test_identity_args_require_scope_for_non_public_calls() {
585        let user_id = uuid::Uuid::new_v4();
586        let auth = AuthContext::authenticated(
587            user_id,
588            vec!["user".to_string()],
589            HashMap::from([(
590                "sub".to_string(),
591                serde_json::Value::String(user_id.to_string()),
592            )]),
593        );
594
595        let result = auth.check_identity_args("list_orders", &serde_json::json!({}), true);
596        assert!(matches!(result, Err(ForgeError::Forbidden(_))));
597    }
598
599    #[test]
600    fn test_identity_args_skip_scope_for_no_input_functions() {
601        let user_id = uuid::Uuid::new_v4();
602        let auth = AuthContext::authenticated(
603            user_id,
604            vec!["user".to_string()],
605            HashMap::from([(
606                "sub".to_string(),
607                serde_json::Value::String(user_id.to_string()),
608            )]),
609        );
610
611        // enforce_scope=false simulates has_input_args=false
612        let result = auth.check_identity_args("list_todos", &serde_json::Value::Null, false);
613        assert!(result.is_ok());
614    }
615
616    #[test]
617    fn test_auth_cache_scope_changes_with_claims() {
618        let user_id = uuid::Uuid::new_v4();
619        let auth_a = AuthContext::authenticated(
620            user_id,
621            vec!["user".to_string()],
622            HashMap::from([
623                (
624                    "sub".to_string(),
625                    serde_json::Value::String(user_id.to_string()),
626                ),
627                (
628                    "tenant_id".to_string(),
629                    serde_json::Value::String("tenant-a".to_string()),
630                ),
631            ]),
632        );
633        let auth_b = AuthContext::authenticated(
634            user_id,
635            vec!["user".to_string()],
636            HashMap::from([
637                (
638                    "sub".to_string(),
639                    serde_json::Value::String(user_id.to_string()),
640                ),
641                (
642                    "tenant_id".to_string(),
643                    serde_json::Value::String("tenant-b".to_string()),
644                ),
645            ]),
646        );
647
648        let scope_a = FunctionRouter::auth_cache_scope(&auth_a);
649        let scope_b = FunctionRouter::auth_cache_scope(&auth_b);
650        assert_ne!(scope_a, scope_b);
651    }
652}