Skip to main content

forge_runtime/function/
router.rs

1use std::collections::BTreeMap;
2use std::sync::Arc;
3use std::time::Duration;
4
5use chrono::Utc;
6use forge_core::{
7    AuthContext, CircuitBreakerClient, ForgeError, FunctionInfo, FunctionKind, JobDispatch,
8    MutationContext, OutboxBuffer, PendingJob, PendingWorkflow, QueryContext, RequestMetadata,
9    Result, WorkflowDispatch,
10    job::JobStatus,
11    rate_limit::{RateLimitConfig, RateLimitKey},
12    workflow::WorkflowStatus,
13};
14use serde_json::Value;
15use tracing::Instrument;
16
17use super::cache::QueryCache;
18use super::registry::{BoxedMutationFn, FunctionEntry, FunctionRegistry};
19use crate::db::Database;
20use crate::rate_limit::HybridRateLimiter;
21
22/// Shared auth enforcement: checks public flag, authentication, and role.
23fn require_auth(is_public: bool, required_role: Option<&str>, auth: &AuthContext) -> Result<()> {
24    if is_public {
25        return Ok(());
26    }
27    if !auth.is_authenticated() {
28        return Err(ForgeError::Unauthorized("Authentication required".into()));
29    }
30    if let Some(role) = required_role
31        && !auth.has_role(role)
32    {
33        return Err(ForgeError::Forbidden(format!("Role '{role}' required")));
34    }
35    Ok(())
36}
37
38/// Result of routing a function call.
39pub enum RouteResult {
40    /// Query execution result.
41    Query(Value),
42    /// Mutation execution result.
43    Mutation(Value),
44    /// Job dispatch result (returns job_id).
45    Job(Value),
46    /// Workflow dispatch result (returns workflow_id).
47    Workflow(Value),
48}
49
50/// Routes function calls to the appropriate handler.
51pub struct FunctionRouter {
52    registry: Arc<FunctionRegistry>,
53    db: Database,
54    http_client: CircuitBreakerClient,
55    job_dispatcher: Option<Arc<dyn JobDispatch>>,
56    workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
57    rate_limiter: HybridRateLimiter,
58    query_cache: QueryCache,
59    token_issuer: Option<Arc<dyn forge_core::TokenIssuer>>,
60    token_ttl: forge_core::AuthTokenTtl,
61}
62
63impl FunctionRouter {
64    /// Create a new function router.
65    pub fn new(registry: Arc<FunctionRegistry>, db: Database) -> Self {
66        let rate_limiter = HybridRateLimiter::new(db.primary().clone());
67        Self {
68            registry,
69            db,
70            http_client: CircuitBreakerClient::with_defaults(reqwest::Client::new()),
71            job_dispatcher: None,
72            workflow_dispatcher: None,
73            rate_limiter,
74            query_cache: QueryCache::new(),
75            token_issuer: None,
76            token_ttl: forge_core::AuthTokenTtl::default(),
77        }
78    }
79
80    /// Create a new function router with a custom HTTP client.
81    pub fn with_http_client(
82        registry: Arc<FunctionRegistry>,
83        db: Database,
84        http_client: CircuitBreakerClient,
85    ) -> Self {
86        let rate_limiter = HybridRateLimiter::new(db.primary().clone());
87        Self {
88            registry,
89            db,
90            http_client,
91            job_dispatcher: None,
92            workflow_dispatcher: None,
93            rate_limiter,
94            query_cache: QueryCache::new(),
95            token_issuer: None,
96            token_ttl: forge_core::AuthTokenTtl::default(),
97        }
98    }
99
100    /// Set the token issuer for this router (enables `ctx.issue_token()` in mutations).
101    pub fn with_token_issuer(mut self, issuer: Arc<dyn forge_core::TokenIssuer>) -> Self {
102        self.token_issuer = Some(issuer);
103        self
104    }
105
106    /// Set the token TTL config for this router (configures `ctx.issue_token_pair()` durations).
107    pub fn with_token_ttl(mut self, ttl: forge_core::AuthTokenTtl) -> Self {
108        self.token_ttl = ttl;
109        self
110    }
111
112    /// Set the token TTL config (mutable reference version).
113    pub fn set_token_ttl(&mut self, ttl: forge_core::AuthTokenTtl) {
114        self.token_ttl = ttl;
115    }
116
117    /// Set the job dispatcher for this router.
118    pub fn with_job_dispatcher(mut self, dispatcher: Arc<dyn JobDispatch>) -> Self {
119        self.job_dispatcher = Some(dispatcher);
120        self
121    }
122
123    /// Set the workflow dispatcher for this router.
124    pub fn with_workflow_dispatcher(mut self, dispatcher: Arc<dyn WorkflowDispatch>) -> Self {
125        self.workflow_dispatcher = Some(dispatcher);
126        self
127    }
128
129    pub async fn route(
130        &self,
131        function_name: &str,
132        args: Value,
133        auth: AuthContext,
134        request: RequestMetadata,
135    ) -> Result<RouteResult> {
136        if let Some(entry) = self.registry.get(function_name) {
137            self.check_auth(entry.info(), &auth)?;
138            self.check_rate_limit(entry.info(), function_name, &auth, &request)
139                .await?;
140            // Skip scope enforcement for functions with no input args.
141            // The JWT still carries identity, accessible via ctx.require_user_id().
142            let enforce = !entry.info().is_public && entry.info().has_input_args;
143            auth.check_identity_args(function_name, &args, enforce)?;
144
145            return match entry {
146                FunctionEntry::Query { handler, info, .. } => {
147                    let pool = if info.consistent {
148                        self.db.primary().clone()
149                    } else {
150                        self.db.read_pool().clone()
151                    };
152
153                    let auth_scope = Self::auth_cache_scope(&auth);
154                    if let Some(ttl) = info.cache_ttl {
155                        if let Some(cached) =
156                            self.query_cache
157                                .get(function_name, &args, auth_scope.as_deref())
158                        {
159                            return Ok(RouteResult::Query(Value::clone(&cached)));
160                        }
161
162                        let ctx = QueryContext::new(pool, auth, request);
163                        let result = handler(&ctx, args.clone()).await?;
164
165                        self.query_cache.set(
166                            function_name,
167                            &args,
168                            auth_scope.as_deref(),
169                            result.clone(),
170                            Duration::from_secs(ttl),
171                        );
172
173                        Ok(RouteResult::Query(result))
174                    } else {
175                        let ctx = QueryContext::new(pool, auth, request);
176                        let result = handler(&ctx, args).await?;
177                        Ok(RouteResult::Query(result))
178                    }
179                }
180                FunctionEntry::Mutation { handler, info } => {
181                    if info.transactional {
182                        self.execute_transactional(info, handler, args, auth, request)
183                            .await
184                    } else {
185                        // Use primary for mutations
186                        let mut ctx = MutationContext::with_dispatch(
187                            self.db.primary().clone(),
188                            auth,
189                            request,
190                            self.http_client.clone(),
191                            self.job_dispatcher.clone(),
192                            self.workflow_dispatcher.clone(),
193                        );
194                        if let Some(ref issuer) = self.token_issuer {
195                            ctx.set_token_issuer(issuer.clone());
196                        }
197                        ctx.set_token_ttl(self.token_ttl.clone());
198                        ctx.set_http_timeout(info.http_timeout.map(Duration::from_secs));
199                        let result = handler(&ctx, args).await?;
200                        Ok(RouteResult::Mutation(result))
201                    }
202                }
203            };
204        }
205
206        if let Some(ref job_dispatcher) = self.job_dispatcher
207            && let Some(job_info) = job_dispatcher.get_info(function_name)
208        {
209            self.check_job_auth(&job_info, &auth)?;
210            auth.check_identity_args(function_name, &args, !job_info.is_public)?;
211            match job_dispatcher
212                .dispatch_by_name(function_name, args.clone(), auth.principal_id())
213                .await
214            {
215                Ok(job_id) => {
216                    return Ok(RouteResult::Job(serde_json::json!({ "job_id": job_id })));
217                }
218                Err(ForgeError::NotFound(_)) => {}
219                Err(e) => return Err(e),
220            }
221        }
222
223        if let Some(ref workflow_dispatcher) = self.workflow_dispatcher
224            && let Some(workflow_info) = workflow_dispatcher.get_info(function_name)
225        {
226            self.check_workflow_auth(&workflow_info, &auth)?;
227            auth.check_identity_args(function_name, &args, !workflow_info.is_public)?;
228            match workflow_dispatcher
229                .start_by_name(function_name, args.clone(), auth.principal_id())
230                .await
231            {
232                Ok(workflow_id) => {
233                    return Ok(RouteResult::Workflow(
234                        serde_json::json!({ "workflow_id": workflow_id }),
235                    ));
236                }
237                Err(ForgeError::NotFound(_)) => {}
238                Err(e) => return Err(e),
239            }
240        }
241
242        Err(ForgeError::NotFound(format!(
243            "Function '{}' not found",
244            function_name
245        )))
246    }
247
248    fn check_auth(&self, info: &FunctionInfo, auth: &AuthContext) -> Result<()> {
249        require_auth(info.is_public, info.required_role, auth)
250    }
251
252    fn check_job_auth(&self, info: &forge_core::job::JobInfo, auth: &AuthContext) -> Result<()> {
253        require_auth(info.is_public, info.required_role, auth)
254    }
255
256    fn check_workflow_auth(
257        &self,
258        info: &forge_core::workflow::WorkflowInfo,
259        auth: &AuthContext,
260    ) -> Result<()> {
261        require_auth(info.is_public, info.required_role, auth)
262    }
263
264    /// Check rate limit for a function call.
265    async fn check_rate_limit(
266        &self,
267        info: &FunctionInfo,
268        function_name: &str,
269        auth: &AuthContext,
270        request: &RequestMetadata,
271    ) -> Result<()> {
272        // Skip if no rate limit configured
273        let (requests, per_secs) = match (info.rate_limit_requests, info.rate_limit_per_secs) {
274            (Some(r), Some(p)) => (r, p),
275            _ => return Ok(()),
276        };
277
278        // Build rate limit config
279        let key_str = info.rate_limit_key.unwrap_or("user");
280        let key_type: RateLimitKey = match key_str.parse() {
281            Ok(k) => k,
282            Err(_) => {
283                tracing::error!(
284                    function = %function_name,
285                    key = %key_str,
286                    "Invalid rate limit key, falling back to 'user'"
287                );
288                RateLimitKey::default()
289            }
290        };
291
292        let config =
293            RateLimitConfig::new(requests, Duration::from_secs(per_secs)).with_key(key_type);
294
295        // Build bucket key
296        let bucket_key = self
297            .rate_limiter
298            .build_key(key_type, function_name, auth, request);
299
300        // Enforce rate limit
301        self.rate_limiter.enforce(&bucket_key, &config).await?;
302
303        Ok(())
304    }
305
306    fn auth_cache_scope(auth: &AuthContext) -> Option<String> {
307        if !auth.is_authenticated() {
308            return Some("anon".to_string());
309        }
310
311        // Include role + claims fingerprint to avoid cross-scope cache bleed.
312        let mut roles = auth.roles().to_vec();
313        roles.sort();
314        roles.dedup();
315
316        let mut claims = BTreeMap::new();
317        for (k, v) in auth.claims() {
318            claims.insert(k.clone(), v.clone());
319        }
320
321        use std::collections::hash_map::DefaultHasher;
322        use std::hash::{Hash, Hasher};
323
324        let mut hasher = DefaultHasher::new();
325        roles.hash(&mut hasher);
326        serde_json::to_string(&claims)
327            .unwrap_or_default()
328            .hash(&mut hasher);
329
330        let principal = auth
331            .principal_id()
332            .unwrap_or_else(|| "authenticated".to_string());
333
334        Some(format!(
335            "subject:{principal}:scope:{:016x}",
336            hasher.finish()
337        ))
338    }
339
340    /// Get the function kind by name.
341    pub fn get_function_kind(&self, function_name: &str) -> Option<FunctionKind> {
342        self.registry.get(function_name).map(|e| e.kind())
343    }
344
345    /// Check if a function exists.
346    pub fn has_function(&self, function_name: &str) -> bool {
347        self.registry.get(function_name).is_some()
348    }
349
350    async fn execute_transactional(
351        &self,
352        info: &FunctionInfo,
353        handler: &BoxedMutationFn,
354        args: Value,
355        auth: AuthContext,
356        request: RequestMetadata,
357    ) -> Result<RouteResult> {
358        let span = tracing::info_span!("db.transaction", db.system = "postgresql",);
359
360        async {
361            let primary = self.db.primary();
362            let tx = primary
363                .begin()
364                .await
365                .map_err(|e| ForgeError::Database(e.to_string()))?;
366
367            let job_dispatcher = self.job_dispatcher.clone();
368            let job_lookup: forge_core::JobInfoLookup =
369                Arc::new(move |name: &str| job_dispatcher.as_ref().and_then(|d| d.get_info(name)));
370
371            let (mut ctx, tx_handle, outbox) = MutationContext::with_transaction(
372                primary.clone(),
373                tx,
374                auth,
375                request,
376                self.http_client.clone(),
377                job_lookup,
378            );
379            if let Some(ref issuer) = self.token_issuer {
380                ctx.set_token_issuer(issuer.clone());
381            }
382            ctx.set_token_ttl(self.token_ttl.clone());
383            ctx.set_http_timeout(info.http_timeout.map(Duration::from_secs));
384
385            match handler(&ctx, args).await {
386                Ok(value) => {
387                    let buffer = {
388                        let guard = outbox.lock().unwrap_or_else(|poisoned| {
389                            tracing::error!("Outbox mutex was poisoned, recovering");
390                            poisoned.into_inner()
391                        });
392                        OutboxBuffer {
393                            jobs: guard.jobs.clone(),
394                            workflows: guard.workflows.clone(),
395                        }
396                    };
397
398                    let mut tx = Arc::try_unwrap(tx_handle)
399                        .map_err(|_| ForgeError::Internal("Transaction still in use".into()))?
400                        .into_inner();
401
402                    for job in &buffer.jobs {
403                        Self::insert_job(&mut tx, job).await?;
404                    }
405
406                    for workflow in &buffer.workflows {
407                        let version = self
408                            .workflow_dispatcher
409                            .as_ref()
410                            .and_then(|d| d.get_info(&workflow.workflow_name))
411                            .map(|info| info.version)
412                            .ok_or_else(|| {
413                                ForgeError::NotFound(format!(
414                                    "Workflow '{}' not found",
415                                    workflow.workflow_name
416                                ))
417                            })?;
418                        Self::insert_workflow(&mut tx, workflow, version).await?;
419                    }
420
421                    tx.commit()
422                        .await
423                        .map_err(|e| ForgeError::Database(e.to_string()))?;
424
425                    Ok(RouteResult::Mutation(value))
426                }
427                Err(e) => Err(e),
428            }
429        }
430        .instrument(span)
431        .await
432    }
433
434    async fn insert_job(
435        tx: &mut sqlx::Transaction<'_, sqlx::Postgres>,
436        job: &PendingJob,
437    ) -> Result<()> {
438        let now = Utc::now();
439        sqlx::query(
440            r#"
441            INSERT INTO forge_jobs (
442                id, job_type, input, job_context, status, priority, attempts, max_attempts,
443                worker_capability, owner_subject, scheduled_at, created_at
444            ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
445            "#,
446        )
447        .bind(job.id)
448        .bind(&job.job_type)
449        .bind(&job.args)
450        .bind(&job.context)
451        .bind(JobStatus::Pending.as_str())
452        .bind(job.priority)
453        .bind(0i32)
454        .bind(job.max_attempts)
455        .bind(&job.worker_capability)
456        .bind(&job.owner_subject)
457        .bind(now)
458        .bind(now)
459        .execute(&mut **tx)
460        .await
461        .map_err(|e| ForgeError::Database(e.to_string()))?;
462
463        Ok(())
464    }
465
466    async fn insert_workflow(
467        tx: &mut sqlx::Transaction<'_, sqlx::Postgres>,
468        workflow: &PendingWorkflow,
469        version: u32,
470    ) -> Result<()> {
471        let now = Utc::now();
472        sqlx::query(
473            r#"
474            INSERT INTO forge_workflow_runs (
475                id, workflow_name, version, owner_subject, input, status, current_step,
476                step_results, started_at, trace_id
477            ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
478            "#,
479        )
480        .bind(workflow.id)
481        .bind(&workflow.workflow_name)
482        .bind(version as i32)
483        .bind(&workflow.owner_subject)
484        .bind(&workflow.input)
485        .bind(WorkflowStatus::Created.as_str())
486        .bind(Option::<String>::None)
487        .bind(serde_json::json!({}))
488        .bind(now)
489        .bind(workflow.id.to_string())
490        .execute(&mut **tx)
491        .await
492        .map_err(|e| ForgeError::Database(e.to_string()))?;
493
494        Ok(())
495    }
496}
497
498#[cfg(test)]
499mod tests {
500    use super::*;
501    use std::collections::HashMap;
502
503    #[test]
504    fn test_check_auth_public() {
505        let info = FunctionInfo {
506            name: "test",
507            description: None,
508            kind: FunctionKind::Query,
509            required_role: None,
510            is_public: true,
511            cache_ttl: None,
512            timeout: None,
513            http_timeout: None,
514            rate_limit_requests: None,
515            rate_limit_per_secs: None,
516            rate_limit_key: None,
517            log_level: None,
518            table_dependencies: &[],
519            selected_columns: &[],
520            transactional: false,
521            consistent: false,
522            has_input_args: false,
523        };
524
525        let _auth = AuthContext::unauthenticated();
526
527        // Can't test check_auth directly without a router instance,
528        // but we can test the logic
529        assert!(info.is_public);
530    }
531
532    #[test]
533    fn test_identity_args_reject_cross_user_value() {
534        let user_id = uuid::Uuid::new_v4();
535        let auth = AuthContext::authenticated(
536            user_id,
537            vec!["user".to_string()],
538            HashMap::from([(
539                "sub".to_string(),
540                serde_json::Value::String(user_id.to_string()),
541            )]),
542        );
543        let args = serde_json::json!({
544            "user_id": uuid::Uuid::new_v4().to_string()
545        });
546
547        let result = auth.check_identity_args("list_orders", &args, true);
548        assert!(matches!(result, Err(ForgeError::Forbidden(_))));
549    }
550
551    #[test]
552    fn test_identity_args_allow_matching_subject() {
553        let sub = "user_123";
554        let auth = AuthContext::authenticated_without_uuid(
555            vec!["user".to_string()],
556            HashMap::from([(
557                "sub".to_string(),
558                serde_json::Value::String(sub.to_string()),
559            )]),
560        );
561        let args = serde_json::json!({
562            "subject": sub
563        });
564
565        let result = auth.check_identity_args("list_orders", &args, true);
566        assert!(result.is_ok());
567    }
568
569    #[test]
570    fn test_identity_args_require_auth_for_identity_keys() {
571        let auth = AuthContext::unauthenticated();
572        let args = serde_json::json!({
573            "user_id": uuid::Uuid::new_v4().to_string()
574        });
575
576        let result = auth.check_identity_args("list_orders", &args, true);
577        assert!(matches!(result, Err(ForgeError::Unauthorized(_))));
578    }
579
580    #[test]
581    fn test_identity_args_require_scope_for_non_public_calls() {
582        let user_id = uuid::Uuid::new_v4();
583        let auth = AuthContext::authenticated(
584            user_id,
585            vec!["user".to_string()],
586            HashMap::from([(
587                "sub".to_string(),
588                serde_json::Value::String(user_id.to_string()),
589            )]),
590        );
591
592        let result = auth.check_identity_args("list_orders", &serde_json::json!({}), true);
593        assert!(matches!(result, Err(ForgeError::Forbidden(_))));
594    }
595
596    #[test]
597    fn test_identity_args_skip_scope_for_no_input_functions() {
598        let user_id = uuid::Uuid::new_v4();
599        let auth = AuthContext::authenticated(
600            user_id,
601            vec!["user".to_string()],
602            HashMap::from([(
603                "sub".to_string(),
604                serde_json::Value::String(user_id.to_string()),
605            )]),
606        );
607
608        // enforce_scope=false simulates has_input_args=false
609        let result = auth.check_identity_args("list_todos", &serde_json::Value::Null, false);
610        assert!(result.is_ok());
611    }
612
613    #[test]
614    fn test_auth_cache_scope_changes_with_claims() {
615        let user_id = uuid::Uuid::new_v4();
616        let auth_a = AuthContext::authenticated(
617            user_id,
618            vec!["user".to_string()],
619            HashMap::from([
620                (
621                    "sub".to_string(),
622                    serde_json::Value::String(user_id.to_string()),
623                ),
624                (
625                    "tenant_id".to_string(),
626                    serde_json::Value::String("tenant-a".to_string()),
627                ),
628            ]),
629        );
630        let auth_b = AuthContext::authenticated(
631            user_id,
632            vec!["user".to_string()],
633            HashMap::from([
634                (
635                    "sub".to_string(),
636                    serde_json::Value::String(user_id.to_string()),
637                ),
638                (
639                    "tenant_id".to_string(),
640                    serde_json::Value::String("tenant-b".to_string()),
641                ),
642            ]),
643        );
644
645        let scope_a = FunctionRouter::auth_cache_scope(&auth_a);
646        let scope_b = FunctionRouter::auth_cache_scope(&auth_b);
647        assert_ne!(scope_a, scope_b);
648    }
649}