1use std::sync::Arc;
2use std::time::Duration;
3
4use forge_core::{
5 AuthContext, CircuitBreakerClient, ForgeError, FunctionInfo, FunctionKind, JobDispatch,
6 KvHandle, MutationContext, QueryContext, RequestMetadata, Result, SharedRoleResolver,
7 WorkflowDispatch, default_role_resolver,
8 rate_limit::{RateLimitConfig, RateLimiterBackend},
9};
10use serde_json::Value;
11use tokio::time::timeout;
12use tracing::Instrument;
13
14use super::cache::QueryCacheCoordinator;
15use super::execution_log::{level_for as log_level_for, log_completion};
16use super::registry::{BoxedMutationFn, FunctionEntry, FunctionRegistry};
17#[cfg(feature = "gateway")]
18use super::rpc_signals::{RpcSignalContext, RpcSignalsEmitter};
19use crate::pg::Database;
20use crate::rate_limit::HybridRateLimiter;
21#[cfg(feature = "gateway")]
22use crate::signals::SignalsCollector;
23
24fn require_auth(
30 is_public: bool,
31 required_role: Option<&str>,
32 auth: &AuthContext,
33 role_resolver: &SharedRoleResolver,
34) -> Result<()> {
35 if is_public {
36 return Ok(());
37 }
38 if !auth.is_authenticated() {
39 return Err(ForgeError::Unauthorized("Authentication required".into()));
40 }
41 if let Some(role) = required_role {
42 let effective_roles = role_resolver.resolve(auth);
43 if !effective_roles.iter().any(|r| r == role) {
44 return Err(ForgeError::Forbidden(format!("Role '{role}' required")));
45 }
46 }
47 Ok(())
48}
49
50pub enum RouteResult {
52 Query(Arc<Value>),
54 Mutation(Value),
56 Job(Value),
58 Workflow(Value),
60}
61
62pub struct RouteOutcome {
66 pub result: RouteResult,
67 pub cache_hit: bool,
68}
69
70#[derive(Clone)]
72struct MutationDeps {
73 http_client: CircuitBreakerClient,
74 job_dispatcher: Option<Arc<dyn JobDispatch>>,
75 workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
76 token_issuer: Option<Arc<dyn forge_core::TokenIssuer>>,
77 token_ttl: forge_core::AuthTokenTtl,
78 max_jobs_per_request: usize,
79 kv: Option<Arc<dyn KvHandle>>,
80}
81
82pub struct FunctionRouter {
84 registry: Arc<FunctionRegistry>,
85 db: Database,
86 mutation_deps: Arc<MutationDeps>,
87 rate_limiter: Arc<dyn RateLimiterBackend>,
88 role_resolver: SharedRoleResolver,
89 cache: Arc<QueryCacheCoordinator>,
90 default_timeout: Duration,
91 max_result_size_bytes: usize,
93 #[cfg(feature = "gateway")]
94 signals: Option<RpcSignalsEmitter>,
95}
96
97impl FunctionRouter {
98 pub fn new(registry: Arc<FunctionRegistry>, db: Database) -> Self {
100 Self::with_http_client(registry, db, CircuitBreakerClient::with_ssrf_protection())
101 }
102
103 pub fn with_http_client(
105 registry: Arc<FunctionRegistry>,
106 db: Database,
107 http_client: CircuitBreakerClient,
108 ) -> Self {
109 let rate_limiter: Arc<dyn RateLimiterBackend> =
110 Arc::new(HybridRateLimiter::new(db.primary().clone()));
111 let cache = Arc::new(QueryCacheCoordinator::new(®istry));
112 Self {
113 registry,
114 db,
115 mutation_deps: Arc::new(MutationDeps {
116 http_client,
117 job_dispatcher: None,
118 workflow_dispatcher: None,
119 token_issuer: None,
120 token_ttl: forge_core::AuthTokenTtl::default(),
121 max_jobs_per_request: 0,
122 kv: None,
123 }),
124 rate_limiter,
125 role_resolver: default_role_resolver(),
126 cache,
127 default_timeout: Duration::from_secs(30),
128 max_result_size_bytes: 0,
129 #[cfg(feature = "gateway")]
130 signals: None,
131 }
132 }
133
134 pub fn with_dispatch(
136 registry: Arc<FunctionRegistry>,
137 db: Database,
138 job_dispatcher: Option<Arc<dyn JobDispatch>>,
139 workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
140 ) -> Self {
141 Self::with_dispatch_and_issuer(registry, db, job_dispatcher, workflow_dispatcher, None)
142 }
143
144 pub fn with_dispatch_and_issuer(
146 registry: Arc<FunctionRegistry>,
147 db: Database,
148 job_dispatcher: Option<Arc<dyn JobDispatch>>,
149 workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
150 token_issuer: Option<Arc<dyn forge_core::TokenIssuer>>,
151 ) -> Self {
152 let mut router = Self::new(Arc::clone(®istry), db);
153 if let Some(jd) = job_dispatcher {
154 router = router.with_job_dispatcher(jd);
155 }
156 if let Some(wd) = workflow_dispatcher {
157 router = router.with_workflow_dispatcher(wd);
158 }
159 if let Some(issuer) = token_issuer {
160 router = router.with_token_issuer(issuer);
161 }
162 router
163 }
164
165 pub fn with_role_resolver(mut self, resolver: SharedRoleResolver) -> Self {
167 self.role_resolver = resolver;
168 self
169 }
170
171 pub fn set_role_resolver(&mut self, resolver: SharedRoleResolver) {
173 self.role_resolver = resolver;
174 }
175
176 pub fn with_rate_limiter(mut self, rate_limiter: Arc<dyn RateLimiterBackend>) -> Self {
179 self.rate_limiter = rate_limiter;
180 self
181 }
182
183 pub fn set_rate_limiter(&mut self, rate_limiter: Arc<dyn RateLimiterBackend>) {
185 self.rate_limiter = rate_limiter;
186 }
187
188 fn deps_mut(&mut self) -> &mut MutationDeps {
190 Arc::make_mut(&mut self.mutation_deps)
191 }
192
193 pub fn with_token_issuer(mut self, issuer: Arc<dyn forge_core::TokenIssuer>) -> Self {
195 self.deps_mut().token_issuer = Some(issuer);
196 self
197 }
198
199 pub fn with_token_ttl(mut self, ttl: forge_core::AuthTokenTtl) -> Self {
201 self.deps_mut().token_ttl = ttl;
202 self
203 }
204
205 pub fn set_token_ttl(&mut self, ttl: forge_core::AuthTokenTtl) {
207 self.deps_mut().token_ttl = ttl;
208 }
209
210 pub fn with_job_dispatcher(mut self, dispatcher: Arc<dyn JobDispatch>) -> Self {
212 self.deps_mut().job_dispatcher = Some(dispatcher);
213 self
214 }
215
216 pub fn with_workflow_dispatcher(mut self, dispatcher: Arc<dyn WorkflowDispatch>) -> Self {
218 self.deps_mut().workflow_dispatcher = Some(dispatcher);
219 self
220 }
221
222 pub fn with_kv(mut self, kv: Arc<dyn KvHandle>) -> Self {
224 self.deps_mut().kv = Some(kv);
225 self
226 }
227
228 pub fn set_kv(&mut self, kv: Arc<dyn KvHandle>) {
230 self.deps_mut().kv = Some(kv);
231 }
232
233 pub fn with_default_timeout(mut self, timeout: Duration) -> Self {
235 self.default_timeout = timeout;
236 self
237 }
238
239 pub fn set_max_jobs_per_request(&mut self, limit: usize) {
242 self.deps_mut().max_jobs_per_request = limit;
243 }
244
245 pub fn set_max_result_size_bytes(&mut self, limit: usize) {
248 self.max_result_size_bytes = limit;
249 }
250
251 #[cfg(feature = "gateway")]
253 pub fn set_signals_collector(&mut self, collector: SignalsCollector, server_secret: String) {
254 self.signals = Some(RpcSignalsEmitter::new(collector, server_secret));
255 }
256
257 pub async fn execute(
259 &self,
260 function_name: &str,
261 args: Value,
262 auth: AuthContext,
263 request: RequestMetadata,
264 ) -> Result<Value> {
265 let start = std::time::Instant::now();
266 let info = self.registry.get(function_name).map(|e| e.info());
267 let fn_timeout = info.and_then(|i| i.timeout).unwrap_or(self.default_timeout);
268 let log_level = log_level_for(info);
269
270 let kind = info.map(|i| i.kind.as_str()).unwrap_or("unknown");
271
272 #[cfg(feature = "gateway")]
274 let mut signal_ctx = self
275 .signals
276 .as_ref()
277 .map(|_| RpcSignalContext::capture(&auth, &request));
278
279 let span = tracing::info_span!(
284 "fn.execute",
285 function = function_name,
286 fn.kind = %kind,
287 cache.hit = tracing::field::Empty,
288 );
289
290 let result = match timeout(
291 fn_timeout,
292 self.route(function_name, args.clone(), auth, request)
293 .instrument(span),
294 )
295 .await
296 {
297 Ok(result) => result,
298 Err(_) => {
299 let duration = start.elapsed();
300 log_completion(
301 log_level,
302 function_name,
303 "unknown",
304 &args,
305 duration,
306 false,
307 Some(&format!("Timeout after {:?}", fn_timeout)),
308 );
309 crate::observability::record_fn_execution(
310 function_name,
311 kind,
312 false,
313 false,
314 duration.as_secs_f64(),
315 );
316 #[cfg(feature = "gateway")]
317 if let (Some(emitter), Some(ctx)) = (&self.signals, signal_ctx.take()) {
318 emitter.emit(function_name, kind, duration, false, ctx);
319 }
320 return Err(ForgeError::Timeout(format!(
321 "Function '{}' timed out after {:?}",
322 function_name, fn_timeout
323 )));
324 }
325 };
326
327 let duration = start.elapsed();
328
329 match result {
330 Ok(outcome) => {
331 let RouteOutcome { result, cache_hit } = outcome;
332 let (result_kind, value) = match result {
333 RouteResult::Query(arc) => {
334 let v = Arc::try_unwrap(arc).unwrap_or_else(|a| Value::clone(&a));
335 ("query", v)
336 }
337 RouteResult::Mutation(v) => ("mutation", v),
338 RouteResult::Job(v) => ("job", v),
339 RouteResult::Workflow(v) => ("workflow", v),
340 };
341
342 log_completion(
343 log_level,
344 function_name,
345 result_kind,
346 &args,
347 duration,
348 true,
349 None,
350 );
351 crate::observability::record_fn_execution(
352 function_name,
353 result_kind,
354 true,
355 cache_hit,
356 duration.as_secs_f64(),
357 );
358 #[cfg(feature = "gateway")]
359 if let (Some(emitter), Some(ctx)) = (&self.signals, signal_ctx.take()) {
360 emitter.emit(function_name, result_kind, duration, true, ctx);
361 }
362
363 Ok(value)
364 }
365 Err(e) => {
366 log_completion(
367 log_level,
368 function_name,
369 kind,
370 &args,
371 duration,
372 false,
373 Some(&e.to_string()),
374 );
375 crate::observability::record_fn_execution(
376 function_name,
377 kind,
378 false,
379 false,
380 duration.as_secs_f64(),
381 );
382 #[cfg(feature = "gateway")]
383 if let (Some(emitter), Some(ctx)) = (&self.signals, signal_ctx.take()) {
384 emitter.emit(function_name, kind, duration, false, ctx);
385 }
386
387 Err(e)
388 }
389 }
390 }
391
392 pub fn function_info(&self, function_name: &str) -> Option<FunctionInfo> {
394 self.registry.get(function_name).map(|e| e.info().clone())
395 }
396
397 pub fn has_function(&self, function_name: &str) -> bool {
399 self.registry.get(function_name).is_some()
400 }
401
402 pub fn get_function_kind(&self, function_name: &str) -> Option<FunctionKind> {
404 self.registry.get(function_name).map(|e| e.kind())
405 }
406
407 pub fn function_infos(&self) -> Vec<FunctionInfo> {
409 self.registry
410 .functions()
411 .map(|(_, entry)| entry.info().clone())
412 .collect()
413 }
414
415 pub fn cache(&self) -> Arc<QueryCacheCoordinator> {
417 Arc::clone(&self.cache)
418 }
419
420 fn check_result_size(&self, value: &Value) -> Result<()> {
424 if self.max_result_size_bytes == 0 {
425 return Ok(());
426 }
427 let serialized_len = json_byte_length(value);
428 if serialized_len > self.max_result_size_bytes {
429 return Err(ForgeError::internal(format!(
430 "Response size {} bytes exceeds max_result_size_bytes limit of {} bytes",
431 serialized_len, self.max_result_size_bytes
432 )));
433 }
434 Ok(())
435 }
436
437 pub async fn route(
438 &self,
439 function_name: &str,
440 args: Value,
441 auth: AuthContext,
442 request: RequestMetadata,
443 ) -> Result<RouteOutcome> {
444 if let Some(entry) = self.registry.get(function_name) {
445 let info = entry.info();
446 require_auth(
447 info.is_public,
448 info.required_role,
449 &auth,
450 &self.role_resolver,
451 )?;
452 if info.requires_tenant_scope && auth.tenant_id().is_none() {
453 return Err(ForgeError::Forbidden(
454 "this function requires a tenant scope but the auth context has no tenant_id \
455 claim"
456 .to_string(),
457 ));
458 }
459 self.check_rate_limit(info, function_name, &auth, &request)
460 .await?;
461
462 return match entry {
463 FunctionEntry::Webhook { info } => {
464 return Err(ForgeError::InvalidArgument(format!(
468 "Webhook '{}' cannot be called via RPC; use its dedicated HTTP endpoint",
469 info.name
470 )));
471 }
472 FunctionEntry::Query { handler, info, .. } => {
473 let pool = if info.consistent {
474 self.db.primary().clone()
475 } else {
476 self.db.read_pool().clone()
477 };
478
479 if !info.consistent
480 && let Some(ttl) = info.cache_ttl
481 {
482 let scope = QueryCacheCoordinator::auth_scope(&auth);
485 if let Some(cached) =
486 self.cache
487 .get_by_scope(function_name, &args, scope.as_deref())
488 {
489 tracing::Span::current().record("cache.hit", true);
490 crate::observability::record_fn_cache(function_name, true);
491 return Ok(RouteOutcome {
492 result: RouteResult::Query(cached),
493 cache_hit: true,
494 });
495 }
496 tracing::Span::current().record("cache.hit", false);
497 crate::observability::record_fn_cache(function_name, false);
498
499 let mut ctx = QueryContext::new(pool, auth, request);
500 if let Some(ref kv) = self.mutation_deps.kv {
501 ctx.set_kv(Arc::clone(kv));
502 }
503 let result = handler(&ctx, args.clone()).await?;
504 self.check_result_size(&result)?;
505
506 let arc = Arc::new(result);
507 self.cache.set_arc_by_scope(
508 function_name,
509 &args,
510 scope.as_deref(),
511 Arc::clone(&arc),
512 Duration::from_secs(ttl),
513 );
514
515 Ok(RouteOutcome {
516 result: RouteResult::Query(arc),
517 cache_hit: false,
518 })
519 } else {
520 let mut ctx = QueryContext::new(pool, auth, request);
521 if let Some(ref kv) = self.mutation_deps.kv {
522 ctx.set_kv(Arc::clone(kv));
523 }
524 let result = handler(&ctx, args).await?;
525 self.check_result_size(&result)?;
526 Ok(RouteOutcome {
527 result: RouteResult::Query(Arc::new(result)),
528 cache_hit: false,
529 })
530 }
531 }
532 FunctionEntry::Mutation { handler, info } => {
533 let result = if info.transactional {
534 self.execute_transactional(info, handler, args, auth, request)
535 .await
536 } else {
537 let deps = Arc::clone(&self.mutation_deps);
538 let mut ctx = MutationContext::with_dispatch(
539 self.db.primary().clone(),
540 auth,
541 request,
542 deps.http_client.clone(),
543 deps.job_dispatcher.clone(),
544 deps.workflow_dispatcher.clone(),
545 );
546 if let Some(ref issuer) = deps.token_issuer {
547 ctx.set_token_issuer(issuer.clone());
548 }
549 ctx.set_token_ttl(deps.token_ttl.clone());
550 ctx.set_http_timeout(info.http_timeout);
551 if deps.max_jobs_per_request > 0 {
552 ctx.set_max_jobs_per_request(deps.max_jobs_per_request);
553 }
554 if let Some(ref kv) = deps.kv {
555 ctx.set_kv(Arc::clone(kv));
556 }
557 let value = handler(&ctx, args).await?;
558 self.check_result_size(&value)?;
559 Ok(RouteResult::Mutation(value))
560 };
561 if result.is_ok() {
565 self.cache.invalidate_for_mutation(info);
566 }
567 result.map(|r| RouteOutcome {
568 result: r,
569 cache_hit: false,
570 })
571 }
572 };
573 }
574
575 if let Some(ref job_dispatcher) = self.mutation_deps.job_dispatcher
576 && let Some(job_info) = job_dispatcher.get_info(function_name)
577 {
578 require_auth(
579 job_info.is_public,
580 job_info.required_role,
581 &auth,
582 &self.role_resolver,
583 )?;
584 match job_dispatcher
585 .dispatch_by_name(
586 function_name,
587 args.clone(),
588 auth.principal_id(),
589 auth.tenant_id(),
590 )
591 .await
592 {
593 Ok(job_id) => {
594 return Ok(RouteOutcome {
595 result: RouteResult::Job(serde_json::json!({ "job_id": job_id })),
596 cache_hit: false,
597 });
598 }
599 Err(ForgeError::NotFound(_)) => {}
600 Err(e) => return Err(e),
601 }
602 }
603
604 if let Some(ref workflow_dispatcher) = self.mutation_deps.workflow_dispatcher
605 && let Some(workflow_info) = workflow_dispatcher.get_info(function_name)
606 {
607 require_auth(
608 workflow_info.is_public,
609 workflow_info.required_role,
610 &auth,
611 &self.role_resolver,
612 )?;
613 match workflow_dispatcher
614 .start_by_name(
615 function_name,
616 args,
617 auth.principal_id(),
618 Some(request.trace_id().to_string()),
619 )
620 .await
621 {
622 Ok(workflow_id) => {
623 return Ok(RouteOutcome {
624 result: RouteResult::Workflow(
625 serde_json::json!({ "workflow_id": workflow_id }),
626 ),
627 cache_hit: false,
628 });
629 }
630 Err(ForgeError::NotFound(_)) => {}
631 Err(e) => return Err(e),
632 }
633 }
634
635 Err(ForgeError::NotFound(format!(
636 "Function '{}' not found",
637 function_name
638 )))
639 }
640
641 async fn check_rate_limit(
643 &self,
644 info: &FunctionInfo,
645 function_name: &str,
646 auth: &AuthContext,
647 request: &RequestMetadata,
648 ) -> Result<()> {
649 let (requests, per_secs) = match (info.rate_limit_requests, info.rate_limit_per_secs) {
650 (Some(r), Some(p)) => (r, p),
651 _ => return Ok(()),
652 };
653
654 let key_type = info.rate_limit_key.clone().unwrap_or_default();
655
656 let config = RateLimitConfig::new(requests, Duration::from_secs(per_secs))
657 .with_key(key_type.clone());
658
659 let bucket_key = self
660 .rate_limiter
661 .build_key(key_type, function_name, auth, request);
662
663 self.rate_limiter.enforce(&bucket_key, &config).await?;
664
665 Ok(())
666 }
667
668 async fn execute_transactional(
669 &self,
670 info: &FunctionInfo,
671 handler: &BoxedMutationFn,
672 args: Value,
673 auth: AuthContext,
674 request: RequestMetadata,
675 ) -> Result<RouteResult> {
676 let span = tracing::info_span!("db.transaction", db.system = "postgresql",);
677 let fn_timeout = info.timeout.unwrap_or(self.default_timeout);
678
679 async {
680 let primary = self.db.primary();
681 let mut tx = primary.begin().await.map_err(ForgeError::Database)?;
682
683 let timeout_ms = fn_timeout.as_millis().min(i64::MAX as u128) as i64;
692 #[allow(clippy::disallowed_methods)]
693 sqlx::query(&format!("SET LOCAL statement_timeout = {timeout_ms}"))
694 .execute(&mut *tx)
695 .await
696 .map_err(ForgeError::Database)?;
697
698 let deps = Arc::clone(&self.mutation_deps);
699 let (mut ctx, tx_handle) = MutationContext::with_transaction(
700 primary.clone(),
701 tx,
702 auth,
703 request,
704 deps.http_client.clone(),
705 deps.job_dispatcher.clone(),
706 deps.workflow_dispatcher.clone(),
707 );
708 if let Some(ref issuer) = deps.token_issuer {
709 ctx.set_token_issuer(issuer.clone());
710 }
711 ctx.set_token_ttl(deps.token_ttl.clone());
712 ctx.set_http_timeout(info.http_timeout);
713 if deps.max_jobs_per_request > 0 {
714 ctx.set_max_jobs_per_request(deps.max_jobs_per_request);
715 }
716 if let Some(ref kv) = deps.kv {
717 ctx.set_kv(Arc::clone(kv));
718 }
719
720 let result = handler(&ctx, args).await;
721 drop(ctx);
722
723 let tx = tx_handle
730 .lock()
731 .await
732 .take()
733 .ok_or_else(|| ForgeError::internal("Transaction already taken from handle"))?;
734
735 match result {
736 Ok(value) => {
737 self.check_result_size(&value)?;
738 tx.commit().await.map_err(ForgeError::Database)?;
739 Ok(RouteResult::Mutation(value))
740 }
741 Err(e) => {
742 if let Err(rollback_err) = tx.rollback().await {
743 tracing::error!(
744 handler_error = %e,
745 rollback_error = %rollback_err,
746 "Mutation rollback failed; transaction will be released by Drop"
747 );
748 } else {
749 tracing::warn!(
750 handler_error = %e,
751 "Mutation rolled back"
752 );
753 }
754 Err(e)
755 }
756 }
757 }
758 .instrument(span)
759 .await
760 }
761}
762
763fn json_byte_length(value: &Value) -> usize {
767 struct Counter(usize);
768 impl std::io::Write for Counter {
769 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
770 self.0 += buf.len();
771 Ok(buf.len())
772 }
773 fn flush(&mut self) -> std::io::Result<()> {
774 Ok(())
775 }
776 }
777 let mut counter = Counter(0);
778 if serde_json::to_writer(&mut counter, value).is_ok() {
779 counter.0
780 } else {
781 usize::MAX
782 }
783}
784
785#[cfg(test)]
786#[allow(clippy::unwrap_used, clippy::panic, clippy::indexing_slicing)]
787mod tests {
788 use super::*;
789 use std::collections::HashMap;
790
791 fn resolver() -> SharedRoleResolver {
792 default_role_resolver()
793 }
794
795 fn authed_as(roles: &[&str]) -> AuthContext {
796 AuthContext::authenticated(
797 uuid::Uuid::new_v4(),
798 roles.iter().map(|s| (*s).to_string()).collect(),
799 HashMap::new(),
800 )
801 }
802
803 #[test]
804 fn require_auth_allows_public_functions_for_anonymous_callers() {
805 let auth = AuthContext::unauthenticated();
806 assert!(require_auth(true, None, &auth, &resolver()).is_ok());
807 }
808
809 #[test]
810 fn require_auth_allows_public_functions_even_with_required_role() {
811 let auth = AuthContext::unauthenticated();
813 assert!(require_auth(true, Some("admin"), &auth, &resolver()).is_ok());
814 }
815
816 #[test]
817 fn require_auth_rejects_anonymous_callers_with_unauthorized() {
818 let auth = AuthContext::unauthenticated();
819 match require_auth(false, None, &auth, &resolver()) {
820 Err(ForgeError::Unauthorized(_)) => {}
821 other => panic!("expected Unauthorized, got {other:?}"),
822 }
823 }
824
825 #[test]
826 fn require_auth_accepts_authenticated_caller_without_role_requirement() {
827 let auth = authed_as(&["user"]);
828 assert!(require_auth(false, None, &auth, &resolver()).is_ok());
829 }
830
831 #[test]
832 fn require_auth_accepts_caller_with_required_role() {
833 let auth = authed_as(&["user", "admin"]);
834 assert!(require_auth(false, Some("admin"), &auth, &resolver()).is_ok());
835 }
836
837 #[test]
838 fn require_auth_rejects_caller_missing_required_role_with_forbidden() {
839 let auth = authed_as(&["user"]);
840 match require_auth(false, Some("admin"), &auth, &resolver()) {
841 Err(ForgeError::Forbidden(msg)) => assert!(msg.contains("admin")),
842 other => panic!("expected Forbidden, got {other:?}"),
843 }
844 }
845
846 #[test]
847 fn require_auth_consults_custom_role_resolver() {
848 struct ExpandingResolver;
850 impl forge_core::RoleResolver for ExpandingResolver {
851 fn resolve(&self, auth: &AuthContext) -> Vec<String> {
852 let mut roles: Vec<String> = auth.roles().to_vec();
853 if roles.iter().any(|r| r == "user") {
854 roles.push("admin".to_string());
855 }
856 roles
857 }
858 }
859 let auth = authed_as(&["user"]);
860 let resolver: SharedRoleResolver = Arc::new(ExpandingResolver);
861 assert!(require_auth(false, Some("admin"), &auth, &resolver).is_ok());
863 }
864
865 #[test]
866 fn test_auth_cache_scope_changes_with_claims() {
867 let user_id = uuid::Uuid::new_v4();
868 let auth_a = AuthContext::authenticated(
869 user_id,
870 vec!["user".to_string()],
871 HashMap::from([
872 (
873 "sub".to_string(),
874 serde_json::Value::String(user_id.to_string()),
875 ),
876 (
877 "tenant_id".to_string(),
878 serde_json::Value::String("tenant-a".to_string()),
879 ),
880 ]),
881 );
882 let auth_b = AuthContext::authenticated(
883 user_id,
884 vec!["user".to_string()],
885 HashMap::from([
886 (
887 "sub".to_string(),
888 serde_json::Value::String(user_id.to_string()),
889 ),
890 (
891 "tenant_id".to_string(),
892 serde_json::Value::String("tenant-b".to_string()),
893 ),
894 ]),
895 );
896
897 let scope_a = QueryCacheCoordinator::auth_scope(&auth_a);
898 let scope_b = QueryCacheCoordinator::auth_scope(&auth_b);
899 assert_ne!(scope_a, scope_b);
900 }
901}