1use std::collections::BTreeMap;
2use std::sync::Arc;
3use std::time::Duration;
4
5use chrono::Utc;
6use forge_core::{
7 AuthContext, CircuitBreakerClient, ForgeError, FunctionInfo, FunctionKind, JobDispatch,
8 MutationContext, OutboxBuffer, PendingJob, PendingWorkflow, QueryContext, RequestMetadata,
9 Result, WorkflowDispatch,
10 job::JobStatus,
11 rate_limit::{RateLimitConfig, RateLimitKey},
12 workflow::WorkflowStatus,
13};
14use serde_json::Value;
15use tracing::Instrument;
16
17use super::cache::QueryCache;
18use super::registry::{BoxedMutationFn, FunctionEntry, FunctionRegistry};
19use crate::db::Database;
20use crate::rate_limit::HybridRateLimiter;
21
22pub enum RouteResult {
24 Query(Value),
26 Mutation(Value),
28 Job(Value),
30 Workflow(Value),
32}
33
34pub struct FunctionRouter {
36 registry: Arc<FunctionRegistry>,
37 db: Database,
38 http_client: CircuitBreakerClient,
39 job_dispatcher: Option<Arc<dyn JobDispatch>>,
40 workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
41 rate_limiter: HybridRateLimiter,
42 query_cache: QueryCache,
43 token_issuer: Option<Arc<dyn forge_core::TokenIssuer>>,
44}
45
46impl FunctionRouter {
47 pub fn new(registry: Arc<FunctionRegistry>, db: Database) -> Self {
49 let rate_limiter = HybridRateLimiter::new(db.primary().clone());
50 Self {
51 registry,
52 db,
53 http_client: CircuitBreakerClient::with_defaults(reqwest::Client::new()),
54 job_dispatcher: None,
55 workflow_dispatcher: None,
56 rate_limiter,
57 query_cache: QueryCache::new(),
58 token_issuer: None,
59 }
60 }
61
62 pub fn with_http_client(
64 registry: Arc<FunctionRegistry>,
65 db: Database,
66 http_client: CircuitBreakerClient,
67 ) -> Self {
68 let rate_limiter = HybridRateLimiter::new(db.primary().clone());
69 Self {
70 registry,
71 db,
72 http_client,
73 job_dispatcher: None,
74 workflow_dispatcher: None,
75 rate_limiter,
76 query_cache: QueryCache::new(),
77 token_issuer: None,
78 }
79 }
80
81 pub fn with_token_issuer(mut self, issuer: Arc<dyn forge_core::TokenIssuer>) -> Self {
83 self.token_issuer = Some(issuer);
84 self
85 }
86
87 pub fn with_job_dispatcher(mut self, dispatcher: Arc<dyn JobDispatch>) -> Self {
89 self.job_dispatcher = Some(dispatcher);
90 self
91 }
92
93 pub fn with_workflow_dispatcher(mut self, dispatcher: Arc<dyn WorkflowDispatch>) -> Self {
95 self.workflow_dispatcher = Some(dispatcher);
96 self
97 }
98
99 pub async fn route(
100 &self,
101 function_name: &str,
102 args: Value,
103 auth: AuthContext,
104 request: RequestMetadata,
105 ) -> Result<RouteResult> {
106 if let Some(entry) = self.registry.get(function_name) {
107 self.check_auth(entry.info(), &auth)?;
108 self.check_rate_limit(entry.info(), function_name, &auth, &request)
109 .await?;
110 let enforce = !entry.info().is_public && entry.info().has_input_args;
113 Self::check_identity_args(function_name, &args, &auth, enforce)?;
114
115 return match entry {
116 FunctionEntry::Query { handler, info, .. } => {
117 let pool = if info.consistent {
118 self.db.primary().clone()
119 } else {
120 self.db.read_pool().clone()
121 };
122
123 let auth_scope = Self::auth_cache_scope(&auth);
124 if let Some(ttl) = info.cache_ttl {
125 if let Some(cached) =
126 self.query_cache
127 .get(function_name, &args, auth_scope.as_deref())
128 {
129 return Ok(RouteResult::Query(Value::clone(&cached)));
130 }
131
132 let ctx = QueryContext::new(pool, auth, request);
133 let result = handler(&ctx, args.clone()).await?;
134
135 self.query_cache.set(
136 function_name,
137 &args,
138 auth_scope.as_deref(),
139 result.clone(),
140 Duration::from_secs(ttl),
141 );
142
143 Ok(RouteResult::Query(result))
144 } else {
145 let ctx = QueryContext::new(pool, auth, request);
146 let result = handler(&ctx, args).await?;
147 Ok(RouteResult::Query(result))
148 }
149 }
150 FunctionEntry::Mutation { handler, info } => {
151 if info.transactional {
152 self.execute_transactional(handler, args, auth, request)
153 .await
154 } else {
155 let mut ctx = MutationContext::with_dispatch(
157 self.db.primary().clone(),
158 auth,
159 request,
160 self.http_client.clone(),
161 self.job_dispatcher.clone(),
162 self.workflow_dispatcher.clone(),
163 );
164 if let Some(ref issuer) = self.token_issuer {
165 ctx.set_token_issuer(issuer.clone());
166 }
167 let result = handler(&ctx, args).await?;
168 Ok(RouteResult::Mutation(result))
169 }
170 }
171 };
172 }
173
174 if let Some(ref job_dispatcher) = self.job_dispatcher
175 && let Some(job_info) = job_dispatcher.get_info(function_name)
176 {
177 self.check_job_auth(&job_info, &auth)?;
178 Self::check_identity_args(function_name, &args, &auth, !job_info.is_public)?;
179 match job_dispatcher
180 .dispatch_by_name(function_name, args.clone(), auth.principal_id())
181 .await
182 {
183 Ok(job_id) => {
184 return Ok(RouteResult::Job(serde_json::json!({ "job_id": job_id })));
185 }
186 Err(ForgeError::NotFound(_)) => {}
187 Err(e) => return Err(e),
188 }
189 }
190
191 if let Some(ref workflow_dispatcher) = self.workflow_dispatcher
192 && let Some(workflow_info) = workflow_dispatcher.get_info(function_name)
193 {
194 self.check_workflow_auth(&workflow_info, &auth)?;
195 Self::check_identity_args(function_name, &args, &auth, !workflow_info.is_public)?;
196 match workflow_dispatcher
197 .start_by_name(function_name, args.clone(), auth.principal_id())
198 .await
199 {
200 Ok(workflow_id) => {
201 return Ok(RouteResult::Workflow(
202 serde_json::json!({ "workflow_id": workflow_id }),
203 ));
204 }
205 Err(ForgeError::NotFound(_)) => {}
206 Err(e) => return Err(e),
207 }
208 }
209
210 Err(ForgeError::NotFound(format!(
211 "Function '{}' not found",
212 function_name
213 )))
214 }
215
216 fn check_auth(&self, info: &FunctionInfo, auth: &AuthContext) -> Result<()> {
217 if info.is_public {
218 return Ok(());
219 }
220
221 if !auth.is_authenticated() {
222 return Err(ForgeError::Unauthorized("Authentication required".into()));
223 }
224
225 if let Some(role) = info.required_role
226 && !auth.has_role(role)
227 {
228 return Err(ForgeError::Forbidden(format!("Role '{}' required", role)));
229 }
230
231 Ok(())
232 }
233
234 fn check_job_auth(&self, info: &forge_core::job::JobInfo, auth: &AuthContext) -> Result<()> {
235 if info.is_public {
236 return Ok(());
237 }
238
239 if !auth.is_authenticated() {
240 return Err(ForgeError::Unauthorized("Authentication required".into()));
241 }
242
243 if let Some(role) = info.required_role
244 && !auth.has_role(role)
245 {
246 return Err(ForgeError::Forbidden(format!("Role '{}' required", role)));
247 }
248
249 Ok(())
250 }
251
252 fn check_workflow_auth(
253 &self,
254 info: &forge_core::workflow::WorkflowInfo,
255 auth: &AuthContext,
256 ) -> Result<()> {
257 if info.is_public {
258 return Ok(());
259 }
260
261 if !auth.is_authenticated() {
262 return Err(ForgeError::Unauthorized("Authentication required".into()));
263 }
264
265 if let Some(role) = info.required_role
266 && !auth.has_role(role)
267 {
268 return Err(ForgeError::Forbidden(format!("Role '{}' required", role)));
269 }
270
271 Ok(())
272 }
273
274 async fn check_rate_limit(
276 &self,
277 info: &FunctionInfo,
278 function_name: &str,
279 auth: &AuthContext,
280 request: &RequestMetadata,
281 ) -> Result<()> {
282 let (requests, per_secs) = match (info.rate_limit_requests, info.rate_limit_per_secs) {
284 (Some(r), Some(p)) => (r, p),
285 _ => return Ok(()),
286 };
287
288 let key_str = info.rate_limit_key.unwrap_or("user");
290 let key_type: RateLimitKey = match key_str.parse() {
291 Ok(k) => k,
292 Err(_) => {
293 tracing::warn!(
294 function = %function_name,
295 key = %key_str,
296 "Invalid rate limit key, falling back to 'user'"
297 );
298 RateLimitKey::default()
299 }
300 };
301
302 let config =
303 RateLimitConfig::new(requests, Duration::from_secs(per_secs)).with_key(key_type);
304
305 let bucket_key = self
307 .rate_limiter
308 .build_key(key_type, function_name, auth, request);
309
310 self.rate_limiter.enforce(&bucket_key, &config).await?;
312
313 Ok(())
314 }
315
316 fn auth_cache_scope(auth: &AuthContext) -> Option<String> {
317 if !auth.is_authenticated() {
318 return Some("anon".to_string());
319 }
320
321 let mut roles = auth.roles().to_vec();
323 roles.sort();
324 roles.dedup();
325
326 let mut claims = BTreeMap::new();
327 for (k, v) in auth.claims() {
328 claims.insert(k.clone(), v.clone());
329 }
330
331 use std::collections::hash_map::DefaultHasher;
332 use std::hash::{Hash, Hasher};
333
334 let mut hasher = DefaultHasher::new();
335 roles.hash(&mut hasher);
336 serde_json::to_string(&claims)
337 .unwrap_or_default()
338 .hash(&mut hasher);
339
340 let principal = auth
341 .principal_id()
342 .unwrap_or_else(|| "authenticated".to_string());
343
344 Some(format!(
345 "subject:{principal}:scope:{:016x}",
346 hasher.finish()
347 ))
348 }
349
350 fn check_identity_args(
351 function_name: &str,
352 args: &Value,
353 auth: &AuthContext,
354 enforce_scope: bool,
355 ) -> Result<()> {
356 if auth.is_admin() {
357 return Ok(());
358 }
359
360 let Some(obj) = args.as_object() else {
361 if enforce_scope && auth.is_authenticated() {
362 return Err(ForgeError::Forbidden(format!(
363 "Function '{function_name}' must include identity or tenant scope arguments"
364 )));
365 }
366 return Ok(());
367 };
368
369 let mut principal_values: Vec<String> = Vec::new();
370 if let Some(user_id) = auth.user_id().map(|id| id.to_string()) {
371 principal_values.push(user_id);
372 }
373 if let Some(subject) = auth.principal_id()
374 && !principal_values.iter().any(|v| v == &subject)
375 {
376 principal_values.push(subject);
377 }
378
379 let mut has_scope_key = false;
380
381 for key in [
382 "user_id",
383 "userId",
384 "owner_id",
385 "ownerId",
386 "owner_subject",
387 "ownerSubject",
388 "subject",
389 "sub",
390 "principal_id",
391 "principalId",
392 ] {
393 let Some(value) = obj.get(key) else {
394 continue;
395 };
396 has_scope_key = true;
397
398 if !auth.is_authenticated() {
399 return Err(ForgeError::Unauthorized(format!(
400 "Function '{function_name}' requires authentication for identity-scoped argument '{key}'"
401 )));
402 }
403
404 let Value::String(actual) = value else {
405 return Err(ForgeError::InvalidArgument(format!(
406 "Function '{function_name}' argument '{key}' must be a non-empty string"
407 )));
408 };
409
410 if actual.trim().is_empty() || !principal_values.iter().any(|v| v == actual) {
411 return Err(ForgeError::Forbidden(format!(
412 "Function '{function_name}' argument '{key}' does not match authenticated principal"
413 )));
414 }
415 }
416
417 for key in ["tenant_id", "tenantId"] {
418 let Some(value) = obj.get(key) else {
419 continue;
420 };
421 has_scope_key = true;
422
423 if !auth.is_authenticated() {
424 return Err(ForgeError::Unauthorized(format!(
425 "Function '{function_name}' requires authentication for tenant-scoped argument '{key}'"
426 )));
427 }
428
429 let expected = auth
430 .claim("tenant_id")
431 .and_then(|v| v.as_str())
432 .ok_or_else(|| {
433 ForgeError::Forbidden(format!(
434 "Function '{function_name}' argument '{key}' is not allowed for this principal"
435 ))
436 })?;
437
438 let Value::String(actual) = value else {
439 return Err(ForgeError::InvalidArgument(format!(
440 "Function '{function_name}' argument '{key}' must be a non-empty string"
441 )));
442 };
443
444 if actual.trim().is_empty() || actual != expected {
445 return Err(ForgeError::Forbidden(format!(
446 "Function '{function_name}' argument '{key}' does not match authenticated tenant"
447 )));
448 }
449 }
450
451 if enforce_scope && auth.is_authenticated() && !has_scope_key {
452 return Err(ForgeError::Forbidden(format!(
453 "Function '{function_name}' must include identity or tenant scope arguments"
454 )));
455 }
456
457 Ok(())
458 }
459
460 pub fn get_function_kind(&self, function_name: &str) -> Option<FunctionKind> {
462 self.registry.get(function_name).map(|e| e.kind())
463 }
464
465 pub fn has_function(&self, function_name: &str) -> bool {
467 self.registry.get(function_name).is_some()
468 }
469
470 async fn execute_transactional(
471 &self,
472 handler: &BoxedMutationFn,
473 args: Value,
474 auth: AuthContext,
475 request: RequestMetadata,
476 ) -> Result<RouteResult> {
477 let span = tracing::info_span!("db.transaction", db.system = "postgresql",);
478
479 async {
480 let primary = self.db.primary();
481 let tx = primary
482 .begin()
483 .await
484 .map_err(|e| ForgeError::Database(e.to_string()))?;
485
486 let job_dispatcher = self.job_dispatcher.clone();
487 let job_lookup: forge_core::JobInfoLookup =
488 Arc::new(move |name: &str| job_dispatcher.as_ref().and_then(|d| d.get_info(name)));
489
490 let (mut ctx, tx_handle, outbox) = MutationContext::with_transaction(
491 primary.clone(),
492 tx,
493 auth,
494 request,
495 self.http_client.clone(),
496 job_lookup,
497 );
498 if let Some(ref issuer) = self.token_issuer {
499 ctx.set_token_issuer(issuer.clone());
500 }
501
502 match handler(&ctx, args).await {
503 Ok(value) => {
504 let buffer = {
505 let guard = outbox.lock().expect("outbox mutex poisoned");
506 OutboxBuffer {
507 jobs: guard.jobs.clone(),
508 workflows: guard.workflows.clone(),
509 }
510 };
511
512 let mut tx = Arc::try_unwrap(tx_handle)
513 .map_err(|_| ForgeError::Internal("Transaction still in use".into()))?
514 .into_inner();
515
516 for job in &buffer.jobs {
517 Self::insert_job(&mut tx, job).await?;
518 }
519
520 for workflow in &buffer.workflows {
521 let version = self
522 .workflow_dispatcher
523 .as_ref()
524 .and_then(|d| d.get_info(&workflow.workflow_name))
525 .map(|info| info.version)
526 .ok_or_else(|| {
527 ForgeError::NotFound(format!(
528 "Workflow '{}' not found",
529 workflow.workflow_name
530 ))
531 })?;
532 Self::insert_workflow(&mut tx, workflow, version).await?;
533 }
534
535 tx.commit()
536 .await
537 .map_err(|e| ForgeError::Database(e.to_string()))?;
538
539 Ok(RouteResult::Mutation(value))
540 }
541 Err(e) => Err(e),
542 }
543 }
544 .instrument(span)
545 .await
546 }
547
548 async fn insert_job(
549 tx: &mut sqlx::Transaction<'_, sqlx::Postgres>,
550 job: &PendingJob,
551 ) -> Result<()> {
552 let now = Utc::now();
553 sqlx::query(
554 r#"
555 INSERT INTO forge_jobs (
556 id, job_type, input, job_context, status, priority, attempts, max_attempts,
557 worker_capability, owner_subject, scheduled_at, created_at
558 ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
559 "#,
560 )
561 .bind(job.id)
562 .bind(&job.job_type)
563 .bind(&job.args)
564 .bind(&job.context)
565 .bind(JobStatus::Pending.as_str())
566 .bind(job.priority)
567 .bind(0i32)
568 .bind(job.max_attempts)
569 .bind(&job.worker_capability)
570 .bind(&job.owner_subject)
571 .bind(now)
572 .bind(now)
573 .execute(&mut **tx)
574 .await
575 .map_err(|e| ForgeError::Database(e.to_string()))?;
576
577 Ok(())
578 }
579
580 async fn insert_workflow(
581 tx: &mut sqlx::Transaction<'_, sqlx::Postgres>,
582 workflow: &PendingWorkflow,
583 version: u32,
584 ) -> Result<()> {
585 let now = Utc::now();
586 sqlx::query(
587 r#"
588 INSERT INTO forge_workflow_runs (
589 id, workflow_name, version, owner_subject, input, status, current_step,
590 step_results, started_at, trace_id
591 ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
592 "#,
593 )
594 .bind(workflow.id)
595 .bind(&workflow.workflow_name)
596 .bind(version as i32)
597 .bind(&workflow.owner_subject)
598 .bind(&workflow.input)
599 .bind(WorkflowStatus::Created.as_str())
600 .bind(Option::<String>::None)
601 .bind(serde_json::json!({}))
602 .bind(now)
603 .bind(workflow.id.to_string())
604 .execute(&mut **tx)
605 .await
606 .map_err(|e| ForgeError::Database(e.to_string()))?;
607
608 Ok(())
609 }
610}
611
612#[cfg(test)]
613mod tests {
614 use super::*;
615 use std::collections::HashMap;
616
617 #[test]
618 fn test_check_auth_public() {
619 let info = FunctionInfo {
620 name: "test",
621 description: None,
622 kind: FunctionKind::Query,
623 required_role: None,
624 is_public: true,
625 cache_ttl: None,
626 timeout: None,
627 rate_limit_requests: None,
628 rate_limit_per_secs: None,
629 rate_limit_key: None,
630 log_level: None,
631 table_dependencies: &[],
632 selected_columns: &[],
633 transactional: false,
634 consistent: false,
635 has_input_args: false,
636 };
637
638 let _auth = AuthContext::unauthenticated();
639
640 assert!(info.is_public);
643 }
644
645 #[test]
646 fn test_identity_args_reject_cross_user_value() {
647 let user_id = uuid::Uuid::new_v4();
648 let auth = AuthContext::authenticated(
649 user_id,
650 vec!["user".to_string()],
651 HashMap::from([(
652 "sub".to_string(),
653 serde_json::Value::String(user_id.to_string()),
654 )]),
655 );
656 let args = serde_json::json!({
657 "user_id": uuid::Uuid::new_v4().to_string()
658 });
659
660 let result = FunctionRouter::check_identity_args("list_orders", &args, &auth, true);
661 assert!(matches!(result, Err(ForgeError::Forbidden(_))));
662 }
663
664 #[test]
665 fn test_identity_args_allow_matching_subject() {
666 let sub = "user_123";
667 let auth = AuthContext::authenticated_without_uuid(
668 vec!["user".to_string()],
669 HashMap::from([(
670 "sub".to_string(),
671 serde_json::Value::String(sub.to_string()),
672 )]),
673 );
674 let args = serde_json::json!({
675 "subject": sub
676 });
677
678 let result = FunctionRouter::check_identity_args("list_orders", &args, &auth, true);
679 assert!(result.is_ok());
680 }
681
682 #[test]
683 fn test_identity_args_require_auth_for_identity_keys() {
684 let auth = AuthContext::unauthenticated();
685 let args = serde_json::json!({
686 "user_id": uuid::Uuid::new_v4().to_string()
687 });
688
689 let result = FunctionRouter::check_identity_args("list_orders", &args, &auth, true);
690 assert!(matches!(result, Err(ForgeError::Unauthorized(_))));
691 }
692
693 #[test]
694 fn test_identity_args_require_scope_for_non_public_calls() {
695 let user_id = uuid::Uuid::new_v4();
696 let auth = AuthContext::authenticated(
697 user_id,
698 vec!["user".to_string()],
699 HashMap::from([(
700 "sub".to_string(),
701 serde_json::Value::String(user_id.to_string()),
702 )]),
703 );
704
705 let result =
706 FunctionRouter::check_identity_args("list_orders", &serde_json::json!({}), &auth, true);
707 assert!(matches!(result, Err(ForgeError::Forbidden(_))));
708 }
709
710 #[test]
711 fn test_identity_args_skip_scope_for_no_input_functions() {
712 let user_id = uuid::Uuid::new_v4();
713 let auth = AuthContext::authenticated(
714 user_id,
715 vec!["user".to_string()],
716 HashMap::from([(
717 "sub".to_string(),
718 serde_json::Value::String(user_id.to_string()),
719 )]),
720 );
721
722 let result = FunctionRouter::check_identity_args(
724 "list_todos",
725 &serde_json::Value::Null,
726 &auth,
727 false,
728 );
729 assert!(result.is_ok());
730 }
731
732 #[test]
733 fn test_auth_cache_scope_changes_with_claims() {
734 let user_id = uuid::Uuid::new_v4();
735 let auth_a = AuthContext::authenticated(
736 user_id,
737 vec!["user".to_string()],
738 HashMap::from([
739 (
740 "sub".to_string(),
741 serde_json::Value::String(user_id.to_string()),
742 ),
743 (
744 "tenant_id".to_string(),
745 serde_json::Value::String("tenant-a".to_string()),
746 ),
747 ]),
748 );
749 let auth_b = AuthContext::authenticated(
750 user_id,
751 vec!["user".to_string()],
752 HashMap::from([
753 (
754 "sub".to_string(),
755 serde_json::Value::String(user_id.to_string()),
756 ),
757 (
758 "tenant_id".to_string(),
759 serde_json::Value::String("tenant-b".to_string()),
760 ),
761 ]),
762 );
763
764 let scope_a = FunctionRouter::auth_cache_scope(&auth_a);
765 let scope_b = FunctionRouter::auth_cache_scope(&auth_b);
766 assert_ne!(scope_a, scope_b);
767 }
768}