Skip to main content

forge_runtime/function/
router.rs

1use std::collections::BTreeMap;
2use std::sync::Arc;
3use std::time::Duration;
4
5use chrono::Utc;
6use forge_core::{
7    AuthContext, CircuitBreakerClient, ForgeError, FunctionInfo, FunctionKind, JobDispatch,
8    MutationContext, OutboxBuffer, PendingJob, PendingWorkflow, QueryContext, RequestMetadata,
9    Result, WorkflowDispatch,
10    job::JobStatus,
11    rate_limit::{RateLimitConfig, RateLimitKey},
12    workflow::WorkflowStatus,
13};
14use serde_json::Value;
15
16use super::cache::QueryCache;
17use super::registry::{BoxedMutationFn, FunctionEntry, FunctionRegistry};
18use crate::db::Database;
19use crate::rate_limit::HybridRateLimiter;
20
21/// Result of routing a function call.
22pub enum RouteResult {
23    /// Query execution result.
24    Query(Value),
25    /// Mutation execution result.
26    Mutation(Value),
27    /// Job dispatch result (returns job_id).
28    Job(Value),
29    /// Workflow dispatch result (returns workflow_id).
30    Workflow(Value),
31}
32
33/// Routes function calls to the appropriate handler.
34pub struct FunctionRouter {
35    registry: Arc<FunctionRegistry>,
36    db: Database,
37    http_client: CircuitBreakerClient,
38    job_dispatcher: Option<Arc<dyn JobDispatch>>,
39    workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
40    rate_limiter: HybridRateLimiter,
41    query_cache: QueryCache,
42}
43
44impl FunctionRouter {
45    /// Create a new function router.
46    pub fn new(registry: Arc<FunctionRegistry>, db: Database) -> Self {
47        let rate_limiter = HybridRateLimiter::new(db.primary().clone());
48        Self {
49            registry,
50            db,
51            http_client: CircuitBreakerClient::with_defaults(reqwest::Client::new()),
52            job_dispatcher: None,
53            workflow_dispatcher: None,
54            rate_limiter,
55            query_cache: QueryCache::new(),
56        }
57    }
58
59    /// Create a new function router with a custom HTTP client.
60    pub fn with_http_client(
61        registry: Arc<FunctionRegistry>,
62        db: Database,
63        http_client: CircuitBreakerClient,
64    ) -> Self {
65        let rate_limiter = HybridRateLimiter::new(db.primary().clone());
66        Self {
67            registry,
68            db,
69            http_client,
70            job_dispatcher: None,
71            workflow_dispatcher: None,
72            rate_limiter,
73            query_cache: QueryCache::new(),
74        }
75    }
76
77    /// Set the job dispatcher for this router.
78    pub fn with_job_dispatcher(mut self, dispatcher: Arc<dyn JobDispatch>) -> Self {
79        self.job_dispatcher = Some(dispatcher);
80        self
81    }
82
83    /// Set the workflow dispatcher for this router.
84    pub fn with_workflow_dispatcher(mut self, dispatcher: Arc<dyn WorkflowDispatch>) -> Self {
85        self.workflow_dispatcher = Some(dispatcher);
86        self
87    }
88
89    pub async fn route(
90        &self,
91        function_name: &str,
92        args: Value,
93        auth: AuthContext,
94        request: RequestMetadata,
95    ) -> Result<RouteResult> {
96        if let Some(entry) = self.registry.get(function_name) {
97            self.check_auth(entry.info(), &auth)?;
98            self.check_rate_limit(entry.info(), function_name, &auth, &request)
99                .await?;
100            Self::check_identity_args(function_name, &args, &auth, !entry.info().is_public)?;
101
102            return match entry {
103                FunctionEntry::Query { handler, info, .. } => {
104                    let pool = if info.consistent {
105                        self.db.primary().clone()
106                    } else {
107                        self.db.read_pool().clone()
108                    };
109
110                    let auth_scope = Self::auth_cache_scope(&auth);
111                    if let Some(ttl) = info.cache_ttl {
112                        if let Some(cached) =
113                            self.query_cache
114                                .get(function_name, &args, auth_scope.as_deref())
115                        {
116                            return Ok(RouteResult::Query(Value::clone(&cached)));
117                        }
118
119                        let ctx = QueryContext::new(pool, auth, request);
120                        let result = handler(&ctx, args.clone()).await?;
121
122                        self.query_cache.set(
123                            function_name,
124                            &args,
125                            auth_scope.as_deref(),
126                            result.clone(),
127                            Duration::from_secs(ttl),
128                        );
129
130                        Ok(RouteResult::Query(result))
131                    } else {
132                        let ctx = QueryContext::new(pool, auth, request);
133                        let result = handler(&ctx, args).await?;
134                        Ok(RouteResult::Query(result))
135                    }
136                }
137                FunctionEntry::Mutation { handler, info } => {
138                    if info.transactional {
139                        self.execute_transactional(handler, args, auth, request)
140                            .await
141                    } else {
142                        // Use primary for mutations
143                        let ctx = MutationContext::with_dispatch(
144                            self.db.primary().clone(),
145                            auth,
146                            request,
147                            self.http_client.clone(),
148                            self.job_dispatcher.clone(),
149                            self.workflow_dispatcher.clone(),
150                        );
151                        let result = handler(&ctx, args).await?;
152                        Ok(RouteResult::Mutation(result))
153                    }
154                }
155            };
156        }
157
158        if let Some(ref job_dispatcher) = self.job_dispatcher
159            && let Some(job_info) = job_dispatcher.get_info(function_name)
160        {
161            self.check_job_auth(&job_info, &auth)?;
162            Self::check_identity_args(function_name, &args, &auth, !job_info.is_public)?;
163            match job_dispatcher
164                .dispatch_by_name(function_name, args.clone(), auth.principal_id())
165                .await
166            {
167                Ok(job_id) => {
168                    return Ok(RouteResult::Job(serde_json::json!({ "job_id": job_id })));
169                }
170                Err(ForgeError::NotFound(_)) => {}
171                Err(e) => return Err(e),
172            }
173        }
174
175        if let Some(ref workflow_dispatcher) = self.workflow_dispatcher
176            && let Some(workflow_info) = workflow_dispatcher.get_info(function_name)
177        {
178            self.check_workflow_auth(&workflow_info, &auth)?;
179            Self::check_identity_args(function_name, &args, &auth, !workflow_info.is_public)?;
180            match workflow_dispatcher
181                .start_by_name(function_name, args.clone(), auth.principal_id())
182                .await
183            {
184                Ok(workflow_id) => {
185                    return Ok(RouteResult::Workflow(
186                        serde_json::json!({ "workflow_id": workflow_id }),
187                    ));
188                }
189                Err(ForgeError::NotFound(_)) => {}
190                Err(e) => return Err(e),
191            }
192        }
193
194        Err(ForgeError::NotFound(format!(
195            "Function '{}' not found",
196            function_name
197        )))
198    }
199
200    fn check_auth(&self, info: &FunctionInfo, auth: &AuthContext) -> Result<()> {
201        if info.is_public {
202            return Ok(());
203        }
204
205        if !auth.is_authenticated() {
206            return Err(ForgeError::Unauthorized("Authentication required".into()));
207        }
208
209        if let Some(role) = info.required_role
210            && !auth.has_role(role)
211        {
212            return Err(ForgeError::Forbidden(format!("Role '{}' required", role)));
213        }
214
215        Ok(())
216    }
217
218    fn check_job_auth(&self, info: &forge_core::job::JobInfo, auth: &AuthContext) -> Result<()> {
219        if info.is_public {
220            return Ok(());
221        }
222
223        if !auth.is_authenticated() {
224            return Err(ForgeError::Unauthorized("Authentication required".into()));
225        }
226
227        if let Some(role) = info.required_role
228            && !auth.has_role(role)
229        {
230            return Err(ForgeError::Forbidden(format!("Role '{}' required", role)));
231        }
232
233        Ok(())
234    }
235
236    fn check_workflow_auth(
237        &self,
238        info: &forge_core::workflow::WorkflowInfo,
239        auth: &AuthContext,
240    ) -> Result<()> {
241        if info.is_public {
242            return Ok(());
243        }
244
245        if !auth.is_authenticated() {
246            return Err(ForgeError::Unauthorized("Authentication required".into()));
247        }
248
249        if let Some(role) = info.required_role
250            && !auth.has_role(role)
251        {
252            return Err(ForgeError::Forbidden(format!("Role '{}' required", role)));
253        }
254
255        Ok(())
256    }
257
258    /// Check rate limit for a function call.
259    async fn check_rate_limit(
260        &self,
261        info: &FunctionInfo,
262        function_name: &str,
263        auth: &AuthContext,
264        request: &RequestMetadata,
265    ) -> Result<()> {
266        // Skip if no rate limit configured
267        let (requests, per_secs) = match (info.rate_limit_requests, info.rate_limit_per_secs) {
268            (Some(r), Some(p)) => (r, p),
269            _ => return Ok(()),
270        };
271
272        // Build rate limit config
273        let key_str = info.rate_limit_key.unwrap_or("user");
274        let key_type: RateLimitKey = match key_str.parse() {
275            Ok(k) => k,
276            Err(_) => {
277                tracing::warn!(
278                    function = %function_name,
279                    key = %key_str,
280                    "Invalid rate limit key, falling back to 'user'"
281                );
282                RateLimitKey::default()
283            }
284        };
285
286        let config =
287            RateLimitConfig::new(requests, Duration::from_secs(per_secs)).with_key(key_type);
288
289        // Build bucket key
290        let bucket_key = self
291            .rate_limiter
292            .build_key(key_type, function_name, auth, request);
293
294        // Enforce rate limit
295        self.rate_limiter.enforce(&bucket_key, &config).await?;
296
297        Ok(())
298    }
299
300    fn auth_cache_scope(auth: &AuthContext) -> Option<String> {
301        if !auth.is_authenticated() {
302            return Some("anon".to_string());
303        }
304
305        // Include role + claims fingerprint to avoid cross-scope cache bleed.
306        let mut roles = auth.roles().to_vec();
307        roles.sort();
308        roles.dedup();
309
310        let mut claims = BTreeMap::new();
311        for (k, v) in auth.claims() {
312            claims.insert(k.clone(), v.clone());
313        }
314
315        use std::collections::hash_map::DefaultHasher;
316        use std::hash::{Hash, Hasher};
317
318        let mut hasher = DefaultHasher::new();
319        roles.hash(&mut hasher);
320        serde_json::to_string(&claims)
321            .unwrap_or_default()
322            .hash(&mut hasher);
323
324        let principal = auth
325            .principal_id()
326            .unwrap_or_else(|| "authenticated".to_string());
327
328        Some(format!(
329            "subject:{principal}:scope:{:016x}",
330            hasher.finish()
331        ))
332    }
333
334    fn check_identity_args(
335        function_name: &str,
336        args: &Value,
337        auth: &AuthContext,
338        enforce_scope: bool,
339    ) -> Result<()> {
340        if auth.is_admin() {
341            return Ok(());
342        }
343
344        let Some(obj) = args.as_object() else {
345            if enforce_scope && auth.is_authenticated() {
346                return Err(ForgeError::Forbidden(format!(
347                    "Function '{function_name}' must include identity or tenant scope arguments"
348                )));
349            }
350            return Ok(());
351        };
352
353        let mut principal_values: Vec<String> = Vec::new();
354        if let Some(user_id) = auth.user_id().map(|id| id.to_string()) {
355            principal_values.push(user_id);
356        }
357        if let Some(subject) = auth.principal_id()
358            && !principal_values.iter().any(|v| v == &subject)
359        {
360            principal_values.push(subject);
361        }
362
363        let mut has_scope_key = false;
364
365        for key in [
366            "user_id",
367            "userId",
368            "owner_id",
369            "ownerId",
370            "owner_subject",
371            "ownerSubject",
372            "subject",
373            "sub",
374            "principal_id",
375            "principalId",
376        ] {
377            let Some(value) = obj.get(key) else {
378                continue;
379            };
380            has_scope_key = true;
381
382            if !auth.is_authenticated() {
383                return Err(ForgeError::Unauthorized(format!(
384                    "Function '{function_name}' requires authentication for identity-scoped argument '{key}'"
385                )));
386            }
387
388            let Value::String(actual) = value else {
389                return Err(ForgeError::InvalidArgument(format!(
390                    "Function '{function_name}' argument '{key}' must be a non-empty string"
391                )));
392            };
393
394            if actual.trim().is_empty() || !principal_values.iter().any(|v| v == actual) {
395                return Err(ForgeError::Forbidden(format!(
396                    "Function '{function_name}' argument '{key}' does not match authenticated principal"
397                )));
398            }
399        }
400
401        for key in ["tenant_id", "tenantId"] {
402            let Some(value) = obj.get(key) else {
403                continue;
404            };
405            has_scope_key = true;
406
407            if !auth.is_authenticated() {
408                return Err(ForgeError::Unauthorized(format!(
409                    "Function '{function_name}' requires authentication for tenant-scoped argument '{key}'"
410                )));
411            }
412
413            let expected = auth
414                .claim("tenant_id")
415                .and_then(|v| v.as_str())
416                .ok_or_else(|| {
417                    ForgeError::Forbidden(format!(
418                        "Function '{function_name}' argument '{key}' is not allowed for this principal"
419                    ))
420                })?;
421
422            let Value::String(actual) = value else {
423                return Err(ForgeError::InvalidArgument(format!(
424                    "Function '{function_name}' argument '{key}' must be a non-empty string"
425                )));
426            };
427
428            if actual.trim().is_empty() || actual != expected {
429                return Err(ForgeError::Forbidden(format!(
430                    "Function '{function_name}' argument '{key}' does not match authenticated tenant"
431                )));
432            }
433        }
434
435        if enforce_scope && auth.is_authenticated() && !has_scope_key {
436            return Err(ForgeError::Forbidden(format!(
437                "Function '{function_name}' must include identity or tenant scope arguments"
438            )));
439        }
440
441        Ok(())
442    }
443
444    /// Get the function kind by name.
445    pub fn get_function_kind(&self, function_name: &str) -> Option<FunctionKind> {
446        self.registry.get(function_name).map(|e| e.kind())
447    }
448
449    /// Check if a function exists.
450    pub fn has_function(&self, function_name: &str) -> bool {
451        self.registry.get(function_name).is_some()
452    }
453
454    async fn execute_transactional(
455        &self,
456        handler: &BoxedMutationFn,
457        args: Value,
458        auth: AuthContext,
459        request: RequestMetadata,
460    ) -> Result<RouteResult> {
461        // Use primary for transactional mutations
462        let primary = self.db.primary();
463        let tx = primary
464            .begin()
465            .await
466            .map_err(|e| ForgeError::Database(e.to_string()))?;
467
468        let job_dispatcher = self.job_dispatcher.clone();
469        let job_lookup: forge_core::JobInfoLookup =
470            Arc::new(move |name: &str| job_dispatcher.as_ref().and_then(|d| d.get_info(name)));
471
472        let (ctx, tx_handle, outbox) = MutationContext::with_transaction(
473            primary.clone(),
474            tx,
475            auth,
476            request,
477            self.http_client.clone(),
478            job_lookup,
479        );
480
481        match handler(&ctx, args).await {
482            Ok(value) => {
483                let buffer = {
484                    let guard = outbox.lock().expect("outbox mutex poisoned");
485                    OutboxBuffer {
486                        jobs: guard.jobs.clone(),
487                        workflows: guard.workflows.clone(),
488                    }
489                };
490
491                let mut tx = Arc::try_unwrap(tx_handle)
492                    .map_err(|_| ForgeError::Internal("Transaction still in use".into()))?
493                    .into_inner();
494
495                for job in &buffer.jobs {
496                    Self::insert_job(&mut tx, job).await?;
497                }
498
499                for workflow in &buffer.workflows {
500                    let version = self
501                        .workflow_dispatcher
502                        .as_ref()
503                        .and_then(|d| d.get_info(&workflow.workflow_name))
504                        .map(|info| info.version)
505                        .ok_or_else(|| {
506                            ForgeError::NotFound(format!(
507                                "Workflow '{}' not found",
508                                workflow.workflow_name
509                            ))
510                        })?;
511                    Self::insert_workflow(&mut tx, workflow, version).await?;
512                }
513
514                tx.commit()
515                    .await
516                    .map_err(|e| ForgeError::Database(e.to_string()))?;
517
518                Ok(RouteResult::Mutation(value))
519            }
520            Err(e) => Err(e),
521        }
522    }
523
524    async fn insert_job(
525        tx: &mut sqlx::Transaction<'_, sqlx::Postgres>,
526        job: &PendingJob,
527    ) -> Result<()> {
528        let now = Utc::now();
529        sqlx::query(
530            r#"
531            INSERT INTO forge_jobs (
532                id, job_type, input, job_context, status, priority, attempts, max_attempts,
533                worker_capability, owner_subject, scheduled_at, created_at
534            ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
535            "#,
536        )
537        .bind(job.id)
538        .bind(&job.job_type)
539        .bind(&job.args)
540        .bind(&job.context)
541        .bind(JobStatus::Pending.as_str())
542        .bind(job.priority)
543        .bind(0i32)
544        .bind(job.max_attempts)
545        .bind(&job.worker_capability)
546        .bind(&job.owner_subject)
547        .bind(now)
548        .bind(now)
549        .execute(&mut **tx)
550        .await
551        .map_err(|e| ForgeError::Database(e.to_string()))?;
552
553        Ok(())
554    }
555
556    async fn insert_workflow(
557        tx: &mut sqlx::Transaction<'_, sqlx::Postgres>,
558        workflow: &PendingWorkflow,
559        version: u32,
560    ) -> Result<()> {
561        let now = Utc::now();
562        sqlx::query(
563            r#"
564            INSERT INTO forge_workflow_runs (
565                id, workflow_name, version, owner_subject, input, status, current_step,
566                step_results, started_at, trace_id
567            ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
568            "#,
569        )
570        .bind(workflow.id)
571        .bind(&workflow.workflow_name)
572        .bind(version as i32)
573        .bind(&workflow.owner_subject)
574        .bind(&workflow.input)
575        .bind(WorkflowStatus::Created.as_str())
576        .bind(Option::<String>::None)
577        .bind(serde_json::json!({}))
578        .bind(now)
579        .bind(workflow.id.to_string())
580        .execute(&mut **tx)
581        .await
582        .map_err(|e| ForgeError::Database(e.to_string()))?;
583
584        Ok(())
585    }
586}
587
588#[cfg(test)]
589mod tests {
590    use super::*;
591    use std::collections::HashMap;
592
593    #[test]
594    fn test_check_auth_public() {
595        let info = FunctionInfo {
596            name: "test",
597            description: None,
598            kind: FunctionKind::Query,
599            required_role: None,
600            is_public: true,
601            cache_ttl: None,
602            timeout: None,
603            rate_limit_requests: None,
604            rate_limit_per_secs: None,
605            rate_limit_key: None,
606            log_level: None,
607            table_dependencies: &[],
608            selected_columns: &[],
609            transactional: false,
610            consistent: false,
611        };
612
613        let _auth = AuthContext::unauthenticated();
614
615        // Can't test check_auth directly without a router instance,
616        // but we can test the logic
617        assert!(info.is_public);
618    }
619
620    #[test]
621    fn test_identity_args_reject_cross_user_value() {
622        let user_id = uuid::Uuid::new_v4();
623        let auth = AuthContext::authenticated(
624            user_id,
625            vec!["user".to_string()],
626            HashMap::from([(
627                "sub".to_string(),
628                serde_json::Value::String(user_id.to_string()),
629            )]),
630        );
631        let args = serde_json::json!({
632            "user_id": uuid::Uuid::new_v4().to_string()
633        });
634
635        let result = FunctionRouter::check_identity_args("list_orders", &args, &auth, true);
636        assert!(matches!(result, Err(ForgeError::Forbidden(_))));
637    }
638
639    #[test]
640    fn test_identity_args_allow_matching_subject() {
641        let sub = "user_123";
642        let auth = AuthContext::authenticated_without_uuid(
643            vec!["user".to_string()],
644            HashMap::from([(
645                "sub".to_string(),
646                serde_json::Value::String(sub.to_string()),
647            )]),
648        );
649        let args = serde_json::json!({
650            "subject": sub
651        });
652
653        let result = FunctionRouter::check_identity_args("list_orders", &args, &auth, true);
654        assert!(result.is_ok());
655    }
656
657    #[test]
658    fn test_identity_args_require_auth_for_identity_keys() {
659        let auth = AuthContext::unauthenticated();
660        let args = serde_json::json!({
661            "user_id": uuid::Uuid::new_v4().to_string()
662        });
663
664        let result = FunctionRouter::check_identity_args("list_orders", &args, &auth, true);
665        assert!(matches!(result, Err(ForgeError::Unauthorized(_))));
666    }
667
668    #[test]
669    fn test_identity_args_require_scope_for_non_public_calls() {
670        let user_id = uuid::Uuid::new_v4();
671        let auth = AuthContext::authenticated(
672            user_id,
673            vec!["user".to_string()],
674            HashMap::from([(
675                "sub".to_string(),
676                serde_json::Value::String(user_id.to_string()),
677            )]),
678        );
679
680        let result =
681            FunctionRouter::check_identity_args("list_orders", &serde_json::json!({}), &auth, true);
682        assert!(matches!(result, Err(ForgeError::Forbidden(_))));
683    }
684
685    #[test]
686    fn test_auth_cache_scope_changes_with_claims() {
687        let user_id = uuid::Uuid::new_v4();
688        let auth_a = AuthContext::authenticated(
689            user_id,
690            vec!["user".to_string()],
691            HashMap::from([
692                (
693                    "sub".to_string(),
694                    serde_json::Value::String(user_id.to_string()),
695                ),
696                (
697                    "tenant_id".to_string(),
698                    serde_json::Value::String("tenant-a".to_string()),
699                ),
700            ]),
701        );
702        let auth_b = AuthContext::authenticated(
703            user_id,
704            vec!["user".to_string()],
705            HashMap::from([
706                (
707                    "sub".to_string(),
708                    serde_json::Value::String(user_id.to_string()),
709                ),
710                (
711                    "tenant_id".to_string(),
712                    serde_json::Value::String("tenant-b".to_string()),
713                ),
714            ]),
715        );
716
717        let scope_a = FunctionRouter::auth_cache_scope(&auth_a);
718        let scope_b = FunctionRouter::auth_cache_scope(&auth_b);
719        assert_ne!(scope_a, scope_b);
720    }
721}