Skip to main content

forge_runtime/function/
router.rs

1use std::collections::BTreeMap;
2use std::sync::Arc;
3use std::time::Duration;
4
5use chrono::Utc;
6use forge_core::{
7    AuthContext, CircuitBreakerClient, ForgeError, FunctionInfo, FunctionKind, JobDispatch,
8    MutationContext, OutboxBuffer, PendingJob, PendingWorkflow, QueryContext, RequestMetadata,
9    Result, WorkflowDispatch,
10    job::JobStatus,
11    rate_limit::{RateLimitConfig, RateLimitKey},
12    workflow::WorkflowStatus,
13};
14use serde_json::Value;
15
16use super::cache::QueryCache;
17use super::registry::{BoxedMutationFn, FunctionEntry, FunctionRegistry};
18use crate::db::Database;
19use crate::rate_limit::RateLimiter;
20
21/// Result of routing a function call.
22pub enum RouteResult {
23    /// Query execution result.
24    Query(Value),
25    /// Mutation execution result.
26    Mutation(Value),
27    /// Job dispatch result (returns job_id).
28    Job(Value),
29    /// Workflow dispatch result (returns workflow_id).
30    Workflow(Value),
31}
32
33/// Routes function calls to the appropriate handler.
34pub struct FunctionRouter {
35    registry: Arc<FunctionRegistry>,
36    db: Database,
37    http_client: CircuitBreakerClient,
38    job_dispatcher: Option<Arc<dyn JobDispatch>>,
39    workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
40    rate_limiter: RateLimiter,
41    query_cache: QueryCache,
42}
43
44impl FunctionRouter {
45    /// Create a new function router.
46    pub fn new(registry: Arc<FunctionRegistry>, db: Database) -> Self {
47        let rate_limiter = RateLimiter::new(db.primary().clone());
48        Self {
49            registry,
50            db,
51            http_client: CircuitBreakerClient::with_defaults(reqwest::Client::new()),
52            job_dispatcher: None,
53            workflow_dispatcher: None,
54            rate_limiter,
55            query_cache: QueryCache::new(),
56        }
57    }
58
59    /// Create a new function router with a custom HTTP client.
60    pub fn with_http_client(
61        registry: Arc<FunctionRegistry>,
62        db: Database,
63        http_client: CircuitBreakerClient,
64    ) -> Self {
65        let rate_limiter = RateLimiter::new(db.primary().clone());
66        Self {
67            registry,
68            db,
69            http_client,
70            job_dispatcher: None,
71            workflow_dispatcher: None,
72            rate_limiter,
73            query_cache: QueryCache::new(),
74        }
75    }
76
77    /// Set the job dispatcher for this router.
78    pub fn with_job_dispatcher(mut self, dispatcher: Arc<dyn JobDispatch>) -> Self {
79        self.job_dispatcher = Some(dispatcher);
80        self
81    }
82
83    /// Set the workflow dispatcher for this router.
84    pub fn with_workflow_dispatcher(mut self, dispatcher: Arc<dyn WorkflowDispatch>) -> Self {
85        self.workflow_dispatcher = Some(dispatcher);
86        self
87    }
88
89    pub async fn route(
90        &self,
91        function_name: &str,
92        args: Value,
93        auth: AuthContext,
94        request: RequestMetadata,
95    ) -> Result<RouteResult> {
96        if let Some(entry) = self.registry.get(function_name) {
97            self.check_auth(entry.info(), &auth)?;
98            self.check_rate_limit(entry.info(), function_name, &auth, &request)
99                .await?;
100            Self::check_identity_args(function_name, &args, &auth, !entry.info().is_public)?;
101
102            return match entry {
103                FunctionEntry::Query { handler, info, .. } => {
104                    let auth_scope = Self::auth_cache_scope(&auth);
105                    if let Some(ttl) = info.cache_ttl {
106                        if let Some(cached) =
107                            self.query_cache
108                                .get(function_name, &args, auth_scope.as_deref())
109                        {
110                            return Ok(RouteResult::Query(Value::clone(&cached)));
111                        }
112
113                        // Execute and cache result (use read replica for queries)
114                        let ctx = QueryContext::new(self.db.read_pool().clone(), auth, request);
115                        let result = handler(&ctx, args.clone()).await?;
116
117                        self.query_cache.set(
118                            function_name,
119                            &args,
120                            auth_scope.as_deref(),
121                            result.clone(),
122                            Duration::from_secs(ttl),
123                        );
124
125                        Ok(RouteResult::Query(result))
126                    } else {
127                        // Use read replica for queries
128                        let ctx = QueryContext::new(self.db.read_pool().clone(), auth, request);
129                        let result = handler(&ctx, args).await?;
130                        Ok(RouteResult::Query(result))
131                    }
132                }
133                FunctionEntry::Mutation { handler, info } => {
134                    if info.transactional {
135                        self.execute_transactional(handler, args, auth, request)
136                            .await
137                    } else {
138                        // Use primary for mutations
139                        let ctx = MutationContext::with_dispatch(
140                            self.db.primary().clone(),
141                            auth,
142                            request,
143                            self.http_client.clone(),
144                            self.job_dispatcher.clone(),
145                            self.workflow_dispatcher.clone(),
146                        );
147                        let result = handler(&ctx, args).await?;
148                        Ok(RouteResult::Mutation(result))
149                    }
150                }
151            };
152        }
153
154        if let Some(ref job_dispatcher) = self.job_dispatcher
155            && let Some(job_info) = job_dispatcher.get_info(function_name)
156        {
157            self.check_job_auth(&job_info, &auth)?;
158            Self::check_identity_args(function_name, &args, &auth, !job_info.is_public)?;
159            match job_dispatcher
160                .dispatch_by_name(function_name, args.clone(), auth.principal_id())
161                .await
162            {
163                Ok(job_id) => {
164                    return Ok(RouteResult::Job(serde_json::json!({ "job_id": job_id })));
165                }
166                Err(ForgeError::NotFound(_)) => {}
167                Err(e) => return Err(e),
168            }
169        }
170
171        if let Some(ref workflow_dispatcher) = self.workflow_dispatcher
172            && let Some(workflow_info) = workflow_dispatcher.get_info(function_name)
173        {
174            self.check_workflow_auth(&workflow_info, &auth)?;
175            Self::check_identity_args(function_name, &args, &auth, !workflow_info.is_public)?;
176            match workflow_dispatcher
177                .start_by_name(function_name, args.clone(), auth.principal_id())
178                .await
179            {
180                Ok(workflow_id) => {
181                    return Ok(RouteResult::Workflow(
182                        serde_json::json!({ "workflow_id": workflow_id }),
183                    ));
184                }
185                Err(ForgeError::NotFound(_)) => {}
186                Err(e) => return Err(e),
187            }
188        }
189
190        Err(ForgeError::NotFound(format!(
191            "Function '{}' not found",
192            function_name
193        )))
194    }
195
196    fn check_auth(&self, info: &FunctionInfo, auth: &AuthContext) -> Result<()> {
197        if info.is_public {
198            return Ok(());
199        }
200
201        if !auth.is_authenticated() {
202            return Err(ForgeError::Unauthorized("Authentication required".into()));
203        }
204
205        if let Some(role) = info.required_role
206            && !auth.has_role(role)
207        {
208            return Err(ForgeError::Forbidden(format!("Role '{}' required", role)));
209        }
210
211        Ok(())
212    }
213
214    fn check_job_auth(&self, info: &forge_core::job::JobInfo, auth: &AuthContext) -> Result<()> {
215        if info.is_public {
216            return Ok(());
217        }
218
219        if !auth.is_authenticated() {
220            return Err(ForgeError::Unauthorized("Authentication required".into()));
221        }
222
223        if let Some(role) = info.required_role
224            && !auth.has_role(role)
225        {
226            return Err(ForgeError::Forbidden(format!("Role '{}' required", role)));
227        }
228
229        Ok(())
230    }
231
232    fn check_workflow_auth(
233        &self,
234        info: &forge_core::workflow::WorkflowInfo,
235        auth: &AuthContext,
236    ) -> Result<()> {
237        if info.is_public {
238            return Ok(());
239        }
240
241        if !auth.is_authenticated() {
242            return Err(ForgeError::Unauthorized("Authentication required".into()));
243        }
244
245        if let Some(role) = info.required_role
246            && !auth.has_role(role)
247        {
248            return Err(ForgeError::Forbidden(format!("Role '{}' required", role)));
249        }
250
251        Ok(())
252    }
253
254    /// Check rate limit for a function call.
255    async fn check_rate_limit(
256        &self,
257        info: &FunctionInfo,
258        function_name: &str,
259        auth: &AuthContext,
260        request: &RequestMetadata,
261    ) -> Result<()> {
262        // Skip if no rate limit configured
263        let (requests, per_secs) = match (info.rate_limit_requests, info.rate_limit_per_secs) {
264            (Some(r), Some(p)) => (r, p),
265            _ => return Ok(()),
266        };
267
268        // Build rate limit config
269        let key_str = info.rate_limit_key.unwrap_or("user");
270        let key_type: RateLimitKey = match key_str.parse() {
271            Ok(k) => k,
272            Err(_) => {
273                tracing::warn!(
274                    function = %function_name,
275                    key = %key_str,
276                    "Invalid rate limit key, falling back to 'user'"
277                );
278                RateLimitKey::default()
279            }
280        };
281
282        let config =
283            RateLimitConfig::new(requests, Duration::from_secs(per_secs)).with_key(key_type);
284
285        // Build bucket key
286        let bucket_key = self
287            .rate_limiter
288            .build_key(key_type, function_name, auth, request);
289
290        // Enforce rate limit
291        self.rate_limiter.enforce(&bucket_key, &config).await?;
292
293        Ok(())
294    }
295
296    fn auth_cache_scope(auth: &AuthContext) -> Option<String> {
297        if !auth.is_authenticated() {
298            return Some("anon".to_string());
299        }
300
301        // Include role + claims fingerprint to avoid cross-scope cache bleed.
302        let mut roles = auth.roles().to_vec();
303        roles.sort();
304        roles.dedup();
305
306        let mut claims = BTreeMap::new();
307        for (k, v) in auth.claims() {
308            claims.insert(k.clone(), v.clone());
309        }
310
311        use std::collections::hash_map::DefaultHasher;
312        use std::hash::{Hash, Hasher};
313
314        let mut hasher = DefaultHasher::new();
315        roles.hash(&mut hasher);
316        serde_json::to_string(&claims)
317            .unwrap_or_default()
318            .hash(&mut hasher);
319
320        let principal = auth
321            .principal_id()
322            .unwrap_or_else(|| "authenticated".to_string());
323
324        Some(format!(
325            "subject:{principal}:scope:{:016x}",
326            hasher.finish()
327        ))
328    }
329
330    fn check_identity_args(
331        function_name: &str,
332        args: &Value,
333        auth: &AuthContext,
334        enforce_scope: bool,
335    ) -> Result<()> {
336        if auth.is_admin() {
337            return Ok(());
338        }
339
340        let Some(obj) = args.as_object() else {
341            if enforce_scope && auth.is_authenticated() {
342                return Err(ForgeError::Forbidden(format!(
343                    "Function '{function_name}' must include identity or tenant scope arguments"
344                )));
345            }
346            return Ok(());
347        };
348
349        let mut principal_values: Vec<String> = Vec::new();
350        if let Some(user_id) = auth.user_id().map(|id| id.to_string()) {
351            principal_values.push(user_id);
352        }
353        if let Some(subject) = auth.principal_id()
354            && !principal_values.iter().any(|v| v == &subject)
355        {
356            principal_values.push(subject);
357        }
358
359        let mut has_scope_key = false;
360
361        for key in [
362            "user_id",
363            "userId",
364            "owner_id",
365            "ownerId",
366            "owner_subject",
367            "ownerSubject",
368            "subject",
369            "sub",
370            "principal_id",
371            "principalId",
372        ] {
373            let Some(value) = obj.get(key) else {
374                continue;
375            };
376            has_scope_key = true;
377
378            if !auth.is_authenticated() {
379                return Err(ForgeError::Unauthorized(format!(
380                    "Function '{function_name}' requires authentication for identity-scoped argument '{key}'"
381                )));
382            }
383
384            let Value::String(actual) = value else {
385                return Err(ForgeError::InvalidArgument(format!(
386                    "Function '{function_name}' argument '{key}' must be a non-empty string"
387                )));
388            };
389
390            if actual.trim().is_empty() || !principal_values.iter().any(|v| v == actual) {
391                return Err(ForgeError::Forbidden(format!(
392                    "Function '{function_name}' argument '{key}' does not match authenticated principal"
393                )));
394            }
395        }
396
397        for key in ["tenant_id", "tenantId"] {
398            let Some(value) = obj.get(key) else {
399                continue;
400            };
401            has_scope_key = true;
402
403            if !auth.is_authenticated() {
404                return Err(ForgeError::Unauthorized(format!(
405                    "Function '{function_name}' requires authentication for tenant-scoped argument '{key}'"
406                )));
407            }
408
409            let expected = auth
410                .claim("tenant_id")
411                .and_then(|v| v.as_str())
412                .ok_or_else(|| {
413                    ForgeError::Forbidden(format!(
414                        "Function '{function_name}' argument '{key}' is not allowed for this principal"
415                    ))
416                })?;
417
418            let Value::String(actual) = value else {
419                return Err(ForgeError::InvalidArgument(format!(
420                    "Function '{function_name}' argument '{key}' must be a non-empty string"
421                )));
422            };
423
424            if actual.trim().is_empty() || actual != expected {
425                return Err(ForgeError::Forbidden(format!(
426                    "Function '{function_name}' argument '{key}' does not match authenticated tenant"
427                )));
428            }
429        }
430
431        if enforce_scope && auth.is_authenticated() && !has_scope_key {
432            return Err(ForgeError::Forbidden(format!(
433                "Function '{function_name}' must include identity or tenant scope arguments"
434            )));
435        }
436
437        Ok(())
438    }
439
440    /// Get the function kind by name.
441    pub fn get_function_kind(&self, function_name: &str) -> Option<FunctionKind> {
442        self.registry.get(function_name).map(|e| e.kind())
443    }
444
445    /// Check if a function exists.
446    pub fn has_function(&self, function_name: &str) -> bool {
447        self.registry.get(function_name).is_some()
448    }
449
450    async fn execute_transactional(
451        &self,
452        handler: &BoxedMutationFn,
453        args: Value,
454        auth: AuthContext,
455        request: RequestMetadata,
456    ) -> Result<RouteResult> {
457        // Use primary for transactional mutations
458        let primary = self.db.primary();
459        let tx = primary
460            .begin()
461            .await
462            .map_err(|e| ForgeError::Database(e.to_string()))?;
463
464        let job_dispatcher = self.job_dispatcher.clone();
465        let job_lookup: forge_core::JobInfoLookup =
466            Arc::new(move |name: &str| job_dispatcher.as_ref().and_then(|d| d.get_info(name)));
467
468        let (ctx, tx_handle, outbox) = MutationContext::with_transaction(
469            primary.clone(),
470            tx,
471            auth,
472            request,
473            self.http_client.clone(),
474            job_lookup,
475        );
476
477        match handler(&ctx, args).await {
478            Ok(value) => {
479                let buffer = {
480                    let guard = outbox.lock().expect("outbox mutex poisoned");
481                    OutboxBuffer {
482                        jobs: guard.jobs.clone(),
483                        workflows: guard.workflows.clone(),
484                    }
485                };
486
487                let mut tx = Arc::try_unwrap(tx_handle)
488                    .map_err(|_| ForgeError::Internal("Transaction still in use".into()))?
489                    .into_inner();
490
491                for job in &buffer.jobs {
492                    Self::insert_job(&mut tx, job).await?;
493                }
494
495                for workflow in &buffer.workflows {
496                    let version = self
497                        .workflow_dispatcher
498                        .as_ref()
499                        .and_then(|d| d.get_info(&workflow.workflow_name))
500                        .map(|info| info.version)
501                        .ok_or_else(|| {
502                            ForgeError::NotFound(format!(
503                                "Workflow '{}' not found",
504                                workflow.workflow_name
505                            ))
506                        })?;
507                    Self::insert_workflow(&mut tx, workflow, version).await?;
508                }
509
510                tx.commit()
511                    .await
512                    .map_err(|e| ForgeError::Database(e.to_string()))?;
513
514                Ok(RouteResult::Mutation(value))
515            }
516            Err(e) => Err(e),
517        }
518    }
519
520    async fn insert_job(
521        tx: &mut sqlx::Transaction<'_, sqlx::Postgres>,
522        job: &PendingJob,
523    ) -> Result<()> {
524        let now = Utc::now();
525        sqlx::query(
526            r#"
527            INSERT INTO forge_jobs (
528                id, job_type, input, job_context, status, priority, attempts, max_attempts,
529                worker_capability, owner_subject, scheduled_at, created_at
530            ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
531            "#,
532        )
533        .bind(job.id)
534        .bind(&job.job_type)
535        .bind(&job.args)
536        .bind(&job.context)
537        .bind(JobStatus::Pending.as_str())
538        .bind(job.priority)
539        .bind(0i32)
540        .bind(job.max_attempts)
541        .bind(&job.worker_capability)
542        .bind(&job.owner_subject)
543        .bind(now)
544        .bind(now)
545        .execute(&mut **tx)
546        .await
547        .map_err(|e| ForgeError::Database(e.to_string()))?;
548
549        Ok(())
550    }
551
552    async fn insert_workflow(
553        tx: &mut sqlx::Transaction<'_, sqlx::Postgres>,
554        workflow: &PendingWorkflow,
555        version: u32,
556    ) -> Result<()> {
557        let now = Utc::now();
558        sqlx::query(
559            r#"
560            INSERT INTO forge_workflow_runs (
561                id, workflow_name, version, owner_subject, input, status, current_step,
562                step_results, started_at, trace_id
563            ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
564            "#,
565        )
566        .bind(workflow.id)
567        .bind(&workflow.workflow_name)
568        .bind(version as i32)
569        .bind(&workflow.owner_subject)
570        .bind(&workflow.input)
571        .bind(WorkflowStatus::Created.as_str())
572        .bind(Option::<String>::None)
573        .bind(serde_json::json!({}))
574        .bind(now)
575        .bind(workflow.id.to_string())
576        .execute(&mut **tx)
577        .await
578        .map_err(|e| ForgeError::Database(e.to_string()))?;
579
580        Ok(())
581    }
582}
583
584#[cfg(test)]
585mod tests {
586    use super::*;
587    use std::collections::HashMap;
588
589    #[test]
590    fn test_check_auth_public() {
591        let info = FunctionInfo {
592            name: "test",
593            description: None,
594            kind: FunctionKind::Query,
595            required_role: None,
596            is_public: true,
597            cache_ttl: None,
598            timeout: None,
599            rate_limit_requests: None,
600            rate_limit_per_secs: None,
601            rate_limit_key: None,
602            log_level: None,
603            table_dependencies: &[],
604            transactional: false,
605        };
606
607        let _auth = AuthContext::unauthenticated();
608
609        // Can't test check_auth directly without a router instance,
610        // but we can test the logic
611        assert!(info.is_public);
612    }
613
614    #[test]
615    fn test_identity_args_reject_cross_user_value() {
616        let user_id = uuid::Uuid::new_v4();
617        let auth = AuthContext::authenticated(
618            user_id,
619            vec!["user".to_string()],
620            HashMap::from([(
621                "sub".to_string(),
622                serde_json::Value::String(user_id.to_string()),
623            )]),
624        );
625        let args = serde_json::json!({
626            "user_id": uuid::Uuid::new_v4().to_string()
627        });
628
629        let result = FunctionRouter::check_identity_args("list_orders", &args, &auth, true);
630        assert!(matches!(result, Err(ForgeError::Forbidden(_))));
631    }
632
633    #[test]
634    fn test_identity_args_allow_matching_subject() {
635        let sub = "user_123";
636        let auth = AuthContext::authenticated_without_uuid(
637            vec!["user".to_string()],
638            HashMap::from([(
639                "sub".to_string(),
640                serde_json::Value::String(sub.to_string()),
641            )]),
642        );
643        let args = serde_json::json!({
644            "subject": sub
645        });
646
647        let result = FunctionRouter::check_identity_args("list_orders", &args, &auth, true);
648        assert!(result.is_ok());
649    }
650
651    #[test]
652    fn test_identity_args_require_auth_for_identity_keys() {
653        let auth = AuthContext::unauthenticated();
654        let args = serde_json::json!({
655            "user_id": uuid::Uuid::new_v4().to_string()
656        });
657
658        let result = FunctionRouter::check_identity_args("list_orders", &args, &auth, true);
659        assert!(matches!(result, Err(ForgeError::Unauthorized(_))));
660    }
661
662    #[test]
663    fn test_identity_args_require_scope_for_non_public_calls() {
664        let user_id = uuid::Uuid::new_v4();
665        let auth = AuthContext::authenticated(
666            user_id,
667            vec!["user".to_string()],
668            HashMap::from([(
669                "sub".to_string(),
670                serde_json::Value::String(user_id.to_string()),
671            )]),
672        );
673
674        let result =
675            FunctionRouter::check_identity_args("list_orders", &serde_json::json!({}), &auth, true);
676        assert!(matches!(result, Err(ForgeError::Forbidden(_))));
677    }
678
679    #[test]
680    fn test_auth_cache_scope_changes_with_claims() {
681        let user_id = uuid::Uuid::new_v4();
682        let auth_a = AuthContext::authenticated(
683            user_id,
684            vec!["user".to_string()],
685            HashMap::from([
686                (
687                    "sub".to_string(),
688                    serde_json::Value::String(user_id.to_string()),
689                ),
690                (
691                    "tenant_id".to_string(),
692                    serde_json::Value::String("tenant-a".to_string()),
693                ),
694            ]),
695        );
696        let auth_b = AuthContext::authenticated(
697            user_id,
698            vec!["user".to_string()],
699            HashMap::from([
700                (
701                    "sub".to_string(),
702                    serde_json::Value::String(user_id.to_string()),
703                ),
704                (
705                    "tenant_id".to_string(),
706                    serde_json::Value::String("tenant-b".to_string()),
707                ),
708            ]),
709        );
710
711        let scope_a = FunctionRouter::auth_cache_scope(&auth_a);
712        let scope_b = FunctionRouter::auth_cache_scope(&auth_b);
713        assert_ne!(scope_a, scope_b);
714    }
715}