Skip to main content

forge_runtime/function/
router.rs

1use std::sync::Arc;
2use std::time::Duration;
3
4use forge_core::{
5    AuthContext, CircuitBreakerClient, ForgeError, FunctionInfo, FunctionKind, JobDispatch,
6    KvHandle, MutationContext, QueryContext, RequestMetadata, Result, SharedRoleResolver,
7    WorkflowDispatch, default_role_resolver,
8    rate_limit::{RateLimitConfig, RateLimiterBackend},
9};
10use serde_json::Value;
11use tokio::time::timeout;
12use tracing::Instrument;
13
14use super::cache::QueryCacheCoordinator;
15use super::execution_log::{level_for as log_level_for, log_completion};
16use super::registry::{BoxedMutationFn, FunctionEntry, FunctionRegistry};
17#[cfg(feature = "gateway")]
18use super::rpc_signals::{RpcSignalContext, RpcSignalsEmitter};
19use crate::pg::Database;
20use crate::rate_limit::HybridRateLimiter;
21#[cfg(feature = "gateway")]
22use crate::signals::SignalsCollector;
23
24/// Shared auth enforcement: checks public flag, authentication, and role.
25///
26/// When a `RoleResolver` is provided, roles are resolved from JWT claims
27/// before the `require_role` check. This allows hierarchy expansion or
28/// remote permission lookups without changing the handler surface.
29fn require_auth(
30    is_public: bool,
31    required_role: Option<&str>,
32    auth: &AuthContext,
33    role_resolver: &SharedRoleResolver,
34) -> Result<()> {
35    if is_public {
36        return Ok(());
37    }
38    if !auth.is_authenticated() {
39        return Err(ForgeError::Unauthorized("Authentication required".into()));
40    }
41    if let Some(role) = required_role {
42        let effective_roles = role_resolver.resolve(auth);
43        if !effective_roles.iter().any(|r| r == role) {
44            return Err(ForgeError::Forbidden(format!("Role '{role}' required")));
45        }
46    }
47    Ok(())
48}
49
50/// Result of routing a function call.
51pub enum RouteResult {
52    /// Query execution result (Arc to avoid cloning cached values).
53    Query(Arc<Value>),
54    /// Mutation execution result.
55    Mutation(Value),
56    /// Job dispatch result (returns job_id).
57    Job(Value),
58    /// Workflow dispatch result (returns workflow_id).
59    Workflow(Value),
60}
61
62/// Result of routing a function call paired with telemetry the executor
63/// wants to forward to spans/metrics. The cache flag is meaningful only for
64/// queries; every other variant returns `cache_hit = false`.
65pub struct RouteOutcome {
66    pub result: RouteResult,
67    pub cache_hit: bool,
68}
69
70/// Shared mutation dependencies cloned once per request instead of per-field.
71#[derive(Clone)]
72struct MutationDeps {
73    http_client: CircuitBreakerClient,
74    job_dispatcher: Option<Arc<dyn JobDispatch>>,
75    workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
76    token_issuer: Option<Arc<dyn forge_core::TokenIssuer>>,
77    token_ttl: forge_core::AuthTokenTtl,
78    max_jobs_per_request: usize,
79    kv: Option<Arc<dyn KvHandle>>,
80}
81
82/// Routes and executes function calls with timeout, rate limiting, and observability.
83pub struct FunctionRouter {
84    registry: Arc<FunctionRegistry>,
85    db: Database,
86    mutation_deps: Arc<MutationDeps>,
87    rate_limiter: Arc<dyn RateLimiterBackend>,
88    role_resolver: SharedRoleResolver,
89    cache: Arc<QueryCacheCoordinator>,
90    default_timeout: Duration,
91    /// Maximum serialized response size in bytes (0 = unlimited).
92    max_result_size_bytes: usize,
93    #[cfg(feature = "gateway")]
94    signals: Option<RpcSignalsEmitter>,
95}
96
97impl FunctionRouter {
98    /// Create a new function router.
99    pub fn new(registry: Arc<FunctionRegistry>, db: Database) -> Self {
100        Self::with_http_client(registry, db, CircuitBreakerClient::with_ssrf_protection())
101    }
102
103    /// Create a new function router with a custom HTTP client.
104    pub fn with_http_client(
105        registry: Arc<FunctionRegistry>,
106        db: Database,
107        http_client: CircuitBreakerClient,
108    ) -> Self {
109        let rate_limiter: Arc<dyn RateLimiterBackend> =
110            Arc::new(HybridRateLimiter::new(db.primary().clone()));
111        let cache = Arc::new(QueryCacheCoordinator::new(&registry));
112        Self {
113            registry,
114            db,
115            mutation_deps: Arc::new(MutationDeps {
116                http_client,
117                job_dispatcher: None,
118                workflow_dispatcher: None,
119                token_issuer: None,
120                token_ttl: forge_core::AuthTokenTtl::default(),
121                max_jobs_per_request: 0,
122                kv: None,
123            }),
124            rate_limiter,
125            role_resolver: default_role_resolver(),
126            cache,
127            default_timeout: Duration::from_secs(30),
128            max_result_size_bytes: 0,
129            #[cfg(feature = "gateway")]
130            signals: None,
131        }
132    }
133
134    /// Create a router with dispatch capabilities.
135    pub fn with_dispatch(
136        registry: Arc<FunctionRegistry>,
137        db: Database,
138        job_dispatcher: Option<Arc<dyn JobDispatch>>,
139        workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
140    ) -> Self {
141        Self::with_dispatch_and_issuer(registry, db, job_dispatcher, workflow_dispatcher, None)
142    }
143
144    /// Create a router with dispatch and token issuer.
145    pub fn with_dispatch_and_issuer(
146        registry: Arc<FunctionRegistry>,
147        db: Database,
148        job_dispatcher: Option<Arc<dyn JobDispatch>>,
149        workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
150        token_issuer: Option<Arc<dyn forge_core::TokenIssuer>>,
151    ) -> Self {
152        let mut router = Self::new(Arc::clone(&registry), db);
153        if let Some(jd) = job_dispatcher {
154            router = router.with_job_dispatcher(jd);
155        }
156        if let Some(wd) = workflow_dispatcher {
157            router = router.with_workflow_dispatcher(wd);
158        }
159        if let Some(issuer) = token_issuer {
160            router = router.with_token_issuer(issuer);
161        }
162        router
163    }
164
165    /// Set a custom role resolver for RBAC extension.
166    pub fn with_role_resolver(mut self, resolver: SharedRoleResolver) -> Self {
167        self.role_resolver = resolver;
168        self
169    }
170
171    /// Set a custom role resolver (mutable reference version).
172    pub fn set_role_resolver(&mut self, resolver: SharedRoleResolver) {
173        self.role_resolver = resolver;
174    }
175
176    /// Override the default [`HybridRateLimiter`] with a custom backend
177    /// (e.g. [`crate::rate_limit::StrictRateLimiter`] for cluster-correct quotas).
178    pub fn with_rate_limiter(mut self, rate_limiter: Arc<dyn RateLimiterBackend>) -> Self {
179        self.rate_limiter = rate_limiter;
180        self
181    }
182
183    /// Replace the rate-limiter backend (mutable variant for late binding).
184    pub fn set_rate_limiter(&mut self, rate_limiter: Arc<dyn RateLimiterBackend>) {
185        self.rate_limiter = rate_limiter;
186    }
187
188    /// Get a mutable reference to the inner mutation deps for builder methods.
189    fn deps_mut(&mut self) -> &mut MutationDeps {
190        Arc::make_mut(&mut self.mutation_deps)
191    }
192
193    /// Set the token issuer for this router (enables `ctx.issue_token()` in mutations).
194    pub fn with_token_issuer(mut self, issuer: Arc<dyn forge_core::TokenIssuer>) -> Self {
195        self.deps_mut().token_issuer = Some(issuer);
196        self
197    }
198
199    /// Set the token TTL config for this router (configures `ctx.issue_token_pair()` durations).
200    pub fn with_token_ttl(mut self, ttl: forge_core::AuthTokenTtl) -> Self {
201        self.deps_mut().token_ttl = ttl;
202        self
203    }
204
205    /// Set the token TTL config (mutable reference version).
206    pub fn set_token_ttl(&mut self, ttl: forge_core::AuthTokenTtl) {
207        self.deps_mut().token_ttl = ttl;
208    }
209
210    /// Set the job dispatcher for this router.
211    pub fn with_job_dispatcher(mut self, dispatcher: Arc<dyn JobDispatch>) -> Self {
212        self.deps_mut().job_dispatcher = Some(dispatcher);
213        self
214    }
215
216    /// Set the workflow dispatcher for this router.
217    pub fn with_workflow_dispatcher(mut self, dispatcher: Arc<dyn WorkflowDispatch>) -> Self {
218        self.deps_mut().workflow_dispatcher = Some(dispatcher);
219        self
220    }
221
222    /// Attach a KV store handle so handlers can call `ctx.kv()`.
223    pub fn with_kv(mut self, kv: Arc<dyn KvHandle>) -> Self {
224        self.deps_mut().kv = Some(kv);
225        self
226    }
227
228    /// Attach a KV store handle (mutable reference version).
229    pub fn set_kv(&mut self, kv: Arc<dyn KvHandle>) {
230        self.deps_mut().kv = Some(kv);
231    }
232
233    /// Set the default timeout applied to all function calls.
234    pub fn with_default_timeout(mut self, timeout: Duration) -> Self {
235        self.default_timeout = timeout;
236        self
237    }
238
239    /// Set the maximum number of jobs a single mutation may dispatch.
240    /// A value of 0 disables the limit.
241    pub fn set_max_jobs_per_request(&mut self, limit: usize) {
242        self.deps_mut().max_jobs_per_request = limit;
243    }
244
245    /// Set the maximum serialized response size in bytes.
246    /// A value of 0 disables the limit.
247    pub fn set_max_result_size_bytes(&mut self, limit: usize) {
248        self.max_result_size_bytes = limit;
249    }
250
251    /// Set the signals collector for auto-capturing RPC events.
252    #[cfg(feature = "gateway")]
253    pub fn set_signals_collector(&mut self, collector: SignalsCollector, server_secret: String) {
254        self.signals = Some(RpcSignalsEmitter::new(collector, server_secret));
255    }
256
257    /// Execute a function call with timeout, observability, and signals emission.
258    pub async fn execute(
259        &self,
260        function_name: &str,
261        args: Value,
262        auth: AuthContext,
263        request: RequestMetadata,
264    ) -> Result<Value> {
265        let start = std::time::Instant::now();
266        let info = self.registry.get(function_name).map(|e| e.info());
267        let fn_timeout = info.and_then(|i| i.timeout).unwrap_or(self.default_timeout);
268        let log_level = log_level_for(info);
269
270        let kind = info.map(|i| i.kind.as_str()).unwrap_or("unknown");
271
272        // Capture signal metadata before auth/request are consumed.
273        #[cfg(feature = "gateway")]
274        let mut signal_ctx = self
275            .signals
276            .as_ref()
277            .map(|_| RpcSignalContext::capture(&auth, &request));
278
279        // Declare cache.hit as Empty so the inner cache branch can fill it
280        // via Span::current().record(...). Latency p99 reported for this span
281        // is then attributable to either real handler work or a pure cache
282        // round-trip without ambiguity.
283        let span = tracing::info_span!(
284            "fn.execute",
285            function = function_name,
286            fn.kind = %kind,
287            cache.hit = tracing::field::Empty,
288        );
289
290        let result = match timeout(
291            fn_timeout,
292            self.route(function_name, args.clone(), auth, request)
293                .instrument(span),
294        )
295        .await
296        {
297            Ok(result) => result,
298            Err(_) => {
299                let duration = start.elapsed();
300                log_completion(
301                    log_level,
302                    function_name,
303                    "unknown",
304                    &args,
305                    duration,
306                    false,
307                    Some(&format!("Timeout after {:?}", fn_timeout)),
308                );
309                crate::observability::record_fn_execution(
310                    function_name,
311                    kind,
312                    false,
313                    false,
314                    duration.as_secs_f64(),
315                );
316                #[cfg(feature = "gateway")]
317                if let (Some(emitter), Some(ctx)) = (&self.signals, signal_ctx.take()) {
318                    emitter.emit(function_name, kind, duration, false, ctx);
319                }
320                return Err(ForgeError::Timeout(format!(
321                    "Function '{}' timed out after {:?}",
322                    function_name, fn_timeout
323                )));
324            }
325        };
326
327        let duration = start.elapsed();
328
329        match result {
330            Ok(outcome) => {
331                let RouteOutcome { result, cache_hit } = outcome;
332                let (result_kind, value) = match result {
333                    RouteResult::Query(arc) => {
334                        let v = Arc::try_unwrap(arc).unwrap_or_else(|a| Value::clone(&a));
335                        ("query", v)
336                    }
337                    RouteResult::Mutation(v) => ("mutation", v),
338                    RouteResult::Job(v) => ("job", v),
339                    RouteResult::Workflow(v) => ("workflow", v),
340                };
341
342                log_completion(
343                    log_level,
344                    function_name,
345                    result_kind,
346                    &args,
347                    duration,
348                    true,
349                    None,
350                );
351                crate::observability::record_fn_execution(
352                    function_name,
353                    result_kind,
354                    true,
355                    cache_hit,
356                    duration.as_secs_f64(),
357                );
358                #[cfg(feature = "gateway")]
359                if let (Some(emitter), Some(ctx)) = (&self.signals, signal_ctx.take()) {
360                    emitter.emit(function_name, result_kind, duration, true, ctx);
361                }
362
363                Ok(value)
364            }
365            Err(e) => {
366                log_completion(
367                    log_level,
368                    function_name,
369                    kind,
370                    &args,
371                    duration,
372                    false,
373                    Some(&e.to_string()),
374                );
375                crate::observability::record_fn_execution(
376                    function_name,
377                    kind,
378                    false,
379                    false,
380                    duration.as_secs_f64(),
381                );
382                #[cfg(feature = "gateway")]
383                if let (Some(emitter), Some(ctx)) = (&self.signals, signal_ctx.take()) {
384                    emitter.emit(function_name, kind, duration, false, ctx);
385                }
386
387                Err(e)
388            }
389        }
390    }
391
392    /// Look up function metadata by name.
393    pub fn function_info(&self, function_name: &str) -> Option<FunctionInfo> {
394        self.registry.get(function_name).map(|e| e.info().clone())
395    }
396
397    /// Check if a function exists.
398    pub fn has_function(&self, function_name: &str) -> bool {
399        self.registry.get(function_name).is_some()
400    }
401
402    /// Get the function kind by name.
403    pub fn get_function_kind(&self, function_name: &str) -> Option<FunctionKind> {
404        self.registry.get(function_name).map(|e| e.kind())
405    }
406
407    /// Return info for all registered query and mutation functions.
408    pub fn function_infos(&self) -> Vec<FunctionInfo> {
409        self.registry
410            .functions()
411            .map(|(_, entry)| entry.info().clone())
412            .collect()
413    }
414
415    /// Shared handle to the query cache coordinator (used to wire cluster invalidation).
416    pub fn cache(&self) -> Arc<QueryCacheCoordinator> {
417        Arc::clone(&self.cache)
418    }
419
420    /// Reject a result value when its serialized size exceeds `max_result_size_bytes`.
421    ///
422    /// A limit of 0 means unlimited.
423    fn check_result_size(&self, value: &Value) -> Result<()> {
424        if self.max_result_size_bytes == 0 {
425            return Ok(());
426        }
427        let serialized_len = json_byte_length(value);
428        if serialized_len > self.max_result_size_bytes {
429            return Err(ForgeError::internal(format!(
430                "Response size {} bytes exceeds max_result_size_bytes limit of {} bytes",
431                serialized_len, self.max_result_size_bytes
432            )));
433        }
434        Ok(())
435    }
436
437    pub async fn route(
438        &self,
439        function_name: &str,
440        args: Value,
441        auth: AuthContext,
442        request: RequestMetadata,
443    ) -> Result<RouteOutcome> {
444        if let Some(entry) = self.registry.get(function_name) {
445            let info = entry.info();
446            require_auth(
447                info.is_public,
448                info.required_role,
449                &auth,
450                &self.role_resolver,
451            )?;
452            if info.requires_tenant_scope && auth.tenant_id().is_none() {
453                return Err(ForgeError::Forbidden(
454                    "this function requires a tenant scope but the auth context has no tenant_id \
455                     claim"
456                        .to_string(),
457                ));
458            }
459            self.check_rate_limit(info, function_name, &auth, &request)
460                .await?;
461
462            return match entry {
463                FunctionEntry::Webhook { info } => {
464                    // Webhooks are registered in the function registry for
465                    // metadata access only. They must be called via their
466                    // dedicated HTTP path which performs signature validation.
467                    return Err(ForgeError::InvalidArgument(format!(
468                        "Webhook '{}' cannot be called via RPC; use its dedicated HTTP endpoint",
469                        info.name
470                    )));
471                }
472                FunctionEntry::Query { handler, info, .. } => {
473                    let pool = if info.consistent {
474                        self.db.primary().clone()
475                    } else {
476                        self.db.read_pool().clone()
477                    };
478
479                    if !info.consistent
480                        && let Some(ttl) = info.cache_ttl
481                    {
482                        // Derive scope once, before auth is moved into ctx, so
483                        // get/set agree on the same cache key.
484                        let scope = QueryCacheCoordinator::auth_scope(&auth);
485                        if let Some(cached) =
486                            self.cache
487                                .get_by_scope(function_name, &args, scope.as_deref())
488                        {
489                            tracing::Span::current().record("cache.hit", true);
490                            crate::observability::record_fn_cache(function_name, true);
491                            return Ok(RouteOutcome {
492                                result: RouteResult::Query(cached),
493                                cache_hit: true,
494                            });
495                        }
496                        tracing::Span::current().record("cache.hit", false);
497                        crate::observability::record_fn_cache(function_name, false);
498
499                        let mut ctx = QueryContext::new(pool, auth, request);
500                        if let Some(ref kv) = self.mutation_deps.kv {
501                            ctx.set_kv(Arc::clone(kv));
502                        }
503                        let result = handler(&ctx, args.clone()).await?;
504                        self.check_result_size(&result)?;
505
506                        let arc = Arc::new(result);
507                        self.cache.set_arc_by_scope(
508                            function_name,
509                            &args,
510                            scope.as_deref(),
511                            Arc::clone(&arc),
512                            Duration::from_secs(ttl),
513                        );
514
515                        Ok(RouteOutcome {
516                            result: RouteResult::Query(arc),
517                            cache_hit: false,
518                        })
519                    } else {
520                        let mut ctx = QueryContext::new(pool, auth, request);
521                        if let Some(ref kv) = self.mutation_deps.kv {
522                            ctx.set_kv(Arc::clone(kv));
523                        }
524                        let result = handler(&ctx, args).await?;
525                        self.check_result_size(&result)?;
526                        Ok(RouteOutcome {
527                            result: RouteResult::Query(Arc::new(result)),
528                            cache_hit: false,
529                        })
530                    }
531                }
532                FunctionEntry::Mutation { handler, info } => {
533                    let result = if info.transactional {
534                        self.execute_transactional(info, handler, args, auth, request)
535                            .await
536                    } else {
537                        let deps = Arc::clone(&self.mutation_deps);
538                        let mut ctx = MutationContext::with_dispatch(
539                            self.db.primary().clone(),
540                            auth,
541                            request,
542                            deps.http_client.clone(),
543                            deps.job_dispatcher.clone(),
544                            deps.workflow_dispatcher.clone(),
545                        );
546                        if let Some(ref issuer) = deps.token_issuer {
547                            ctx.set_token_issuer(issuer.clone());
548                        }
549                        ctx.set_token_ttl(deps.token_ttl.clone());
550                        ctx.set_http_timeout(info.http_timeout);
551                        if deps.max_jobs_per_request > 0 {
552                            ctx.set_max_jobs_per_request(deps.max_jobs_per_request);
553                        }
554                        if let Some(ref kv) = deps.kv {
555                            ctx.set_kv(Arc::clone(kv));
556                        }
557                        let value = handler(&ctx, args).await?;
558                        self.check_result_size(&value)?;
559                        Ok(RouteResult::Mutation(value))
560                    };
561                    // Invalidation runs here, AFTER commit (execute_transactional
562                    // commits before returning). The cluster-wide NOTIFY also fires
563                    // post-commit, ensuring peer nodes invalidate too.
564                    if result.is_ok() {
565                        self.cache.invalidate_for_mutation(info);
566                    }
567                    result.map(|r| RouteOutcome {
568                        result: r,
569                        cache_hit: false,
570                    })
571                }
572            };
573        }
574
575        if let Some(ref job_dispatcher) = self.mutation_deps.job_dispatcher
576            && let Some(job_info) = job_dispatcher.get_info(function_name)
577        {
578            require_auth(
579                job_info.is_public,
580                job_info.required_role,
581                &auth,
582                &self.role_resolver,
583            )?;
584            match job_dispatcher
585                .dispatch_by_name(
586                    function_name,
587                    args.clone(),
588                    auth.principal_id(),
589                    auth.tenant_id(),
590                )
591                .await
592            {
593                Ok(job_id) => {
594                    return Ok(RouteOutcome {
595                        result: RouteResult::Job(serde_json::json!({ "job_id": job_id })),
596                        cache_hit: false,
597                    });
598                }
599                Err(ForgeError::NotFound(_)) => {}
600                Err(e) => return Err(e),
601            }
602        }
603
604        if let Some(ref workflow_dispatcher) = self.mutation_deps.workflow_dispatcher
605            && let Some(workflow_info) = workflow_dispatcher.get_info(function_name)
606        {
607            require_auth(
608                workflow_info.is_public,
609                workflow_info.required_role,
610                &auth,
611                &self.role_resolver,
612            )?;
613            match workflow_dispatcher
614                .start_by_name(
615                    function_name,
616                    args,
617                    auth.principal_id(),
618                    Some(request.trace_id().to_string()),
619                )
620                .await
621            {
622                Ok(workflow_id) => {
623                    return Ok(RouteOutcome {
624                        result: RouteResult::Workflow(
625                            serde_json::json!({ "workflow_id": workflow_id }),
626                        ),
627                        cache_hit: false,
628                    });
629                }
630                Err(ForgeError::NotFound(_)) => {}
631                Err(e) => return Err(e),
632            }
633        }
634
635        Err(ForgeError::NotFound(format!(
636            "Function '{}' not found",
637            function_name
638        )))
639    }
640
641    /// Check rate limit for a function call.
642    async fn check_rate_limit(
643        &self,
644        info: &FunctionInfo,
645        function_name: &str,
646        auth: &AuthContext,
647        request: &RequestMetadata,
648    ) -> Result<()> {
649        let (requests, per_secs) = match (info.rate_limit_requests, info.rate_limit_per_secs) {
650            (Some(r), Some(p)) => (r, p),
651            _ => return Ok(()),
652        };
653
654        let key_type = info.rate_limit_key.clone().unwrap_or_default();
655
656        let config = RateLimitConfig::new(requests, Duration::from_secs(per_secs))
657            .with_key(key_type.clone());
658
659        let bucket_key = self
660            .rate_limiter
661            .build_key(key_type, function_name, auth, request);
662
663        self.rate_limiter.enforce(&bucket_key, &config).await?;
664
665        Ok(())
666    }
667
668    async fn execute_transactional(
669        &self,
670        info: &FunctionInfo,
671        handler: &BoxedMutationFn,
672        args: Value,
673        auth: AuthContext,
674        request: RequestMetadata,
675    ) -> Result<RouteResult> {
676        let span = tracing::info_span!("db.transaction", db.system = "postgresql",);
677        let fn_timeout = info.timeout.unwrap_or(self.default_timeout);
678
679        async {
680            let primary = self.db.primary();
681            let mut tx = primary.begin().await.map_err(ForgeError::Database)?;
682
683            // Bind the per-function deadline to PostgreSQL via SET LOCAL so
684            // PG cancels the in-flight query at the same instant the tokio
685            // timeout fires. Without this the connection sits busy until the
686            // pool-wide statement_timeout — wasting connections and producing
687            // misleading "still running" backends after a 504. SET LOCAL
688            // doesn't accept bind parameters, so the value is interpolated
689            // directly; it's an integer derived from a Duration so injection
690            // is impossible.
691            let timeout_ms = fn_timeout.as_millis().min(i64::MAX as u128) as i64;
692            #[allow(clippy::disallowed_methods)]
693            sqlx::query(&format!("SET LOCAL statement_timeout = {timeout_ms}"))
694                .execute(&mut *tx)
695                .await
696                .map_err(ForgeError::Database)?;
697
698            let deps = Arc::clone(&self.mutation_deps);
699            let (mut ctx, tx_handle) = MutationContext::with_transaction(
700                primary.clone(),
701                tx,
702                auth,
703                request,
704                deps.http_client.clone(),
705                deps.job_dispatcher.clone(),
706                deps.workflow_dispatcher.clone(),
707            );
708            if let Some(ref issuer) = deps.token_issuer {
709                ctx.set_token_issuer(issuer.clone());
710            }
711            ctx.set_token_ttl(deps.token_ttl.clone());
712            ctx.set_http_timeout(info.http_timeout);
713            if deps.max_jobs_per_request > 0 {
714                ctx.set_max_jobs_per_request(deps.max_jobs_per_request);
715            }
716            if let Some(ref kv) = deps.kv {
717                ctx.set_kv(Arc::clone(kv));
718            }
719
720            let result = handler(&ctx, args).await;
721            drop(ctx);
722
723            // After dropping ctx, the executor holds the only Arc to the
724            // transaction. Take it out via `lock().await.take()` so we never
725            // depend on `Arc::try_unwrap` succeeding — even if a handler
726            // accidentally retained a clone of the Arc through a destructured
727            // DbConn, the take() leaves a None behind that prevents further
728            // misuse rather than leaking the transaction.
729            let tx = tx_handle
730                .lock()
731                .await
732                .take()
733                .ok_or_else(|| ForgeError::internal("Transaction already taken from handle"))?;
734
735            match result {
736                Ok(value) => {
737                    self.check_result_size(&value)?;
738                    tx.commit().await.map_err(ForgeError::Database)?;
739                    Ok(RouteResult::Mutation(value))
740                }
741                Err(e) => {
742                    if let Err(rollback_err) = tx.rollback().await {
743                        tracing::error!(
744                            handler_error = %e,
745                            rollback_error = %rollback_err,
746                            "Mutation rollback failed; transaction will be released by Drop"
747                        );
748                    } else {
749                        tracing::warn!(
750                            handler_error = %e,
751                            "Mutation rolled back"
752                        );
753                    }
754                    Err(e)
755                }
756            }
757        }
758        .instrument(span)
759        .await
760    }
761}
762
763/// Measure the JSON-serialized byte length of a `serde_json::Value` without
764/// allocating a `String`. Uses a counting `io::Write` implementation fed to
765/// `serde_json::to_writer`.
766fn json_byte_length(value: &Value) -> usize {
767    struct Counter(usize);
768    impl std::io::Write for Counter {
769        fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
770            self.0 += buf.len();
771            Ok(buf.len())
772        }
773        fn flush(&mut self) -> std::io::Result<()> {
774            Ok(())
775        }
776    }
777    let mut counter = Counter(0);
778    if serde_json::to_writer(&mut counter, value).is_ok() {
779        counter.0
780    } else {
781        usize::MAX
782    }
783}
784
785#[cfg(test)]
786#[allow(clippy::unwrap_used, clippy::panic, clippy::indexing_slicing)]
787mod tests {
788    use super::*;
789    use std::collections::HashMap;
790
791    fn resolver() -> SharedRoleResolver {
792        default_role_resolver()
793    }
794
795    fn authed_as(roles: &[&str]) -> AuthContext {
796        AuthContext::authenticated(
797            uuid::Uuid::new_v4(),
798            roles.iter().map(|s| (*s).to_string()).collect(),
799            HashMap::new(),
800        )
801    }
802
803    #[test]
804    fn require_auth_allows_public_functions_for_anonymous_callers() {
805        let auth = AuthContext::unauthenticated();
806        assert!(require_auth(true, None, &auth, &resolver()).is_ok());
807    }
808
809    #[test]
810    fn require_auth_allows_public_functions_even_with_required_role() {
811        // Public flag short-circuits: the role check never runs.
812        let auth = AuthContext::unauthenticated();
813        assert!(require_auth(true, Some("admin"), &auth, &resolver()).is_ok());
814    }
815
816    #[test]
817    fn require_auth_rejects_anonymous_callers_with_unauthorized() {
818        let auth = AuthContext::unauthenticated();
819        match require_auth(false, None, &auth, &resolver()) {
820            Err(ForgeError::Unauthorized(_)) => {}
821            other => panic!("expected Unauthorized, got {other:?}"),
822        }
823    }
824
825    #[test]
826    fn require_auth_accepts_authenticated_caller_without_role_requirement() {
827        let auth = authed_as(&["user"]);
828        assert!(require_auth(false, None, &auth, &resolver()).is_ok());
829    }
830
831    #[test]
832    fn require_auth_accepts_caller_with_required_role() {
833        let auth = authed_as(&["user", "admin"]);
834        assert!(require_auth(false, Some("admin"), &auth, &resolver()).is_ok());
835    }
836
837    #[test]
838    fn require_auth_rejects_caller_missing_required_role_with_forbidden() {
839        let auth = authed_as(&["user"]);
840        match require_auth(false, Some("admin"), &auth, &resolver()) {
841            Err(ForgeError::Forbidden(msg)) => assert!(msg.contains("admin")),
842            other => panic!("expected Forbidden, got {other:?}"),
843        }
844    }
845
846    #[test]
847    fn require_auth_consults_custom_role_resolver() {
848        // Custom resolver expands "user" to also include "admin".
849        struct ExpandingResolver;
850        impl forge_core::RoleResolver for ExpandingResolver {
851            fn resolve(&self, auth: &AuthContext) -> Vec<String> {
852                let mut roles: Vec<String> = auth.roles().to_vec();
853                if roles.iter().any(|r| r == "user") {
854                    roles.push("admin".to_string());
855                }
856                roles
857            }
858        }
859        let auth = authed_as(&["user"]);
860        let resolver: SharedRoleResolver = Arc::new(ExpandingResolver);
861        // Without expansion this would Forbidden; with expansion it succeeds.
862        assert!(require_auth(false, Some("admin"), &auth, &resolver).is_ok());
863    }
864
865    #[test]
866    fn test_auth_cache_scope_changes_with_claims() {
867        let user_id = uuid::Uuid::new_v4();
868        let auth_a = AuthContext::authenticated(
869            user_id,
870            vec!["user".to_string()],
871            HashMap::from([
872                (
873                    "sub".to_string(),
874                    serde_json::Value::String(user_id.to_string()),
875                ),
876                (
877                    "tenant_id".to_string(),
878                    serde_json::Value::String("tenant-a".to_string()),
879                ),
880            ]),
881        );
882        let auth_b = AuthContext::authenticated(
883            user_id,
884            vec!["user".to_string()],
885            HashMap::from([
886                (
887                    "sub".to_string(),
888                    serde_json::Value::String(user_id.to_string()),
889                ),
890                (
891                    "tenant_id".to_string(),
892                    serde_json::Value::String("tenant-b".to_string()),
893                ),
894            ]),
895        );
896
897        let scope_a = QueryCacheCoordinator::auth_scope(&auth_a);
898        let scope_b = QueryCacheCoordinator::auth_scope(&auth_b);
899        assert_ne!(scope_a, scope_b);
900    }
901}